src/Tools/Metis/metis.ML
author blanchet
Wed, 15 Sep 2010 10:43:57 +0200
changeset 39385 0049301f7333
parent 39351 1e118007e41a
child 39417 0be01cad5df4
permissions -rw-r--r--
compile on SML/NJ

(*
   This file was generated by the "make-metis" script. A few changes were done
   manually on the script's output; these are marked as follows:

       MODIFIED by Jasmin Blanchette

   Some of these changes are needed so that the ML files compiles at all. Others
   are old tweaks by Lawrence C. Paulson that are needed, if nothing else, for
   backward compatibility. The BSD License is used with Joe Hurd's kind
   permission. Extract from a September 13, 2010 email written by Joe Hurd:

       I hereby give permission to the Isabelle team to release Metis as part
       of Isabelle, with the Metis code covered under the Isabelle BSD
       license.
*)

(******************************************************************)
(* GENERATED FILE -- DO NOT EDIT -- GENERATED FILE -- DO NOT EDIT *)
(* GENERATED FILE -- DO NOT EDIT -- GENERATED FILE -- DO NOT EDIT *)
(* GENERATED FILE -- DO NOT EDIT -- GENERATED FILE -- DO NOT EDIT *)
(******************************************************************)

print_depth 0;

structure Metis = struct structure Word = Word structure Array = Array end;

(**** Original file: Random.sig ****)

(*  Title:      Tools/random_word.ML
    Author:     Makarius

Simple generator for pseudo-random numbers, using unboxed word
arithmetic only.  Unprotected concurrency introduces some true
randomness.
*)

signature Random =
sig

  val nextWord : unit -> word

  val nextBool : unit -> bool

  val nextInt : int -> int  (* k -> [0,k) *)

  val nextReal : unit -> real  (* () -> [0,1) *)

end;

(**** Original file: Random.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(*  Title:      Tools/random_word.ML
    Author:     Makarius

Simple generator for pseudo-random numbers, using unboxed word
arithmetic only.  Unprotected concurrency introduces some true
randomness.
*)

structure Random :> Random =
struct

(* random words: 0w0 <= result <= max_word *)

(*minimum length of unboxed words on all supported ML platforms*)
val _ = Word.wordSize >= 30
  orelse raise Fail ("Bad platform word size");

val max_word = 0wx3FFFFFFF;
val top_bit = 0wx20000000;

(*multiplier according to Borosh and Niederreiter (for modulus = 2^30),
  see http://random.mat.sbg.ac.at/~charly/server/server.html*)
val a = 0w777138309;
fun step x = Word.andb (a * x + 0w1, max_word);

fun change r f = r := f (!r);
local val rand = (*Unsynchronized.*)Unsynchronized.ref 0w1
in fun nextWord () = ((*Unsynchronized.*)change rand step; ! rand) end;

(*NB: higher bits are more random than lower ones*)
fun nextBool () = Word.andb (nextWord (), top_bit) = 0w0;


(* random integers: 0 <= result < k *)

val max_int = Word.toInt max_word;

fun nextInt k =
  if k <= 0 orelse k > max_int then raise Fail ("next_int: out of range")
  else if k = max_int then Word.toInt (nextWord ())
  else Word.toInt (Word.mod (nextWord (), Word.fromInt k));

(* random reals: 0.0 <= result < 1.0 *)

val scaling = real max_int + 1.0;
fun nextReal () = real (Word.toInt (nextWord ())) / scaling;

end;
end;

(**** Original file: Portable.sig ****)

(* ========================================================================= *)
(* ML SPECIFIC FUNCTIONS                                                     *)
(* Copyright (c) 2001-2007 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Portable =
sig

(* ------------------------------------------------------------------------- *)
(* The ML implementation.                                                    *)
(* ------------------------------------------------------------------------- *)

val ml : string

(* ------------------------------------------------------------------------- *)
(* Pointer equality using the run-time system.                               *)
(* ------------------------------------------------------------------------- *)

val pointerEqual : 'a * 'a -> bool

(* ------------------------------------------------------------------------- *)
(* Timing function applications.                                             *)
(* ------------------------------------------------------------------------- *)

val time : ('a -> 'b) -> 'a -> 'b

(* ------------------------------------------------------------------------- *)
(* Generating random values.                                                 *)
(* ------------------------------------------------------------------------- *)

val randomBool : unit -> bool

val randomInt : int -> int  (* n |-> [0,n) *)

val randomReal : unit -> real  (* () |-> [0,1] *)

val randomWord : unit -> Metis.Word.word

end

(**** Original file: PortablePolyml.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* POLY/ML SPECIFIC FUNCTIONS                                                *)
(* Copyright (c) 2008 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

structure Portable :> Portable =
struct

(* ------------------------------------------------------------------------- *)
(* The ML implementation.                                                    *)
(* ------------------------------------------------------------------------- *)

val ml = "polyml";

(* ------------------------------------------------------------------------- *)
(* Pointer equality using the run-time system.                               *)
(* ------------------------------------------------------------------------- *)

fun pointerEqual (x : 'a, y : 'a) = pointer_eq (x, y) (* MODIFIED by Jasmin Blanchette *)

(* ------------------------------------------------------------------------- *)
(* Timing function applications.                                             *)
(* ------------------------------------------------------------------------- *)

fun time f x =
    let
      fun p t =
          let
            val s = Time.fmt 3 t
          in
            case size (List.last (String.fields (fn x => x = #".") s)) of
              3 => s
            | 2 => s ^ "0"
            | 1 => s ^ "00"
            | _ => raise Fail "Portable.time"
          end

      val c = Timer.startCPUTimer ()

      val r = Timer.startRealTimer ()

      fun pt () =
          let
            val {usr,sys} = Timer.checkCPUTimer c
            val real = Timer.checkRealTimer r
          in
            print
              ("User: " ^ p usr ^ "  System: " ^ p sys ^
               "  Real: " ^ p real ^ "\n")
          end

      val y = f x handle e => (pt (); raise e)

      val () = pt ()
    in
      y
    end;

(* ------------------------------------------------------------------------- *)
(* Generating random values.                                                 *)
(* ------------------------------------------------------------------------- *)

val randomWord = Random.nextWord;

val randomBool = Random.nextBool;

val randomInt = Random.nextInt;

val randomReal = Random.nextReal;

end

(* ------------------------------------------------------------------------- *)
(* Quotations a la Moscow ML.                                                *)
(* ------------------------------------------------------------------------- *)

datatype 'a frag = QUOTE of string | ANTIQUOTE of 'a;
end;

(**** Original file: Useful.sig ****)

(* ========================================================================= *)
(* ML UTILITY FUNCTIONS                                                      *)
(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature Useful =
sig

(* ------------------------------------------------------------------------- *)
(* Exceptions.                                                               *)
(* ------------------------------------------------------------------------- *)

exception Error of string

exception Bug of string

val total : ('a -> 'b) -> 'a -> 'b option

val can : ('a -> 'b) -> 'a -> bool

(* ------------------------------------------------------------------------- *)
(* Tracing.                                                                  *)
(* ------------------------------------------------------------------------- *)

val tracePrint : (string -> unit) Unsynchronized.ref

val trace : string -> unit

(* ------------------------------------------------------------------------- *)
(* Combinators.                                                              *)
(* ------------------------------------------------------------------------- *)

val C : ('a -> 'b -> 'c) -> 'b -> 'a -> 'c

val I : 'a -> 'a

val K : 'a -> 'b -> 'a

val S : ('a -> 'b -> 'c) -> ('a -> 'b) -> 'a -> 'c

val W : ('a -> 'a -> 'b) -> 'a -> 'b

val funpow : int -> ('a -> 'a) -> 'a -> 'a

val exp : ('a * 'a -> 'a) -> 'a -> int -> 'a -> 'a

(* ------------------------------------------------------------------------- *)
(* Pairs.                                                                    *)
(* ------------------------------------------------------------------------- *)

val fst : 'a * 'b -> 'a

val snd : 'a * 'b -> 'b

val pair : 'a -> 'b -> 'a * 'b

val swap : 'a * 'b -> 'b * 'a

val curry : ('a * 'b -> 'c) -> 'a -> 'b -> 'c

val uncurry : ('a -> 'b -> 'c) -> 'a * 'b -> 'c

val ## : ('a -> 'c) * ('b -> 'd) -> 'a * 'b -> 'c * 'd

(* ------------------------------------------------------------------------- *)
(* State transformers.                                                       *)
(* ------------------------------------------------------------------------- *)

val unit : 'a -> 's -> 'a * 's

val bind : ('s -> 'a * 's) -> ('a -> 's -> 'b * 's) -> 's -> 'b * 's

val mmap : ('a -> 'b) -> ('s -> 'a * 's) -> 's -> 'b * 's

val mjoin : ('s -> ('s -> 'a * 's) * 's) -> 's -> 'a * 's

val mwhile : ('a -> bool) -> ('a -> 's -> 'a * 's) -> 'a -> 's -> 'a * 's

(* ------------------------------------------------------------------------- *)
(* Equality.                                                                 *)
(* ------------------------------------------------------------------------- *)

val equal : ''a -> ''a -> bool

val notEqual : ''a -> ''a -> bool

val listEqual : ('a -> 'a -> bool) -> 'a list -> 'a list -> bool

(* ------------------------------------------------------------------------- *)
(* Comparisons.                                                              *)
(* ------------------------------------------------------------------------- *)

val mapCompare : ('a -> 'b) -> ('b * 'b -> order) -> 'a * 'a -> order

val revCompare : ('a * 'a -> order) -> 'a * 'a -> order

val prodCompare :
    ('a * 'a -> order) -> ('b * 'b -> order) -> ('a * 'b) * ('a * 'b) -> order

val lexCompare : ('a * 'a -> order) -> 'a list * 'a list -> order

val optionCompare : ('a * 'a -> order) -> 'a option * 'a option -> order

val boolCompare : bool * bool -> order  (* false < true *)

(* ------------------------------------------------------------------------- *)
(* Lists: note we count elements from 0.                                     *)
(* ------------------------------------------------------------------------- *)

val cons : 'a -> 'a list -> 'a list

val hdTl : 'a list -> 'a * 'a list

val append : 'a list -> 'a list -> 'a list

val singleton : 'a -> 'a list

val first : ('a -> 'b option) -> 'a list -> 'b option

val maps : ('a -> 's -> 'b * 's) -> 'a list -> 's -> 'b list * 's

val mapsPartial : ('a -> 's -> 'b option * 's) -> 'a list -> 's -> 'b list * 's

val zipWith : ('a -> 'b -> 'c) -> 'a list -> 'b list -> 'c list

val zip : 'a list -> 'b list -> ('a * 'b) list

val unzip : ('a * 'b) list -> 'a list * 'b list

val cartwith : ('a -> 'b -> 'c) -> 'a list -> 'b list -> 'c list

val cart : 'a list -> 'b list -> ('a * 'b) list

val takeWhile : ('a -> bool) -> 'a list -> 'a list

val dropWhile : ('a -> bool) -> 'a list -> 'a list

val divideWhile : ('a -> bool) -> 'a list -> 'a list * 'a list

val groups : ('a * 's -> bool * 's) -> 's -> 'a list -> 'a list list

val groupsBy : ('a * 'a -> bool) -> 'a list -> 'a list list

val groupsByFst : (''a * 'b) list -> (''a * 'b list) list

val groupsOf : int -> 'a list -> 'a list list

val index : ('a -> bool) -> 'a list -> int option

val enumerate : 'a list -> (int * 'a) list

val divide : 'a list -> int -> 'a list * 'a list  (* Subscript *)

val revDivide : 'a list -> int -> 'a list * 'a list  (* Subscript *)

val updateNth : int * 'a -> 'a list -> 'a list  (* Subscript *)

val deleteNth : int -> 'a list -> 'a list  (* Subscript *)

(* ------------------------------------------------------------------------- *)
(* Sets implemented with lists.                                              *)
(* ------------------------------------------------------------------------- *)

val mem : ''a -> ''a list -> bool

val insert : ''a -> ''a list -> ''a list

val delete : ''a -> ''a list -> ''a list

val setify : ''a list -> ''a list  (* removes duplicates *)

val union : ''a list -> ''a list -> ''a list

val intersect : ''a list -> ''a list -> ''a list

val difference : ''a list -> ''a list -> ''a list

val subset : ''a list -> ''a list -> bool

val distinct : ''a list -> bool

(* ------------------------------------------------------------------------- *)
(* Sorting and searching.                                                    *)
(* ------------------------------------------------------------------------- *)

val minimum : ('a * 'a -> order) -> 'a list -> 'a * 'a list  (* Empty *)

val maximum : ('a * 'a -> order) -> 'a list -> 'a * 'a list  (* Empty *)

val merge : ('a * 'a -> order) -> 'a list -> 'a list -> 'a list

val sort : ('a * 'a -> order) -> 'a list -> 'a list

val sortMap : ('a -> 'b) -> ('b * 'b -> order) -> 'a list -> 'a list

(* ------------------------------------------------------------------------- *)
(* Integers.                                                                 *)
(* ------------------------------------------------------------------------- *)

val interval : int -> int -> int list

val divides : int -> int -> bool

val gcd : int -> int -> int

val primes : int -> int list

val primesUpTo : int -> int list

(* ------------------------------------------------------------------------- *)
(* Strings.                                                                  *)
(* ------------------------------------------------------------------------- *)

val rot : int -> char -> char

val charToInt : char -> int option

val charFromInt : int -> char option

val nChars : char -> int -> string

val chomp : string -> string

val trim : string -> string

val join : string -> string list -> string

val split : string -> string -> string list

val capitalize : string -> string

val mkPrefix : string -> string -> string

val destPrefix : string -> string -> string

val isPrefix : string -> string -> bool

val stripPrefix : (char -> bool) -> string -> string

val mkSuffix : string -> string -> string

val destSuffix : string -> string -> string

val isSuffix : string -> string -> bool

val stripSuffix : (char -> bool) -> string -> string

(* ------------------------------------------------------------------------- *)
(* Tables.                                                                   *)
(* ------------------------------------------------------------------------- *)

type columnAlignment = {leftAlign : bool, padChar : char}

val alignColumn : columnAlignment -> string list -> string list -> string list

val alignTable : columnAlignment list -> string list list -> string list

(* ------------------------------------------------------------------------- *)
(* Reals.                                                                    *)
(* ------------------------------------------------------------------------- *)

val percentToString : real -> string

val pos : real -> real

val log2 : real -> real  (* Domain *)

(* ------------------------------------------------------------------------- *)
(* Sum datatype.                                                             *)
(* ------------------------------------------------------------------------- *)

datatype ('a,'b) sum = Left of 'a | Right of 'b

val destLeft : ('a,'b) sum -> 'a

val isLeft : ('a,'b) sum -> bool

val destRight : ('a,'b) sum -> 'b

val isRight : ('a,'b) sum -> bool

(* ------------------------------------------------------------------------- *)
(* Useful impure features.                                                   *)
(* ------------------------------------------------------------------------- *)

val newInt : unit -> int

val newInts : int -> int list

val withRef : 'r Unsynchronized.ref * 'r -> ('a -> 'b) -> 'a -> 'b

val cloneArray : 'a Metis.Array.array -> 'a Metis.Array.array

(* ------------------------------------------------------------------------- *)
(* The environment.                                                          *)
(* ------------------------------------------------------------------------- *)

val host : unit -> string

val time : unit -> string

val date : unit -> string

val readDirectory : {directory : string} -> {filename : string} list

val readTextFile : {filename : string} -> string

val writeTextFile : {contents : string, filename : string} -> unit

(* ------------------------------------------------------------------------- *)
(* Profiling and error reporting.                                            *)
(* ------------------------------------------------------------------------- *)

val try : ('a -> 'b) -> 'a -> 'b

val chat : string -> unit

val warn : string -> unit

val die : string -> 'exit

val timed : ('a -> 'b) -> 'a -> real * 'b

val timedMany : ('a -> 'b) -> 'a -> real * 'b

val executionTime : unit -> real  (* Wall clock execution time *)

end

(**** Original file: Useful.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* ML UTILITY FUNCTIONS                                                      *)
(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

structure Useful :> Useful =
struct

(* ------------------------------------------------------------------------- *)
(* Exceptions.                                                               *)
(* ------------------------------------------------------------------------- *)

exception Error of string;

exception Bug of string;

fun errorToStringOption err =
    case err of
      Error message => SOME ("Error: " ^ message)
    | _ => NONE;

(*mlton
val () = MLton.Exn.addExnMessager errorToStringOption;
*)

fun errorToString err =
    case errorToStringOption err of
      SOME s => "\n" ^ s ^ "\n"
    | NONE => raise Bug "errorToString: not an Error exception";

fun bugToStringOption err =
    case err of
      Bug message => SOME ("Bug: " ^ message)
    | _ => NONE;

(*mlton
val () = MLton.Exn.addExnMessager bugToStringOption;
*)

fun bugToString err =
    case bugToStringOption err of
      SOME s => "\n" ^ s ^ "\n"
    | NONE => raise Bug "bugToString: not a Bug exception";

fun total f x = SOME (f x) handle Error _ => NONE;

fun can f = Option.isSome o total f;

(* ------------------------------------------------------------------------- *)
(* Tracing.                                                                  *)
(* ------------------------------------------------------------------------- *)

val tracePrint = Unsynchronized.ref print;

fun trace mesg = !tracePrint mesg;

(* ------------------------------------------------------------------------- *)
(* Combinators.                                                              *)
(* ------------------------------------------------------------------------- *)

fun C f x y = f y x;

fun I x = x;

fun K x y = x;

fun S f g x = f x (g x);

fun W f x = f x x;

fun funpow 0 _ x = x
  | funpow n f x = funpow (n - 1) f (f x);

fun exp m =
    let
      fun f _ 0 z = z
        | f x y z = f (m (x,x)) (y div 2) (if y mod 2 = 0 then z else m (z,x))
    in
      f
    end;

(* ------------------------------------------------------------------------- *)
(* Pairs.                                                                    *)
(* ------------------------------------------------------------------------- *)

fun fst (x,_) = x;

fun snd (_,y) = y;

fun pair x y = (x,y);

fun swap (x,y) = (y,x);

fun curry f x y = f (x,y);

fun uncurry f (x,y) = f x y;

val op## = fn (f,g) => fn (x,y) => (f x, g y);

(* ------------------------------------------------------------------------- *)
(* State transformers.                                                       *)
(* ------------------------------------------------------------------------- *)

val unit : 'a -> 's -> 'a * 's = pair;

fun bind f (g : 'a -> 's -> 'b * 's) = uncurry g o f;

fun mmap f (m : 's -> 'a * 's) = bind m (unit o f);

fun mjoin (f : 's -> ('s -> 'a * 's) * 's) = bind f I;

fun mwhile c b = let fun f a = if c a then bind (b a) f else unit a in f end;

(* ------------------------------------------------------------------------- *)
(* Equality.                                                                 *)
(* ------------------------------------------------------------------------- *)

val equal = fn x => fn y => x = y;

val notEqual = fn x => fn y => x <> y;

fun listEqual xEq =
    let
      fun xsEq [] [] = true
        | xsEq (x1 :: xs1) (x2 :: xs2) = xEq x1 x2 andalso xsEq xs1 xs2
        | xsEq _ _ = false
    in
      xsEq
    end;

(* ------------------------------------------------------------------------- *)
(* Comparisons.                                                              *)
(* ------------------------------------------------------------------------- *)

fun mapCompare f cmp (a,b) = cmp (f a, f b);

fun revCompare cmp x_y =
    case cmp x_y of LESS => GREATER | EQUAL => EQUAL | GREATER => LESS;

fun prodCompare xCmp yCmp ((x1,y1),(x2,y2)) =
    case xCmp (x1,x2) of
      LESS => LESS
    | EQUAL => yCmp (y1,y2)
    | GREATER => GREATER;

fun lexCompare cmp =
    let
      fun lex ([],[]) = EQUAL
        | lex ([], _ :: _) = LESS
        | lex (_ :: _, []) = GREATER
        | lex (x :: xs, y :: ys) =
          case cmp (x,y) of
            LESS => LESS
          | EQUAL => lex (xs,ys)
          | GREATER => GREATER
    in
      lex
    end;

fun optionCompare _ (NONE,NONE) = EQUAL
  | optionCompare _ (NONE,_) = LESS
  | optionCompare _ (_,NONE) = GREATER
  | optionCompare cmp (SOME x, SOME y) = cmp (x,y);

fun boolCompare (false,true) = LESS
  | boolCompare (true,false) = GREATER
  | boolCompare _ = EQUAL;

(* ------------------------------------------------------------------------- *)
(* Lists.                                                                    *)
(* ------------------------------------------------------------------------- *)

fun cons x y = x :: y;

fun hdTl l = (hd l, tl l);

fun append xs ys = xs @ ys;

fun singleton a = [a];

fun first f [] = NONE
  | first f (x :: xs) = (case f x of NONE => first f xs | s => s);

fun maps (_ : 'a -> 's -> 'b * 's) [] = unit []
  | maps f (x :: xs) =
    bind (f x) (fn y => bind (maps f xs) (fn ys => unit (y :: ys)));

fun mapsPartial (_ : 'a -> 's -> 'b option * 's) [] = unit []
  | mapsPartial f (x :: xs) =
    bind
      (f x)
      (fn yo =>
          bind
            (mapsPartial f xs)
            (fn ys => unit (case yo of NONE => ys | SOME y => y :: ys)));

fun zipWith f =
    let
      fun z l [] [] = l
        | z l (x :: xs) (y :: ys) = z (f x y :: l) xs ys
        | z _ _ _ = raise Error "zipWith: lists different lengths";
    in
      fn xs => fn ys => rev (z [] xs ys)
    end;

fun zip xs ys = zipWith pair xs ys;

fun unzip ab =
    foldl (fn ((x, y), (xs, ys)) => (x :: xs, y :: ys)) ([], []) (rev ab);

fun cartwith f =
  let
    fun aux _ res _ [] = res
      | aux xsCopy res [] (y :: yt) = aux xsCopy res xsCopy yt
      | aux xsCopy res (x :: xt) (ys as y :: _) =
        aux xsCopy (f x y :: res) xt ys
  in
    fn xs => fn ys =>
       let val xs' = rev xs in aux xs' [] xs' (rev ys) end
  end;

fun cart xs ys = cartwith pair xs ys;

fun takeWhile p =
    let
      fun f acc [] = rev acc
        | f acc (x :: xs) = if p x then f (x :: acc) xs else rev acc
    in
      f []
    end;

fun dropWhile p =
    let
      fun f [] = []
        | f (l as x :: xs) = if p x then f xs else l
    in
      f
    end;

fun divideWhile p =
    let
      fun f acc [] = (rev acc, [])
        | f acc (l as x :: xs) = if p x then f (x :: acc) xs else (rev acc, l)
    in
      f []
    end;

fun groups f =
    let
      fun group acc row x l =
          case l of
            [] =>
            let
              val acc = if null row then acc else rev row :: acc
            in
              rev acc
            end
          | h :: t =>
            let
              val (eor,x) = f (h,x)
            in
              if eor then group (rev row :: acc) [h] x t
              else group acc (h :: row) x t
            end
    in
      group [] []
    end;

fun groupsBy eq =
    let
      fun f (x_y as (x,_)) = (not (eq x_y), x)
    in
      fn [] => []
       | h :: t =>
         case groups f h t of
           [] => [[h]]
         | hs :: ts => (h :: hs) :: ts
    end;

local
  fun fstEq ((x,_),(y,_)) = x = y;

  fun collapse l = (fst (hd l), map snd l);
in
  fun groupsByFst l = map collapse (groupsBy fstEq l);
end;

fun groupsOf n =
    let
      fun f (_,i) = if i = 1 then (true,n) else (false, i - 1)
    in
      groups f (n + 1)
    end;

fun index p =
  let
    fun idx _ [] = NONE
      | idx n (x :: xs) = if p x then SOME n else idx (n + 1) xs
  in
    idx 0
  end;

fun enumerate l = fst (maps (fn x => fn m => ((m, x), m + 1)) l 0);

local
  fun revDiv acc l 0 = (acc,l)
    | revDiv _ [] _ = raise Subscript
    | revDiv acc (h :: t) n = revDiv (h :: acc) t (n - 1);
in
  fun revDivide l = revDiv [] l;
end;

fun divide l n = let val (a,b) = revDivide l n in (rev a, b) end;

fun updateNth (n,x) l =
    let
      val (a,b) = revDivide l n
    in
      case b of [] => raise Subscript | _ :: t => List.revAppend (a, x :: t)
    end;

fun deleteNth n l =
    let
      val (a,b) = revDivide l n
    in
      case b of [] => raise Subscript | _ :: t => List.revAppend (a,t)
    end;

(* ------------------------------------------------------------------------- *)
(* Sets implemented with lists.                                              *)
(* ------------------------------------------------------------------------- *)

fun mem x = List.exists (equal x);

fun insert x s = if mem x s then s else x :: s;

fun delete x s = List.filter (not o equal x) s;

fun setify s = rev (foldl (fn (v,x) => if mem v x then x else v :: x) [] s);

fun union s t = foldl (fn (v,x) => if mem v t then x else v :: x) t (rev s);

fun intersect s t =
    foldl (fn (v,x) => if mem v t then v :: x else x) [] (rev s);

fun difference s t =
    foldl (fn (v,x) => if mem v t then x else v :: x) [] (rev s);

fun subset s t = List.all (fn x => mem x t) s;

fun distinct [] = true
  | distinct (x :: rest) = not (mem x rest) andalso distinct rest;

(* ------------------------------------------------------------------------- *)
(* Sorting and searching.                                                    *)
(* ------------------------------------------------------------------------- *)

(* Finding the minimum and maximum element of a list, wrt some order. *)

fun minimum cmp =
    let
      fun min (l,m,r) _ [] = (m, List.revAppend (l,r))
        | min (best as (_,m,_)) l (x :: r) =
          min (case cmp (x,m) of LESS => (l,x,r) | _ => best) (x :: l) r
    in
      fn [] => raise Empty
       | h :: t => min ([],h,t) [h] t
    end;

fun maximum cmp = minimum (revCompare cmp);

(* Merge (for the following merge-sort, but generally useful too). *)

fun merge cmp =
    let
      fun mrg acc [] ys = List.revAppend (acc,ys)
        | mrg acc xs [] = List.revAppend (acc,xs)
        | mrg acc (xs as x :: xt) (ys as y :: yt) =
          (case cmp (x,y) of
             GREATER => mrg (y :: acc) xs yt
           | _ => mrg (x :: acc) xt ys)
    in
      mrg []
    end;

(* Merge sort (stable). *)

fun sort cmp =
    let
      fun findRuns acc r rs [] = rev (rev (r :: rs) :: acc)
        | findRuns acc r rs (x :: xs) =
          case cmp (r,x) of
            GREATER => findRuns (rev (r :: rs) :: acc) x [] xs
          | _ => findRuns acc x (r :: rs) xs

      fun mergeAdj acc [] = rev acc
        | mergeAdj acc (xs as [_]) = List.revAppend (acc,xs)
        | mergeAdj acc (x :: y :: xs) = mergeAdj (merge cmp x y :: acc) xs

      fun mergePairs [xs] = xs
        | mergePairs l = mergePairs (mergeAdj [] l)
    in
      fn [] => []
       | l as [_] => l
       | h :: t => mergePairs (findRuns [] h [] t)
    end;

fun sortMap _ _ [] = []
  | sortMap _ _ (l as [_]) = l
  | sortMap f cmp xs =
    let
      fun ncmp ((m,_),(n,_)) = cmp (m,n)
      val nxs = map (fn x => (f x, x)) xs
      val nys = sort ncmp nxs
    in
      map snd nys
    end;

(* ------------------------------------------------------------------------- *)
(* Integers.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun interval m 0 = []
  | interval m len = m :: interval (m + 1) (len - 1);

fun divides _ 0 = true
  | divides 0 _ = false
  | divides a b = b mod (Int.abs a) = 0;

local
  fun hcf 0 n = n
    | hcf 1 _ = 1
    | hcf m n = hcf (n mod m) m;
in
  fun gcd m n =
      let
        val m = Int.abs m
        and n = Int.abs n
      in
        if m < n then hcf m n else hcf n m
      end;
end;

local
  fun calcPrimes ps n i =
      if List.exists (fn p => divides p i) ps then calcPrimes ps n (i + 1)
      else
        let
          val ps = ps @ [i]
          and n = n - 1
        in
          if n = 0 then ps else calcPrimes ps n (i + 1)
        end;

  val primesList = Unsynchronized.ref [2];
in
  fun primes n =
      let
        val Unsynchronized.ref ps = primesList

        val k = n - length ps
      in
        if k <= 0 then List.take (ps,n)
        else
          let
            val ps = calcPrimes ps k (List.last ps + 1)

            val () = primesList := ps
          in
            ps
          end
      end;
end;

fun primesUpTo n =
    let
      fun f k =
          let
            val l = primes k

            val p = List.last l
          in
            if p < n then f (2 * k) else takeWhile (fn j => j <= n) l
          end
    in
      f 8
    end;

(* ------------------------------------------------------------------------- *)
(* Strings.                                                                  *)
(* ------------------------------------------------------------------------- *)

local
  fun len l = (length l, l)

  val upper = len (explode "ABCDEFGHIJKLMNOPQRSTUVWXYZ");

  val lower = len (explode "abcdefghijklmnopqrstuvwxyz");

  fun rotate (n,l) c k =
      List.nth (l, (k + Option.valOf (index (equal c) l)) mod n);
in
  fun rot k c =
      if Char.isLower c then rotate lower c k
      else if Char.isUpper c then rotate upper c k
      else c;
end;

fun charToInt #"0" = SOME 0
  | charToInt #"1" = SOME 1
  | charToInt #"2" = SOME 2
  | charToInt #"3" = SOME 3
  | charToInt #"4" = SOME 4
  | charToInt #"5" = SOME 5
  | charToInt #"6" = SOME 6
  | charToInt #"7" = SOME 7
  | charToInt #"8" = SOME 8
  | charToInt #"9" = SOME 9
  | charToInt _ = NONE;

fun charFromInt 0 = SOME #"0"
  | charFromInt 1 = SOME #"1"
  | charFromInt 2 = SOME #"2"
  | charFromInt 3 = SOME #"3"
  | charFromInt 4 = SOME #"4"
  | charFromInt 5 = SOME #"5"
  | charFromInt 6 = SOME #"6"
  | charFromInt 7 = SOME #"7"
  | charFromInt 8 = SOME #"8"
  | charFromInt 9 = SOME #"9"
  | charFromInt _ = NONE;

fun nChars x =
    let
      fun dup 0 l = l | dup n l = dup (n - 1) (x :: l)
    in
      fn n => implode (dup n [])
    end;

fun chomp s =
    let
      val n = size s
    in
      if n = 0 orelse String.sub (s, n - 1) <> #"\n" then s
      else String.substring (s, 0, n - 1)
    end;

local
  fun chop [] = []
    | chop (l as (h :: t)) = if Char.isSpace h then chop t else l;
in
  val trim = implode o chop o rev o chop o rev o explode;
end;

fun join _ [] = ""
  | join s (h :: t) = foldl (fn (x,y) => y ^ s ^ x) h t;

local
  fun match [] l = SOME l
    | match _ [] = NONE
    | match (x :: xs) (y :: ys) = if x = y then match xs ys else NONE;

  fun stringify acc [] = acc
    | stringify acc (h :: t) = stringify (implode h :: acc) t;
in
  fun split sep =
      let
        val pat = String.explode sep
        fun div1 prev recent [] = stringify [] (rev recent :: prev)
          | div1 prev recent (l as h :: t) =
            case match pat l of
              NONE => div1 prev (h :: recent) t
            | SOME rest => div1 (rev recent :: prev) [] rest
      in
        fn s => div1 [] [] (explode s)
      end;
end;

fun capitalize s =
    if s = "" then s
    else str (Char.toUpper (String.sub (s,0))) ^ String.extract (s,1,NONE);

fun mkPrefix p s = p ^ s;

fun destPrefix p =
    let
      fun check s =
          if String.isPrefix p s then ()
          else raise Error "destPrefix"

      val sizeP = size p
    in
      fn s =>
         let
           val () = check s
         in
           String.extract (s,sizeP,NONE)
         end
    end;

fun isPrefix p = can (destPrefix p);

fun stripPrefix pred s =
    Substring.string (Substring.dropl pred (Substring.full s));

fun mkSuffix p s = s ^ p;

fun destSuffix p =
    let
      fun check s =
          if String.isSuffix p s then ()
          else raise Error "destSuffix"

      val sizeP = size p
    in
      fn s =>
         let
           val () = check s

           val sizeS = size s
         in
           String.substring (s, 0, sizeS - sizeP)
         end
    end;

fun isSuffix p = can (destSuffix p);

fun stripSuffix pred s =
    Substring.string (Substring.dropr pred (Substring.full s));

(* ------------------------------------------------------------------------- *)
(* Tables.                                                                   *)
(* ------------------------------------------------------------------------- *)

type columnAlignment = {leftAlign : bool, padChar : char}

fun alignColumn {leftAlign,padChar} column =
    let
      val (n,_) = maximum Int.compare (map size column)

      fun pad entry row =
          let
            val padding = nChars padChar (n - size entry)
          in
            if leftAlign then entry ^ padding ^ row
            else padding ^ entry ^ row
          end
    in
      zipWith pad column
    end;

local
  fun alignTab aligns rows =
      case aligns of
        [] => map (K "") rows
      | [{leftAlign = true, padChar = #" "}] => map hd rows
      | align :: aligns =>
        alignColumn align (map hd rows) (alignTab aligns (map tl rows));
in
  fun alignTable aligns rows =
      if null rows then [] else alignTab aligns rows;
end;

(* ------------------------------------------------------------------------- *)
(* Reals.                                                                    *)
(* ------------------------------------------------------------------------- *)

val realToString = Real.toString;

fun percentToString x = Int.toString (Real.round (100.0 * x)) ^ "%";

fun pos r = Real.max (r,0.0);

local
  val invLn2 = 1.0 / Math.ln 2.0;
in
  fun log2 x = invLn2 * Math.ln x;
end;

(* ------------------------------------------------------------------------- *)
(* Sums.                                                                     *)
(* ------------------------------------------------------------------------- *)

datatype ('a,'b) sum = Left of 'a | Right of 'b

fun destLeft (Left l) = l
  | destLeft _ = raise Error "destLeft";

fun isLeft (Left _) = true
  | isLeft (Right _) = false;

fun destRight (Right r) = r
  | destRight _ = raise Error "destRight";

fun isRight (Left _) = false
  | isRight (Right _) = true;

(* ------------------------------------------------------------------------- *)
(* Useful impure features.                                                   *)
(* ------------------------------------------------------------------------- *)

local
  val generator = Unsynchronized.ref 0
in
  fun newInt () =
      let
        val n = !generator
        val () = generator := n + 1
      in
        n
      end;

  fun newInts 0 = []
    | newInts k =
      let
        val n = !generator
        val () = generator := n + k
      in
        interval n k
      end;
end;

fun withRef (r,new) f x =
  let
    val old = !r
    val () = r := new
    val y = f x handle e => (r := old; raise e)
    val () = r := old
  in
    y
  end;

fun cloneArray a =
    let
      fun index i = Array.sub (a,i)
    in
      Array.tabulate (Array.length a, index)
    end;

(* ------------------------------------------------------------------------- *)
(* Environment.                                                              *)
(* ------------------------------------------------------------------------- *)

fun host () = Option.getOpt (OS.Process.getEnv "HOSTNAME", "unknown");

fun time () = Date.fmt "%H:%M:%S" (Date.fromTimeLocal (Time.now ()));

fun date () = Date.fmt "%d/%m/%Y" (Date.fromTimeLocal (Time.now ()));

fun readDirectory {directory = dir} =
    let
      val dirStrm = OS.FileSys.openDir dir

      fun readAll acc =
          case OS.FileSys.readDir dirStrm of
            NONE => acc
          | SOME file =>
            let
              val filename = OS.Path.joinDirFile {dir = dir, file = file}

              val acc = {filename = filename} :: acc
            in
              readAll acc
            end

      val filenames = readAll []

      val () = OS.FileSys.closeDir dirStrm
    in
      rev filenames
    end;

fun readTextFile {filename} =
  let
    open TextIO

    val h = openIn filename

    val contents = inputAll h

    val () = closeIn h
  in
    contents
  end;

fun writeTextFile {contents,filename} =
  let
    open TextIO
    val h = openOut filename
    val () = output (h,contents)
    val () = closeOut h
  in
    ()
  end;

(* ------------------------------------------------------------------------- *)
(* Profiling and error reporting.                                            *)
(* ------------------------------------------------------------------------- *)

fun chat s = TextIO.output (TextIO.stdErr, s ^ "\n");

local
  fun err x s = chat (x ^ ": " ^ s);
in
  fun try f x = f x
      handle e as Error _ => (err "try" (errorToString e); raise e)
           | e as Bug _ => (err "try" (bugToString e); raise e)
           | e => (err "try" "strange exception raised"; raise e);

  val warn = err "WARNING";

  fun die s = (err "\nFATAL ERROR" s; OS.Process.exit OS.Process.failure);
end;

fun timed f a =
  let
    val tmr = Timer.startCPUTimer ()
    val res = f a
    val {usr,sys,...} = Timer.checkCPUTimer tmr
  in
    (Time.toReal usr + Time.toReal sys, res)
  end;

local
  val MIN = 1.0;

  fun several n t f a =
    let
      val (t',res) = timed f a
      val t = t + t'
      val n = n + 1
    in
      if t > MIN then (t / Real.fromInt n, res) else several n t f a
    end;
in
  fun timedMany f a = several 0 0.0 f a
end;

val executionTime =
    let
      val startTime = Time.toReal (Time.now ())
    in
      fn () => Time.toReal (Time.now ()) - startTime
    end;

end
end;

(**** Original file: Lazy.sig ****)

(* ========================================================================= *)
(* SUPPORT FOR LAZY EVALUATION                                               *)
(* Copyright (c) 2007 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature Lazy =
sig

type 'a lazy

val quickly : 'a -> 'a lazy

val delay : (unit -> 'a) -> 'a lazy

val force : 'a lazy -> 'a

val memoize : (unit -> 'a) -> unit -> 'a

end

(**** Original file: Lazy.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* SUPPORT FOR LAZY EVALUATION                                               *)
(* Copyright (c) 2007 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

structure Lazy :> Lazy =
struct

datatype 'a thunk =
    Value of 'a
  | Thunk of unit -> 'a;

datatype 'a lazy = Lazy of 'a thunk Unsynchronized.ref;

fun quickly v = Lazy (Unsynchronized.ref (Value v));

fun delay f = Lazy (Unsynchronized.ref (Thunk f));

fun force (Lazy s) =
    case !s of
      Value v => v
    | Thunk f =>
      let
        val v = f ()

        val () = s := Value v
      in
        v
      end;

fun memoize f =
    let
      val t = delay f
    in
      fn () => force t
    end;

end
end;

(**** Original file: Stream.sig ****)

(* ========================================================================= *)
(* A POSSIBLY-INFINITE STREAM DATATYPE FOR ML                                *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Stream =
sig

(* ------------------------------------------------------------------------- *)
(* The stream type.                                                          *)
(* ------------------------------------------------------------------------- *)

datatype 'a stream = Nil | Cons of 'a * (unit -> 'a stream)

(* If you're wondering how to create an infinite stream: *)
(* val stream4 = let fun s4 () = Metis.Stream.Cons (4,s4) in s4 () end; *)

(* ------------------------------------------------------------------------- *)
(* Stream constructors.                                                      *)
(* ------------------------------------------------------------------------- *)

val repeat : 'a -> 'a stream

val count : int -> int stream

val funpows : ('a -> 'a) -> 'a -> 'a stream

(* ------------------------------------------------------------------------- *)
(* Stream versions of standard list operations: these should all terminate.  *)
(* ------------------------------------------------------------------------- *)

val cons : 'a -> (unit -> 'a stream) -> 'a stream

val null : 'a stream -> bool

val hd : 'a stream -> 'a  (* raises Empty *)

val tl : 'a stream -> 'a stream  (* raises Empty *)

val hdTl : 'a stream -> 'a * 'a stream  (* raises Empty *)

val singleton : 'a -> 'a stream

val append : 'a stream -> (unit -> 'a stream) -> 'a stream

val map : ('a -> 'b) -> 'a stream -> 'b stream

val maps :
    ('a -> 's -> 'b * 's) -> ('s -> 'b stream) -> 's -> 'a stream -> 'b stream

val zipwith : ('a -> 'b -> 'c) -> 'a stream -> 'b stream -> 'c stream

val zip : 'a stream -> 'b stream -> ('a * 'b) stream

val take : int -> 'a stream -> 'a stream  (* raises Subscript *)

val drop : int -> 'a stream -> 'a stream  (* raises Subscript *)

(* ------------------------------------------------------------------------- *)
(* Stream versions of standard list operations: these might not terminate.   *)
(* ------------------------------------------------------------------------- *)

val length : 'a stream -> int

val exists : ('a -> bool) -> 'a stream -> bool

val all : ('a -> bool) -> 'a stream -> bool

val filter : ('a -> bool) -> 'a stream -> 'a stream

val foldl : ('a * 's -> 's) -> 's -> 'a stream -> 's

val concat : 'a stream stream -> 'a stream

val mapPartial : ('a -> 'b option) -> 'a stream -> 'b stream

val mapsPartial :
    ('a -> 's -> 'b option * 's) -> ('s -> 'b stream) -> 's ->
    'a stream -> 'b stream

val mapConcat : ('a -> 'b stream) -> 'a stream -> 'b stream

val mapsConcat :
    ('a -> 's -> 'b stream * 's) -> ('s -> 'b stream) -> 's ->
    'a stream -> 'b stream

(* ------------------------------------------------------------------------- *)
(* Stream operations.                                                        *)
(* ------------------------------------------------------------------------- *)

val memoize : 'a stream -> 'a stream

val listConcat : 'a list stream -> 'a stream

val concatList : 'a stream list -> 'a stream

val toList : 'a stream -> 'a list

val fromList : 'a list -> 'a stream

val toString : char stream -> string

val fromString : string -> char stream

val toTextFile : {filename : string} -> string stream -> unit

val fromTextFile : {filename : string} -> string stream  (* line by line *)

end

(**** Original file: Stream.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* A POSSIBLY-INFINITE STREAM DATATYPE FOR ML                                *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Stream :> Stream =
struct

val K = Useful.K;

val pair = Useful.pair;

val funpow = Useful.funpow;

(* ------------------------------------------------------------------------- *)
(* The stream type.                                                          *)
(* ------------------------------------------------------------------------- *)

datatype 'a stream =
    Nil
  | Cons of 'a * (unit -> 'a stream);

(* ------------------------------------------------------------------------- *)
(* Stream constructors.                                                      *)
(* ------------------------------------------------------------------------- *)

fun repeat x = let fun rep () = Cons (x,rep) in rep () end;

fun count n = Cons (n, fn () => count (n + 1));

fun funpows f x = Cons (x, fn () => funpows f (f x));

(* ------------------------------------------------------------------------- *)
(* Stream versions of standard list operations: these should all terminate.  *)
(* ------------------------------------------------------------------------- *)

fun cons h t = Cons (h,t);

fun null Nil = true
  | null (Cons _) = false;

fun hd Nil = raise Empty
  | hd (Cons (h,_)) = h;

fun tl Nil = raise Empty
  | tl (Cons (_,t)) = t ();

fun hdTl s = (hd s, tl s);

fun singleton s = Cons (s, K Nil);

fun append Nil s = s ()
  | append (Cons (h,t)) s = Cons (h, fn () => append (t ()) s);

fun map f =
    let
      fun m Nil = Nil
        | m (Cons (h,t)) = Cons (f h, m o t)
    in
      m
    end;

fun maps f g =
    let
      fun mm s Nil = g s
        | mm s (Cons (x,xs)) =
          let
            val (y,s') = f x s
          in
            Cons (y, mm s' o xs)
          end
    in
      mm
    end;

fun zipwith f =
    let
      fun z Nil _ = Nil
        | z _ Nil = Nil
        | z (Cons (x,xs)) (Cons (y,ys)) =
          Cons (f x y, fn () => z (xs ()) (ys ()))
    in
      z
    end;

fun zip s t = zipwith pair s t;

fun take 0 _ = Nil
  | take n Nil = raise Subscript
  | take 1 (Cons (x,_)) = Cons (x, K Nil)
  | take n (Cons (x,xs)) = Cons (x, fn () => take (n - 1) (xs ()));

fun drop n s = funpow n tl s handle Empty => raise Subscript;

(* ------------------------------------------------------------------------- *)
(* Stream versions of standard list operations: these might not terminate.   *)
(* ------------------------------------------------------------------------- *)

local
  fun len n Nil = n
    | len n (Cons (_,t)) = len (n + 1) (t ());
in
  fun length s = len 0 s;
end;

fun exists pred =
    let
      fun f Nil = false
        | f (Cons (h,t)) = pred h orelse f (t ())
    in
      f
    end;

fun all pred = not o exists (not o pred);

fun filter p Nil = Nil
  | filter p (Cons (x,xs)) =
    if p x then Cons (x, fn () => filter p (xs ())) else filter p (xs ());

fun foldl f =
    let
      fun fold b Nil = b
        | fold b (Cons (h,t)) = fold (f (h,b)) (t ())
    in
      fold
    end;

fun concat Nil = Nil
  | concat (Cons (Nil, ss)) = concat (ss ())
  | concat (Cons (Cons (x, xs), ss)) =
    Cons (x, fn () => concat (Cons (xs (), ss)));

fun mapPartial f =
    let
      fun mp Nil = Nil
        | mp (Cons (h,t)) =
          case f h of
            NONE => mp (t ())
          | SOME h' => Cons (h', fn () => mp (t ()))
    in
      mp
    end;

fun mapsPartial f g =
    let
      fun mp s Nil = g s
        | mp s (Cons (h,t)) =
          let
            val (h,s) = f h s
          in
            case h of
              NONE => mp s (t ())
            | SOME h => Cons (h, fn () => mp s (t ()))
          end
    in
      mp
    end;

fun mapConcat f =
    let
      fun mc Nil = Nil
        | mc (Cons (h,t)) = append (f h) (fn () => mc (t ()))
    in
      mc
    end;

fun mapsConcat f g =
    let
      fun mc s Nil = g s
        | mc s (Cons (h,t)) =
          let
            val (l,s) = f h s
          in
            append l (fn () => mc s (t ()))
          end
    in
      mc
    end;

(* ------------------------------------------------------------------------- *)
(* Stream operations.                                                        *)
(* ------------------------------------------------------------------------- *)

fun memoize Nil = Nil
  | memoize (Cons (h,t)) = Cons (h, Lazy.memoize (fn () => memoize (t ())));

fun concatList [] = Nil
  | concatList (h :: t) = append h (fn () => concatList t);

local
  fun toLst res Nil = rev res
    | toLst res (Cons (x, xs)) = toLst (x :: res) (xs ());
in
  fun toList s = toLst [] s;
end;

fun fromList [] = Nil
  | fromList (x :: xs) = Cons (x, fn () => fromList xs);

fun listConcat s = concat (map fromList s);

fun toString s = implode (toList s);

fun fromString s = fromList (explode s);

fun toTextFile {filename = f} s =
    let
      val (h,close) =
          if f = "-" then (TextIO.stdOut, K ())
          else (TextIO.openOut f, TextIO.closeOut)

      fun toFile Nil = ()
        | toFile (Cons (x,y)) = (TextIO.output (h,x); toFile (y ()))

      val () = toFile s
    in
      close h
    end;

fun fromTextFile {filename = f} =
    let
      val (h,close) =
          if f = "-" then (TextIO.stdIn, K ())
          else (TextIO.openIn f, TextIO.closeIn)

      fun strm () =
          case TextIO.inputLine h of
            NONE => (close h; Nil)
          | SOME s => Cons (s,strm)
    in
      memoize (strm ())
    end;

end
end;

(**** Original file: Ordered.sig ****)

(* ========================================================================= *)
(* ORDERED TYPES                                                             *)
(* Copyright (c) 2004-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Ordered =
sig

type t

val compare : t * t -> order

(*
  PROVIDES

  !x : t. compare (x,x) = EQUAL

  !x y : t. compare (x,y) = LESS <=> compare (y,x) = GREATER

  !x y : t. compare (x,y) = EQUAL ==> compare (y,x) = EQUAL

  !x y z : t. compare (x,y) = EQUAL ==> compare (x,z) = compare (y,z)

  !x y z : t.
    compare (x,y) = LESS andalso compare (y,z) = LESS ==>
    compare (x,z) = LESS

  !x y z : t.
    compare (x,y) = GREATER andalso compare (y,z) = GREATER ==>
    compare (x,z) = GREATER
*)

end

(**** Original file: Ordered.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* ORDERED TYPES                                                             *)
(* Copyright (c) 2004-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure IntOrdered =
struct type t = int val compare = Int.compare end;

structure IntPairOrdered =
struct

type t = int * int;

fun compare ((i1,j1),(i2,j2)) =
    case Int.compare (i1,i2) of
      LESS => LESS
    | EQUAL => Int.compare (j1,j2)
    | GREATER => GREATER;

end;

structure StringOrdered =
struct type t = string val compare = String.compare end;
end;

(**** Original file: Map.sig ****)

(* ========================================================================= *)
(* FINITE MAPS IMPLEMENTED WITH RANDOMLY BALANCED TREES                      *)
(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature Map =
sig

(* ------------------------------------------------------------------------- *)
(* A type of finite maps.                                                    *)
(* ------------------------------------------------------------------------- *)

type ('key,'a) map

(* ------------------------------------------------------------------------- *)
(* Constructors.                                                             *)
(* ------------------------------------------------------------------------- *)

val new : ('key * 'key -> order) -> ('key,'a) map

val singleton : ('key * 'key -> order) -> 'key * 'a -> ('key,'a) map

(* ------------------------------------------------------------------------- *)
(* Map size.                                                                 *)
(* ------------------------------------------------------------------------- *)

val null : ('key,'a) map -> bool

val size : ('key,'a) map -> int

(* ------------------------------------------------------------------------- *)
(* Querying.                                                                 *)
(* ------------------------------------------------------------------------- *)

val peekKey : ('key,'a) map -> 'key -> ('key * 'a) option

val peek : ('key,'a) map -> 'key -> 'a option

val get : ('key,'a) map -> 'key -> 'a  (* raises Error *)

val pick : ('key,'a) map -> 'key * 'a  (* an arbitrary key/value pair *)

val nth : ('key,'a) map -> int -> 'key * 'a  (* in the range [0,size-1] *)

val random : ('key,'a) map -> 'key * 'a

(* ------------------------------------------------------------------------- *)
(* Adding.                                                                   *)
(* ------------------------------------------------------------------------- *)

val insert : ('key,'a) map -> 'key * 'a -> ('key,'a) map

val insertList : ('key,'a) map -> ('key * 'a) list -> ('key,'a) map

(* ------------------------------------------------------------------------- *)
(* Removing.                                                                 *)
(* ------------------------------------------------------------------------- *)

val delete : ('key,'a) map -> 'key -> ('key,'a) map  (* must be present *)

val remove : ('key,'a) map -> 'key -> ('key,'a) map

val deletePick : ('key,'a) map -> ('key * 'a) * ('key,'a) map

val deleteNth : ('key,'a) map -> int -> ('key * 'a) * ('key,'a) map

val deleteRandom : ('key,'a) map -> ('key * 'a) * ('key,'a) map

(* ------------------------------------------------------------------------- *)
(* Joining (all join operations prefer keys in the second map).              *)
(* ------------------------------------------------------------------------- *)

val merge :
    {first : 'key * 'a -> 'c option,
     second : 'key * 'b -> 'c option,
     both : ('key * 'a) * ('key * 'b) -> 'c option} ->
    ('key,'a) map -> ('key,'b) map -> ('key,'c) map

val union :
    (('key * 'a) * ('key * 'a) -> 'a option) ->
    ('key,'a) map -> ('key,'a) map -> ('key,'a) map

val intersect :
    (('key * 'a) * ('key * 'b) -> 'c option) ->
    ('key,'a) map -> ('key,'b) map -> ('key,'c) map

(* ------------------------------------------------------------------------- *)
(* Set operations on the domain.                                             *)
(* ------------------------------------------------------------------------- *)

val inDomain : 'key -> ('key,'a) map -> bool

val unionDomain : ('key,'a) map -> ('key,'a) map -> ('key,'a) map

val unionListDomain : ('key,'a) map list -> ('key,'a) map

val intersectDomain : ('key,'a) map -> ('key,'a) map -> ('key,'a) map

val intersectListDomain : ('key,'a) map list -> ('key,'a) map

val differenceDomain : ('key,'a) map -> ('key,'a) map -> ('key,'a) map

val symmetricDifferenceDomain : ('key,'a) map -> ('key,'a) map -> ('key,'a) map

val equalDomain : ('key,'a) map -> ('key,'a) map -> bool

val subsetDomain : ('key,'a) map -> ('key,'a) map -> bool

val disjointDomain : ('key,'a) map -> ('key,'a) map -> bool

(* ------------------------------------------------------------------------- *)
(* Mapping and folding.                                                      *)
(* ------------------------------------------------------------------------- *)

val mapPartial : ('key * 'a -> 'b option) -> ('key,'a) map -> ('key,'b) map

val map : ('key * 'a -> 'b) -> ('key,'a) map -> ('key,'b) map

val app : ('key * 'a -> unit) -> ('key,'a) map -> unit

val transform : ('a -> 'b) -> ('key,'a) map -> ('key,'b) map

val filter : ('key * 'a -> bool) -> ('key,'a) map -> ('key,'a) map

val partition :
    ('key * 'a -> bool) -> ('key,'a) map -> ('key,'a) map * ('key,'a) map

val foldl : ('key * 'a * 's -> 's) -> 's -> ('key,'a) map -> 's

val foldr : ('key * 'a * 's -> 's) -> 's -> ('key,'a) map -> 's

(* ------------------------------------------------------------------------- *)
(* Searching.                                                                *)
(* ------------------------------------------------------------------------- *)

val findl : ('key * 'a -> bool) -> ('key,'a) map -> ('key * 'a) option

val findr : ('key * 'a -> bool) -> ('key,'a) map -> ('key * 'a) option

val firstl : ('key * 'a -> 'b option) -> ('key,'a) map -> 'b option

val firstr : ('key * 'a -> 'b option) -> ('key,'a) map -> 'b option

val exists : ('key * 'a -> bool) -> ('key,'a) map -> bool

val all : ('key * 'a -> bool) -> ('key,'a) map -> bool

val count : ('key * 'a -> bool) -> ('key,'a) map -> int

(* ------------------------------------------------------------------------- *)
(* Comparing.                                                                *)
(* ------------------------------------------------------------------------- *)

val compare : ('a * 'a -> order) -> ('key,'a) map * ('key,'a) map -> order

val equal : ('a -> 'a -> bool) -> ('key,'a) map -> ('key,'a) map -> bool

(* ------------------------------------------------------------------------- *)
(* Converting to and from lists.                                             *)
(* ------------------------------------------------------------------------- *)

val keys : ('key,'a) map -> 'key list

val values : ('key,'a) map -> 'a list

val toList : ('key,'a) map -> ('key * 'a) list

val fromList : ('key * 'key -> order) -> ('key * 'a) list -> ('key,'a) map

(* ------------------------------------------------------------------------- *)
(* Pretty-printing.                                                          *)
(* ------------------------------------------------------------------------- *)

val toString : ('key,'a) map -> string

(* ------------------------------------------------------------------------- *)
(* Iterators over maps.                                                      *)
(* ------------------------------------------------------------------------- *)

type ('key,'a) iterator

val mkIterator : ('key,'a) map -> ('key,'a) iterator option

val mkRevIterator : ('key,'a) map -> ('key,'a) iterator option

val readIterator : ('key,'a) iterator -> 'key * 'a

val advanceIterator : ('key,'a) iterator -> ('key,'a) iterator option

end

(**** Original file: Map.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* FINITE MAPS IMPLEMENTED WITH RANDOMLY BALANCED TREES                      *)
(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

structure Map :> Map =
struct

(* ------------------------------------------------------------------------- *)
(* Importing useful functionality.                                           *)
(* ------------------------------------------------------------------------- *)

exception Bug = Useful.Bug;

exception Error = Useful.Error;

val pointerEqual = Portable.pointerEqual;

val K = Useful.K;

val randomInt = Portable.randomInt;

val randomWord = Portable.randomWord;

(* ------------------------------------------------------------------------- *)
(* Converting a comparison function to an equality function.                 *)
(* ------------------------------------------------------------------------- *)

fun equalKey compareKey key1 key2 = compareKey (key1,key2) = EQUAL;

(* ------------------------------------------------------------------------- *)
(* Priorities.                                                               *)
(* ------------------------------------------------------------------------- *)

type priority = Word.word;

val randomPriority = randomWord;

val comparePriority = Word.compare;

(* ------------------------------------------------------------------------- *)
(* Priority search trees.                                                    *)
(* ------------------------------------------------------------------------- *)

datatype ('key,'value) tree =
    E
  | T of ('key,'value) node

and ('key,'value) node =
    Node of
      {size : int,
       priority : priority,
       left : ('key,'value) tree,
       key : 'key,
       value : 'value,
       right : ('key,'value) tree};

fun lowerPriorityNode node1 node2 =
    let
      val Node {priority = p1, ...} = node1
      and Node {priority = p2, ...} = node2
    in
      comparePriority (p1,p2) = LESS
    end;

(* ------------------------------------------------------------------------- *)
(* Tree debugging functions.                                                 *)
(* ------------------------------------------------------------------------- *)

(*BasicDebug
local
  fun checkSizes tree =
      case tree of
        E => 0
      | T (Node {size,left,right,...}) =>
        let
          val l = checkSizes left
          and r = checkSizes right

          val () = if l + 1 + r = size then () else raise Bug "wrong size"
        in
          size
        end;

  fun checkSorted compareKey x tree =
      case tree of
        E => x
      | T (Node {left,key,right,...}) =>
        let
          val x = checkSorted compareKey x left

          val () =
              case x of
                NONE => ()
              | SOME k =>
                case compareKey (k,key) of
                  LESS => ()
                | EQUAL => raise Bug "duplicate keys"
                | GREATER => raise Bug "unsorted"

          val x = SOME key
        in
          checkSorted compareKey x right
        end;

  fun checkPriorities compareKey tree =
      case tree of
        E => NONE
      | T node =>
        let
          val Node {left,right,...} = node

          val () =
              case checkPriorities compareKey left of
                NONE => ()
              | SOME lnode =>
                if not (lowerPriorityNode node lnode) then ()
                else raise Bug "left child has greater priority"

          val () =
              case checkPriorities compareKey right of
                NONE => ()
              | SOME rnode =>
                if not (lowerPriorityNode node rnode) then ()
                else raise Bug "right child has greater priority"
        in
          SOME node
        end;
in
  fun treeCheckInvariants compareKey tree =
      let
        val _ = checkSizes tree

        val _ = checkSorted compareKey NONE tree

        val _ = checkPriorities compareKey tree
      in
        tree
      end
      handle Error err => raise Bug err;
end;
*)

(* ------------------------------------------------------------------------- *)
(* Tree operations.                                                          *)
(* ------------------------------------------------------------------------- *)

fun treeNew () = E;

fun nodeSize (Node {size = x, ...}) = x;

fun treeSize tree =
    case tree of
      E => 0
    | T x => nodeSize x;

fun mkNode priority left key value right =
    let
      val size = treeSize left + 1 + treeSize right
    in
      Node
        {size = size,
         priority = priority,
         left = left,
         key = key,
         value = value,
         right = right}
    end;

fun mkTree priority left key value right =
    let
      val node = mkNode priority left key value right
    in
      T node
    end;

(* ------------------------------------------------------------------------- *)
(* Extracting the left and right spines of a tree.                           *)
(* ------------------------------------------------------------------------- *)

fun treeLeftSpine acc tree =
    case tree of
      E => acc
    | T node => nodeLeftSpine acc node

and nodeLeftSpine acc node =
    let
      val Node {left,...} = node
    in
      treeLeftSpine (node :: acc) left
    end;

fun treeRightSpine acc tree =
    case tree of
      E => acc
    | T node => nodeRightSpine acc node

and nodeRightSpine acc node =
    let
      val Node {right,...} = node
    in
      treeRightSpine (node :: acc) right
    end;

(* ------------------------------------------------------------------------- *)
(* Singleton trees.                                                          *)
(* ------------------------------------------------------------------------- *)

fun mkNodeSingleton priority key value =
    let
      val size = 1
      and left = E
      and right = E
    in
      Node
        {size = size,
         priority = priority,
         left = left,
         key = key,
         value = value,
         right = right}
    end;

fun nodeSingleton (key,value) =
    let
      val priority = randomPriority ()
    in
      mkNodeSingleton priority key value
    end;

fun treeSingleton key_value =
    let
      val node = nodeSingleton key_value
    in
      T node
    end;

(* ------------------------------------------------------------------------- *)
(* Appending two trees, where every element of the first tree is less than   *)
(* every element of the second tree.                                         *)
(* ------------------------------------------------------------------------- *)

fun treeAppend tree1 tree2 =
    case tree1 of
      E => tree2
    | T node1 =>
      case tree2 of
        E => tree1
      | T node2 =>
        if lowerPriorityNode node1 node2 then
          let
            val Node {priority,left,key,value,right,...} = node2

            val left = treeAppend tree1 left
          in
            mkTree priority left key value right
          end
        else
          let
            val Node {priority,left,key,value,right,...} = node1

            val right = treeAppend right tree2
          in
            mkTree priority left key value right
          end;

(* ------------------------------------------------------------------------- *)
(* Appending two trees and a node, where every element of the first tree is  *)
(* less than the node, which in turn is less than every element of the       *)
(* second tree.                                                              *)
(* ------------------------------------------------------------------------- *)

fun treeCombine left node right =
    let
      val left_node = treeAppend left (T node)
    in
      treeAppend left_node right
    end;

(* ------------------------------------------------------------------------- *)
(* Searching a tree for a value.                                             *)
(* ------------------------------------------------------------------------- *)

fun treePeek compareKey pkey tree =
    case tree of
      E => NONE
    | T node => nodePeek compareKey pkey node

and nodePeek compareKey pkey node =
    let
      val Node {left,key,value,right,...} = node
    in
      case compareKey (pkey,key) of
        LESS => treePeek compareKey pkey left
      | EQUAL => SOME value
      | GREATER => treePeek compareKey pkey right
    end;

(* ------------------------------------------------------------------------- *)
(* Tree paths.                                                               *)
(* ------------------------------------------------------------------------- *)

(* Generating a path by searching a tree for a key/value pair *)

fun treePeekPath compareKey pkey path tree =
    case tree of
      E => (path,NONE)
    | T node => nodePeekPath compareKey pkey path node

and nodePeekPath compareKey pkey path node =
    let
      val Node {left,key,right,...} = node
    in
      case compareKey (pkey,key) of
        LESS => treePeekPath compareKey pkey ((true,node) :: path) left
      | EQUAL => (path, SOME node)
      | GREATER => treePeekPath compareKey pkey ((false,node) :: path) right
    end;

(* A path splits a tree into left/right components *)

fun addSidePath ((wentLeft,node),(leftTree,rightTree)) =
    let
      val Node {priority,left,key,value,right,...} = node
    in
      if wentLeft then (leftTree, mkTree priority rightTree key value right)
      else (mkTree priority left key value leftTree, rightTree)
    end;

fun addSidesPath left_right = List.foldl addSidePath left_right;

fun mkSidesPath path = addSidesPath (E,E) path;

(* Updating the subtree at a path *)

local
  fun updateTree ((wentLeft,node),tree) =
      let
        val Node {priority,left,key,value,right,...} = node
      in
        if wentLeft then mkTree priority tree key value right
        else mkTree priority left key value tree
      end;
in
  fun updateTreePath tree = List.foldl updateTree tree;
end;

(* Inserting a new node at a path position *)

fun insertNodePath node =
    let
      fun insert left_right path =
          case path of
            [] =>
            let
              val (left,right) = left_right
            in
              treeCombine left node right
            end
          | (step as (_,snode)) :: rest =>
            if lowerPriorityNode snode node then
              let
                val left_right = addSidePath (step,left_right)
              in
                insert left_right rest
              end
            else
              let
                val (left,right) = left_right

                val tree = treeCombine left node right
              in
                updateTreePath tree path
              end
    in
      insert (E,E)
    end;

(* ------------------------------------------------------------------------- *)
(* Using a key to split a node into three components: the keys comparing     *)
(* less than the supplied key, an optional equal key, and the keys comparing *)
(* greater.                                                                  *)
(* ------------------------------------------------------------------------- *)

fun nodePartition compareKey pkey node =
    let
      val (path,pnode) = nodePeekPath compareKey pkey [] node
    in
      case pnode of
        NONE =>
        let
          val (left,right) = mkSidesPath path
        in
          (left,NONE,right)
        end
      | SOME node =>
        let
          val Node {left,key,value,right,...} = node

          val (left,right) = addSidesPath (left,right) path
        in
          (left, SOME (key,value), right)
        end
    end;

(* ------------------------------------------------------------------------- *)
(* Searching a tree for a key/value pair.                                    *)
(* ------------------------------------------------------------------------- *)

fun treePeekKey compareKey pkey tree =
    case tree of
      E => NONE
    | T node => nodePeekKey compareKey pkey node

and nodePeekKey compareKey pkey node =
    let
      val Node {left,key,value,right,...} = node
    in
      case compareKey (pkey,key) of
        LESS => treePeekKey compareKey pkey left
      | EQUAL => SOME (key,value)
      | GREATER => treePeekKey compareKey pkey right
    end;

(* ------------------------------------------------------------------------- *)
(* Inserting new key/values into the tree.                                   *)
(* ------------------------------------------------------------------------- *)

fun treeInsert compareKey key_value tree =
    let
      val (key,value) = key_value

      val (path,inode) = treePeekPath compareKey key [] tree
    in
      case inode of
        NONE =>
        let
          val node = nodeSingleton (key,value)
        in
          insertNodePath node path
        end
      | SOME node =>
        let
          val Node {size,priority,left,right,...} = node

          val node =
              Node
                {size = size,
                 priority = priority,
                 left = left,
                 key = key,
                 value = value,
                 right = right}
        in
          updateTreePath (T node) path
        end
    end;

(* ------------------------------------------------------------------------- *)
(* Deleting key/value pairs: it raises an exception if the supplied key is   *)
(* not present.                                                              *)
(* ------------------------------------------------------------------------- *)

fun treeDelete compareKey dkey tree =
    case tree of
      E => raise Bug "Map.delete: element not found"
    | T node => nodeDelete compareKey dkey node

and nodeDelete compareKey dkey node =
    let
      val Node {size,priority,left,key,value,right} = node
    in
      case compareKey (dkey,key) of
        LESS =>
        let
          val size = size - 1
          and left = treeDelete compareKey dkey left

          val node =
              Node
                {size = size,
                 priority = priority,
                 left = left,
                 key = key,
                 value = value,
                 right = right}
        in
          T node
        end
      | EQUAL => treeAppend left right
      | GREATER =>
        let
          val size = size - 1
          and right = treeDelete compareKey dkey right

          val node =
              Node
                {size = size,
                 priority = priority,
                 left = left,
                 key = key,
                 value = value,
                 right = right}
        in
          T node
        end
    end;

(* ------------------------------------------------------------------------- *)
(* Partial map is the basic operation for preserving tree structure.         *)
(* It applies its argument function to the elements *in order*.              *)
(* ------------------------------------------------------------------------- *)

fun treeMapPartial f tree =
    case tree of
      E => E
    | T node => nodeMapPartial f node

and nodeMapPartial f (Node {priority,left,key,value,right,...}) =
    let
      val left = treeMapPartial f left
      and vo = f (key,value)
      and right = treeMapPartial f right
    in
      case vo of
        NONE => treeAppend left right
      | SOME value => mkTree priority left key value right
    end;

(* ------------------------------------------------------------------------- *)
(* Mapping tree values.                                                      *)
(* ------------------------------------------------------------------------- *)

fun treeMap f tree =
    case tree of
      E => E
    | T node => T (nodeMap f node)

and nodeMap f node =
    let
      val Node {size,priority,left,key,value,right} = node

      val left = treeMap f left
      and value = f (key,value)
      and right = treeMap f right
    in
      Node
        {size = size,
         priority = priority,
         left = left,
         key = key,
         value = value,
         right = right}
    end;

(* ------------------------------------------------------------------------- *)
(* Merge is the basic operation for joining two trees. Note that the merged  *)
(* key is always the one from the second map.                                *)
(* ------------------------------------------------------------------------- *)

fun treeMerge compareKey f1 f2 fb tree1 tree2 =
    case tree1 of
      E => treeMapPartial f2 tree2
    | T node1 =>
      case tree2 of
        E => treeMapPartial f1 tree1
      | T node2 => nodeMerge compareKey f1 f2 fb node1 node2

and nodeMerge compareKey f1 f2 fb node1 node2 =
    let
      val Node {priority,left,key,value,right,...} = node2

      val (l,kvo,r) = nodePartition compareKey key node1

      val left = treeMerge compareKey f1 f2 fb l left
      and right = treeMerge compareKey f1 f2 fb r right

      val vo =
          case kvo of
            NONE => f2 (key,value)
          | SOME kv => fb (kv,(key,value))
    in
      case vo of
        NONE => treeAppend left right
      | SOME value =>
        let
          val node = mkNodeSingleton priority key value
        in
          treeCombine left node right
        end
    end;

(* ------------------------------------------------------------------------- *)
(* A union operation on trees.                                               *)
(* ------------------------------------------------------------------------- *)

fun treeUnion compareKey f f2 tree1 tree2 =
    case tree1 of
      E => tree2
    | T node1 =>
      case tree2 of
        E => tree1
      | T node2 => nodeUnion compareKey f f2 node1 node2

and nodeUnion compareKey f f2 node1 node2 =
    if pointerEqual (node1,node2) then nodeMapPartial f2 node1
    else
      let
        val Node {priority,left,key,value,right,...} = node2

        val (l,kvo,r) = nodePartition compareKey key node1

        val left = treeUnion compareKey f f2 l left
        and right = treeUnion compareKey f f2 r right

        val vo =
            case kvo of
              NONE => SOME value
            | SOME kv => f (kv,(key,value))
      in
        case vo of
          NONE => treeAppend left right
        | SOME value =>
          let
            val node = mkNodeSingleton priority key value
          in
            treeCombine left node right
          end
      end;

(* ------------------------------------------------------------------------- *)
(* An intersect operation on trees.                                          *)
(* ------------------------------------------------------------------------- *)

fun treeIntersect compareKey f t1 t2 =
    case t1 of
      E => E
    | T n1 =>
      case t2 of
        E => E
      | T n2 => nodeIntersect compareKey f n1 n2

and nodeIntersect compareKey f n1 n2 =
    let
      val Node {priority,left,key,value,right,...} = n2

      val (l,kvo,r) = nodePartition compareKey key n1

      val left = treeIntersect compareKey f l left
      and right = treeIntersect compareKey f r right

      val vo =
          case kvo of
            NONE => NONE
          | SOME kv => f (kv,(key,value))
    in
      case vo of
        NONE => treeAppend left right
      | SOME value => mkTree priority left key value right
    end;

(* ------------------------------------------------------------------------- *)
(* A union operation on trees which simply chooses the second value.         *)
(* ------------------------------------------------------------------------- *)

fun treeUnionDomain compareKey tree1 tree2 =
    case tree1 of
      E => tree2
    | T node1 =>
      case tree2 of
        E => tree1
      | T node2 =>
        if pointerEqual (node1,node2) then tree2
        else nodeUnionDomain compareKey node1 node2

and nodeUnionDomain compareKey node1 node2 =
    let
      val Node {priority,left,key,value,right,...} = node2

      val (l,_,r) = nodePartition compareKey key node1

      val left = treeUnionDomain compareKey l left
      and right = treeUnionDomain compareKey r right

      val node = mkNodeSingleton priority key value
    in
      treeCombine left node right
    end;

(* ------------------------------------------------------------------------- *)
(* An intersect operation on trees which simply chooses the second value.    *)
(* ------------------------------------------------------------------------- *)

fun treeIntersectDomain compareKey tree1 tree2 =
    case tree1 of
      E => E
    | T node1 =>
      case tree2 of
        E => E
      | T node2 =>
        if pointerEqual (node1,node2) then tree2
        else nodeIntersectDomain compareKey node1 node2

and nodeIntersectDomain compareKey node1 node2 =
    let
      val Node {priority,left,key,value,right,...} = node2

      val (l,kvo,r) = nodePartition compareKey key node1

      val left = treeIntersectDomain compareKey l left
      and right = treeIntersectDomain compareKey r right
    in
      if Option.isSome kvo then mkTree priority left key value right
      else treeAppend left right
    end;

(* ------------------------------------------------------------------------- *)
(* A difference operation on trees.                                          *)
(* ------------------------------------------------------------------------- *)

fun treeDifferenceDomain compareKey t1 t2 =
    case t1 of
      E => E
    | T n1 =>
      case t2 of
        E => t1
      | T n2 => nodeDifferenceDomain compareKey n1 n2

and nodeDifferenceDomain compareKey n1 n2 =
    if pointerEqual (n1,n2) then E
    else
      let
        val Node {priority,left,key,value,right,...} = n1

        val (l,kvo,r) = nodePartition compareKey key n2

        val left = treeDifferenceDomain compareKey left l
        and right = treeDifferenceDomain compareKey right r
      in
        if Option.isSome kvo then treeAppend left right
        else mkTree priority left key value right
      end;

(* ------------------------------------------------------------------------- *)
(* A subset operation on trees.                                              *)
(* ------------------------------------------------------------------------- *)

fun treeSubsetDomain compareKey tree1 tree2 =
    case tree1 of
      E => true
    | T node1 =>
      case tree2 of
        E => false
      | T node2 => nodeSubsetDomain compareKey node1 node2

and nodeSubsetDomain compareKey node1 node2 =
    pointerEqual (node1,node2) orelse
    let
      val Node {size,left,key,right,...} = node1
    in
      size <= nodeSize node2 andalso
      let
        val (l,kvo,r) = nodePartition compareKey key node2
      in
        Option.isSome kvo andalso
        treeSubsetDomain compareKey left l andalso
        treeSubsetDomain compareKey right r
      end
    end;

(* ------------------------------------------------------------------------- *)
(* Picking an arbitrary key/value pair from a tree.                          *)
(* ------------------------------------------------------------------------- *)

fun nodePick node =
    let
      val Node {key,value,...} = node
    in
      (key,value)
    end;

fun treePick tree =
    case tree of
      E => raise Bug "Map.treePick"
    | T node => nodePick node;

(* ------------------------------------------------------------------------- *)
(* Removing an arbitrary key/value pair from a tree.                         *)
(* ------------------------------------------------------------------------- *)

fun nodeDeletePick node =
    let
      val Node {left,key,value,right,...} = node
    in
      ((key,value), treeAppend left right)
    end;

fun treeDeletePick tree =
    case tree of
      E => raise Bug "Map.treeDeletePick"
    | T node => nodeDeletePick node;

(* ------------------------------------------------------------------------- *)
(* Finding the nth smallest key/value (counting from 0).                     *)
(* ------------------------------------------------------------------------- *)

fun treeNth n tree =
    case tree of
      E => raise Bug "Map.treeNth"
    | T node => nodeNth n node

and nodeNth n node =
    let
      val Node {left,key,value,right,...} = node

      val k = treeSize left
    in
      if n = k then (key,value)
      else if n < k then treeNth n left
      else treeNth (n - (k + 1)) right
    end;

(* ------------------------------------------------------------------------- *)
(* Removing the nth smallest key/value (counting from 0).                    *)
(* ------------------------------------------------------------------------- *)

fun treeDeleteNth n tree =
    case tree of
      E => raise Bug "Map.treeDeleteNth"
    | T node => nodeDeleteNth n node

and nodeDeleteNth n node =
    let
      val Node {size,priority,left,key,value,right} = node

      val k = treeSize left
    in
      if n = k then ((key,value), treeAppend left right)
      else if n < k then
        let
          val (key_value,left) = treeDeleteNth n left

          val size = size - 1

          val node =
              Node
                {size = size,
                 priority = priority,
                 left = left,
                 key = key,
                 value = value,
                 right = right}
        in
          (key_value, T node)
        end
      else
        let
          val n = n - (k + 1)

          val (key_value,right) = treeDeleteNth n right

          val size = size - 1

          val node =
              Node
                {size = size,
                 priority = priority,
                 left = left,
                 key = key,
                 value = value,
                 right = right}
        in
          (key_value, T node)
        end
    end;

(* ------------------------------------------------------------------------- *)
(* Iterators.                                                                *)
(* ------------------------------------------------------------------------- *)

datatype ('key,'value) iterator =
    LR of ('key * 'value) * ('key,'value) tree * ('key,'value) node list
  | RL of ('key * 'value) * ('key,'value) tree * ('key,'value) node list;

fun fromSpineLR nodes =
    case nodes of
      [] => NONE
    | Node {key,value,right,...} :: nodes =>
      SOME (LR ((key,value),right,nodes));

fun fromSpineRL nodes =
    case nodes of
      [] => NONE
    | Node {key,value,left,...} :: nodes =>
      SOME (RL ((key,value),left,nodes));

fun addLR nodes tree = fromSpineLR (treeLeftSpine nodes tree);

fun addRL nodes tree = fromSpineRL (treeRightSpine nodes tree);

fun treeMkIterator tree = addLR [] tree;

fun treeMkRevIterator tree = addRL [] tree;

fun readIterator iter =
    case iter of
      LR (key_value,_,_) => key_value
    | RL (key_value,_,_) => key_value;

fun advanceIterator iter =
    case iter of
      LR (_,tree,nodes) => addLR nodes tree
    | RL (_,tree,nodes) => addRL nodes tree;

fun foldIterator f acc io =
    case io of
      NONE => acc
    | SOME iter =>
      let
        val (key,value) = readIterator iter
      in
        foldIterator f (f (key,value,acc)) (advanceIterator iter)
      end;

fun findIterator pred io =
    case io of
      NONE => NONE
    | SOME iter =>
      let
        val key_value = readIterator iter
      in
        if pred key_value then SOME key_value
        else findIterator pred (advanceIterator iter)
      end;

fun firstIterator f io =
    case io of
      NONE => NONE
    | SOME iter =>
      let
        val key_value = readIterator iter
      in
        case f key_value of
          NONE => firstIterator f (advanceIterator iter)
        | s => s
      end;

fun compareIterator compareKey compareValue io1 io2 =
    case (io1,io2) of
      (NONE,NONE) => EQUAL
    | (NONE, SOME _) => LESS
    | (SOME _, NONE) => GREATER
    | (SOME i1, SOME i2) =>
      let
        val (k1,v1) = readIterator i1
        and (k2,v2) = readIterator i2
      in
        case compareKey (k1,k2) of
          LESS => LESS
        | EQUAL =>
          (case compareValue (v1,v2) of
             LESS => LESS
           | EQUAL =>
             let
               val io1 = advanceIterator i1
               and io2 = advanceIterator i2
             in
               compareIterator compareKey compareValue io1 io2
             end
           | GREATER => GREATER)
        | GREATER => GREATER
      end;

fun equalIterator equalKey equalValue io1 io2 =
    case (io1,io2) of
      (NONE,NONE) => true
    | (NONE, SOME _) => false
    | (SOME _, NONE) => false
    | (SOME i1, SOME i2) =>
      let
        val (k1,v1) = readIterator i1
        and (k2,v2) = readIterator i2
      in
        equalKey k1 k2 andalso
        equalValue v1 v2 andalso
        let
          val io1 = advanceIterator i1
          and io2 = advanceIterator i2
        in
          equalIterator equalKey equalValue io1 io2
        end
      end;

(* ------------------------------------------------------------------------- *)
(* A type of finite maps.                                                    *)
(* ------------------------------------------------------------------------- *)

datatype ('key,'value) map =
    Map of ('key * 'key -> order) * ('key,'value) tree;

(* ------------------------------------------------------------------------- *)
(* Map debugging functions.                                                  *)
(* ------------------------------------------------------------------------- *)

(*BasicDebug
fun checkInvariants s m =
    let
      val Map (compareKey,tree) = m

      val _ = treeCheckInvariants compareKey tree
    in
      m
    end
    handle Bug bug => raise Bug (s ^ "\n" ^ "Map.checkInvariants: " ^ bug);
*)

(* ------------------------------------------------------------------------- *)
(* Constructors.                                                             *)
(* ------------------------------------------------------------------------- *)

fun new compareKey =
    let
      val tree = treeNew ()
    in
      Map (compareKey,tree)
    end;

fun singleton compareKey key_value =
    let
      val tree = treeSingleton key_value
    in
      Map (compareKey,tree)
    end;

(* ------------------------------------------------------------------------- *)
(* Map size.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun size (Map (_,tree)) = treeSize tree;

fun null m = size m = 0;

(* ------------------------------------------------------------------------- *)
(* Querying.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun peekKey (Map (compareKey,tree)) key = treePeekKey compareKey key tree;

fun peek (Map (compareKey,tree)) key = treePeek compareKey key tree;

fun inDomain key m = Option.isSome (peek m key);

fun get m key =
    case peek m key of
      NONE => raise Error "Map.get: element not found"
    | SOME value => value;

fun pick (Map (_,tree)) = treePick tree;

fun nth (Map (_,tree)) n = treeNth n tree;

fun random m =
    let
      val n = size m
    in
      if n = 0 then raise Bug "Map.random: empty"
      else nth m (randomInt n)
    end;

(* ------------------------------------------------------------------------- *)
(* Adding.                                                                   *)
(* ------------------------------------------------------------------------- *)

fun insert (Map (compareKey,tree)) key_value =
    let
      val tree = treeInsert compareKey key_value tree
    in
      Map (compareKey,tree)
    end;

(*BasicDebug
val insert = fn m => fn kv =>
    checkInvariants "Map.insert: result"
      (insert (checkInvariants "Map.insert: input" m) kv);
*)

fun insertList m =
    let
      fun ins (key_value,acc) = insert acc key_value
    in
      List.foldl ins m
    end;

(* ------------------------------------------------------------------------- *)
(* Removing.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun delete (Map (compareKey,tree)) dkey =
    let
      val tree = treeDelete compareKey dkey tree
    in
      Map (compareKey,tree)
    end;

(*BasicDebug
val delete = fn m => fn k =>
    checkInvariants "Map.delete: result"
      (delete (checkInvariants "Map.delete: input" m) k);
*)

fun remove m key = if inDomain key m then delete m key else m;

fun deletePick (Map (compareKey,tree)) =
    let
      val (key_value,tree) = treeDeletePick tree
    in
      (key_value, Map (compareKey,tree))
    end;

(*BasicDebug
val deletePick = fn m =>
    let
      val (kv,m) = deletePick (checkInvariants "Map.deletePick: input" m)
    in
      (kv, checkInvariants "Map.deletePick: result" m)
    end;
*)

fun deleteNth (Map (compareKey,tree)) n =
    let
      val (key_value,tree) = treeDeleteNth n tree
    in
      (key_value, Map (compareKey,tree))
    end;

(*BasicDebug
val deleteNth = fn m => fn n =>
    let
      val (kv,m) = deleteNth (checkInvariants "Map.deleteNth: input" m) n
    in
      (kv, checkInvariants "Map.deleteNth: result" m)
    end;
*)

fun deleteRandom m =
    let
      val n = size m
    in
      if n = 0 then raise Bug "Map.deleteRandom: empty"
      else deleteNth m (randomInt n)
    end;

(* ------------------------------------------------------------------------- *)
(* Joining (all join operations prefer keys in the second map).              *)
(* ------------------------------------------------------------------------- *)

fun merge {first,second,both} (Map (compareKey,tree1)) (Map (_,tree2)) =
    let
      val tree = treeMerge compareKey first second both tree1 tree2
    in
      Map (compareKey,tree)
    end;

(*BasicDebug
val merge = fn f => fn m1 => fn m2 =>
    checkInvariants "Map.merge: result"
      (merge f
         (checkInvariants "Map.merge: input 1" m1)
         (checkInvariants "Map.merge: input 2" m2));
*)

fun union f (Map (compareKey,tree1)) (Map (_,tree2)) =
    let
      fun f2 kv = f (kv,kv)

      val tree = treeUnion compareKey f f2 tree1 tree2
    in
      Map (compareKey,tree)
    end;

(*BasicDebug
val union = fn f => fn m1 => fn m2 =>
    checkInvariants "Map.union: result"
      (union f
         (checkInvariants "Map.union: input 1" m1)
         (checkInvariants "Map.union: input 2" m2));
*)

fun intersect f (Map (compareKey,tree1)) (Map (_,tree2)) =
    let
      val tree = treeIntersect compareKey f tree1 tree2
    in
      Map (compareKey,tree)
    end;

(*BasicDebug
val intersect = fn f => fn m1 => fn m2 =>
    checkInvariants "Map.intersect: result"
      (intersect f
         (checkInvariants "Map.intersect: input 1" m1)
         (checkInvariants "Map.intersect: input 2" m2));
*)

(* ------------------------------------------------------------------------- *)
(* Iterators over maps.                                                      *)
(* ------------------------------------------------------------------------- *)

fun mkIterator (Map (_,tree)) = treeMkIterator tree;

fun mkRevIterator (Map (_,tree)) = treeMkRevIterator tree;

(* ------------------------------------------------------------------------- *)
(* Mapping and folding.                                                      *)
(* ------------------------------------------------------------------------- *)

fun mapPartial f (Map (compareKey,tree)) =
    let
      val tree = treeMapPartial f tree
    in
      Map (compareKey,tree)
    end;

(*BasicDebug
val mapPartial = fn f => fn m =>
    checkInvariants "Map.mapPartial: result"
      (mapPartial f (checkInvariants "Map.mapPartial: input" m));
*)

fun map f (Map (compareKey,tree)) =
    let
      val tree = treeMap f tree
    in
      Map (compareKey,tree)
    end;

(*BasicDebug
val map = fn f => fn m =>
    checkInvariants "Map.map: result"
      (map f (checkInvariants "Map.map: input" m));
*)

fun transform f = map (fn (_,value) => f value);

fun filter pred =
    let
      fun f (key_value as (_,value)) =
          if pred key_value then SOME value else NONE
    in
      mapPartial f
    end;

fun partition p =
    let
      fun np x = not (p x)
    in
      fn m => (filter p m, filter np m)
    end;

fun foldl f b m = foldIterator f b (mkIterator m);

fun foldr f b m = foldIterator f b (mkRevIterator m);

fun app f m = foldl (fn (key,value,()) => f (key,value)) () m;

(* ------------------------------------------------------------------------- *)
(* Searching.                                                                *)
(* ------------------------------------------------------------------------- *)

fun findl p m = findIterator p (mkIterator m);

fun findr p m = findIterator p (mkRevIterator m);

fun firstl f m = firstIterator f (mkIterator m);

fun firstr f m = firstIterator f (mkRevIterator m);

fun exists p m = Option.isSome (findl p m);

fun all p =
    let
      fun np x = not (p x)
    in
      fn m => not (exists np m)
    end;

fun count pred =
    let
      fun f (k,v,acc) = if pred (k,v) then acc + 1 else acc
    in
      foldl f 0
    end;

(* ------------------------------------------------------------------------- *)
(* Comparing.                                                                *)
(* ------------------------------------------------------------------------- *)

fun compare compareValue (m1,m2) =
    if pointerEqual (m1,m2) then EQUAL
    else
      case Int.compare (size m1, size m2) of
        LESS => LESS
      | EQUAL =>
        let
          val Map (compareKey,_) = m1

          val io1 = mkIterator m1
          and io2 = mkIterator m2
        in
          compareIterator compareKey compareValue io1 io2
        end
      | GREATER => GREATER;

fun equal equalValue m1 m2 =
    pointerEqual (m1,m2) orelse
    (size m1 = size m2 andalso
     let
       val Map (compareKey,_) = m1

       val io1 = mkIterator m1
       and io2 = mkIterator m2
     in
       equalIterator (equalKey compareKey) equalValue io1 io2
     end);

(* ------------------------------------------------------------------------- *)
(* Set operations on the domain.                                             *)
(* ------------------------------------------------------------------------- *)

fun unionDomain (Map (compareKey,tree1)) (Map (_,tree2)) =
    let
      val tree = treeUnionDomain compareKey tree1 tree2
    in
      Map (compareKey,tree)
    end;

(*BasicDebug
val unionDomain = fn m1 => fn m2 =>
    checkInvariants "Map.unionDomain: result"
      (unionDomain
         (checkInvariants "Map.unionDomain: input 1" m1)
         (checkInvariants "Map.unionDomain: input 2" m2));
*)

local
  fun uncurriedUnionDomain (m,acc) = unionDomain acc m;
in
  fun unionListDomain ms =
      case ms of
        [] => raise Bug "Map.unionListDomain: no sets"
      | m :: ms => List.foldl uncurriedUnionDomain m ms;
end;

fun intersectDomain (Map (compareKey,tree1)) (Map (_,tree2)) =
    let
      val tree = treeIntersectDomain compareKey tree1 tree2
    in
      Map (compareKey,tree)
    end;

(*BasicDebug
val intersectDomain = fn m1 => fn m2 =>
    checkInvariants "Map.intersectDomain: result"
      (intersectDomain
         (checkInvariants "Map.intersectDomain: input 1" m1)
         (checkInvariants "Map.intersectDomain: input 2" m2));
*)

local
  fun uncurriedIntersectDomain (m,acc) = intersectDomain acc m;
in
  fun intersectListDomain ms =
      case ms of
        [] => raise Bug "Map.intersectListDomain: no sets"
      | m :: ms => List.foldl uncurriedIntersectDomain m ms;
end;

fun differenceDomain (Map (compareKey,tree1)) (Map (_,tree2)) =
    let
      val tree = treeDifferenceDomain compareKey tree1 tree2
    in
      Map (compareKey,tree)
    end;

(*BasicDebug
val differenceDomain = fn m1 => fn m2 =>
    checkInvariants "Map.differenceDomain: result"
      (differenceDomain
         (checkInvariants "Map.differenceDomain: input 1" m1)
         (checkInvariants "Map.differenceDomain: input 2" m2));
*)

fun symmetricDifferenceDomain m1 m2 =
    unionDomain (differenceDomain m1 m2) (differenceDomain m2 m1);

fun equalDomain m1 m2 = equal (K (K true)) m1 m2;

fun subsetDomain (Map (compareKey,tree1)) (Map (_,tree2)) =
    treeSubsetDomain compareKey tree1 tree2;

fun disjointDomain m1 m2 = null (intersectDomain m1 m2);

(* ------------------------------------------------------------------------- *)
(* Converting to and from lists.                                             *)
(* ------------------------------------------------------------------------- *)

fun keys m = foldr (fn (key,_,l) => key :: l) [] m;

fun values m = foldr (fn (_,value,l) => value :: l) [] m;

fun toList m = foldr (fn (key,value,l) => (key,value) :: l) [] m;

fun fromList compareKey l =
    let
      val m = new compareKey
    in
      insertList m l
    end;

(* ------------------------------------------------------------------------- *)
(* Pretty-printing.                                                          *)
(* ------------------------------------------------------------------------- *)

fun toString m = "<" ^ (if null m then "" else Int.toString (size m)) ^ ">";

end
end;

(**** Original file: KeyMap.sig ****)

(* ========================================================================= *)
(* FINITE MAPS WITH A FIXED KEY TYPE                                         *)
(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature KeyMap =
sig

(* ------------------------------------------------------------------------- *)
(* A type of map keys.                                                       *)
(* ------------------------------------------------------------------------- *)

type key

(* ------------------------------------------------------------------------- *)
(* A type of finite maps.                                                    *)
(* ------------------------------------------------------------------------- *)

type 'a map

(* ------------------------------------------------------------------------- *)
(* Constructors.                                                             *)
(* ------------------------------------------------------------------------- *)

val new : unit -> 'a map

val singleton : key * 'a -> 'a map

(* ------------------------------------------------------------------------- *)
(* Map size.                                                                 *)
(* ------------------------------------------------------------------------- *)

val null : 'a map -> bool

val size : 'a map -> int

(* ------------------------------------------------------------------------- *)
(* Querying.                                                                 *)
(* ------------------------------------------------------------------------- *)

val peekKey : 'a map -> key -> (key * 'a) option

val peek : 'a map -> key -> 'a option

val get : 'a map -> key -> 'a  (* raises Error *)

val pick : 'a map -> key * 'a  (* an arbitrary key/value pair *)

val nth : 'a map -> int -> key * 'a  (* in the range [0,size-1] *)

val random : 'a map -> key * 'a

(* ------------------------------------------------------------------------- *)
(* Adding.                                                                   *)
(* ------------------------------------------------------------------------- *)

val insert : 'a map -> key * 'a -> 'a map

val insertList : 'a map -> (key * 'a) list -> 'a map

(* ------------------------------------------------------------------------- *)
(* Removing.                                                                 *)
(* ------------------------------------------------------------------------- *)

val delete : 'a map -> key -> 'a map  (* must be present *)

val remove : 'a map -> key -> 'a map

val deletePick : 'a map -> (key * 'a) * 'a map

val deleteNth : 'a map -> int -> (key * 'a) * 'a map

val deleteRandom : 'a map -> (key * 'a) * 'a map

(* ------------------------------------------------------------------------- *)
(* Joining (all join operations prefer keys in the second map).              *)
(* ------------------------------------------------------------------------- *)

val merge :
    {first : key * 'a -> 'c option,
     second : key * 'b -> 'c option,
     both : (key * 'a) * (key * 'b) -> 'c option} ->
    'a map -> 'b map -> 'c map

val union :
    ((key * 'a) * (key * 'a) -> 'a option) ->
    'a map -> 'a map -> 'a map

val intersect :
    ((key * 'a) * (key * 'b) -> 'c option) ->
    'a map -> 'b map -> 'c map

(* ------------------------------------------------------------------------- *)
(* Set operations on the domain.                                             *)
(* ------------------------------------------------------------------------- *)

val inDomain : key -> 'a map -> bool

val unionDomain : 'a map -> 'a map -> 'a map

val unionListDomain : 'a map list -> 'a map

val intersectDomain : 'a map -> 'a map -> 'a map

val intersectListDomain : 'a map list -> 'a map

val differenceDomain : 'a map -> 'a map -> 'a map

val symmetricDifferenceDomain : 'a map -> 'a map -> 'a map

val equalDomain : 'a map -> 'a map -> bool

val subsetDomain : 'a map -> 'a map -> bool

val disjointDomain : 'a map -> 'a map -> bool

(* ------------------------------------------------------------------------- *)
(* Mapping and folding.                                                      *)
(* ------------------------------------------------------------------------- *)

val mapPartial : (key * 'a -> 'b option) -> 'a map -> 'b map

val map : (key * 'a -> 'b) -> 'a map -> 'b map

val app : (key * 'a -> unit) -> 'a map -> unit

val transform : ('a -> 'b) -> 'a map -> 'b map

val filter : (key * 'a -> bool) -> 'a map -> 'a map

val partition :
    (key * 'a -> bool) -> 'a map -> 'a map * 'a map

val foldl : (key * 'a * 's -> 's) -> 's -> 'a map -> 's

val foldr : (key * 'a * 's -> 's) -> 's -> 'a map -> 's

(* ------------------------------------------------------------------------- *)
(* Searching.                                                                *)
(* ------------------------------------------------------------------------- *)

val findl : (key * 'a -> bool) -> 'a map -> (key * 'a) option

val findr : (key * 'a -> bool) -> 'a map -> (key * 'a) option

val firstl : (key * 'a -> 'b option) -> 'a map -> 'b option

val firstr : (key * 'a -> 'b option) -> 'a map -> 'b option

val exists : (key * 'a -> bool) -> 'a map -> bool

val all : (key * 'a -> bool) -> 'a map -> bool

val count : (key * 'a -> bool) -> 'a map -> int

(* ------------------------------------------------------------------------- *)
(* Comparing.                                                                *)
(* ------------------------------------------------------------------------- *)

val compare : ('a * 'a -> order) -> 'a map * 'a map -> order

val equal : ('a -> 'a -> bool) -> 'a map -> 'a map -> bool

(* ------------------------------------------------------------------------- *)
(* Converting to and from lists.                                             *)
(* ------------------------------------------------------------------------- *)

val keys : 'a map -> key list

val values : 'a map -> 'a list

val toList : 'a map -> (key * 'a) list

val fromList : (key * 'a) list -> 'a map

(* ------------------------------------------------------------------------- *)
(* Pretty-printing.                                                          *)
(* ------------------------------------------------------------------------- *)

val toString : 'a map -> string

(* ------------------------------------------------------------------------- *)
(* Iterators over maps.                                                      *)
(* ------------------------------------------------------------------------- *)

type 'a iterator

val mkIterator : 'a map -> 'a iterator option

val mkRevIterator : 'a map -> 'a iterator option

val readIterator : 'a iterator -> key * 'a

val advanceIterator : 'a iterator -> 'a iterator option

end

(**** Original file: KeyMap.sml ****)

(* ========================================================================= *)
(* FINITE MAPS WITH A FIXED KEY TYPE                                         *)
(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

functor KeyMap (Key : Ordered) :> KeyMap where type key = Key.t =
struct

(* ------------------------------------------------------------------------- *)
(* Importing from the input signature.                                       *)
(* ------------------------------------------------------------------------- *)

open Metis; (* MODIFIED by Jasmin Blanchette *)

type key = Key.t;

val compareKey = Key.compare;

(* ------------------------------------------------------------------------- *)
(* Importing useful functionality.                                           *)
(* ------------------------------------------------------------------------- *)

exception Bug = Useful.Bug;

exception Error = Useful.Error;

val pointerEqual = Portable.pointerEqual;

val K = Useful.K;

val randomInt = Portable.randomInt;

val randomWord = Portable.randomWord;

(* ------------------------------------------------------------------------- *)
(* Converting a comparison function to an equality function.                 *)
(* ------------------------------------------------------------------------- *)

fun equalKey key1 key2 = compareKey (key1,key2) = EQUAL;

(* ------------------------------------------------------------------------- *)
(* Priorities.                                                               *)
(* ------------------------------------------------------------------------- *)

type priority = Word.word;

val randomPriority = randomWord;

val comparePriority = Word.compare;

(* ------------------------------------------------------------------------- *)
(* Priority search trees.                                                    *)
(* ------------------------------------------------------------------------- *)

datatype 'value tree =
    E
  | T of 'value node

and 'value node =
    Node of
      {size : int,
       priority : priority,
       left : 'value tree,
       key : key,
       value : 'value,
       right : 'value tree};

fun lowerPriorityNode node1 node2 =
    let
      val Node {priority = p1, ...} = node1
      and Node {priority = p2, ...} = node2
    in
      comparePriority (p1,p2) = LESS
    end;

(* ------------------------------------------------------------------------- *)
(* Tree debugging functions.                                                 *)
(* ------------------------------------------------------------------------- *)

(*BasicDebug
local
  fun checkSizes tree =
      case tree of
        E => 0
      | T (Node {size,left,right,...}) =>
        let
          val l = checkSizes left
          and r = checkSizes right

          val () = if l + 1 + r = size then () else raise Bug "wrong size"
        in
          size
        end;

  fun checkSorted x tree =
      case tree of
        E => x
      | T (Node {left,key,right,...}) =>
        let
          val x = checkSorted x left

          val () =
              case x of
                NONE => ()
              | SOME k =>
                case compareKey (k,key) of
                  LESS => ()
                | EQUAL => raise Bug "duplicate keys"
                | GREATER => raise Bug "unsorted"

          val x = SOME key
        in
          checkSorted x right
        end;

  fun checkPriorities tree =
      case tree of
        E => NONE
      | T node =>
        let
          val Node {left,right,...} = node

          val () =
              case checkPriorities left of
                NONE => ()
              | SOME lnode =>
                if not (lowerPriorityNode node lnode) then ()
                else raise Bug "left child has greater priority"

          val () =
              case checkPriorities right of
                NONE => ()
              | SOME rnode =>
                if not (lowerPriorityNode node rnode) then ()
                else raise Bug "right child has greater priority"
        in
          SOME node
        end;
in
  fun treeCheckInvariants tree =
      let
        val _ = checkSizes tree

        val _ = checkSorted NONE tree

        val _ = checkPriorities tree
      in
        tree
      end
      handle Error err => raise Bug err;
end;
*)

(* ------------------------------------------------------------------------- *)
(* Tree operations.                                                          *)
(* ------------------------------------------------------------------------- *)

fun treeNew () = E;

fun nodeSize (Node {size = x, ...}) = x;

fun treeSize tree =
    case tree of
      E => 0
    | T x => nodeSize x;

fun mkNode priority left key value right =
    let
      val size = treeSize left + 1 + treeSize right
    in
      Node
        {size = size,
         priority = priority,
         left = left,
         key = key,
         value = value,
         right = right}
    end;

fun mkTree priority left key value right =
    let
      val node = mkNode priority left key value right
    in
      T node
    end;

(* ------------------------------------------------------------------------- *)
(* Extracting the left and right spines of a tree.                           *)
(* ------------------------------------------------------------------------- *)

fun treeLeftSpine acc tree =
    case tree of
      E => acc
    | T node => nodeLeftSpine acc node

and nodeLeftSpine acc node =
    let
      val Node {left,...} = node
    in
      treeLeftSpine (node :: acc) left
    end;

fun treeRightSpine acc tree =
    case tree of
      E => acc
    | T node => nodeRightSpine acc node

and nodeRightSpine acc node =
    let
      val Node {right,...} = node
    in
      treeRightSpine (node :: acc) right
    end;

(* ------------------------------------------------------------------------- *)
(* Singleton trees.                                                          *)
(* ------------------------------------------------------------------------- *)

fun mkNodeSingleton priority key value =
    let
      val size = 1
      and left = E
      and right = E
    in
      Node
        {size = size,
         priority = priority,
         left = left,
         key = key,
         value = value,
         right = right}
    end;

fun nodeSingleton (key,value) =
    let
      val priority = randomPriority ()
    in
      mkNodeSingleton priority key value
    end;

fun treeSingleton key_value =
    let
      val node = nodeSingleton key_value
    in
      T node
    end;

(* ------------------------------------------------------------------------- *)
(* Appending two trees, where every element of the first tree is less than   *)
(* every element of the second tree.                                         *)
(* ------------------------------------------------------------------------- *)

fun treeAppend tree1 tree2 =
    case tree1 of
      E => tree2
    | T node1 =>
      case tree2 of
        E => tree1
      | T node2 =>
        if lowerPriorityNode node1 node2 then
          let
            val Node {priority,left,key,value,right,...} = node2

            val left = treeAppend tree1 left
          in
            mkTree priority left key value right
          end
        else
          let
            val Node {priority,left,key,value,right,...} = node1

            val right = treeAppend right tree2
          in
            mkTree priority left key value right
          end;

(* ------------------------------------------------------------------------- *)
(* Appending two trees and a node, where every element of the first tree is  *)
(* less than the node, which in turn is less than every element of the       *)
(* second tree.                                                              *)
(* ------------------------------------------------------------------------- *)

fun treeCombine left node right =
    let
      val left_node = treeAppend left (T node)
    in
      treeAppend left_node right
    end;

(* ------------------------------------------------------------------------- *)
(* Searching a tree for a value.                                             *)
(* ------------------------------------------------------------------------- *)

fun treePeek pkey tree =
    case tree of
      E => NONE
    | T node => nodePeek pkey node

and nodePeek pkey node =
    let
      val Node {left,key,value,right,...} = node
    in
      case compareKey (pkey,key) of
        LESS => treePeek pkey left
      | EQUAL => SOME value
      | GREATER => treePeek pkey right
    end;

(* ------------------------------------------------------------------------- *)
(* Tree paths.                                                               *)
(* ------------------------------------------------------------------------- *)

(* Generating a path by searching a tree for a key/value pair *)

fun treePeekPath pkey path tree =
    case tree of
      E => (path,NONE)
    | T node => nodePeekPath pkey path node

and nodePeekPath pkey path node =
    let
      val Node {left,key,right,...} = node
    in
      case compareKey (pkey,key) of
        LESS => treePeekPath pkey ((true,node) :: path) left
      | EQUAL => (path, SOME node)
      | GREATER => treePeekPath pkey ((false,node) :: path) right
    end;

(* A path splits a tree into left/right components *)

fun addSidePath ((wentLeft,node),(leftTree,rightTree)) =
    let
      val Node {priority,left,key,value,right,...} = node
    in
      if wentLeft then (leftTree, mkTree priority rightTree key value right)
      else (mkTree priority left key value leftTree, rightTree)
    end;

fun addSidesPath left_right = List.foldl addSidePath left_right;

fun mkSidesPath path = addSidesPath (E,E) path;

(* Updating the subtree at a path *)

local
  fun updateTree ((wentLeft,node),tree) =
      let
        val Node {priority,left,key,value,right,...} = node
      in
        if wentLeft then mkTree priority tree key value right
        else mkTree priority left key value tree
      end;
in
  fun updateTreePath tree = List.foldl updateTree tree;
end;

(* Inserting a new node at a path position *)

fun insertNodePath node =
    let
      fun insert left_right path =
          case path of
            [] =>
            let
              val (left,right) = left_right
            in
              treeCombine left node right
            end
          | (step as (_,snode)) :: rest =>
            if lowerPriorityNode snode node then
              let
                val left_right = addSidePath (step,left_right)
              in
                insert left_right rest
              end
            else
              let
                val (left,right) = left_right

                val tree = treeCombine left node right
              in
                updateTreePath tree path
              end
    in
      insert (E,E)
    end;

(* ------------------------------------------------------------------------- *)
(* Using a key to split a node into three components: the keys comparing     *)
(* less than the supplied key, an optional equal key, and the keys comparing *)
(* greater.                                                                  *)
(* ------------------------------------------------------------------------- *)

fun nodePartition pkey node =
    let
      val (path,pnode) = nodePeekPath pkey [] node
    in
      case pnode of
        NONE =>
        let
          val (left,right) = mkSidesPath path
        in
          (left,NONE,right)
        end
      | SOME node =>
        let
          val Node {left,key,value,right,...} = node

          val (left,right) = addSidesPath (left,right) path
        in
          (left, SOME (key,value), right)
        end
    end;

(* ------------------------------------------------------------------------- *)
(* Searching a tree for a key/value pair.                                    *)
(* ------------------------------------------------------------------------- *)

fun treePeekKey pkey tree =
    case tree of
      E => NONE
    | T node => nodePeekKey pkey node

and nodePeekKey pkey node =
    let
      val Node {left,key,value,right,...} = node
    in
      case compareKey (pkey,key) of
        LESS => treePeekKey pkey left
      | EQUAL => SOME (key,value)
      | GREATER => treePeekKey pkey right
    end;

(* ------------------------------------------------------------------------- *)
(* Inserting new key/values into the tree.                                   *)
(* ------------------------------------------------------------------------- *)

fun treeInsert key_value tree =
    let
      val (key,value) = key_value

      val (path,inode) = treePeekPath key [] tree
    in
      case inode of
        NONE =>
        let
          val node = nodeSingleton (key,value)
        in
          insertNodePath node path
        end
      | SOME node =>
        let
          val Node {size,priority,left,right,...} = node

          val node =
              Node
                {size = size,
                 priority = priority,
                 left = left,
                 key = key,
                 value = value,
                 right = right}
        in
          updateTreePath (T node) path
        end
    end;

(* ------------------------------------------------------------------------- *)
(* Deleting key/value pairs: it raises an exception if the supplied key is   *)
(* not present.                                                              *)
(* ------------------------------------------------------------------------- *)

fun treeDelete dkey tree =
    case tree of
      E => raise Bug "KeyMap.delete: element not found"
    | T node => nodeDelete dkey node

and nodeDelete dkey node =
    let
      val Node {size,priority,left,key,value,right} = node
    in
      case compareKey (dkey,key) of
        LESS =>
        let
          val size = size - 1
          and left = treeDelete dkey left

          val node =
              Node
                {size = size,
                 priority = priority,
                 left = left,
                 key = key,
                 value = value,
                 right = right}
        in
          T node
        end
      | EQUAL => treeAppend left right
      | GREATER =>
        let
          val size = size - 1
          and right = treeDelete dkey right

          val node =
              Node
                {size = size,
                 priority = priority,
                 left = left,
                 key = key,
                 value = value,
                 right = right}
        in
          T node
        end
    end;

(* ------------------------------------------------------------------------- *)
(* Partial map is the basic operation for preserving tree structure.         *)
(* It applies its argument function to the elements *in order*.              *)
(* ------------------------------------------------------------------------- *)

fun treeMapPartial f tree =
    case tree of
      E => E
    | T node => nodeMapPartial f node

and nodeMapPartial f (Node {priority,left,key,value,right,...}) =
    let
      val left = treeMapPartial f left
      and vo = f (key,value)
      and right = treeMapPartial f right
    in
      case vo of
        NONE => treeAppend left right
      | SOME value => mkTree priority left key value right
    end;

(* ------------------------------------------------------------------------- *)
(* Mapping tree values.                                                      *)
(* ------------------------------------------------------------------------- *)

fun treeMap f tree =
    case tree of
      E => E
    | T node => T (nodeMap f node)

and nodeMap f node =
    let
      val Node {size,priority,left,key,value,right} = node

      val left = treeMap f left
      and value = f (key,value)
      and right = treeMap f right
    in
      Node
        {size = size,
         priority = priority,
         left = left,
         key = key,
         value = value,
         right = right}
    end;

(* ------------------------------------------------------------------------- *)
(* Merge is the basic operation for joining two trees. Note that the merged  *)
(* key is always the one from the second map.                                *)
(* ------------------------------------------------------------------------- *)

fun treeMerge f1 f2 fb tree1 tree2 =
    case tree1 of
      E => treeMapPartial f2 tree2
    | T node1 =>
      case tree2 of
        E => treeMapPartial f1 tree1
      | T node2 => nodeMerge f1 f2 fb node1 node2

and nodeMerge f1 f2 fb node1 node2 =
    let
      val Node {priority,left,key,value,right,...} = node2

      val (l,kvo,r) = nodePartition key node1

      val left = treeMerge f1 f2 fb l left
      and right = treeMerge f1 f2 fb r right

      val vo =
          case kvo of
            NONE => f2 (key,value)
          | SOME kv => fb (kv,(key,value))
    in
      case vo of
        NONE => treeAppend left right
      | SOME value =>
        let
          val node = mkNodeSingleton priority key value
        in
          treeCombine left node right
        end
    end;

(* ------------------------------------------------------------------------- *)
(* A op union operation on trees.                                               *)
(* ------------------------------------------------------------------------- *)

fun treeUnion f f2 tree1 tree2 =
    case tree1 of
      E => tree2
    | T node1 =>
      case tree2 of
        E => tree1
      | T node2 => nodeUnion f f2 node1 node2

and nodeUnion f f2 node1 node2 =
    if pointerEqual (node1,node2) then nodeMapPartial f2 node1
    else
      let
        val Node {priority,left,key,value,right,...} = node2

        val (l,kvo,r) = nodePartition key node1

        val left = treeUnion f f2 l left
        and right = treeUnion f f2 r right

        val vo =
            case kvo of
              NONE => SOME value
            | SOME kv => f (kv,(key,value))
      in
        case vo of
          NONE => treeAppend left right
        | SOME value =>
          let
            val node = mkNodeSingleton priority key value
          in
            treeCombine left node right
          end
      end;

(* ------------------------------------------------------------------------- *)
(* An intersect operation on trees.                                          *)
(* ------------------------------------------------------------------------- *)

fun treeIntersect f t1 t2 =
    case t1 of
      E => E
    | T n1 =>
      case t2 of
        E => E
      | T n2 => nodeIntersect f n1 n2

and nodeIntersect f n1 n2 =
    let
      val Node {priority,left,key,value,right,...} = n2

      val (l,kvo,r) = nodePartition key n1

      val left = treeIntersect f l left
      and right = treeIntersect f r right

      val vo =
          case kvo of
            NONE => NONE
          | SOME kv => f (kv,(key,value))
    in
      case vo of
        NONE => treeAppend left right
      | SOME value => mkTree priority left key value right
    end;

(* ------------------------------------------------------------------------- *)
(* A op union operation on trees which simply chooses the second value.         *)
(* ------------------------------------------------------------------------- *)

fun treeUnionDomain tree1 tree2 =
    case tree1 of
      E => tree2
    | T node1 =>
      case tree2 of
        E => tree1
      | T node2 =>
        if pointerEqual (node1,node2) then tree2
        else nodeUnionDomain node1 node2

and nodeUnionDomain node1 node2 =
    let
      val Node {priority,left,key,value,right,...} = node2

      val (l,_,r) = nodePartition key node1

      val left = treeUnionDomain l left
      and right = treeUnionDomain r right

      val node = mkNodeSingleton priority key value
    in
      treeCombine left node right
    end;

(* ------------------------------------------------------------------------- *)
(* An intersect operation on trees which simply chooses the second value.    *)
(* ------------------------------------------------------------------------- *)

fun treeIntersectDomain tree1 tree2 =
    case tree1 of
      E => E
    | T node1 =>
      case tree2 of
        E => E
      | T node2 =>
        if pointerEqual (node1,node2) then tree2
        else nodeIntersectDomain node1 node2

and nodeIntersectDomain node1 node2 =
    let
      val Node {priority,left,key,value,right,...} = node2

      val (l,kvo,r) = nodePartition key node1

      val left = treeIntersectDomain l left
      and right = treeIntersectDomain r right
    in
      if Option.isSome kvo then mkTree priority left key value right
      else treeAppend left right
    end;

(* ------------------------------------------------------------------------- *)
(* A difference operation on trees.                                          *)
(* ------------------------------------------------------------------------- *)

fun treeDifferenceDomain t1 t2 =
    case t1 of
      E => E
    | T n1 =>
      case t2 of
        E => t1
      | T n2 => nodeDifferenceDomain n1 n2

and nodeDifferenceDomain n1 n2 =
    if pointerEqual (n1,n2) then E
    else
      let
        val Node {priority,left,key,value,right,...} = n1

        val (l,kvo,r) = nodePartition key n2

        val left = treeDifferenceDomain left l
        and right = treeDifferenceDomain right r
      in
        if Option.isSome kvo then treeAppend left right
        else mkTree priority left key value right
      end;

(* ------------------------------------------------------------------------- *)
(* A op subset operation on trees.                                              *)
(* ------------------------------------------------------------------------- *)

fun treeSubsetDomain tree1 tree2 =
    case tree1 of
      E => true
    | T node1 =>
      case tree2 of
        E => false
      | T node2 => nodeSubsetDomain node1 node2

and nodeSubsetDomain node1 node2 =
    pointerEqual (node1,node2) orelse
    let
      val Node {size,left,key,right,...} = node1
    in
      size <= nodeSize node2 andalso
      let
        val (l,kvo,r) = nodePartition key node2
      in
        Option.isSome kvo andalso
        treeSubsetDomain left l andalso
        treeSubsetDomain right r
      end
    end;

(* ------------------------------------------------------------------------- *)
(* Picking an arbitrary key/value pair from a tree.                          *)
(* ------------------------------------------------------------------------- *)

fun nodePick node =
    let
      val Node {key,value,...} = node
    in
      (key,value)
    end;

fun treePick tree =
    case tree of
      E => raise Bug "KeyMap.treePick"
    | T node => nodePick node;

(* ------------------------------------------------------------------------- *)
(* Removing an arbitrary key/value pair from a tree.                         *)
(* ------------------------------------------------------------------------- *)

fun nodeDeletePick node =
    let
      val Node {left,key,value,right,...} = node
    in
      ((key,value), treeAppend left right)
    end;

fun treeDeletePick tree =
    case tree of
      E => raise Bug "KeyMap.treeDeletePick"
    | T node => nodeDeletePick node;

(* ------------------------------------------------------------------------- *)
(* Finding the nth smallest key/value (counting from 0).                     *)
(* ------------------------------------------------------------------------- *)

fun treeNth n tree =
    case tree of
      E => raise Bug "KeyMap.treeNth"
    | T node => nodeNth n node

and nodeNth n node =
    let
      val Node {left,key,value,right,...} = node

      val k = treeSize left
    in
      if n = k then (key,value)
      else if n < k then treeNth n left
      else treeNth (n - (k + 1)) right
    end;

(* ------------------------------------------------------------------------- *)
(* Removing the nth smallest key/value (counting from 0).                    *)
(* ------------------------------------------------------------------------- *)

fun treeDeleteNth n tree =
    case tree of
      E => raise Bug "KeyMap.treeDeleteNth"
    | T node => nodeDeleteNth n node

and nodeDeleteNth n node =
    let
      val Node {size,priority,left,key,value,right} = node

      val k = treeSize left
    in
      if n = k then ((key,value), treeAppend left right)
      else if n < k then
        let
          val (key_value,left) = treeDeleteNth n left

          val size = size - 1

          val node =
              Node
                {size = size,
                 priority = priority,
                 left = left,
                 key = key,
                 value = value,
                 right = right}
        in
          (key_value, T node)
        end
      else
        let
          val n = n - (k + 1)

          val (key_value,right) = treeDeleteNth n right

          val size = size - 1

          val node =
              Node
                {size = size,
                 priority = priority,
                 left = left,
                 key = key,
                 value = value,
                 right = right}
        in
          (key_value, T node)
        end
    end;

(* ------------------------------------------------------------------------- *)
(* Iterators.                                                                *)
(* ------------------------------------------------------------------------- *)

datatype 'value iterator =
    LR of (key * 'value) * 'value tree * 'value node list
  | RL of (key * 'value) * 'value tree * 'value node list;

fun fromSpineLR nodes =
    case nodes of
      [] => NONE
    | Node {key,value,right,...} :: nodes =>
      SOME (LR ((key,value),right,nodes));

fun fromSpineRL nodes =
    case nodes of
      [] => NONE
    | Node {key,value,left,...} :: nodes =>
      SOME (RL ((key,value),left,nodes));

fun addLR nodes tree = fromSpineLR (treeLeftSpine nodes tree);

fun addRL nodes tree = fromSpineRL (treeRightSpine nodes tree);

fun treeMkIterator tree = addLR [] tree;

fun treeMkRevIterator tree = addRL [] tree;

fun readIterator iter =
    case iter of
      LR (key_value,_,_) => key_value
    | RL (key_value,_,_) => key_value;

fun advanceIterator iter =
    case iter of
      LR (_,tree,nodes) => addLR nodes tree
    | RL (_,tree,nodes) => addRL nodes tree;

fun foldIterator f acc io =
    case io of
      NONE => acc
    | SOME iter =>
      let
        val (key,value) = readIterator iter
      in
        foldIterator f (f (key,value,acc)) (advanceIterator iter)
      end;

fun findIterator pred io =
    case io of
      NONE => NONE
    | SOME iter =>
      let
        val key_value = readIterator iter
      in
        if pred key_value then SOME key_value
        else findIterator pred (advanceIterator iter)
      end;

fun firstIterator f io =
    case io of
      NONE => NONE
    | SOME iter =>
      let
        val key_value = readIterator iter
      in
        case f key_value of
          NONE => firstIterator f (advanceIterator iter)
        | s => s
      end;

fun compareIterator compareValue io1 io2 =
    case (io1,io2) of
      (NONE,NONE) => EQUAL
    | (NONE, SOME _) => LESS
    | (SOME _, NONE) => GREATER
    | (SOME i1, SOME i2) =>
      let
        val (k1,v1) = readIterator i1
        and (k2,v2) = readIterator i2
      in
        case compareKey (k1,k2) of
          LESS => LESS
        | EQUAL =>
          (case compareValue (v1,v2) of
             LESS => LESS
           | EQUAL =>
             let
               val io1 = advanceIterator i1
               and io2 = advanceIterator i2
             in
               compareIterator compareValue io1 io2
             end
           | GREATER => GREATER)
        | GREATER => GREATER
      end;

fun equalIterator equalValue io1 io2 =
    case (io1,io2) of
      (NONE,NONE) => true
    | (NONE, SOME _) => false
    | (SOME _, NONE) => false
    | (SOME i1, SOME i2) =>
      let
        val (k1,v1) = readIterator i1
        and (k2,v2) = readIterator i2
      in
        equalKey k1 k2 andalso
        equalValue v1 v2 andalso
        let
          val io1 = advanceIterator i1
          and io2 = advanceIterator i2
        in
          equalIterator equalValue io1 io2
        end
      end;

(* ------------------------------------------------------------------------- *)
(* A type of finite maps.                                                    *)
(* ------------------------------------------------------------------------- *)

datatype 'value map =
    Map of 'value tree;

(* ------------------------------------------------------------------------- *)
(* Map debugging functions.                                                  *)
(* ------------------------------------------------------------------------- *)

(*BasicDebug
fun checkInvariants s m =
    let
      val Map tree = m

      val _ = treeCheckInvariants tree
    in
      m
    end
    handle Bug bug => raise Bug (s ^ "\n" ^ "KeyMap.checkInvariants: " ^ bug);
*)

(* ------------------------------------------------------------------------- *)
(* Constructors.                                                             *)
(* ------------------------------------------------------------------------- *)

fun new () =
    let
      val tree = treeNew ()
    in
      Map tree
    end;

fun singleton key_value =
    let
      val tree = treeSingleton key_value
    in
      Map tree
    end;

(* ------------------------------------------------------------------------- *)
(* Map size.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun size (Map tree) = treeSize tree;

fun null m = size m = 0;

(* ------------------------------------------------------------------------- *)
(* Querying.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun peekKey (Map tree) key = treePeekKey key tree;

fun peek (Map tree) key = treePeek key tree;

fun inDomain key m = Option.isSome (peek m key);

fun get m key =
    case peek m key of
      NONE => raise Error "KeyMap.get: element not found"
    | SOME value => value;

fun pick (Map tree) = treePick tree;

fun nth (Map tree) n = treeNth n tree;

fun random m =
    let
      val n = size m
    in
      if n = 0 then raise Bug "KeyMap.random: empty"
      else nth m (randomInt n)
    end;

(* ------------------------------------------------------------------------- *)
(* Adding.                                                                   *)
(* ------------------------------------------------------------------------- *)

fun insert (Map tree) key_value =
    let
      val tree = treeInsert key_value tree
    in
      Map tree
    end;

(*BasicDebug
val insert = fn m => fn kv =>
    checkInvariants "KeyMap.insert: result"
      (insert (checkInvariants "KeyMap.insert: input" m) kv);
*)

fun insertList m =
    let
      fun ins (key_value,acc) = insert acc key_value
    in
      List.foldl ins m
    end;

(* ------------------------------------------------------------------------- *)
(* Removing.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun delete (Map tree) dkey =
    let
      val tree = treeDelete dkey tree
    in
      Map tree
    end;

(*BasicDebug
val delete = fn m => fn k =>
    checkInvariants "KeyMap.delete: result"
      (delete (checkInvariants "KeyMap.delete: input" m) k);
*)

fun remove m key = if inDomain key m then delete m key else m;

fun deletePick (Map tree) =
    let
      val (key_value,tree) = treeDeletePick tree
    in
      (key_value, Map tree)
    end;

(*BasicDebug
val deletePick = fn m =>
    let
      val (kv,m) = deletePick (checkInvariants "KeyMap.deletePick: input" m)
    in
      (kv, checkInvariants "KeyMap.deletePick: result" m)
    end;
*)

fun deleteNth (Map tree) n =
    let
      val (key_value,tree) = treeDeleteNth n tree
    in
      (key_value, Map tree)
    end;

(*BasicDebug
val deleteNth = fn m => fn n =>
    let
      val (kv,m) = deleteNth (checkInvariants "KeyMap.deleteNth: input" m) n
    in
      (kv, checkInvariants "KeyMap.deleteNth: result" m)
    end;
*)

fun deleteRandom m =
    let
      val n = size m
    in
      if n = 0 then raise Bug "KeyMap.deleteRandom: empty"
      else deleteNth m (randomInt n)
    end;

(* ------------------------------------------------------------------------- *)
(* Joining (all join operations prefer keys in the second map).              *)
(* ------------------------------------------------------------------------- *)

fun merge {first,second,both} (Map tree1) (Map tree2) =
    let
      val tree = treeMerge first second both tree1 tree2
    in
      Map tree
    end;

(*BasicDebug
val merge = fn f => fn m1 => fn m2 =>
    checkInvariants "KeyMap.merge: result"
      (merge f
         (checkInvariants "KeyMap.merge: input 1" m1)
         (checkInvariants "KeyMap.merge: input 2" m2));
*)

fun op union f (Map tree1) (Map tree2) =
    let
      fun f2 kv = f (kv,kv)

      val tree = treeUnion f f2 tree1 tree2
    in
      Map tree
    end;

(*BasicDebug
val op union = fn f => fn m1 => fn m2 =>
    checkInvariants "KeyMap.union: result"
      (union f
         (checkInvariants "KeyMap.union: input 1" m1)
         (checkInvariants "KeyMap.union: input 2" m2));
*)

fun intersect f (Map tree1) (Map tree2) =
    let
      val tree = treeIntersect f tree1 tree2
    in
      Map tree
    end;

(*BasicDebug
val intersect = fn f => fn m1 => fn m2 =>
    checkInvariants "KeyMap.intersect: result"
      (intersect f
         (checkInvariants "KeyMap.intersect: input 1" m1)
         (checkInvariants "KeyMap.intersect: input 2" m2));
*)

(* ------------------------------------------------------------------------- *)
(* Iterators over maps.                                                      *)
(* ------------------------------------------------------------------------- *)

fun mkIterator (Map tree) = treeMkIterator tree;

fun mkRevIterator (Map tree) = treeMkRevIterator tree;

(* ------------------------------------------------------------------------- *)
(* Mapping and folding.                                                      *)
(* ------------------------------------------------------------------------- *)

fun mapPartial f (Map tree) =
    let
      val tree = treeMapPartial f tree
    in
      Map tree
    end;

(*BasicDebug
val mapPartial = fn f => fn m =>
    checkInvariants "KeyMap.mapPartial: result"
      (mapPartial f (checkInvariants "KeyMap.mapPartial: input" m));
*)

fun map f (Map tree) =
    let
      val tree = treeMap f tree
    in
      Map tree
    end;

(*BasicDebug
val map = fn f => fn m =>
    checkInvariants "KeyMap.map: result"
      (map f (checkInvariants "KeyMap.map: input" m));
*)

fun transform f = map (fn (_,value) => f value);

fun filter pred =
    let
      fun f (key_value as (_,value)) =
          if pred key_value then SOME value else NONE
    in
      mapPartial f
    end;

fun partition p =
    let
      fun np x = not (p x)
    in
      fn m => (filter p m, filter np m)
    end;

fun foldl f b m = foldIterator f b (mkIterator m);

fun foldr f b m = foldIterator f b (mkRevIterator m);

fun app f m = foldl (fn (key,value,()) => f (key,value)) () m;

(* ------------------------------------------------------------------------- *)
(* Searching.                                                                *)
(* ------------------------------------------------------------------------- *)

fun findl p m = findIterator p (mkIterator m);

fun findr p m = findIterator p (mkRevIterator m);

fun firstl f m = firstIterator f (mkIterator m);

fun firstr f m = firstIterator f (mkRevIterator m);

fun exists p m = Option.isSome (findl p m);

fun all p =
    let
      fun np x = not (p x)
    in
      fn m => not (exists np m)
    end;

fun count pred =
    let
      fun f (k,v,acc) = if pred (k,v) then acc + 1 else acc
    in
      foldl f 0
    end;

(* ------------------------------------------------------------------------- *)
(* Comparing.                                                                *)
(* ------------------------------------------------------------------------- *)

fun compare compareValue (m1,m2) =
    if pointerEqual (m1,m2) then EQUAL
    else
      case Int.compare (size m1, size m2) of
        LESS => LESS
      | EQUAL =>
        let
          val Map _ = m1

          val io1 = mkIterator m1
          and io2 = mkIterator m2
        in
          compareIterator compareValue io1 io2
        end
      | GREATER => GREATER;

fun equal equalValue m1 m2 =
    pointerEqual (m1,m2) orelse
    (size m1 = size m2 andalso
     let
       val Map _ = m1

       val io1 = mkIterator m1
       and io2 = mkIterator m2
     in
       equalIterator equalValue io1 io2
     end);

(* ------------------------------------------------------------------------- *)
(* Set operations on the domain.                                             *)
(* ------------------------------------------------------------------------- *)

fun unionDomain (Map tree1) (Map tree2) =
    let
      val tree = treeUnionDomain tree1 tree2
    in
      Map tree
    end;

(*BasicDebug
val unionDomain = fn m1 => fn m2 =>
    checkInvariants "KeyMap.unionDomain: result"
      (unionDomain
         (checkInvariants "KeyMap.unionDomain: input 1" m1)
         (checkInvariants "KeyMap.unionDomain: input 2" m2));
*)

local
  fun uncurriedUnionDomain (m,acc) = unionDomain acc m;
in
  fun unionListDomain ms =
      case ms of
        [] => raise Bug "KeyMap.unionListDomain: no sets"
      | m :: ms => List.foldl uncurriedUnionDomain m ms;
end;

fun intersectDomain (Map tree1) (Map tree2) =
    let
      val tree = treeIntersectDomain tree1 tree2
    in
      Map tree
    end;

(*BasicDebug
val intersectDomain = fn m1 => fn m2 =>
    checkInvariants "KeyMap.intersectDomain: result"
      (intersectDomain
         (checkInvariants "KeyMap.intersectDomain: input 1" m1)
         (checkInvariants "KeyMap.intersectDomain: input 2" m2));
*)

local
  fun uncurriedIntersectDomain (m,acc) = intersectDomain acc m;
in
  fun intersectListDomain ms =
      case ms of
        [] => raise Bug "KeyMap.intersectListDomain: no sets"
      | m :: ms => List.foldl uncurriedIntersectDomain m ms;
end;

fun differenceDomain (Map tree1) (Map tree2) =
    let
      val tree = treeDifferenceDomain tree1 tree2
    in
      Map tree
    end;

(*BasicDebug
val differenceDomain = fn m1 => fn m2 =>
    checkInvariants "KeyMap.differenceDomain: result"
      (differenceDomain
         (checkInvariants "KeyMap.differenceDomain: input 1" m1)
         (checkInvariants "KeyMap.differenceDomain: input 2" m2));
*)

fun symmetricDifferenceDomain m1 m2 =
    unionDomain (differenceDomain m1 m2) (differenceDomain m2 m1);

fun equalDomain m1 m2 = equal (K (K true)) m1 m2;

fun subsetDomain (Map tree1) (Map tree2) =
    treeSubsetDomain tree1 tree2;

fun disjointDomain m1 m2 = null (intersectDomain m1 m2);

(* ------------------------------------------------------------------------- *)
(* Converting to and from lists.                                             *)
(* ------------------------------------------------------------------------- *)

fun keys m = foldr (fn (key,_,l) => key :: l) [] m;

fun values m = foldr (fn (_,value,l) => value :: l) [] m;

fun toList m = foldr (fn (key,value,l) => (key,value) :: l) [] m;

fun fromList l =
    let
      val m = new ()
    in
      insertList m l
    end;

(* ------------------------------------------------------------------------- *)
(* Pretty-printing.                                                          *)
(* ------------------------------------------------------------------------- *)

fun toString m = "<" ^ (if null m then "" else Int.toString (size m)) ^ ">";

end

structure IntMap =
KeyMap (Metis.IntOrdered); (* MODIFIED by Jasmin Blanchette *)

structure IntPairMap =
KeyMap (Metis.IntPairOrdered); (* MODIFIED by Jasmin Blanchette *)

structure StringMap =
KeyMap (Metis.StringOrdered); (* MODIFIED by Jasmin Blanchette *)

(**** Original file: Set.sig ****)

(* ========================================================================= *)
(* FINITE SETS IMPLEMENTED WITH RANDOMLY BALANCED TREES                      *)
(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature Set =
sig

(* ------------------------------------------------------------------------- *)
(* A type of finite sets.                                                    *)
(* ------------------------------------------------------------------------- *)

type 'elt set

(* ------------------------------------------------------------------------- *)
(* Constructors.                                                             *)
(* ------------------------------------------------------------------------- *)

val empty : ('elt * 'elt -> order) -> 'elt set

val singleton : ('elt * 'elt -> order) -> 'elt -> 'elt set

(* ------------------------------------------------------------------------- *)
(* Set size.                                                                 *)
(* ------------------------------------------------------------------------- *)

val null : 'elt set -> bool

val size : 'elt set -> int

(* ------------------------------------------------------------------------- *)
(* Querying.                                                                 *)
(* ------------------------------------------------------------------------- *)

val peek : 'elt set -> 'elt -> 'elt option

val member : 'elt -> 'elt set -> bool

val pick : 'elt set -> 'elt  (* an arbitrary element *)

val nth : 'elt set -> int -> 'elt  (* in the range [0,size-1] *)

val random : 'elt set -> 'elt

(* ------------------------------------------------------------------------- *)
(* Adding.                                                                   *)
(* ------------------------------------------------------------------------- *)

val add : 'elt set -> 'elt -> 'elt set

val addList : 'elt set -> 'elt list -> 'elt set

(* ------------------------------------------------------------------------- *)
(* Removing.                                                                 *)
(* ------------------------------------------------------------------------- *)

val delete : 'elt set -> 'elt -> 'elt set  (* must be present *)

val remove : 'elt set -> 'elt -> 'elt set

val deletePick : 'elt set -> 'elt * 'elt set

val deleteNth : 'elt set -> int -> 'elt * 'elt set

val deleteRandom : 'elt set -> 'elt * 'elt set

(* ------------------------------------------------------------------------- *)
(* Joining.                                                                  *)
(* ------------------------------------------------------------------------- *)

val union : 'elt set -> 'elt set -> 'elt set

val unionList : 'elt set list -> 'elt set

val intersect : 'elt set -> 'elt set -> 'elt set

val intersectList : 'elt set list -> 'elt set

val difference : 'elt set -> 'elt set -> 'elt set

val symmetricDifference : 'elt set -> 'elt set -> 'elt set

(* ------------------------------------------------------------------------- *)
(* Mapping and folding.                                                      *)
(* ------------------------------------------------------------------------- *)

val filter : ('elt -> bool) -> 'elt set -> 'elt set

val partition : ('elt -> bool) -> 'elt set -> 'elt set * 'elt set

val app : ('elt -> unit) -> 'elt set -> unit

val foldl : ('elt * 's -> 's) -> 's -> 'elt set -> 's

val foldr : ('elt * 's -> 's) -> 's -> 'elt set -> 's

(* ------------------------------------------------------------------------- *)
(* Searching.                                                                *)
(* ------------------------------------------------------------------------- *)

val findl : ('elt -> bool) -> 'elt set -> 'elt option

val findr : ('elt -> bool) -> 'elt set -> 'elt option

val firstl : ('elt -> 'a option) -> 'elt set -> 'a option

val firstr : ('elt -> 'a option) -> 'elt set -> 'a option

val exists : ('elt -> bool) -> 'elt set -> bool

val all : ('elt -> bool) -> 'elt set -> bool

val count : ('elt -> bool) -> 'elt set -> int

(* ------------------------------------------------------------------------- *)
(* Comparing.                                                                *)
(* ------------------------------------------------------------------------- *)

val compare : 'elt set * 'elt set -> order

val equal : 'elt set -> 'elt set -> bool

val subset : 'elt set -> 'elt set -> bool

val disjoint : 'elt set -> 'elt set -> bool

(* ------------------------------------------------------------------------- *)
(* Converting to and from lists.                                             *)
(* ------------------------------------------------------------------------- *)

val transform : ('elt -> 'a) -> 'elt set -> 'a list

val toList : 'elt set -> 'elt list

val fromList : ('elt * 'elt -> order) -> 'elt list -> 'elt set

(* ------------------------------------------------------------------------- *)
(* Converting to and from maps.                                              *)
(* ------------------------------------------------------------------------- *)

type ('elt,'a) map = ('elt,'a) Metis.Map.map

val mapPartial : ('elt -> 'a option) -> 'elt set -> ('elt,'a) map

val map : ('elt -> 'a) -> 'elt set -> ('elt,'a) map

val domain : ('elt,'a) map -> 'elt set

(* ------------------------------------------------------------------------- *)
(* Pretty-printing.                                                          *)
(* ------------------------------------------------------------------------- *)

val toString : 'elt set -> string

(* ------------------------------------------------------------------------- *)
(* Iterators over sets                                                       *)
(* ------------------------------------------------------------------------- *)

type 'elt iterator

val mkIterator : 'elt set -> 'elt iterator option

val mkRevIterator : 'elt set -> 'elt iterator option

val readIterator : 'elt iterator -> 'elt

val advanceIterator : 'elt iterator -> 'elt iterator option

end

(**** Original file: Set.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* FINITE SETS IMPLEMENTED WITH RANDOMLY BALANCED TREES                      *)
(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

structure Set :> Set =
struct

(* ------------------------------------------------------------------------- *)
(* A type of finite sets.                                                    *)
(* ------------------------------------------------------------------------- *)

type ('elt,'a) map = ('elt,'a) Map.map;

datatype 'elt set = Set of ('elt,unit) map;

(* ------------------------------------------------------------------------- *)
(* Converting to and from maps.                                              *)
(* ------------------------------------------------------------------------- *)

fun dest (Set m) = m;

fun mapPartial f =
    let
      fun mf (elt,()) = f elt
    in
      fn Set m => Map.mapPartial mf m
    end;

fun map f =
    let
      fun mf (elt,()) = f elt
    in
      fn Set m => Map.map mf m
    end;

fun domain m = Set (Map.transform (fn _ => ()) m);

(* ------------------------------------------------------------------------- *)
(* Constructors.                                                             *)
(* ------------------------------------------------------------------------- *)

fun empty cmp = Set (Map.new cmp);

fun singleton cmp elt = Set (Map.singleton cmp (elt,()));

(* ------------------------------------------------------------------------- *)
(* Set size.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun null (Set m) = Map.null m;

fun size (Set m) = Map.size m;

(* ------------------------------------------------------------------------- *)
(* Querying.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun peek (Set m) elt =
    case Map.peekKey m elt of
      SOME (elt,()) => SOME elt
    | NONE => NONE;

fun member elt (Set m) = Map.inDomain elt m;

fun pick (Set m) =
    let
      val (elt,_) = Map.pick m
    in
      elt
    end;

fun nth (Set m) n =
    let
      val (elt,_) = Map.nth m n
    in
      elt
    end;

fun random (Set m) =
    let
      val (elt,_) = Map.random m
    in
      elt
    end;

(* ------------------------------------------------------------------------- *)
(* Adding.                                                                   *)
(* ------------------------------------------------------------------------- *)

fun add (Set m) elt =
    let
      val m = Map.insert m (elt,())
    in
      Set m
    end;

local
  fun uncurriedAdd (elt,set) = add set elt;
in
  fun addList set = List.foldl uncurriedAdd set;
end;

(* ------------------------------------------------------------------------- *)
(* Removing.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun delete (Set m) elt =
    let
      val m = Map.delete m elt
    in
      Set m
    end;

fun remove (Set m) elt =
    let
      val m = Map.remove m elt
    in
      Set m
    end;

fun deletePick (Set m) =
    let
      val ((elt,()),m) = Map.deletePick m
    in
      (elt, Set m)
    end;

fun deleteNth (Set m) n =
    let
      val ((elt,()),m) = Map.deleteNth m n
    in
      (elt, Set m)
    end;

fun deleteRandom (Set m) =
    let
      val ((elt,()),m) = Map.deleteRandom m
    in
      (elt, Set m)
    end;

(* ------------------------------------------------------------------------- *)
(* Joining.                                                                  *)
(* ------------------------------------------------------------------------- *)

fun union (Set m1) (Set m2) = Set (Map.unionDomain m1 m2);

fun unionList sets =
    let
      val ms = List.map dest sets
    in
      Set (Map.unionListDomain ms)
    end;

fun intersect (Set m1) (Set m2) = Set (Map.intersectDomain m1 m2);

fun intersectList sets =
    let
      val ms = List.map dest sets
    in
      Set (Map.intersectListDomain ms)
    end;

fun difference (Set m1) (Set m2) =
    Set (Map.differenceDomain m1 m2);

fun symmetricDifference (Set m1) (Set m2) =
    Set (Map.symmetricDifferenceDomain m1 m2);

(* ------------------------------------------------------------------------- *)
(* Mapping and folding.                                                      *)
(* ------------------------------------------------------------------------- *)

fun filter pred =
    let
      fun mpred (elt,()) = pred elt
    in
      fn Set m => Set (Map.filter mpred m)
    end;

fun partition pred =
    let
      fun mpred (elt,()) = pred elt
    in
      fn Set m =>
         let
           val (m1,m2) = Map.partition mpred m
         in
           (Set m1, Set m2)
         end
    end;

fun app f =
    let
      fun mf (elt,()) = f elt
    in
      fn Set m => Map.app mf m
    end;

fun foldl f =
    let
      fun mf (elt,(),acc) = f (elt,acc)
    in
      fn acc => fn Set m => Map.foldl mf acc m
    end;

fun foldr f =
    let
      fun mf (elt,(),acc) = f (elt,acc)
    in
      fn acc => fn Set m => Map.foldr mf acc m
    end;

(* ------------------------------------------------------------------------- *)
(* Searching.                                                                *)
(* ------------------------------------------------------------------------- *)

fun findl p =
    let
      fun mp (elt,()) = p elt
    in
      fn Set m =>
         case Map.findl mp m of
           SOME (elt,()) => SOME elt
         | NONE => NONE
    end;

fun findr p =
    let
      fun mp (elt,()) = p elt
    in
      fn Set m =>
         case Map.findr mp m of
           SOME (elt,()) => SOME elt
         | NONE => NONE
    end;

fun firstl f =
    let
      fun mf (elt,()) = f elt
    in
      fn Set m => Map.firstl mf m
    end;

fun firstr f =
    let
      fun mf (elt,()) = f elt
    in
      fn Set m => Map.firstr mf m
    end;

fun exists p =
    let
      fun mp (elt,()) = p elt
    in
      fn Set m => Map.exists mp m
    end;

fun all p =
    let
      fun mp (elt,()) = p elt
    in
      fn Set m => Map.all mp m
    end;

fun count p =
    let
      fun mp (elt,()) = p elt
    in
      fn Set m => Map.count mp m
    end;

(* ------------------------------------------------------------------------- *)
(* Comparing.                                                                *)
(* ------------------------------------------------------------------------- *)

fun compareValue ((),()) = EQUAL;

fun equalValue () () = true;

fun compare (Set m1, Set m2) = Map.compare compareValue (m1,m2);

fun equal (Set m1) (Set m2) = Map.equal equalValue m1 m2;

fun subset (Set m1) (Set m2) = Map.subsetDomain m1 m2;

fun disjoint (Set m1) (Set m2) = Map.disjointDomain m1 m2;

(* ------------------------------------------------------------------------- *)
(* Converting to and from lists.                                             *)
(* ------------------------------------------------------------------------- *)

fun transform f =
    let
      fun inc (x,l) = f x :: l
    in
      foldr inc []
    end;

fun toList (Set m) = Map.keys m;

fun fromList cmp elts = addList (empty cmp) elts;

(* ------------------------------------------------------------------------- *)
(* Pretty-printing.                                                          *)
(* ------------------------------------------------------------------------- *)

fun toString set =
    "{" ^ (if null set then "" else Int.toString (size set)) ^ "}";

(* ------------------------------------------------------------------------- *)
(* Iterators over sets                                                       *)
(* ------------------------------------------------------------------------- *)

type 'elt iterator = ('elt,unit) Map.iterator;

fun mkIterator (Set m) = Map.mkIterator m;

fun mkRevIterator (Set m) = Map.mkRevIterator m;

fun readIterator iter =
    let
      val (elt,()) = Map.readIterator iter
    in
      elt
    end;

fun advanceIterator iter = Map.advanceIterator iter;

end
end;

(**** Original file: ElementSet.sig ****)

(* ========================================================================= *)
(* FINITE SETS WITH A FIXED ELEMENT TYPE                                     *)
(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature ElementSet =
sig

(* ------------------------------------------------------------------------- *)
(* A type of set elements.                                                   *)
(* ------------------------------------------------------------------------- *)

type element

(* ------------------------------------------------------------------------- *)
(* A type of finite sets.                                                    *)
(* ------------------------------------------------------------------------- *)

type set

(* ------------------------------------------------------------------------- *)
(* Constructors.                                                             *)
(* ------------------------------------------------------------------------- *)

val empty : set

val singleton : element -> set

(* ------------------------------------------------------------------------- *)
(* Set size.                                                                 *)
(* ------------------------------------------------------------------------- *)

val null : set -> bool

val size : set -> int

(* ------------------------------------------------------------------------- *)
(* Querying.                                                                 *)
(* ------------------------------------------------------------------------- *)

val peek : set -> element -> element option

val member : element -> set -> bool

val pick : set -> element  (* an arbitrary element *)

val nth : set -> int -> element  (* in the range [0,size-1] *)

val random : set -> element

(* ------------------------------------------------------------------------- *)
(* Adding.                                                                   *)
(* ------------------------------------------------------------------------- *)

val add : set -> element -> set

val addList : set -> element list -> set

(* ------------------------------------------------------------------------- *)
(* Removing.                                                                 *)
(* ------------------------------------------------------------------------- *)

val delete : set -> element -> set  (* must be present *)

val remove : set -> element -> set

val deletePick : set -> element * set

val deleteNth : set -> int -> element * set

val deleteRandom : set -> element * set

(* ------------------------------------------------------------------------- *)
(* Joining.                                                                  *)
(* ------------------------------------------------------------------------- *)

val union : set -> set -> set

val unionList : set list -> set

val intersect : set -> set -> set

val intersectList : set list -> set

val difference : set -> set -> set

val symmetricDifference : set -> set -> set

(* ------------------------------------------------------------------------- *)
(* Mapping and folding.                                                      *)
(* ------------------------------------------------------------------------- *)

val filter : (element -> bool) -> set -> set

val partition : (element -> bool) -> set -> set * set

val app : (element -> unit) -> set -> unit

val foldl : (element * 's -> 's) -> 's -> set -> 's

val foldr : (element * 's -> 's) -> 's -> set -> 's

(* ------------------------------------------------------------------------- *)
(* Searching.                                                                *)
(* ------------------------------------------------------------------------- *)

val findl : (element -> bool) -> set -> element option

val findr : (element -> bool) -> set -> element option

val firstl : (element -> 'a option) -> set -> 'a option

val firstr : (element -> 'a option) -> set -> 'a option

val exists : (element -> bool) -> set -> bool

val all : (element -> bool) -> set -> bool

val count : (element -> bool) -> set -> int

(* ------------------------------------------------------------------------- *)
(* Comparing.                                                                *)
(* ------------------------------------------------------------------------- *)

val compare : set * set -> order

val equal : set -> set -> bool

val subset : set -> set -> bool

val disjoint : set -> set -> bool

(* ------------------------------------------------------------------------- *)
(* Converting to and from lists.                                             *)
(* ------------------------------------------------------------------------- *)

val transform : (element -> 'a) -> set -> 'a list

val toList : set -> element list

val fromList : element list -> set

(* ------------------------------------------------------------------------- *)
(* Converting to and from maps.                                              *)
(* ------------------------------------------------------------------------- *)

type 'a map

val mapPartial : (element -> 'a option) -> set -> 'a map

val map : (element -> 'a) -> set -> 'a map

val domain : 'a map -> set

(* ------------------------------------------------------------------------- *)
(* Pretty-printing.                                                          *)
(* ------------------------------------------------------------------------- *)

val toString : set -> string

(* ------------------------------------------------------------------------- *)
(* Iterators over sets                                                       *)
(* ------------------------------------------------------------------------- *)

type iterator

val mkIterator : set -> iterator option

val mkRevIterator : set -> iterator option

val readIterator : iterator -> element

val advanceIterator : iterator -> iterator option

end

(**** Original file: ElementSet.sml ****)

(* ========================================================================= *)
(* FINITE SETS WITH A FIXED ELEMENT TYPE                                     *)
(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

functor ElementSet (KM : KeyMap) :> ElementSet
where type element = KM.key and type 'a map = 'a KM.map =
struct

(* ------------------------------------------------------------------------- *)
(* A type of set elements.                                                   *)
(* ------------------------------------------------------------------------- *)

type element = KM.key;

(* ------------------------------------------------------------------------- *)
(* A type of finite sets.                                                    *)
(* ------------------------------------------------------------------------- *)

type 'a map = 'a KM.map;

datatype set = Set of unit map;

(* ------------------------------------------------------------------------- *)
(* Converting to and from maps.                                              *)
(* ------------------------------------------------------------------------- *)

fun dest (Set m) = m;

fun mapPartial f =
    let
      fun mf (elt,()) = f elt
    in
      fn Set m => KM.mapPartial mf m
    end;

fun map f =
    let
      fun mf (elt,()) = f elt
    in
      fn Set m => KM.map mf m
    end;

fun domain m = Set (KM.transform (fn _ => ()) m);

(* ------------------------------------------------------------------------- *)
(* Constructors.                                                             *)
(* ------------------------------------------------------------------------- *)

val empty = Set (KM.new ());

fun singleton elt = Set (KM.singleton (elt,()));

(* ------------------------------------------------------------------------- *)
(* Set size.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun null (Set m) = KM.null m;

fun size (Set m) = KM.size m;

(* ------------------------------------------------------------------------- *)
(* Querying.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun peek (Set m) elt =
    case KM.peekKey m elt of
      SOME (elt,()) => SOME elt
    | NONE => NONE;

fun member elt (Set m) = KM.inDomain elt m;

fun pick (Set m) =
    let
      val (elt,_) = KM.pick m
    in
      elt
    end;

fun nth (Set m) n =
    let
      val (elt,_) = KM.nth m n
    in
      elt
    end;

fun random (Set m) =
    let
      val (elt,_) = KM.random m
    in
      elt
    end;

(* ------------------------------------------------------------------------- *)
(* Adding.                                                                   *)
(* ------------------------------------------------------------------------- *)

fun add (Set m) elt =
    let
      val m = KM.insert m (elt,())
    in
      Set m
    end;

local
  fun uncurriedAdd (elt,set) = add set elt;
in
  fun addList set = List.foldl uncurriedAdd set;
end;

(* ------------------------------------------------------------------------- *)
(* Removing.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun delete (Set m) elt =
    let
      val m = KM.delete m elt
    in
      Set m
    end;

fun remove (Set m) elt =
    let
      val m = KM.remove m elt
    in
      Set m
    end;

fun deletePick (Set m) =
    let
      val ((elt,()),m) = KM.deletePick m
    in
      (elt, Set m)
    end;

fun deleteNth (Set m) n =
    let
      val ((elt,()),m) = KM.deleteNth m n
    in
      (elt, Set m)
    end;

fun deleteRandom (Set m) =
    let
      val ((elt,()),m) = KM.deleteRandom m
    in
      (elt, Set m)
    end;

(* ------------------------------------------------------------------------- *)
(* Joining.                                                                  *)
(* ------------------------------------------------------------------------- *)

fun op union (Set m1) (Set m2) = Set (KM.unionDomain m1 m2);

fun unionList sets =
    let
      val ms = List.map dest sets
    in
      Set (KM.unionListDomain ms)
    end;

fun intersect (Set m1) (Set m2) = Set (KM.intersectDomain m1 m2);

fun intersectList sets =
    let
      val ms = List.map dest sets
    in
      Set (KM.intersectListDomain ms)
    end;

fun difference (Set m1) (Set m2) =
    Set (KM.differenceDomain m1 m2);

fun symmetricDifference (Set m1) (Set m2) =
    Set (KM.symmetricDifferenceDomain m1 m2);

(* ------------------------------------------------------------------------- *)
(* Mapping and folding.                                                      *)
(* ------------------------------------------------------------------------- *)

fun filter pred =
    let
      fun mpred (elt,()) = pred elt
    in
      fn Set m => Set (KM.filter mpred m)
    end;

fun partition pred =
    let
      fun mpred (elt,()) = pred elt
    in
      fn Set m =>
         let
           val (m1,m2) = KM.partition mpred m
         in
           (Set m1, Set m2)
         end
    end;

fun app f =
    let
      fun mf (elt,()) = f elt
    in
      fn Set m => KM.app mf m
    end;

fun foldl f =
    let
      fun mf (elt,(),acc) = f (elt,acc)
    in
      fn acc => fn Set m => KM.foldl mf acc m
    end;

fun foldr f =
    let
      fun mf (elt,(),acc) = f (elt,acc)
    in
      fn acc => fn Set m => KM.foldr mf acc m
    end;

(* ------------------------------------------------------------------------- *)
(* Searching.                                                                *)
(* ------------------------------------------------------------------------- *)

fun findl p =
    let
      fun mp (elt,()) = p elt
    in
      fn Set m =>
         case KM.findl mp m of
           SOME (elt,()) => SOME elt
         | NONE => NONE
    end;

fun findr p =
    let
      fun mp (elt,()) = p elt
    in
      fn Set m =>
         case KM.findr mp m of
           SOME (elt,()) => SOME elt
         | NONE => NONE
    end;

fun firstl f =
    let
      fun mf (elt,()) = f elt
    in
      fn Set m => KM.firstl mf m
    end;

fun firstr f =
    let
      fun mf (elt,()) = f elt
    in
      fn Set m => KM.firstr mf m
    end;

fun exists p =
    let
      fun mp (elt,()) = p elt
    in
      fn Set m => KM.exists mp m
    end;

fun all p =
    let
      fun mp (elt,()) = p elt
    in
      fn Set m => KM.all mp m
    end;

fun count p =
    let
      fun mp (elt,()) = p elt
    in
      fn Set m => KM.count mp m
    end;

(* ------------------------------------------------------------------------- *)
(* Comparing.                                                                *)
(* ------------------------------------------------------------------------- *)

fun compareValue ((),()) = EQUAL;

fun equalValue () () = true;

fun compare (Set m1, Set m2) = KM.compare compareValue (m1,m2);

fun equal (Set m1) (Set m2) = KM.equal equalValue m1 m2;

fun op subset (Set m1) (Set m2) = KM.subsetDomain m1 m2;

fun disjoint (Set m1) (Set m2) = KM.disjointDomain m1 m2;

(* ------------------------------------------------------------------------- *)
(* Converting to and from lists.                                             *)
(* ------------------------------------------------------------------------- *)

fun transform f =
    let
      fun inc (x,l) = f x :: l
    in
      foldr inc []
    end;

fun toList (Set m) = KM.keys m;

fun fromList elts = addList empty elts;

(* ------------------------------------------------------------------------- *)
(* Pretty-printing.                                                          *)
(* ------------------------------------------------------------------------- *)

fun toString set =
    "{" ^ (if null set then "" else Int.toString (size set)) ^ "}";

(* ------------------------------------------------------------------------- *)
(* Iterators over sets                                                       *)
(* ------------------------------------------------------------------------- *)

type iterator = unit KM.iterator;

fun mkIterator (Set m) = KM.mkIterator m;

fun mkRevIterator (Set m) = KM.mkRevIterator m;

fun readIterator iter =
    let
      val (elt,()) = KM.readIterator iter
    in
      elt
    end;

fun advanceIterator iter = KM.advanceIterator iter;

end

structure IntSet =
ElementSet (IntMap);

structure IntPairSet =
ElementSet (IntPairMap);

structure StringSet =
ElementSet (StringMap);

(**** Original file: Sharing.sig ****)

(* ========================================================================= *)
(* PRESERVING SHARING OF ML VALUES                                           *)
(* Copyright (c) 2005-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Sharing =
sig

(* ------------------------------------------------------------------------- *)
(* Option operations.                                                        *)
(* ------------------------------------------------------------------------- *)

val mapOption : ('a -> 'a) -> 'a option -> 'a option

val mapsOption : ('a -> 's -> 'a * 's) -> 'a option -> 's -> 'a option * 's

(* ------------------------------------------------------------------------- *)
(* List operations.                                                          *)
(* ------------------------------------------------------------------------- *)

val map : ('a -> 'a) -> 'a list -> 'a list

val revMap : ('a -> 'a) -> 'a list -> 'a list

val maps : ('a -> 's -> 'a * 's) -> 'a list -> 's -> 'a list * 's

val revMaps : ('a -> 's -> 'a * 's) -> 'a list -> 's -> 'a list * 's

val updateNth : int * 'a -> 'a list -> 'a list

val setify : ''a list -> ''a list

(* ------------------------------------------------------------------------- *)
(* Function caching.                                                         *)
(* ------------------------------------------------------------------------- *)

val cache : ('a * 'a -> order) -> ('a -> 'b) -> 'a -> 'b

(* ------------------------------------------------------------------------- *)
(* Hash consing.                                                             *)
(* ------------------------------------------------------------------------- *)

val hashCons : ('a * 'a -> order) -> 'a -> 'a

end

(**** Original file: Sharing.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* PRESERVING SHARING OF ML VALUES                                           *)
(* Copyright (c) 2005-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Sharing :> Sharing =
struct

infix ==

val op== = Portable.pointerEqual;

(* ------------------------------------------------------------------------- *)
(* Option operations.                                                        *)
(* ------------------------------------------------------------------------- *)

fun mapOption f xo =
    case xo of
      SOME x =>
      let
        val y = f x
      in
        if x == y then xo else SOME y
      end
    | NONE => xo;

fun mapsOption f xo acc =
    case xo of
      SOME x =>
      let
        val (y,acc) = f x acc
      in
        if x == y then (xo,acc) else (SOME y, acc)
      end
    | NONE => (xo,acc);

(* ------------------------------------------------------------------------- *)
(* List operations.                                                          *)
(* ------------------------------------------------------------------------- *)

fun map f =
    let
      fun m ys ys_xs xs =
          case xs of
            [] => List.revAppend ys_xs
          | x :: xs =>
            let
              val y = f x
              val ys = y :: ys
              val ys_xs = if x == y then ys_xs else (ys,xs)
            in
              m ys ys_xs xs
            end
    in
      fn xs => m [] ([],xs) xs
    end;

fun maps f =
    let
      fun m acc ys ys_xs xs =
          case xs of
            [] => (List.revAppend ys_xs, acc)
          | x :: xs =>
            let
              val (y,acc) = f x acc
              val ys = y :: ys
              val ys_xs = if x == y then ys_xs else (ys,xs)
            in
              m acc ys ys_xs xs
            end
    in
      fn xs => fn acc => m acc [] ([],xs) xs
    end;

local
  fun revTails acc xs =
      case xs of
        [] => acc
      | x :: xs' => revTails ((x,xs) :: acc) xs';
in
  fun revMap f =
      let
        fun m ys same xxss =
            case xxss of
              [] => ys
            | (x,xs) :: xxss =>
              let
                val y = f x
                val same = same andalso x == y
                val ys = if same then xs else y :: ys
              in
                m ys same xxss
              end
      in
        fn xs => m [] true (revTails [] xs)
      end;

  fun revMaps f =
      let
        fun m acc ys same xxss =
            case xxss of
              [] => (ys,acc)
            | (x,xs) :: xxss =>
              let
                val (y,acc) = f x acc
                val same = same andalso x == y
                val ys = if same then xs else y :: ys
              in
                m acc ys same xxss
              end
      in
        fn xs => fn acc => m acc [] true (revTails [] xs)
      end;
end;

fun updateNth (n,x) l =
    let
      val (a,b) = Useful.revDivide l n
    in
      case b of
        [] => raise Subscript
      | h :: t => if x == h then l else List.revAppend (a, x :: t)
    end;

fun setify l =
    let
      val l' = Useful.setify l
    in
      if length l' = length l then l else l'
    end;

(* ------------------------------------------------------------------------- *)
(* Function caching.                                                         *)
(* ------------------------------------------------------------------------- *)

fun cache cmp f =
    let
      val cache = Unsynchronized.ref (Map.new cmp)
    in
      fn a =>
         case Map.peek (!cache) a of
           SOME b => b
         | NONE =>
           let
             val b = f a
             val () = cache := Map.insert (!cache) (a,b)
           in
             b
           end
    end;

(* ------------------------------------------------------------------------- *)
(* Hash consing.                                                             *)
(* ------------------------------------------------------------------------- *)

fun hashCons cmp = cache cmp Useful.I;

end
end;

(**** Original file: Heap.sig ****)

(* ========================================================================= *)
(* A HEAP DATATYPE FOR ML                                                    *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Heap =
sig

type 'a heap

val new : ('a * 'a -> order) -> 'a heap

val add : 'a heap -> 'a -> 'a heap

val null : 'a heap -> bool

val top : 'a heap -> 'a  (* raises Empty *)

val remove : 'a heap -> 'a * 'a heap  (* raises Empty *)

val size : 'a heap -> int

val app : ('a -> unit) -> 'a heap -> unit

val toList : 'a heap -> 'a list

val toStream : 'a heap -> 'a Metis.Stream.stream

val toString : 'a heap -> string

end

(**** Original file: Heap.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* A HEAP DATATYPE FOR ML                                                    *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Heap :> Heap =
struct

(* Leftist heaps as in Purely Functional Data Structures, by Chris Okasaki *)

datatype 'a node = E | T of int * 'a * 'a node * 'a node;

datatype 'a heap = Heap of ('a * 'a -> order) * int * 'a node;

fun rank E = 0
  | rank (T (r,_,_,_)) = r;

fun makeT (x,a,b) =
  if rank a >= rank b then T (rank b + 1, x, a, b) else T (rank a + 1, x, b, a);

fun merge cmp =
    let
      fun mrg (h,E) = h
        | mrg (E,h) = h
        | mrg (h1 as T (_,x,a1,b1), h2 as T (_,y,a2,b2)) =
          case cmp (x,y) of
            GREATER => makeT (y, a2, mrg (h1,b2))
          | _ => makeT (x, a1, mrg (b1,h2))
    in
      mrg
    end;

fun new cmp = Heap (cmp,0,E);

fun add (Heap (f,n,a)) x = Heap (f, n + 1, merge f (T (1,x,E,E), a));

fun size (Heap (_, n, _)) = n;

fun null h = size h = 0;

fun top (Heap (_,_,E)) = raise Empty
  | top (Heap (_, _, T (_,x,_,_))) = x;

fun remove (Heap (_,_,E)) = raise Empty
  | remove (Heap (f, n, T (_,x,a,b))) = (x, Heap (f, n - 1, merge f (a,b)));

fun app f =
    let
      fun ap [] = ()
        | ap (E :: rest) = ap rest
        | ap (T (_,d,a,b) :: rest) = (f d; ap (a :: b :: rest))
    in
      fn Heap (_,_,a) => ap [a]
    end;

fun toList h =
    if null h then []
    else
      let
        val (x,h) = remove h
      in
        x :: toList h
      end;

fun toStream h =
    if null h then Stream.Nil
    else
      let
        val (x,h) = remove h
      in
        Stream.Cons (x, fn () => toStream h)
      end;

fun toString h =
    "Heap[" ^ (if null h then "" else Int.toString (size h)) ^ "]";

end
end;

(**** Original file: Print.sig ****)

(* ========================================================================= *)
(* PRETTY-PRINTING                                                           *)
(* Copyright (c) 2001-2008 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Print =
sig

(* ------------------------------------------------------------------------- *)
(* A type of pretty-printers.                                                *)
(* ------------------------------------------------------------------------- *)

datatype breakStyle = Consistent | Inconsistent

datatype ppStep =
    BeginBlock of breakStyle * int
  | EndBlock
  | AddString of string
  | AddBreak of int
  | AddNewline

type ppstream = ppStep Metis.Stream.stream

type 'a pp = 'a -> ppstream

(* ------------------------------------------------------------------------- *)
(* Pretty-printer primitives.                                                *)
(* ------------------------------------------------------------------------- *)

val beginBlock : breakStyle -> int -> ppstream

val endBlock : ppstream

val addString : string -> ppstream

val addBreak : int -> ppstream

val addNewline : ppstream

val skip : ppstream

val sequence : ppstream -> ppstream -> ppstream

val duplicate : int -> ppstream -> ppstream

val program : ppstream list -> ppstream

val stream : ppstream Metis.Stream.stream -> ppstream

val block : breakStyle -> int -> ppstream -> ppstream

val blockProgram : breakStyle -> int -> ppstream list -> ppstream

val bracket : string -> string -> ppstream -> ppstream

val field : string -> ppstream -> ppstream

val record : (string * ppstream) list -> ppstream

(* ------------------------------------------------------------------------- *)
(* Pretty-printer combinators.                                               *)
(* ------------------------------------------------------------------------- *)

val ppMap : ('a -> 'b) -> 'b pp -> 'a pp

val ppBracket : string -> string -> 'a pp -> 'a pp

val ppOp : string -> ppstream

val ppOp2 : string -> 'a pp -> 'b pp -> ('a * 'b) pp

val ppOp3 : string -> string -> 'a pp -> 'b pp -> 'c pp -> ('a * 'b * 'c) pp

val ppOpList : string -> 'a pp -> 'a list pp

val ppOpStream : string -> 'a pp -> 'a Metis.Stream.stream pp

(* ------------------------------------------------------------------------- *)
(* Pretty-printers for common types.                                         *)
(* ------------------------------------------------------------------------- *)

val ppChar : char pp

val ppString : string pp

val ppEscapeString : {escape : char list} -> string pp

val ppUnit : unit pp

val ppBool : bool pp

val ppInt : int pp

val ppPrettyInt : int pp

val ppReal : real pp

val ppPercent : real pp

val ppOrder : order pp

val ppList : 'a pp -> 'a list pp

val ppStream : 'a pp -> 'a Metis.Stream.stream pp

val ppOption : 'a pp -> 'a option pp

val ppPair : 'a pp -> 'b pp -> ('a * 'b) pp

val ppTriple : 'a pp -> 'b pp -> 'c pp -> ('a * 'b * 'c) pp

val ppBreakStyle : breakStyle pp

val ppPpStep : ppStep pp

val ppPpstream : ppstream pp

(* ------------------------------------------------------------------------- *)
(* Pretty-printing infix operators.                                          *)
(* ------------------------------------------------------------------------- *)

datatype infixes =
    Infixes of
      {token : string,
       precedence : int,
       leftAssoc : bool} list

val tokensInfixes : infixes -> StringSet.set (* MODIFIED by Jasmin Blanchette *)

val layerInfixes :
    infixes ->
    {tokens : {leftSpaces : int, token : string, rightSpaces : int} list,
     leftAssoc : bool} list

val ppInfixes :
    infixes -> ('a -> (string * 'a * 'a) option) -> ('a * bool) pp ->
    ('a * bool) pp

(* ------------------------------------------------------------------------- *)
(* Executing pretty-printers to generate lines.                              *)
(* ------------------------------------------------------------------------- *)

val execute : {lineLength : int} -> ppstream -> string Metis.Stream.stream

(* ------------------------------------------------------------------------- *)
(* Executing pretty-printers with a global line length.                      *)
(* ------------------------------------------------------------------------- *)

val lineLength : int Unsynchronized.ref

val toString : 'a pp -> 'a -> string

val toStream : 'a pp -> 'a -> string Metis.Stream.stream

val trace : 'a pp -> string -> 'a -> unit

end

(**** Original file: Print.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* PRETTY-PRINTING                                                           *)
(* Copyright (c) 2001-2008 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Print :> Print =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Constants.                                                                *)
(* ------------------------------------------------------------------------- *)

val initialLineLength = 75;

(* ------------------------------------------------------------------------- *)
(* Helper functions.                                                         *)
(* ------------------------------------------------------------------------- *)

fun revAppend xs s =
    case xs of
      [] => s ()
    | x :: xs => revAppend xs (K (Stream.Cons (x,s)));

fun revConcat strm =
    case strm of
      Stream.Nil => Stream.Nil
    | Stream.Cons (h,t) => revAppend h (revConcat o t);

local
  fun join current prev = (prev ^ "\n", current);
in
  fun joinNewline strm =
      case strm of
        Stream.Nil => Stream.Nil
      | Stream.Cons (h,t) => Stream.maps join Stream.singleton h (t ());
end;

local
  fun calcSpaces n = nChars #" " n;

  val cachedSpaces = Vector.tabulate (initialLineLength,calcSpaces);
in
  fun nSpaces n =
      if n < initialLineLength then Vector.sub (cachedSpaces,n)
      else calcSpaces n;
end;

(* ------------------------------------------------------------------------- *)
(* A type of pretty-printers.                                                *)
(* ------------------------------------------------------------------------- *)

datatype breakStyle = Consistent | Inconsistent;

datatype ppStep =
    BeginBlock of breakStyle * int
  | EndBlock
  | AddString of string
  | AddBreak of int
  | AddNewline;

type ppstream = ppStep Stream.stream;

type 'a pp = 'a -> ppstream;

fun breakStyleToString style =
    case style of
      Consistent => "Consistent"
    | Inconsistent => "Inconsistent";

fun ppStepToString step =
    case step of
      BeginBlock _ => "BeginBlock"
    | EndBlock => "EndBlock"
    | AddString _ => "AddString"
    | AddBreak _ => "AddBreak"
    | AddNewline => "AddNewline";

(* ------------------------------------------------------------------------- *)
(* Pretty-printer primitives.                                                *)
(* ------------------------------------------------------------------------- *)

fun beginBlock style indent = Stream.singleton (BeginBlock (style,indent));

val endBlock = Stream.singleton EndBlock;

fun addString s = Stream.singleton (AddString s);

fun addBreak spaces = Stream.singleton (AddBreak spaces);

val addNewline = Stream.singleton AddNewline;

val skip : ppstream = Stream.Nil;

fun sequence pp1 pp2 : ppstream = Stream.append pp1 (K pp2);

local
  fun dup pp n () = if n = 1 then pp else Stream.append pp (dup pp (n - 1));
in
  fun duplicate n pp = if n = 0 then skip else dup pp n ();
end;

val program : ppstream list -> ppstream = Stream.concatList;

val stream : ppstream Stream.stream -> ppstream = Stream.concat;

fun block style indent pp = program [beginBlock style indent, pp, endBlock];

fun blockProgram style indent pps = block style indent (program pps);

fun bracket l r pp =
    blockProgram Inconsistent (size l)
      [addString l,
       pp,
       addString r];

fun field f pp =
    blockProgram Inconsistent 2
      [addString (f ^ " ="),
       addBreak 1,
       pp];

val record =
    let
      val sep = sequence (addString ",") (addBreak 1)

      fun recordField (f,pp) = field f pp

      fun sepField f = sequence sep (recordField f)

      fun fields [] = []
        | fields (f :: fs) = recordField f :: map sepField fs
    in
      bracket "{" "}" o blockProgram Consistent 0 o fields
    end;

(* ------------------------------------------------------------------------- *)
(* Pretty-printer combinators.                                               *)
(* ------------------------------------------------------------------------- *)

fun ppMap f ppB a : ppstream = ppB (f a);

fun ppBracket l r ppA a = bracket l r (ppA a);

fun ppOp s = sequence (if s = "" then skip else addString s) (addBreak 1);

fun ppOp2 ab ppA ppB (a,b) =
    blockProgram Inconsistent 0
      [ppA a,
       ppOp ab,
       ppB b];

fun ppOp3 ab bc ppA ppB ppC (a,b,c) =
    blockProgram Inconsistent 0
      [ppA a,
       ppOp ab,
       ppB b,
       ppOp bc,
       ppC c];

fun ppOpList s ppA =
    let
      fun ppOpA a = sequence (ppOp s) (ppA a)
    in
      fn [] => skip
       | h :: t => blockProgram Inconsistent 0 (ppA h :: map ppOpA t)
    end;

fun ppOpStream s ppA =
    let
      fun ppOpA a = sequence (ppOp s) (ppA a)
    in
      fn Stream.Nil => skip
       | Stream.Cons (h,t) =>
         blockProgram Inconsistent 0
           [ppA h,
            Stream.concat (Stream.map ppOpA (t ()))]
    end;

(* ------------------------------------------------------------------------- *)
(* Pretty-printers for common types.                                         *)
(* ------------------------------------------------------------------------- *)

fun ppChar c = addString (str c);

val ppString = addString;

fun ppEscapeString {escape} =
    let
      val escapeMap = map (fn c => (c, "\\" ^ str c)) escape

      fun escapeChar c =
          case c of
            #"\\" => "\\\\"
          | #"\n" => "\\n"
          | #"\t" => "\\t"
          | _ =>
            case List.find (equal c o fst) escapeMap of
              SOME (_,s) => s
            | NONE => str c
    in
      fn s => addString (String.translate escapeChar s)
    end;

val ppUnit : unit pp = K (addString "()");

fun ppBool b = addString (if b then "true" else "false");

fun ppInt i = addString (Int.toString i);

local
  val ppNeg = addString "~"
  and ppSep = addString ","
  and ppZero = addString "0"
  and ppZeroZero = addString "00";

  fun ppIntBlock i =
      if i < 10 then sequence ppZeroZero (ppInt i)
      else if i < 100 then sequence ppZero (ppInt i)
      else ppInt i;

  fun ppIntBlocks i =
      if i < 1000 then ppInt i
      else sequence (ppIntBlocks (i div 1000))
             (sequence ppSep (ppIntBlock (i mod 1000)));
in
  fun ppPrettyInt i =
      if i < 0 then sequence ppNeg (ppIntBlocks (~i))
      else ppIntBlocks i;
end;

fun ppReal r = addString (Real.toString r);

fun ppPercent p = addString (percentToString p);

fun ppOrder x =
    addString
      (case x of
         LESS => "Less"
       | EQUAL => "Equal"
       | GREATER => "Greater");

fun ppList ppA = ppBracket "[" "]" (ppOpList "," ppA);

fun ppStream ppA = ppBracket "[" "]" (ppOpStream "," ppA);

fun ppOption ppA ao =
    case ao of
      SOME a => ppA a
    | NONE => addString "-";

fun ppPair ppA ppB = ppBracket "(" ")" (ppOp2 "," ppA ppB);

fun ppTriple ppA ppB ppC = ppBracket "(" ")" (ppOp3 "," "," ppA ppB ppC);

fun ppBreakStyle style = addString (breakStyleToString style);

fun ppPpStep step =
    let
      val cmd = ppStepToString step
    in
      blockProgram Inconsistent 2
        (addString cmd ::
         (case step of
            BeginBlock style_indent =>
              [addBreak 1,
               ppPair ppBreakStyle ppInt style_indent]
          | EndBlock => []
          | AddString s =>
              [addBreak 1,
               addString ("\"" ^ s ^ "\"")]
          | AddBreak n =>
              [addBreak 1,
               ppInt n]
          | AddNewline => []))
    end;

val ppPpstream = ppStream ppPpStep;

(* ------------------------------------------------------------------------- *)
(* Pretty-printing infix operators.                                          *)
(* ------------------------------------------------------------------------- *)

datatype infixes =
    Infixes of
      {token : string,
       precedence : int,
       leftAssoc : bool} list;

local
  fun chop l =
      case l of
        #" " :: l => let val (n,l) = chop l in (n + 1, l) end
      | _ => (0,l);
in
  fun opSpaces tok =
      let
        val tok = explode tok
        val (r,tok) = chop (rev tok)
        val (l,tok) = chop (rev tok)
        val tok = implode tok
      in
        {leftSpaces = l, token = tok, rightSpaces = r}
      end;
end;

fun ppOpSpaces {leftSpaces,token,rightSpaces} =
    let
      val leftSpacesToken =
          if leftSpaces = 0 then token else nSpaces leftSpaces ^ token
    in
      sequence
        (addString leftSpacesToken)
        (addBreak rightSpaces)
    end;

local
  fun new t l acc = {tokens = [opSpaces t], leftAssoc = l} :: acc;

  fun add t l acc =
      case acc of
        [] => raise Bug "Print.layerInfixOps.layer"
      | {tokens = ts, leftAssoc = l'} :: acc =>
        if l = l' then {tokens = opSpaces t :: ts, leftAssoc = l} :: acc
        else raise Bug "Print.layerInfixOps: mixed assocs";

  fun layer ({token = t, precedence = p, leftAssoc = l}, (acc,p')) =
      let
        val acc = if p = p' then add t l acc else new t l acc
      in
        (acc,p)
      end;
in
  fun layerInfixes (Infixes i) =
      case sortMap #precedence Int.compare i of
        [] => []
      | {token = t, precedence = p, leftAssoc = l} :: i =>
        let
          val acc = new t l []

          val (acc,_) = List.foldl layer (acc,p) i
        in
          acc
        end;
end;

val tokensLayeredInfixes =
    let
      fun addToken ({leftSpaces = _, token = t, rightSpaces = _}, s) =
          StringSet.add s t

      fun addTokens ({tokens = t, leftAssoc = _}, s) =
          List.foldl addToken s t
    in
      List.foldl addTokens StringSet.empty
    end;

val tokensInfixes = tokensLayeredInfixes o layerInfixes;

local
  val mkTokenMap =
      let
        fun add (token,m) =
            let
              val {leftSpaces = _, token = t, rightSpaces = _} = token
            in
              StringMap.insert m (t, ppOpSpaces token)
            end
      in
        List.foldl add (StringMap.new ())
      end;
in
  fun ppGenInfix {tokens,leftAssoc} =
      let
        val tokenMap = mkTokenMap tokens
      in
        fn dest => fn ppSub =>
           let
             fun dest' tm =
                 case dest tm of
                   NONE => NONE
                 | SOME (t,a,b) =>
                   case StringMap.peek tokenMap t of
                     NONE => NONE
                   | SOME p => SOME (p,a,b)

             fun ppGo (tmr as (tm,r)) =
                 case dest' tm of
                   NONE => ppSub tmr
                 | SOME (p,a,b) =>
                   program
                     [(if leftAssoc then ppGo else ppSub) (a,true),
                      p,
                      (if leftAssoc then ppSub else ppGo) (b,r)]
           in
             fn tmr as (tm,_) =>
                if Option.isSome (dest' tm) then
                  block Inconsistent 0 (ppGo tmr)
                else
                  ppSub tmr
           end
      end;
end

fun ppInfixes ops =
    let
      val layeredOps = layerInfixes ops

      val toks = tokensLayeredInfixes layeredOps

      val iprinters = List.map ppGenInfix layeredOps
    in
      fn dest => fn ppSub =>
         let
           fun printer sub = foldl (fn (ip,p) => ip dest p) sub iprinters

           fun isOp t =
               case dest t of
                 SOME (x,_,_) => StringSet.member x toks
               | NONE => false

           fun subpr (tmr as (tm,_)) =
               if isOp tm then
                 blockProgram Inconsistent 1
                   [addString "(",
                    printer subpr (tm,false),
                    addString ")"]
               else
                 ppSub tmr
         in
           fn tmr => block Inconsistent 0 (printer subpr tmr)
         end
    end;

(* ------------------------------------------------------------------------- *)
(* Executing pretty-printers to generate lines.                              *)
(* ------------------------------------------------------------------------- *)

datatype blockBreakStyle =
    InconsistentBlock
  | ConsistentBlock
  | BreakingBlock;

datatype block =
    Block of
      {indent : int,
       style : blockBreakStyle,
       size : int,
       chunks : chunk list}

and chunk =
    StringChunk of {size : int, string : string}
  | BreakChunk of int
  | BlockChunk of block;

datatype state =
    State of
      {blocks : block list,
       lineIndent : int,
       lineSize : int};

val initialIndent = 0;

val initialStyle = Inconsistent;

fun liftStyle style =
    case style of
      Inconsistent => InconsistentBlock
    | Consistent => ConsistentBlock;

fun breakStyle style =
    case style of
      ConsistentBlock => BreakingBlock
    | _ => style;

fun sizeBlock (Block {size,...}) = size;

fun sizeChunk chunk =
    case chunk of
      StringChunk {size,...} => size
    | BreakChunk spaces => spaces
    | BlockChunk block => sizeBlock block;

val splitChunks =
    let
      fun split _ [] = NONE
        | split acc (chunk :: chunks) =
          case chunk of
            BreakChunk _ => SOME (rev acc, chunks)
          | _ => split (chunk :: acc) chunks
    in
      split []
    end;

val sizeChunks = List.foldl (fn (c,z) => sizeChunk c + z) 0;

local
  fun render acc [] = acc
    | render acc (chunk :: chunks) =
      case chunk of
        StringChunk {string = s, ...} => render (acc ^ s) chunks
      | BreakChunk n => render (acc ^ nSpaces n) chunks
      | BlockChunk (Block {chunks = l, ...}) =>
        render acc (List.revAppend (l,chunks));
in
  fun renderChunks indent chunks = render (nSpaces indent) (rev chunks);

  fun renderChunk indent chunk = renderChunks indent [chunk];
end;

fun isEmptyBlock block =
    let
      val Block {indent = _, style = _, size, chunks} = block

      val empty = null chunks

(*BasicDebug
      val _ = not empty orelse size = 0 orelse
              raise Bug "Print.isEmptyBlock: bad size"
*)
    in
      empty
    end;

fun checkBlock ind block =
    let
      val Block {indent, style = _, size, chunks} = block
      val _ = indent >= ind orelse raise Bug "Print.checkBlock: bad indents"
      val size' = checkChunks indent chunks
      val _ = size = size' orelse raise Bug "Print.checkBlock: wrong size"
    in
      size
    end

and checkChunks ind chunks =
    case chunks of
      [] => 0
    | chunk :: chunks => checkChunk ind chunk + checkChunks ind chunks

and checkChunk ind chunk =
    case chunk of
      StringChunk {size,...} => size
    | BreakChunk spaces => spaces
    | BlockChunk block => checkBlock ind block;

val checkBlocks =
    let
      fun check ind blocks =
          case blocks of
            [] => 0
          | block :: blocks =>
            let
              val Block {indent,...} = block
            in
              checkBlock ind block + check indent blocks
            end
    in
      check initialIndent o rev
    end;

val initialBlock =
    let
      val indent = initialIndent
      val style = liftStyle initialStyle
      val size = 0
      val chunks = []
    in
      Block
        {indent = indent,
         style = style,
         size = size,
         chunks = chunks}
    end;

val initialState =
    let
      val blocks = [initialBlock]
      val lineIndent = initialIndent
      val lineSize = 0
    in
      State
        {blocks = blocks,
         lineIndent = lineIndent,
         lineSize = lineSize}
    end;

(*BasicDebug
fun checkState state =
    (let
       val State {blocks, lineIndent = _, lineSize} = state
       val lineSize' = checkBlocks blocks
       val _ = lineSize = lineSize' orelse
               raise Error "wrong lineSize"
     in
       ()
     end
     handle Error err => raise Bug err)
    handle Bug bug => raise Bug ("Print.checkState: " ^ bug);
*)

fun isFinalState state =
    let
      val State {blocks,lineIndent,lineSize} = state
    in
      case blocks of
        [] => raise Bug "Print.isFinalState: no block"
      | [block] => isEmptyBlock block
      | _ :: _ :: _ => false
    end;

local
  fun renderBreak lineIndent (chunks,lines) =
      let
        val line = renderChunks lineIndent chunks

        val lines = line :: lines
      in
        lines
      end;

  fun renderBreaks lineIndent lineIndent' breaks lines =
      case rev breaks of
        [] => raise Bug "Print.renderBreaks"
      | c :: cs =>
        let
          val lines = renderBreak lineIndent (c,lines)
        in
          List.foldl (renderBreak lineIndent') lines cs
        end;

  fun splitAllChunks cumulatingChunks =
      let
        fun split chunks =
            case splitChunks chunks of
              SOME (prefix,chunks) => prefix :: split chunks
            | NONE => [List.concat (chunks :: cumulatingChunks)]
      in
        split
      end;

  fun mkBreak style cumulatingChunks chunks =
      case splitChunks chunks of
        NONE => NONE
      | SOME (chunks,broken) =>
        let
          val breaks =
              case style of
                InconsistentBlock =>
                [List.concat (broken :: cumulatingChunks)]
              | _ => splitAllChunks cumulatingChunks broken
        in
          SOME (breaks,chunks)
        end;

  fun naturalBreak blocks =
      case blocks of
        [] => Right ([],[])
      | block :: blocks =>
        case naturalBreak blocks of
          Left (breaks,blocks,lineIndent,lineSize) =>
          let
            val Block {size,...} = block

            val blocks = block :: blocks

            val lineSize = lineSize + size
          in
            Left (breaks,blocks,lineIndent,lineSize)
          end
        | Right (cumulatingChunks,blocks) =>
          let
            val Block {indent,style,size,chunks} = block

            val style = breakStyle style
          in
            case mkBreak style cumulatingChunks chunks of
              SOME (breaks,chunks) =>
              let
                val size = sizeChunks chunks

                val block =
                    Block
                      {indent = indent,
                       style = style,
                       size = size,
                       chunks = chunks}

                val blocks = block :: blocks

                val lineIndent = indent

                val lineSize = size
              in
                Left (breaks,blocks,lineIndent,lineSize)
              end
            | NONE =>
              let
                val cumulatingChunks = chunks :: cumulatingChunks

                val size = 0

                val chunks = []

                val block =
                    Block
                      {indent = indent,
                       style = style,
                       size = size,
                       chunks = chunks}

                val blocks = block :: blocks
              in
                Right (cumulatingChunks,blocks)
              end
          end;

  fun forceBreakBlock cumulatingChunks block =
      let
        val Block {indent, style, size = _, chunks} = block

        val style = breakStyle style

        val break =
            case mkBreak style cumulatingChunks chunks of
              SOME (breaks,chunks) =>
              let
                val lineSize = sizeChunks chunks
                val lineIndent = indent
              in
                SOME (breaks,chunks,lineIndent,lineSize)
              end
            | NONE => forceBreakChunks cumulatingChunks chunks
      in
        case break of
          SOME (breaks,chunks,lineIndent,lineSize) =>
          let
            val size = lineSize

            val block =
                Block
                  {indent = indent,
                   style = style,
                   size = size,
                   chunks = chunks}
          in
            SOME (breaks,block,lineIndent,lineSize)
          end
        | NONE => NONE
      end

  and forceBreakChunks cumulatingChunks chunks =
      case chunks of
        [] => NONE
      | chunk :: chunks =>
        case forceBreakChunk (chunks :: cumulatingChunks) chunk of
          SOME (breaks,chunk,lineIndent,lineSize) =>
          let
            val chunks = [chunk]
          in
            SOME (breaks,chunks,lineIndent,lineSize)
          end
        | NONE =>
          case forceBreakChunks cumulatingChunks chunks of
            SOME (breaks,chunks,lineIndent,lineSize) =>
            let
              val chunks = chunk :: chunks

              val lineSize = lineSize + sizeChunk chunk
            in
              SOME (breaks,chunks,lineIndent,lineSize)
            end
          | NONE => NONE

  and forceBreakChunk cumulatingChunks chunk =
      case chunk of
        StringChunk _ => NONE
      | BreakChunk _ => raise Bug "Print.forceBreakChunk: BreakChunk"
      | BlockChunk block =>
        case forceBreakBlock cumulatingChunks block of
          SOME (breaks,block,lineIndent,lineSize) =>
          let
            val chunk = BlockChunk block
          in
            SOME (breaks,chunk,lineIndent,lineSize)
          end
        | NONE => NONE;

  fun forceBreak cumulatingChunks blocks' blocks =
      case blocks of
        [] => NONE
      | block :: blocks =>
        let
          val cumulatingChunks =
              case cumulatingChunks of
                [] => raise Bug "Print.forceBreak: null cumulatingChunks"
              | _ :: cumulatingChunks => cumulatingChunks

          val blocks' =
              case blocks' of
                [] => raise Bug "Print.forceBreak: null blocks'"
              | _ :: blocks' => blocks'
        in
          case forceBreakBlock cumulatingChunks block of
            SOME (breaks,block,lineIndent,lineSize) =>
            let
              val blocks = block :: blocks'
            in
              SOME (breaks,blocks,lineIndent,lineSize)
            end
          | NONE =>
            case forceBreak cumulatingChunks blocks' blocks of
              SOME (breaks,blocks,lineIndent,lineSize) =>
              let
                val blocks = block :: blocks

                val Block {size,...} = block

                val lineSize = lineSize + size
              in
                SOME (breaks,blocks,lineIndent,lineSize)
              end
            | NONE => NONE
        end;

  fun normalize lineLength lines state =
      let
        val State {blocks,lineIndent,lineSize} = state
      in
        if lineIndent + lineSize <= lineLength then (lines,state)
        else
          let
            val break =
                case naturalBreak blocks of
                  Left break => SOME break
                | Right (c,b) => forceBreak c b blocks
          in
            case break of
              SOME (breaks,blocks,lineIndent',lineSize) =>
              let
                val lines = renderBreaks lineIndent lineIndent' breaks lines

                val state =
                    State
                      {blocks = blocks,
                       lineIndent = lineIndent',
                       lineSize = lineSize}
              in
                normalize lineLength lines state
              end
            | NONE => (lines,state)
          end
      end;

(*BasicDebug
  val normalize = fn lineLength => fn lines => fn state =>
      let
        val () = checkState state
      in
        normalize lineLength lines state
      end
      handle Bug bug =>
        raise Bug ("Print.normalize: before normalize:\n" ^ bug)
*)

  fun executeBeginBlock (style,ind) lines state =
      let
        val State {blocks,lineIndent,lineSize} = state

        val Block {indent,...} =
            case blocks of
              [] => raise Bug "Print.executeBeginBlock: no block"
            | block :: _ => block

        val indent = indent + ind

        val style = liftStyle style

        val size = 0

        val chunks = []

        val block =
            Block
              {indent = indent,
               style = style,
               size = size,
               chunks = chunks}

        val blocks = block :: blocks

        val state =
            State
              {blocks = blocks,
               lineIndent = lineIndent,
               lineSize = lineSize}
      in
        (lines,state)
      end;

  fun executeEndBlock lines state =
      let
        val State {blocks,lineIndent,lineSize} = state

        val (lineSize,blocks) =
            case blocks of
              [] => raise Bug "Print.executeEndBlock: no block"
            | topBlock :: blocks =>
              let
                val Block
                      {indent = topIndent,
                       style = topStyle,
                       size = topSize,
                       chunks = topChunks} = topBlock
              in
                case topChunks of
                  [] => (lineSize,blocks)
                | headTopChunks :: tailTopChunks =>
                  let
                    val (lineSize,topSize,topChunks) =
                        case headTopChunks of
                          BreakChunk spaces =>
                          let
                            val lineSize = lineSize - spaces
                            and topSize = topSize - spaces
                            and topChunks = tailTopChunks
                          in
                            (lineSize,topSize,topChunks)
                          end
                        | _ => (lineSize,topSize,topChunks)

                    val topBlock =
                        Block
                          {indent = topIndent,
                           style = topStyle,
                           size = topSize,
                           chunks = topChunks}
                  in
                    case blocks of
                      [] => raise Error "Print.executeEndBlock: no block"
                    | block :: blocks =>
                      let
                        val Block {indent,style,size,chunks} = block

                        val size = size + topSize

                        val chunks = BlockChunk topBlock :: chunks

                        val block =
                            Block
                              {indent = indent,
                               style = style,
                               size = size,
                               chunks = chunks}

                        val blocks = block :: blocks
                      in
                        (lineSize,blocks)
                      end
                  end
              end

        val state =
            State
              {blocks = blocks,
               lineIndent = lineIndent,
               lineSize = lineSize}
      in
        (lines,state)
      end;

  fun executeAddString lineLength s lines state =
      let
        val State {blocks,lineIndent,lineSize} = state

        val n = size s

        val blocks =
            case blocks of
              [] => raise Bug "Print.executeAddString: no block"
            | Block {indent,style,size,chunks} :: blocks =>
              let
                val size = size + n

                val chunk = StringChunk {size = n, string = s}

                val chunks = chunk :: chunks

                val block =
                    Block
                      {indent = indent,
                       style = style,
                       size = size,
                       chunks = chunks}

                val blocks = block :: blocks
              in
                blocks
              end

        val lineSize = lineSize + n

        val state =
            State
              {blocks = blocks,
               lineIndent = lineIndent,
               lineSize = lineSize}
      in
        normalize lineLength lines state
      end;

  fun executeAddBreak lineLength spaces lines state =
      let
        val State {blocks,lineIndent,lineSize} = state

        val (blocks,lineSize) =
            case blocks of
              [] => raise Bug "Print.executeAddBreak: no block"
            | Block {indent,style,size,chunks} :: blocks' =>
              case chunks of
                [] => (blocks,lineSize)
              | chunk :: chunks' =>
                let
                  val spaces =
                      case style of
                        BreakingBlock => lineLength + 1
                      | _ => spaces

                  val size = size + spaces

                  val chunks =
                      case chunk of
                        BreakChunk k => BreakChunk (k + spaces) :: chunks'
                      | _ => BreakChunk spaces :: chunks

                  val block =
                      Block
                        {indent = indent,
                         style = style,
                         size = size,
                         chunks = chunks}

                  val blocks = block :: blocks'

                  val lineSize = lineSize + spaces
                in
                  (blocks,lineSize)
                end

        val state =
            State
              {blocks = blocks,
               lineIndent = lineIndent,
               lineSize = lineSize}
      in
        normalize lineLength lines state
      end;

  fun executeBigBreak lineLength lines state =
      executeAddBreak lineLength (lineLength + 1) lines state;

  fun executeAddNewline lineLength lines state =
      let
        val (lines,state) = executeAddString lineLength "" lines state
        val (lines,state) = executeBigBreak lineLength lines state
      in
        executeAddString lineLength "" lines state
      end;

  fun final lineLength lines state =
      let
        val lines =
            if isFinalState state then lines
            else
              let
                val (lines,state) = executeBigBreak lineLength lines state

(*BasicDebug
                val _ = isFinalState state orelse raise Bug "Print.final"
*)
              in
                lines
              end
      in
        if null lines then Stream.Nil else Stream.singleton lines
      end;
in
  fun execute {lineLength} =
      let
        fun advance step state =
            let
              val lines = []
            in
              case step of
                BeginBlock style_ind => executeBeginBlock style_ind lines state
              | EndBlock => executeEndBlock lines state
              | AddString s => executeAddString lineLength s lines state
              | AddBreak spaces => executeAddBreak lineLength spaces lines state
              | AddNewline => executeAddNewline lineLength lines state
            end

(*BasicDebug
        val advance = fn step => fn state =>
            let
              val (lines,state) = advance step state
              val () = checkState state
            in
              (lines,state)
            end
            handle Bug bug =>
              raise Bug ("Print.advance: after " ^ ppStepToString step ^
                         " command:\n" ^ bug)
*)
      in
        revConcat o Stream.maps advance (final lineLength []) initialState
      end;
end;

(* ------------------------------------------------------------------------- *)
(* Executing pretty-printers with a global line length.                      *)
(* ------------------------------------------------------------------------- *)

val lineLength = Unsynchronized.ref initialLineLength;

fun toStream ppA a =
    Stream.map (fn s => s ^ "\n")
      (execute {lineLength = !lineLength} (ppA a));

fun toString ppA a =
    case execute {lineLength = !lineLength} (ppA a) of
      Stream.Nil => ""
    | Stream.Cons (h,t) => Stream.foldl (fn (s,z) => z ^ "\n" ^ s) h (t ());

fun trace ppX nameX x =
    Useful.trace (toString (ppOp2 " =" ppString ppX) (nameX,x) ^ "\n");

end
end;

(**** Original file: Parse.sig ****)

(* ========================================================================= *)
(* PARSING                                                                   *)
(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature Parse =
sig

(* ------------------------------------------------------------------------- *)
(* A "cannot parse" exception.                                               *)
(* ------------------------------------------------------------------------- *)

exception NoParse

(* ------------------------------------------------------------------------- *)
(* Recursive descent parsing combinators.                                    *)
(* ------------------------------------------------------------------------- *)

(*
  Recommended fixities:

  infixr 9 >>++
  infixr 8 ++
  infixr 7 >>
  infixr 6 ||
*)

val error : 'a -> 'b * 'a

val ++ : ('a -> 'b * 'a) * ('a -> 'c * 'a) -> 'a -> ('b * 'c) * 'a

val >> : ('a -> 'b * 'a) * ('b -> 'c) -> 'a -> 'c * 'a

val >>++ : ('a -> 'b * 'a) * ('b -> 'a -> 'c * 'a) -> 'a -> 'c * 'a

val || : ('a -> 'b * 'a) * ('a -> 'b * 'a) -> 'a -> 'b * 'a

val first : ('a -> 'b * 'a) list -> 'a -> 'b * 'a

val mmany : ('s -> 'a -> 's * 'a) -> 's -> 'a -> 's * 'a

val many : ('a -> 'b * 'a) -> 'a -> 'b list * 'a

val atLeastOne : ('a -> 'b * 'a) -> 'a -> 'b list * 'a

val nothing : 'a -> unit * 'a

val optional : ('a -> 'b * 'a) -> 'a -> 'b option * 'a

(* ------------------------------------------------------------------------- *)
(* Stream-based parsers.                                                     *)
(* ------------------------------------------------------------------------- *)

type ('a,'b) parser = 'a Metis.Stream.stream -> 'b * 'a Metis.Stream.stream

val maybe : ('a -> 'b option) -> ('a,'b) parser

val finished : ('a,unit) parser

val some : ('a -> bool) -> ('a,'a) parser

val any : ('a,'a) parser

(* ------------------------------------------------------------------------- *)
(* Parsing whole streams.                                                    *)
(* ------------------------------------------------------------------------- *)

val fromStream : ('a,'b) parser -> 'a Metis.Stream.stream -> 'b

val fromList : ('a,'b) parser -> 'a list -> 'b

val everything : ('a, 'b list) parser -> 'a Metis.Stream.stream -> 'b Metis.Stream.stream

(* ------------------------------------------------------------------------- *)
(* Parsing lines of text.                                                    *)
(* ------------------------------------------------------------------------- *)

val initialize :
    {lines : string Metis.Stream.stream} ->
    {chars : char list Metis.Stream.stream,
     parseErrorLocation : unit -> string}

val exactChar : char -> (char,unit) parser

val exactCharList : char list -> (char,unit) parser

val exactString : string -> (char,unit) parser

val escapeString : {escape : char list} -> (char,string) parser

val manySpace : (char,unit) parser

val atLeastOneSpace : (char,unit) parser

val fromString : (char,'a) parser -> string -> 'a

(* ------------------------------------------------------------------------- *)
(* Infix operators.                                                          *)
(* ------------------------------------------------------------------------- *)

val parseInfixes :
    Metis.Print.infixes -> (string * 'a * 'a -> 'a) -> (string,'a) parser ->
    (string,'a) parser

(* ------------------------------------------------------------------------- *)
(* Quotations.                                                               *)
(* ------------------------------------------------------------------------- *)

type 'a quotation = 'a Metis.frag list

val parseQuotation : ('a -> string) -> (string -> 'b) -> 'a quotation -> 'b

end

(**** Original file: Parse.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* PARSING                                                                   *)
(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

structure Parse :> Parse =
struct

open Useful;

infixr 9 >>++
infixr 8 ++
infixr 7 >>
infixr 6 ||

(* ------------------------------------------------------------------------- *)
(* A "cannot parse" exception.                                               *)
(* ------------------------------------------------------------------------- *)

exception NoParse;

(* ------------------------------------------------------------------------- *)
(* Recursive descent parsing combinators.                                    *)
(* ------------------------------------------------------------------------- *)

val error : 'a -> 'b * 'a = fn _ => raise NoParse;

fun op ++ (parser1,parser2) input =
    let
      val (result1,input) = parser1 input
      val (result2,input) = parser2 input
    in
      ((result1,result2),input)
    end;

fun op >> (parser : 'a -> 'b * 'a, treatment) input =
    let
      val (result,input) = parser input
    in
      (treatment result, input)
    end;

fun op >>++ (parser,treatment) input =
    let
      val (result,input) = parser input
    in
      treatment result input
    end;

fun op || (parser1,parser2) input =
    parser1 input handle NoParse => parser2 input;

fun first [] _ = raise NoParse
  | first (parser :: parsers) input = (parser || first parsers) input;

fun mmany parser state input =
    let
      val (state,input) = parser state input
    in
      mmany parser state input
    end
    handle NoParse => (state,input);

fun many parser =
    let
      fun sparser l = parser >> (fn x => x :: l)
    in
      mmany sparser [] >> rev
    end;

fun atLeastOne p = (p ++ many p) >> op::;

fun nothing input = ((),input);

fun optional p = (p >> SOME) || (nothing >> K NONE);

(* ------------------------------------------------------------------------- *)
(* Stream-based parsers.                                                     *)
(* ------------------------------------------------------------------------- *)

type ('a,'b) parser = 'a Stream.stream -> 'b * 'a Stream.stream

fun maybe p Stream.Nil = raise NoParse
  | maybe p (Stream.Cons (h,t)) =
    case p h of SOME r => (r, t ()) | NONE => raise NoParse;

fun finished Stream.Nil = ((), Stream.Nil)
  | finished (Stream.Cons _) = raise NoParse;

fun some p = maybe (fn x => if p x then SOME x else NONE);

fun any input = some (K true) input;

(* ------------------------------------------------------------------------- *)
(* Parsing whole streams.                                                    *)
(* ------------------------------------------------------------------------- *)

fun fromStream parser input =
    let
      val (res,_) = (parser ++ finished >> fst) input
    in
      res
    end;

fun fromList parser l = fromStream parser (Stream.fromList l);

fun everything parser =
    let
      fun parserOption input =
          SOME (parser input)
          handle e as NoParse => if Stream.null input then NONE else raise e

      fun parserList input =
          case parserOption input of
            NONE => Stream.Nil
          | SOME (result,input) =>
            Stream.append (Stream.fromList result) (fn () => parserList input)
    in
      parserList
    end;

(* ------------------------------------------------------------------------- *)
(* Parsing lines of text.                                                    *)
(* ------------------------------------------------------------------------- *)

fun initialize {lines} =
    let
      val lastLine = Unsynchronized.ref (~1,"","","")

      val chars =
          let
            fun saveLast line =
                let
                  val Unsynchronized.ref (n,_,l2,l3) = lastLine
                  val () = lastLine := (n + 1, l2, l3, line)
                in
                  explode line
                end
          in
            Stream.memoize (Stream.map saveLast lines)
          end

      fun parseErrorLocation () =
          let
            val Unsynchronized.ref (n,l1,l2,l3) = lastLine
          in
            (if n <= 0 then "at start"
             else "around line " ^ Int.toString n) ^
            chomp (":\n" ^ l1 ^ l2 ^ l3)
          end
    in
      {chars = chars,
       parseErrorLocation = parseErrorLocation}
    end;

fun exactChar (c : char) = some (equal c) >> K ();

fun exactCharList cs =
    case cs of
      [] => nothing
    | c :: cs => (exactChar c ++ exactCharList cs) >> snd;

fun exactString s = exactCharList (explode s);

fun escapeString {escape} =
    let
      fun isEscape c = mem c escape

      fun isNormal c =
          case c of
            #"\\" => false
          | #"\n" => false
          | #"\t" => false
          | _ => not (isEscape c)

      val escapeParser =
          (exactChar #"\\" >> K #"\\") ||
          (exactChar #"n" >> K #"\n") ||
          (exactChar #"t" >> K #"\t") ||
          some isEscape

      val charParser =
          ((exactChar #"\\" ++ escapeParser) >> snd) ||
          some isNormal
    in
      many charParser >> implode
    end;

local
  val isSpace = Char.isSpace;

  val space = some isSpace;
in
  val manySpace = many space >> K ();

  val atLeastOneSpace = atLeastOne space >> K ();
end;

fun fromString parser s = fromList parser (explode s);

(* ------------------------------------------------------------------------- *)
(* Infix operators.                                                          *)
(* ------------------------------------------------------------------------- *)

fun parseGenInfix update sof toks parse inp =
    let
      val (e,rest) = parse inp

      val continue =
          case rest of
            Stream.Nil => NONE
          | Stream.Cons (h_t as (h,_)) =>
            if StringSet.member h toks then SOME h_t else NONE
    in
      case continue of
        NONE => (sof e, rest)
      | SOME (h,t) => parseGenInfix update (update sof h e) toks parse (t ())
    end;

local
  fun add ({leftSpaces = _, token = t, rightSpaces = _}, s) = StringSet.add s t;

  fun parse leftAssoc toks con =
      let
        val update =
            if leftAssoc then (fn f => fn t => fn a => fn b => con (t, f a, b))
            else (fn f => fn t => fn a => fn b => f (con (t, a, b)))
      in
        parseGenInfix update I toks
      end;
in
  fun parseLayeredInfixes {tokens,leftAssoc} =
      let
        val toks = List.foldl add StringSet.empty tokens
      in
        parse leftAssoc toks
      end;
end;

fun parseInfixes ops =
    let
      val layeredOps = Print.layerInfixes ops

      val iparsers = List.map parseLayeredInfixes layeredOps
    in
      fn con => fn subparser => foldl (fn (p,sp) => p con sp) subparser iparsers
    end;

(* ------------------------------------------------------------------------- *)
(* Quotations.                                                               *)
(* ------------------------------------------------------------------------- *)

type 'a quotation = 'a frag list;

fun parseQuotation printer parser quote =
  let
    fun expand (QUOTE q, s) = s ^ q
      | expand (ANTIQUOTE a, s) = s ^ printer a

    val string = foldl expand "" quote
  in
    parser string
  end;

end
end;

(**** Original file: Options.sig ****)

(* ========================================================================= *)
(* PROCESSING COMMAND LINE OPTIONS                                           *)
(* Copyright (c) 2003-2004 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Options =
sig

(* ------------------------------------------------------------------------- *)
(* Option processors take an option with its associated arguments.           *)
(* ------------------------------------------------------------------------- *)

type proc = string * string list -> unit

type ('a,'x) mkProc = ('x -> proc) -> ('a -> 'x) -> proc

(* ------------------------------------------------------------------------- *)
(* One command line option: names, arguments, description and a processor.   *)
(* ------------------------------------------------------------------------- *)

type opt =
     {switches : string list, arguments : string list,
      description : string, processor : proc}

(* ------------------------------------------------------------------------- *)
(* Option processors may raise an OptionExit exception.                      *)
(* ------------------------------------------------------------------------- *)

type optionExit = {message : string option, usage : bool, success : bool}

exception OptionExit of optionExit

(* ------------------------------------------------------------------------- *)
(* Constructing option processors.                                           *)
(* ------------------------------------------------------------------------- *)

val beginOpt : (string,'x) mkProc

val endOpt : unit -> proc

val stringOpt : (string,'x) mkProc

val intOpt : int option * int option -> (int,'x) mkProc

val realOpt : real option * real option -> (real,'x) mkProc

val enumOpt : string list -> (string,'x) mkProc

val optionOpt : string * ('a,'x) mkProc -> ('a option,'x) mkProc

(* ------------------------------------------------------------------------- *)
(* Basic options useful for all programs.                                    *)
(* ------------------------------------------------------------------------- *)

val basicOptions : opt list

(* ------------------------------------------------------------------------- *)
(* All the command line options of a program.                                *)
(* ------------------------------------------------------------------------- *)

type allOptions =
     {name : string, version : string, header : string,
      footer : string, options : opt list}

(* ------------------------------------------------------------------------- *)
(* Usage information.                                                        *)
(* ------------------------------------------------------------------------- *)

val versionInformation : allOptions -> string

val usageInformation : allOptions -> string

(* ------------------------------------------------------------------------- *)
(* Exit the program gracefully.                                              *)
(* ------------------------------------------------------------------------- *)

val exit : allOptions -> optionExit -> 'exit

val succeed : allOptions -> 'exit

val fail : allOptions -> string -> 'exit

val usage : allOptions -> string -> 'exit

val version : allOptions -> 'exit

(* ------------------------------------------------------------------------- *)
(* Process the command line options passed to the program.                   *)
(* ------------------------------------------------------------------------- *)

val processOptions : allOptions -> string list -> string list * string list

end

(**** Original file: Options.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* PROCESSING COMMAND LINE OPTIONS                                           *)
(* Copyright (c) 2003-2004 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Options :> Options =
struct

infix ##

open Useful;

(* ------------------------------------------------------------------------- *)
(* One command line option: names, arguments, description and a processor    *)
(* ------------------------------------------------------------------------- *)

type proc = string * string list -> unit;

type ('a,'x) mkProc = ('x -> proc) -> ('a -> 'x) -> proc;

type opt = {switches : string list, arguments : string list,
            description : string, processor : proc};

(* ------------------------------------------------------------------------- *)
(* Option processors may raise an OptionExit exception                       *)
(* ------------------------------------------------------------------------- *)

type optionExit = {message : string option, usage : bool, success : bool};

exception OptionExit of optionExit;

(* ------------------------------------------------------------------------- *)
(* Wrappers for option processors                                            *)
(* ------------------------------------------------------------------------- *)

fun beginOpt f p (s : string, l : string list) : unit = f (p s) (s,l);

fun endOpt () (_ : string, [] : string list) = ()
  | endOpt _ (_, _ :: _) = raise Bug "endOpt";

fun stringOpt _ _ (_ : string, []) = raise Bug "stringOpt"
  | stringOpt f p (s, (h : string) :: t) : unit = f (p h) (s,t);

local
  fun range NONE NONE = "Z"
    | range (SOME i) NONE = "{n IN Z | " ^ Int.toString i ^ " <= n}"
    | range NONE (SOME j) = "{n IN Z | n <= " ^ Int.toString j ^ "}"
    | range (SOME i) (SOME j) =
    "{n IN Z | " ^ Int.toString i ^ " <= n <= " ^ Int.toString j ^ "}";
  fun oLeq (SOME x) (SOME y) = x <= y | oLeq _ _ = true;
  fun argToInt arg omin omax x =
    (case Int.fromString x of
       SOME i =>
       if oLeq omin (SOME i) andalso oLeq (SOME i) omax then i else
         raise OptionExit
           {success = false, usage = false, message =
            SOME (arg ^ " option needs an integer argument in the range "
                  ^ range omin omax ^ " (not " ^ x ^ ")")}
     | NONE =>
       raise OptionExit
         {success = false, usage = false, message =
          SOME (arg ^ " option needs an integer argument (not \"" ^ x ^ "\")")})
    handle Overflow =>
       raise OptionExit
         {success = false, usage = false, message =
          SOME (arg ^ " option suffered integer overflow on argument " ^ x)};
in
  fun intOpt _ _ _ (_,[]) = raise Bug "intOpt"
    | intOpt (omin,omax) f p (s:string, h :: (t : string list)) : unit =
      f (p (argToInt s omin omax h)) (s,t);
end;

local
  fun range NONE NONE = "R"
    | range (SOME i) NONE = "{n IN R | " ^ Real.toString i ^ " <= n}"
    | range NONE (SOME j) = "{n IN R | n <= " ^ Real.toString j ^ "}"
    | range (SOME i) (SOME j) =
    "{n IN R | " ^ Real.toString i ^ " <= n <= " ^ Real.toString j ^ "}";
  fun oLeq (SOME (x:real)) (SOME y) = x <= y | oLeq _ _ = true;
  fun argToReal arg omin omax x =
    (case Real.fromString x of
       SOME i =>
       if oLeq omin (SOME i) andalso oLeq (SOME i) omax then i else
         raise OptionExit
           {success = false, usage = false, message =
            SOME (arg ^ " option needs an real argument in the range "
                  ^ range omin omax ^ " (not " ^ x ^ ")")}
     | NONE =>
       raise OptionExit
         {success = false, usage = false, message =
          SOME (arg ^ " option needs an real argument (not \"" ^ x ^ "\")")})
in
  fun realOpt _ _ _ (_,[]) = raise Bug "realOpt"
    | realOpt (omin,omax) f p (s:string, h :: (t : string list)) : unit =
      f (p (argToReal s omin omax h)) (s,t);
end;

fun enumOpt _ _ _ (_,[]) = raise Bug "enumOpt"
  | enumOpt (choices : string list) f p (s : string, h :: t) : unit =
    if mem h choices then f (p h) (s,t) else
      raise OptionExit
        {success = false, usage = false,
         message = SOME ("follow parameter " ^ s ^ " with one of {" ^
                         join "," choices ^ "}, not \"" ^ h ^ "\"")};

fun optionOpt _ _ _ (_,[]) = raise Bug "optionOpt"
  | optionOpt (x : string, p) f q (s : string, l as h :: t) : unit =
    if h = x then f (q NONE) (s,t) else p f (q o SOME) (s,l);

(* ------------------------------------------------------------------------- *)
(* Basic options useful for all programs                                     *)
(* ------------------------------------------------------------------------- *)

val basicOptions : opt list =
  [{switches = ["--"], arguments = [],
    description = "no more options",
    processor = fn _ => raise Fail "basicOptions: --"},
   {switches = ["-?","-h","--help"], arguments = [],
    description = "display option information and exit",
    processor = fn _ => raise OptionExit
    {message = SOME "displaying option information",
     usage = true, success = true}},
   {switches = ["-v", "--version"], arguments = [],
    description = "display version information",
    processor = fn _ => raise Fail "basicOptions: -v, --version"}];

(* ------------------------------------------------------------------------- *)
(* All the command line options of a program                                 *)
(* ------------------------------------------------------------------------- *)

type allOptions =
     {name : string, version : string, header : string,
      footer : string, options : opt list};

(* ------------------------------------------------------------------------- *)
(* Usage information                                                         *)
(* ------------------------------------------------------------------------- *)

fun versionInformation ({version, ...} : allOptions) = version;

fun usageInformation ({name,version,header,footer,options} : allOptions) =
  let
    fun listOpts {switches = n, arguments = r, description = s,
                  processor = _} =
        let
          fun indent (s, "" :: l) = indent (s ^ "  ", l) | indent x = x
          val (res,n) = indent ("  ",n)
          val res = res ^ join ", " n
          val res = foldl (fn (x,y) => y ^ " " ^ x) res r
        in
          [res ^ " ...", " " ^ s]
        end

    val alignment =
        [{leftAlign = true, padChar = #"."},
         {leftAlign = true, padChar = #" "}]

    val table = alignTable alignment (map listOpts options)
  in
    header ^ join "\n" table ^ "\n" ^ footer
  end;

(* ------------------------------------------------------------------------- *)
(* Exit the program gracefully                                               *)
(* ------------------------------------------------------------------------- *)

fun exit (allopts : allOptions) (optexit : optionExit) =
  let
    val {name, options, ...} = allopts
    val {message, usage, success} = optexit
    fun err s = TextIO.output (TextIO.stdErr, s)
  in
    case message of NONE => () | SOME m => err (name ^ ": " ^ m ^ "\n");
    if usage then err (usageInformation allopts) else ();
    OS.Process.exit (if success then OS.Process.success else OS.Process.failure)
  end;

fun succeed allopts =
    exit allopts {message = NONE, usage = false, success = true};

fun fail allopts mesg =
    exit allopts {message = SOME mesg, usage = false, success = false};

fun usage allopts mesg =
    exit allopts {message = SOME mesg, usage = true, success = false};

fun version allopts =
    (print (versionInformation allopts);
     exit allopts {message = NONE, usage = false, success = true});

(* ------------------------------------------------------------------------- *)
(* Process the command line options passed to the program                    *)
(* ------------------------------------------------------------------------- *)

fun processOptions (allopts : allOptions) =
  let
    fun findOption x =
      case List.find (fn {switches = n, ...} => mem x n) (#options allopts) of
        NONE => raise OptionExit
                        {message = SOME ("unknown switch \"" ^ x ^ "\""),
                         usage = true, success = false}
      | SOME {arguments = r, processor = f, ...} => (r,f)

    fun getArgs x r xs =
      let
        fun f 1 = "a following argument"
          | f m = Int.toString m ^ " following arguments"
        val m = length r
        val () =
          if m <= length xs then () else
            raise OptionExit
              {usage = false, success = false, message = SOME
               (x ^ " option needs " ^ f m ^ ": " ^ join " " r)}
      in
        divide xs m
      end

    fun process [] = ([], [])
      | process ("--" :: xs) = ([("--",[])], xs)
      | process ("-v" :: _) = version allopts
      | process ("--version" :: _) = version allopts
      | process (x :: xs) =
      if x = "" orelse x = "-" orelse hd (explode x) <> #"-" then ([], x :: xs)
      else
        let
          val (r,f) = findOption x
          val (ys,xs) = getArgs x r xs
          val () = f (x,ys)
        in
          (cons (x,ys) ## I) (process xs)
        end
  in
    fn l =>
    let
      val (a,b) = process l
      val a = foldl (fn ((x,xs),ys) => x :: xs @ ys) [] (rev a)
    in
      (a,b)
    end
    handle OptionExit x => exit allopts x
  end;

end
end;

(**** Original file: Name.sig ****)

(* ========================================================================= *)
(* NAMES                                                                     *)
(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature Name =
sig

(* ------------------------------------------------------------------------- *)
(* A type of names.                                                          *)
(* ------------------------------------------------------------------------- *)

type name = string (* MODIFIED by Jasmin Blanchette *)

(* ------------------------------------------------------------------------- *)
(* A total ordering.                                                         *)
(* ------------------------------------------------------------------------- *)

val compare : name * name -> order

val equal : name -> name -> bool

(* ------------------------------------------------------------------------- *)
(* Fresh names.                                                              *)
(* ------------------------------------------------------------------------- *)

val newName : unit -> name

val newNames : int -> name list

val variantPrime : (name -> bool) -> name -> name

val variantNum : (name -> bool) -> name -> name

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty printing.                                              *)
(* ------------------------------------------------------------------------- *)

val pp : name Metis.Print.pp

val toString : name -> string

val fromString : string -> name

end

(**** Original file: Name.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* NAMES                                                                     *)
(* Copyright (c) 2004 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

structure Name :> Name =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of names.                                                          *)
(* ------------------------------------------------------------------------- *)

type name = string;

(* ------------------------------------------------------------------------- *)
(* A total ordering.                                                         *)
(* ------------------------------------------------------------------------- *)

val compare = String.compare;

fun equal n1 n2 = n1 = n2;

(* ------------------------------------------------------------------------- *)
(* Fresh variables.                                                          *)
(* ------------------------------------------------------------------------- *)

local
  val prefix  = "_";

  fun numName i = mkPrefix prefix (Int.toString i);
in
  fun newName () = numName (newInt ());

  fun newNames n = map numName (newInts n);
end;

fun variantPrime acceptable =
    let
      fun variant n = if acceptable n then n else variant (n ^ "'")
    in
      variant
    end;

local
  fun isDigitOrPrime #"'" = true
    | isDigitOrPrime c = Char.isDigit c;
in
  fun variantNum acceptable n =
      if acceptable n then n
      else
        let
          val n = stripSuffix isDigitOrPrime n

          fun variant i =
              let
                val n_i = n ^ Int.toString i
              in
                if acceptable n_i then n_i else variant (i + 1)
              end
        in
          variant 0
        end;
end;

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty printing.                                              *)
(* ------------------------------------------------------------------------- *)

val pp = Print.ppString;

fun toString s : string = s;

fun fromString s : name = s;

end

structure NameOrdered =
struct type t = Name.name val compare = Name.compare end

structure NameMap = KeyMap (NameOrdered);

structure NameSet =
struct

  local
    structure S = ElementSet (NameMap);
  in
    open S;
  end;

  val pp =
      Print.ppMap
        toList
        (Print.ppBracket "{" "}" (Print.ppOpList "," Name.pp));

end
end;

(**** Original file: NameArity.sig ****)

(* ========================================================================= *)
(* NAME/ARITY PAIRS                                                          *)
(* Copyright (c) 2004-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature NameArity =
sig

(* ------------------------------------------------------------------------- *)
(* A type of name/arity pairs.                                               *)
(* ------------------------------------------------------------------------- *)

type nameArity = Metis.Name.name * int

val name : nameArity -> Metis.Name.name

val arity : nameArity -> int

(* ------------------------------------------------------------------------- *)
(* Testing for different arities.                                            *)
(* ------------------------------------------------------------------------- *)

val nary : int -> nameArity -> bool

val nullary : nameArity -> bool

val unary : nameArity -> bool

val binary : nameArity -> bool

val ternary : nameArity -> bool

(* ------------------------------------------------------------------------- *)
(* A total ordering.                                                         *)
(* ------------------------------------------------------------------------- *)

val compare : nameArity * nameArity -> order

val equal : nameArity -> nameArity -> bool

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty printing.                                              *)
(* ------------------------------------------------------------------------- *)

val pp : nameArity Metis.Print.pp

end

(**** Original file: NameArity.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* NAME/ARITY PAIRS                                                          *)
(* Copyright (c) 2004-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure NameArity :> NameArity =
struct

(* ------------------------------------------------------------------------- *)
(* A type of name/arity pairs.                                               *)
(* ------------------------------------------------------------------------- *)

type nameArity = Name.name * int;

fun name ((n,_) : nameArity) = n;

fun arity ((_,i) : nameArity) = i;

(* ------------------------------------------------------------------------- *)
(* Testing for different arities.                                            *)
(* ------------------------------------------------------------------------- *)

fun nary i n_i = arity n_i = i;

val nullary = nary 0
and unary = nary 1
and binary = nary 2
and ternary = nary 3;

(* ------------------------------------------------------------------------- *)
(* A total ordering.                                                         *)
(* ------------------------------------------------------------------------- *)

fun compare ((n1,i1),(n2,i2)) =
    case Name.compare (n1,n2) of
      LESS => LESS
    | EQUAL => Int.compare (i1,i2)
    | GREATER => GREATER;

fun equal (n1,i1) (n2,i2) = i1 = i2 andalso Name.equal n1 n2;

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty printing.                                              *)
(* ------------------------------------------------------------------------- *)

fun pp (n,i) =
    Print.blockProgram Print.Inconsistent 0
      [Name.pp n,
       Print.addString "/",
       Print.ppInt i];

end

structure NameArityOrdered =
struct type t = NameArity.nameArity val compare = NameArity.compare end

structure NameArityMap =
struct

  local
    structure S = KeyMap (NameArityOrdered);
  in
    open S;
  end;

  fun compose m1 m2 =
      let
        fun pk ((_,a),n) = peek m2 (n,a)
      in
        mapPartial pk m1
      end;

end

structure NameAritySet =
struct

  local
    structure S = ElementSet (NameArityMap);
  in
    open S;
  end;

  val allNullary = all NameArity.nullary;

  val pp =
      Print.ppMap
        toList
        (Print.ppBracket "{" "}" (Print.ppOpList "," NameArity.pp));

end
end;

(**** Original file: Term.sig ****)

(* ========================================================================= *)
(* FIRST ORDER LOGIC TERMS                                                   *)
(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature Term =
sig

(* ------------------------------------------------------------------------- *)
(* A type of first order logic terms.                                        *)
(* ------------------------------------------------------------------------- *)

type var = Metis.Name.name

type functionName = Metis.Name.name

type function = functionName * int

type const = functionName

datatype term =
    Var of var
  | Fn of functionName * term list

(* ------------------------------------------------------------------------- *)
(* Constructors and destructors.                                             *)
(* ------------------------------------------------------------------------- *)

(* Variables *)

val destVar : term -> var

val isVar : term -> bool

val equalVar : var -> term -> bool

(* Functions *)

val destFn : term -> functionName * term list

val isFn : term -> bool

val fnName : term -> functionName

val fnArguments : term -> term list

val fnArity : term -> int

val fnFunction : term -> function

val functions : term -> Metis.NameAritySet.set

val functionNames : term -> Metis.NameSet.set

(* Constants *)

val mkConst : const -> term

val destConst : term -> const

val isConst : term -> bool

(* Binary functions *)

val mkBinop : functionName -> term * term -> term

val destBinop : functionName -> term -> term * term

val isBinop : functionName -> term -> bool

(* ------------------------------------------------------------------------- *)
(* The size of a term in symbols.                                            *)
(* ------------------------------------------------------------------------- *)

val symbols : term -> int

(* ------------------------------------------------------------------------- *)
(* A total comparison function for terms.                                    *)
(* ------------------------------------------------------------------------- *)

val compare : term * term -> order

val equal : term -> term -> bool

(* ------------------------------------------------------------------------- *)
(* Subterms.                                                                 *)
(* ------------------------------------------------------------------------- *)

type path = int list

val subterm : term -> path -> term

val subterms : term -> (path * term) list

val replace : term -> path * term -> term

val find : (term -> bool) -> term -> path option

val ppPath : path Metis.Print.pp

val pathToString : path -> string

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

val freeIn : var -> term -> bool

val freeVars : term -> Metis.NameSet.set

val freeVarsList : term list -> Metis.NameSet.set

(* ------------------------------------------------------------------------- *)
(* Fresh variables.                                                          *)
(* ------------------------------------------------------------------------- *)

val newVar : unit -> term

val newVars : int -> term list

val variantPrime : Metis.NameSet.set -> var -> var

val variantNum : Metis.NameSet.set -> var -> var

(* ------------------------------------------------------------------------- *)
(* Special support for terms with type annotations.                          *)
(* ------------------------------------------------------------------------- *)

val hasTypeFunctionName : functionName

val hasTypeFunction : function

val isTypedVar : term -> bool

val typedSymbols : term -> int

val nonVarTypedSubterms : term -> (path * term) list

(* ------------------------------------------------------------------------- *)
(* Special support for terms with an explicit function application operator. *)
(* ------------------------------------------------------------------------- *)

val appName : Metis.Name.name

val mkApp : term * term -> term

val destApp : term -> term * term

val isApp : term -> bool

val listMkApp : term * term list -> term

val stripApp : term -> term * term list

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty printing.                                              *)
(* ------------------------------------------------------------------------- *)

(* Infix symbols *)

val infixes : Metis.Print.infixes Unsynchronized.ref

(* The negation symbol *)

val negation : string Unsynchronized.ref

(* Binder symbols *)

val binders : string list Unsynchronized.ref

(* Bracket symbols *)

val brackets : (string * string) list Unsynchronized.ref

(* Pretty printing *)

val pp : term Metis.Print.pp

val toString : term -> string

(* Parsing *)

val fromString : string -> term

val parse : term Metis.Parse.quotation -> term

end

(**** Original file: Term.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* FIRST ORDER LOGIC TERMS                                                   *)
(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

structure Term :> Term =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of first order logic terms.                                        *)
(* ------------------------------------------------------------------------- *)

type var = Name.name;

type functionName = Name.name;

type function = functionName * int;

type const = functionName;

datatype term =
    Var of Name.name
  | Fn of Name.name * term list;

(* ------------------------------------------------------------------------- *)
(* Constructors and destructors.                                             *)
(* ------------------------------------------------------------------------- *)

(* Variables *)

fun destVar (Var v) = v
  | destVar (Fn _) = raise Error "destVar";

val isVar = can destVar;

fun equalVar v (Var v') = Name.equal v v'
  | equalVar _ _ = false;

(* Functions *)

fun destFn (Fn f) = f
  | destFn (Var _) = raise Error "destFn";

val isFn = can destFn;

fun fnName tm = fst (destFn tm);

fun fnArguments tm = snd (destFn tm);

fun fnArity tm = length (fnArguments tm);

fun fnFunction tm = (fnName tm, fnArity tm);

local
  fun func fs [] = fs
    | func fs (Var _ :: tms) = func fs tms
    | func fs (Fn (n,l) :: tms) =
      func (NameAritySet.add fs (n, length l)) (l @ tms);
in
  fun functions tm = func NameAritySet.empty [tm];
end;

local
  fun func fs [] = fs
    | func fs (Var _ :: tms) = func fs tms
    | func fs (Fn (n,l) :: tms) = func (NameSet.add fs n) (l @ tms);
in
  fun functionNames tm = func NameSet.empty [tm];
end;

(* Constants *)

fun mkConst c = (Fn (c, []));

fun destConst (Fn (c, [])) = c
  | destConst _ = raise Error "destConst";

val isConst = can destConst;

(* Binary functions *)

fun mkBinop f (a,b) = Fn (f,[a,b]);

fun destBinop f (Fn (x,[a,b])) =
    if Name.equal x f then (a,b) else raise Error "Term.destBinop: wrong binop"
  | destBinop _ _ = raise Error "Term.destBinop: not a binop";

fun isBinop f = can (destBinop f);

(* ------------------------------------------------------------------------- *)
(* The size of a term in symbols.                                            *)
(* ------------------------------------------------------------------------- *)

val VAR_SYMBOLS = 1;

val FN_SYMBOLS = 1;

local
  fun sz n [] = n
    | sz n (Var _ :: tms) = sz (n + VAR_SYMBOLS) tms
    | sz n (Fn (func,args) :: tms) = sz (n + FN_SYMBOLS) (args @ tms);
in
  fun symbols tm = sz 0 [tm];
end;

(* ------------------------------------------------------------------------- *)
(* A total comparison function for terms.                                    *)
(* ------------------------------------------------------------------------- *)

local
  fun cmp [] [] = EQUAL
    | cmp (tm1 :: tms1) (tm2 :: tms2) =
      let
        val tm1_tm2 = (tm1,tm2)
      in
        if Portable.pointerEqual tm1_tm2 then cmp tms1 tms2
        else
          case tm1_tm2 of
            (Var v1, Var v2) =>
            (case Name.compare (v1,v2) of
               LESS => LESS
             | EQUAL => cmp tms1 tms2
             | GREATER => GREATER)
          | (Var _, Fn _) => LESS
          | (Fn _, Var _) => GREATER
          | (Fn (f1,a1), Fn (f2,a2)) =>
            (case Name.compare (f1,f2) of
               LESS => LESS
             | EQUAL =>
               (case Int.compare (length a1, length a2) of
                  LESS => LESS
                | EQUAL => cmp (a1 @ tms1) (a2 @ tms2)
                | GREATER => GREATER)
             | GREATER => GREATER)
      end
    | cmp _ _ = raise Bug "Term.compare";
in
  fun compare (tm1,tm2) = cmp [tm1] [tm2];
end;

fun equal tm1 tm2 = compare (tm1,tm2) = EQUAL;

(* ------------------------------------------------------------------------- *)
(* Subterms.                                                                 *)
(* ------------------------------------------------------------------------- *)

type path = int list;

fun subterm tm [] = tm
  | subterm (Var _) (_ :: _) = raise Error "Term.subterm: Var"
  | subterm (Fn (_,tms)) (h :: t) =
    if h >= length tms then raise Error "Term.replace: Fn"
    else subterm (List.nth (tms,h)) t;

local
  fun subtms [] acc = acc
    | subtms ((path,tm) :: rest) acc =
      let
        fun f (n,arg) = (n :: path, arg)

        val acc = (rev path, tm) :: acc
      in
        case tm of
          Var _ => subtms rest acc
        | Fn (_,args) => subtms (map f (enumerate args) @ rest) acc
      end;
in
  fun subterms tm = subtms [([],tm)] [];
end;

fun replace tm ([],res) = if equal res tm then tm else res
  | replace tm (h :: t, res) =
    case tm of
      Var _ => raise Error "Term.replace: Var"
    | Fn (func,tms) =>
      if h >= length tms then raise Error "Term.replace: Fn"
      else
        let
          val arg = List.nth (tms,h)
          val arg' = replace arg (t,res)
        in
          if Portable.pointerEqual (arg',arg) then tm
          else Fn (func, updateNth (h,arg') tms)
        end;

fun find pred =
    let
      fun search [] = NONE
        | search ((path,tm) :: rest) =
          if pred tm then SOME (rev path)
          else
            case tm of
              Var _ => search rest
            | Fn (_,a) =>
              let
                val subtms = map (fn (i,t) => (i :: path, t)) (enumerate a)
              in
                search (subtms @ rest)
              end
    in
      fn tm => search [([],tm)]
    end;

val ppPath = Print.ppList Print.ppInt;

val pathToString = Print.toString ppPath;

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

local
  fun free _ [] = false
    | free v (Var w :: tms) = Name.equal v w orelse free v tms
    | free v (Fn (_,args) :: tms) = free v (args @ tms);
in
  fun freeIn v tm = free v [tm];
end;

local
  fun free vs [] = vs
    | free vs (Var v :: tms) = free (NameSet.add vs v) tms
    | free vs (Fn (_,args) :: tms) = free vs (args @ tms);
in
  val freeVarsList = free NameSet.empty;

  fun freeVars tm = freeVarsList [tm];
end;

(* ------------------------------------------------------------------------- *)
(* Fresh variables.                                                          *)
(* ------------------------------------------------------------------------- *)

fun newVar () = Var (Name.newName ());

fun newVars n = map Var (Name.newNames n);

local
  fun avoidAcceptable avoid n = not (NameSet.member n avoid);
in
  fun variantPrime avoid = Name.variantPrime (avoidAcceptable avoid);

  fun variantNum avoid = Name.variantNum (avoidAcceptable avoid);
end;

(* ------------------------------------------------------------------------- *)
(* Special support for terms with type annotations.                          *)
(* ------------------------------------------------------------------------- *)

val hasTypeFunctionName = Name.fromString ":";

val hasTypeFunction = (hasTypeFunctionName,2);

fun destFnHasType ((f,a) : functionName * term list) =
    if not (Name.equal f hasTypeFunctionName) then
      raise Error "Term.destFnHasType"
    else
      case a of
        [tm,ty] => (tm,ty)
      | _ => raise Error "Term.destFnHasType";

val isFnHasType = can destFnHasType;

fun isTypedVar tm =
    case tm of
      Var _ => true
    | Fn func =>
      case total destFnHasType func of
        SOME (Var _, _) => true
      | _ => false;

local
  fun sz n [] = n
    | sz n (tm :: tms) =
      case tm of
        Var _ => sz (n + 1) tms
      | Fn func =>
        case total destFnHasType func of
          SOME (tm,_) => sz n (tm :: tms)
        | NONE =>
          let
            val (_,a) = func
          in
            sz (n + 1) (a @ tms)
          end;
in
  fun typedSymbols tm = sz 0 [tm];
end;

local
  fun subtms [] acc = acc
    | subtms ((path,tm) :: rest) acc =
      case tm of
        Var _ => subtms rest acc
      | Fn func =>
        case total destFnHasType func of
          SOME (t,_) =>
          (case t of
             Var _ => subtms rest acc
           | Fn _ =>
             let
               val acc = (rev path, tm) :: acc
               val rest = (0 :: path, t) :: rest
             in
               subtms rest acc
             end)
        | NONE =>
          let
            fun f (n,arg) = (n :: path, arg)

            val (_,args) = func

            val acc = (rev path, tm) :: acc

            val rest = map f (enumerate args) @ rest
          in
            subtms rest acc
          end;
in
  fun nonVarTypedSubterms tm = subtms [([],tm)] [];
end;

(* ------------------------------------------------------------------------- *)
(* Special support for terms with an explicit function application operator. *)
(* ------------------------------------------------------------------------- *)

val appName = Name.fromString ".";

fun mkFnApp (fTm,aTm) = (appName, [fTm,aTm]);

fun mkApp f_a = Fn (mkFnApp f_a);

fun destFnApp ((f,a) : Name.name * term list) =
    if not (Name.equal f appName) then raise Error "Term.destFnApp"
    else
      case a of
        [fTm,aTm] => (fTm,aTm)
      | _ => raise Error "Term.destFnApp";

val isFnApp = can destFnApp;

fun destApp tm =
    case tm of
      Var _ => raise Error "Term.destApp"
    | Fn func => destFnApp func;

val isApp = can destApp;

fun listMkApp (f,l) = foldl mkApp f l;

local
  fun strip tms tm =
      case total destApp tm of
        SOME (f,a) => strip (a :: tms) f
      | NONE => (tm,tms);
in
  fun stripApp tm = strip [] tm;
end;

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty printing.                                              *)
(* ------------------------------------------------------------------------- *)

(* Operators parsed and printed infix *)

val infixes =
    (Unsynchronized.ref o Print.Infixes)
      [(* ML symbols *)
       {token = " / ", precedence = 7, leftAssoc = true},
       {token = " div ", precedence = 7, leftAssoc = true},
       {token = " mod ", precedence = 7, leftAssoc = true},
       {token = " * ", precedence = 7, leftAssoc = true},
       {token = " + ", precedence = 6, leftAssoc = true},
       {token = " - ", precedence = 6, leftAssoc = true},
       {token = " ^ ", precedence = 6, leftAssoc = true},
       {token = " @ ", precedence = 5, leftAssoc = false},
       {token = " :: ", precedence = 5, leftAssoc = false},
       {token = " = ", precedence = 4, leftAssoc = true},
       {token = " <> ", precedence = 4, leftAssoc = true},
       {token = " <= ", precedence = 4, leftAssoc = true},
       {token = " < ", precedence = 4, leftAssoc = true},
       {token = " >= ", precedence = 4, leftAssoc = true},
       {token = " > ", precedence = 4, leftAssoc = true},
       {token = " o ", precedence = 3, leftAssoc = true},
       {token = " -> ", precedence = 2, leftAssoc = false},  (* inferred prec *)
       {token = " : ", precedence = 1, leftAssoc = false},  (* inferred prec *)
       {token = ", ", precedence = 0, leftAssoc = false},  (* inferred prec *)

       (* Logical connectives *)
       {token = " /\\ ", precedence = ~1, leftAssoc = false},
       {token = " \\/ ", precedence = ~2, leftAssoc = false},
       {token = " ==> ", precedence = ~3, leftAssoc = false},
       {token = " <=> ", precedence = ~4, leftAssoc = false},

       (* Other symbols *)
       {token = " . ", precedence = 9, leftAssoc = true},  (* function app *)
       {token = " ** ", precedence = 8, leftAssoc = true},
       {token = " ++ ", precedence = 6, leftAssoc = true},
       {token = " -- ", precedence = 6, leftAssoc = true},
       {token = " == ", precedence = 4, leftAssoc = true}];

(* The negation symbol *)

val negation : string Unsynchronized.ref = Unsynchronized.ref "~";

(* Binder symbols *)

val binders : string list Unsynchronized.ref = Unsynchronized.ref ["\\","!","?","?!"];

(* Bracket symbols *)

val brackets : (string * string) list Unsynchronized.ref = Unsynchronized.ref [("[","]"),("{","}")];

(* Pretty printing *)

fun pp inputTerm =
    let
      val quants = !binders
      and iOps = !infixes
      and neg = !negation
      and bracks = !brackets

      val bracks = map (fn (b1,b2) => (b1 ^ b2, b1, b2)) bracks

      val bTokens = map #2 bracks @ map #3 bracks

      val iTokens = Print.tokensInfixes iOps

      fun destI tm =
          case tm of
            Fn (f,[a,b]) =>
            let
              val f = Name.toString f
            in
              if StringSet.member f iTokens then SOME (f,a,b) else NONE
            end
          | _ => NONE

      val iPrinter = Print.ppInfixes iOps destI

      val specialTokens =
          StringSet.addList iTokens (neg :: quants @ ["$","(",")"] @ bTokens)

      fun vName bv s = StringSet.member s bv

      fun checkVarName bv n =
          let
            val s = Name.toString n
          in
            if vName bv s then s else "$" ^ s
          end

      fun varName bv = Print.ppMap (checkVarName bv) Print.ppString

      fun checkFunctionName bv n =
          let
            val s = Name.toString n
          in
            if StringSet.member s specialTokens orelse vName bv s then
              "(" ^ s ^ ")"
            else
              s
          end

      fun functionName bv = Print.ppMap (checkFunctionName bv) Print.ppString

      fun isI tm = Option.isSome (destI tm)

      fun stripNeg tm =
          case tm of
            Fn (f,[a]) =>
            if Name.toString f <> neg then (0,tm)
            else let val (n,tm) = stripNeg a in (n + 1, tm) end
          | _ => (0,tm)

      val destQuant =
          let
            fun dest q (Fn (q', [Var v, body])) =
                if Name.toString q' <> q then NONE
                else
                  (case dest q body of
                     NONE => SOME (q,v,[],body)
                   | SOME (_,v',vs,body) => SOME (q, v, v' :: vs, body))
              | dest _ _ = NONE
          in
            fn tm => Useful.first (fn q => dest q tm) quants
          end

      fun isQuant tm = Option.isSome (destQuant tm)

      fun destBrack (Fn (b,[tm])) =
          let
            val s = Name.toString b
          in
            case List.find (fn (n,_,_) => n = s) bracks of
              NONE => NONE
            | SOME (_,b1,b2) => SOME (b1,tm,b2)
          end
        | destBrack _ = NONE

      fun isBrack tm = Option.isSome (destBrack tm)

      fun functionArgument bv tm =
          Print.sequence
            (Print.addBreak 1)
            (if isBrack tm then customBracket bv tm
             else if isVar tm orelse isConst tm then basic bv tm
             else bracket bv tm)

      and basic bv (Var v) = varName bv v
        | basic bv (Fn (f,args)) =
          Print.blockProgram Print.Inconsistent 2
            (functionName bv f :: map (functionArgument bv) args)

      and customBracket bv tm =
          case destBrack tm of
            SOME (b1,tm,b2) => Print.ppBracket b1 b2 (term bv) tm
          | NONE => basic bv tm

      and innerQuant bv tm =
          case destQuant tm of
            NONE => term bv tm
          | SOME (q,v,vs,tm) =>
            let
              val bv = StringSet.addList bv (map Name.toString (v :: vs))
            in
              Print.program
                [Print.addString q,
                 varName bv v,
                 Print.program
                   (map (Print.sequence (Print.addBreak 1) o varName bv) vs),
                 Print.addString ".",
                 Print.addBreak 1,
                 innerQuant bv tm]
            end

      and quantifier bv tm =
          if not (isQuant tm) then customBracket bv tm
          else Print.block Print.Inconsistent 2 (innerQuant bv tm)

      and molecule bv (tm,r) =
          let
            val (n,tm) = stripNeg tm
          in
            Print.blockProgram Print.Inconsistent n
              [Print.duplicate n (Print.addString neg),
               if isI tm orelse (r andalso isQuant tm) then bracket bv tm
               else quantifier bv tm]
          end

      and term bv tm = iPrinter (molecule bv) (tm,false)

      and bracket bv tm = Print.ppBracket "(" ")" (term bv) tm
    in
      term StringSet.empty
    end inputTerm;

val toString = Print.toString pp;

(* Parsing *)

local
  open Parse;

  infixr 9 >>++
  infixr 8 ++
  infixr 7 >>
  infixr 6 ||

  val isAlphaNum =
      let
        val alphaNumChars = explode "_'"
      in
        fn c => mem c alphaNumChars orelse Char.isAlphaNum c
      end;

  local
    val alphaNumToken = atLeastOne (some isAlphaNum) >> implode;

    val symbolToken =
        let
          fun isNeg c = str c = !negation

          val symbolChars = explode "<>=-*+/\\?@|!$%&#^:;~"

          fun isSymbol c = mem c symbolChars

          fun isNonNegSymbol c = not (isNeg c) andalso isSymbol c
        in
          some isNeg >> str ||
          (some isNonNegSymbol ++ many (some isSymbol)) >> (implode o op::)
        end;

    val punctToken =
        let
          val punctChars = explode "()[]{}.,"

          fun isPunct c = mem c punctChars
        in
          some isPunct >> str
        end;

    val lexToken = alphaNumToken || symbolToken || punctToken;

    val space = many (some Char.isSpace);
  in
    val lexer = (space ++ lexToken ++ space) >> (fn (_,(tok,_)) => tok);
  end;

  fun termParser inputStream =
      let
        val quants = !binders
        and iOps = !infixes
        and neg = !negation
        and bracks = ("(",")") :: !brackets

        val bracks = map (fn (b1,b2) => (b1 ^ b2, b1, b2)) bracks

        val bTokens = map #2 bracks @ map #3 bracks

        fun possibleVarName "" = false
          | possibleVarName s = isAlphaNum (String.sub (s,0))

        fun vName bv s = StringSet.member s bv

        val iTokens = Print.tokensInfixes iOps

        val iParser =
            parseInfixes iOps (fn (f,a,b) => Fn (Name.fromString f, [a,b]))

        val specialTokens =
            StringSet.addList iTokens (neg :: quants @ ["$"] @ bTokens)

        fun varName bv =
            some (vName bv) ||
            (some (Useful.equal "$") ++ some possibleVarName) >> snd

        fun fName bv s =
            not (StringSet.member s specialTokens) andalso not (vName bv s)

        fun functionName bv =
            some (fName bv) ||
            (some (Useful.equal "(") ++ any ++ some (Useful.equal ")")) >>
            (fn (_,(s,_)) => s)

        fun basic bv tokens =
            let
              val var = varName bv >> (Var o Name.fromString)

              val const =
                  functionName bv >> (fn f => Fn (Name.fromString f, []))

              fun bracket (ab,a,b) =
                  (some (Useful.equal a) ++ term bv ++ some (Useful.equal b)) >>
                  (fn (_,(tm,_)) =>
                      if ab = "()" then tm else Fn (Name.fromString ab, [tm]))

              fun quantifier q =
                  let
                    fun bind (v,t) =
                        Fn (Name.fromString q, [Var (Name.fromString v), t])
                  in
                    (some (Useful.equal q) ++
                     atLeastOne (some possibleVarName) ++
                     some (Useful.equal ".")) >>++
                    (fn (_,(vs,_)) =>
                        term (StringSet.addList bv vs) >>
                        (fn body => foldr bind body vs))
                  end
            in
              var ||
              const ||
              first (map bracket bracks) ||
              first (map quantifier quants)
            end tokens

        and molecule bv tokens =
            let
              val negations = many (some (Useful.equal neg)) >> length

              val function =
                  (functionName bv ++ many (basic bv)) >>
                  (fn (f,args) => Fn (Name.fromString f, args)) ||
                  basic bv
            in
              (negations ++ function) >>
              (fn (n,tm) => funpow n (fn t => Fn (Name.fromString neg, [t])) tm)
            end tokens

        and term bv tokens = iParser (molecule bv) tokens
      in
        term StringSet.empty
      end inputStream;
in
  fun fromString input =
      let
        val chars = Stream.fromList (explode input)

        val tokens = everything (lexer >> singleton) chars

        val terms = everything (termParser >> singleton) tokens
      in
        case Stream.toList terms of
          [tm] => tm
        | _ => raise Error "Term.fromString"
      end;
end;

local
  val antiquotedTermToString = Print.toString (Print.ppBracket "(" ")" pp);
in
  val parse = Parse.parseQuotation antiquotedTermToString fromString;
end;

end

structure TermOrdered =
struct type t = Term.term val compare = Term.compare end

structure TermMap = KeyMap (TermOrdered);

structure TermSet = ElementSet (TermMap);
end;

(**** Original file: Subst.sig ****)

(* ========================================================================= *)
(* FIRST ORDER LOGIC SUBSTITUTIONS                                           *)
(* Copyright (c) 2002-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Subst =
sig

(* ------------------------------------------------------------------------- *)
(* A type of first order logic substitutions.                                *)
(* ------------------------------------------------------------------------- *)

type subst

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val empty : subst

val null : subst -> bool

val size : subst -> int

val peek : subst -> Metis.Term.var -> Metis.Term.term option

val insert : subst -> Metis.Term.var * Metis.Term.term -> subst

val singleton : Metis.Term.var * Metis.Term.term -> subst

val toList : subst -> (Metis.Term.var * Metis.Term.term) list

val fromList : (Metis.Term.var * Metis.Term.term) list -> subst

val foldl : (Metis.Term.var * Metis.Term.term * 'a -> 'a) -> 'a -> subst -> 'a

val foldr : (Metis.Term.var * Metis.Term.term * 'a -> 'a) -> 'a -> subst -> 'a

val pp : subst Metis.Print.pp

val toString : subst -> string

(* ------------------------------------------------------------------------- *)
(* Normalizing removes identity substitutions.                               *)
(* ------------------------------------------------------------------------- *)

val normalize : subst -> subst

(* ------------------------------------------------------------------------- *)
(* Applying a substitution to a first order logic term.                      *)
(* ------------------------------------------------------------------------- *)

val subst : subst -> Metis.Term.term -> Metis.Term.term

(* ------------------------------------------------------------------------- *)
(* Restricting a substitution to a smaller set of variables.                 *)
(* ------------------------------------------------------------------------- *)

val restrict : subst -> Metis.NameSet.set -> subst

val remove : subst -> Metis.NameSet.set -> subst

(* ------------------------------------------------------------------------- *)
(* Composing two substitutions so that the following identity holds:         *)
(*                                                                           *)
(* subst (compose sub1 sub2) tm = subst sub2 (subst sub1 tm)                 *)
(* ------------------------------------------------------------------------- *)

val compose : subst -> subst -> subst

(* ------------------------------------------------------------------------- *)
(* Creating the union of two compatible substitutions.                       *)
(* ------------------------------------------------------------------------- *)

val union : subst -> subst -> subst  (* raises Error *)

(* ------------------------------------------------------------------------- *)
(* Substitutions can be inverted iff they are renaming substitutions.        *)
(* ------------------------------------------------------------------------- *)

val invert : subst -> subst  (* raises Error *)

val isRenaming : subst -> bool

(* ------------------------------------------------------------------------- *)
(* Creating a substitution to freshen variables.                             *)
(* ------------------------------------------------------------------------- *)

val freshVars : Metis.NameSet.set -> subst

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

val redexes : subst -> Metis.NameSet.set

val residueFreeVars : subst -> Metis.NameSet.set

val freeVars : subst -> Metis.NameSet.set

(* ------------------------------------------------------------------------- *)
(* Functions.                                                                *)
(* ------------------------------------------------------------------------- *)

val functions : subst -> Metis.NameAritySet.set

(* ------------------------------------------------------------------------- *)
(* Matching for first order logic terms.                                     *)
(* ------------------------------------------------------------------------- *)

val match : subst -> Metis.Term.term -> Metis.Term.term -> subst  (* raises Error *)

(* ------------------------------------------------------------------------- *)
(* Unification for first order logic terms.                                  *)
(* ------------------------------------------------------------------------- *)

val unify : subst -> Metis.Term.term -> Metis.Term.term -> subst  (* raises Error *)

end

(**** Original file: Subst.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* FIRST ORDER LOGIC SUBSTITUTIONS                                           *)
(* Copyright (c) 2002-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Subst :> Subst =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of first order logic substitutions.                                *)
(* ------------------------------------------------------------------------- *)

datatype subst = Subst of Term.term NameMap.map;

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val empty = Subst (NameMap.new ());

fun null (Subst m) = NameMap.null m;

fun size (Subst m) = NameMap.size m;

fun peek (Subst m) v = NameMap.peek m v;

fun insert (Subst m) v_tm = Subst (NameMap.insert m v_tm);

fun singleton v_tm = insert empty v_tm;

fun toList (Subst m) = NameMap.toList m;

fun fromList l = Subst (NameMap.fromList l);

fun foldl f b (Subst m) = NameMap.foldl f b m;

fun foldr f b (Subst m) = NameMap.foldr f b m;

fun pp sub =
    Print.ppBracket "<[" "]>"
      (Print.ppOpList "," (Print.ppOp2 " |->" Name.pp Term.pp))
      (toList sub);

val toString = Print.toString pp;

(* ------------------------------------------------------------------------- *)
(* Normalizing removes identity substitutions.                               *)
(* ------------------------------------------------------------------------- *)

local
  fun isNotId (v,tm) = not (Term.equalVar v tm);
in
  fun normalize (sub as Subst m) =
      let
        val m' = NameMap.filter isNotId m
      in
        if NameMap.size m = NameMap.size m' then sub else Subst m'
      end;
end;

(* ------------------------------------------------------------------------- *)
(* Applying a substitution to a first order logic term.                      *)
(* ------------------------------------------------------------------------- *)

fun subst sub =
    let
      fun tmSub (tm as Term.Var v) =
          (case peek sub v of
             SOME tm' => if Portable.pointerEqual (tm,tm') then tm else tm'
           | NONE => tm)
        | tmSub (tm as Term.Fn (f,args)) =
          let
            val args' = Sharing.map tmSub args
          in
            if Portable.pointerEqual (args,args') then tm
            else Term.Fn (f,args')
          end
    in
      fn tm => if null sub then tm else tmSub tm
    end;

(* ------------------------------------------------------------------------- *)
(* Restricting a substitution to a given set of variables.                   *)
(* ------------------------------------------------------------------------- *)

fun restrict (sub as Subst m) varSet =
    let
      fun isRestrictedVar (v,_) = NameSet.member v varSet

      val m' = NameMap.filter isRestrictedVar m
    in
      if NameMap.size m = NameMap.size m' then sub else Subst m'
    end;

fun remove (sub as Subst m) varSet =
    let
      fun isRestrictedVar (v,_) = not (NameSet.member v varSet)

      val m' = NameMap.filter isRestrictedVar m
    in
      if NameMap.size m = NameMap.size m' then sub else Subst m'
    end;

(* ------------------------------------------------------------------------- *)
(* Composing two substitutions so that the following identity holds:         *)
(*                                                                           *)
(* subst (compose sub1 sub2) tm = subst sub2 (subst sub1 tm)                 *)
(* ------------------------------------------------------------------------- *)

fun compose (sub1 as Subst m1) sub2 =
    let
      fun f (v,tm,s) = insert s (v, subst sub2 tm)
    in
      if null sub2 then sub1 else NameMap.foldl f sub2 m1
    end;

(* ------------------------------------------------------------------------- *)
(* Creating the union of two compatible substitutions.                       *)
(* ------------------------------------------------------------------------- *)

local
  fun compatible ((_,tm1),(_,tm2)) =
      if Term.equal tm1 tm2 then SOME tm1
      else raise Error "Subst.union: incompatible";
in
  fun union (s1 as Subst m1) (s2 as Subst m2) =
      if NameMap.null m1 then s2
      else if NameMap.null m2 then s1
      else Subst (NameMap.union compatible m1 m2);
end;

(* ------------------------------------------------------------------------- *)
(* Substitutions can be inverted iff they are renaming substitutions.        *)
(* ------------------------------------------------------------------------- *)

local
  fun inv (v, Term.Var w, s) =
      if NameMap.inDomain w s then raise Error "Subst.invert: non-injective"
      else NameMap.insert s (w, Term.Var v)
    | inv (_, Term.Fn _, _) = raise Error "Subst.invert: non-variable";
in
  fun invert (Subst m) = Subst (NameMap.foldl inv (NameMap.new ()) m);
end;

val isRenaming = can invert;

(* ------------------------------------------------------------------------- *)
(* Creating a substitution to freshen variables.                             *)
(* ------------------------------------------------------------------------- *)

val freshVars =
    let
      fun add (v,m) = insert m (v, Term.newVar ())
    in
      NameSet.foldl add empty
    end;

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

val redexes =
    let
      fun add (v,_,s) = NameSet.add s v
    in
      foldl add NameSet.empty
    end;

val residueFreeVars =
    let
      fun add (_,t,s) = NameSet.union s (Term.freeVars t)
    in
      foldl add NameSet.empty
    end;

val freeVars =
    let
      fun add (v,t,s) = NameSet.union (NameSet.add s v) (Term.freeVars t)
    in
      foldl add NameSet.empty
    end;

(* ------------------------------------------------------------------------- *)
(* Functions.                                                                *)
(* ------------------------------------------------------------------------- *)

val functions =
    let
      fun add (_,t,s) = NameAritySet.union s (Term.functions t)
    in
      foldl add NameAritySet.empty
    end;

(* ------------------------------------------------------------------------- *)
(* Matching for first order logic terms.                                     *)
(* ------------------------------------------------------------------------- *)

local
  fun matchList sub [] = sub
    | matchList sub ((Term.Var v, tm) :: rest) =
      let
        val sub =
            case peek sub v of
              NONE => insert sub (v,tm)
            | SOME tm' =>
              if Term.equal tm tm' then sub
              else raise Error "Subst.match: incompatible matches"
      in
        matchList sub rest
      end
    | matchList sub ((Term.Fn (f1,args1), Term.Fn (f2,args2)) :: rest) =
      if Name.equal f1 f2 andalso length args1 = length args2 then
        matchList sub (zip args1 args2 @ rest)
      else raise Error "Subst.match: different structure"
    | matchList _ _ = raise Error "Subst.match: functions can't match vars";
in
  fun match sub tm1 tm2 = matchList sub [(tm1,tm2)];
end;

(* ------------------------------------------------------------------------- *)
(* Unification for first order logic terms.                                  *)
(* ------------------------------------------------------------------------- *)

local
  fun solve sub [] = sub
    | solve sub ((tm1_tm2 as (tm1,tm2)) :: rest) =
      if Portable.pointerEqual tm1_tm2 then solve sub rest
      else solve' sub (subst sub tm1) (subst sub tm2) rest

  and solve' sub (Term.Var v) tm rest =
      if Term.equalVar v tm then solve sub rest
      else if Term.freeIn v tm then raise Error "Subst.unify: occurs check"
      else
        (case peek sub v of
           NONE => solve (compose sub (singleton (v,tm))) rest
         | SOME tm' => solve' sub tm' tm rest)
    | solve' sub tm1 (tm2 as Term.Var _) rest = solve' sub tm2 tm1 rest
    | solve' sub (Term.Fn (f1,args1)) (Term.Fn (f2,args2)) rest =
      if Name.equal f1 f2 andalso length args1 = length args2 then
        solve sub (zip args1 args2 @ rest)
      else
        raise Error "Subst.unify: different structure";
in
  fun unify sub tm1 tm2 = solve sub [(tm1,tm2)];
end;

end
end;

(**** Original file: Atom.sig ****)

(* ========================================================================= *)
(* FIRST ORDER LOGIC ATOMS                                                   *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Atom =
sig

(* ------------------------------------------------------------------------- *)
(* A type for storing first order logic atoms.                               *)
(* ------------------------------------------------------------------------- *)

type relationName = Metis.Name.name

type relation = relationName * int

type atom = relationName * Metis.Term.term list

(* ------------------------------------------------------------------------- *)
(* Constructors and destructors.                                             *)
(* ------------------------------------------------------------------------- *)

val name : atom -> relationName

val arguments : atom -> Metis.Term.term list

val arity : atom -> int

val relation : atom -> relation

val functions : atom -> Metis.NameAritySet.set

val functionNames : atom -> Metis.NameSet.set

(* Binary relations *)

val mkBinop : relationName -> Metis.Term.term * Metis.Term.term -> atom

val destBinop : relationName -> atom -> Metis.Term.term * Metis.Term.term

val isBinop : relationName -> atom -> bool

(* ------------------------------------------------------------------------- *)
(* The size of an atom in symbols.                                           *)
(* ------------------------------------------------------------------------- *)

val symbols : atom -> int

(* ------------------------------------------------------------------------- *)
(* A total comparison function for atoms.                                    *)
(* ------------------------------------------------------------------------- *)

val compare : atom * atom -> order

val equal : atom -> atom -> bool

(* ------------------------------------------------------------------------- *)
(* Subterms.                                                                 *)
(* ------------------------------------------------------------------------- *)

val subterm : atom -> Metis.Term.path -> Metis.Term.term

val subterms : atom -> (Metis.Term.path * Metis.Term.term) list

val replace : atom -> Metis.Term.path * Metis.Term.term -> atom

val find : (Metis.Term.term -> bool) -> atom -> Metis.Term.path option

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

val freeIn : Metis.Term.var -> atom -> bool

val freeVars : atom -> Metis.NameSet.set

(* ------------------------------------------------------------------------- *)
(* Substitutions.                                                            *)
(* ------------------------------------------------------------------------- *)

val subst : Metis.Subst.subst -> atom -> atom

(* ------------------------------------------------------------------------- *)
(* Matching.                                                                 *)
(* ------------------------------------------------------------------------- *)

val match : Metis.Subst.subst -> atom -> atom -> Metis.Subst.subst  (* raises Error *)

(* ------------------------------------------------------------------------- *)
(* Unification.                                                              *)
(* ------------------------------------------------------------------------- *)

val unify : Metis.Subst.subst -> atom -> atom -> Metis.Subst.subst  (* raises Error *)

(* ------------------------------------------------------------------------- *)
(* The equality relation.                                                    *)
(* ------------------------------------------------------------------------- *)

val eqRelationName : relationName

val eqRelation : relation

val mkEq : Metis.Term.term * Metis.Term.term -> atom

val destEq : atom -> Metis.Term.term * Metis.Term.term

val isEq : atom -> bool

val mkRefl : Metis.Term.term -> atom

val destRefl : atom -> Metis.Term.term

val isRefl : atom -> bool

val sym : atom -> atom  (* raises Error if given a refl *)

val lhs : atom -> Metis.Term.term

val rhs : atom -> Metis.Term.term

(* ------------------------------------------------------------------------- *)
(* Special support for terms with type annotations.                          *)
(* ------------------------------------------------------------------------- *)

val typedSymbols : atom -> int

val nonVarTypedSubterms : atom -> (Metis.Term.path * Metis.Term.term) list

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty printing.                                              *)
(* ------------------------------------------------------------------------- *)

val pp : atom Metis.Print.pp

val toString : atom -> string

val fromString : string -> atom

val parse : Metis.Term.term Metis.Parse.quotation -> atom

end

(**** Original file: Atom.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* FIRST ORDER LOGIC ATOMS                                                   *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Atom :> Atom =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type for storing first order logic atoms.                               *)
(* ------------------------------------------------------------------------- *)

type relationName = Name.name;

type relation = relationName * int;

type atom = relationName * Term.term list;

(* ------------------------------------------------------------------------- *)
(* Constructors and destructors.                                             *)
(* ------------------------------------------------------------------------- *)

fun name ((rel,_) : atom) = rel;

fun arguments ((_,args) : atom) = args;

fun arity atm = length (arguments atm);

fun relation atm = (name atm, arity atm);

val functions =
    let
      fun f (tm,acc) = NameAritySet.union (Term.functions tm) acc
    in
      fn atm => foldl f NameAritySet.empty (arguments atm)
    end;

val functionNames =
    let
      fun f (tm,acc) = NameSet.union (Term.functionNames tm) acc
    in
      fn atm => foldl f NameSet.empty (arguments atm)
    end;

(* Binary relations *)

fun mkBinop p (a,b) : atom = (p,[a,b]);

fun destBinop p (x,[a,b]) =
    if Name.equal x p then (a,b) else raise Error "Atom.destBinop: wrong binop"
  | destBinop _ _ = raise Error "Atom.destBinop: not a binop";

fun isBinop p = can (destBinop p);

(* ------------------------------------------------------------------------- *)
(* The size of an atom in symbols.                                           *)
(* ------------------------------------------------------------------------- *)

fun symbols atm = foldl (fn (tm,z) => Term.symbols tm + z) 1 (arguments atm);

(* ------------------------------------------------------------------------- *)
(* A total comparison function for atoms.                                    *)
(* ------------------------------------------------------------------------- *)

fun compare ((p1,tms1),(p2,tms2)) =
    case Name.compare (p1,p2) of
      LESS => LESS
    | EQUAL => lexCompare Term.compare (tms1,tms2)
    | GREATER => GREATER;

fun equal atm1 atm2 = compare (atm1,atm2) = EQUAL;

(* ------------------------------------------------------------------------- *)
(* Subterms.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun subterm _ [] = raise Bug "Atom.subterm: empty path"
  | subterm ((_,tms) : atom) (h :: t) =
    if h >= length tms then raise Error "Atom.subterm: bad path"
    else Term.subterm (List.nth (tms,h)) t;

fun subterms ((_,tms) : atom) =
    let
      fun f ((n,tm),l) = map (fn (p,s) => (n :: p, s)) (Term.subterms tm) @ l
    in
      foldl f [] (enumerate tms)
    end;

fun replace _ ([],_) = raise Bug "Atom.replace: empty path"
  | replace (atm as (rel,tms)) (h :: t, res) : atom =
    if h >= length tms then raise Error "Atom.replace: bad path"
    else
      let
        val tm = List.nth (tms,h)
        val tm' = Term.replace tm (t,res)
      in
        if Portable.pointerEqual (tm,tm') then atm
        else (rel, updateNth (h,tm') tms)
      end;

fun find pred =
    let
      fun f (i,tm) =
          case Term.find pred tm of
            SOME path => SOME (i :: path)
          | NONE => NONE
    in
      fn (_,tms) : atom => first f (enumerate tms)
    end;

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

fun freeIn v atm = List.exists (Term.freeIn v) (arguments atm);

val freeVars =
    let
      fun f (tm,acc) = NameSet.union (Term.freeVars tm) acc
    in
      fn atm => foldl f NameSet.empty (arguments atm)
    end;

(* ------------------------------------------------------------------------- *)
(* Substitutions.                                                            *)
(* ------------------------------------------------------------------------- *)

fun subst sub (atm as (p,tms)) : atom =
    let
      val tms' = Sharing.map (Subst.subst sub) tms
    in
      if Portable.pointerEqual (tms',tms) then atm else (p,tms')
    end;

(* ------------------------------------------------------------------------- *)
(* Matching.                                                                 *)
(* ------------------------------------------------------------------------- *)

local
  fun matchArg ((tm1,tm2),sub) = Subst.match sub tm1 tm2;
in
  fun match sub (p1,tms1) (p2,tms2) =
      let
        val _ = (Name.equal p1 p2 andalso length tms1 = length tms2) orelse
                raise Error "Atom.match"
      in
        foldl matchArg sub (zip tms1 tms2)
      end;
end;

(* ------------------------------------------------------------------------- *)
(* Unification.                                                              *)
(* ------------------------------------------------------------------------- *)

local
  fun unifyArg ((tm1,tm2),sub) = Subst.unify sub tm1 tm2;
in
  fun unify sub (p1,tms1) (p2,tms2) =
      let
        val _ = (Name.equal p1 p2 andalso length tms1 = length tms2) orelse
                raise Error "Atom.unify"
      in
        foldl unifyArg sub (zip tms1 tms2)
      end;
end;

(* ------------------------------------------------------------------------- *)
(* The equality relation.                                                    *)
(* ------------------------------------------------------------------------- *)

val eqRelationName = Name.fromString "=";

val eqRelationArity = 2;

val eqRelation = (eqRelationName,eqRelationArity);

val mkEq = mkBinop eqRelationName;

fun destEq x = destBinop eqRelationName x;

fun isEq x = isBinop eqRelationName x;

fun mkRefl tm = mkEq (tm,tm);

fun destRefl atm =
    let
      val (l,r) = destEq atm
      val _ = Term.equal l r orelse raise Error "Atom.destRefl"
    in
      l
    end;

fun isRefl x = can destRefl x;

fun sym atm =
    let
      val (l,r) = destEq atm
      val _ = not (Term.equal l r) orelse raise Error "Atom.sym: refl"
    in
      mkEq (r,l)
    end;

fun lhs atm = fst (destEq atm);

fun rhs atm = snd (destEq atm);

(* ------------------------------------------------------------------------- *)
(* Special support for terms with type annotations.                          *)
(* ------------------------------------------------------------------------- *)

fun typedSymbols ((_,tms) : atom) =
    foldl (fn (tm,z) => Term.typedSymbols tm + z) 1 tms;

fun nonVarTypedSubterms (_,tms) =
    let
      fun addArg ((n,arg),acc) =
          let
            fun addTm ((path,tm),acc) = (n :: path, tm) :: acc
          in
            foldl addTm acc (Term.nonVarTypedSubterms arg)
          end
    in
      foldl addArg [] (enumerate tms)
    end;

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty printing.                                              *)
(* ------------------------------------------------------------------------- *)

val pp = Print.ppMap Term.Fn Term.pp;

val toString = Print.toString pp;

fun fromString s = Term.destFn (Term.fromString s);

val parse = Parse.parseQuotation Term.toString fromString;

end

structure AtomOrdered =
struct type t = Atom.atom val compare = Atom.compare end

structure AtomMap = KeyMap (AtomOrdered);

structure AtomSet = ElementSet (AtomMap);
end;

(**** Original file: Formula.sig ****)

(* ========================================================================= *)
(* FIRST ORDER LOGIC FORMULAS                                                *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Formula =
sig

(* ------------------------------------------------------------------------- *)
(* A type of first order logic formulas.                                     *)
(* ------------------------------------------------------------------------- *)

datatype formula =
    True
  | False
  | Atom of Metis.Atom.atom
  | Not of formula
  | And of formula * formula
  | Or of formula * formula
  | Imp of formula * formula
  | Iff of formula * formula
  | Forall of Metis.Term.var * formula
  | Exists of Metis.Term.var * formula

(* ------------------------------------------------------------------------- *)
(* Constructors and destructors.                                             *)
(* ------------------------------------------------------------------------- *)

(* Booleans *)

val mkBoolean : bool -> formula

val destBoolean : formula -> bool

val isBoolean : formula -> bool

val isTrue : formula -> bool

val isFalse : formula -> bool

(* Functions *)

val functions : formula -> Metis.NameAritySet.set

val functionNames : formula -> Metis.NameSet.set

(* Relations *)

val relations : formula -> Metis.NameAritySet.set

val relationNames : formula -> Metis.NameSet.set

(* Atoms *)

val destAtom : formula -> Metis.Atom.atom

val isAtom : formula -> bool

(* Negations *)

val destNeg : formula -> formula

val isNeg : formula -> bool

val stripNeg : formula -> int * formula

(* Conjunctions *)

val listMkConj : formula list -> formula

val stripConj : formula -> formula list

val flattenConj : formula -> formula list

(* Disjunctions *)

val listMkDisj : formula list -> formula

val stripDisj : formula -> formula list

val flattenDisj : formula -> formula list

(* Equivalences *)

val listMkEquiv : formula list -> formula

val stripEquiv : formula -> formula list

val flattenEquiv : formula -> formula list

(* Universal quantification *)

val destForall : formula -> Metis.Term.var * formula

val isForall : formula -> bool

val listMkForall : Metis.Term.var list * formula -> formula

val setMkForall : Metis.NameSet.set * formula -> formula

val stripForall : formula -> Metis.Term.var list * formula

(* Existential quantification *)

val destExists : formula -> Metis.Term.var * formula

val isExists : formula -> bool

val listMkExists : Metis.Term.var list * formula -> formula

val setMkExists : Metis.NameSet.set * formula -> formula

val stripExists : formula -> Metis.Term.var list * formula

(* ------------------------------------------------------------------------- *)
(* The size of a formula in symbols.                                         *)
(* ------------------------------------------------------------------------- *)

val symbols : formula -> int

(* ------------------------------------------------------------------------- *)
(* A total comparison function for formulas.                                 *)
(* ------------------------------------------------------------------------- *)

val compare : formula * formula -> order

val equal : formula -> formula -> bool

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

val freeIn : Metis.Term.var -> formula -> bool

val freeVars : formula -> Metis.NameSet.set

val freeVarsList : formula list -> Metis.NameSet.set

val specialize : formula -> formula

val generalize : formula -> formula

(* ------------------------------------------------------------------------- *)
(* Substitutions.                                                            *)
(* ------------------------------------------------------------------------- *)

val subst : Metis.Subst.subst -> formula -> formula

(* ------------------------------------------------------------------------- *)
(* The equality relation.                                                    *)
(* ------------------------------------------------------------------------- *)

val mkEq : Metis.Term.term * Metis.Term.term -> formula

val destEq : formula -> Metis.Term.term * Metis.Term.term

val isEq : formula -> bool

val mkNeq : Metis.Term.term * Metis.Term.term -> formula

val destNeq : formula -> Metis.Term.term * Metis.Term.term

val isNeq : formula -> bool

val mkRefl : Metis.Term.term -> formula

val destRefl : formula -> Metis.Term.term

val isRefl : formula -> bool

val sym : formula -> formula  (* raises Error if given a refl *)

val lhs : formula -> Metis.Term.term

val rhs : formula -> Metis.Term.term

(* ------------------------------------------------------------------------- *)
(* Splitting goals.                                                          *)
(* ------------------------------------------------------------------------- *)

val splitGoal : formula -> formula list

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty-printing.                                              *)
(* ------------------------------------------------------------------------- *)

type quotation = formula Metis.Parse.quotation

val pp : formula Metis.Print.pp

val toString : formula -> string

val fromString : string -> formula

val parse : quotation -> formula

end

(**** Original file: Formula.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* FIRST ORDER LOGIC FORMULAS                                                *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Formula :> Formula =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of first order logic formulas.                                     *)
(* ------------------------------------------------------------------------- *)

datatype formula =
    True
  | False
  | Atom of Atom.atom
  | Not of formula
  | And of formula * formula
  | Or of formula * formula
  | Imp of formula * formula
  | Iff of formula * formula
  | Forall of Term.var * formula
  | Exists of Term.var * formula;

(* ------------------------------------------------------------------------- *)
(* Constructors and destructors.                                             *)
(* ------------------------------------------------------------------------- *)

(* Booleans *)

fun mkBoolean true = True
  | mkBoolean false = False;

fun destBoolean True = true
  | destBoolean False = false
  | destBoolean _ = raise Error "destBoolean";

val isBoolean = can destBoolean;

fun isTrue fm =
    case fm of
      True => true
    | _ => false;

fun isFalse fm =
    case fm of
      False => true
    | _ => false;

(* Functions *)

local
  fun funcs fs [] = fs
    | funcs fs (True :: fms) = funcs fs fms
    | funcs fs (False :: fms) = funcs fs fms
    | funcs fs (Atom atm :: fms) =
      funcs (NameAritySet.union (Atom.functions atm) fs) fms
    | funcs fs (Not p :: fms) = funcs fs (p :: fms)
    | funcs fs (And (p,q) :: fms) = funcs fs (p :: q :: fms)
    | funcs fs (Or (p,q) :: fms) = funcs fs (p :: q :: fms)
    | funcs fs (Imp (p,q) :: fms) = funcs fs (p :: q :: fms)
    | funcs fs (Iff (p,q) :: fms) = funcs fs (p :: q :: fms)
    | funcs fs (Forall (_,p) :: fms) = funcs fs (p :: fms)
    | funcs fs (Exists (_,p) :: fms) = funcs fs (p :: fms);
in
  fun functions fm = funcs NameAritySet.empty [fm];
end;

local
  fun funcs fs [] = fs
    | funcs fs (True :: fms) = funcs fs fms
    | funcs fs (False :: fms) = funcs fs fms
    | funcs fs (Atom atm :: fms) =
      funcs (NameSet.union (Atom.functionNames atm) fs) fms
    | funcs fs (Not p :: fms) = funcs fs (p :: fms)
    | funcs fs (And (p,q) :: fms) = funcs fs (p :: q :: fms)
    | funcs fs (Or (p,q) :: fms) = funcs fs (p :: q :: fms)
    | funcs fs (Imp (p,q) :: fms) = funcs fs (p :: q :: fms)
    | funcs fs (Iff (p,q) :: fms) = funcs fs (p :: q :: fms)
    | funcs fs (Forall (_,p) :: fms) = funcs fs (p :: fms)
    | funcs fs (Exists (_,p) :: fms) = funcs fs (p :: fms);
in
  fun functionNames fm = funcs NameSet.empty [fm];
end;

(* Relations *)

local
  fun rels fs [] = fs
    | rels fs (True :: fms) = rels fs fms
    | rels fs (False :: fms) = rels fs fms
    | rels fs (Atom atm :: fms) =
      rels (NameAritySet.add fs (Atom.relation atm)) fms
    | rels fs (Not p :: fms) = rels fs (p :: fms)
    | rels fs (And (p,q) :: fms) = rels fs (p :: q :: fms)
    | rels fs (Or (p,q) :: fms) = rels fs (p :: q :: fms)
    | rels fs (Imp (p,q) :: fms) = rels fs (p :: q :: fms)
    | rels fs (Iff (p,q) :: fms) = rels fs (p :: q :: fms)
    | rels fs (Forall (_,p) :: fms) = rels fs (p :: fms)
    | rels fs (Exists (_,p) :: fms) = rels fs (p :: fms);
in
  fun relations fm = rels NameAritySet.empty [fm];
end;

local
  fun rels fs [] = fs
    | rels fs (True :: fms) = rels fs fms
    | rels fs (False :: fms) = rels fs fms
    | rels fs (Atom atm :: fms) = rels (NameSet.add fs (Atom.name atm)) fms
    | rels fs (Not p :: fms) = rels fs (p :: fms)
    | rels fs (And (p,q) :: fms) = rels fs (p :: q :: fms)
    | rels fs (Or (p,q) :: fms) = rels fs (p :: q :: fms)
    | rels fs (Imp (p,q) :: fms) = rels fs (p :: q :: fms)
    | rels fs (Iff (p,q) :: fms) = rels fs (p :: q :: fms)
    | rels fs (Forall (_,p) :: fms) = rels fs (p :: fms)
    | rels fs (Exists (_,p) :: fms) = rels fs (p :: fms);
in
  fun relationNames fm = rels NameSet.empty [fm];
end;

(* Atoms *)

fun destAtom (Atom atm) = atm
  | destAtom _ = raise Error "Formula.destAtom";

val isAtom = can destAtom;

(* Negations *)

fun destNeg (Not p) = p
  | destNeg _ = raise Error "Formula.destNeg";

val isNeg = can destNeg;

val stripNeg =
    let
      fun strip n (Not fm) = strip (n + 1) fm
        | strip n fm = (n,fm)
    in
      strip 0
    end;

(* Conjunctions *)

fun listMkConj fms =
    case rev fms of [] => True | fm :: fms => foldl And fm fms;

local
  fun strip cs (And (p,q)) = strip (p :: cs) q
    | strip cs fm = rev (fm :: cs);
in
  fun stripConj True = []
    | stripConj fm = strip [] fm;
end;

val flattenConj =
    let
      fun flat acc [] = acc
        | flat acc (And (p,q) :: fms) = flat acc (q :: p :: fms)
        | flat acc (True :: fms) = flat acc fms
        | flat acc (fm :: fms) = flat (fm :: acc) fms
    in
      fn fm => flat [] [fm]
    end;

(* Disjunctions *)

fun listMkDisj fms =
    case rev fms of [] => False | fm :: fms => foldl Or fm fms;

local
  fun strip cs (Or (p,q)) = strip (p :: cs) q
    | strip cs fm = rev (fm :: cs);
in
  fun stripDisj False = []
    | stripDisj fm = strip [] fm;
end;

val flattenDisj =
    let
      fun flat acc [] = acc
        | flat acc (Or (p,q) :: fms) = flat acc (q :: p :: fms)
        | flat acc (False :: fms) = flat acc fms
        | flat acc (fm :: fms) = flat (fm :: acc) fms
    in
      fn fm => flat [] [fm]
    end;

(* Equivalences *)

fun listMkEquiv fms =
    case rev fms of [] => True | fm :: fms => foldl Iff fm fms;

local
  fun strip cs (Iff (p,q)) = strip (p :: cs) q
    | strip cs fm = rev (fm :: cs);
in
  fun stripEquiv True = []
    | stripEquiv fm = strip [] fm;
end;

val flattenEquiv =
    let
      fun flat acc [] = acc
        | flat acc (Iff (p,q) :: fms) = flat acc (q :: p :: fms)
        | flat acc (True :: fms) = flat acc fms
        | flat acc (fm :: fms) = flat (fm :: acc) fms
    in
      fn fm => flat [] [fm]
    end;

(* Universal quantifiers *)

fun destForall (Forall v_f) = v_f
  | destForall _ = raise Error "destForall";

val isForall = can destForall;

fun listMkForall ([],body) = body
  | listMkForall (v :: vs, body) = Forall (v, listMkForall (vs,body));

fun setMkForall (vs,body) = NameSet.foldr Forall body vs;

local
  fun strip vs (Forall (v,b)) = strip (v :: vs) b
    | strip vs tm = (rev vs, tm);
in
  val stripForall = strip [];
end;

(* Existential quantifiers *)

fun destExists (Exists v_f) = v_f
  | destExists _ = raise Error "destExists";

val isExists = can destExists;

fun listMkExists ([],body) = body
  | listMkExists (v :: vs, body) = Exists (v, listMkExists (vs,body));

fun setMkExists (vs,body) = NameSet.foldr Exists body vs;

local
  fun strip vs (Exists (v,b)) = strip (v :: vs) b
    | strip vs tm = (rev vs, tm);
in
  val stripExists = strip [];
end;

(* ------------------------------------------------------------------------- *)
(* The size of a formula in symbols.                                         *)
(* ------------------------------------------------------------------------- *)

local
  fun sz n [] = n
    | sz n (True :: fms) = sz (n + 1) fms
    | sz n (False :: fms) = sz (n + 1) fms
    | sz n (Atom atm :: fms) = sz (n + Atom.symbols atm) fms
    | sz n (Not p :: fms) = sz (n + 1) (p :: fms)
    | sz n (And (p,q) :: fms) = sz (n + 1) (p :: q :: fms)
    | sz n (Or (p,q) :: fms) = sz (n + 1) (p :: q :: fms)
    | sz n (Imp (p,q) :: fms) = sz (n + 1) (p :: q :: fms)
    | sz n (Iff (p,q) :: fms) = sz (n + 1) (p :: q :: fms)
    | sz n (Forall (_,p) :: fms) = sz (n + 1) (p :: fms)
    | sz n (Exists (_,p) :: fms) = sz (n + 1) (p :: fms);
in
  fun symbols fm = sz 0 [fm];
end;

(* ------------------------------------------------------------------------- *)
(* A total comparison function for formulas.                                 *)
(* ------------------------------------------------------------------------- *)

local
  fun cmp [] = EQUAL
    | cmp (f1_f2 :: fs) =
      if Portable.pointerEqual f1_f2 then cmp fs
      else
        case f1_f2 of
          (True,True) => cmp fs
        | (True,_) => LESS
        | (_,True) => GREATER
        | (False,False) => cmp fs
        | (False,_) => LESS
        | (_,False) => GREATER
        | (Atom atm1, Atom atm2) =>
          (case Atom.compare (atm1,atm2) of
             LESS => LESS
           | EQUAL => cmp fs
           | GREATER => GREATER)
        | (Atom _, _) => LESS
        | (_, Atom _) => GREATER
        | (Not p1, Not p2) => cmp ((p1,p2) :: fs)
        | (Not _, _) => LESS
        | (_, Not _) => GREATER
        | (And (p1,q1), And (p2,q2)) => cmp ((p1,p2) :: (q1,q2) :: fs)
        | (And _, _) => LESS
        | (_, And _) => GREATER
        | (Or (p1,q1), Or (p2,q2)) => cmp ((p1,p2) :: (q1,q2) :: fs)
        | (Or _, _) => LESS
        | (_, Or _) => GREATER
        | (Imp (p1,q1), Imp (p2,q2)) => cmp ((p1,p2) :: (q1,q2) :: fs)
        | (Imp _, _) => LESS
        | (_, Imp _) => GREATER
        | (Iff (p1,q1), Iff (p2,q2)) => cmp ((p1,p2) :: (q1,q2) :: fs)
        | (Iff _, _) => LESS
        | (_, Iff _) => GREATER
        | (Forall (v1,p1), Forall (v2,p2)) =>
          (case Name.compare (v1,v2) of
             LESS => LESS
           | EQUAL => cmp ((p1,p2) :: fs)
           | GREATER => GREATER)
        | (Forall _, Exists _) => LESS
        | (Exists _, Forall _) => GREATER
        | (Exists (v1,p1), Exists (v2,p2)) =>
          (case Name.compare (v1,v2) of
             LESS => LESS
           | EQUAL => cmp ((p1,p2) :: fs)
           | GREATER => GREATER);
in
  fun compare fm1_fm2 = cmp [fm1_fm2];
end;

fun equal fm1 fm2 = compare (fm1,fm2) = EQUAL;

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

fun freeIn v =
    let
      fun f [] = false
        | f (True :: fms) = f fms
        | f (False :: fms) = f fms
        | f (Atom atm :: fms) = Atom.freeIn v atm orelse f fms
        | f (Not p :: fms) = f (p :: fms)
        | f (And (p,q) :: fms) = f (p :: q :: fms)
        | f (Or (p,q) :: fms) = f (p :: q :: fms)
        | f (Imp (p,q) :: fms) = f (p :: q :: fms)
        | f (Iff (p,q) :: fms) = f (p :: q :: fms)
        | f (Forall (w,p) :: fms) =
          if Name.equal v w then f fms else f (p :: fms)
        | f (Exists (w,p) :: fms) =
          if Name.equal v w then f fms else f (p :: fms)
    in
      fn fm => f [fm]
    end;

local
  fun fv vs [] = vs
    | fv vs ((_,True) :: fms) = fv vs fms
    | fv vs ((_,False) :: fms) = fv vs fms
    | fv vs ((bv, Atom atm) :: fms) =
      fv (NameSet.union vs (NameSet.difference (Atom.freeVars atm) bv)) fms
    | fv vs ((bv, Not p) :: fms) = fv vs ((bv,p) :: fms)
    | fv vs ((bv, And (p,q)) :: fms) = fv vs ((bv,p) :: (bv,q) :: fms)
    | fv vs ((bv, Or (p,q)) :: fms) = fv vs ((bv,p) :: (bv,q) :: fms)
    | fv vs ((bv, Imp (p,q)) :: fms) = fv vs ((bv,p) :: (bv,q) :: fms)
    | fv vs ((bv, Iff (p,q)) :: fms) = fv vs ((bv,p) :: (bv,q) :: fms)
    | fv vs ((bv, Forall (v,p)) :: fms) = fv vs ((NameSet.add bv v, p) :: fms)
    | fv vs ((bv, Exists (v,p)) :: fms) = fv vs ((NameSet.add bv v, p) :: fms);

  fun add (fm,vs) = fv vs [(NameSet.empty,fm)];
in
  fun freeVars fm = add (fm,NameSet.empty);

  fun freeVarsList fms = List.foldl add NameSet.empty fms;
end;

fun specialize fm = snd (stripForall fm);

fun generalize fm = listMkForall (NameSet.toList (freeVars fm), fm);

(* ------------------------------------------------------------------------- *)
(* Substitutions.                                                            *)
(* ------------------------------------------------------------------------- *)

local
  fun substCheck sub fm = if Subst.null sub then fm else substFm sub fm

  and substFm sub fm =
      case fm of
        True => fm
      | False => fm
      | Atom (p,tms) =>
        let
          val tms' = Sharing.map (Subst.subst sub) tms
        in
          if Portable.pointerEqual (tms,tms') then fm else Atom (p,tms')
        end
      | Not p =>
        let
          val p' = substFm sub p
        in
          if Portable.pointerEqual (p,p') then fm else Not p'
        end
      | And (p,q) => substConn sub fm And p q
      | Or (p,q) => substConn sub fm Or p q
      | Imp (p,q) => substConn sub fm Imp p q
      | Iff (p,q) => substConn sub fm Iff p q
      | Forall (v,p) => substQuant sub fm Forall v p
      | Exists (v,p) => substQuant sub fm Exists v p

  and substConn sub fm conn p q =
      let
        val p' = substFm sub p
        and q' = substFm sub q
      in
        if Portable.pointerEqual (p,p') andalso
           Portable.pointerEqual (q,q')
        then fm
        else conn (p',q')
      end

  and substQuant sub fm quant v p =
      let
        val v' =
            let
              fun f (w,s) =
                  if Name.equal w v then s
                  else
                    case Subst.peek sub w of
                      NONE => NameSet.add s w
                    | SOME tm => NameSet.union s (Term.freeVars tm)

              val vars = freeVars p
              val vars = NameSet.foldl f NameSet.empty vars
            in
              Term.variantPrime vars v
            end

        val sub =
            if Name.equal v v' then Subst.remove sub (NameSet.singleton v)
            else Subst.insert sub (v, Term.Var v')

        val p' = substCheck sub p
      in
        if Name.equal v v' andalso Portable.pointerEqual (p,p') then fm
        else quant (v',p')
      end;
in
  val subst = substCheck;
end;

(* ------------------------------------------------------------------------- *)
(* The equality relation.                                                    *)
(* ------------------------------------------------------------------------- *)

fun mkEq a_b = Atom (Atom.mkEq a_b);

fun destEq fm = Atom.destEq (destAtom fm);

val isEq = can destEq;

fun mkNeq a_b = Not (mkEq a_b);

fun destNeq (Not fm) = destEq fm
  | destNeq _ = raise Error "Formula.destNeq";

val isNeq = can destNeq;

fun mkRefl tm = Atom (Atom.mkRefl tm);

fun destRefl fm = Atom.destRefl (destAtom fm);

val isRefl = can destRefl;

fun sym fm = Atom (Atom.sym (destAtom fm));

fun lhs fm = fst (destEq fm);

fun rhs fm = snd (destEq fm);

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty-printing.                                              *)
(* ------------------------------------------------------------------------- *)

type quotation = formula Parse.quotation;

val truthName = Name.fromString "T"
and falsityName = Name.fromString "F"
and conjunctionName = Name.fromString "/\\"
and disjunctionName = Name.fromString "\\/"
and implicationName = Name.fromString "==>"
and equivalenceName = Name.fromString "<=>"
and universalName = Name.fromString "!"
and existentialName = Name.fromString "?";

local
  fun demote True = Term.Fn (truthName,[])
    | demote False = Term.Fn (falsityName,[])
    | demote (Atom (p,tms)) = Term.Fn (p,tms)
    | demote (Not p) =
      let
        val Unsynchronized.ref s = Term.negation
      in
        Term.Fn (Name.fromString s, [demote p])
      end
    | demote (And (p,q)) = Term.Fn (conjunctionName, [demote p, demote q])
    | demote (Or (p,q)) = Term.Fn (disjunctionName, [demote p, demote q])
    | demote (Imp (p,q)) = Term.Fn (implicationName, [demote p, demote q])
    | demote (Iff (p,q)) = Term.Fn (equivalenceName, [demote p, demote q])
    | demote (Forall (v,b)) = Term.Fn (universalName, [Term.Var v, demote b])
    | demote (Exists (v,b)) =
      Term.Fn (existentialName, [Term.Var v, demote b]);
in
  fun pp fm = Term.pp (demote fm);
end;

val toString = Print.toString pp;

local
  fun isQuant [Term.Var _, _] = true
    | isQuant _ = false;

  fun promote (Term.Var v) = Atom (v,[])
    | promote (Term.Fn (f,tms)) =
      if Name.equal f truthName andalso null tms then
        True
      else if Name.equal f falsityName andalso null tms then
        False
      else if Name.toString f = !Term.negation andalso length tms = 1 then
        Not (promote (hd tms))
      else if Name.equal f conjunctionName andalso length tms = 2 then
        And (promote (hd tms), promote (List.nth (tms,1)))
      else if Name.equal f disjunctionName andalso length tms = 2 then
        Or (promote (hd tms), promote (List.nth (tms,1)))
      else if Name.equal f implicationName andalso length tms = 2 then
        Imp (promote (hd tms), promote (List.nth (tms,1)))
      else if Name.equal f equivalenceName andalso length tms = 2 then
        Iff (promote (hd tms), promote (List.nth (tms,1)))
      else if Name.equal f universalName andalso isQuant tms then
        Forall (Term.destVar (hd tms), promote (List.nth (tms,1)))
      else if Name.equal f existentialName andalso isQuant tms then
        Exists (Term.destVar (hd tms), promote (List.nth (tms,1)))
      else
        Atom (f,tms);
in
  fun fromString s = promote (Term.fromString s);
end;

val parse = Parse.parseQuotation toString fromString;

(* ------------------------------------------------------------------------- *)
(* Splitting goals.                                                          *)
(* ------------------------------------------------------------------------- *)

local
  fun add_asms asms goal =
      if null asms then goal else Imp (listMkConj (rev asms), goal);

  fun add_var_asms asms v goal = add_asms asms (Forall (v,goal));

  fun split asms pol fm =
      case (pol,fm) of
        (* Positive splittables *)
        (true,True) => []
      | (true, Not f) => split asms false f
      | (true, And (f1,f2)) => split asms true f1 @ split (f1 :: asms) true f2
      | (true, Or (f1,f2)) => split (Not f1 :: asms) true f2
      | (true, Imp (f1,f2)) => split (f1 :: asms) true f2
      | (true, Iff (f1,f2)) =>
        split (f1 :: asms) true f2 @ split (f2 :: asms) true f1
      | (true, Forall (v,f)) => map (add_var_asms asms v) (split [] true f)
        (* Negative splittables *)
      | (false,False) => []
      | (false, Not f) => split asms true f
      | (false, And (f1,f2)) => split (f1 :: asms) false f2
      | (false, Or (f1,f2)) =>
        split asms false f1 @ split (Not f1 :: asms) false f2
      | (false, Imp (f1,f2)) => split asms true f1 @ split (f1 :: asms) false f2
      | (false, Iff (f1,f2)) =>
        split (f1 :: asms) false f2 @ split (f2 :: asms) false f1
      | (false, Exists (v,f)) => map (add_var_asms asms v) (split [] false f)
        (* Unsplittables *)
      | _ => [add_asms asms (if pol then fm else Not fm)];
in
  fun splitGoal fm = split [] true fm;
end;

(*MetisTrace3
val splitGoal = fn fm =>
    let
      val result = splitGoal fm
      val () = Print.trace pp "Formula.splitGoal: fm" fm
      val () = Print.trace (Print.ppList pp) "Formula.splitGoal: result" result
    in
      result
    end;
*)

end

structure FormulaOrdered =
struct type t = Formula.formula val compare = Formula.compare end

structure FormulaMap = KeyMap (FormulaOrdered);

structure FormulaSet = ElementSet (FormulaMap);
end;

(**** Original file: Literal.sig ****)

(* ========================================================================= *)
(* FIRST ORDER LOGIC LITERALS                                                *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Literal =
sig

(* ------------------------------------------------------------------------- *)
(* A type for storing first order logic literals.                            *)
(* ------------------------------------------------------------------------- *)

type polarity = bool

type literal = polarity * Metis.Atom.atom

(* ------------------------------------------------------------------------- *)
(* Constructors and destructors.                                             *)
(* ------------------------------------------------------------------------- *)

val polarity : literal -> polarity

val atom : literal -> Metis.Atom.atom

val name : literal -> Metis.Atom.relationName

val arguments : literal -> Metis.Term.term list

val arity : literal -> int

val positive : literal -> bool

val negative : literal -> bool

val negate : literal -> literal

val relation : literal -> Metis.Atom.relation

val functions : literal -> Metis.NameAritySet.set

val functionNames : literal -> Metis.NameSet.set

(* Binary relations *)

val mkBinop : Metis.Atom.relationName -> polarity * Metis.Term.term * Metis.Term.term -> literal

val destBinop : Metis.Atom.relationName -> literal -> polarity * Metis.Term.term * Metis.Term.term

val isBinop : Metis.Atom.relationName -> literal -> bool

(* Formulas *)

val toFormula : literal -> Metis.Formula.formula

val fromFormula : Metis.Formula.formula -> literal

(* ------------------------------------------------------------------------- *)
(* The size of a literal in symbols.                                         *)
(* ------------------------------------------------------------------------- *)

val symbols : literal -> int

(* ------------------------------------------------------------------------- *)
(* A total comparison function for literals.                                 *)
(* ------------------------------------------------------------------------- *)

val compare : literal * literal -> order  (* negative < positive *)

val equal : literal -> literal -> bool

(* ------------------------------------------------------------------------- *)
(* Subterms.                                                                 *)
(* ------------------------------------------------------------------------- *)

val subterm : literal -> Metis.Term.path -> Metis.Term.term

val subterms : literal -> (Metis.Term.path * Metis.Term.term) list

val replace : literal -> Metis.Term.path * Metis.Term.term -> literal

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

val freeIn : Metis.Term.var -> literal -> bool

val freeVars : literal -> Metis.NameSet.set

(* ------------------------------------------------------------------------- *)
(* Substitutions.                                                            *)
(* ------------------------------------------------------------------------- *)

val subst : Metis.Subst.subst -> literal -> literal

(* ------------------------------------------------------------------------- *)
(* Matching.                                                                 *)
(* ------------------------------------------------------------------------- *)

val match :  (* raises Error *)
    Metis.Subst.subst -> literal -> literal -> Metis.Subst.subst

(* ------------------------------------------------------------------------- *)
(* Unification.                                                              *)
(* ------------------------------------------------------------------------- *)

val unify :  (* raises Error *)
    Metis.Subst.subst -> literal -> literal -> Metis.Subst.subst

(* ------------------------------------------------------------------------- *)
(* The equality relation.                                                    *)
(* ------------------------------------------------------------------------- *)

val mkEq : Metis.Term.term * Metis.Term.term -> literal

val destEq : literal -> Metis.Term.term * Metis.Term.term

val isEq : literal -> bool

val mkNeq : Metis.Term.term * Metis.Term.term -> literal

val destNeq : literal -> Metis.Term.term * Metis.Term.term

val isNeq : literal -> bool

val mkRefl : Metis.Term.term -> literal

val destRefl : literal -> Metis.Term.term

val isRefl : literal -> bool

val mkIrrefl : Metis.Term.term -> literal

val destIrrefl : literal -> Metis.Term.term

val isIrrefl : literal -> bool

(* The following work with both equalities and disequalities *)

val sym : literal -> literal  (* raises Error if given a refl or irrefl *)

val lhs : literal -> Metis.Term.term

val rhs : literal -> Metis.Term.term

(* ------------------------------------------------------------------------- *)
(* Special support for terms with type annotations.                          *)
(* ------------------------------------------------------------------------- *)

val typedSymbols : literal -> int

val nonVarTypedSubterms : literal -> (Metis.Term.path * Metis.Term.term) list

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty-printing.                                              *)
(* ------------------------------------------------------------------------- *)

val pp : literal Metis.Print.pp

val toString : literal -> string

val fromString : string -> literal

val parse : Metis.Term.term Metis.Parse.quotation -> literal

end

(**** Original file: Literal.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* FIRST ORDER LOGIC LITERALS                                                *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Literal :> Literal =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type for storing first order logic literals.                            *)
(* ------------------------------------------------------------------------- *)

type polarity = bool;

type literal = polarity * Atom.atom;

(* ------------------------------------------------------------------------- *)
(* Constructors and destructors.                                             *)
(* ------------------------------------------------------------------------- *)

fun polarity ((pol,_) : literal) = pol;

fun atom ((_,atm) : literal) = atm;

fun name lit = Atom.name (atom lit);

fun arguments lit = Atom.arguments (atom lit);

fun arity lit = Atom.arity (atom lit);

fun positive lit = polarity lit;

fun negative lit = not (polarity lit);

fun negate (pol,atm) : literal = (not pol, atm)

fun relation lit = Atom.relation (atom lit);

fun functions lit = Atom.functions (atom lit);

fun functionNames lit = Atom.functionNames (atom lit);

(* Binary relations *)

fun mkBinop rel (pol,a,b) : literal = (pol, Atom.mkBinop rel (a,b));

fun destBinop rel ((pol,atm) : literal) =
    case Atom.destBinop rel atm of (a,b) => (pol,a,b);

fun isBinop rel = can (destBinop rel);

(* Formulas *)

fun toFormula (true,atm) = Formula.Atom atm
  | toFormula (false,atm) = Formula.Not (Formula.Atom atm);

fun fromFormula (Formula.Atom atm) = (true,atm)
  | fromFormula (Formula.Not (Formula.Atom atm)) = (false,atm)
  | fromFormula _ = raise Error "Literal.fromFormula";

(* ------------------------------------------------------------------------- *)
(* The size of a literal in symbols.                                         *)
(* ------------------------------------------------------------------------- *)

fun symbols ((_,atm) : literal) = Atom.symbols atm;

(* ------------------------------------------------------------------------- *)
(* A total comparison function for literals.                                 *)
(* ------------------------------------------------------------------------- *)

val compare = prodCompare boolCompare Atom.compare;

fun equal (p1,atm1) (p2,atm2) = p1 = p2 andalso Atom.equal atm1 atm2;

(* ------------------------------------------------------------------------- *)
(* Subterms.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun subterm lit path = Atom.subterm (atom lit) path;

fun subterms lit = Atom.subterms (atom lit);

fun replace (lit as (pol,atm)) path_tm =
    let
      val atm' = Atom.replace atm path_tm
    in
      if Portable.pointerEqual (atm,atm') then lit else (pol,atm')
    end;

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

fun freeIn v lit = Atom.freeIn v (atom lit);

fun freeVars lit = Atom.freeVars (atom lit);

(* ------------------------------------------------------------------------- *)
(* Substitutions.                                                            *)
(* ------------------------------------------------------------------------- *)

fun subst sub (lit as (pol,atm)) : literal =
    let
      val atm' = Atom.subst sub atm
    in
      if Portable.pointerEqual (atm',atm) then lit else (pol,atm')
    end;

(* ------------------------------------------------------------------------- *)
(* Matching.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun match sub ((pol1,atm1) : literal) (pol2,atm2) =
    let
      val _ = pol1 = pol2 orelse raise Error "Literal.match"
    in
      Atom.match sub atm1 atm2
    end;

(* ------------------------------------------------------------------------- *)
(* Unification.                                                              *)
(* ------------------------------------------------------------------------- *)

fun unify sub ((pol1,atm1) : literal) (pol2,atm2) =
    let
      val _ = pol1 = pol2 orelse raise Error "Literal.unify"
    in
      Atom.unify sub atm1 atm2
    end;

(* ------------------------------------------------------------------------- *)
(* The equality relation.                                                    *)
(* ------------------------------------------------------------------------- *)

fun mkEq l_r : literal = (true, Atom.mkEq l_r);

fun destEq ((true,atm) : literal) = Atom.destEq atm
  | destEq (false,_) = raise Error "Literal.destEq";

val isEq = can destEq;

fun mkNeq l_r : literal = (false, Atom.mkEq l_r);

fun destNeq ((false,atm) : literal) = Atom.destEq atm
  | destNeq (true,_) = raise Error "Literal.destNeq";

val isNeq = can destNeq;

fun mkRefl tm = (true, Atom.mkRefl tm);

fun destRefl (true,atm) = Atom.destRefl atm
  | destRefl (false,_) = raise Error "Literal.destRefl";

val isRefl = can destRefl;

fun mkIrrefl tm = (false, Atom.mkRefl tm);

fun destIrrefl (true,_) = raise Error "Literal.destIrrefl"
  | destIrrefl (false,atm) = Atom.destRefl atm;

val isIrrefl = can destIrrefl;

fun sym (pol,atm) : literal = (pol, Atom.sym atm);

fun lhs ((_,atm) : literal) = Atom.lhs atm;

fun rhs ((_,atm) : literal) = Atom.rhs atm;

(* ------------------------------------------------------------------------- *)
(* Special support for terms with type annotations.                          *)
(* ------------------------------------------------------------------------- *)

fun typedSymbols ((_,atm) : literal) = Atom.typedSymbols atm;

fun nonVarTypedSubterms ((_,atm) : literal) = Atom.nonVarTypedSubterms atm;

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty-printing.                                              *)
(* ------------------------------------------------------------------------- *)

val pp = Print.ppMap toFormula Formula.pp;

val toString = Print.toString pp;

fun fromString s = fromFormula (Formula.fromString s);

val parse = Parse.parseQuotation Term.toString fromString;

end

structure LiteralOrdered =
struct type t = Literal.literal val compare = Literal.compare end

structure LiteralMap = KeyMap (LiteralOrdered);

structure LiteralSet =
struct

  local
    structure S = ElementSet (LiteralMap);
  in
    open S;
  end;

  fun negateMember lit set = member (Literal.negate lit) set;

  val negate =
      let
        fun f (lit,set) = add set (Literal.negate lit)
      in
        foldl f empty
      end;

  val relations =
      let
        fun f (lit,set) = NameAritySet.add set (Literal.relation lit)
      in
        foldl f NameAritySet.empty
      end;

  val functions =
      let
        fun f (lit,set) = NameAritySet.union set (Literal.functions lit)
      in
        foldl f NameAritySet.empty
      end;

  fun freeIn v = exists (Literal.freeIn v);

  val freeVars =
      let
        fun f (lit,set) = NameSet.union set (Literal.freeVars lit)
      in
        foldl f NameSet.empty
      end;

  val freeVarsList =
      let
        fun f (lits,set) = NameSet.union set (freeVars lits)
      in
        List.foldl f NameSet.empty
      end;

  val symbols =
      let
        fun f (lit,z) = Literal.symbols lit + z
      in
        foldl f 0
      end;

  val typedSymbols =
      let
        fun f (lit,z) = Literal.typedSymbols lit + z
      in
        foldl f 0
      end;

  fun subst sub lits =
      let
        fun substLit (lit,(eq,lits')) =
            let
              val lit' = Literal.subst sub lit
              val eq = eq andalso Portable.pointerEqual (lit,lit')
            in
              (eq, add lits' lit')
            end

        val (eq,lits') = foldl substLit (true,empty) lits
      in
        if eq then lits else lits'
      end;

  fun conjoin set =
      Formula.listMkConj (List.map Literal.toFormula (toList set));

  fun disjoin set =
      Formula.listMkDisj (List.map Literal.toFormula (toList set));

  val pp =
      Print.ppMap
        toList
        (Print.ppBracket "{" "}" (Print.ppOpList "," Literal.pp));

end

structure LiteralSetOrdered =
struct type t = LiteralSet.set val compare = LiteralSet.compare end

structure LiteralSetMap = KeyMap (LiteralSetOrdered);

structure LiteralSetSet = ElementSet (LiteralSetMap);
end;

(**** Original file: Thm.sig ****)

(* ========================================================================= *)
(* A LOGICAL KERNEL FOR FIRST ORDER CLAUSAL THEOREMS                         *)
(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature Thm =
sig

(* ------------------------------------------------------------------------- *)
(* An abstract type of first order logic theorems.                           *)
(* ------------------------------------------------------------------------- *)

type thm

(* ------------------------------------------------------------------------- *)
(* Theorem destructors.                                                      *)
(* ------------------------------------------------------------------------- *)

type clause = Metis.LiteralSet.set

datatype inferenceType =
    Axiom
  | Assume
  | Subst
  | Factor
  | Resolve
  | Refl
  | Equality

type inference = inferenceType * thm list

val clause : thm -> clause

val inference : thm -> inference

(* Tautologies *)

val isTautology : thm -> bool

(* Contradictions *)

val isContradiction : thm -> bool

(* Unit theorems *)

val destUnit : thm -> Metis.Literal.literal

val isUnit : thm -> bool

(* Unit equality theorems *)

val destUnitEq : thm -> Metis.Term.term * Metis.Term.term

val isUnitEq : thm -> bool

(* Literals *)

val member : Metis.Literal.literal -> thm -> bool

val negateMember : Metis.Literal.literal -> thm -> bool

(* ------------------------------------------------------------------------- *)
(* A total order.                                                            *)
(* ------------------------------------------------------------------------- *)

val compare : thm * thm -> order

val equal : thm -> thm -> bool

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

val freeIn : Metis.Term.var -> thm -> bool

val freeVars : thm -> Metis.NameSet.set

(* ------------------------------------------------------------------------- *)
(* Pretty-printing.                                                          *)
(* ------------------------------------------------------------------------- *)

val ppInferenceType : inferenceType Metis.Print.pp

val inferenceTypeToString : inferenceType -> string

val pp : thm Metis.Print.pp

val toString : thm -> string

(* ------------------------------------------------------------------------- *)
(* Primitive rules of inference.                                             *)
(* ------------------------------------------------------------------------- *)

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* ----- axiom C                                                             *)
(*   C                                                                       *)
(* ------------------------------------------------------------------------- *)

val axiom : clause -> thm

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* ----------- assume L                                                      *)
(*   L \/ ~L                                                                 *)
(* ------------------------------------------------------------------------- *)

val assume : Metis.Literal.literal -> thm

(* ------------------------------------------------------------------------- *)
(*    C                                                                      *)
(* -------- subst s                                                          *)
(*   C[s]                                                                    *)
(* ------------------------------------------------------------------------- *)

val subst : Metis.Subst.subst -> thm -> thm

(* ------------------------------------------------------------------------- *)
(*   L \/ C    ~L \/ D                                                       *)
(* --------------------- resolve L                                           *)
(*        C \/ D                                                             *)
(*                                                                           *)
(* The literal L must occur in the first theorem, and the literal ~L must    *)
(* occur in the second theorem.                                              *)
(* ------------------------------------------------------------------------- *)

val resolve : Metis.Literal.literal -> thm -> thm -> thm

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* --------- refl t                                                          *)
(*   t = t                                                                   *)
(* ------------------------------------------------------------------------- *)

val refl : Metis.Term.term -> thm

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* ------------------------ equality L p t                                   *)
(*   ~(s = t) \/ ~L \/ L'                                                    *)
(*                                                                           *)
(* where s is the subterm of L at path p, and L' is L with the subterm at    *)
(* path p being replaced by t.                                               *)
(* ------------------------------------------------------------------------- *)

val equality : Metis.Literal.literal -> Metis.Term.path -> Metis.Term.term -> thm

end

(**** Original file: Thm.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* A LOGICAL KERNEL FOR FIRST ORDER CLAUSAL THEOREMS                         *)
(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

structure Thm :> Thm =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* An abstract type of first order logic theorems.                           *)
(* ------------------------------------------------------------------------- *)

type clause = LiteralSet.set;

datatype inferenceType =
    Axiom
  | Assume
  | Subst
  | Factor
  | Resolve
  | Refl
  | Equality;

datatype thm = Thm of clause * (inferenceType * thm list);

type inference = inferenceType * thm list;

(* ------------------------------------------------------------------------- *)
(* Theorem destructors.                                                      *)
(* ------------------------------------------------------------------------- *)

fun clause (Thm (cl,_)) = cl;

fun inference (Thm (_,inf)) = inf;

(* Tautologies *)

local
  fun chk (_,NONE) = NONE
    | chk ((pol,atm), SOME set) =
      if (pol andalso Atom.isRefl atm) orelse AtomSet.member atm set then NONE
      else SOME (AtomSet.add set atm);
in
  fun isTautology th =
      case LiteralSet.foldl chk (SOME AtomSet.empty) (clause th) of
        SOME _ => false
      | NONE => true;
end;

(* Contradictions *)

fun isContradiction th = LiteralSet.null (clause th);

(* Unit theorems *)

fun destUnit (Thm (cl,_)) =
    if LiteralSet.size cl = 1 then LiteralSet.pick cl
    else raise Error "Thm.destUnit";

val isUnit = can destUnit;

(* Unit equality theorems *)

fun destUnitEq th = Literal.destEq (destUnit th);

val isUnitEq = can destUnitEq;

(* Literals *)

fun member lit (Thm (cl,_)) = LiteralSet.member lit cl;

fun negateMember lit (Thm (cl,_)) = LiteralSet.negateMember lit cl;

(* ------------------------------------------------------------------------- *)
(* A total order.                                                            *)
(* ------------------------------------------------------------------------- *)

fun compare (th1,th2) = LiteralSet.compare (clause th1, clause th2);

fun equal th1 th2 = LiteralSet.equal (clause th1) (clause th2);

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

fun freeIn v (Thm (cl,_)) = LiteralSet.freeIn v cl;

fun freeVars (Thm (cl,_)) = LiteralSet.freeVars cl;

(* ------------------------------------------------------------------------- *)
(* Pretty-printing.                                                          *)
(* ------------------------------------------------------------------------- *)

fun inferenceTypeToString Axiom = "Axiom"
  | inferenceTypeToString Assume = "Assume"
  | inferenceTypeToString Subst = "Subst"
  | inferenceTypeToString Factor = "Factor"
  | inferenceTypeToString Resolve = "Resolve"
  | inferenceTypeToString Refl = "Refl"
  | inferenceTypeToString Equality = "Equality";

fun ppInferenceType inf =
    Print.ppString (inferenceTypeToString inf);

local
  fun toFormula th =
      Formula.listMkDisj
        (map Literal.toFormula (LiteralSet.toList (clause th)));
in
  fun pp th =
      Print.blockProgram Print.Inconsistent 3
        [Print.addString "|- ",
         Formula.pp (toFormula th)];
end;

val toString = Print.toString pp;

(* ------------------------------------------------------------------------- *)
(* Primitive rules of inference.                                             *)
(* ------------------------------------------------------------------------- *)

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* ----- axiom C                                                             *)
(*   C                                                                       *)
(* ------------------------------------------------------------------------- *)

fun axiom cl = Thm (cl,(Axiom,[]));

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* ----------- assume L                                                      *)
(*   L \/ ~L                                                                 *)
(* ------------------------------------------------------------------------- *)

fun assume lit =
    Thm (LiteralSet.fromList [lit, Literal.negate lit], (Assume,[]));

(* ------------------------------------------------------------------------- *)
(*    C                                                                      *)
(* -------- subst s                                                          *)
(*   C[s]                                                                    *)
(* ------------------------------------------------------------------------- *)

fun subst sub (th as Thm (cl,inf)) =
    let
      val cl' = LiteralSet.subst sub cl
    in
      if Portable.pointerEqual (cl,cl') then th
      else
        case inf of
          (Subst,_) => Thm (cl',inf)
        | _ => Thm (cl',(Subst,[th]))
    end;

(* ------------------------------------------------------------------------- *)
(*   L \/ C    ~L \/ D                                                       *)
(* --------------------- resolve L                                           *)
(*        C \/ D                                                             *)
(*                                                                           *)
(* The literal L must occur in the first theorem, and the literal ~L must    *)
(* occur in the second theorem.                                              *)
(* ------------------------------------------------------------------------- *)

fun resolve lit (th1 as Thm (cl1,_)) (th2 as Thm (cl2,_)) =
    let
      val cl1' = LiteralSet.delete cl1 lit
      and cl2' = LiteralSet.delete cl2 (Literal.negate lit)
    in
      Thm (LiteralSet.union cl1' cl2', (Resolve,[th1,th2]))
    end;

(*MetisDebug
val resolve = fn lit => fn pos => fn neg =>
    resolve lit pos neg
    handle Error err =>
      raise Error ("Thm.resolve:\nlit = " ^ Literal.toString lit ^
                   "\npos = " ^ toString pos ^
                   "\nneg = " ^ toString neg ^ "\n" ^ err);
*)

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* --------- refl t                                                          *)
(*   t = t                                                                   *)
(* ------------------------------------------------------------------------- *)

fun refl tm = Thm (LiteralSet.singleton (true, Atom.mkRefl tm), (Refl,[]));

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* ------------------------ equality L p t                                   *)
(*   ~(s = t) \/ ~L \/ L'                                                    *)
(*                                                                           *)
(* where s is the subterm of L at path p, and L' is L with the subterm at    *)
(* path p being replaced by t.                                               *)
(* ------------------------------------------------------------------------- *)

fun equality lit path t =
    let
      val s = Literal.subterm lit path

      val lit' = Literal.replace lit (path,t)

      val eqLit = Literal.mkNeq (s,t)

      val cl = LiteralSet.fromList [eqLit, Literal.negate lit, lit']
    in
      Thm (cl,(Equality,[]))
    end;

end
end;

(**** Original file: Proof.sig ****)

(* ========================================================================= *)
(* PROOFS IN FIRST ORDER LOGIC                                               *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Proof =
sig

(* ------------------------------------------------------------------------- *)
(* A type of first order logic proofs.                                       *)
(* ------------------------------------------------------------------------- *)

datatype inference =
    Axiom of Metis.LiteralSet.set
  | Assume of Metis.Atom.atom
  | Subst of Metis.Subst.subst * Metis.Thm.thm
  | Resolve of Metis.Atom.atom * Metis.Thm.thm * Metis.Thm.thm
  | Refl of Metis.Term.term
  | Equality of Metis.Literal.literal * Metis.Term.path * Metis.Term.term

type proof = (Metis.Thm.thm * inference) list

(* ------------------------------------------------------------------------- *)
(* Reconstructing single inferences.                                         *)
(* ------------------------------------------------------------------------- *)

val inferenceType : inference -> Metis.Thm.inferenceType

val parents : inference -> Metis.Thm.thm list

val inferenceToThm : inference -> Metis.Thm.thm

val thmToInference : Metis.Thm.thm -> inference

(* ------------------------------------------------------------------------- *)
(* Reconstructing whole proofs.                                              *)
(* ------------------------------------------------------------------------- *)

val proof : Metis.Thm.thm -> proof

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

val freeIn : Metis.Term.var -> proof -> bool

val freeVars : proof -> Metis.NameSet.set

(* ------------------------------------------------------------------------- *)
(* Printing.                                                                 *)
(* ------------------------------------------------------------------------- *)

val ppInference : inference Metis.Print.pp

val inferenceToString : inference -> string

val pp : proof Metis.Print.pp

val toString : proof -> string

end

(**** Original file: Proof.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* PROOFS IN FIRST ORDER LOGIC                                               *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Proof :> Proof =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of first order logic proofs.                                       *)
(* ------------------------------------------------------------------------- *)

datatype inference =
    Axiom of LiteralSet.set
  | Assume of Atom.atom
  | Subst of Subst.subst * Thm.thm
  | Resolve of Atom.atom * Thm.thm * Thm.thm
  | Refl of Term.term
  | Equality of Literal.literal * Term.path * Term.term;

type proof = (Thm.thm * inference) list;

(* ------------------------------------------------------------------------- *)
(* Printing.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun inferenceType (Axiom _) = Thm.Axiom
  | inferenceType (Assume _) = Thm.Assume
  | inferenceType (Subst _) = Thm.Subst
  | inferenceType (Resolve _) = Thm.Resolve
  | inferenceType (Refl _) = Thm.Refl
  | inferenceType (Equality _) = Thm.Equality;

local
  fun ppAssume atm = Print.sequence (Print.addBreak 1) (Atom.pp atm);

  fun ppSubst ppThm (sub,thm) =
      Print.sequence (Print.addBreak 1)
        (Print.blockProgram Print.Inconsistent 1
           [Print.addString "{",
            Print.ppOp2 " =" Print.ppString Subst.pp ("sub",sub),
            Print.addString ",",
            Print.addBreak 1,
            Print.ppOp2 " =" Print.ppString ppThm ("thm",thm),
            Print.addString "}"]);

  fun ppResolve ppThm (res,pos,neg) =
      Print.sequence (Print.addBreak 1)
        (Print.blockProgram Print.Inconsistent 1
           [Print.addString "{",
            Print.ppOp2 " =" Print.ppString Atom.pp ("res",res),
            Print.addString ",",
            Print.addBreak 1,
            Print.ppOp2 " =" Print.ppString ppThm ("pos",pos),
            Print.addString ",",
            Print.addBreak 1,
            Print.ppOp2 " =" Print.ppString ppThm ("neg",neg),
            Print.addString "}"]);

  fun ppRefl tm = Print.sequence (Print.addBreak 1) (Term.pp tm);

  fun ppEquality (lit,path,res) =
      Print.sequence (Print.addBreak 1)
        (Print.blockProgram Print.Inconsistent 1
           [Print.addString "{",
            Print.ppOp2 " =" Print.ppString Literal.pp ("lit",lit),
            Print.addString ",",
            Print.addBreak 1,
            Print.ppOp2 " =" Print.ppString Term.ppPath ("path",path),
            Print.addString ",",
            Print.addBreak 1,
            Print.ppOp2 " =" Print.ppString Term.pp ("res",res),
            Print.addString "}"]);

  fun ppInf ppAxiom ppThm inf =
      let
        val infString = Thm.inferenceTypeToString (inferenceType inf)
      in
        Print.block Print.Inconsistent 2
          (Print.sequence
             (Print.addString infString)
             (case inf of
                Axiom cl => ppAxiom cl
              | Assume x => ppAssume x
              | Subst x => ppSubst ppThm x
              | Resolve x => ppResolve ppThm x
              | Refl x => ppRefl x
              | Equality x => ppEquality x))
      end;

  fun ppAxiom cl =
      Print.sequence
        (Print.addBreak 1)
        (Print.ppMap
           LiteralSet.toList
           (Print.ppBracket "{" "}" (Print.ppOpList "," Literal.pp)) cl);
in
  val ppInference = ppInf ppAxiom Thm.pp;

  fun pp prf =
      let
        fun thmString n = "(" ^ Int.toString n ^ ")"

        val prf = enumerate prf

        fun ppThm th =
            Print.addString
            let
              val cl = Thm.clause th

              fun pred (_,(th',_)) = LiteralSet.equal (Thm.clause th') cl
            in
              case List.find pred prf of
                NONE => "(?)"
              | SOME (n,_) => thmString n
            end

        fun ppStep (n,(th,inf)) =
            let
              val s = thmString n
            in
              Print.sequence
                (Print.blockProgram Print.Consistent (1 + size s)
                   [Print.addString (s ^ " "),
                    Thm.pp th,
                    Print.addBreak 2,
                    Print.ppBracket "[" "]" (ppInf (K Print.skip) ppThm) inf])
                Print.addNewline
            end
      in
        Print.blockProgram Print.Consistent 0
          [Print.addString "START OF PROOF",
           Print.addNewline,
           Print.program (map ppStep prf),
           Print.addString "END OF PROOF"]
      end
(*MetisDebug
      handle Error err => raise Bug ("Proof.pp: shouldn't fail:\n" ^ err);
*)

end;

val inferenceToString = Print.toString ppInference;

val toString = Print.toString pp;

(* ------------------------------------------------------------------------- *)
(* Reconstructing single inferences.                                         *)
(* ------------------------------------------------------------------------- *)

fun parents (Axiom _) = []
  | parents (Assume _) = []
  | parents (Subst (_,th)) = [th]
  | parents (Resolve (_,th,th')) = [th,th']
  | parents (Refl _) = []
  | parents (Equality _) = [];

fun inferenceToThm (Axiom cl) = Thm.axiom cl
  | inferenceToThm (Assume atm) = Thm.assume (true,atm)
  | inferenceToThm (Subst (sub,th)) = Thm.subst sub th
  | inferenceToThm (Resolve (atm,th,th')) = Thm.resolve (true,atm) th th'
  | inferenceToThm (Refl tm) = Thm.refl tm
  | inferenceToThm (Equality (lit,path,r)) = Thm.equality lit path r;

local
  fun reconstructSubst cl cl' =
      let
        fun recon [] =
            let
(*MetisTrace3
              val () = Print.trace LiteralSet.pp "reconstructSubst: cl" cl
              val () = Print.trace LiteralSet.pp "reconstructSubst: cl'" cl'
*)
            in
              raise Bug "can't reconstruct Subst rule"
            end
          | recon (([],sub) :: others) =
            if LiteralSet.equal (LiteralSet.subst sub cl) cl' then sub
            else recon others
          | recon ((lit :: lits, sub) :: others) =
            let
              fun checkLit (lit',acc) =
                  case total (Literal.match sub lit) lit' of
                    NONE => acc
                  | SOME sub => (lits,sub) :: acc
            in
              recon (LiteralSet.foldl checkLit others cl')
            end
      in
        Subst.normalize (recon [(LiteralSet.toList cl, Subst.empty)])
      end
(*MetisDebug
      handle Error err =>
        raise Bug ("Proof.recontructSubst: shouldn't fail:\n" ^ err);
*)

  fun reconstructResolvant cl1 cl2 cl =
      (if not (LiteralSet.subset cl1 cl) then
         LiteralSet.pick (LiteralSet.difference cl1 cl)
       else if not (LiteralSet.subset cl2 cl) then
         Literal.negate (LiteralSet.pick (LiteralSet.difference cl2 cl))
       else
         (* A useless resolution, but we must reconstruct it anyway *)
         let
           val cl1' = LiteralSet.negate cl1
           and cl2' = LiteralSet.negate cl2
           val lits = LiteralSet.intersectList [cl1,cl1',cl2,cl2']
         in
           if not (LiteralSet.null lits) then LiteralSet.pick lits
           else raise Bug "can't reconstruct Resolve rule"
         end)
(*MetisDebug
      handle Error err =>
        raise Bug ("Proof.recontructResolvant: shouldn't fail:\n" ^ err);
*)

  fun reconstructEquality cl =
      let
(*MetisTrace3
        val () = Print.trace LiteralSet.pp "Proof.reconstructEquality: cl" cl
*)

        fun sync s t path (f,a) (f',a') =
            if not (Name.equal f f' andalso length a = length a') then NONE
            else
              let
                val itms = enumerate (zip a a')
              in
                case List.filter (not o uncurry Term.equal o snd) itms of
                  [(i,(tm,tm'))] =>
                  let
                    val path = i :: path
                  in
                    if Term.equal tm s andalso Term.equal tm' t then
                      SOME (rev path)
                    else
                      case (tm,tm') of
                        (Term.Fn f_a, Term.Fn f_a') => sync s t path f_a f_a'
                      | _ => NONE
                  end
                | _ => NONE
              end

        fun recon (neq,(pol,atm),(pol',atm')) =
            if pol = pol' then NONE
            else
              let
                val (s,t) = Literal.destNeq neq

                val path =
                    if not (Term.equal s t) then sync s t [] atm atm'
                    else if not (Atom.equal atm atm') then NONE
                    else Atom.find (Term.equal s) atm
              in
                case path of
                  SOME path => SOME ((pol',atm),path,t)
                | NONE => NONE
              end

        val candidates =
            case List.partition Literal.isNeq (LiteralSet.toList cl) of
              ([l1],[l2,l3]) => [(l1,l2,l3),(l1,l3,l2)]
            | ([l1,l2],[l3]) => [(l1,l2,l3),(l1,l3,l2),(l2,l1,l3),(l2,l3,l1)]
            | ([l1],[l2]) => [(l1,l1,l2),(l1,l2,l1)]
            | _ => raise Bug "reconstructEquality: malformed"

(*MetisTrace3
        val ppCands =
            Print.ppList (Print.ppTriple Literal.pp Literal.pp Literal.pp)
        val () = Print.trace ppCands
                   "Proof.reconstructEquality: candidates" candidates
*)
      in
        case first recon candidates of
          SOME info => info
        | NONE => raise Bug "can't reconstruct Equality rule"
      end
(*MetisDebug
      handle Error err =>
        raise Bug ("Proof.recontructEquality: shouldn't fail:\n" ^ err);
*)

  fun reconstruct cl (Thm.Axiom,[]) = Axiom cl
    | reconstruct cl (Thm.Assume,[]) =
      (case LiteralSet.findl Literal.positive cl of
         SOME (_,atm) => Assume atm
       | NONE => raise Bug "malformed Assume inference")
    | reconstruct cl (Thm.Subst,[th]) =
      Subst (reconstructSubst (Thm.clause th) cl, th)
    | reconstruct cl (Thm.Resolve,[th1,th2]) =
      let
        val cl1 = Thm.clause th1
        and cl2 = Thm.clause th2
        val (pol,atm) = reconstructResolvant cl1 cl2 cl
      in
        if pol then Resolve (atm,th1,th2) else Resolve (atm,th2,th1)
      end
    | reconstruct cl (Thm.Refl,[]) =
      (case LiteralSet.findl (K true) cl of
         SOME lit => Refl (Literal.destRefl lit)
       | NONE => raise Bug "malformed Refl inference")
    | reconstruct cl (Thm.Equality,[]) = Equality (reconstructEquality cl)
    | reconstruct _ _ = raise Bug "malformed inference";
in
  fun thmToInference th =
      let
(*MetisTrace3
        val () = Print.trace Thm.pp "Proof.thmToInference: th" th
*)

        val cl = Thm.clause th

        val thmInf = Thm.inference th

(*MetisTrace3
        val ppThmInf = Print.ppPair Thm.ppInferenceType (Print.ppList Thm.pp)
        val () = Print.trace ppThmInf "Proof.thmToInference: thmInf" thmInf
*)

        val inf = reconstruct cl thmInf

(*MetisTrace3
        val () = Print.trace ppInference "Proof.thmToInference: inf" inf
*)
(*MetisDebug
        val () =
            let
              val th' = inferenceToThm inf
            in
              if LiteralSet.equal (Thm.clause th') cl then ()
              else
                raise
                  Bug
                    ("Proof.thmToInference: bad inference reconstruction:" ^
                     "\n  th = " ^ Thm.toString th ^
                     "\n  inf = " ^ inferenceToString inf ^
                     "\n  inf th = " ^ Thm.toString th')
            end
*)
      in
        inf
      end
(*MetisDebug
      handle Error err =>
        raise Bug ("Proof.thmToInference: shouldn't fail:\n" ^ err);
*)
end;

(* ------------------------------------------------------------------------- *)
(* Reconstructing whole proofs.                                              *)
(* ------------------------------------------------------------------------- *)

local
  val emptyThms : Thm.thm LiteralSetMap.map = LiteralSetMap.new ();

  fun addThms (th,ths) =
      let
        val cl = Thm.clause th
      in
        if LiteralSetMap.inDomain cl ths then ths
        else
          let
            val (_,pars) = Thm.inference th
            val ths = List.foldl addThms ths pars
          in
            if LiteralSetMap.inDomain cl ths then ths
            else LiteralSetMap.insert ths (cl,th)
          end
      end;

  fun mkThms th = addThms (th,emptyThms);

  fun addProof (th,(ths,acc)) =
      let
        val cl = Thm.clause th
      in
        case LiteralSetMap.peek ths cl of
          NONE => (ths,acc)
        | SOME th =>
          let
            val (_,pars) = Thm.inference th
            val (ths,acc) = List.foldl addProof (ths,acc) pars
            val ths = LiteralSetMap.delete ths cl
            val acc = (th, thmToInference th) :: acc
          in
            (ths,acc)
          end
      end;

  fun mkProof ths th =
      let
        val (ths,acc) = addProof (th,(ths,[]))
(*MetisTrace4
        val () = Print.trace Print.ppInt "Proof.proof: unnecessary clauses" (LiteralSetMap.size ths)
*)
      in
        rev acc
      end;
in
  fun proof th =
      let
(*MetisTrace3
        val () = Print.trace Thm.pp "Proof.proof: th" th
*)
        val ths = mkThms th
        val infs = mkProof ths th
(*MetisTrace3
        val () = Print.trace Print.ppInt "Proof.proof: size" (length infs)
*)
      in
        infs
      end;
end;

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

fun freeIn v =
    let
      fun free th_inf =
          case th_inf of
            (_, Axiom lits) => LiteralSet.freeIn v lits
          | (_, Assume atm) => Atom.freeIn v atm
          | (th, Subst _) => Thm.freeIn v th
          | (_, Resolve _) => false
          | (_, Refl tm) => Term.freeIn v tm
          | (_, Equality (lit,_,tm)) =>
            Literal.freeIn v lit orelse Term.freeIn v tm
    in
      List.exists free
    end;

val freeVars =
    let
      fun inc (th_inf,set) =
          NameSet.union set
          (case th_inf of
             (_, Axiom lits) => LiteralSet.freeVars lits
           | (_, Assume atm) => Atom.freeVars atm
           | (th, Subst _) => Thm.freeVars th
           | (_, Resolve _) => NameSet.empty
           | (_, Refl tm) => Term.freeVars tm
           | (_, Equality (lit,_,tm)) =>
             NameSet.union (Literal.freeVars lit) (Term.freeVars tm))
    in
      List.foldl inc NameSet.empty
    end;

end
end;

(**** Original file: Rule.sig ****)

(* ========================================================================= *)
(* DERIVED RULES FOR CREATING FIRST ORDER LOGIC THEOREMS                     *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Rule =
sig

(* ------------------------------------------------------------------------- *)
(* An equation consists of two terms (t,u) plus a theorem (stronger than)    *)
(* t = u \/ C.                                                               *)
(* ------------------------------------------------------------------------- *)

type equation = (Metis.Term.term * Metis.Term.term) * Metis.Thm.thm

val ppEquation : equation Metis.Print.pp

val equationToString : equation -> string

(* Returns t = u if the equation theorem contains this literal *)
val equationLiteral : equation -> Metis.Literal.literal option

val reflEqn : Metis.Term.term -> equation

val symEqn : equation -> equation

val transEqn : equation -> equation -> equation

(* ------------------------------------------------------------------------- *)
(* A conversion takes a term t and either:                                   *)
(* 1. Returns a term u together with a theorem (stronger than) t = u \/ C.   *)
(* 2. Raises an Error exception.                                             *)
(* ------------------------------------------------------------------------- *)

type conv = Metis.Term.term -> Metis.Term.term * Metis.Thm.thm

val allConv : conv

val noConv : conv

val thenConv : conv -> conv -> conv

val orelseConv : conv -> conv -> conv

val tryConv : conv -> conv

val repeatConv : conv -> conv

val firstConv : conv list -> conv

val everyConv : conv list -> conv

val rewrConv : equation -> Metis.Term.path -> conv

val pathConv : conv -> Metis.Term.path -> conv

val subtermConv : conv -> int -> conv

val subtermsConv : conv -> conv  (* All function arguments *)

(* ------------------------------------------------------------------------- *)
(* Applying a conversion to every subterm, with some traversal strategy.     *)
(* ------------------------------------------------------------------------- *)

val bottomUpConv : conv -> conv

val topDownConv : conv -> conv

val repeatTopDownConv : conv -> conv  (* useful for rewriting *)

(* ------------------------------------------------------------------------- *)
(* A literule (bad pun) takes a literal L and either:                        *)
(* 1. Returns a literal L' with a theorem (stronger than) ~L \/ L' \/ C.     *)
(* 2. Raises an Error exception.                                             *)
(* ------------------------------------------------------------------------- *)

type literule = Metis.Literal.literal -> Metis.Literal.literal * Metis.Thm.thm

val allLiterule : literule

val noLiterule : literule

val thenLiterule : literule -> literule -> literule

val orelseLiterule : literule -> literule -> literule

val tryLiterule : literule -> literule

val repeatLiterule : literule -> literule

val firstLiterule : literule list -> literule

val everyLiterule : literule list -> literule

val rewrLiterule : equation -> Metis.Term.path -> literule

val pathLiterule : conv -> Metis.Term.path -> literule

val argumentLiterule : conv -> int -> literule

val allArgumentsLiterule : conv -> literule

(* ------------------------------------------------------------------------- *)
(* A rule takes one theorem and either deduces another or raises an Error    *)
(* exception.                                                                *)
(* ------------------------------------------------------------------------- *)

type rule = Metis.Thm.thm -> Metis.Thm.thm

val allRule : rule

val noRule : rule

val thenRule : rule -> rule -> rule

val orelseRule : rule -> rule -> rule

val tryRule : rule -> rule

val changedRule : rule -> rule

val repeatRule : rule -> rule

val firstRule : rule list -> rule

val everyRule : rule list -> rule

val literalRule : literule -> Metis.Literal.literal -> rule

val rewrRule : equation -> Metis.Literal.literal -> Metis.Term.path -> rule

val pathRule : conv -> Metis.Literal.literal -> Metis.Term.path -> rule

val literalsRule : literule -> Metis.LiteralSet.set -> rule

val allLiteralsRule : literule -> rule

val convRule : conv -> rule  (* All arguments of all literals *)

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* --------- reflexivity                                                     *)
(*   x = x                                                                   *)
(* ------------------------------------------------------------------------- *)

val reflexivityRule : Metis.Term.term -> Metis.Thm.thm

val reflexivity : Metis.Thm.thm

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* --------------------- symmetry                                            *)
(*   ~(x = y) \/ y = x                                                       *)
(* ------------------------------------------------------------------------- *)

val symmetryRule : Metis.Term.term -> Metis.Term.term -> Metis.Thm.thm

val symmetry : Metis.Thm.thm

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* --------------------------------- transitivity                            *)
(*   ~(x = y) \/ ~(y = z) \/ x = z                                           *)
(* ------------------------------------------------------------------------- *)

val transitivity : Metis.Thm.thm

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* ---------------------------------------------- functionCongruence (f,n)   *)
(*   ~(x0 = y0) \/ ... \/ ~(x{n-1} = y{n-1}) \/                              *)
(*   f x0 ... x{n-1} = f y0 ... y{n-1}                                       *)
(* ------------------------------------------------------------------------- *)

val functionCongruence : Metis.Term.function -> Metis.Thm.thm

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* ---------------------------------------------- relationCongruence (R,n)   *)
(*   ~(x0 = y0) \/ ... \/ ~(x{n-1} = y{n-1}) \/                              *)
(*   ~R x0 ... x{n-1} \/ R y0 ... y{n-1}                                     *)
(* ------------------------------------------------------------------------- *)

val relationCongruence : Metis.Atom.relation -> Metis.Thm.thm

(* ------------------------------------------------------------------------- *)
(*   x = y \/ C                                                              *)
(* -------------- symEq (x = y)                                              *)
(*   y = x \/ C                                                              *)
(* ------------------------------------------------------------------------- *)

val symEq : Metis.Literal.literal -> rule

(* ------------------------------------------------------------------------- *)
(*   ~(x = y) \/ C                                                           *)
(* ----------------- symNeq ~(x = y)                                         *)
(*   ~(y = x) \/ C                                                           *)
(* ------------------------------------------------------------------------- *)

val symNeq : Metis.Literal.literal -> rule

(* ------------------------------------------------------------------------- *)
(* sym (x = y) = symEq (x = y)  /\  sym ~(x = y) = symNeq ~(x = y)           *)
(* ------------------------------------------------------------------------- *)

val sym : Metis.Literal.literal -> rule

(* ------------------------------------------------------------------------- *)
(*   ~(x = x) \/ C                                                           *)
(* ----------------- removeIrrefl                                            *)
(*         C                                                                 *)
(*                                                                           *)
(* where all irreflexive equalities.                                         *)
(* ------------------------------------------------------------------------- *)

val removeIrrefl : rule

(* ------------------------------------------------------------------------- *)
(*   x = y \/ y = x \/ C                                                     *)
(* ----------------------- removeSym                                         *)
(*       x = y \/ C                                                          *)
(*                                                                           *)
(* where all duplicate copies of equalities and disequalities are removed.   *)
(* ------------------------------------------------------------------------- *)

val removeSym : rule

(* ------------------------------------------------------------------------- *)
(*   ~(v = t) \/ C                                                           *)
(* ----------------- expandAbbrevs                                           *)
(*      C[t/v]                                                               *)
(*                                                                           *)
(* where t must not contain any occurrence of the variable v.                *)
(* ------------------------------------------------------------------------- *)

val expandAbbrevs : rule

(* ------------------------------------------------------------------------- *)
(* simplify = isTautology + expandAbbrevs + removeSym                        *)
(* ------------------------------------------------------------------------- *)

val simplify : Metis.Thm.thm -> Metis.Thm.thm option

(* ------------------------------------------------------------------------- *)
(*    C                                                                      *)
(* -------- freshVars                                                        *)
(*   C[s]                                                                    *)
(*                                                                           *)
(* where s is a renaming substitution chosen so that all of the variables in *)
(* C are replaced by fresh variables.                                        *)
(* ------------------------------------------------------------------------- *)

val freshVars : rule

(* ------------------------------------------------------------------------- *)
(*               C                                                           *)
(* ---------------------------- factor                                       *)
(*   C_s_1, C_s_2, ..., C_s_n                                                *)
(*                                                                           *)
(* where each s_i is a substitution that factors C, meaning that the theorem *)
(*                                                                           *)
(*   C_s_i = (removeIrrefl o removeSym o Metis.Thm.subst s_i) C                    *)
(*                                                                           *)
(* has fewer literals than C.                                                *)
(*                                                                           *)
(* Also, if s is any substitution that factors C, then one of the s_i will   *)
(* result in a theorem C_s_i that strictly subsumes the theorem C_s.         *)
(* ------------------------------------------------------------------------- *)

val factor' : Metis.Thm.clause -> Metis.Subst.subst list

val factor : Metis.Thm.thm -> Metis.Thm.thm list

end

(**** Original file: Rule.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* DERIVED RULES FOR CREATING FIRST ORDER LOGIC THEOREMS                     *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Rule :> Rule =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Variable names.                                                           *)
(* ------------------------------------------------------------------------- *)

val xVarName = Name.fromString "x";
val xVar = Term.Var xVarName;

val yVarName = Name.fromString "y";
val yVar = Term.Var yVarName;

val zVarName = Name.fromString "z";
val zVar = Term.Var zVarName;

fun xIVarName i = Name.fromString ("x" ^ Int.toString i);
fun xIVar i = Term.Var (xIVarName i);

fun yIVarName i = Name.fromString ("y" ^ Int.toString i);
fun yIVar i = Term.Var (yIVarName i);

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* --------- reflexivity                                                     *)
(*   x = x                                                                   *)
(* ------------------------------------------------------------------------- *)

fun reflexivityRule x = Thm.refl x;

val reflexivity = reflexivityRule xVar;

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* --------------------- symmetry                                            *)
(*   ~(x = y) \/ y = x                                                       *)
(* ------------------------------------------------------------------------- *)

fun symmetryRule x y =
    let
      val reflTh = reflexivityRule x
      val reflLit = Thm.destUnit reflTh
      val eqTh = Thm.equality reflLit [0] y
    in
      Thm.resolve reflLit reflTh eqTh
    end;

val symmetry = symmetryRule xVar yVar;

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* --------------------------------- transitivity                            *)
(*   ~(x = y) \/ ~(y = z) \/ x = z                                           *)
(* ------------------------------------------------------------------------- *)

val transitivity =
    let
      val eqTh = Thm.equality (Literal.mkEq (yVar,zVar)) [0] xVar
    in
      Thm.resolve (Literal.mkEq (yVar,xVar)) symmetry eqTh
    end;

(* ------------------------------------------------------------------------- *)
(*   x = y \/ C                                                              *)
(* -------------- symEq (x = y)                                              *)
(*   y = x \/ C                                                              *)
(* ------------------------------------------------------------------------- *)

fun symEq lit th =
    let
      val (x,y) = Literal.destEq lit
    in
      if Term.equal x y then th
      else
        let
          val sub = Subst.fromList [(xVarName,x),(yVarName,y)]
          val symTh = Thm.subst sub symmetry
        in
          Thm.resolve lit th symTh
        end
    end;

(* ------------------------------------------------------------------------- *)
(* An equation consists of two terms (t,u) plus a theorem (stronger than)    *)
(* t = u \/ C.                                                               *)
(* ------------------------------------------------------------------------- *)

type equation = (Term.term * Term.term) * Thm.thm;

fun ppEquation (_,th) = Thm.pp th;

val equationToString = Print.toString ppEquation;

fun equationLiteral (t_u,th) =
    let
      val lit = Literal.mkEq t_u
    in
      if LiteralSet.member lit (Thm.clause th) then SOME lit else NONE
    end;

fun reflEqn t = ((t,t), Thm.refl t);

fun symEqn (eqn as ((t,u), th)) =
    if Term.equal t u then eqn
    else
      ((u,t),
       case equationLiteral eqn of
         SOME t_u => symEq t_u th
       | NONE => th);

fun transEqn (eqn1 as ((x,y), th1)) (eqn2 as ((_,z), th2)) =
    if Term.equal x y then eqn2
    else if Term.equal y z then eqn1
    else if Term.equal x z then reflEqn x
    else
      ((x,z),
       case equationLiteral eqn1 of
         NONE => th1
       | SOME x_y =>
         case equationLiteral eqn2 of
           NONE => th2
         | SOME y_z =>
           let
             val sub = Subst.fromList [(xVarName,x),(yVarName,y),(zVarName,z)]
             val th = Thm.subst sub transitivity
             val th = Thm.resolve x_y th1 th
             val th = Thm.resolve y_z th2 th
           in
             th
           end);

(*MetisDebug
val transEqn = fn eqn1 => fn eqn2 =>
    transEqn eqn1 eqn2
    handle Error err =>
      raise Error ("Rule.transEqn:\neqn1 = " ^ equationToString eqn1 ^
                   "\neqn2 = " ^ equationToString eqn2 ^ "\n" ^ err);
*)

(* ------------------------------------------------------------------------- *)
(* A conversion takes a term t and either:                                   *)
(* 1. Returns a term u together with a theorem (stronger than) t = u \/ C.   *)
(* 2. Raises an Error exception.                                             *)
(* ------------------------------------------------------------------------- *)

type conv = Term.term -> Term.term * Thm.thm;

fun allConv tm = (tm, Thm.refl tm);

val noConv : conv = fn _ => raise Error "noConv";

fun traceConv s conv tm =
    let
      val res as (tm',th) = conv tm
      val () = print (s ^ ": " ^ Term.toString tm ^ " --> " ^
                      Term.toString tm' ^ " " ^ Thm.toString th ^ "\n")
    in
      res
    end
    handle Error err =>
      (print (s ^ ": " ^ Term.toString tm ^ " --> Error: " ^ err ^ "\n");
       raise Error (s ^ ": " ^ err));

fun thenConvTrans tm (tm',th1) (tm'',th2) =
    let
      val eqn1 = ((tm,tm'),th1)
      and eqn2 = ((tm',tm''),th2)
      val (_,th) = transEqn eqn1 eqn2
    in
      (tm'',th)
    end;

fun thenConv conv1 conv2 tm =
    let
      val res1 as (tm',_) = conv1 tm
      val res2 = conv2 tm'
    in
      thenConvTrans tm res1 res2
    end;

fun orelseConv (conv1 : conv) conv2 tm = conv1 tm handle Error _ => conv2 tm;

fun tryConv conv = orelseConv conv allConv;

fun changedConv conv tm =
    let
      val res as (tm',_) = conv tm
    in
      if tm = tm' then raise Error "changedConv" else res
    end;

fun repeatConv conv tm = tryConv (thenConv conv (repeatConv conv)) tm;

fun firstConv [] _ = raise Error "firstConv"
  | firstConv [conv] tm = conv tm
  | firstConv (conv :: convs) tm = orelseConv conv (firstConv convs) tm;

fun everyConv [] tm = allConv tm
  | everyConv [conv] tm = conv tm
  | everyConv (conv :: convs) tm = thenConv conv (everyConv convs) tm;

fun rewrConv (eqn as ((x,y), eqTh)) path tm =
    if Term.equal x y then allConv tm
    else if null path then (y,eqTh)
    else
      let
        val reflTh = Thm.refl tm
        val reflLit = Thm.destUnit reflTh
        val th = Thm.equality reflLit (1 :: path) y
        val th = Thm.resolve reflLit reflTh th
        val th =
            case equationLiteral eqn of
              NONE => th
            | SOME x_y => Thm.resolve x_y eqTh th
        val tm' = Term.replace tm (path,y)
      in
        (tm',th)
      end;

(*MetisDebug
val rewrConv = fn eqn as ((x,y),eqTh) => fn path => fn tm =>
    rewrConv eqn path tm
    handle Error err =>
      raise Error ("Rule.rewrConv:\nx = " ^ Term.toString x ^
                   "\ny = " ^ Term.toString y ^
                   "\neqTh = " ^ Thm.toString eqTh ^
                   "\npath = " ^ Term.pathToString path ^
                   "\ntm = " ^ Term.toString tm ^ "\n" ^ err);
*)

fun pathConv conv path tm =
    let
      val x = Term.subterm tm path
      val (y,th) = conv x
    in
      rewrConv ((x,y),th) path tm
    end;

fun subtermConv conv i = pathConv conv [i];

fun subtermsConv _ (tm as Term.Var _) = allConv tm
  | subtermsConv conv (tm as Term.Fn (_,a)) =
    everyConv (map (subtermConv conv) (interval 0 (length a))) tm;

(* ------------------------------------------------------------------------- *)
(* Applying a conversion to every subterm, with some traversal strategy.     *)
(* ------------------------------------------------------------------------- *)

fun bottomUpConv conv tm =
    thenConv (subtermsConv (bottomUpConv conv)) (repeatConv conv) tm;

fun topDownConv conv tm =
    thenConv (repeatConv conv) (subtermsConv (topDownConv conv)) tm;

fun repeatTopDownConv conv =
    let
      fun f tm = thenConv (repeatConv conv) g tm
      and g tm = thenConv (subtermsConv f) h tm
      and h tm = tryConv (thenConv conv f) tm
    in
      f
    end;

(*MetisDebug
val repeatTopDownConv = fn conv => fn tm =>
    repeatTopDownConv conv tm
    handle Error err => raise Error ("repeatTopDownConv: " ^ err);
*)

(* ------------------------------------------------------------------------- *)
(* A literule (bad pun) takes a literal L and either:                        *)
(* 1. Returns a literal L' with a theorem (stronger than) ~L \/ L' \/ C.     *)
(* 2. Raises an Error exception.                                             *)
(* ------------------------------------------------------------------------- *)

type literule = Literal.literal -> Literal.literal * Thm.thm;

fun allLiterule lit = (lit, Thm.assume lit);

val noLiterule : literule = fn _ => raise Error "noLiterule";

fun thenLiterule literule1 literule2 lit =
    let
      val res1 as (lit',th1) = literule1 lit
      val res2 as (lit'',th2) = literule2 lit'
    in
      if Literal.equal lit lit' then res2
      else if Literal.equal lit' lit'' then res1
      else if Literal.equal lit lit'' then allLiterule lit
      else
        (lit'',
         if not (Thm.member lit' th1) then th1
         else if not (Thm.negateMember lit' th2) then th2
         else Thm.resolve lit' th1 th2)
    end;

fun orelseLiterule (literule1 : literule) literule2 lit =
    literule1 lit handle Error _ => literule2 lit;

fun tryLiterule literule = orelseLiterule literule allLiterule;

fun changedLiterule literule lit =
    let
      val res as (lit',_) = literule lit
    in
      if lit = lit' then raise Error "changedLiterule" else res
    end;

fun repeatLiterule literule lit =
    tryLiterule (thenLiterule literule (repeatLiterule literule)) lit;

fun firstLiterule [] _ = raise Error "firstLiterule"
  | firstLiterule [literule] lit = literule lit
  | firstLiterule (literule :: literules) lit =
    orelseLiterule literule (firstLiterule literules) lit;

fun everyLiterule [] lit = allLiterule lit
  | everyLiterule [literule] lit = literule lit
  | everyLiterule (literule :: literules) lit =
    thenLiterule literule (everyLiterule literules) lit;

fun rewrLiterule (eqn as ((x,y),eqTh)) path lit =
    if Term.equal x y then allLiterule lit
    else
      let
        val th = Thm.equality lit path y
        val th =
            case equationLiteral eqn of
              NONE => th
            | SOME x_y => Thm.resolve x_y eqTh th
        val lit' = Literal.replace lit (path,y)
      in
        (lit',th)
      end;

(*MetisDebug
val rewrLiterule = fn eqn => fn path => fn lit =>
    rewrLiterule eqn path lit
    handle Error err =>
      raise Error ("Rule.rewrLiterule:\neqn = " ^ equationToString eqn ^
                   "\npath = " ^ Term.pathToString path ^
                   "\nlit = " ^ Literal.toString lit ^ "\n" ^ err);
*)

fun pathLiterule conv path lit =
    let
      val tm = Literal.subterm lit path
      val (tm',th) = conv tm
    in
      rewrLiterule ((tm,tm'),th) path lit
    end;

fun argumentLiterule conv i = pathLiterule conv [i];

fun allArgumentsLiterule conv lit =
    everyLiterule
      (map (argumentLiterule conv) (interval 0 (Literal.arity lit))) lit;

(* ------------------------------------------------------------------------- *)
(* A rule takes one theorem and either deduces another or raises an Error    *)
(* exception.                                                                *)
(* ------------------------------------------------------------------------- *)

type rule = Thm.thm -> Thm.thm;

val allRule : rule = fn th => th;

val noRule : rule = fn _ => raise Error "noRule";

fun thenRule (rule1 : rule) (rule2 : rule) th = rule1 (rule2 th);

fun orelseRule (rule1 : rule) rule2 th = rule1 th handle Error _ => rule2 th;

fun tryRule rule = orelseRule rule allRule;

fun changedRule rule th =
    let
      val th' = rule th
    in
      if not (LiteralSet.equal (Thm.clause th) (Thm.clause th')) then th'
      else raise Error "changedRule"
    end;

fun repeatRule rule lit = tryRule (thenRule rule (repeatRule rule)) lit;

fun firstRule [] _ = raise Error "firstRule"
  | firstRule [rule] th = rule th
  | firstRule (rule :: rules) th = orelseRule rule (firstRule rules) th;

fun everyRule [] th = allRule th
  | everyRule [rule] th = rule th
  | everyRule (rule :: rules) th = thenRule rule (everyRule rules) th;

fun literalRule literule lit th =
    let
      val (lit',litTh) = literule lit
    in
      if Literal.equal lit lit' then th
      else if not (Thm.negateMember lit litTh) then litTh
      else Thm.resolve lit th litTh
    end;

(*MetisDebug
val literalRule = fn literule => fn lit => fn th =>
    literalRule literule lit th
    handle Error err =>
      raise Error ("Rule.literalRule:\nlit = " ^ Literal.toString lit ^
                   "\nth = " ^ Thm.toString th ^ "\n" ^ err);
*)

fun rewrRule eqTh lit path = literalRule (rewrLiterule eqTh path) lit;

fun pathRule conv lit path = literalRule (pathLiterule conv path) lit;

fun literalsRule literule =
    let
      fun f (lit,th) =
          if Thm.member lit th then literalRule literule lit th else th
    in
      fn lits => fn th => LiteralSet.foldl f th lits
    end;

fun allLiteralsRule literule th = literalsRule literule (Thm.clause th) th;

fun convRule conv = allLiteralsRule (allArgumentsLiterule conv);

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* ---------------------------------------------- functionCongruence (f,n)   *)
(*   ~(x0 = y0) \/ ... \/ ~(x{n-1} = y{n-1}) \/                              *)
(*   f x0 ... x{n-1} = f y0 ... y{n-1}                                       *)
(* ------------------------------------------------------------------------- *)

fun functionCongruence (f,n) =
    let
      val xs = List.tabulate (n,xIVar)
      and ys = List.tabulate (n,yIVar)

      fun cong ((i,yi),(th,lit)) =
          let
            val path = [1,i]
            val th = Thm.resolve lit th (Thm.equality lit path yi)
            val lit = Literal.replace lit (path,yi)
          in
            (th,lit)
          end

      val reflTh = Thm.refl (Term.Fn (f,xs))
      val reflLit = Thm.destUnit reflTh
    in
      fst (foldl cong (reflTh,reflLit) (enumerate ys))
    end;

(* ------------------------------------------------------------------------- *)
(*                                                                           *)
(* ---------------------------------------------- relationCongruence (R,n)   *)
(*   ~(x0 = y0) \/ ... \/ ~(x{n-1} = y{n-1}) \/                              *)
(*   ~R x0 ... x{n-1} \/ R y0 ... y{n-1}                                     *)
(* ------------------------------------------------------------------------- *)

fun relationCongruence (R,n) =
    let
      val xs = List.tabulate (n,xIVar)
      and ys = List.tabulate (n,yIVar)

      fun cong ((i,yi),(th,lit)) =
          let
            val path = [i]
            val th = Thm.resolve lit th (Thm.equality lit path yi)
            val lit = Literal.replace lit (path,yi)
          in
            (th,lit)
          end

      val assumeLit = (false,(R,xs))
      val assumeTh = Thm.assume assumeLit
    in
      fst (foldl cong (assumeTh,assumeLit) (enumerate ys))
    end;

(* ------------------------------------------------------------------------- *)
(*   ~(x = y) \/ C                                                           *)
(* ----------------- symNeq ~(x = y)                                         *)
(*   ~(y = x) \/ C                                                           *)
(* ------------------------------------------------------------------------- *)

fun symNeq lit th =
    let
      val (x,y) = Literal.destNeq lit
    in
      if Term.equal x y then th
      else
        let
          val sub = Subst.fromList [(xVarName,y),(yVarName,x)]
          val symTh = Thm.subst sub symmetry
        in
          Thm.resolve lit th symTh
        end
    end;

(* ------------------------------------------------------------------------- *)
(* sym (x = y) = symEq (x = y)  /\  sym ~(x = y) = symNeq ~(x = y)           *)
(* ------------------------------------------------------------------------- *)

fun sym (lit as (pol,_)) th = if pol then symEq lit th else symNeq lit th;

(* ------------------------------------------------------------------------- *)
(*   ~(x = x) \/ C                                                           *)
(* ----------------- removeIrrefl                                            *)
(*         C                                                                 *)
(*                                                                           *)
(* where all irreflexive equalities.                                         *)
(* ------------------------------------------------------------------------- *)

local
  fun irrefl ((true,_),th) = th
    | irrefl (lit as (false,atm), th) =
      case total Atom.destRefl atm of
        SOME x => Thm.resolve lit th (Thm.refl x)
      | NONE => th;
in
  fun removeIrrefl th = LiteralSet.foldl irrefl th (Thm.clause th);
end;

(* ------------------------------------------------------------------------- *)
(*   x = y \/ y = x \/ C                                                     *)
(* ----------------------- removeSym                                         *)
(*       x = y \/ C                                                          *)
(*                                                                           *)
(* where all duplicate copies of equalities and disequalities are removed.   *)
(* ------------------------------------------------------------------------- *)

local
  fun rem (lit as (pol,atm), eqs_th as (eqs,th)) =
      case total Atom.sym atm of
        NONE => eqs_th
      | SOME atm' =>
        if LiteralSet.member lit eqs then
          (eqs, if pol then symEq lit th else symNeq lit th)
        else
          (LiteralSet.add eqs (pol,atm'), th);
in
  fun removeSym th =
      snd (LiteralSet.foldl rem (LiteralSet.empty,th) (Thm.clause th));
end;

(* ------------------------------------------------------------------------- *)
(*   ~(v = t) \/ C                                                           *)
(* ----------------- expandAbbrevs                                           *)
(*      C[t/v]                                                               *)
(*                                                                           *)
(* where t must not contain any occurrence of the variable v.                *)
(* ------------------------------------------------------------------------- *)

local
  fun expand lit =
      let
        val (x,y) = Literal.destNeq lit
        val _ = Term.isTypedVar x orelse Term.isTypedVar y orelse
                raise Error "Rule.expandAbbrevs: no vars"
        val _ = not (Term.equal x y) orelse
                raise Error "Rule.expandAbbrevs: equal vars"
      in
        Subst.unify Subst.empty x y
      end;
in
  fun expandAbbrevs th =
      case LiteralSet.firstl (total expand) (Thm.clause th) of
        NONE => removeIrrefl th
      | SOME sub => expandAbbrevs (Thm.subst sub th);
end;

(* ------------------------------------------------------------------------- *)
(* simplify = isTautology + expandAbbrevs + removeSym                        *)
(* ------------------------------------------------------------------------- *)

fun simplify th =
    if Thm.isTautology th then NONE
    else
      let
        val th' = th
        val th' = expandAbbrevs th'
        val th' = removeSym th'
      in
        if Thm.equal th th' then SOME th else simplify th'
      end;

(* ------------------------------------------------------------------------- *)
(*    C                                                                      *)
(* -------- freshVars                                                        *)
(*   C[s]                                                                    *)
(*                                                                           *)
(* where s is a renaming substitution chosen so that all of the variables in *)
(* C are replaced by fresh variables.                                        *)
(* ------------------------------------------------------------------------- *)

fun freshVars th = Thm.subst (Subst.freshVars (Thm.freeVars th)) th;

(* ------------------------------------------------------------------------- *)
(*               C                                                           *)
(* ---------------------------- factor                                       *)
(*   C_s_1, C_s_2, ..., C_s_n                                                *)
(*                                                                           *)
(* where each s_i is a substitution that factors C, meaning that the theorem *)
(*                                                                           *)
(*   C_s_i = (removeIrrefl o removeSym o Thm.subst s_i) C                    *)
(*                                                                           *)
(* has fewer literals than C.                                                *)
(*                                                                           *)
(* Also, if s is any substitution that factors C, then one of the s_i will   *)
(* result in a theorem C_s_i that strictly subsumes the theorem C_s.         *)
(* ------------------------------------------------------------------------- *)

local
  datatype edge =
      FactorEdge of Atom.atom * Atom.atom
    | ReflEdge of Term.term * Term.term;

  fun ppEdge (FactorEdge atm_atm') = Print.ppPair Atom.pp Atom.pp atm_atm'
    | ppEdge (ReflEdge tm_tm') = Print.ppPair Term.pp Term.pp tm_tm';

  datatype joinStatus =
      Joined
    | Joinable of Subst.subst
    | Apart;

  fun joinEdge sub edge =
      let
        val result =
            case edge of
              FactorEdge (atm,atm') => total (Atom.unify sub atm) atm'
            | ReflEdge (tm,tm') => total (Subst.unify sub tm) tm'
      in
        case result of
          NONE => Apart
        | SOME sub' =>
          if Portable.pointerEqual (sub,sub') then Joined else Joinable sub'
      end;

  fun updateApart sub =
      let
        fun update acc [] = SOME acc
          | update acc (edge :: edges) =
            case joinEdge sub edge of
              Joined => NONE
            | Joinable _ => update (edge :: acc) edges
            | Apart => update acc edges
      in
        update []
      end;

  fun addFactorEdge (pol,atm) ((pol',atm'),acc) =
      if pol <> pol' then acc
      else
        let
          val edge = FactorEdge (atm,atm')
        in
          case joinEdge Subst.empty edge of
            Joined => raise Bug "addFactorEdge: joined"
          | Joinable sub => (sub,edge) :: acc
          | Apart => acc
        end;

  fun addReflEdge (false,_) acc = acc
    | addReflEdge (true,atm) acc =
      let
        val edge = ReflEdge (Atom.destEq atm)
      in
        case joinEdge Subst.empty edge of
          Joined => raise Bug "addRefl: joined"
        | Joinable _ => edge :: acc
        | Apart => acc
      end;

  fun addIrreflEdge (true,_) acc = acc
    | addIrreflEdge (false,atm) acc =
      let
        val edge = ReflEdge (Atom.destEq atm)
      in
        case joinEdge Subst.empty edge of
          Joined => raise Bug "addRefl: joined"
        | Joinable sub => (sub,edge) :: acc
        | Apart => acc
      end;

  fun init_edges acc _ [] =
      let
        fun init ((apart,sub,edge),(edges,acc)) =
            (edge :: edges, (apart,sub,edges) :: acc)
      in
        snd (List.foldl init ([],[]) acc)
      end
    | init_edges acc apart ((sub,edge) :: sub_edges) =
      let
(*MetisDebug
        val () = if not (Subst.null sub) then ()
                 else raise Bug "Rule.factor.init_edges: empty subst"
*)
        val (acc,apart) =
            case updateApart sub apart of
              SOME apart' => ((apart',sub,edge) :: acc, edge :: apart)
            | NONE => (acc,apart)
      in
        init_edges acc apart sub_edges
      end;

  fun mk_edges apart sub_edges [] = init_edges [] apart sub_edges
    | mk_edges apart sub_edges (lit :: lits) =
      let
        val sub_edges = List.foldl (addFactorEdge lit) sub_edges lits

        val (apart,sub_edges) =
            case total Literal.sym lit of
              NONE => (apart,sub_edges)
            | SOME lit' =>
              let
                val apart = addReflEdge lit apart
                val sub_edges = addIrreflEdge lit sub_edges
                val sub_edges = List.foldl (addFactorEdge lit') sub_edges lits
              in
                (apart,sub_edges)
              end
      in
        mk_edges apart sub_edges lits
      end;

  fun fact acc [] = acc
    | fact acc ((_,sub,[]) :: others) = fact (sub :: acc) others
    | fact acc ((apart, sub, edge :: edges) :: others) =
      let
        val others =
            case joinEdge sub edge of
              Joinable sub' =>
              let
                val others = (edge :: apart, sub, edges) :: others
              in
                case updateApart sub' apart of
                  NONE => others
                | SOME apart' => (apart',sub',edges) :: others
              end
            | _ => (apart,sub,edges) :: others
      in
        fact acc others
      end;
in
  fun factor' cl =
      let
(*MetisTrace6
        val () = Print.trace LiteralSet.pp "Rule.factor': cl" cl
*)
        val edges = mk_edges [] [] (LiteralSet.toList cl)
(*MetisTrace6
        val ppEdgesSize = Print.ppMap length Print.ppInt
        val ppEdgel = Print.ppList ppEdge
        val ppEdges = Print.ppList (Print.ppTriple ppEdgel Subst.pp ppEdgel)
        val () = Print.trace ppEdgesSize "Rule.factor': |edges|" edges
        val () = Print.trace ppEdges "Rule.factor': edges" edges
*)
        val result = fact [] edges
(*MetisTrace6
        val ppResult = Print.ppList Subst.pp
        val () = Print.trace ppResult "Rule.factor': result" result
*)
      in
        result
      end;
end;

fun factor th =
    let
      fun fact sub = removeSym (Thm.subst sub th)
    in
      map fact (factor' (Thm.clause th))
    end;

end
end;

(**** Original file: Normalize.sig ****)

(* ========================================================================= *)
(* NORMALIZING FORMULAS                                                      *)
(* Copyright (c) 2001-2009 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Normalize =
sig

(* ------------------------------------------------------------------------- *)
(* Negation normal form.                                                     *)
(* ------------------------------------------------------------------------- *)

val nnf : Metis.Formula.formula -> Metis.Formula.formula

(* ------------------------------------------------------------------------- *)
(* Conjunctive normal form derivations.                                      *)
(* ------------------------------------------------------------------------- *)

type thm

datatype inference =
    Axiom of Metis.Formula.formula
  | Definition of string * Metis.Formula.formula
  | Simplify of thm * thm list
  | Conjunct of thm
  | Specialize of thm
  | Skolemize of thm
  | Clausify of thm

val mkAxiom : Metis.Formula.formula -> thm

val destThm : thm -> Metis.Formula.formula * inference

val proveThms :
    thm list -> (Metis.Formula.formula * inference * Metis.Formula.formula list) list

val toStringInference : inference -> string

val ppInference : inference Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* Conjunctive normal form.                                                  *)
(* ------------------------------------------------------------------------- *)

type cnf

val initialCnf : cnf

val addCnf : thm -> cnf -> (Metis.Thm.clause * thm) list * cnf

val proveCnf : thm list -> (Metis.Thm.clause * thm) list

val cnf : Metis.Formula.formula -> Metis.Thm.clause list

end

(**** Original file: Normalize.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* NORMALIZING FORMULAS                                                      *)
(* Copyright (c) 2001-2007 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Normalize :> Normalize =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Constants.                                                                *)
(* ------------------------------------------------------------------------- *)

val prefix = "FOFtoCNF";

val skolemPrefix = "skolem" ^ prefix;

val definitionPrefix = "definition" ^ prefix;

(* ------------------------------------------------------------------------- *)
(* Storing huge real numbers as their log.                                   *)
(* ------------------------------------------------------------------------- *)

datatype logReal = LogReal of real;

fun compareLogReal (LogReal logX, LogReal logY) =
    Real.compare (logX,logY);

val zeroLogReal = LogReal ~1.0;

val oneLogReal = LogReal 0.0;

local
  fun isZero logX = logX < 0.0;

  (* Assume logX >= logY >= 0.0 *)
  fun add logX logY = logX + Math.ln (1.0 + Math.exp (logY - logX));
in
  fun isZeroLogReal (LogReal logX) = isZero logX;

  fun multiplyLogReal (LogReal logX) (LogReal logY) =
      if isZero logX orelse isZero logY then zeroLogReal
      else LogReal (logX + logY);

  fun addLogReal (lx as LogReal logX) (ly as LogReal logY) =
      if isZero logX then ly
      else if isZero logY then lx
      else if logX < logY then LogReal (add logY logX)
      else LogReal (add logX logY);

  fun withinRelativeLogReal logDelta (LogReal logX) (LogReal logY) =
      isZero logX orelse
      (not (isZero logY) andalso logX < logY + logDelta);
end;

fun toStringLogReal (LogReal logX) = Real.toString logX;

(* ------------------------------------------------------------------------- *)
(* Counting the clauses that would be generated by conjunctive normal form.  *)
(* ------------------------------------------------------------------------- *)

val countLogDelta = 0.01;

datatype count = Count of {positive : logReal, negative : logReal};

fun countCompare (count1,count2) =
    let
      val Count {positive = p1, negative = _} = count1
      and Count {positive = p2, negative = _} = count2
    in
      compareLogReal (p1,p2)
    end;

fun countNegate (Count {positive = p, negative = n}) =
    Count {positive = n, negative = p};

fun countLeqish count1 count2 =
    let
      val Count {positive = p1, negative = _} = count1
      and Count {positive = p2, negative = _} = count2
    in
      withinRelativeLogReal countLogDelta p1 p2
    end;

(*MetisDebug
fun countEqualish count1 count2 =
    countLeqish count1 count2 andalso
    countLeqish count2 count1;

fun countEquivish count1 count2 =
    countEqualish count1 count2 andalso
    countEqualish (countNegate count1) (countNegate count2);
*)

val countTrue = Count {positive = zeroLogReal, negative = oneLogReal};

val countFalse = Count {positive = oneLogReal, negative = zeroLogReal};

val countLiteral = Count {positive = oneLogReal, negative = oneLogReal};

fun countAnd2 (count1,count2) =
    let
      val Count {positive = p1, negative = n1} = count1
      and Count {positive = p2, negative = n2} = count2
      val p = addLogReal p1 p2
      and n = multiplyLogReal n1 n2
    in
      Count {positive = p, negative = n}
    end;

fun countOr2 (count1,count2) =
    let
      val Count {positive = p1, negative = n1} = count1
      and Count {positive = p2, negative = n2} = count2
      val p = multiplyLogReal p1 p2
      and n = addLogReal n1 n2
    in
      Count {positive = p, negative = n}
    end;

(* Whether countXor2 is associative or not is an open question. *)

fun countXor2 (count1,count2) =
    let
      val Count {positive = p1, negative = n1} = count1
      and Count {positive = p2, negative = n2} = count2
      val p = addLogReal (multiplyLogReal p1 p2) (multiplyLogReal n1 n2)
      and n = addLogReal (multiplyLogReal p1 n2) (multiplyLogReal n1 p2)
    in
      Count {positive = p, negative = n}
    end;

fun countDefinition body_count = countXor2 (countLiteral,body_count);

val countToString =
    let
      val rToS = toStringLogReal
    in
      fn Count {positive = p, negative = n} =>
         "(+" ^ rToS p ^ ",-" ^ rToS n ^ ")"
    end;

val ppCount = Print.ppMap countToString Print.ppString;

(* ------------------------------------------------------------------------- *)
(* A type of normalized formula.                                             *)
(* ------------------------------------------------------------------------- *)

datatype formula =
    True
  | False
  | Literal of NameSet.set * Literal.literal
  | And of NameSet.set * count * formula Set.set
  | Or of NameSet.set * count * formula Set.set
  | Xor of NameSet.set * count * bool * formula Set.set
  | Exists of NameSet.set * count * NameSet.set * formula
  | Forall of NameSet.set * count * NameSet.set * formula;

fun compare f1_f2 =
    if Portable.pointerEqual f1_f2 then EQUAL
    else
      case f1_f2 of
        (True,True) => EQUAL
      | (True,_) => LESS
      | (_,True) => GREATER
      | (False,False) => EQUAL
      | (False,_) => LESS
      | (_,False) => GREATER
      | (Literal (_,l1), Literal (_,l2)) => Literal.compare (l1,l2)
      | (Literal _, _) => LESS
      | (_, Literal _) => GREATER
      | (And (_,_,s1), And (_,_,s2)) => Set.compare (s1,s2)
      | (And _, _) => LESS
      | (_, And _) => GREATER
      | (Or (_,_,s1), Or (_,_,s2)) => Set.compare (s1,s2)
      | (Or _, _) => LESS
      | (_, Or _) => GREATER
      | (Xor (_,_,p1,s1), Xor (_,_,p2,s2)) =>
        (case boolCompare (p1,p2) of
           LESS => LESS
         | EQUAL => Set.compare (s1,s2)
         | GREATER => GREATER)
      | (Xor _, _) => LESS
      | (_, Xor _) => GREATER
      | (Exists (_,_,n1,f1), Exists (_,_,n2,f2)) =>
        (case NameSet.compare (n1,n2) of
           LESS => LESS
         | EQUAL => compare (f1,f2)
         | GREATER => GREATER)
      | (Exists _, _) => LESS
      | (_, Exists _) => GREATER
      | (Forall (_,_,n1,f1), Forall (_,_,n2,f2)) =>
        (case NameSet.compare (n1,n2) of
           LESS => LESS
         | EQUAL => compare (f1,f2)
         | GREATER => GREATER);

val empty = Set.empty compare;

val singleton = Set.singleton compare;

local
  fun neg True = False
    | neg False = True
    | neg (Literal (fv,lit)) = Literal (fv, Literal.negate lit)
    | neg (And (fv,c,s)) = Or (fv, countNegate c, neg_set s)
    | neg (Or (fv,c,s)) = And (fv, countNegate c, neg_set s)
    | neg (Xor (fv,c,p,s)) = Xor (fv, c, not p, s)
    | neg (Exists (fv,c,n,f)) = Forall (fv, countNegate c, n, neg f)
    | neg (Forall (fv,c,n,f)) = Exists (fv, countNegate c, n, neg f)

  and neg_set s = Set.foldl neg_elt empty s

  and neg_elt (f,s) = Set.add s (neg f);
in
  val negate = neg;

  val negateSet = neg_set;
end;

fun negateMember x s = Set.member (negate x) s;

local
  fun member s x = negateMember x s;
in
  fun negateDisjoint s1 s2 =
      if Set.size s1 < Set.size s2 then not (Set.exists (member s2) s1)
      else not (Set.exists (member s1) s2);
end;

fun polarity True = true
  | polarity False = false
  | polarity (Literal (_,(pol,_))) = not pol
  | polarity (And _) = true
  | polarity (Or _) = false
  | polarity (Xor (_,_,pol,_)) = pol
  | polarity (Exists _) = true
  | polarity (Forall _) = false;

(*MetisDebug
val polarity = fn f =>
    let
      val res1 = compare (f, negate f) = LESS
      val res2 = polarity f
      val _ = res1 = res2 orelse raise Bug "polarity"
    in
      res2
    end;
*)

fun applyPolarity true fm = fm
  | applyPolarity false fm = negate fm;

fun freeVars True = NameSet.empty
  | freeVars False = NameSet.empty
  | freeVars (Literal (fv,_)) = fv
  | freeVars (And (fv,_,_)) = fv
  | freeVars (Or (fv,_,_)) = fv
  | freeVars (Xor (fv,_,_,_)) = fv
  | freeVars (Exists (fv,_,_,_)) = fv
  | freeVars (Forall (fv,_,_,_)) = fv;

fun freeIn v fm = NameSet.member v (freeVars fm);

val freeVarsSet =
    let
      fun free (fm,acc) = NameSet.union (freeVars fm) acc
    in
      Set.foldl free NameSet.empty
    end;

fun count True = countTrue
  | count False = countFalse
  | count (Literal _) = countLiteral
  | count (And (_,c,_)) = c
  | count (Or (_,c,_)) = c
  | count (Xor (_,c,p,_)) = if p then c else countNegate c
  | count (Exists (_,c,_,_)) = c
  | count (Forall (_,c,_,_)) = c;

val countAndSet =
    let
      fun countAnd (fm,c) = countAnd2 (count fm, c)
    in
      Set.foldl countAnd countTrue
    end;

val countOrSet =
    let
      fun countOr (fm,c) = countOr2 (count fm, c)
    in
      Set.foldl countOr countFalse
    end;

val countXorSet =
    let
      fun countXor (fm,c) = countXor2 (count fm, c)
    in
      Set.foldl countXor countFalse
    end;

fun And2 (False,_) = False
  | And2 (_,False) = False
  | And2 (True,f2) = f2
  | And2 (f1,True) = f1
  | And2 (f1,f2) =
    let
      val (fv1,c1,s1) =
          case f1 of
            And fv_c_s => fv_c_s
          | _ => (freeVars f1, count f1, singleton f1)

      and (fv2,c2,s2) =
          case f2 of
            And fv_c_s => fv_c_s
          | _ => (freeVars f2, count f2, singleton f2)
    in
      if not (negateDisjoint s1 s2) then False
      else
        let
          val s = Set.union s1 s2
        in
          case Set.size s of
            0 => True
          | 1 => Set.pick s
          | n =>
            if n = Set.size s1 + Set.size s2 then
              And (NameSet.union fv1 fv2, countAnd2 (c1,c2), s)
            else
              And (freeVarsSet s, countAndSet s, s)
        end
    end;

val AndList = List.foldl And2 True;

val AndSet = Set.foldl And2 True;

fun Or2 (True,_) = True
  | Or2 (_,True) = True
  | Or2 (False,f2) = f2
  | Or2 (f1,False) = f1
  | Or2 (f1,f2) =
    let
      val (fv1,c1,s1) =
          case f1 of
            Or fv_c_s => fv_c_s
          | _ => (freeVars f1, count f1, singleton f1)

      and (fv2,c2,s2) =
          case f2 of
            Or fv_c_s => fv_c_s
          | _ => (freeVars f2, count f2, singleton f2)
    in
      if not (negateDisjoint s1 s2) then True
      else
        let
          val s = Set.union s1 s2
        in
          case Set.size s of
            0 => False
          | 1 => Set.pick s
          | n =>
            if n = Set.size s1 + Set.size s2 then
              Or (NameSet.union fv1 fv2, countOr2 (c1,c2), s)
            else
              Or (freeVarsSet s, countOrSet s, s)
        end
    end;

val OrList = List.foldl Or2 False;

val OrSet = Set.foldl Or2 False;

fun pushOr2 (f1,f2) =
    let
      val s1 = case f1 of And (_,_,s) => s | _ => singleton f1
      and s2 = case f2 of And (_,_,s) => s | _ => singleton f2

      fun g x1 (x2,acc) = And2 (Or2 (x1,x2), acc)

      fun f (x1,acc) = Set.foldl (g x1) acc s2
    in
      Set.foldl f True s1
    end;

val pushOrList = List.foldl pushOr2 False;

local
  fun normalize fm =
      let
        val p = polarity fm
        val fm = applyPolarity p fm
      in
        (freeVars fm, count fm, p, singleton fm)
      end;
in
  fun Xor2 (False,f2) = f2
    | Xor2 (f1,False) = f1
    | Xor2 (True,f2) = negate f2
    | Xor2 (f1,True) = negate f1
    | Xor2 (f1,f2) =
      let
        val (fv1,c1,p1,s1) = case f1 of Xor x => x | _ => normalize f1
        and (fv2,c2,p2,s2) = case f2 of Xor x => x | _ => normalize f2

        val s = Set.symmetricDifference s1 s2

        val fm =
            case Set.size s of
              0 => False
            | 1 => Set.pick s
            | n =>
              if n = Set.size s1 + Set.size s2 then
                Xor (NameSet.union fv1 fv2, countXor2 (c1,c2), true, s)
              else
                Xor (freeVarsSet s, countXorSet s, true, s)

        val p = p1 = p2
      in
        applyPolarity p fm
      end;
end;

val XorList = List.foldl Xor2 False;

val XorSet = Set.foldl Xor2 False;

fun XorPolarityList (p,l) = applyPolarity p (XorList l);

fun XorPolaritySet (p,s) = applyPolarity p (XorSet s);

fun destXor (Xor (_,_,p,s)) =
    let
      val (fm1,s) = Set.deletePick s
      val fm2 =
          if Set.size s = 1 then applyPolarity p (Set.pick s)
          else Xor (freeVarsSet s, countXorSet s, p, s)
    in
      (fm1,fm2)
    end
  | destXor _ = raise Error "destXor";

fun pushXor fm =
    let
      val (f1,f2) = destXor fm
      val f1' = negate f1
      and f2' = negate f2
    in
      And2 (Or2 (f1,f2), Or2 (f1',f2'))
    end;

fun Exists1 (v,init_fm) =
    let
      fun exists_gen fm =
          let
            val fv = NameSet.delete (freeVars fm) v
            val c = count fm
            val n = NameSet.singleton v
          in
            Exists (fv,c,n,fm)
          end

      fun exists fm = if freeIn v fm then exists_free fm else fm

      and exists_free (Or (_,_,s)) = OrList (Set.transform exists s)
        | exists_free (fm as And (_,_,s)) =
          let
            val sv = Set.filter (freeIn v) s
          in
            if Set.size sv <> 1 then exists_gen fm
            else
              let
                val fm = Set.pick sv
                val s = Set.delete s fm
              in
                And2 (exists_free fm, AndSet s)
              end
          end
        | exists_free (Exists (fv,c,n,f)) =
          Exists (NameSet.delete fv v, c, NameSet.add n v, f)
        | exists_free fm = exists_gen fm
    in
      exists init_fm
    end;

fun ExistsList (vs,f) = List.foldl Exists1 f vs;

fun ExistsSet (n,f) = NameSet.foldl Exists1 f n;

fun Forall1 (v,init_fm) =
    let
      fun forall_gen fm =
          let
            val fv = NameSet.delete (freeVars fm) v
            val c = count fm
            val n = NameSet.singleton v
          in
            Forall (fv,c,n,fm)
          end

      fun forall fm = if freeIn v fm then forall_free fm else fm

      and forall_free (And (_,_,s)) = AndList (Set.transform forall s)
        | forall_free (fm as Or (_,_,s)) =
          let
            val sv = Set.filter (freeIn v) s
          in
            if Set.size sv <> 1 then forall_gen fm
            else
              let
                val fm = Set.pick sv
                val s = Set.delete s fm
              in
                Or2 (forall_free fm, OrSet s)
              end
          end
        | forall_free (Forall (fv,c,n,f)) =
          Forall (NameSet.delete fv v, c, NameSet.add n v, f)
        | forall_free fm = forall_gen fm
    in
      forall init_fm
    end;

fun ForallList (vs,f) = List.foldl Forall1 f vs;

fun ForallSet (n,f) = NameSet.foldl Forall1 f n;

fun generalize f = ForallSet (freeVars f, f);

local
  fun subst_fv fvSub =
      let
        fun add_fv (v,s) = NameSet.union (NameMap.get fvSub v) s
      in
        NameSet.foldl add_fv NameSet.empty
      end;

  fun subst_rename (v,(avoid,bv,sub,domain,fvSub)) =
      let
        val v' = Term.variantPrime avoid v
        val avoid = NameSet.add avoid v'
        val bv = NameSet.add bv v'
        val sub = Subst.insert sub (v, Term.Var v')
        val domain = NameSet.add domain v
        val fvSub = NameMap.insert fvSub (v, NameSet.singleton v')
      in
        (avoid,bv,sub,domain,fvSub)
      end;

  fun subst_check sub domain fvSub fm =
      let
        val domain = NameSet.intersect domain (freeVars fm)
      in
        if NameSet.null domain then fm
        else subst_domain sub domain fvSub fm
      end

  and subst_domain sub domain fvSub fm =
      case fm of
        Literal (fv,lit) =>
        let
          val fv = NameSet.difference fv domain
          val fv = NameSet.union fv (subst_fv fvSub domain)
          val lit = Literal.subst sub lit
        in
          Literal (fv,lit)
        end
      | And (_,_,s) =>
        AndList (Set.transform (subst_check sub domain fvSub) s)
      | Or (_,_,s) =>
        OrList (Set.transform (subst_check sub domain fvSub) s)
      | Xor (_,_,p,s) =>
        XorPolarityList (p, Set.transform (subst_check sub domain fvSub) s)
      | Exists fv_c_n_f => subst_quant Exists sub domain fvSub fv_c_n_f
      | Forall fv_c_n_f => subst_quant Forall sub domain fvSub fv_c_n_f
      | _ => raise Bug "subst_domain"

  and subst_quant quant sub domain fvSub (fv,c,bv,fm) =
      let
        val sub_fv = subst_fv fvSub domain
        val fv = NameSet.union sub_fv (NameSet.difference fv domain)
        val captured = NameSet.intersect bv sub_fv
        val bv = NameSet.difference bv captured
        val avoid = NameSet.union fv bv
        val (_,bv,sub,domain,fvSub) =
            NameSet.foldl subst_rename (avoid,bv,sub,domain,fvSub) captured
        val fm = subst_domain sub domain fvSub fm
      in
        quant (fv,c,bv,fm)
      end;
in
  fun subst sub =
      let
        fun mk_dom (v,tm,(d,fv)) =
            (NameSet.add d v, NameMap.insert fv (v, Term.freeVars tm))

        val domain_fvSub = (NameSet.empty, NameMap.new ())
        val (domain,fvSub) = Subst.foldl mk_dom domain_fvSub sub
      in
        subst_check sub domain fvSub
      end;
end;

fun fromFormula fm =
    case fm of
      Formula.True => True
    | Formula.False => False
    | Formula.Atom atm => Literal (Atom.freeVars atm, (true,atm))
    | Formula.Not p => negateFromFormula p
    | Formula.And (p,q) => And2 (fromFormula p, fromFormula q)
    | Formula.Or (p,q) => Or2 (fromFormula p, fromFormula q)
    | Formula.Imp (p,q) => Or2 (negateFromFormula p, fromFormula q)
    | Formula.Iff (p,q) => Xor2 (negateFromFormula p, fromFormula q)
    | Formula.Forall (v,p) => Forall1 (v, fromFormula p)
    | Formula.Exists (v,p) => Exists1 (v, fromFormula p)

and negateFromFormula fm =
    case fm of
      Formula.True => False
    | Formula.False => True
    | Formula.Atom atm => Literal (Atom.freeVars atm, (false,atm))
    | Formula.Not p => fromFormula p
    | Formula.And (p,q) => Or2 (negateFromFormula p, negateFromFormula q)
    | Formula.Or (p,q) => And2 (negateFromFormula p, negateFromFormula q)
    | Formula.Imp (p,q) => And2 (fromFormula p, negateFromFormula q)
    | Formula.Iff (p,q) => Xor2 (fromFormula p, fromFormula q)
    | Formula.Forall (v,p) => Exists1 (v, negateFromFormula p)
    | Formula.Exists (v,p) => Forall1 (v, negateFromFormula p);

local
  fun lastElt (s : formula Set.set) =
      case Set.findr (K true) s of
        NONE => raise Bug "lastElt: empty set"
      | SOME fm => fm;

  fun negateLastElt s =
      let
        val fm = lastElt s
      in
        Set.add (Set.delete s fm) (negate fm)
      end;

  fun form fm =
      case fm of
        True => Formula.True
      | False => Formula.False
      | Literal (_,lit) => Literal.toFormula lit
      | And (_,_,s) => Formula.listMkConj (Set.transform form s)
      | Or (_,_,s) => Formula.listMkDisj (Set.transform form s)
      | Xor (_,_,p,s) =>
        let
          val s = if p then negateLastElt s else s
        in
          Formula.listMkEquiv (Set.transform form s)
        end
      | Exists (_,_,n,f) => Formula.listMkExists (NameSet.toList n, form f)
      | Forall (_,_,n,f) => Formula.listMkForall (NameSet.toList n, form f);
in
  val toFormula = form;
end;

fun toLiteral (Literal (_,lit)) = lit
  | toLiteral _ = raise Error "Normalize.toLiteral";

local
  fun addLiteral (l,s) = LiteralSet.add s (toLiteral l);
in
  fun toClause False = LiteralSet.empty
    | toClause (Or (_,_,s)) = Set.foldl addLiteral LiteralSet.empty s
    | toClause l = LiteralSet.singleton (toLiteral l);
end;

val pp = Print.ppMap toFormula Formula.pp;

val toString = Print.toString pp;

(* ------------------------------------------------------------------------- *)
(* Negation normal form.                                                     *)
(* ------------------------------------------------------------------------- *)

fun nnf fm = toFormula (fromFormula fm);

(* ------------------------------------------------------------------------- *)
(* Basic conjunctive normal form.                                            *)
(* ------------------------------------------------------------------------- *)

val newSkolemFunction =
    let
      val counter : int StringMap.map Unsynchronized.ref = Unsynchronized.ref (StringMap.new ())
    in
      fn n =>
         let
           val Unsynchronized.ref m = counter
           val s = Name.toString n
           val i = Option.getOpt (StringMap.peek m s, 0)
           val () = counter := StringMap.insert m (s, i + 1)
           val i = if i = 0 then "" else "_" ^ Int.toString i
           val s = skolemPrefix ^ "_" ^ s ^ i
         in
           Name.fromString s
         end
    end;

fun skolemize fv bv fm =
    let
      val fv = NameSet.transform Term.Var fv

      fun mk (v,s) = Subst.insert s (v, Term.Fn (newSkolemFunction v, fv))
    in
      subst (NameSet.foldl mk Subst.empty bv) fm
    end;

local
  fun rename avoid fv bv fm =
      let
        val captured = NameSet.intersect avoid bv
      in
        if NameSet.null captured then fm
        else
          let
            fun ren (v,(a,s)) =
                let
                  val v' = Term.variantPrime a v
                in
                  (NameSet.add a v', Subst.insert s (v, Term.Var v'))
                end

            val avoid = NameSet.union (NameSet.union avoid fv) bv

            val (_,sub) = NameSet.foldl ren (avoid,Subst.empty) captured
          in
            subst sub fm
          end
      end;

  fun cnfFm avoid fm =
(*MetisTrace5
      let
        val fm' = cnfFm' avoid fm
        val () = Print.trace pp "Normalize.cnfFm: fm" fm
        val () = Print.trace pp "Normalize.cnfFm: fm'" fm'
      in
        fm'
      end
  and cnfFm' avoid fm =
*)
      case fm of
        True => True
      | False => False
      | Literal _ => fm
      | And (_,_,s) => AndList (Set.transform (cnfFm avoid) s)
      | Or (fv,_,s) =>
        let
          val avoid = NameSet.union avoid fv
          val (fms,_) = Set.foldl cnfOr ([],avoid) s
        in
          pushOrList fms
        end
      | Xor _ => cnfFm avoid (pushXor fm)
      | Exists (fv,_,n,f) => cnfFm avoid (skolemize fv n f)
      | Forall (fv,_,n,f) => cnfFm avoid (rename avoid fv n f)

  and cnfOr (fm,(fms,avoid)) =
      let
        val fm = cnfFm avoid fm
        val fms = fm :: fms
        val avoid = NameSet.union avoid (freeVars fm)
      in
        (fms,avoid)
      end;
in
  val basicCnf = cnfFm NameSet.empty;
end;

(* ------------------------------------------------------------------------- *)
(* Finding the formula definition that minimizes the number of clauses.      *)
(* ------------------------------------------------------------------------- *)

local
  type best = count * formula option;

  fun minBreak countClauses fm best =
      case fm of
        True => best
      | False => best
      | Literal _ => best
      | And (_,_,s) =>
        minBreakSet countClauses countAnd2 countTrue AndSet s best
      | Or (_,_,s) =>
        minBreakSet countClauses countOr2 countFalse OrSet s best
      | Xor (_,_,_,s) =>
        minBreakSet countClauses countXor2 countFalse XorSet s best
      | Exists (_,_,_,f) => minBreak countClauses f best
      | Forall (_,_,_,f) => minBreak countClauses f best

  and minBreakSet countClauses count2 count0 mkSet fmSet best =
      let
        fun cumulatives fms =
            let
              fun fwd (fm,(c1,s1,l)) =
                  let
                    val c1' = count2 (count fm, c1)
                    and s1' = Set.add s1 fm
                  in
                    (c1', s1', (c1,s1,fm) :: l)
                  end

              fun bwd ((c1,s1,fm),(c2,s2,l)) =
                  let
                    val c2' = count2 (count fm, c2)
                    and s2' = Set.add s2 fm
                  in
                    (c2', s2', (c1,s1,fm,c2,s2) :: l)
                  end

              val (c1,_,fms) = foldl fwd (count0,empty,[]) fms
              val (c2,_,fms) = foldl bwd (count0,empty,[]) fms

(*MetisDebug
              val _ = countEquivish c1 c2 orelse
                      raise Bug ("cumulativeCounts: c1 = " ^ countToString c1 ^
                                 ", c2 = " ^ countToString c2)
*)
            in
              fms
            end

        fun breakSing ((c1,_,fm,c2,_),best) =
            let
              val cFms = count2 (c1,c2)

              fun countCls cFm = countClauses (count2 (cFms,cFm))
            in
              minBreak countCls fm best
            end

        val breakSet1 =
            let
              fun break c1 s1 fm c2 (best as (bcl,_)) =
                  if Set.null s1 then best
                  else
                    let
                      val cDef = countDefinition (countXor2 (c1, count fm))
                      val cFm = count2 (countLiteral,c2)
                      val cl = countAnd2 (cDef, countClauses cFm)
                      val noBetter = countLeqish bcl cl
                    in
                      if noBetter then best
                      else (cl, SOME (mkSet (Set.add s1 fm)))
                    end
            in
              fn ((c1,s1,fm,c2,s2),best) =>
                 break c1 s1 fm c2 (break c2 s2 fm c1 best)
            end

        val fms = Set.toList fmSet

        fun breakSet measure best =
            let
              val fms = sortMap (measure o count) countCompare fms
            in
              foldl breakSet1 best (cumulatives fms)
            end

        val best = foldl breakSing best (cumulatives fms)
        val best = breakSet I best
        val best = breakSet countNegate best
        val best = breakSet countClauses best
      in
        best
      end
in
  fun minimumDefinition fm =
      let
        val cl = count fm
      in
        if countLeqish cl countLiteral then NONE
        else
          let
            val (cl',def) = minBreak I fm (cl,NONE)
(*MetisTrace1
            val () =
                case def of
                  NONE => ()
                | SOME d =>
                  Print.trace pp ("defCNF: before = " ^ countToString cl ^
                                  ", after = " ^ countToString cl' ^
                                  ", definition") d
*)
          in
            def
          end
      end;
end;

(* ------------------------------------------------------------------------- *)
(* Conjunctive normal form derivations.                                      *)
(* ------------------------------------------------------------------------- *)

datatype thm = Thm of formula * inference

and inference =
    Axiom of Formula.formula
  | Definition of string * Formula.formula
  | Simplify of thm * thm list
  | Conjunct of thm
  | Specialize of thm
  | Skolemize of thm
  | Clausify of thm;

fun parentsInference inf =
    case inf of
      Axiom _ => []
    | Definition _ => []
    | Simplify (th,ths) => th :: ths
    | Conjunct th => [th]
    | Specialize th => [th]
    | Skolemize th => [th]
    | Clausify th => [th];

fun compareThm (Thm (fm1,_), Thm (fm2,_)) = compare (fm1,fm2);

fun parentsThm (Thm (_,inf)) = parentsInference inf;

fun mkAxiom fm = Thm (fromFormula fm, Axiom fm);

fun destThm (Thm (fm,inf)) = (toFormula fm, inf);

local
  val emptyProved : (thm,Formula.formula) Map.map = Map.new compareThm;

  fun isProved proved th = Map.inDomain th proved;

  fun isUnproved proved th = not (isProved proved th);

  fun lookupProved proved th =
      case Map.peek proved th of
        SOME fm => fm
      | NONE => raise Bug "Normalize.lookupProved";

  fun prove acc proved ths =
      case ths of
        [] => rev acc
      | th :: ths' =>
        if isProved proved th then prove acc proved ths'
        else
          let
            val pars = parentsThm th

            val deps = List.filter (isUnproved proved) pars
          in
            if null deps then
              let
                val (fm,inf) = destThm th

                val fms = map (lookupProved proved) pars

                val acc = (fm,inf,fms) :: acc

                val proved = Map.insert proved (th,fm)
              in
                prove acc proved ths'
              end
            else
              let
                val ths = deps @ ths
              in
                prove acc proved ths
              end
          end;
in
  val proveThms = prove [] emptyProved;
end;

fun toStringInference inf =
    case inf of
      Axiom _ => "Axiom"
    | Definition _ => "Definition"
    | Simplify _ => "Simplify"
    | Conjunct _ => "Conjunct"
    | Specialize _ => "Specialize"
    | Skolemize _ => "Skolemize"
    | Clausify _ => "Clausify";

val ppInference = Print.ppMap toStringInference Print.ppString;

(* ------------------------------------------------------------------------- *)
(* Simplifying with definitions.                                             *)
(* ------------------------------------------------------------------------- *)

datatype simplify =
    Simp of
      {formula : (formula, formula * thm) Map.map,
       andSet : (formula Set.set * formula * thm) list,
       orSet : (formula Set.set * formula * thm) list,
       xorSet : (formula Set.set * formula * thm) list};

val simplifyEmpty =
    Simp
      {formula = Map.new compare,
       andSet = [],
       orSet = [],
       xorSet = []};

local
  fun simpler fm s =
      Set.size s <> 1 orelse
      case Set.pick s of
        True => false
      | False => false
      | Literal _ => false
      | _ => true;

  fun addSet set_defs body_def =
      let
        fun def_body_size (body,_,_) = Set.size body

        val body_size = def_body_size body_def

        val (body,_,_) = body_def

        fun add acc [] = List.revAppend (acc,[body_def])
          | add acc (l as (bd as (b,_,_)) :: bds) =
            case Int.compare (def_body_size bd, body_size) of
              LESS => List.revAppend (acc, body_def :: l)
            | EQUAL =>
              if Set.equal b body then List.revAppend (acc,l)
              else add (bd :: acc) bds
            | GREATER => add (bd :: acc) bds
      in
        add [] set_defs
      end;

  fun add simp (body,False,th) = add simp (negate body, True, th)
    | add simp (True,_,_) = simp
    | add (Simp {formula,andSet,orSet,xorSet}) (And (_,_,s), def, th) =
      let
        val andSet = addSet andSet (s,def,th)
        and orSet = addSet orSet (negateSet s, negate def, th)
      in
        Simp
          {formula = formula,
           andSet = andSet,
           orSet = orSet,
           xorSet = xorSet}
      end
    | add (Simp {formula,andSet,orSet,xorSet}) (Or (_,_,s), def, th) =
      let
        val orSet = addSet orSet (s,def,th)
        and andSet = addSet andSet (negateSet s, negate def, th)
      in
        Simp
          {formula = formula,
           andSet = andSet,
           orSet = orSet,
           xorSet = xorSet}
      end
    | add simp (Xor (_,_,p,s), def, th) =
      let
        val simp = addXorSet simp (s, applyPolarity p def, th)
      in
        case def of
          True =>
          let
            fun addXorLiteral (fm as Literal _, simp) =
                let
                  val s = Set.delete s fm
                in
                  if not (simpler fm s) then simp
                  else addXorSet simp (s, applyPolarity (not p) fm, th)
                end
              | addXorLiteral (_,simp) = simp
          in
            Set.foldl addXorLiteral simp s
          end
        | _ => simp
      end
    | add (simp as Simp {formula,andSet,orSet,xorSet}) (body,def,th) =
      if Map.inDomain body formula then simp
      else
        let
          val formula = Map.insert formula (body,(def,th))
          val formula = Map.insert formula (negate body, (negate def, th))
        in
          Simp
            {formula = formula,
             andSet = andSet,
             orSet = orSet,
             xorSet = xorSet}
        end

  and addXorSet (simp as Simp {formula,andSet,orSet,xorSet}) (s,def,th) =
      if Set.size s = 1 then add simp (Set.pick s, def, th)
      else
        let
          val xorSet = addSet xorSet (s,def,th)
        in
          Simp
            {formula = formula,
             andSet = andSet,
             orSet = orSet,
             xorSet = xorSet}
        end;
in
  fun simplifyAdd simp (th as Thm (fm,_)) = add simp (fm,True,th);
end;

local
  fun simplifySet set_defs set =
      let
        fun pred (s,_,_) = Set.subset s set
      in
        case List.find pred set_defs of
          NONE => NONE
        | SOME (s,f,th) =>
          let
            val set = Set.add (Set.difference set s) f
          in
            SOME (set,th)
          end
      end;
in
  fun simplify (Simp {formula,andSet,orSet,xorSet}) =
      let
        fun simp fm inf =
            case simp_sub fm inf of
              NONE => simp_top fm inf
            | SOME (fm,inf) => try_simp_top fm inf

        and try_simp_top fm inf =
            case simp_top fm inf of
              NONE => SOME (fm,inf)
            | x => x

        and simp_top fm inf =
            case fm of
              And (_,_,s) =>
              (case simplifySet andSet s of
                 NONE => NONE
               | SOME (s,th) =>
                 let
                   val fm = AndSet s
                   val inf = th :: inf
                 in
                   try_simp_top fm inf
                 end)
            | Or (_,_,s) =>
              (case simplifySet orSet s of
                 NONE => NONE
               | SOME (s,th) =>
                 let
                   val fm = OrSet s
                   val inf = th :: inf
                 in
                   try_simp_top fm inf
                 end)
            | Xor (_,_,p,s) =>
              (case simplifySet xorSet s of
                 NONE => NONE
               | SOME (s,th) =>
                 let
                   val fm = XorPolaritySet (p,s)
                   val inf = th :: inf
                 in
                   try_simp_top fm inf
                 end)
            | _ =>
              (case Map.peek formula fm of
                 NONE => NONE
               | SOME (fm,th) =>
                 let
                   val inf = th :: inf
                 in
                   try_simp_top fm inf
                 end)

        and simp_sub fm inf =
            case fm of
              And (_,_,s) =>
              (case simp_set s inf of
                 NONE => NONE
               | SOME (l,inf) => SOME (AndList l, inf))
            | Or (_,_,s) =>
              (case simp_set s inf of
                 NONE => NONE
               | SOME (l,inf) => SOME (OrList l, inf))
            | Xor (_,_,p,s) =>
              (case simp_set s inf of
                 NONE => NONE
               | SOME (l,inf) => SOME (XorPolarityList (p,l), inf))
            | Exists (_,_,n,f) =>
              (case simp f inf of
                 NONE => NONE
               | SOME (f,inf) => SOME (ExistsSet (n,f), inf))
            | Forall (_,_,n,f) =>
              (case simp f inf of
                 NONE => NONE
               | SOME (f,inf) => SOME (ForallSet (n,f), inf))
            | _ => NONE

        and simp_set s inf =
            let
              val (changed,l,inf) = Set.foldr simp_set_elt (false,[],inf) s
            in
              if changed then SOME (l,inf) else NONE
            end

        and simp_set_elt (fm,(changed,l,inf)) =
            case simp fm inf of
              NONE => (changed, fm :: l, inf)
            | SOME (fm,inf) => (true, fm :: l, inf)
      in
        fn th as Thm (fm,_) =>
           case simp fm [] of
             SOME (fm,ths) =>
             let
               val inf = Simplify (th,ths)
             in
               Thm (fm,inf)
             end
           | NONE => th
      end;
end;

(*MetisTrace2
val simplify = fn simp => fn th as Thm (fm,_) =>
    let
      val th' as Thm (fm',_) = simplify simp th
      val () = if compare (fm,fm') = EQUAL then ()
               else (Print.trace pp "Normalize.simplify: fm" fm;
                     Print.trace pp "Normalize.simplify: fm'" fm')
    in
      th'
    end;
*)

(* ------------------------------------------------------------------------- *)
(* Definitions.                                                              *)
(* ------------------------------------------------------------------------- *)

val newDefinitionRelation =
    let
      val counter : int Unsynchronized.ref = Unsynchronized.ref 0
    in
      fn () =>
         let
           val Unsynchronized.ref i = counter
           val () = counter := i + 1
         in
           definitionPrefix ^ "_" ^ Int.toString i
         end
    end;

fun newDefinition def =
    let
      val fv = freeVars def
      val rel = newDefinitionRelation ()
      val atm = (Name.fromString rel, NameSet.transform Term.Var fv)
      val fm = Formula.Iff (Formula.Atom atm, toFormula def)
      val fm = Formula.setMkForall (fv,fm)
      val inf = Definition (rel,fm)
      val lit = Literal (fv,(false,atm))
      val fm = Xor2 (lit,def)
    in
      Thm (fm,inf)
    end;

(* ------------------------------------------------------------------------- *)
(* Definitional conjunctive normal form.                                     *)
(* ------------------------------------------------------------------------- *)

datatype cnf =
    ConsistentCnf of simplify
  | InconsistentCnf;

val initialCnf = ConsistentCnf simplifyEmpty;

local
  fun def_cnf_inconsistent th =
      let
        val cls = [(LiteralSet.empty,th)]
      in
        (cls,InconsistentCnf)
      end;

  fun def_cnf_clause inf (fm,acc) =
      let
        val cl = toClause fm
        val th = Thm (fm,inf)
      in
        (cl,th) :: acc
      end
(*MetisDebug
      handle Error err =>
        (Print.trace pp "Normalize.addCnf.def_cnf_clause: fm" fm;
         raise Bug ("Normalize.addCnf.def_cnf_clause: " ^ err));
*)

  fun def_cnf cls simp ths =
      case ths of
        [] => (cls, ConsistentCnf simp)
      | th :: ths => def_cnf_formula cls simp (simplify simp th) ths

  and def_cnf_formula cls simp (th as Thm (fm,_)) ths =
      case fm of
        True => def_cnf cls simp ths
      | False => def_cnf_inconsistent th
      | And (_,_,s) =>
        let
          fun add (f,z) = Thm (f, Conjunct th) :: z
        in
          def_cnf cls simp (Set.foldr add ths s)
        end
      | Exists (fv,_,n,f) =>
        let
          val th = Thm (skolemize fv n f, Skolemize th)
        in
          def_cnf_formula cls simp th ths
        end
      | Forall (_,_,_,f) =>
        let
          val th = Thm (f, Specialize th)
        in
          def_cnf_formula cls simp th ths
        end
      | _ =>
        case minimumDefinition fm of
          SOME def =>
          let
            val ths = th :: ths
            val th = newDefinition def
          in
            def_cnf_formula cls simp th ths
          end
        | NONE =>
          let
            val simp = simplifyAdd simp th

            val fm = basicCnf fm

            val inf = Clausify th
          in
            case fm of
              True => def_cnf cls simp ths
            | False => def_cnf_inconsistent (Thm (fm,inf))
            | And (_,_,s) =>
              let
                val inf = Conjunct (Thm (fm,inf))
                val cls = Set.foldl (def_cnf_clause inf) cls s
              in
                def_cnf cls simp ths
              end
            | fm => def_cnf (def_cnf_clause inf (fm,cls)) simp ths
          end;
in
  fun addCnf th cnf =
      case cnf of
        ConsistentCnf simp => def_cnf [] simp [th]
      | InconsistentCnf => ([],cnf);
end;

local
  fun add (th,(cls,cnf)) =
      let
        val (cls',cnf) = addCnf th cnf
      in
        (cls' @ cls, cnf)
      end;
in
  fun proveCnf ths =
      let
        val (cls,_) = List.foldl add ([],initialCnf) ths
      in
        rev cls
      end;
end;

fun cnf fm =
    let
      val cls = proveCnf [mkAxiom fm]
    in
      map fst cls
    end;

end
end;

(**** Original file: Model.sig ****)

(* ========================================================================= *)
(* RANDOM FINITE MODELS                                                      *)
(* Copyright (c) 2003 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature Model =
sig

(* ------------------------------------------------------------------------- *)
(* Model size.                                                               *)
(* ------------------------------------------------------------------------- *)

type size = {size : int}

(* ------------------------------------------------------------------------- *)
(* A model of size N has integer elements 0...N-1.                           *)
(* ------------------------------------------------------------------------- *)

type element = int

val zeroElement : element

val incrementElement : size -> element -> element option

(* ------------------------------------------------------------------------- *)
(* The parts of the model that are fixed.                                    *)
(* ------------------------------------------------------------------------- *)

type fixedFunction = size -> element list -> element option

type fixedRelation = size -> element list -> bool option

datatype fixed =
    Fixed of
      {functions : fixedFunction Metis.NameArityMap.map,
       relations : fixedRelation Metis.NameArityMap.map}

val emptyFixed : fixed

val unionFixed : fixed -> fixed -> fixed

val getFunctionFixed : fixed -> Metis.NameArity.nameArity -> fixedFunction

val getRelationFixed : fixed -> Metis.NameArity.nameArity -> fixedRelation

val insertFunctionFixed : fixed -> Metis.NameArity.nameArity * fixedFunction -> fixed

val insertRelationFixed : fixed -> Metis.NameArity.nameArity * fixedRelation -> fixed

val unionListFixed : fixed list -> fixed

val basicFixed : fixed  (* interprets equality and hasType *)

(* ------------------------------------------------------------------------- *)
(* Renaming fixed model parts.                                               *)
(* ------------------------------------------------------------------------- *)

type fixedMap =
     {functionMap : Metis.Name.name Metis.NameArityMap.map,
      relationMap : Metis.Name.name Metis.NameArityMap.map}

val mapFixed : fixedMap -> fixed -> fixed

val ppFixedMap : fixedMap Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* Standard fixed model parts.                                               *)
(* ------------------------------------------------------------------------- *)

(* Projections *)

val projectionMin : int

val projectionMax : int

val projectionName : int -> Metis.Name.name

val projectionFixed : fixed

(* Arithmetic *)

val numeralMin : int

val numeralMax : int

val numeralName : int -> Metis.Name.name

val addName : Metis.Name.name

val divName : Metis.Name.name

val dividesName : Metis.Name.name

val evenName : Metis.Name.name

val expName : Metis.Name.name

val geName : Metis.Name.name

val gtName : Metis.Name.name

val isZeroName : Metis.Name.name

val leName : Metis.Name.name

val ltName : Metis.Name.name

val modName : Metis.Name.name

val multName : Metis.Name.name

val negName : Metis.Name.name

val oddName : Metis.Name.name

val preName : Metis.Name.name

val subName : Metis.Name.name

val sucName : Metis.Name.name

val modularFixed : fixed

val overflowFixed : fixed

(* Sets *)

val cardName : Metis.Name.name

val complementName : Metis.Name.name

val differenceName : Metis.Name.name

val emptyName : Metis.Name.name

val memberName : Metis.Name.name

val insertName : Metis.Name.name

val intersectName : Metis.Name.name

val singletonName : Metis.Name.name

val subsetName : Metis.Name.name

val symmetricDifferenceName : Metis.Name.name

val unionName : Metis.Name.name

val universeName : Metis.Name.name

val setFixed : fixed

(* Lists *)

val appendName : Metis.Name.name

val consName : Metis.Name.name

val lengthName : Metis.Name.name

val nilName : Metis.Name.name

val nullName : Metis.Name.name

val tailName : Metis.Name.name

val listFixed : fixed

(* ------------------------------------------------------------------------- *)
(* Valuations.                                                               *)
(* ------------------------------------------------------------------------- *)

type valuation

val emptyValuation : valuation

val zeroValuation : Metis.NameSet.set -> valuation

val constantValuation : element -> Metis.NameSet.set -> valuation

val peekValuation : valuation -> Metis.Name.name -> element option

val getValuation : valuation -> Metis.Name.name -> element

val insertValuation : valuation -> Metis.Name.name * element -> valuation

val randomValuation : {size : int} -> Metis.NameSet.set -> valuation

val incrementValuation :
    {size : int} -> Metis.NameSet.set -> valuation -> valuation option

val foldValuation :
    {size : int} -> Metis.NameSet.set -> (valuation * 'a -> 'a) -> 'a -> 'a

(* ------------------------------------------------------------------------- *)
(* A type of random finite models.                                           *)
(* ------------------------------------------------------------------------- *)

type parameters = {size : int, fixed : fixed}

type model

val default : parameters

val new : parameters -> model

val size : model -> int

(* ------------------------------------------------------------------------- *)
(* Interpreting terms and formulas in the model.                             *)
(* ------------------------------------------------------------------------- *)

val interpretFunction : model -> Metis.Term.functionName * element list -> element

val interpretRelation : model -> Metis.Atom.relationName * element list -> bool

val interpretTerm : model -> valuation -> Metis.Term.term -> element

val interpretAtom : model -> valuation -> Metis.Atom.atom -> bool

val interpretFormula : model -> valuation -> Metis.Formula.formula -> bool

val interpretLiteral : model -> valuation -> Metis.Literal.literal -> bool

val interpretClause : model -> valuation -> Metis.Thm.clause -> bool

(* ------------------------------------------------------------------------- *)
(* Check whether random groundings of a formula are true in the model.       *)
(* Note: if it's cheaper, a systematic check will be performed instead.      *)
(* ------------------------------------------------------------------------- *)

val check :
    (model -> valuation -> 'a -> bool) -> {maxChecks : int option} -> model ->
    Metis.NameSet.set -> 'a -> {T : int, F : int}

val checkAtom :
    {maxChecks : int option} -> model -> Metis.Atom.atom -> {T : int, F : int}

val checkFormula :
    {maxChecks : int option} -> model -> Metis.Formula.formula -> {T : int, F : int}

val checkLiteral :
    {maxChecks : int option} -> model -> Metis.Literal.literal -> {T : int, F : int}

val checkClause :
    {maxChecks : int option} -> model -> Metis.Thm.clause -> {T : int, F : int}

(* ------------------------------------------------------------------------- *)
(* Updating the model.                                                       *)
(* ------------------------------------------------------------------------- *)

val updateFunction :
    model -> (Metis.Term.functionName * element list) * element -> unit

val updateRelation :
    model -> (Metis.Atom.relationName * element list) * bool -> unit

(* ------------------------------------------------------------------------- *)
(* Choosing a random perturbation to make a formula true in the model.       *)
(* ------------------------------------------------------------------------- *)

val perturbTerm : model -> valuation -> Metis.Term.term * element list -> unit

val perturbAtom : model -> valuation -> Metis.Atom.atom * bool -> unit

val perturbLiteral : model -> valuation -> Metis.Literal.literal -> unit

val perturbClause : model -> valuation -> Metis.Thm.clause -> unit

(* ------------------------------------------------------------------------- *)
(* Pretty printing.                                                          *)
(* ------------------------------------------------------------------------- *)

val pp : model Metis.Print.pp

end

(**** Original file: Model.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* RANDOM FINITE MODELS                                                      *)
(* Copyright (c) 2003 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

structure Model :> Model =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Constants.                                                                *)
(* ------------------------------------------------------------------------- *)

val maxSpace = 1000;

(* ------------------------------------------------------------------------- *)
(* Helper functions.                                                         *)
(* ------------------------------------------------------------------------- *)

val multInt =
    case Int.maxInt of
      NONE => (fn x => fn y => SOME (x * y))
    | SOME m =>
      let
        val m = Real.floor (Math.sqrt (Real.fromInt m))
      in
        fn x => fn y => if x <= m andalso y <= m then SOME (x * y) else NONE
      end;

local
  fun iexp x y acc =
      if y mod 2 = 0 then iexp' x y acc
      else
        case multInt acc x of
          SOME acc => iexp' x y acc
        | NONE => NONE

  and iexp' x y acc =
      if y = 1 then SOME acc
      else
        let
          val y = y div 2
        in
          case multInt x x of
            SOME x => iexp x y acc
          | NONE => NONE
        end;
in
  fun expInt x y =
      if y <= 1 then
        if y = 0 then SOME 1
        else if y = 1 then SOME x
        else raise Bug "expInt: negative exponent"
      else if x <= 1 then
        if 0 <= x then SOME x
        else raise Bug "expInt: negative exponand"
      else iexp x y 1;
end;

fun boolToInt true = 1
  | boolToInt false = 0;

fun intToBool 1 = true
  | intToBool 0 = false
  | intToBool _ = raise Bug "Model.intToBool";

fun minMaxInterval i j = interval i (1 + j - i);

(* ------------------------------------------------------------------------- *)
(* Model size.                                                               *)
(* ------------------------------------------------------------------------- *)

type size = {size : int};

(* ------------------------------------------------------------------------- *)
(* A model of size N has integer elements 0...N-1.                           *)
(* ------------------------------------------------------------------------- *)

type element = int;

val zeroElement = 0;

fun incrementElement {size = N} i =
    let
      val i = i + 1
    in
      if i = N then NONE else SOME i
    end;

fun elementListSpace {size = N} arity =
    case expInt N arity of
      NONE => NONE
    | s as SOME m => if m <= maxSpace then s else NONE;

fun elementListIndex {size = N} =
    let
      fun f acc elts =
          case elts of
            [] => acc
          | elt :: elts => f (N * acc + elt) elts
    in
      f 0
    end;

(* ------------------------------------------------------------------------- *)
(* The parts of the model that are fixed.                                    *)
(* ------------------------------------------------------------------------- *)

type fixedFunction = size -> element list -> element option;

type fixedRelation = size -> element list -> bool option;

datatype fixed =
    Fixed of
      {functions : fixedFunction NameArityMap.map,
       relations : fixedRelation NameArityMap.map};

val uselessFixedFunction : fixedFunction = K (K NONE);

val uselessFixedRelation : fixedRelation = K (K NONE);

val emptyFunctions : fixedFunction NameArityMap.map = NameArityMap.new ();

val emptyRelations : fixedRelation NameArityMap.map = NameArityMap.new ();

fun fixed0 f sz elts =
    case elts of
      [] => f sz
    | _ => raise Bug "Model.fixed0: wrong arity";

fun fixed1 f sz elts =
    case elts of
      [x] => f sz x
    | _ => raise Bug "Model.fixed1: wrong arity";

fun fixed2 f sz elts =
    case elts of
      [x,y] => f sz x y
    | _ => raise Bug "Model.fixed2: wrong arity";

val emptyFixed =
    let
      val fns = emptyFunctions
      and rels = emptyRelations
    in
      Fixed
        {functions = fns,
         relations = rels}
    end;

fun peekFunctionFixed fix name_arity =
    let
      val Fixed {functions = fns, ...} = fix
    in
      NameArityMap.peek fns name_arity
    end;

fun peekRelationFixed fix name_arity =
    let
      val Fixed {relations = rels, ...} = fix
    in
      NameArityMap.peek rels name_arity
    end;

fun getFunctionFixed fix name_arity =
    case peekFunctionFixed fix name_arity of
      SOME f => f
    | NONE => uselessFixedFunction;

fun getRelationFixed fix name_arity =
    case peekRelationFixed fix name_arity of
      SOME rel => rel
    | NONE => uselessFixedRelation;

fun insertFunctionFixed fix name_arity_fn =
    let
      val Fixed {functions = fns, relations = rels} = fix

      val fns = NameArityMap.insert fns name_arity_fn
    in
      Fixed
        {functions = fns,
         relations = rels}
    end;

fun insertRelationFixed fix name_arity_rel =
    let
      val Fixed {functions = fns, relations = rels} = fix

      val rels = NameArityMap.insert rels name_arity_rel
    in
      Fixed
        {functions = fns,
         relations = rels}
    end;

local
  fun union _ = raise Bug "Model.unionFixed: nameArity clash";
in
  fun unionFixed fix1 fix2 =
      let
        val Fixed {functions = fns1, relations = rels1} = fix1
        and Fixed {functions = fns2, relations = rels2} = fix2

        val fns = NameArityMap.union union fns1 fns2

        val rels = NameArityMap.union union rels1 rels2
      in
        Fixed
          {functions = fns,
           relations = rels}
      end;
end;

val unionListFixed =
    let
      fun union (fix,acc) = unionFixed acc fix
    in
      List.foldl union emptyFixed
    end;

local
  fun hasTypeFn _ elts =
      case elts of
        [x,_] => SOME x
      | _ => raise Bug "Model.hasTypeFn: wrong arity";

  fun eqRel _ elts =
      case elts of
        [x,y] => SOME (x = y)
      | _ => raise Bug "Model.eqRel: wrong arity";
in
  val basicFixed =
      let
        val fns = NameArityMap.singleton (Term.hasTypeFunction,hasTypeFn)

        val rels = NameArityMap.singleton (Atom.eqRelation,eqRel)
      in
        Fixed
          {functions = fns,
           relations = rels}
      end;
end;

(* ------------------------------------------------------------------------- *)
(* Renaming fixed model parts.                                               *)
(* ------------------------------------------------------------------------- *)

type fixedMap =
     {functionMap : Name.name NameArityMap.map,
      relationMap : Name.name NameArityMap.map};

fun mapFixed fixMap fix =
    let
      val {functionMap = fnMap, relationMap = relMap} = fixMap
      and Fixed {functions = fns, relations = rels} = fix

      val fns = NameArityMap.compose fnMap fns

      val rels = NameArityMap.compose relMap rels
    in
      Fixed
        {functions = fns,
         relations = rels}
    end;

local
  fun mkEntry tag (na,n) = (tag,na,n);

  fun mkList tag m = map (mkEntry tag) (NameArityMap.toList m);

  fun ppEntry (tag,source_arity,target) =
      Print.blockProgram Print.Inconsistent 2
        [Print.addString tag,
         Print.addBreak 1,
         NameArity.pp source_arity,
         Print.addString " ->",
         Print.addBreak 1,
         Name.pp target];
in
  fun ppFixedMap fixMap =
      let
        val {functionMap = fnMap, relationMap = relMap} = fixMap
      in
        case mkList "function" fnMap @ mkList "relation" relMap of
          [] => Print.skip
        | entry :: entries =>
          Print.blockProgram Print.Consistent 0
            (ppEntry entry ::
             map (Print.sequence Print.addNewline o ppEntry) entries)
      end;
end;

(* ------------------------------------------------------------------------- *)
(* Standard fixed model parts.                                               *)
(* ------------------------------------------------------------------------- *)

(* Projections *)

val projectionMin = 1
and projectionMax = 9;

val projectionList = minMaxInterval projectionMin projectionMax;

fun projectionName i =
    let
      val _ = projectionMin <= i orelse
              raise Bug "Model.projectionName: less than projectionMin"

      val _ = i <= projectionMax orelse
              raise Bug "Model.projectionName: greater than projectionMax"
    in
      Name.fromString ("project" ^ Int.toString i)
    end;

fun projectionFn i _ elts = SOME (List.nth (elts, i - 1));

fun arityProjectionFixed arity =
    let
      fun mkProj i = ((projectionName i, arity), projectionFn i)

      fun addProj i acc =
          if i > arity then acc
          else addProj (i + 1) (NameArityMap.insert acc (mkProj i))

      val fns = addProj projectionMin emptyFunctions

      val rels = emptyRelations
    in
      Fixed
        {functions = fns,
         relations = rels}
    end;

val projectionFixed =
    unionListFixed (map arityProjectionFixed projectionList);

(* Arithmetic *)

val numeralMin = ~100
and numeralMax = 100;

val numeralList = minMaxInterval numeralMin numeralMax;

fun numeralName i =
    let
      val _ = numeralMin <= i orelse
              raise Bug "Model.numeralName: less than numeralMin"

      val _ = i <= numeralMax orelse
              raise Bug "Model.numeralName: greater than numeralMax"

      val s = if i < 0 then "negative" ^ Int.toString (~i) else Int.toString i
    in
      Name.fromString s
    end;

val addName = Name.fromString "+"
and divName = Name.fromString "div"
and dividesName = Name.fromString "divides"
and evenName = Name.fromString "even"
and expName = Name.fromString "exp"
and geName = Name.fromString ">="
and gtName = Name.fromString ">"
and isZeroName = Name.fromString "isZero"
and leName = Name.fromString "<="
and ltName = Name.fromString "<"
and modName = Name.fromString "mod"
and multName = Name.fromString "*"
and negName = Name.fromString "~"
and oddName = Name.fromString "odd"
and preName = Name.fromString "pre"
and subName = Name.fromString "-"
and sucName = Name.fromString "suc";

local
  (* Support *)

  fun modN {size = N} x = x mod N;

  fun oneN sz = modN sz 1;

  fun multN sz (x,y) = modN sz (x * y);

  (* Functions *)

  fun numeralFn i sz = SOME (modN sz i);

  fun addFn sz x y = SOME (modN sz (x + y));

  fun divFn {size = N} x y =
      let
        val y = if y = 0 then N else y
      in
        SOME (x div y)
      end;

  fun expFn sz x y = SOME (exp (multN sz) x y (oneN sz));

  fun modFn {size = N} x y =
      let
        val y = if y = 0 then N else y
      in
        SOME (x mod y)
      end;

  fun multFn sz x y = SOME (multN sz (x,y));

  fun negFn {size = N} x = SOME (if x = 0 then 0 else N - x);

  fun preFn {size = N} x = SOME (if x = 0 then N - 1 else x - 1);

  fun subFn {size = N} x y = SOME (if x < y then N + x - y else x - y);

  fun sucFn {size = N} x = SOME (if x = N - 1 then 0 else x + 1);

  (* Relations *)

  fun dividesRel _ x y = SOME (divides x y);

  fun evenRel _ x = SOME (x mod 2 = 0);

  fun geRel _ x y = SOME (x >= y);

  fun gtRel _ x y = SOME (x > y);

  fun isZeroRel _ x = SOME (x = 0);

  fun leRel _ x y = SOME (x <= y);

  fun ltRel _ x y = SOME (x < y);

  fun oddRel _ x = SOME (x mod 2 = 1);
in
  val modularFixed =
      let
        val fns =
            NameArityMap.fromList
              (map (fn i => ((numeralName i,0), fixed0 (numeralFn i)))
                 numeralList @
               [((addName,2), fixed2 addFn),
                ((divName,2), fixed2 divFn),
                ((expName,2), fixed2 expFn),
                ((modName,2), fixed2 modFn),
                ((multName,2), fixed2 multFn),
                ((negName,1), fixed1 negFn),
                ((preName,1), fixed1 preFn),
                ((subName,2), fixed2 subFn),
                ((sucName,1), fixed1 sucFn)])

        val rels =
            NameArityMap.fromList
              [((dividesName,2), fixed2 dividesRel),
               ((evenName,1), fixed1 evenRel),
               ((geName,2), fixed2 geRel),
               ((gtName,2), fixed2 gtRel),
               ((isZeroName,1), fixed1 isZeroRel),
               ((leName,2), fixed2 leRel),
               ((ltName,2), fixed2 ltRel),
               ((oddName,1), fixed1 oddRel)]
      in
        Fixed
          {functions = fns,
           relations = rels}
      end;
end;

local
  (* Support *)

  fun cutN {size = N} x = if x >= N then N - 1 else x;

  fun oneN sz = cutN sz 1;

  fun multN sz (x,y) = cutN sz (x * y);

  (* Functions *)

  fun numeralFn i sz = if i < 0 then NONE else SOME (cutN sz i);

  fun addFn sz x y = SOME (cutN sz (x + y));

  fun divFn _ x y = if y = 0 then NONE else SOME (x div y);

  fun expFn sz x y = SOME (exp (multN sz) x y (oneN sz));

  fun modFn {size = N} x y =
      if y = 0 orelse x = N - 1 then NONE else SOME (x mod y);

  fun multFn sz x y = SOME (multN sz (x,y));

  fun negFn _ x = if x = 0 then SOME 0 else NONE;

  fun preFn _ x = if x = 0 then NONE else SOME (x - 1);

  fun subFn {size = N} x y =
      if y = 0 then SOME x
      else if x = N - 1 orelse x < y then NONE
      else SOME (x - y);

  fun sucFn sz x = SOME (cutN sz (x + 1));

  (* Relations *)

  fun dividesRel {size = N} x y =
      if x = 1 orelse y = 0 then SOME true
      else if x = 0 then SOME false
      else if y = N - 1 then NONE
      else SOME (divides x y);

  fun evenRel {size = N} x =
      if x = N - 1 then NONE else SOME (x mod 2 = 0);

  fun geRel {size = N} y x =
      if x = N - 1 then if y = N - 1 then NONE else SOME false
      else if y = N - 1 then SOME true else SOME (x <= y);

  fun gtRel {size = N} y x =
      if x = N - 1 then if y = N - 1 then NONE else SOME false
      else if y = N - 1 then SOME true else SOME (x < y);

  fun isZeroRel _ x = SOME (x = 0);

  fun leRel {size = N} x y =
      if x = N - 1 then if y = N - 1 then NONE else SOME false
      else if y = N - 1 then SOME true else SOME (x <= y);

  fun ltRel {size = N} x y =
      if x = N - 1 then if y = N - 1 then NONE else SOME false
      else if y = N - 1 then SOME true else SOME (x < y);

  fun oddRel {size = N} x =
      if x = N - 1 then NONE else SOME (x mod 2 = 1);
in
  val overflowFixed =
      let
        val fns =
            NameArityMap.fromList
              (map (fn i => ((numeralName i,0), fixed0 (numeralFn i)))
                 numeralList @
               [((addName,2), fixed2 addFn),
                ((divName,2), fixed2 divFn),
                ((expName,2), fixed2 expFn),
                ((modName,2), fixed2 modFn),
                ((multName,2), fixed2 multFn),
                ((negName,1), fixed1 negFn),
                ((preName,1), fixed1 preFn),
                ((subName,2), fixed2 subFn),
                ((sucName,1), fixed1 sucFn)])

        val rels =
            NameArityMap.fromList
              [((dividesName,2), fixed2 dividesRel),
               ((evenName,1), fixed1 evenRel),
               ((geName,2), fixed2 geRel),
               ((gtName,2), fixed2 gtRel),
               ((isZeroName,1), fixed1 isZeroRel),
               ((leName,2), fixed2 leRel),
               ((ltName,2), fixed2 ltRel),
               ((oddName,1), fixed1 oddRel)]
      in
        Fixed
          {functions = fns,
           relations = rels}
      end;
end;

(* Sets *)

val cardName = Name.fromString "card"
and complementName = Name.fromString "complement"
and differenceName = Name.fromString "difference"
and emptyName = Name.fromString "empty"
and memberName = Name.fromString "member"
and insertName = Name.fromString "insert"
and intersectName = Name.fromString "intersect"
and singletonName = Name.fromString "singleton"
and subsetName = Name.fromString "subset"
and symmetricDifferenceName = Name.fromString "symmetricDifference"
and unionName = Name.fromString "union"
and universeName = Name.fromString "universe";

local
  (* Support *)

  fun eltN {size = N} =
      let
        fun f 0 acc = acc
          | f x acc = f (x div 2) (acc + 1)
      in
        f N ~1
      end;

  fun posN i = Word.<< (0w1, Word.fromInt i);

  fun univN sz = Word.- (posN (eltN sz), 0w1);

  fun setN sz x = Word.andb (Word.fromInt x, univN sz);

  (* Functions *)

  fun cardFn sz x =
      let
        fun f 0w0 acc = acc
          | f s acc =
            let
              val acc = if Word.andb (s,0w1) = 0w0 then acc else acc + 1
            in
              f (Word.>> (s,0w1)) acc
            end
      in
        SOME (f (setN sz x) 0)
      end;

  fun complementFn sz x = SOME (Word.toInt (Word.xorb (univN sz, setN sz x)));

  fun differenceFn sz x y =
      let
        val x = setN sz x
        and y = setN sz y
      in
        SOME (Word.toInt (Word.andb (x, Word.notb y)))
      end;

  fun emptyFn _ = SOME 0;

  fun insertFn sz x y =
      let
        val x = x mod eltN sz
        and y = setN sz y
      in
        SOME (Word.toInt (Word.orb (posN x, y)))
      end;

  fun intersectFn sz x y =
      SOME (Word.toInt (Word.andb (setN sz x, setN sz y)));

  fun singletonFn sz x =
      let
        val x = x mod eltN sz
      in
        SOME (Word.toInt (posN x))
      end;

  fun symmetricDifferenceFn sz x y =
      let
        val x = setN sz x
        and y = setN sz y
      in
        SOME (Word.toInt (Word.xorb (x,y)))
      end;

  fun unionFn sz x y =
      SOME (Word.toInt (Word.orb (setN sz x, setN sz y)));

  fun universeFn sz = SOME (Word.toInt (univN sz));

  (* Relations *)

  fun memberRel sz x y =
      let
        val x = x mod eltN sz
        and y = setN sz y
      in
        SOME (Word.andb (posN x, y) <> 0w0)
      end;

  fun subsetRel sz x y =
      let
        val x = setN sz x
        and y = setN sz y
      in
        SOME (Word.andb (x, Word.notb y) = 0w0)
      end;
in
  val setFixed =
      let
        val fns =
            NameArityMap.fromList
              [((cardName,1), fixed1 cardFn),
               ((complementName,1), fixed1 complementFn),
               ((differenceName,2), fixed2 differenceFn),
               ((emptyName,0), fixed0 emptyFn),
               ((insertName,2), fixed2 insertFn),
               ((intersectName,2), fixed2 intersectFn),
               ((singletonName,1), fixed1 singletonFn),
               ((symmetricDifferenceName,2), fixed2 symmetricDifferenceFn),
               ((unionName,2), fixed2 unionFn),
               ((universeName,0), fixed0 universeFn)]

        val rels =
            NameArityMap.fromList
              [((memberName,2), fixed2 memberRel),
               ((subsetName,2), fixed2 subsetRel)]
      in
        Fixed
          {functions = fns,
           relations = rels}
      end;
end;

(* Lists *)

val appendName = Name.fromString "@"
and consName = Name.fromString "::"
and lengthName = Name.fromString "length"
and nilName = Name.fromString "nil"
and nullName = Name.fromString "null"
and tailName = Name.fromString "tail";

local
  val baseFix =
      let
        val fix = unionFixed projectionFixed overflowFixed

        val sucFn = getFunctionFixed fix (sucName,1)

        fun suc2Fn sz _ x = sucFn sz [x]
      in
        insertFunctionFixed fix ((sucName,2), fixed2 suc2Fn)
      end;

  val fixMap =
      {functionMap = NameArityMap.fromList
                       [((appendName,2),addName),
                        ((consName,2),sucName),
                        ((lengthName,1), projectionName 1),
                        ((nilName,0), numeralName 0),
                        ((tailName,1),preName)],
       relationMap = NameArityMap.fromList
                       [((nullName,1),isZeroName)]};

in
  val listFixed = mapFixed fixMap baseFix;
end;

(* ------------------------------------------------------------------------- *)
(* Valuations.                                                               *)
(* ------------------------------------------------------------------------- *)

datatype valuation = Valuation of element NameMap.map;

val emptyValuation = Valuation (NameMap.new ());

fun insertValuation (Valuation m) v_i = Valuation (NameMap.insert m v_i);

fun peekValuation (Valuation m) v = NameMap.peek m v;

fun constantValuation i =
    let
      fun add (v,V) = insertValuation V (v,i)
    in
      NameSet.foldl add emptyValuation
    end;

val zeroValuation = constantValuation zeroElement;

fun getValuation V v =
    case peekValuation V v of
      SOME i => i
    | NONE => raise Error "Model.getValuation: incomplete valuation";

fun randomValuation {size = N} vs =
    let
      fun f (v,V) = insertValuation V (v, Portable.randomInt N)
    in
      NameSet.foldl f emptyValuation vs
    end;

fun incrementValuation N vars =
    let
      fun inc vs V =
          case vs of
            [] => NONE
          | v :: vs =>
            let
              val (carry,i) =
                  case incrementElement N (getValuation V v) of
                    SOME i => (false,i)
                  | NONE => (true,zeroElement)

              val V = insertValuation V (v,i)
            in
              if carry then inc vs V else SOME V
            end
    in
      inc (NameSet.toList vars)
    end;

fun foldValuation N vars f =
    let
      val inc = incrementValuation N vars

      fun fold V acc =
          let
            val acc = f (V,acc)
          in
            case inc V of
              NONE => acc
            | SOME V => fold V acc
          end

      val zero = zeroValuation vars
    in
      fold zero
    end;

(* ------------------------------------------------------------------------- *)
(* A type of random finite mapping Z^n -> Z.                                 *)
(* ------------------------------------------------------------------------- *)

val UNKNOWN = ~1;

datatype table =
    ForgetfulTable
  | ArrayTable of int Array.array;

fun newTable N arity =
    case elementListSpace {size = N} arity of
      NONE => ForgetfulTable
    | SOME space => ArrayTable (Array.array (space,UNKNOWN));

local
  fun randomResult R = Portable.randomInt R;
in
  fun lookupTable N R table elts =
      case table of
        ForgetfulTable => randomResult R
      | ArrayTable a =>
        let
          val i = elementListIndex {size = N} elts

          val r = Array.sub (a,i)
        in
          if r <> UNKNOWN then r
          else
            let
              val r = randomResult R

              val () = Array.update (a,i,r)
            in
              r
            end
        end;
end;

fun updateTable N table (elts,r) =
    case table of
      ForgetfulTable => ()
    | ArrayTable a =>
      let
        val i = elementListIndex {size = N} elts

        val () = Array.update (a,i,r)
      in
        ()
      end;

(* ------------------------------------------------------------------------- *)
(* A type of random finite mappings name * arity -> Z^arity -> Z.            *)
(* ------------------------------------------------------------------------- *)

datatype tables =
    Tables of
      {domainSize : int,
       rangeSize : int,
       tableMap : table NameArityMap.map Unsynchronized.ref};

fun newTables N R =
    Tables
      {domainSize = N,
       rangeSize = R,
       tableMap = Unsynchronized.ref (NameArityMap.new ())};

fun getTables tables n_a =
    let
      val Tables {domainSize = N, rangeSize = _, tableMap = tm} = tables

      val Unsynchronized.ref m = tm
    in
      case NameArityMap.peek m n_a of
        SOME t => t
      | NONE =>
        let
          val (_,a) = n_a

          val t = newTable N a

          val m = NameArityMap.insert m (n_a,t)

          val () = tm := m
        in
          t
        end
    end;

fun lookupTables tables (n,elts) =
    let
      val Tables {domainSize = N, rangeSize = R, ...} = tables

      val a = length elts

      val table = getTables tables (n,a)
    in
      lookupTable N R table elts
    end;

fun updateTables tables ((n,elts),r) =
    let
      val Tables {domainSize = N, ...} = tables

      val a = length elts

      val table = getTables tables (n,a)
    in
      updateTable N table (elts,r)
    end;

(* ------------------------------------------------------------------------- *)
(* A type of random finite models.                                           *)
(* ------------------------------------------------------------------------- *)

type parameters = {size : int, fixed : fixed};

datatype model =
    Model of
      {size : int,
       fixedFunctions : (element list -> element option) NameArityMap.map,
       fixedRelations : (element list -> bool option) NameArityMap.map,
       randomFunctions : tables,
       randomRelations : tables};

fun new {size = N, fixed} =
    let
      val Fixed {functions = fns, relations = rels} = fixed

      val fixFns = NameArityMap.transform (fn f => f {size = N}) fns
      and fixRels = NameArityMap.transform (fn r => r {size = N}) rels

      val rndFns = newTables N N
      and rndRels = newTables N 2
    in
      Model
        {size = N,
         fixedFunctions = fixFns,
         fixedRelations = fixRels,
         randomFunctions = rndFns,
         randomRelations = rndRels}
    end;

fun size (Model {size = N, ...}) = N;

fun peekFixedFunction M (n,elts) =
    let
      val Model {fixedFunctions = fixFns, ...} = M
    in
      case NameArityMap.peek fixFns (n, length elts) of
        NONE => NONE
      | SOME fixFn => fixFn elts
    end;

fun isFixedFunction M n_elts = Option.isSome (peekFixedFunction M n_elts);

fun peekFixedRelation M (n,elts) =
    let
      val Model {fixedRelations = fixRels, ...} = M
    in
      case NameArityMap.peek fixRels (n, length elts) of
        NONE => NONE
      | SOME fixRel => fixRel elts
    end;

fun isFixedRelation M n_elts = Option.isSome (peekFixedRelation M n_elts);

(* A default model *)

val defaultSize = 8;

val defaultFixed =
    unionListFixed
      [basicFixed,
       projectionFixed,
       modularFixed,
       setFixed,
       listFixed];

val default = {size = defaultSize, fixed = defaultFixed};

(* ------------------------------------------------------------------------- *)
(* Taking apart terms to interpret them.                                     *)
(* ------------------------------------------------------------------------- *)

fun destTerm tm =
    case tm of
      Term.Var _ => tm
    | Term.Fn f_tms =>
      case Term.stripApp tm of
        (_,[]) => tm
      | (v as Term.Var _, tms) => Term.Fn (Term.appName, v :: tms)
      | (Term.Fn (f,tms), tms') => Term.Fn (f, tms @ tms');

(* ------------------------------------------------------------------------- *)
(* Interpreting terms and formulas in the model.                             *)
(* ------------------------------------------------------------------------- *)

fun interpretFunction M n_elts =
    case peekFixedFunction M n_elts of
      SOME r => r
    | NONE =>
      let
        val Model {randomFunctions = rndFns, ...} = M
      in
        lookupTables rndFns n_elts
      end;

fun interpretRelation M n_elts =
    case peekFixedRelation M n_elts of
      SOME r => r
    | NONE =>
      let
        val Model {randomRelations = rndRels, ...} = M
      in
        intToBool (lookupTables rndRels n_elts)
      end;

fun interpretTerm M V =
    let
      fun interpret tm =
          case destTerm tm of
            Term.Var v => getValuation V v
          | Term.Fn (f,tms) => interpretFunction M (f, map interpret tms)
    in
      interpret
    end;

fun interpretAtom M V (r,tms) =
    interpretRelation M (r, map (interpretTerm M V) tms);

fun interpretFormula M =
    let
      val N = size M

      fun interpret V fm =
          case fm of
            Formula.True => true
          | Formula.False => false
          | Formula.Atom atm => interpretAtom M V atm
          | Formula.Not p => not (interpret V p)
          | Formula.Or (p,q) => interpret V p orelse interpret V q
          | Formula.And (p,q) => interpret V p andalso interpret V q
          | Formula.Imp (p,q) => interpret V (Formula.Or (Formula.Not p, q))
          | Formula.Iff (p,q) => interpret V p = interpret V q
          | Formula.Forall (v,p) => interpret' V p v N
          | Formula.Exists (v,p) =>
            interpret V (Formula.Not (Formula.Forall (v, Formula.Not p)))

      and interpret' V fm v i =
          i = 0 orelse
          let
            val i = i - 1
            val V' = insertValuation V (v,i)
          in
            interpret V' fm andalso interpret' V fm v i
          end
    in
      interpret
    end;

fun interpretLiteral M V (pol,atm) =
    let
      val b = interpretAtom M V atm
    in
      if pol then b else not b
    end;

fun interpretClause M V cl = LiteralSet.exists (interpretLiteral M V) cl;

(* ------------------------------------------------------------------------- *)
(* Check whether random groundings of a formula are true in the model.       *)
(* Note: if it's cheaper, a systematic check will be performed instead.      *)
(* ------------------------------------------------------------------------- *)

fun check interpret {maxChecks} M fv x =
    let
      val N = size M

      fun score (V,{T,F}) =
          if interpret M V x then {T = T + 1, F = F} else {T = T, F = F + 1}

      fun randomCheck acc = score (randomValuation {size = N} fv, acc)

      val maxChecks =
          case maxChecks of
            NONE => maxChecks
          | SOME m =>
            case expInt N (NameSet.size fv) of
              SOME n => if n <= m then NONE else maxChecks
            | NONE => maxChecks
    in
      case maxChecks of
        SOME m => funpow m randomCheck {T = 0, F = 0}
      | NONE => foldValuation {size = N} fv score {T = 0, F = 0}
    end;

fun checkAtom maxChecks M atm =
    check interpretAtom maxChecks M (Atom.freeVars atm) atm;

fun checkFormula maxChecks M fm =
    check interpretFormula maxChecks M (Formula.freeVars fm) fm;

fun checkLiteral maxChecks M lit =
    check interpretLiteral maxChecks M (Literal.freeVars lit) lit;

fun checkClause maxChecks M cl =
    check interpretClause maxChecks M (LiteralSet.freeVars cl) cl;

(* ------------------------------------------------------------------------- *)
(* Updating the model.                                                       *)
(* ------------------------------------------------------------------------- *)

fun updateFunction M func_elts_elt =
    let
      val Model {randomFunctions = rndFns, ...} = M

      val () = updateTables rndFns func_elts_elt
    in
      ()
    end;

fun updateRelation M (rel_elts,pol) =
    let
      val Model {randomRelations = rndRels, ...} = M

      val () = updateTables rndRels (rel_elts, boolToInt pol)
    in
      ()
    end;

(* ------------------------------------------------------------------------- *)
(* A type of terms with interpretations embedded in the subterms.            *)
(* ------------------------------------------------------------------------- *)

datatype modelTerm =
    ModelVar
  | ModelFn of Term.functionName * modelTerm list * int list;

fun modelTerm M V =
    let
      fun modelTm tm =
          case destTerm tm of
            Term.Var v => (ModelVar, getValuation V v)
          | Term.Fn (f,tms) =>
            let
              val (tms,xs) = unzip (map modelTm tms)
            in
              (ModelFn (f,tms,xs), interpretFunction M (f,xs))
            end
    in
      modelTm
    end;

(* ------------------------------------------------------------------------- *)
(* Perturbing the model.                                                     *)
(* ------------------------------------------------------------------------- *)

datatype perturbation =
    FunctionPerturbation of (Term.functionName * element list) * element
  | RelationPerturbation of (Atom.relationName * element list) * bool;

fun perturb M pert =
    case pert of
      FunctionPerturbation func_elts_elt => updateFunction M func_elts_elt
    | RelationPerturbation rel_elts_pol => updateRelation M rel_elts_pol;

local
  fun pertTerm _ [] _ acc = acc
    | pertTerm M target tm acc =
      case tm of
        ModelVar => acc
      | ModelFn (func,tms,xs) =>
        let
          fun onTarget ys = mem (interpretFunction M (func,ys)) target

          val func_xs = (func,xs)

          val acc =
              if isFixedFunction M func_xs then acc
              else
                let
                  fun add (y,acc) = FunctionPerturbation (func_xs,y) :: acc
                in
                  foldl add acc target
                end
        in
          pertTerms M onTarget tms xs acc
        end

  and pertTerms M onTarget =
      let
        val N = size M

        fun filterElements pred =
            let
              fun filt 0 acc = acc
                | filt i acc =
                  let
                    val i = i - 1
                    val acc = if pred i then i :: acc else acc
                  in
                    filt i acc
                  end
            in
              filt N []
            end

        fun pert _ [] [] acc = acc
          | pert ys (tm :: tms) (x :: xs) acc =
            let
              fun pred y =
                  y <> x andalso onTarget (List.revAppend (ys, y :: xs))

              val target = filterElements pred

              val acc = pertTerm M target tm acc
            in
              pert (x :: ys) tms xs acc
            end
          | pert _ _ _ _ = raise Bug "Model.pertTerms.pert"
      in
        pert []
      end;

  fun pertAtom M V target (rel,tms) acc =
      let
        fun onTarget ys = interpretRelation M (rel,ys) = target

        val (tms,xs) = unzip (map (modelTerm M V) tms)

        val rel_xs = (rel,xs)

        val acc =
            if isFixedRelation M rel_xs then acc
            else RelationPerturbation (rel_xs,target) :: acc
      in
        pertTerms M onTarget tms xs acc
      end;

  fun pertLiteral M V ((pol,atm),acc) = pertAtom M V pol atm acc;

  fun pertClause M V cl acc = LiteralSet.foldl (pertLiteral M V) acc cl;

  fun pickPerturb M perts =
      if null perts then ()
      else perturb M (List.nth (perts, Portable.randomInt (length perts)));
in
  fun perturbTerm M V (tm,target) =
      pickPerturb M (pertTerm M target (fst (modelTerm M V tm)) []);

  fun perturbAtom M V (atm,target) =
      pickPerturb M (pertAtom M V target atm []);

  fun perturbLiteral M V lit = pickPerturb M (pertLiteral M V (lit,[]));

  fun perturbClause M V cl = pickPerturb M (pertClause M V cl []);
end;

(* ------------------------------------------------------------------------- *)
(* Pretty printing.                                                          *)
(* ------------------------------------------------------------------------- *)

fun pp M =
    Print.program
      [Print.addString "Model{",
       Print.ppInt (size M),
       Print.addString "}"];

end
end;

(**** Original file: Problem.sig ****)

(* ========================================================================= *)
(* CNF PROBLEMS                                                              *)
(* Copyright (c) 2001-2008 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Problem =
sig

(* ------------------------------------------------------------------------- *)
(* Problems.                                                                 *)
(* ------------------------------------------------------------------------- *)

type problem =
     {axioms : Metis.Thm.clause list,
      conjecture : Metis.Thm.clause list}

val size : problem -> {clauses : int,
                       literals : int,
                       symbols : int,
                       typedSymbols : int}

val freeVars : problem -> Metis.NameSet.set

val toClauses : problem -> Metis.Thm.clause list

val toFormula : problem -> Metis.Formula.formula

val toGoal : problem -> Metis.Formula.formula

val toString : problem -> string

(* ------------------------------------------------------------------------- *)
(* Categorizing problems.                                                    *)
(* ------------------------------------------------------------------------- *)

datatype propositional =
    Propositional
  | EffectivelyPropositional
  | NonPropositional

datatype equality =
    NonEquality
  | Equality
  | PureEquality

datatype horn =
    Trivial
  | Unit
  | DoubleHorn
  | Horn
  | NegativeHorn
  | NonHorn

type category =
     {propositional : propositional,
      equality : equality,
      horn : horn}

val categorize : problem -> category

val categoryToString : category -> string

end

(**** Original file: Problem.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* CNF PROBLEMS                                                              *)
(* Copyright (c) 2001-2008 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Problem :> Problem =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Problems.                                                                 *)
(* ------------------------------------------------------------------------- *)

type problem =
     {axioms : Thm.clause list,
      conjecture : Thm.clause list};

fun toClauses {axioms,conjecture} = axioms @ conjecture;

fun size prob =
    let
      fun lits (cl,n) = n + LiteralSet.size cl

      fun syms (cl,n) = n + LiteralSet.symbols cl

      fun typedSyms (cl,n) = n + LiteralSet.typedSymbols cl

      val cls = toClauses prob
    in
      {clauses = length cls,
       literals = foldl lits 0 cls,
       symbols = foldl syms 0 cls,
       typedSymbols = foldl typedSyms 0 cls}
    end;

fun freeVars {axioms,conjecture} =
    NameSet.union
      (LiteralSet.freeVarsList axioms)
      (LiteralSet.freeVarsList conjecture);

local
  fun clauseToFormula cl =
      Formula.listMkDisj (LiteralSet.transform Literal.toFormula cl);
in
  fun toFormula prob =
      Formula.listMkConj (map clauseToFormula (toClauses prob));

  fun toGoal {axioms,conjecture} =
      let
        val clToFm = Formula.generalize o clauseToFormula
        val clsToFm = Formula.listMkConj o map clToFm

        val fm = Formula.False
        val fm =
            if null conjecture then fm
            else Formula.Imp (clsToFm conjecture, fm)
        val fm = Formula.Imp (clsToFm axioms, fm)
      in
        fm
      end;
end;

fun toString prob = Formula.toString (toFormula prob);

(* ------------------------------------------------------------------------- *)
(* Categorizing problems.                                                    *)
(* ------------------------------------------------------------------------- *)

datatype propositional =
    Propositional
  | EffectivelyPropositional
  | NonPropositional;

datatype equality =
    NonEquality
  | Equality
  | PureEquality;

datatype horn =
    Trivial
  | Unit
  | DoubleHorn
  | Horn
  | NegativeHorn
  | NonHorn;

type category =
     {propositional : propositional,
      equality : equality,
      horn : horn};

fun categorize prob =
    let
      val cls = toClauses prob

      val rels =
          let
            fun f (cl,set) = NameAritySet.union set (LiteralSet.relations cl)
          in
            List.foldl f NameAritySet.empty cls
          end

      val funs =
          let
            fun f (cl,set) = NameAritySet.union set (LiteralSet.functions cl)
          in
            List.foldl f NameAritySet.empty cls
          end

      val propositional =
          if NameAritySet.allNullary rels then Propositional
          else if NameAritySet.allNullary funs then EffectivelyPropositional
          else NonPropositional

      val equality =
          if not (NameAritySet.member Atom.eqRelation rels) then NonEquality
          else if NameAritySet.size rels = 1 then PureEquality
          else Equality

      val horn =
          if List.exists LiteralSet.null cls then Trivial
          else if List.all (fn cl => LiteralSet.size cl = 1) cls then Unit
          else 
            let
              fun pos cl = LiteralSet.count Literal.positive cl <= 1
              fun neg cl = LiteralSet.count Literal.negative cl <= 1
            in
              case (List.all pos cls, List.all neg cls) of
                (true,true) => DoubleHorn
              | (true,false) => Horn
              | (false,true) => NegativeHorn
              | (false,false) => NonHorn
            end
    in
      {propositional = propositional,
       equality = equality,
       horn = horn}
    end;

fun categoryToString {propositional,equality,horn} =
    (case propositional of
       Propositional => "propositional"
     | EffectivelyPropositional => "effectively propositional"
     | NonPropositional => "non-propositional") ^
    ", " ^
    (case equality of
       NonEquality => "non-equality"
     | Equality => "equality"
     | PureEquality => "pure equality") ^
    ", " ^
    (case horn of
       Trivial => "trivial"
     | Unit => "unit"
     | DoubleHorn => "horn (and negative horn)"
     | Horn => "horn"
     | NegativeHorn => "negative horn"
     | NonHorn => "non-horn");

end
end;

(**** Original file: TermNet.sig ****)

(* ========================================================================= *)
(* MATCHING AND UNIFICATION FOR SETS OF FIRST ORDER LOGIC TERMS              *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature TermNet =
sig

(* ------------------------------------------------------------------------- *)
(* A type of term sets that can be efficiently matched and unified.          *)
(* ------------------------------------------------------------------------- *)

type parameters = {fifo : bool}

type 'a termNet

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val new : parameters -> 'a termNet

val null : 'a termNet -> bool

val size : 'a termNet -> int

val insert : 'a termNet -> Metis.Term.term * 'a -> 'a termNet

val fromList : parameters -> (Metis.Term.term * 'a) list -> 'a termNet

val filter : ('a -> bool) -> 'a termNet -> 'a termNet

val toString : 'a termNet -> string

val pp : 'a Metis.Print.pp -> 'a termNet Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* Matching and unification queries.                                         *)
(*                                                                           *)
(* These function return OVER-APPROXIMATIONS!                                *)
(* Filter afterwards to get the precise set of satisfying values.            *)
(* ------------------------------------------------------------------------- *)

val match : 'a termNet -> Metis.Term.term -> 'a list

val matched : 'a termNet -> Metis.Term.term -> 'a list

val unify : 'a termNet -> Metis.Term.term -> 'a list

end

(**** Original file: TermNet.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* MATCHING AND UNIFICATION FOR SETS OF FIRST ORDER LOGIC TERMS              *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure TermNet :> TermNet =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Anonymous variables.                                                      *)
(* ------------------------------------------------------------------------- *)

val anonymousName = Name.fromString "_";
val anonymousVar = Term.Var anonymousName;

(* ------------------------------------------------------------------------- *)
(* Quotient terms.                                                           *)
(* ------------------------------------------------------------------------- *)

datatype qterm =
    Var
  | Fn of NameArity.nameArity * qterm list;

local
  fun cmp [] = EQUAL
    | cmp (q1_q2 :: qs) =
      if Portable.pointerEqual q1_q2 then cmp qs
      else
        case q1_q2 of
          (Var,Var) => EQUAL
        | (Var, Fn _) => LESS
        | (Fn _, Var) => GREATER
        | (Fn f1, Fn f2) => fnCmp f1 f2 qs

  and fnCmp (n1,q1) (n2,q2) qs =
    case NameArity.compare (n1,n2) of
      LESS => LESS
    | EQUAL => cmp (zip q1 q2 @ qs)
    | GREATER => GREATER;
in
  fun compareQterm q1_q2 = cmp [q1_q2];

  fun compareFnQterm (f1,f2) = fnCmp f1 f2 [];
end;

fun equalQterm q1 q2 = compareQterm (q1,q2) = EQUAL;

fun equalFnQterm f1 f2 = compareFnQterm (f1,f2) = EQUAL;

fun termToQterm (Term.Var _) = Var
  | termToQterm (Term.Fn (f,l)) = Fn ((f, length l), map termToQterm l);

local
  fun qm [] = true
    | qm ((Var,_) :: rest) = qm rest
    | qm ((Fn _, Var) :: _) = false
    | qm ((Fn (f,a), Fn (g,b)) :: rest) =
      NameArity.equal f g andalso qm (zip a b @ rest);
in
  fun matchQtermQterm qtm qtm' = qm [(qtm,qtm')];
end;

local
  fun qm [] = true
    | qm ((Var,_) :: rest) = qm rest
    | qm ((Fn _, Term.Var _) :: _) = false
    | qm ((Fn ((f,n),a), Term.Fn (g,b)) :: rest) =
      Name.equal f g andalso n = length b andalso qm (zip a b @ rest);
in
  fun matchQtermTerm qtm tm = qm [(qtm,tm)];
end;

local
  fun qn qsub [] = SOME qsub
    | qn qsub ((Term.Var v, qtm) :: rest) =
      (case NameMap.peek qsub v of
         NONE => qn (NameMap.insert qsub (v,qtm)) rest
       | SOME qtm' => if equalQterm qtm qtm' then qn qsub rest else NONE)
    | qn _ ((Term.Fn _, Var) :: _) = NONE
    | qn qsub ((Term.Fn (f,a), Fn ((g,n),b)) :: rest) =
      if Name.equal f g andalso length a = n then qn qsub (zip a b @ rest)
      else NONE;
in
  fun matchTermQterm qsub tm qtm = qn qsub [(tm,qtm)];
end;

local
  fun qv Var x = x
    | qv x Var = x
    | qv (Fn (f,a)) (Fn (g,b)) =
      let
        val _ = NameArity.equal f g orelse raise Error "TermNet.qv"
      in
        Fn (f, zipWith qv a b)
      end;

  fun qu qsub [] = qsub
    | qu qsub ((Var, _) :: rest) = qu qsub rest
    | qu qsub ((qtm, Term.Var v) :: rest) =
      let
        val qtm =
            case NameMap.peek qsub v of NONE => qtm | SOME qtm' => qv qtm qtm'
      in
        qu (NameMap.insert qsub (v,qtm)) rest
      end
    | qu qsub ((Fn ((f,n),a), Term.Fn (g,b)) :: rest) =
      if Name.equal f g andalso n = length b then qu qsub (zip a b @ rest)
      else raise Error "TermNet.qu";
in
  fun unifyQtermQterm qtm qtm' = total (qv qtm) qtm';

  fun unifyQtermTerm qsub qtm tm = total (qu qsub) [(qtm,tm)];
end;

local
  fun qtermToTerm Var = anonymousVar
    | qtermToTerm (Fn ((f,_),l)) = Term.Fn (f, map qtermToTerm l);
in
  val ppQterm = Print.ppMap qtermToTerm Term.pp;
end;

(* ------------------------------------------------------------------------- *)
(* A type of term sets that can be efficiently matched and unified.          *)
(* ------------------------------------------------------------------------- *)

type parameters = {fifo : bool};

datatype 'a net =
    Result of 'a list
  | Single of qterm * 'a net
  | Multiple of 'a net option * 'a net NameArityMap.map;

datatype 'a termNet = Net of parameters * int * (int * (int * 'a) net) option;

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

fun new parm = Net (parm,0,NONE);

local
  fun computeSize (Result l) = length l
    | computeSize (Single (_,n)) = computeSize n
    | computeSize (Multiple (vs,fs)) =
      NameArityMap.foldl
        (fn (_,n,acc) => acc + computeSize n)
        (case vs of SOME n => computeSize n | NONE => 0)
        fs;
in
  fun netSize NONE = NONE
    | netSize (SOME n) = SOME (computeSize n, n);
end;

fun size (Net (_,_,NONE)) = 0
  | size (Net (_, _, SOME (i,_))) = i;

fun null net = size net = 0;

fun singles qtms a = foldr Single a qtms;

local
  fun pre NONE = (0,NONE)
    | pre (SOME (i,n)) = (i, SOME n);

  fun add (Result l) [] (Result l') = Result (l @ l')
    | add a (input1 as qtm :: qtms) (Single (qtm',n)) =
      if equalQterm qtm qtm' then Single (qtm, add a qtms n)
      else add a input1 (add n [qtm'] (Multiple (NONE, NameArityMap.new ())))
    | add a (Var :: qtms) (Multiple (vs,fs)) =
      Multiple (SOME (oadd a qtms vs), fs)
    | add a (Fn (f,l) :: qtms) (Multiple (vs,fs)) =
      let
        val n = NameArityMap.peek fs f
      in
        Multiple (vs, NameArityMap.insert fs (f, oadd a (l @ qtms) n))
      end
    | add _ _ _ = raise Bug "TermNet.insert: Match"

  and oadd a qtms NONE = singles qtms a
    | oadd a qtms (SOME n) = add a qtms n;

  fun ins a qtm (i,n) = SOME (i + 1, oadd (Result [a]) [qtm] n);
in
  fun insert (Net (p,k,n)) (tm,a) =
      Net (p, k + 1, ins (k,a) (termToQterm tm) (pre n))
      handle Error _ => raise Bug "TermNet.insert: should never fail";
end;

fun fromList parm l = foldl (fn (tm_a,n) => insert n tm_a) (new parm) l;

fun filter pred =
    let
      fun filt (Result l) =
          (case List.filter (fn (_,a) => pred a) l of
             [] => NONE
           | l => SOME (Result l))
        | filt (Single (qtm,n)) =
          (case filt n of
             NONE => NONE
           | SOME n => SOME (Single (qtm,n)))
        | filt (Multiple (vs,fs)) =
          let
            val vs = Option.mapPartial filt vs

            val fs = NameArityMap.mapPartial (fn (_,n) => filt n) fs
          in
            if not (Option.isSome vs) andalso NameArityMap.null fs then NONE
            else SOME (Multiple (vs,fs))
          end
    in
      fn net as Net (_,_,NONE) => net
       | Net (p, k, SOME (_,n)) => Net (p, k, netSize (filt n))
    end
    handle Error _ => raise Bug "TermNet.filter: should never fail";

fun toString net = "TermNet[" ^ Int.toString (size net) ^ "]";

(* ------------------------------------------------------------------------- *)
(* Specialized fold operations to support matching and unification.          *)
(* ------------------------------------------------------------------------- *)

local
  fun norm (0 :: ks, (f as (_,n)) :: fs, qtms) =
      let
        val (a,qtms) = revDivide qtms n
      in
        addQterm (Fn (f,a)) (ks,fs,qtms)
      end
    | norm stack = stack

  and addQterm qtm (ks,fs,qtms) =
      let
        val ks = case ks of [] => [] | k :: ks => (k - 1) :: ks
      in
        norm (ks, fs, qtm :: qtms)
      end

  and addFn (f as (_,n)) (ks,fs,qtms) = norm (n :: ks, f :: fs, qtms);
in
  val stackEmpty = ([],[],[]);

  val stackAddQterm = addQterm;

  val stackAddFn = addFn;

  fun stackValue ([],[],[qtm]) = qtm
    | stackValue _ = raise Bug "TermNet.stackValue";
end;

local
  fun fold _ acc [] = acc
    | fold inc acc ((0,stack,net) :: rest) =
      fold inc (inc (stackValue stack, net, acc)) rest
    | fold inc acc ((n, stack, Single (qtm,net)) :: rest) =
      fold inc acc ((n - 1, stackAddQterm qtm stack, net) :: rest)
    | fold inc acc ((n, stack, Multiple (v,fns)) :: rest) =
      let
        val n = n - 1

        val rest =
            case v of
              NONE => rest
            | SOME net => (n, stackAddQterm Var stack, net) :: rest

        fun getFns (f as (_,k), net, x) =
            (k + n, stackAddFn f stack, net) :: x
      in
        fold inc acc (NameArityMap.foldr getFns rest fns)
      end
    | fold _ _ _ = raise Bug "TermNet.foldTerms.fold";
in
  fun foldTerms inc acc net = fold inc acc [(1,stackEmpty,net)];
end;

fun foldEqualTerms pat inc acc =
    let
      fun fold ([],net) = inc (pat,net,acc)
        | fold (pat :: pats, Single (qtm,net)) =
          if equalQterm pat qtm then fold (pats,net) else acc
        | fold (Var :: pats, Multiple (v,_)) =
          (case v of NONE => acc | SOME net => fold (pats,net))
        | fold (Fn (f,a) :: pats, Multiple (_,fns)) =
          (case NameArityMap.peek fns f of
             NONE => acc
           | SOME net => fold (a @ pats, net))
        | fold _ = raise Bug "TermNet.foldEqualTerms.fold";
    in
      fn net => fold ([pat],net)
    end;

local
  fun fold _ acc [] = acc
    | fold inc acc (([],stack,net) :: rest) =
      fold inc (inc (stackValue stack, net, acc)) rest
    | fold inc acc ((Var :: pats, stack, net) :: rest) =
      let
        fun harvest (qtm,n,l) = (pats, stackAddQterm qtm stack, n) :: l
      in
        fold inc acc (foldTerms harvest rest net)
      end
    | fold inc acc ((pat :: pats, stack, Single (qtm,net)) :: rest) =
      (case unifyQtermQterm pat qtm of
         NONE => fold inc acc rest
       | SOME qtm =>
         fold inc acc ((pats, stackAddQterm qtm stack, net) :: rest))
    | fold
        inc acc
        (((pat as Fn (f,a)) :: pats, stack, Multiple (v,fns)) :: rest) =
      let
        val rest =
            case v of
              NONE => rest
            | SOME net => (pats, stackAddQterm pat stack, net) :: rest

        val rest =
            case NameArityMap.peek fns f of
              NONE => rest
            | SOME net => (a @ pats, stackAddFn f stack, net) :: rest
      in
        fold inc acc rest
      end
    | fold _ _ _ = raise Bug "TermNet.foldUnifiableTerms.fold";
in
  fun foldUnifiableTerms pat inc acc net =
      fold inc acc [([pat],stackEmpty,net)];
end;

(* ------------------------------------------------------------------------- *)
(* Matching and unification queries.                                         *)
(*                                                                           *)
(* These function return OVER-APPROXIMATIONS!                                *)
(* Filter afterwards to get the precise set of satisfying values.            *)
(* ------------------------------------------------------------------------- *)

local
  fun idwise ((m,_),(n,_)) = Int.compare (m,n);

  fun fifoize ({fifo, ...} : parameters) l = if fifo then sort idwise l else l;
in
  fun finally parm l = map snd (fifoize parm l);
end;

local
  fun mat acc [] = acc
    | mat acc ((Result l, []) :: rest) = mat (l @ acc) rest
    | mat acc ((Single (qtm,n), tm :: tms) :: rest) =
      mat acc (if matchQtermTerm qtm tm then (n,tms) :: rest else rest)
    | mat acc ((Multiple (vs,fs), tm :: tms) :: rest) =
      let
        val rest = case vs of NONE => rest | SOME n => (n,tms) :: rest

        val rest =
            case tm of
              Term.Var _ => rest
            | Term.Fn (f,l) =>
              case NameArityMap.peek fs (f, length l) of
                NONE => rest
              | SOME n => (n, l @ tms) :: rest
      in
        mat acc rest
      end
    | mat _ _ = raise Bug "TermNet.match: Match";
in
  fun match (Net (_,_,NONE)) _ = []
    | match (Net (p, _, SOME (_,n))) tm =
      finally p (mat [] [(n,[tm])])
      handle Error _ => raise Bug "TermNet.match: should never fail";
end;

local
  fun unseenInc qsub v tms (qtm,net,rest) =
      (NameMap.insert qsub (v,qtm), net, tms) :: rest;

  fun seenInc qsub tms (_,net,rest) = (qsub,net,tms) :: rest;

  fun mat acc [] = acc
    | mat acc ((_, Result l, []) :: rest) = mat (l @ acc) rest
    | mat acc ((qsub, Single (qtm,net), tm :: tms) :: rest) =
      (case matchTermQterm qsub tm qtm of
         NONE => mat acc rest
       | SOME qsub => mat acc ((qsub,net,tms) :: rest))
    | mat acc ((qsub, net as Multiple _, Term.Var v :: tms) :: rest) =
      (case NameMap.peek qsub v of
         NONE => mat acc (foldTerms (unseenInc qsub v tms) rest net)
       | SOME qtm => mat acc (foldEqualTerms qtm (seenInc qsub tms) rest net))
    | mat acc ((qsub, Multiple (_,fns), Term.Fn (f,a) :: tms) :: rest) =
      let
        val rest =
            case NameArityMap.peek fns (f, length a) of
              NONE => rest
            | SOME net => (qsub, net, a @ tms) :: rest
      in
        mat acc rest
      end
    | mat _ _ = raise Bug "TermNet.matched.mat";
in
  fun matched (Net (_,_,NONE)) _ = []
    | matched (Net (parm, _, SOME (_,net))) tm =
      finally parm (mat [] [(NameMap.new (), net, [tm])])
      handle Error _ => raise Bug "TermNet.matched: should never fail";
end;

local
  fun inc qsub v tms (qtm,net,rest) =
      (NameMap.insert qsub (v,qtm), net, tms) :: rest;

  fun mat acc [] = acc
    | mat acc ((_, Result l, []) :: rest) = mat (l @ acc) rest
    | mat acc ((qsub, Single (qtm,net), tm :: tms) :: rest) =
      (case unifyQtermTerm qsub qtm tm of
         NONE => mat acc rest
       | SOME qsub => mat acc ((qsub,net,tms) :: rest))
    | mat acc ((qsub, net as Multiple _, Term.Var v :: tms) :: rest) =
      (case NameMap.peek qsub v of
         NONE => mat acc (foldTerms (inc qsub v tms) rest net)
       | SOME qtm => mat acc (foldUnifiableTerms qtm (inc qsub v tms) rest net))
    | mat acc ((qsub, Multiple (v,fns), Term.Fn (f,a) :: tms) :: rest) =
      let
        val rest = case v of NONE => rest | SOME net => (qsub,net,tms) :: rest

        val rest =
            case NameArityMap.peek fns (f, length a) of
              NONE => rest
            | SOME net => (qsub, net, a @ tms) :: rest
      in
        mat acc rest
      end
    | mat _ _ = raise Bug "TermNet.unify.mat";
in
  fun unify (Net (_,_,NONE)) _ = []
    | unify (Net (parm, _, SOME (_,net))) tm =
      finally parm (mat [] [(NameMap.new (), net, [tm])])
      handle Error _ => raise Bug "TermNet.unify: should never fail";
end;

(* ------------------------------------------------------------------------- *)
(* Pretty printing.                                                          *)
(* ------------------------------------------------------------------------- *)

local
  fun inc (qtm, Result l, acc) =
      foldl (fn ((n,a),acc) => (n,(qtm,a)) :: acc) acc l
    | inc _ = raise Bug "TermNet.pp.inc";

  fun toList (Net (_,_,NONE)) = []
    | toList (Net (parm, _, SOME (_,net))) =
      finally parm (foldTerms inc [] net);
in
  fun pp ppA =
      Print.ppMap toList (Print.ppList (Print.ppOp2 " |->" ppQterm ppA));
end;

end
end;

(**** Original file: AtomNet.sig ****)

(* ========================================================================= *)
(* MATCHING AND UNIFICATION FOR SETS OF FIRST ORDER LOGIC ATOMS              *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature AtomNet =
sig

(* ------------------------------------------------------------------------- *)
(* A type of atom sets that can be efficiently matched and unified.          *)
(* ------------------------------------------------------------------------- *)

type parameters = {fifo : bool}

type 'a atomNet

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val new : parameters -> 'a atomNet

val size : 'a atomNet -> int

val insert : 'a atomNet -> Metis.Atom.atom * 'a -> 'a atomNet

val fromList : parameters -> (Metis.Atom.atom * 'a) list -> 'a atomNet

val filter : ('a -> bool) -> 'a atomNet -> 'a atomNet

val toString : 'a atomNet -> string

val pp : 'a Metis.Print.pp -> 'a atomNet Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* Matching and unification queries.                                         *)
(*                                                                           *)
(* These function return OVER-APPROXIMATIONS!                                *)
(* Filter afterwards to get the precise set of satisfying values.            *)
(* ------------------------------------------------------------------------- *)

val match : 'a atomNet -> Metis.Atom.atom -> 'a list

val matched : 'a atomNet -> Metis.Atom.atom -> 'a list

val unify : 'a atomNet -> Metis.Atom.atom -> 'a list

end

(**** Original file: AtomNet.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* MATCHING AND UNIFICATION FOR SETS OF FIRST ORDER LOGIC ATOMS              *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure AtomNet :> AtomNet =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Helper functions.                                                         *)
(* ------------------------------------------------------------------------- *)

fun atomToTerm atom = Term.Fn atom;

fun termToAtom (Term.Var _) = raise Bug "AtomNet.termToAtom"
  | termToAtom (Term.Fn atom) = atom;

(* ------------------------------------------------------------------------- *)
(* A type of atom sets that can be efficiently matched and unified.          *)
(* ------------------------------------------------------------------------- *)

type parameters = TermNet.parameters;

type 'a atomNet = 'a TermNet.termNet;

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val new = TermNet.new;

val size = TermNet.size;

fun insert net (atm,a) = TermNet.insert net (atomToTerm atm, a);

fun fromList parm l = foldl (fn (atm_a,n) => insert n atm_a) (new parm) l;

val filter = TermNet.filter;

fun toString net = "AtomNet[" ^ Int.toString (size net) ^ "]";

val pp = TermNet.pp;

(* ------------------------------------------------------------------------- *)
(* Matching and unification queries.                                         *)
(*                                                                           *)
(* These function return OVER-APPROXIMATIONS!                                *)
(* Filter afterwards to get the precise set of satisfying values.            *)
(* ------------------------------------------------------------------------- *)

fun match net atm = TermNet.match net (atomToTerm atm);

fun matched net atm = TermNet.matched net (atomToTerm atm);

fun unify net atm = TermNet.unify net (atomToTerm atm);

end
end;

(**** Original file: LiteralNet.sig ****)

(* ========================================================================= *)
(* MATCHING AND UNIFICATION FOR SETS OF FIRST ORDER LOGIC LITERALS           *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature LiteralNet =
sig

(* ------------------------------------------------------------------------- *)
(* A type of literal sets that can be efficiently matched and unified.       *)
(* ------------------------------------------------------------------------- *)

type parameters = {fifo : bool}

type 'a literalNet

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val new : parameters -> 'a literalNet

val size : 'a literalNet -> int

val profile : 'a literalNet -> {positive : int, negative : int}

val insert : 'a literalNet -> Metis.Literal.literal * 'a -> 'a literalNet

val fromList : parameters -> (Metis.Literal.literal * 'a) list -> 'a literalNet

val filter : ('a -> bool) -> 'a literalNet -> 'a literalNet

val toString : 'a literalNet -> string

val pp : 'a Metis.Print.pp -> 'a literalNet Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* Matching and unification queries.                                         *)
(*                                                                           *)
(* These function return OVER-APPROXIMATIONS!                                *)
(* Filter afterwards to get the precise set of satisfying values.            *)
(* ------------------------------------------------------------------------- *)

val match : 'a literalNet -> Metis.Literal.literal -> 'a list

val matched : 'a literalNet -> Metis.Literal.literal -> 'a list

val unify : 'a literalNet -> Metis.Literal.literal -> 'a list

end

(**** Original file: LiteralNet.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* MATCHING AND UNIFICATION FOR SETS OF FIRST ORDER LOGIC LITERALS           *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure LiteralNet :> LiteralNet =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of literal sets that can be efficiently matched and unified.       *)
(* ------------------------------------------------------------------------- *)

type parameters = AtomNet.parameters;

type 'a literalNet =
    {positive : 'a AtomNet.atomNet,
     negative : 'a AtomNet.atomNet};

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

fun new parm = {positive = AtomNet.new parm, negative = AtomNet.new parm};

local
  fun pos ({positive,...} : 'a literalNet) = AtomNet.size positive;

  fun neg ({negative,...} : 'a literalNet) = AtomNet.size negative;
in
  fun size net = pos net + neg net;

  fun profile net = {positive = pos net, negative = neg net};
end;

fun insert {positive,negative} ((true,atm),a) =
    {positive = AtomNet.insert positive (atm,a), negative = negative}
  | insert {positive,negative} ((false,atm),a) =
    {positive = positive, negative = AtomNet.insert negative (atm,a)};

fun fromList parm l = foldl (fn (lit_a,n) => insert n lit_a) (new parm) l;

fun filter pred {positive,negative} =
    {positive = AtomNet.filter pred positive,
     negative = AtomNet.filter pred negative};

fun toString net = "LiteralNet[" ^ Int.toString (size net) ^ "]";

fun pp ppA =
    Print.ppMap
      (fn {positive,negative} => (positive,negative))
      (Print.ppOp2 " + NEGATIVE" (AtomNet.pp ppA) (AtomNet.pp ppA));

(* ------------------------------------------------------------------------- *)
(* Matching and unification queries.                                         *)
(*                                                                           *)
(* These function return OVER-APPROXIMATIONS!                                *)
(* Filter afterwards to get the precise set of satisfying values.            *)
(* ------------------------------------------------------------------------- *)

fun match ({positive,...} : 'a literalNet) (true,atm) =
    AtomNet.match positive atm
  | match {negative,...} (false,atm) = AtomNet.match negative atm;

fun matched ({positive,...} : 'a literalNet) (true,atm) =
    AtomNet.matched positive atm
  | matched {negative,...} (false,atm) = AtomNet.matched negative atm;

fun unify ({positive,...} : 'a literalNet) (true,atm) =
    AtomNet.unify positive atm
  | unify {negative,...} (false,atm) = AtomNet.unify negative atm;

end
end;

(**** Original file: Subsume.sig ****)

(* ========================================================================= *)
(* SUBSUMPTION CHECKING FOR FIRST ORDER LOGIC CLAUSES                        *)
(* Copyright (c) 2002-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Subsume =
sig

(* ------------------------------------------------------------------------- *)
(* A type of clause sets that supports efficient subsumption checking.       *)
(* ------------------------------------------------------------------------- *)

type 'a subsume

val new : unit -> 'a subsume

val size : 'a subsume -> int

val insert : 'a subsume -> Metis.Thm.clause * 'a -> 'a subsume

val filter : ('a -> bool) -> 'a subsume -> 'a subsume

val pp : 'a subsume Metis.Print.pp

val toString : 'a subsume -> string

(* ------------------------------------------------------------------------- *)
(* Subsumption checking.                                                     *)
(* ------------------------------------------------------------------------- *)

val subsumes :
    (Metis.Thm.clause * Metis.Subst.subst * 'a -> bool) -> 'a subsume -> Metis.Thm.clause ->
    (Metis.Thm.clause * Metis.Subst.subst * 'a) option

val isSubsumed : 'a subsume -> Metis.Thm.clause -> bool

val strictlySubsumes :  (* exclude subsuming clauses with more literals *)
    (Metis.Thm.clause * Metis.Subst.subst * 'a -> bool) -> 'a subsume -> Metis.Thm.clause ->
    (Metis.Thm.clause * Metis.Subst.subst * 'a) option

val isStrictlySubsumed : 'a subsume -> Metis.Thm.clause -> bool

(* ------------------------------------------------------------------------- *)
(* Single clause versions.                                                   *)
(* ------------------------------------------------------------------------- *)

val clauseSubsumes : Metis.Thm.clause -> Metis.Thm.clause -> Metis.Subst.subst option

val clauseStrictlySubsumes : Metis.Thm.clause -> Metis.Thm.clause -> Metis.Subst.subst option

end

(**** Original file: Subsume.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* SUBSUMPTION CHECKING FOR FIRST ORDER LOGIC CLAUSES                        *)
(* Copyright (c) 2002-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Subsume :> Subsume =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Helper functions.                                                         *)
(* ------------------------------------------------------------------------- *)

fun findRest pred =
    let
      fun f _ [] = NONE
        | f ys (x :: xs) =
          if pred x then SOME (x, List.revAppend (ys,xs)) else f (x :: ys) xs
    in
      f []
    end;

local
  fun addSym (lit,acc) =
      case total Literal.sym lit of
        NONE => acc
      | SOME lit => lit :: acc
in
  fun clauseSym lits = List.foldl addSym lits lits;
end;

fun sortClause cl =
    let
      val lits = LiteralSet.toList cl
    in
      sortMap Literal.typedSymbols (revCompare Int.compare) lits
    end;

fun incompatible lit =
    let
      val lits = clauseSym [lit]
    in
      fn lit' => not (List.exists (can (Literal.unify Subst.empty lit')) lits)
    end;

(* ------------------------------------------------------------------------- *)
(* Clause ids and lengths.                                                   *)
(* ------------------------------------------------------------------------- *)

type clauseId = int;

type clauseLength = int;

local
  type idSet = (clauseId * clauseLength) Set.set;

  fun idCompare ((id1,len1),(id2,len2)) =
      case Int.compare (len1,len2) of
        LESS => LESS
      | EQUAL => Int.compare (id1,id2)
      | GREATER => GREATER;
in
  val idSetEmpty : idSet = Set.empty idCompare;

  fun idSetAdd (id_len,set) : idSet = Set.add set id_len;

  fun idSetAddMax max (id_len as (_,len), set) : idSet =
      if len <= max then Set.add set id_len else set;

  fun idSetIntersect set1 set2 : idSet = Set.intersect set1 set2;
end;

(* ------------------------------------------------------------------------- *)
(* A type of clause sets that supports efficient subsumption checking.       *)
(* ------------------------------------------------------------------------- *)

datatype 'a subsume =
    Subsume of
      {empty : (Thm.clause * Subst.subst * 'a) list,
       unit : (Literal.literal * Thm.clause * 'a)  LiteralNet.literalNet,
       nonunit :
         {nextId : clauseId,
          clauses : (Literal.literal list * Thm.clause * 'a) IntMap.map,
          fstLits : (clauseId * clauseLength) LiteralNet.literalNet,
          sndLits : (clauseId * clauseLength) LiteralNet.literalNet}};

fun new () =
    Subsume
      {empty = [],
       unit = LiteralNet.new {fifo = false},
       nonunit =
         {nextId = 0,
          clauses = IntMap.new (),
          fstLits = LiteralNet.new {fifo = false},
          sndLits = LiteralNet.new {fifo = false}}};

fun size (Subsume {empty, unit, nonunit = {clauses,...}}) =
    length empty + LiteralNet.size unit + IntMap.size clauses;
      
fun insert (Subsume {empty,unit,nonunit}) (cl',a) =
    case sortClause cl' of
      [] =>
      let
        val empty = (cl',Subst.empty,a) :: empty
      in
        Subsume {empty = empty, unit = unit, nonunit = nonunit}
      end
    | [lit] =>
      let
        val unit = LiteralNet.insert unit (lit,(lit,cl',a))
      in
        Subsume {empty = empty, unit = unit, nonunit = nonunit}
      end
    | fstLit :: (nonFstLits as sndLit :: otherLits) =>
      let
        val {nextId,clauses,fstLits,sndLits} = nonunit
        val id_length = (nextId, LiteralSet.size cl')
        val fstLits = LiteralNet.insert fstLits (fstLit,id_length)
        val (sndLit,otherLits) =
            case findRest (incompatible fstLit) nonFstLits of
              SOME sndLit_otherLits => sndLit_otherLits
            | NONE => (sndLit,otherLits)
        val sndLits = LiteralNet.insert sndLits (sndLit,id_length)
        val lits' = otherLits @ [fstLit,sndLit]
        val clauses = IntMap.insert clauses (nextId,(lits',cl',a))
        val nextId = nextId + 1
        val nonunit = {nextId = nextId, clauses = clauses,
                       fstLits = fstLits, sndLits = sndLits}
      in
        Subsume {empty = empty, unit = unit, nonunit = nonunit}
      end;

fun filter pred (Subsume {empty,unit,nonunit}) =
    let
      val empty = List.filter (pred o #3) empty

      val unit = LiteralNet.filter (pred o #3) unit

      val nonunit =
          let
            val {nextId,clauses,fstLits,sndLits} = nonunit
            val clauses' = IntMap.filter (pred o #3 o snd) clauses
          in
            if IntMap.size clauses = IntMap.size clauses' then nonunit
            else
              let
                fun predId (id,_) = IntMap.inDomain id clauses'
                val fstLits = LiteralNet.filter predId fstLits
                and sndLits = LiteralNet.filter predId sndLits
              in
                {nextId = nextId, clauses = clauses',
                 fstLits = fstLits, sndLits = sndLits}
              end
          end
    in
      Subsume {empty = empty, unit = unit, nonunit = nonunit}
    end;

fun toString subsume = "Subsume{" ^ Int.toString (size subsume) ^ "}";

fun pp subsume = Print.ppMap toString Print.ppString subsume;

(* ------------------------------------------------------------------------- *)
(* Subsumption checking.                                                     *)
(* ------------------------------------------------------------------------- *)

local
  fun matchLit lit' (lit,acc) =
      case total (Literal.match Subst.empty lit') lit of
        SOME sub => sub :: acc
      | NONE => acc;
in
  fun genClauseSubsumes pred cl' lits' cl a =
      let
        fun mkSubsl acc sub [] = SOME (sub, sortMap length Int.compare acc)
          | mkSubsl acc sub (lit' :: lits') =
            case List.foldl (matchLit lit') [] cl of
              [] => NONE
            | [sub'] =>
              (case total (Subst.union sub) sub' of
                 NONE => NONE
               | SOME sub => mkSubsl acc sub lits')
            | subs => mkSubsl (subs :: acc) sub lits'

        fun search [] = NONE
          | search ((sub,[]) :: others) =
            let
              val x = (cl',sub,a)
            in
              if pred x then SOME x else search others
            end
          | search ((_, [] :: _) :: others) = search others
          | search ((sub, (sub' :: subs) :: subsl) :: others) =
            let
              val others = (sub, subs :: subsl) :: others
            in
              case total (Subst.union sub) sub' of
                NONE => search others
              | SOME sub => search ((sub,subsl) :: others)
            end
      in
        case mkSubsl [] Subst.empty lits' of
          NONE => NONE
        | SOME sub_subsl => search [sub_subsl]
      end;
end;

local
  fun emptySubsumes pred empty = List.find pred empty;

  fun unitSubsumes pred unit =
      let
        fun subLit lit =
            let
              fun subUnit (lit',cl',a) =
                  case total (Literal.match Subst.empty lit') lit of
                    NONE => NONE
                  | SOME sub =>
                    let
                      val x = (cl',sub,a)
                    in
                      if pred x then SOME x else NONE
                    end
            in
              first subUnit (LiteralNet.match unit lit)
            end
      in
        first subLit
      end;

  fun nonunitSubsumes pred nonunit max cl =
      let
        val addId = case max of NONE => idSetAdd | SOME n => idSetAddMax n

        fun subLit lits (lit,acc) =
            List.foldl addId acc (LiteralNet.match lits lit)

        val {nextId = _, clauses, fstLits, sndLits} = nonunit

        fun subCl' (id,_) =
            let
              val (lits',cl',a) = IntMap.get clauses id
            in
              genClauseSubsumes pred cl' lits' cl a
            end

        val fstCands = List.foldl (subLit fstLits) idSetEmpty cl
        val sndCands = List.foldl (subLit sndLits) idSetEmpty cl
        val cands = idSetIntersect fstCands sndCands
      in
        Set.firstl subCl' cands
      end;

  fun genSubsumes pred (Subsume {empty,unit,nonunit}) max cl =
      case emptySubsumes pred empty of
        s as SOME _ => s
      | NONE =>
        if max = SOME 0 then NONE
        else
          let
            val cl = clauseSym (LiteralSet.toList cl)
          in
            case unitSubsumes pred unit cl of
              s as SOME _ => s
            | NONE =>
              if max = SOME 1 then NONE
              else nonunitSubsumes pred nonunit max cl
          end;
in
  fun subsumes pred subsume cl = genSubsumes pred subsume NONE cl;

  fun strictlySubsumes pred subsume cl =
      genSubsumes pred subsume (SOME (LiteralSet.size cl)) cl;
end;

(*MetisTrace4
val subsumes = fn pred => fn subsume => fn cl =>
    let
      val ppCl = LiteralSet.pp
      val ppSub = Subst.pp
      val () = Print.trace ppCl "Subsume.subsumes: cl" cl
      val result = subsumes pred subsume cl
      val () =
          case result of
            NONE => trace "Subsume.subsumes: not subsumed\n"
          | SOME (cl,sub,_) =>
            (Print.trace ppCl "Subsume.subsumes: subsuming cl" cl;
             Print.trace ppSub "Subsume.subsumes: subsuming sub" sub)
    in
      result
    end;

val strictlySubsumes = fn pred => fn subsume => fn cl =>
    let
      val ppCl = LiteralSet.pp
      val ppSub = Subst.pp
      val () = Print.trace ppCl "Subsume.strictlySubsumes: cl" cl
      val result = strictlySubsumes pred subsume cl
      val () =
          case result of
            NONE => trace "Subsume.subsumes: not subsumed\n"
          | SOME (cl,sub,_) =>
            (Print.trace ppCl "Subsume.subsumes: subsuming cl" cl;
             Print.trace ppSub "Subsume.subsumes: subsuming sub" sub)
    in
      result
    end;
*)

fun isSubsumed subs cl = Option.isSome (subsumes (K true) subs cl);

fun isStrictlySubsumed subs cl =
    Option.isSome (strictlySubsumes (K true) subs cl);

(* ------------------------------------------------------------------------- *)
(* Single clause versions.                                                   *)
(* ------------------------------------------------------------------------- *)

fun clauseSubsumes cl' cl =
    let
      val lits' = sortClause cl'
      and lits = clauseSym (LiteralSet.toList cl)
    in
      case genClauseSubsumes (K true) cl' lits' lits () of
        SOME (_,sub,()) => SOME sub
      | NONE => NONE
    end;

fun clauseStrictlySubsumes cl' cl =
    if LiteralSet.size cl' > LiteralSet.size cl then NONE
    else clauseSubsumes cl' cl;

end
end;

(**** Original file: KnuthBendixOrder.sig ****)

(* ========================================================================= *)
(* THE KNUTH-BENDIX TERM ORDERING                                            *)
(* Copyright (c) 2002-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature KnuthBendixOrder =
sig

(* ------------------------------------------------------------------------- *)
(* The weight of all constants must be at least 1, and there must be at most *)
(* one unary function with weight 0.                                         *)
(* ------------------------------------------------------------------------- *)

type kbo =
     {weight : Metis.Term.function -> int,
      precedence : Metis.Term.function * Metis.Term.function -> order}

val default : kbo

val compare : kbo -> Metis.Term.term * Metis.Term.term -> order option

end

(**** Original file: KnuthBendixOrder.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* KNUTH-BENDIX TERM ORDERING CONSTRAINTS                                    *)
(* Copyright (c) 2002-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure KnuthBendixOrder :> KnuthBendixOrder =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Helper functions.                                                         *)
(* ------------------------------------------------------------------------- *)

fun notEqualTerm (x,y) = not (Term.equal x y);

fun firstNotEqualTerm f l =
    case List.find notEqualTerm l of
      SOME (x,y) => f x y
    | NONE => raise Bug "firstNotEqualTerm";

(* ------------------------------------------------------------------------- *)
(* The weight of all constants must be at least 1, and there must be at most *)
(* one unary function with weight 0.                                         *)
(* ------------------------------------------------------------------------- *)

type kbo =
     {weight : Term.function -> int,
      precedence : Term.function * Term.function -> order};

(* Default weight = uniform *)

val uniformWeight : Term.function -> int = K 1;

(* Default precedence = by arity *)

val arityPrecedence : Term.function * Term.function -> order =
    fn ((f1,n1),(f2,n2)) =>
       case Int.compare (n1,n2) of
         LESS => LESS
       | EQUAL => Name.compare (f1,f2)
       | GREATER => GREATER;

(* The default order *)

val default = {weight = uniformWeight, precedence = arityPrecedence};

(* ------------------------------------------------------------------------- *)
(* Term weight-1 represented as a linear function of the weight-1 of the     *)
(* variables in the term (plus a constant).                                  *)
(*                                                                           *)
(* Note that the conditions on weight functions ensure that all weights are  *)
(* at least 1, so all weight-1s are at least 0.                              *)
(* ------------------------------------------------------------------------- *)

datatype weight = Weight of int NameMap.map * int;

val weightEmpty : int NameMap.map = NameMap.new ();

val weightZero = Weight (weightEmpty,0);

fun weightIsZero (Weight (m,c)) = c = 0 andalso NameMap.null m;

fun weightNeg (Weight (m,c)) = Weight (NameMap.transform ~ m, ~c);

local
  fun add ((_,n1),(_,n2)) =
      let
        val n = n1 + n2
      in
        if n = 0 then NONE else SOME n
      end;
in
  fun weightAdd (Weight (m1,c1)) (Weight (m2,c2)) =
      Weight (NameMap.union add m1 m2, c1 + c2);
end;

fun weightSubtract w1 w2 = weightAdd w1 (weightNeg w2);

fun weightTerm weight =
    let
      fun wt m c [] = Weight (m,c)
        | wt m c (Term.Var v :: tms) =
          let
            val n = Option.getOpt (NameMap.peek m v, 0)
          in
            wt (NameMap.insert m (v, n + 1)) (c + 1) tms
          end
        | wt m c (Term.Fn (f,a) :: tms) =
          wt m (c + weight (f, length a)) (a @ tms)
    in
      fn tm => wt weightEmpty ~1 [tm]
    end;

fun weightLowerBound (w as Weight (m,c)) =
    if NameMap.exists (fn (_,n) => n < 0) m then NONE else SOME c;

(*MetisDebug
val ppWeightList =
    let
      fun ppCoeff n =
          if n < 0 then Print.sequence (Print.addString "~") (ppCoeff (~n))
          else if n = 1 then Print.skip
          else Print.ppInt n

      fun pp_tm (NONE,n) = Print.ppInt n
        | pp_tm (SOME v, n) = Print.sequence (ppCoeff n) (Name.pp v)
    in
      fn [] => Print.ppInt 0
       | tms => Print.ppOpList " +" pp_tm tms
    end;

fun ppWeight (Weight (m,c)) =
    let
      val l = NameMap.toList m
      val l = map (fn (v,n) => (SOME v, n)) l
      val l = if c = 0 then l else l @ [(NONE,c)]
    in
      ppWeightList l
    end;

val weightToString = Print.toString ppWeight;
*)

(* ------------------------------------------------------------------------- *)
(* The Knuth-Bendix term order.                                              *)
(* ------------------------------------------------------------------------- *)

fun compare {weight,precedence} =
    let
      fun weightDifference tm1 tm2 =
          let
            val w1 = weightTerm weight tm1
            and w2 = weightTerm weight tm2
          in
            weightSubtract w2 w1
          end

      fun weightLess tm1 tm2 =
          let
            val w = weightDifference tm1 tm2
          in
            if weightIsZero w then precedenceLess tm1 tm2
            else weightDiffLess w tm1 tm2
          end

      and weightDiffLess w tm1 tm2 =
          case weightLowerBound w of
            NONE => false
          | SOME 0 => precedenceLess tm1 tm2
          | SOME n => n > 0

      and precedenceLess (Term.Fn (f1,a1)) (Term.Fn (f2,a2)) =
          (case precedence ((f1, length a1), (f2, length a2)) of
             LESS => true
           | EQUAL => firstNotEqualTerm weightLess (zip a1 a2)
           | GREATER => false)
        | precedenceLess _ _ = false

      fun weightDiffGreater w tm1 tm2 = weightDiffLess (weightNeg w) tm2 tm1

      fun weightCmp tm1 tm2 =
          let
            val w = weightDifference tm1 tm2
          in
            if weightIsZero w then precedenceCmp tm1 tm2
            else if weightDiffLess w tm1 tm2 then SOME LESS
            else if weightDiffGreater w tm1 tm2 then SOME GREATER
            else NONE
          end

      and precedenceCmp (Term.Fn (f1,a1)) (Term.Fn (f2,a2)) =
          (case precedence ((f1, length a1), (f2, length a2)) of
             LESS => SOME LESS
           | EQUAL => firstNotEqualTerm weightCmp (zip a1 a2)
           | GREATER => SOME GREATER)
        | precedenceCmp _ _ = raise Bug "kboOrder.precendenceCmp"
    in
      fn (tm1,tm2) =>
         if Term.equal tm1 tm2 then SOME EQUAL else weightCmp tm1 tm2
    end;

(*MetisTrace7
val compare = fn kbo => fn (tm1,tm2) =>
    let
      val () = Print.trace Term.pp "KnuthBendixOrder.compare: tm1" tm1
      val () = Print.trace Term.pp "KnuthBendixOrder.compare: tm2" tm2
      val result = compare kbo (tm1,tm2)
      val () =
          case result of
            NONE => trace "KnuthBendixOrder.compare: result = Incomparable\n"
          | SOME x =>
            Print.trace Print.ppOrder "KnuthBendixOrder.compare: result" x
    in
      result
    end;
*)

end
end;

(**** Original file: Rewrite.sig ****)

(* ========================================================================= *)
(* ORDERED REWRITING FOR FIRST ORDER TERMS                                   *)
(* Copyright (c) 2003-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Rewrite =
sig

(* ------------------------------------------------------------------------- *)
(* Orientations of equations.                                                *)
(* ------------------------------------------------------------------------- *)

datatype orient = LeftToRight | RightToLeft

val toStringOrient : orient -> string

val ppOrient : orient Metis.Print.pp

val toStringOrientOption : orient option -> string

val ppOrientOption : orient option Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* A type of rewrite systems.                                                *)
(* ------------------------------------------------------------------------- *)

type reductionOrder = Metis.Term.term * Metis.Term.term -> order option

type equationId = int

type equation = Metis.Rule.equation

type rewrite

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val new : reductionOrder -> rewrite

val peek : rewrite -> equationId -> (equation * orient option) option

val size : rewrite -> int

val equations : rewrite -> equation list

val toString : rewrite -> string

val pp : rewrite Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* Add equations into the system.                                            *)
(* ------------------------------------------------------------------------- *)

val add : rewrite -> equationId * equation -> rewrite

val addList : rewrite -> (equationId * equation) list -> rewrite

(* ------------------------------------------------------------------------- *)
(* Rewriting (the order must be a refinement of the rewrite order).          *)
(* ------------------------------------------------------------------------- *)

val rewrConv : rewrite -> reductionOrder -> Metis.Rule.conv

val rewriteConv : rewrite -> reductionOrder -> Metis.Rule.conv

val rewriteLiteralsRule :
    rewrite -> reductionOrder -> Metis.LiteralSet.set -> Metis.Rule.rule

val rewriteRule : rewrite -> reductionOrder -> Metis.Rule.rule

val rewrIdConv : rewrite -> reductionOrder -> equationId -> Metis.Rule.conv

val rewriteIdConv : rewrite -> reductionOrder -> equationId -> Metis.Rule.conv

val rewriteIdLiteralsRule :
    rewrite -> reductionOrder -> equationId -> Metis.LiteralSet.set -> Metis.Rule.rule

val rewriteIdRule : rewrite -> reductionOrder -> equationId -> Metis.Rule.rule

(* ------------------------------------------------------------------------- *)
(* Inter-reduce the equations in the system.                                 *)
(* ------------------------------------------------------------------------- *)

val reduce' : rewrite -> rewrite * equationId list

val reduce : rewrite -> rewrite

val isReduced : rewrite -> bool

(* ------------------------------------------------------------------------- *)
(* Rewriting as a derived rule.                                              *)
(* ------------------------------------------------------------------------- *)

val rewrite : equation list -> Metis.Thm.thm -> Metis.Thm.thm

val orderedRewrite : reductionOrder -> equation list -> Metis.Thm.thm -> Metis.Thm.thm

end

(**** Original file: Rewrite.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* ORDERED REWRITING FOR FIRST ORDER TERMS                                   *)
(* Copyright (c) 2003-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Rewrite :> Rewrite =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Orientations of equations.                                                *)
(* ------------------------------------------------------------------------- *)

datatype orient = LeftToRight | RightToLeft;

fun toStringOrient ort =
    case ort of
      LeftToRight => "-->"
    | RightToLeft => "<--";

val ppOrient = Print.ppMap toStringOrient Print.ppString;

fun toStringOrientOption orto =
    case orto of
      SOME ort => toStringOrient ort
    | NONE => "<->";

val ppOrientOption = Print.ppMap toStringOrientOption Print.ppString;

(* ------------------------------------------------------------------------- *)
(* A type of rewrite systems.                                                *)
(* ------------------------------------------------------------------------- *)

type reductionOrder = Term.term * Term.term -> order option;

type equationId = int;

type equation = Rule.equation;

datatype rewrite =
    Rewrite of
      {order : reductionOrder,
       known : (equation * orient option) IntMap.map,
       redexes : (equationId * orient) TermNet.termNet,
       subterms : (equationId * bool * Term.path) TermNet.termNet,
       waiting : IntSet.set};

fun updateWaiting rw waiting =
    let
      val Rewrite {order, known, redexes, subterms, waiting = _} = rw
    in
      Rewrite
        {order = order, known = known, redexes = redexes,
         subterms = subterms, waiting = waiting}
    end;

fun deleteWaiting (rw as Rewrite {waiting,...}) id =
    updateWaiting rw (IntSet.delete waiting id);

(* ------------------------------------------------------------------------- *)
(* Basic operations                                                          *)
(* ------------------------------------------------------------------------- *)

fun new order =
    Rewrite
      {order = order,
       known = IntMap.new (),
       redexes = TermNet.new {fifo = false},
       subterms = TermNet.new {fifo = false},
       waiting = IntSet.empty};

fun peek (Rewrite {known,...}) id = IntMap.peek known id;

fun size (Rewrite {known,...}) = IntMap.size known;

fun equations (Rewrite {known,...}) =
    IntMap.foldr (fn (_,(eqn,_),eqns) => eqn :: eqns) [] known;

val pp = Print.ppMap equations (Print.ppList Rule.ppEquation);

(*MetisTrace1
local
  fun ppEq ((x_y,_),ort) =
      Print.ppOp2 (" " ^ toStringOrientOption ort) Term.pp Term.pp x_y;

  fun ppField f ppA a =
      Print.blockProgram Print.Inconsistent 2
        [Print.addString (f ^ " ="),
         Print.addBreak 1,
         ppA a];

  val ppKnown =
      ppField "known"
        (Print.ppMap IntMap.toList
           (Print.ppList (Print.ppPair Print.ppInt ppEq)));

  val ppRedexes =
      ppField "redexes"
        (TermNet.pp (Print.ppPair Print.ppInt ppOrient));

  val ppSubterms =
      ppField "subterms"
        (TermNet.pp
           (Print.ppMap
              (fn (i,l,p) => (i, (if l then 0 else 1) :: p))
              (Print.ppPair Print.ppInt Term.ppPath)));

  val ppWaiting =
      ppField "waiting"
        (Print.ppMap (IntSet.toList) (Print.ppList Print.ppInt));
in
  fun pp (Rewrite {known,redexes,subterms,waiting,...}) =
      Print.blockProgram Print.Inconsistent 2
        [Print.addString "Rewrite",
         Print.addBreak 1,
         Print.blockProgram Print.Inconsistent 1
           [Print.addString "{",
            ppKnown known,
(*MetisTrace5
            Print.addString ",",
            Print.addBreak 1,
            ppRedexes redexes,
            Print.addString ",",
            Print.addBreak 1,
            ppSubterms subterms,
            Print.addString ",",
            Print.addBreak 1,
            ppWaiting waiting,
*)
            Print.skip],
         Print.addString "}"]
end;
*)

val toString = Print.toString pp;

(* ------------------------------------------------------------------------- *)
(* Debug functions.                                                          *)
(* ------------------------------------------------------------------------- *)

fun termReducible order known id =
    let
      fun eqnRed ((l,r),_) tm =
          case total (Subst.match Subst.empty l) tm of
            NONE => false
          | SOME sub =>
            order (tm, Subst.subst (Subst.normalize sub) r) = SOME GREATER

      fun knownRed tm (eqnId,(eqn,ort)) =
          eqnId <> id andalso
          ((ort <> SOME RightToLeft andalso eqnRed eqn tm) orelse
           (ort <> SOME LeftToRight andalso eqnRed (Rule.symEqn eqn) tm))

      fun termRed tm = IntMap.exists (knownRed tm) known orelse subtermRed tm
      and subtermRed (Term.Var _) = false
        | subtermRed (Term.Fn (_,tms)) = List.exists termRed tms
    in
      termRed
    end;

fun literalReducible order known id lit =
    List.exists (termReducible order known id) (Literal.arguments lit);

fun literalsReducible order known id lits =
    LiteralSet.exists (literalReducible order known id) lits;

fun thmReducible order known id th =
    literalsReducible order known id (Thm.clause th);

(* ------------------------------------------------------------------------- *)
(* Add equations into the system.                                            *)
(* ------------------------------------------------------------------------- *)

fun orderToOrient (SOME EQUAL) = raise Error "Rewrite.orient: reflexive"
  | orderToOrient (SOME GREATER) = SOME LeftToRight
  | orderToOrient (SOME LESS) = SOME RightToLeft
  | orderToOrient NONE = NONE;

local
  fun ins redexes redex id ort = TermNet.insert redexes (redex,(id,ort));
in
  fun addRedexes id (((l,r),_),ort) redexes =
      case ort of
        SOME LeftToRight => ins redexes l id LeftToRight
      | SOME RightToLeft => ins redexes r id RightToLeft
      | NONE => ins (ins redexes l id LeftToRight) r id RightToLeft;
end;

fun add (rw as Rewrite {known,...}) (id,eqn) =
    if IntMap.inDomain id known then rw
    else
      let
        val Rewrite {order,redexes,subterms,waiting, ...} = rw
        val ort = orderToOrient (order (fst eqn))
        val known = IntMap.insert known (id,(eqn,ort))
        val redexes = addRedexes id (eqn,ort) redexes
        val waiting = IntSet.add waiting id
        val rw =
            Rewrite
              {order = order, known = known, redexes = redexes,
               subterms = subterms, waiting = waiting}
(*MetisTrace5
        val () = Print.trace pp "Rewrite.add: result" rw
*)
      in
        rw
      end;

val addList = foldl (fn (eqn,rw) => add rw eqn);

(* ------------------------------------------------------------------------- *)
(* Rewriting (the order must be a refinement of the rewrite order).          *)
(* ------------------------------------------------------------------------- *)

local
  fun reorder ((i,_),(j,_)) = Int.compare (j,i);
in
  fun matchingRedexes redexes tm = sort reorder (TermNet.match redexes tm);
end;

fun wellOriented NONE _ = true
  | wellOriented (SOME LeftToRight) LeftToRight = true
  | wellOriented (SOME RightToLeft) RightToLeft = true
  | wellOriented _ _ = false;

fun redexResidue LeftToRight ((l_r,_) : equation) = l_r
  | redexResidue RightToLeft ((l,r),_) = (r,l);

fun orientedEquation LeftToRight eqn = eqn
  | orientedEquation RightToLeft eqn = Rule.symEqn eqn;

fun rewrIdConv' order known redexes id tm =
    let
      fun rewr (id',lr) =
          let
            val _ = id <> id' orelse raise Error "same theorem"
            val (eqn,ort) = IntMap.get known id'
            val _ = wellOriented ort lr orelse raise Error "orientation"
            val (l,r) = redexResidue lr eqn
            val sub = Subst.normalize (Subst.match Subst.empty l tm)
            val tm' = Subst.subst sub r
            val _ = Option.isSome ort orelse
                    order (tm,tm') = SOME GREATER orelse
                    raise Error "order"
            val (_,th) = orientedEquation lr eqn
          in
            (tm', Thm.subst sub th)
          end
    in
      case first (total rewr) (matchingRedexes redexes tm) of
        NONE => raise Error "Rewrite.rewrIdConv: no matching rewrites"
      | SOME res => res
    end;

fun rewriteIdConv' order known redexes id =
    if IntMap.null known then Rule.allConv
    else Rule.repeatTopDownConv (rewrIdConv' order known redexes id);

fun mkNeqConv order lit =
    let
      val (l,r) = Literal.destNeq lit
    in
      case order (l,r) of
        NONE => raise Error "incomparable"
      | SOME LESS =>
        let
          val th = Rule.symmetryRule l r
        in
          fn tm =>
             if Term.equal tm r then (l,th) else raise Error "mkNeqConv: RL"
        end
      | SOME EQUAL => raise Error "irreflexive"
      | SOME GREATER =>
        let
          val th = Thm.assume lit
        in
          fn tm =>
             if Term.equal tm l then (r,th) else raise Error "mkNeqConv: LR"
        end
    end;

datatype neqConvs = NeqConvs of Rule.conv LiteralMap.map;

val neqConvsEmpty = NeqConvs (LiteralMap.new ());

fun neqConvsNull (NeqConvs m) = LiteralMap.null m;

fun neqConvsAdd order (neq as NeqConvs m) lit =
    case total (mkNeqConv order) lit of
      NONE => NONE
    | SOME conv => SOME (NeqConvs (LiteralMap.insert m (lit,conv)));

fun mkNeqConvs order =
    let
      fun add (lit,(neq,lits)) =
          case neqConvsAdd order neq lit of
            SOME neq => (neq,lits)
          | NONE => (neq, LiteralSet.add lits lit)
    in
      LiteralSet.foldl add (neqConvsEmpty,LiteralSet.empty)
    end;

fun neqConvsDelete (NeqConvs m) lit = NeqConvs (LiteralMap.delete m lit);

fun neqConvsToConv (NeqConvs m) =
    Rule.firstConv (LiteralMap.foldr (fn (_,c,l) => c :: l) [] m);

fun neqConvsFoldl f b (NeqConvs m) =
    LiteralMap.foldl (fn (l,_,z) => f (l,z)) b m;

fun neqConvsRewrIdLiterule order known redexes id neq =
    if IntMap.null known andalso neqConvsNull neq then Rule.allLiterule
    else
      let
        val neq_conv = neqConvsToConv neq
        val rewr_conv = rewrIdConv' order known redexes id
        val conv = Rule.orelseConv neq_conv rewr_conv
        val conv = Rule.repeatTopDownConv conv
      in
        Rule.allArgumentsLiterule conv
      end;

fun rewriteIdEqn' order known redexes id (eqn as (l_r,th)) =
    let
      val (neq,_) = mkNeqConvs order (Thm.clause th)
      val literule = neqConvsRewrIdLiterule order known redexes id neq
      val (strongEqn,lit) =
          case Rule.equationLiteral eqn of
            NONE => (true, Literal.mkEq l_r)
          | SOME lit => (false,lit)
      val (lit',litTh) = literule lit
    in
      if Literal.equal lit lit' then eqn
      else
        (Literal.destEq lit',
         if strongEqn then th
         else if not (Thm.negateMember lit litTh) then litTh
         else Thm.resolve lit th litTh)
    end
(*MetisDebug
    handle Error err => raise Error ("Rewrite.rewriteIdEqn':\n" ^ err);
*)

fun rewriteIdLiteralsRule' order known redexes id lits th =
    let
      val mk_literule = neqConvsRewrIdLiterule order known redexes id

      fun rewr_neq_lit (lit, acc as (changed,neq,lits,th)) =
          let
            val neq = neqConvsDelete neq lit
            val (lit',litTh) = mk_literule neq lit
          in
            if Literal.equal lit lit' then acc
            else
              let
                val th = Thm.resolve lit th litTh
              in
                case neqConvsAdd order neq lit' of
                  SOME neq => (true,neq,lits,th)
                | NONE => (changed, neq, LiteralSet.add lits lit', th)
              end
          end

      fun rewr_neq_lits neq lits th =
          let
            val (changed,neq,lits,th) =
                neqConvsFoldl rewr_neq_lit (false,neq,lits,th) neq
          in
            if changed then rewr_neq_lits neq lits th
            else (neq,lits,th)
          end

      val (neq,lits) = mkNeqConvs order lits

      val (neq,lits,th) = rewr_neq_lits neq lits th

      val rewr_literule = mk_literule neq

      fun rewr_lit (lit,th) =
          if Thm.member lit th then Rule.literalRule rewr_literule lit th
          else th
    in
      LiteralSet.foldl rewr_lit th lits
    end;

fun rewriteIdRule' order known redexes id th =
    rewriteIdLiteralsRule' order known redexes id (Thm.clause th) th;

(*MetisDebug
val rewriteIdRule' = fn order => fn known => fn redexes => fn id => fn th =>
    let
(*MetisTrace6
      val () = Print.trace Thm.pp "Rewrite.rewriteIdRule': th" th
*)
      val result = rewriteIdRule' order known redexes id th
(*MetisTrace6
      val () = Print.trace Thm.pp "Rewrite.rewriteIdRule': result" result
*)
      val _ = not (thmReducible order known id result) orelse
              raise Bug "rewriteIdRule: should be normalized"
    in
      result
    end
    handle Error err => raise Error ("Rewrite.rewriteIdRule:\n" ^ err);
*)

fun rewrIdConv (Rewrite {known,redexes,...}) order =
    rewrIdConv' order known redexes;

fun rewrConv rewrite order = rewrIdConv rewrite order ~1;

fun rewriteIdConv (Rewrite {known,redexes,...}) order =
    rewriteIdConv' order known redexes;

fun rewriteConv rewrite order = rewriteIdConv rewrite order ~1;

fun rewriteIdLiteralsRule (Rewrite {known,redexes,...}) order =
    rewriteIdLiteralsRule' order known redexes;

fun rewriteLiteralsRule rewrite order =
    rewriteIdLiteralsRule rewrite order ~1;

fun rewriteIdRule (Rewrite {known,redexes,...}) order =
    rewriteIdRule' order known redexes;

fun rewriteRule rewrite order = rewriteIdRule rewrite order ~1;

(* ------------------------------------------------------------------------- *)
(* Inter-reduce the equations in the system.                                 *)
(* ------------------------------------------------------------------------- *)

fun addSubterms id (((l,r),_) : equation) subterms =
    let
      fun addSubterm b ((path,tm),net) = TermNet.insert net (tm,(id,b,path))

      val subterms = foldl (addSubterm true) subterms (Term.subterms l)
      val subterms = foldl (addSubterm false) subterms (Term.subterms r)
    in
      subterms
    end;

fun sameRedexes NONE _ _ = false
  | sameRedexes (SOME LeftToRight) (l0,_) (l,_) = Term.equal l0 l
  | sameRedexes (SOME RightToLeft) (_,r0) (_,r) = Term.equal r0 r;

fun redexResidues NONE (l,r) = [(l,r,false),(r,l,false)]
  | redexResidues (SOME LeftToRight) (l,r) = [(l,r,true)]
  | redexResidues (SOME RightToLeft) (l,r) = [(r,l,true)];

fun findReducibles order known subterms id =
    let
      fun checkValidRewr (l,r,ord) id' left path =
          let
            val (((x,y),_),_) = IntMap.get known id'
            val tm = Term.subterm (if left then x else y) path
            val sub = Subst.match Subst.empty l tm
          in
            if ord then ()
            else
              let
                val tm' = Subst.subst (Subst.normalize sub) r
              in
                if order (tm,tm') = SOME GREATER then ()
                else raise Error "order"
              end
          end

      fun addRed lr ((id',left,path),todo) =
          if id <> id' andalso not (IntSet.member id' todo) andalso
             can (checkValidRewr lr id' left) path
          then IntSet.add todo id'
          else todo

      fun findRed (lr as (l,_,_), todo) =
          List.foldl (addRed lr) todo (TermNet.matched subterms l)
    in
      List.foldl findRed
    end;

fun reduce1 new id (eqn0,ort0) (rpl,spl,todo,rw,changed) =
    let
      val (eq0,_) = eqn0
      val Rewrite {order,known,redexes,subterms,waiting} = rw
      val eqn as (eq,_) = rewriteIdEqn' order known redexes id eqn0
      val identical =
          let
            val (l0,r0) = eq0
            and (l,r) = eq
          in
            Term.equal l l0 andalso Term.equal r r0
          end
      val same_redexes = identical orelse sameRedexes ort0 eq0 eq
      val rpl = if same_redexes then rpl else IntSet.add rpl id
      val spl = if new orelse identical then spl else IntSet.add spl id
      val changed =
          if not new andalso identical then changed else IntSet.add changed id
      val ort =
          if same_redexes then SOME ort0 else total orderToOrient (order eq)
    in
      case ort of
        NONE =>
        let
          val known = IntMap.delete known id
          val rw =
              Rewrite
                {order = order, known = known, redexes = redexes,
                 subterms = subterms, waiting = waiting}
        in
          (rpl,spl,todo,rw,changed)
        end
      | SOME ort =>
        let
          val todo =
              if not new andalso same_redexes then todo
              else
                findReducibles
                  order known subterms id todo (redexResidues ort eq)
          val known =
              if identical then known else IntMap.insert known (id,(eqn,ort))
          val redexes =
              if same_redexes then redexes
              else addRedexes id (eqn,ort) redexes
          val subterms =
              if new orelse not identical then addSubterms id eqn subterms
              else subterms
          val rw =
              Rewrite
                {order = order, known = known, redexes = redexes,
                 subterms = subterms, waiting = waiting}
        in
          (rpl,spl,todo,rw,changed)
        end
    end;

fun pick known set =
    let
      fun oriented id =
          case IntMap.peek known id of
            SOME (x as (_, SOME _)) => SOME (id,x)
          | _ => NONE

      fun any id =
          case IntMap.peek known id of SOME x => SOME (id,x) | _ => NONE
    in
      case IntSet.firstl oriented set of
        x as SOME _ => x
      | NONE => IntSet.firstl any set
    end;

local
  fun cleanRedexes known redexes rpl =
      if IntSet.null rpl then redexes
      else
        let
          fun filt (id,_) = not (IntSet.member id rpl)

          fun addReds (id,reds) =
              case IntMap.peek known id of
                NONE => reds
              | SOME eqn_ort => addRedexes id eqn_ort reds

          val redexes = TermNet.filter filt redexes
          val redexes = IntSet.foldl addReds redexes rpl
        in
          redexes
        end;

  fun cleanSubterms known subterms spl =
      if IntSet.null spl then subterms
      else
        let
          fun filt (id,_,_) = not (IntSet.member id spl)

          fun addSubtms (id,subtms) =
              case IntMap.peek known id of
                NONE => subtms
              | SOME (eqn,_) => addSubterms id eqn subtms

          val subterms = TermNet.filter filt subterms
          val subterms = IntSet.foldl addSubtms subterms spl
        in
          subterms
        end;
in
  fun rebuild rpl spl rw =
      let
(*MetisTrace5
        val ppPl = Print.ppMap IntSet.toList (Print.ppList Print.ppInt)
        val () = Print.trace ppPl "Rewrite.rebuild: rpl" rpl
        val () = Print.trace ppPl "Rewrite.rebuild: spl" spl
*)
        val Rewrite {order,known,redexes,subterms,waiting} = rw
        val redexes = cleanRedexes known redexes rpl
        val subterms = cleanSubterms known subterms spl
      in
        Rewrite
          {order = order,
           known = known,
           redexes = redexes,
           subterms = subterms,
           waiting = waiting}
      end;
end;

fun reduceAcc (rpl, spl, todo, rw as Rewrite {known,waiting,...}, changed) =
    case pick known todo of
      SOME (id,eqn_ort) =>
      let
        val todo = IntSet.delete todo id
      in
        reduceAcc (reduce1 false id eqn_ort (rpl,spl,todo,rw,changed))
      end
    | NONE =>
      case pick known waiting of
        SOME (id,eqn_ort) =>
        let
          val rw = deleteWaiting rw id
        in
          reduceAcc (reduce1 true id eqn_ort (rpl,spl,todo,rw,changed))
        end
      | NONE => (rebuild rpl spl rw, IntSet.toList changed);

fun isReduced (Rewrite {waiting,...}) = IntSet.null waiting;

fun reduce' rw =
    if isReduced rw then (rw,[])
    else reduceAcc (IntSet.empty,IntSet.empty,IntSet.empty,rw,IntSet.empty);

(*MetisDebug
val reduce' = fn rw =>
    let
(*MetisTrace4
      val () = Print.trace pp "Rewrite.reduce': rw" rw
*)
      val Rewrite {known,order,...} = rw
      val result as (Rewrite {known = known', ...}, _) = reduce' rw
(*MetisTrace4
      val ppResult = Print.ppPair pp (Print.ppList Print.ppInt)
      val () = Print.trace ppResult "Rewrite.reduce': result" result
*)
      val ths = map (fn (id,((_,th),_)) => (id,th)) (IntMap.toList known')
      val _ =
          not (List.exists (uncurry (thmReducible order known')) ths) orelse
          raise Bug "Rewrite.reduce': not fully reduced"
    in
      result
    end
    handle Error err => raise Bug ("Rewrite.reduce': shouldn't fail\n" ^ err);
*)

fun reduce rw = fst (reduce' rw);

(* ------------------------------------------------------------------------- *)
(* Rewriting as a derived rule.                                              *)
(* ------------------------------------------------------------------------- *)

local
  fun addEqn (id_eqn,rw) = add rw id_eqn;
in
  fun orderedRewrite order ths =
    let
      val rw = foldl addEqn (new order) (enumerate ths)
    in
      rewriteRule rw order
    end;
end;

val rewrite = orderedRewrite (K (SOME GREATER));

end
end;

(**** Original file: Units.sig ****)

(* ========================================================================= *)
(* A STORE FOR UNIT THEOREMS                                                 *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Units =
sig

(* ------------------------------------------------------------------------- *)
(* A type of unit store.                                                     *)
(* ------------------------------------------------------------------------- *)

type unitThm = Metis.Literal.literal * Metis.Thm.thm

type units

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val empty : units

val size : units -> int

val toString : units -> string

val pp : units Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* Add units into the store.                                                 *)
(* ------------------------------------------------------------------------- *)

val add : units -> unitThm -> units

val addList : units -> unitThm list -> units

(* ------------------------------------------------------------------------- *)
(* Matching.                                                                 *)
(* ------------------------------------------------------------------------- *)

val match : units -> Metis.Literal.literal -> (unitThm * Metis.Subst.subst) option

(* ------------------------------------------------------------------------- *)
(* Reducing by repeated matching and resolution.                             *)
(* ------------------------------------------------------------------------- *)

val reduce : units -> Metis.Rule.rule

end

(**** Original file: Units.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* A STORE FOR UNIT THEOREMS                                                 *)
(* Copyright (c) 2001-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Units :> Units =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of unit store.                                                     *)
(* ------------------------------------------------------------------------- *)

type unitThm = Literal.literal * Thm.thm;

datatype units = Units of unitThm LiteralNet.literalNet;

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val empty = Units (LiteralNet.new {fifo = false});

fun size (Units net) = LiteralNet.size net;

fun toString units = "U{" ^ Int.toString (size units) ^ "}";

val pp = Print.ppMap toString Print.ppString;

(* ------------------------------------------------------------------------- *)
(* Add units into the store.                                                 *)
(* ------------------------------------------------------------------------- *)

fun add (units as Units net) (uTh as (lit,th)) =
    let
      val net = LiteralNet.insert net (lit,uTh)
    in
      case total Literal.sym lit of
        NONE => Units net
      | SOME (lit' as (pol,_)) =>
        let
          val th' = (if pol then Rule.symEq else Rule.symNeq) lit th
          val net = LiteralNet.insert net (lit',(lit',th'))
        in
          Units net
        end
    end;

val addList = foldl (fn (th,u) => add u th);

(* ------------------------------------------------------------------------- *)
(* Matching.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun match (Units net) lit =
    let
      fun check (uTh as (lit',_)) =
          case total (Literal.match Subst.empty lit') lit of
            NONE => NONE
          | SOME sub => SOME (uTh,sub)
    in
      first check (LiteralNet.match net lit)
    end;

(* ------------------------------------------------------------------------- *)
(* Reducing by repeated matching and resolution.                             *)
(* ------------------------------------------------------------------------- *)

fun reduce units =
    let
      fun red1 (lit,news_th) =
          case total Literal.destIrrefl lit of
            SOME tm =>
            let
              val (news,th) = news_th
              val th = Thm.resolve lit th (Thm.refl tm)
            in
              (news,th)
            end
          | NONE =>
            let
              val lit' = Literal.negate lit
            in
              case match units lit' of
                NONE => news_th
              | SOME ((_,rth),sub) =>
                let
                  val (news,th) = news_th
                  val rth = Thm.subst sub rth
                  val th = Thm.resolve lit th rth
                  val new = LiteralSet.delete (Thm.clause rth) lit'
                  val news = LiteralSet.union new news
                in
                  (news,th)
                end
            end

      fun red (news,th) =
          if LiteralSet.null news then th
          else red (LiteralSet.foldl red1 (LiteralSet.empty,th) news)
    in
      fn th => Rule.removeSym (red (Thm.clause th, th))
    end;

end
end;

(**** Original file: Clause.sig ****)

(* ========================================================================= *)
(* CLAUSE = ID + THEOREM                                                     *)
(* Copyright (c) 2002-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Clause =
sig

(* ------------------------------------------------------------------------- *)
(* A type of clause.                                                         *)
(* ------------------------------------------------------------------------- *)

datatype literalOrder =
    NoLiteralOrder
  | UnsignedLiteralOrder
  | PositiveLiteralOrder

type parameters =
     {ordering : Metis.KnuthBendixOrder.kbo,
      orderLiterals : literalOrder,
      orderTerms : bool}

type clauseId = int

type clauseInfo = {parameters : parameters, id : clauseId, thm : Metis.Thm.thm}

type clause

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val default : parameters

val newId : unit -> clauseId

val mk : clauseInfo -> clause

val dest : clause -> clauseInfo

val id : clause -> clauseId

val thm : clause -> Metis.Thm.thm

val equalThms : clause -> clause -> bool

val literals : clause -> Metis.Thm.clause

val isTautology : clause -> bool

val isContradiction : clause -> bool

(* ------------------------------------------------------------------------- *)
(* The term ordering is used to cut down inferences.                         *)
(* ------------------------------------------------------------------------- *)

val largestLiterals : clause -> Metis.LiteralSet.set

val largestEquations :
    clause -> (Metis.Literal.literal * Metis.Rewrite.orient * Metis.Term.term) list

val largestSubterms :
    clause -> (Metis.Literal.literal * Metis.Term.path * Metis.Term.term) list

val allSubterms : clause -> (Metis.Literal.literal * Metis.Term.path * Metis.Term.term) list

(* ------------------------------------------------------------------------- *)
(* Subsumption.                                                              *)
(* ------------------------------------------------------------------------- *)

val subsumes : clause Metis.Subsume.subsume -> clause -> bool

(* ------------------------------------------------------------------------- *)
(* Simplifying rules: these preserve the clause id.                          *)
(* ------------------------------------------------------------------------- *)

val freshVars : clause -> clause

val simplify : clause -> clause option

val reduce : Metis.Units.units -> clause -> clause

val rewrite : Metis.Rewrite.rewrite -> clause -> clause

(* ------------------------------------------------------------------------- *)
(* Inference rules: these generate new clause ids.                           *)
(* ------------------------------------------------------------------------- *)

val factor : clause -> clause list

val resolve : clause * Metis.Literal.literal -> clause * Metis.Literal.literal -> clause

val paramodulate :
    clause * Metis.Literal.literal * Metis.Rewrite.orient * Metis.Term.term ->
    clause * Metis.Literal.literal * Metis.Term.path * Metis.Term.term -> clause

(* ------------------------------------------------------------------------- *)
(* Pretty printing.                                                          *)
(* ------------------------------------------------------------------------- *)

val showId : bool Unsynchronized.ref

val pp : clause Metis.Print.pp

val toString : clause -> string

end

(**** Original file: Clause.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* CLAUSE = ID + THEOREM                                                     *)
(* Copyright (c) 2002-2004 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Clause :> Clause =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Helper functions.                                                         *)
(* ------------------------------------------------------------------------- *)

val newId =
    let
      val r = Unsynchronized.ref 0
    in
      fn () => case r of Unsynchronized.ref n => let val () = r := n + 1 in n end
    end;

(* ------------------------------------------------------------------------- *)
(* A type of clause.                                                         *)
(* ------------------------------------------------------------------------- *)

datatype literalOrder =
    NoLiteralOrder
  | UnsignedLiteralOrder
  | PositiveLiteralOrder;

type parameters =
     {ordering : KnuthBendixOrder.kbo,
      orderLiterals : literalOrder,
      orderTerms : bool};

type clauseId = int;

type clauseInfo = {parameters : parameters, id : clauseId, thm : Thm.thm};

datatype clause = Clause of clauseInfo;

(* ------------------------------------------------------------------------- *)
(* Pretty printing.                                                          *)
(* ------------------------------------------------------------------------- *)

val showId = Unsynchronized.ref false;

local
  val ppIdThm = Print.ppPair Print.ppInt Thm.pp;
in
  fun pp (Clause {id,thm,...}) =
      if !showId then ppIdThm (id,thm) else Thm.pp thm;
end;

fun toString cl = Print.toString pp cl;

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val default : parameters =
    {ordering = KnuthBendixOrder.default,
     orderLiterals = UnsignedLiteralOrder (* PositiveLiteralOrder *) (* MODIFIED by Jasmin Blanchette *),
     orderTerms = true};

fun mk info = Clause info

fun dest (Clause info) = info;

fun id (Clause {id = i, ...}) = i;

fun thm (Clause {thm = th, ...}) = th;

fun equalThms cl cl' = Thm.equal (thm cl) (thm cl');

fun new parameters thm =
    Clause {parameters = parameters, id = newId (), thm = thm};

fun literals cl = Thm.clause (thm cl);

fun isTautology (Clause {thm,...}) = Thm.isTautology thm;

fun isContradiction (Clause {thm,...}) = Thm.isContradiction thm;

(* ------------------------------------------------------------------------- *)
(* The term ordering is used to cut down inferences.                         *)
(* ------------------------------------------------------------------------- *)

fun strictlyLess ordering x_y =
    case KnuthBendixOrder.compare ordering x_y of
      SOME LESS => true
    | _ => false;

fun isLargerTerm ({ordering,orderTerms,...} : parameters) l_r =
    not orderTerms orelse not (strictlyLess ordering l_r);

local
  fun atomToTerms atm =
      case total Atom.destEq atm of
        NONE => [Term.Fn atm]
      | SOME (l,r) => [l,r];

  fun notStrictlyLess ordering (xs,ys) =
      let
        fun less x = List.exists (fn y => strictlyLess ordering (x,y)) ys
      in
        not (List.all less xs)
      end;
in
  fun isLargerLiteral ({ordering,orderLiterals,...} : parameters) lits =
      case orderLiterals of
        NoLiteralOrder => K true
      | UnsignedLiteralOrder =>
        let
          fun addLit ((_,atm),acc) = atomToTerms atm @ acc

          val tms = LiteralSet.foldl addLit [] lits
        in
          fn (_,atm') => notStrictlyLess ordering (atomToTerms atm', tms)
        end
      | PositiveLiteralOrder =>
        case LiteralSet.findl (K true) lits of
          NONE => K true
        | SOME (pol,_) =>
          let
            fun addLit ((p,atm),acc) =
                if p = pol then atomToTerms atm @ acc else acc

            val tms = LiteralSet.foldl addLit [] lits
          in
            fn (pol',atm') =>
               if pol <> pol' then pol
               else notStrictlyLess ordering (atomToTerms atm', tms)
          end;
end;

fun largestLiterals (Clause {parameters,thm,...}) =
    let
      val litSet = Thm.clause thm
      val isLarger = isLargerLiteral parameters litSet
      fun addLit (lit,s) = if isLarger lit then LiteralSet.add s lit else s
    in
      LiteralSet.foldr addLit LiteralSet.empty litSet
    end;

(*MetisTrace6
val largestLiterals = fn cl =>
    let
      val ppResult = LiteralSet.pp
      val () = Print.trace pp "Clause.largestLiterals: cl" cl
      val result = largestLiterals cl
      val () = Print.trace ppResult "Clause.largestLiterals: result" result
    in
      result
    end;
*)

fun largestEquations (cl as Clause {parameters,...}) =
    let
      fun addEq lit ort (l_r as (l,_)) acc =
          if isLargerTerm parameters l_r then (lit,ort,l) :: acc else acc

      fun addLit (lit,acc) =
          case total Literal.destEq lit of
            NONE => acc
          | SOME (l,r) =>
            let
              val acc = addEq lit Rewrite.RightToLeft (r,l) acc
              val acc = addEq lit Rewrite.LeftToRight (l,r) acc
            in
              acc
            end
    in
      LiteralSet.foldr addLit [] (largestLiterals cl)
    end;

local
  fun addLit (lit,acc) =
      let
        fun addTm ((path,tm),acc) = (lit,path,tm) :: acc
      in
        foldl addTm acc (Literal.nonVarTypedSubterms lit)
      end;
in
  fun largestSubterms cl = LiteralSet.foldl addLit [] (largestLiterals cl);

  fun allSubterms cl = LiteralSet.foldl addLit [] (literals cl);
end;

(* ------------------------------------------------------------------------- *)
(* Subsumption.                                                              *)
(* ------------------------------------------------------------------------- *)

fun subsumes (subs : clause Subsume.subsume) cl =
    Subsume.isStrictlySubsumed subs (literals cl);

(* ------------------------------------------------------------------------- *)
(* Simplifying rules: these preserve the clause id.                          *)
(* ------------------------------------------------------------------------- *)

fun freshVars (Clause {parameters,id,thm}) =
    Clause {parameters = parameters, id = id, thm = Rule.freshVars thm};

fun simplify (Clause {parameters,id,thm}) =
    case Rule.simplify thm of
      NONE => NONE
    | SOME thm => SOME (Clause {parameters = parameters, id = id, thm = thm});

fun reduce units (Clause {parameters,id,thm}) =
    Clause {parameters = parameters, id = id, thm = Units.reduce units thm};

fun rewrite rewr (cl as Clause {parameters,id,thm}) =
    let
      fun simp th =
          let
            val {ordering,...} = parameters
            val cmp = KnuthBendixOrder.compare ordering
          in
            Rewrite.rewriteIdRule rewr cmp id th
          end

(*MetisTrace4
      val () = Print.trace Rewrite.pp "Clause.rewrite: rewr" rewr
      val () = Print.trace Print.ppInt "Clause.rewrite: id" id
      val () = Print.trace pp "Clause.rewrite: cl" cl
*)

      val thm =
          case Rewrite.peek rewr id of
            NONE => simp thm
          | SOME ((_,thm),_) => if Rewrite.isReduced rewr then thm else simp thm

      val result = Clause {parameters = parameters, id = id, thm = thm}

(*MetisTrace4
      val () = Print.trace pp "Clause.rewrite: result" result
*)
    in
      result
    end
(*MetisDebug
    handle Error err => raise Error ("Clause.rewrite:\n" ^ err);
*)

(* ------------------------------------------------------------------------- *)
(* Inference rules: these generate new clause ids.                           *)
(* ------------------------------------------------------------------------- *)

fun factor (cl as Clause {parameters,thm,...}) =
    let
      val lits = largestLiterals cl

      fun apply sub = new parameters (Thm.subst sub thm)
    in
      map apply (Rule.factor' lits)
    end;

(*MetisTrace5
val factor = fn cl =>
    let
      val () = Print.trace pp "Clause.factor: cl" cl
      val result = factor cl
      val () = Print.trace (Print.ppList pp) "Clause.factor: result" result
    in
      result
    end;
*)

fun resolve (cl1,lit1) (cl2,lit2) =
    let
(*MetisTrace5
      val () = Print.trace pp "Clause.resolve: cl1" cl1
      val () = Print.trace Literal.pp "Clause.resolve: lit1" lit1
      val () = Print.trace pp "Clause.resolve: cl2" cl2
      val () = Print.trace Literal.pp "Clause.resolve: lit2" lit2
*)
      val Clause {parameters, thm = th1, ...} = cl1
      and Clause {thm = th2, ...} = cl2
      val sub = Literal.unify Subst.empty lit1 (Literal.negate lit2)
(*MetisTrace5
      val () = Print.trace Subst.pp "Clause.resolve: sub" sub
*)
      val lit1 = Literal.subst sub lit1
      val lit2 = Literal.negate lit1
      val th1 = Thm.subst sub th1
      and th2 = Thm.subst sub th2
      val _ = isLargerLiteral parameters (Thm.clause th1) lit1 orelse
(*MetisTrace5
              (trace "Clause.resolve: th1 violates ordering\n"; false) orelse
*)
              raise Error "resolve: clause1: ordering constraints"
      val _ = isLargerLiteral parameters (Thm.clause th2) lit2 orelse
(*MetisTrace5
              (trace "Clause.resolve: th2 violates ordering\n"; false) orelse
*)
              raise Error "resolve: clause2: ordering constraints"
      val th = Thm.resolve lit1 th1 th2
(*MetisTrace5
      val () = Print.trace Thm.pp "Clause.resolve: th" th
*)
      val cl = Clause {parameters = parameters, id = newId (), thm = th}
(*MetisTrace5
      val () = Print.trace pp "Clause.resolve: cl" cl
*)
    in
      cl
    end;

fun paramodulate (cl1,lit1,ort1,tm1) (cl2,lit2,path2,tm2) =
    let
(*MetisTrace5
      val () = Print.trace pp "Clause.paramodulate: cl1" cl1
      val () = Print.trace Literal.pp "Clause.paramodulate: lit1" lit1
      val () = Print.trace Rewrite.ppOrient "Clause.paramodulate: ort1" ort1
      val () = Print.trace Term.pp "Clause.paramodulate: tm1" tm1
      val () = Print.trace pp "Clause.paramodulate: cl2" cl2
      val () = Print.trace Literal.pp "Clause.paramodulate: lit2" lit2
      val () = Print.trace Term.ppPath "Clause.paramodulate: path2" path2
      val () = Print.trace Term.pp "Clause.paramodulate: tm2" tm2
*)
      val Clause {parameters, thm = th1, ...} = cl1
      and Clause {thm = th2, ...} = cl2
      val sub = Subst.unify Subst.empty tm1 tm2
      val lit1 = Literal.subst sub lit1
      and lit2 = Literal.subst sub lit2
      and th1 = Thm.subst sub th1
      and th2 = Thm.subst sub th2

      val _ = isLargerLiteral parameters (Thm.clause th1) lit1 orelse
              raise Error "Clause.paramodulate: with clause: ordering"
      val _ = isLargerLiteral parameters (Thm.clause th2) lit2 orelse
              raise Error "Clause.paramodulate: into clause: ordering"

      val eqn = (Literal.destEq lit1, th1)
      val eqn as (l_r,_) =
          case ort1 of
            Rewrite.LeftToRight => eqn
          | Rewrite.RightToLeft => Rule.symEqn eqn
(*MetisTrace6
      val () = Print.trace Rule.ppEquation "Clause.paramodulate: eqn" eqn
*)
      val _ = isLargerTerm parameters l_r orelse
              raise Error "Clause.paramodulate: equation: ordering constraints"
      val th = Rule.rewrRule eqn lit2 path2 th2
(*MetisTrace5
      val () = Print.trace Thm.pp "Clause.paramodulate: th" th
*)
    in
      Clause {parameters = parameters, id = newId (), thm = th}
    end
(*MetisTrace5
    handle Error err =>
      let
        val () = trace ("Clause.paramodulate: failed: " ^ err ^ "\n")
      in
        raise Error err
      end;
*)

end
end;

(**** Original file: Active.sig ****)

(* ========================================================================= *)
(* THE ACTIVE SET OF CLAUSES                                                 *)
(* Copyright (c) 2002-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Active =
sig

(* ------------------------------------------------------------------------- *)
(* A type of active clause sets.                                             *)
(* ------------------------------------------------------------------------- *)

type simplify =
     {subsume : bool,
      reduce : bool,
      rewrite : bool}

type parameters =
     {clause : Metis.Clause.parameters,
      prefactor : simplify,
      postfactor : simplify}

type active

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val default : parameters

val size : active -> int

val saturation : active -> Metis.Clause.clause list

(* ------------------------------------------------------------------------- *)
(* Create a new active clause set and initialize clauses.                    *)
(* ------------------------------------------------------------------------- *)

val new :
    parameters -> {axioms : Metis.Thm.thm list, conjecture : Metis.Thm.thm list} ->
    active * {axioms : Metis.Clause.clause list, conjecture : Metis.Clause.clause list}

(* ------------------------------------------------------------------------- *)
(* Add a clause into the active set and deduce all consequences.             *)
(* ------------------------------------------------------------------------- *)

val add : active -> Metis.Clause.clause -> active * Metis.Clause.clause list

(* ------------------------------------------------------------------------- *)
(* Pretty printing.                                                          *)
(* ------------------------------------------------------------------------- *)

val pp : active Metis.Print.pp

end

(**** Original file: Active.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* THE ACTIVE SET OF CLAUSES                                                 *)
(* Copyright (c) 2002-2006 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Active :> Active =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Helper functions.                                                         *)
(* ------------------------------------------------------------------------- *)

(*MetisDebug
local
  fun mkRewrite ordering =
      let
        fun add (cl,rw) =
            let
              val {id, thm = th, ...} = Clause.dest cl
            in
              case total Thm.destUnitEq th of
                SOME l_r => Rewrite.add rw (id,(l_r,th))
              | NONE => rw
            end
      in
        foldl add (Rewrite.new (KnuthBendixOrder.compare ordering))
      end;

  fun allFactors red =
      let
        fun allClause cl =
            List.all red (cl :: Clause.factor cl) orelse
            let
              val () = Print.trace Clause.pp
                         "Active.isSaturated.allFactors: cl" cl
            in
              false
            end
      in
        List.all allClause
      end;

  fun allResolutions red =
      let
        fun allClause2 cl_lit cl =
            let
              fun allLiteral2 lit =
                  case total (Clause.resolve cl_lit) (cl,lit) of
                    NONE => true
                  | SOME cl => allFactors red [cl]
            in
              LiteralSet.all allLiteral2 (Clause.literals cl)
            end orelse
            let
              val () = Print.trace Clause.pp
                         "Active.isSaturated.allResolutions: cl2" cl
            in
              false
            end

        fun allClause1 allCls cl =
            let
              val cl = Clause.freshVars cl

              fun allLiteral1 lit = List.all (allClause2 (cl,lit)) allCls
            in
              LiteralSet.all allLiteral1 (Clause.literals cl)
            end orelse
            let
              val () = Print.trace Clause.pp
                         "Active.isSaturated.allResolutions: cl1" cl
            in
              false
            end

      in
        fn [] => true
         | allCls as cl :: cls =>
           allClause1 allCls cl andalso allResolutions red cls
      end;

  fun allParamodulations red cls =
      let
        fun allClause2 cl_lit_ort_tm cl =
            let
              fun allLiteral2 lit =
                  let
                    val para = Clause.paramodulate cl_lit_ort_tm

                    fun allSubterms (path,tm) =
                        case total para (cl,lit,path,tm) of
                          NONE => true
                        | SOME cl => allFactors red [cl]
                  in
                    List.all allSubterms (Literal.nonVarTypedSubterms lit)
                  end orelse
                  let
                    val () = Print.trace Literal.pp
                               "Active.isSaturated.allParamodulations: lit2" lit
                  in
                    false
                  end
            in
              LiteralSet.all allLiteral2 (Clause.literals cl)
            end orelse
            let
              val () = Print.trace Clause.pp
                         "Active.isSaturated.allParamodulations: cl2" cl
              val (_,_,ort,_) = cl_lit_ort_tm
              val () = Print.trace Rewrite.ppOrient
                         "Active.isSaturated.allParamodulations: ort1" ort
            in
              false
            end

        fun allClause1 cl =
            let
              val cl = Clause.freshVars cl

              fun allLiteral1 lit =
                  let
                    fun allCl2 x = List.all (allClause2 x) cls
                  in
                    case total Literal.destEq lit of
                      NONE => true
                    | SOME (l,r) =>
                      allCl2 (cl,lit,Rewrite.LeftToRight,l) andalso
                      allCl2 (cl,lit,Rewrite.RightToLeft,r)
                  end orelse
                  let
                    val () = Print.trace Literal.pp
                               "Active.isSaturated.allParamodulations: lit1" lit
                  in
                    false
                  end
            in
              LiteralSet.all allLiteral1 (Clause.literals cl)
            end orelse
            let
              val () = Print.trace Clause.pp
                         "Active.isSaturated.allParamodulations: cl1" cl
            in
              false
            end
      in
        List.all allClause1 cls
      end;

  fun redundant {subsume,reduce,rewrite} =
      let
        fun simp cl =
            case Clause.simplify cl of
              NONE => true
            | SOME cl =>
              Subsume.isStrictlySubsumed subsume (Clause.literals cl) orelse
              let
                val cl' = cl
                val cl' = Clause.reduce reduce cl'
                val cl' = Clause.rewrite rewrite cl'
              in
                not (Clause.equalThms cl cl') andalso
                (simp cl' orelse
                 let
                   val () = Print.trace Clause.pp
                              "Active.isSaturated.redundant: cl'" cl'
                 in
                   false
                 end)
              end
      in
        fn cl =>
           simp cl orelse
           let
             val () = Print.trace Clause.pp
                        "Active.isSaturated.redundant: cl" cl
           in
             false
           end
      end;
in
  fun isSaturated ordering subs cls =
      let
        val rd = Units.empty
        val rw = mkRewrite ordering cls
        val red = redundant {subsume = subs, reduce = rd, rewrite = rw}
      in
        (allFactors red cls andalso
         allResolutions red cls andalso
         allParamodulations red cls) orelse
        let
          val () = Print.trace Rewrite.pp "Active.isSaturated: rw" rw
          val () = Print.trace (Print.ppList Clause.pp)
                     "Active.isSaturated: clauses" cls
        in
          false
        end
      end;
end;

fun checkSaturated ordering subs cls =
    if isSaturated ordering subs cls then ()
    else raise Bug "Active.checkSaturated";
*)

(* ------------------------------------------------------------------------- *)
(* A type of active clause sets.                                             *)
(* ------------------------------------------------------------------------- *)

type simplify = {subsume : bool, reduce : bool, rewrite : bool};

type parameters =
     {clause : Clause.parameters,
      prefactor : simplify,
      postfactor : simplify};

datatype active =
    Active of
      {parameters : parameters,
       clauses : Clause.clause IntMap.map,
       units : Units.units,
       rewrite : Rewrite.rewrite,
       subsume : Clause.clause Subsume.subsume,
       literals : (Clause.clause * Literal.literal) LiteralNet.literalNet,
       equations :
         (Clause.clause * Literal.literal * Rewrite.orient * Term.term)
         TermNet.termNet,
       subterms :
         (Clause.clause * Literal.literal * Term.path * Term.term)
         TermNet.termNet,
       allSubterms : (Clause.clause * Term.term) TermNet.termNet};

fun getSubsume (Active {subsume = s, ...}) = s;

fun setRewrite active rewrite =
    let
      val Active
            {parameters,clauses,units,subsume,literals,equations,
             subterms,allSubterms,...} = active
    in
      Active
        {parameters = parameters, clauses = clauses, units = units,
         rewrite = rewrite, subsume = subsume, literals = literals,
         equations = equations, subterms = subterms, allSubterms = allSubterms}
    end;

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val maxSimplify : simplify = {subsume = true, reduce = true, rewrite = true};

val default : parameters =
    {clause = Clause.default,
     prefactor = maxSimplify,
     postfactor = maxSimplify};

fun empty parameters =
    let
      val {clause,...} = parameters
      val {ordering,...} = clause
    in
      Active
        {parameters = parameters,
         clauses = IntMap.new (),
         units = Units.empty,
         rewrite = Rewrite.new (KnuthBendixOrder.compare ordering),
         subsume = Subsume.new (),
         literals = LiteralNet.new {fifo = false},
         equations = TermNet.new {fifo = false},
         subterms = TermNet.new {fifo = false},
         allSubterms = TermNet.new {fifo = false}}
    end;

fun size (Active {clauses,...}) = IntMap.size clauses;

fun clauses (Active {clauses = cls, ...}) =
    let
      fun add (_,cl,acc) = cl :: acc
    in
      IntMap.foldr add [] cls
    end;

fun saturation active =
    let
      fun remove (cl,(cls,subs)) =
          let
            val lits = Clause.literals cl
          in
            if Subsume.isStrictlySubsumed subs lits then (cls,subs)
            else (cl :: cls, Subsume.insert subs (lits,()))
          end

      val cls = clauses active
      val (cls,_) = foldl remove ([], Subsume.new ()) cls
      val (cls,subs) = foldl remove ([], Subsume.new ()) cls

(*MetisDebug
      val Active {parameters,...} = active
      val {clause,...} = parameters
      val {ordering,...} = clause
      val () = checkSaturated ordering subs cls
*)
    in
      cls
    end;

(* ------------------------------------------------------------------------- *)
(* Pretty printing.                                                          *)
(* ------------------------------------------------------------------------- *)

val pp =
    let
      fun toStr active = "Active{" ^ Int.toString (size active) ^ "}"
    in
      Print.ppMap toStr Print.ppString
    end;

(*MetisDebug
local
  fun ppField f ppA a =
      Print.blockProgram Print.Inconsistent 2
        [Print.addString (f ^ " ="),
         Print.addBreak 1,
         ppA a];

  val ppClauses =
      ppField "clauses"
        (Print.ppMap IntMap.toList
           (Print.ppList (Print.ppPair Print.ppInt Clause.pp)));

  val ppRewrite = ppField "rewrite" Rewrite.pp;

  val ppSubterms =
      ppField "subterms"
        (TermNet.pp
           (Print.ppMap (fn (c,l,p,t) => ((Clause.id c, l, p), t))
              (Print.ppPair
                 (Print.ppTriple Print.ppInt Literal.pp Term.ppPath)
                 Term.pp)));
in
  fun pp (Active {clauses,rewrite,subterms,...}) =
      Print.blockProgram Print.Inconsistent 2
        [Print.addString "Active",
         Print.addBreak 1,
         Print.blockProgram Print.Inconsistent 1
           [Print.addString "{",
            ppClauses clauses,
            Print.addString ",",
            Print.addBreak 1,
            ppRewrite rewrite,
(*MetisTrace5
            Print.addString ",",
            Print.addBreak 1,
            ppSubterms subterms,
*)
            Print.skip],
         Print.addString "}"];
end;
*)

val toString = Print.toString pp;

(* ------------------------------------------------------------------------- *)
(* Simplify clauses.                                                         *)
(* ------------------------------------------------------------------------- *)

fun simplify simp units rewr subs =
    let
      val {subsume = s, reduce = r, rewrite = w} = simp

      fun rewrite cl =
          let
            val cl' = Clause.rewrite rewr cl
          in
            if Clause.equalThms cl cl' then SOME cl else Clause.simplify cl'
          end
    in
      fn cl =>
         case Clause.simplify cl of
           NONE => NONE
         | SOME cl =>
           case (if w then rewrite cl else SOME cl) of
             NONE => NONE
           | SOME cl =>
             let
               val cl = if r then Clause.reduce units cl else cl
             in
               if s andalso Clause.subsumes subs cl then NONE else SOME cl
             end
    end;

(*MetisDebug
val simplify = fn simp => fn units => fn rewr => fn subs => fn cl =>
    let
      fun traceCl s = Print.trace Clause.pp ("Active.simplify: " ^ s)
(*MetisTrace4
      val ppClOpt = Print.ppOption Clause.pp
      val () = traceCl "cl" cl
*)
      val cl' = simplify simp units rewr subs cl
(*MetisTrace4
      val () = Print.trace ppClOpt "Active.simplify: cl'" cl'
*)
      val () =
          case cl' of
            NONE => ()
          | SOME cl' =>
            case
              (case simplify simp units rewr subs cl' of
                 NONE => SOME ("away", K ())
               | SOME cl'' =>
                 if Clause.equalThms cl' cl'' then NONE
                 else SOME ("further", fn () => traceCl "cl''" cl'')) of
              NONE => ()
            | SOME (e,f) =>
              let
                val () = traceCl "cl" cl
                val () = traceCl "cl'" cl'
                val () = f ()
              in
                raise
                  Bug
                    ("Active.simplify: clause should have been simplified "^e)
              end
    in
      cl'
    end;
*)

fun simplifyActive simp active =
    let
      val Active {units,rewrite,subsume,...} = active
    in
      simplify simp units rewrite subsume
    end;

(* ------------------------------------------------------------------------- *)
(* Add a clause into the active set.                                         *)
(* ------------------------------------------------------------------------- *)

fun addUnit units cl =
    let
      val th = Clause.thm cl
    in
      case total Thm.destUnit th of
        SOME lit => Units.add units (lit,th)
      | NONE => units
    end;

fun addRewrite rewrite cl =
    let
      val th = Clause.thm cl
    in
      case total Thm.destUnitEq th of
        SOME l_r => Rewrite.add rewrite (Clause.id cl, (l_r,th))
      | NONE => rewrite
    end;

fun addSubsume subsume cl = Subsume.insert subsume (Clause.literals cl, cl);

fun addLiterals literals cl =
    let
      fun add (lit as (_,atm), literals) =
          if Atom.isEq atm then literals
          else LiteralNet.insert literals (lit,(cl,lit))
    in
      LiteralSet.foldl add literals (Clause.largestLiterals cl)
    end;

fun addEquations equations cl =
    let
      fun add ((lit,ort,tm),equations) =
          TermNet.insert equations (tm,(cl,lit,ort,tm))
    in
      foldl add equations (Clause.largestEquations cl)
    end;

fun addSubterms subterms cl =
    let
      fun add ((lit,path,tm),subterms) =
          TermNet.insert subterms (tm,(cl,lit,path,tm))
    in
      foldl add subterms (Clause.largestSubterms cl)
    end;

fun addAllSubterms allSubterms cl =
    let
      fun add ((_,_,tm),allSubterms) =
          TermNet.insert allSubterms (tm,(cl,tm))
    in
      foldl add allSubterms (Clause.allSubterms cl)
    end;

fun addClause active cl =
    let
      val Active
            {parameters,clauses,units,rewrite,subsume,literals,
             equations,subterms,allSubterms} = active
      val clauses = IntMap.insert clauses (Clause.id cl, cl)
      and subsume = addSubsume subsume cl
      and literals = addLiterals literals cl
      and equations = addEquations equations cl
      and subterms = addSubterms subterms cl
      and allSubterms = addAllSubterms allSubterms cl
    in
      Active
        {parameters = parameters, clauses = clauses, units = units,
         rewrite = rewrite, subsume = subsume, literals = literals,
         equations = equations, subterms = subterms,
         allSubterms = allSubterms}
    end;

fun addFactorClause active cl =
    let
      val Active
            {parameters,clauses,units,rewrite,subsume,literals,
             equations,subterms,allSubterms} = active
      val units = addUnit units cl
      and rewrite = addRewrite rewrite cl
    in
      Active
        {parameters = parameters, clauses = clauses, units = units,
         rewrite = rewrite, subsume = subsume, literals = literals,
         equations = equations, subterms = subterms, allSubterms = allSubterms}
    end;

(* ------------------------------------------------------------------------- *)
(* Derive (unfactored) consequences of a clause.                             *)
(* ------------------------------------------------------------------------- *)

fun deduceResolution literals cl (lit as (_,atm), acc) =
    let
      fun resolve (cl_lit,acc) =
          case total (Clause.resolve cl_lit) (cl,lit) of
            SOME cl' => cl' :: acc
          | NONE => acc
(*MetisTrace4
      val () = Print.trace Literal.pp "Active.deduceResolution: lit" lit
*)
    in
      if Atom.isEq atm then acc
      else foldl resolve acc (LiteralNet.unify literals (Literal.negate lit))
    end;

fun deduceParamodulationWith subterms cl ((lit,ort,tm),acc) =
    let
      fun para (cl_lit_path_tm,acc) =
          case total (Clause.paramodulate (cl,lit,ort,tm)) cl_lit_path_tm of
            SOME cl' => cl' :: acc
          | NONE => acc
    in
      foldl para acc (TermNet.unify subterms tm)
    end;

fun deduceParamodulationInto equations cl ((lit,path,tm),acc) =
    let
      fun para (cl_lit_ort_tm,acc) =
          case total (Clause.paramodulate cl_lit_ort_tm) (cl,lit,path,tm) of
            SOME cl' => cl' :: acc
          | NONE => acc
    in
      foldl para acc (TermNet.unify equations tm)
    end;

fun deduce active cl =
    let
      val Active {parameters,literals,equations,subterms,...} = active

      val lits = Clause.largestLiterals cl
      val eqns = Clause.largestEquations cl
      val subtms =
          if TermNet.null equations then [] else Clause.largestSubterms cl
(*MetisTrace5
      val () = Print.trace LiteralSet.pp "Active.deduce: lits" lits
      val () = Print.trace
                 (Print.ppList
                    (Print.ppMap (fn (lit,ort,_) => (lit,ort))
                      (Print.ppPair Literal.pp Rewrite.ppOrient)))
                 "Active.deduce: eqns" eqns
      val () = Print.trace
                 (Print.ppList
                    (Print.ppTriple Literal.pp Term.ppPath Term.pp))
                 "Active.deduce: subtms" subtms
*)

      val acc = []
      val acc = LiteralSet.foldl (deduceResolution literals cl) acc lits
      val acc = foldl (deduceParamodulationWith subterms cl) acc eqns
      val acc = foldl (deduceParamodulationInto equations cl) acc subtms
      val acc = rev acc

(*MetisTrace5
      val () = Print.trace (Print.ppList Clause.pp) "Active.deduce: acc" acc
*)
    in
      acc
    end;

(* ------------------------------------------------------------------------- *)
(* Extract clauses from the active set that can be simplified.               *)
(* ------------------------------------------------------------------------- *)

local
  fun clause_rewritables active =
      let
        val Active {clauses,rewrite,...} = active

        fun rewr (id,cl,ids) =
            let
              val cl' = Clause.rewrite rewrite cl
            in
              if Clause.equalThms cl cl' then ids else IntSet.add ids id
            end
      in
        IntMap.foldr rewr IntSet.empty clauses
      end;

  fun orderedRedexResidues (((l,r),_),ort) =
      case ort of
        NONE => []
      | SOME Rewrite.LeftToRight => [(l,r,true)]
      | SOME Rewrite.RightToLeft => [(r,l,true)];

  fun unorderedRedexResidues (((l,r),_),ort) =
      case ort of
        NONE => [(l,r,false),(r,l,false)]
      | SOME _ => [];

  fun rewrite_rewritables active rewr_ids =
      let
        val Active {parameters,rewrite,clauses,allSubterms,...} = active
        val {clause = {ordering,...}, ...} = parameters
        val order = KnuthBendixOrder.compare ordering

        fun addRewr (id,acc) =
            if IntMap.inDomain id clauses then IntSet.add acc id else acc

        fun addReduce ((l,r,ord),acc) =
            let
              fun isValidRewr tm =
                  case total (Subst.match Subst.empty l) tm of
                    NONE => false
                  | SOME sub =>
                    ord orelse
                    let
                      val tm' = Subst.subst (Subst.normalize sub) r
                    in
                      order (tm,tm') = SOME GREATER
                    end

              fun addRed ((cl,tm),acc) =
                  let
(*MetisTrace5
                    val () = Print.trace Clause.pp "Active.addRed: cl" cl
                    val () = Print.trace Term.pp "Active.addRed: tm" tm
*)
                    val id = Clause.id cl
                  in
                    if IntSet.member id acc then acc
                    else if not (isValidRewr tm) then acc
                    else IntSet.add acc id
                  end

(*MetisTrace5
              val () = Print.trace Term.pp "Active.addReduce: l" l
              val () = Print.trace Term.pp "Active.addReduce: r" r
              val () = Print.trace Print.ppBool "Active.addReduce: ord" ord
*)
            in
              List.foldl addRed acc (TermNet.matched allSubterms l)
            end

        fun addEquation redexResidues (id,acc) =
            case Rewrite.peek rewrite id of
              NONE => acc
            | SOME eqn_ort => List.foldl addReduce acc (redexResidues eqn_ort)

        val addOrdered = addEquation orderedRedexResidues

        val addUnordered = addEquation unorderedRedexResidues

        val ids = IntSet.empty
        val ids = List.foldl addRewr ids rewr_ids
        val ids = List.foldl addOrdered ids rewr_ids
        val ids = List.foldl addUnordered ids rewr_ids
      in
        ids
      end;

  fun choose_clause_rewritables active ids = size active <= length ids

  fun rewritables active ids =
      if choose_clause_rewritables active ids then clause_rewritables active
      else rewrite_rewritables active ids;

(*MetisDebug
  val rewritables = fn active => fn ids =>
      let
        val clause_ids = clause_rewritables active
        val rewrite_ids = rewrite_rewritables active ids

        val () =
            if IntSet.equal rewrite_ids clause_ids then ()
            else
              let
                val ppIdl = Print.ppList Print.ppInt
                val ppIds = Print.ppMap IntSet.toList ppIdl
                val () = Print.trace pp "Active.rewritables: active" active
                val () = Print.trace ppIdl "Active.rewritables: ids" ids
                val () = Print.trace ppIds
                           "Active.rewritables: clause_ids" clause_ids
                val () = Print.trace ppIds
                           "Active.rewritables: rewrite_ids" rewrite_ids
              in
                raise Bug "Active.rewritables: ~(rewrite_ids SUBSET clause_ids)"
              end
      in
        if choose_clause_rewritables active ids then clause_ids else rewrite_ids
      end;
*)

  fun delete active ids =
      if IntSet.null ids then active
      else
        let
          fun idPred id = not (IntSet.member id ids)

          fun clausePred cl = idPred (Clause.id cl)

          val Active
                {parameters,
                 clauses,
                 units,
                 rewrite,
                 subsume,
                 literals,
                 equations,
                 subterms,
                 allSubterms} = active

          val clauses = IntMap.filter (idPred o fst) clauses
          and subsume = Subsume.filter clausePred subsume
          and literals = LiteralNet.filter (clausePred o #1) literals
          and equations = TermNet.filter (clausePred o #1) equations
          and subterms = TermNet.filter (clausePred o #1) subterms
          and allSubterms = TermNet.filter (clausePred o fst) allSubterms
        in
          Active
            {parameters = parameters,
             clauses = clauses,
             units = units,
             rewrite = rewrite,
             subsume = subsume,
             literals = literals,
             equations = equations,
             subterms = subterms,
             allSubterms = allSubterms}
        end;
in
  fun extract_rewritables (active as Active {clauses,rewrite,...}) =
      if Rewrite.isReduced rewrite then (active,[])
      else
        let
(*MetisTrace3
          val () = trace "Active.extract_rewritables: inter-reducing\n"
*)
          val (rewrite,ids) = Rewrite.reduce' rewrite
          val active = setRewrite active rewrite
          val ids = rewritables active ids
          val cls = IntSet.transform (IntMap.get clauses) ids
(*MetisTrace3
          val ppCls = Print.ppList Clause.pp
          val () = Print.trace ppCls "Active.extract_rewritables: cls" cls
*)
        in
          (delete active ids, cls)
        end
(*MetisDebug
        handle Error err =>
          raise Bug ("Active.extract_rewritables: shouldn't fail\n" ^ err);
*)
end;

(* ------------------------------------------------------------------------- *)
(* Factor clauses.                                                           *)
(* ------------------------------------------------------------------------- *)

local
  fun prefactor_simplify active subsume =
      let
        val Active {parameters,units,rewrite,...} = active
        val {prefactor,...} = parameters
      in
        simplify prefactor units rewrite subsume
      end;

  fun postfactor_simplify active subsume =
      let
        val Active {parameters,units,rewrite,...} = active
        val {postfactor,...} = parameters
      in
        simplify postfactor units rewrite subsume
      end;

  val sort_utilitywise =
      let
        fun utility cl =
            case LiteralSet.size (Clause.literals cl) of
              0 => ~1
            | 1 => if Thm.isUnitEq (Clause.thm cl) then 0 else 1
            | n => n
      in
        sortMap utility Int.compare
      end;

  fun factor_add (cl, active_subsume_acc as (active,subsume,acc)) =
      case postfactor_simplify active subsume cl of
        NONE => active_subsume_acc
      | SOME cl =>
        let
          val active = addFactorClause active cl
          and subsume = addSubsume subsume cl
          and acc = cl :: acc
        in
          (active,subsume,acc)
        end;

  fun factor1 (cl, active_subsume_acc as (active,subsume,_)) =
      case prefactor_simplify active subsume cl of
        NONE => active_subsume_acc
      | SOME cl =>
        let
          val cls = sort_utilitywise (cl :: Clause.factor cl)
        in
          foldl factor_add active_subsume_acc cls
        end;

  fun factor' active acc [] = (active, rev acc)
    | factor' active acc cls =
      let
        val cls = sort_utilitywise cls
        val subsume = getSubsume active
        val (active,_,acc) = foldl factor1 (active,subsume,acc) cls
        val (active,cls) = extract_rewritables active
      in
        factor' active acc cls
      end;
in
  fun factor active cls = factor' active [] cls;
end;

(*MetisTrace4
val factor = fn active => fn cls =>
    let
      val ppCls = Print.ppList Clause.pp
      val () = Print.trace ppCls "Active.factor: cls" cls
      val (active,cls') = factor active cls
      val () = Print.trace ppCls "Active.factor: cls'" cls'
    in
      (active,cls')
    end;
*)

(* ------------------------------------------------------------------------- *)
(* Create a new active clause set and initialize clauses.                    *)
(* ------------------------------------------------------------------------- *)

fun new parameters {axioms,conjecture} =
    let
      val {clause,...} = parameters

      fun mk_clause th =
          Clause.mk {parameters = clause, id = Clause.newId (), thm = th}

      val active = empty parameters
      val (active,axioms) = factor active (map mk_clause axioms)
      val (active,conjecture) = factor active (map mk_clause conjecture)
    in
      (active, {axioms = axioms, conjecture = conjecture})
    end;

(* ------------------------------------------------------------------------- *)
(* Add a clause into the active set and deduce all consequences.             *)
(* ------------------------------------------------------------------------- *)

fun add active cl =
    case simplifyActive maxSimplify active cl of
      NONE => (active,[])
    | SOME cl' =>
      if Clause.isContradiction cl' then (active,[cl'])
      else if not (Clause.equalThms cl cl') then factor active [cl']
      else
        let
(*MetisTrace2
          val () = Print.trace Clause.pp "Active.add: cl" cl
*)
          val active = addClause active cl
          val cl = Clause.freshVars cl
          val cls = deduce active cl
          val (active,cls) = factor active cls
(*MetisTrace2
          val ppCls = Print.ppList Clause.pp
          val () = Print.trace ppCls "Active.add: cls" cls
*)
        in
          (active,cls)
        end;

end
end;

(**** Original file: Waiting.sig ****)

(* ========================================================================= *)
(* THE WAITING SET OF CLAUSES                                                *)
(* Copyright (c) 2002-2007 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Waiting =
sig

(* ------------------------------------------------------------------------- *)
(* The parameters control the order that clauses are removed from the        *)
(* waiting set: clauses are assigned a weight and removed in strict weight   *)
(* order, with smaller weights being removed before larger weights.          *)
(*                                                                           *)
(* The weight of a clause is defined to be                                   *)
(*                                                                           *)
(*   d * s^symbolsWeight * v^variablesWeight * l^literalsWeight * m          *)
(*                                                                           *)
(* where                                                                     *)
(*                                                                           *)
(*   d = the derivation distance of the clause from the axioms               *)
(*   s = the number of symbols in the clause                                 *)
(*   v = the number of distinct variables in the clause                      *)
(*   l = the number of literals in the clause                                *)
(*   m = the truth of the clause wrt the models                              *)
(* ------------------------------------------------------------------------- *)

type weight = real

type modelParameters =
     {model : Metis.Model.parameters,
      initialPerturbations : int,
      maxChecks : int option,
      perturbations : int,
      weight : weight}

type parameters =
     {symbolsWeight : weight,
      variablesWeight : weight,
      literalsWeight : weight,
      models : modelParameters list}

(* ------------------------------------------------------------------------- *)
(* A type of waiting sets of clauses.                                        *)
(* ------------------------------------------------------------------------- *)

type waiting

type distance

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val default : parameters

val new :
    parameters ->
    {axioms : Metis.Clause.clause list,
     conjecture : Metis.Clause.clause list} -> waiting

val size : waiting -> int

val pp : waiting Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* Adding new clauses.                                                       *)
(* ------------------------------------------------------------------------- *)

val add : waiting -> distance * Metis.Clause.clause list -> waiting

(* ------------------------------------------------------------------------- *)
(* Removing the lightest clause.                                             *)
(* ------------------------------------------------------------------------- *)

val remove : waiting -> ((distance * Metis.Clause.clause) * waiting) option

end

(**** Original file: Waiting.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* THE WAITING SET OF CLAUSES                                                *)
(* Copyright (c) 2002-2007 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Waiting :> Waiting =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of waiting sets of clauses.                                        *)
(* ------------------------------------------------------------------------- *)

type weight = real;

type modelParameters =
     {model : Model.parameters,
      initialPerturbations : int,
      maxChecks : int option,
      perturbations : int,
      weight : weight}

type parameters =
     {symbolsWeight : weight,
      variablesWeight : weight,
      literalsWeight : weight,
      models : modelParameters list};

type distance = real;

datatype waiting =
    Waiting of
      {parameters : parameters,
       clauses : (weight * (distance * Clause.clause)) Heap.heap,
       models : Model.model list};

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val defaultModels : modelParameters list =
    [(* MODIFIED by Jasmin Blanchette
      {model = Model.default,
      initialPerturbations = 100,
      maxChecks = SOME 20,
      perturbations = 0,
      weight = 1.0} *)];

val default : parameters =
     {symbolsWeight = 1.0,
      literalsWeight = (* 1.0 *) 0.0, (* MODIFIED by Jasmin Blanchette *)
      variablesWeight = (* 1.0 *) 0.0, (* MODIFIED by Jasmin Blanchette *)
      models = defaultModels};

fun size (Waiting {clauses,...}) = Heap.size clauses;

val pp =
    Print.ppMap
      (fn w => "Waiting{" ^ Int.toString (size w) ^ "}")
      Print.ppString;

(*MetisDebug
val pp =
    Print.ppMap
      (fn Waiting {clauses,...} =>
          map (fn (w,(_,cl)) => (w, Clause.id cl, cl)) (Heap.toList clauses))
      (Print.ppList (Print.ppTriple Print.ppReal Print.ppInt Clause.pp));
*)

(* ------------------------------------------------------------------------- *)
(* Perturbing the models.                                                    *)
(* ------------------------------------------------------------------------- *)

type modelClause = NameSet.set * Thm.clause;

fun mkModelClause cl =
    let
      val lits = Clause.literals cl
      val fvs = LiteralSet.freeVars lits
    in
      (fvs,lits)
    end;

val mkModelClauses = map mkModelClause;

fun perturbModel M cls =
    if null cls then K ()
    else
      let
        val N = {size = Model.size M}

        fun perturbClause (fv,cl) =
            let
              val V = Model.randomValuation N fv
            in
              if Model.interpretClause M V cl then ()
              else Model.perturbClause M V cl
            end

        fun perturbClauses () = app perturbClause cls
      in
        fn n => funpow n perturbClauses ()
      end;

fun initialModel axioms conjecture parm =
    let
      val {model,initialPerturbations,...} : modelParameters = parm
      val m = Model.new model
      val () = perturbModel m conjecture initialPerturbations
      val () = perturbModel m axioms initialPerturbations
    in
      m
    end;

fun checkModels parms models (fv,cl) =
    let
      fun check ((parm,model),z) =
          let
            val {maxChecks,weight,...} : modelParameters = parm
            val n = {maxChecks = maxChecks}
            val {T,F} = Model.check Model.interpretClause n model fv cl
          in
            Math.pow (1.0 + Real.fromInt T / Real.fromInt (T + F), weight) * z
          end
    in
      List.foldl check 1.0 (zip parms models)
    end;

fun perturbModels parms models cls =
    let
      fun perturb (parm,model) =
          let
            val {perturbations,...} : modelParameters = parm
          in
            perturbModel model cls perturbations
          end
    in
      app perturb (zip parms models)
    end;

(* ------------------------------------------------------------------------- *)
(* Clause weights.                                                           *)
(* ------------------------------------------------------------------------- *)

local
  fun clauseSymbols cl = Real.fromInt (LiteralSet.typedSymbols cl);

  fun clauseVariables cl =
      Real.fromInt (NameSet.size (LiteralSet.freeVars cl) + 1);

  fun clauseLiterals cl = Real.fromInt (LiteralSet.size cl);

  fun clausePriority cl = 1e~12 * Real.fromInt (Clause.id cl);
in
  fun clauseWeight (parm : parameters) mods dist mcl cl =
      let
(*MetisTrace3
        val () = Print.trace Clause.pp "Waiting.clauseWeight: cl" cl
*)
        val {symbolsWeight,variablesWeight,literalsWeight,models,...} = parm
        val lits = Clause.literals cl
        val symbolsW = Math.pow (clauseSymbols lits, symbolsWeight)
        val variablesW = Math.pow (clauseVariables lits, variablesWeight)
        val literalsW = Math.pow (clauseLiterals lits, literalsWeight)
        val modelsW = (* checkModels models mods mcl *) 1.0 (* MODIFIED by Jasmin Blanchette *)
(*MetisTrace4
        val () = trace ("Waiting.clauseWeight: dist = " ^
                        Real.toString dist ^ "\n")
        val () = trace ("Waiting.clauseWeight: symbolsW = " ^
                        Real.toString symbolsW ^ "\n")
        val () = trace ("Waiting.clauseWeight: variablesW = " ^
                        Real.toString variablesW ^ "\n")
        val () = trace ("Waiting.clauseWeight: literalsW = " ^
                        Real.toString literalsW ^ "\n")
        val () = trace ("Waiting.clauseWeight: modelsW = " ^
                        Real.toString modelsW ^ "\n")
*)
        val weight = dist * symbolsW * variablesW * literalsW * modelsW
        val weight = weight + clausePriority cl
(*MetisTrace3
        val () = trace ("Waiting.clauseWeight: weight = " ^
                        Real.toString weight ^ "\n")
*)
      in
        weight
      end;
end;

(* ------------------------------------------------------------------------- *)
(* Adding new clauses.                                                       *)
(* ------------------------------------------------------------------------- *)

fun add' waiting dist mcls cls =
    let
      val Waiting {parameters,clauses,models} = waiting
      val {models = modelParameters, ...} = parameters

      val dist = dist + Math.ln (Real.fromInt (length cls))

      fun addCl ((mcl,cl),acc) =
          let
            val weight = clauseWeight parameters models dist mcl cl
          in
            Heap.add acc (weight,(dist,cl))
          end

      val clauses = List.foldl addCl clauses (zip mcls cls)

      val () = perturbModels modelParameters models mcls
    in
      Waiting {parameters = parameters, clauses = clauses, models = models}
    end;

fun add waiting (_,[]) = waiting
  | add waiting (dist,cls) =
    let
(*MetisTrace3
      val () = Print.trace pp "Waiting.add: waiting" waiting
      val () = Print.trace (Print.ppList Clause.pp) "Waiting.add: cls" cls
*)

      val waiting = add' waiting dist (mkModelClauses cls) cls

(*MetisTrace3
      val () = Print.trace pp "Waiting.add: waiting" waiting
*)
    in
      waiting
    end;

local
  fun cmp ((w1,_),(w2,_)) = Real.compare (w1,w2);

  fun empty parameters axioms conjecture =
      let
        val {models = modelParameters, ...} = parameters
        val clauses = Heap.new cmp
        and models = map (initialModel axioms conjecture) modelParameters
      in
        Waiting {parameters = parameters, clauses = clauses, models = models}
      end;
in
  fun new parameters {axioms,conjecture} =
      let
        val mAxioms = mkModelClauses axioms
        and mConjecture = mkModelClauses conjecture

        val waiting = empty parameters mAxioms mConjecture
      in
        add' waiting 0.0 (mAxioms @ mConjecture) (axioms @ conjecture)
      end;
end;

(* ------------------------------------------------------------------------- *)
(* Removing the lightest clause.                                             *)
(* ------------------------------------------------------------------------- *)

fun remove (Waiting {parameters,clauses,models}) =
    if Heap.null clauses then NONE
    else
      let
        val ((_,dcl),clauses) = Heap.remove clauses
        val waiting =
            Waiting
              {parameters = parameters, clauses = clauses, models = models}
      in
        SOME (dcl,waiting)
      end;

end
end;

(**** Original file: Resolution.sig ****)

(* ========================================================================= *)
(* THE RESOLUTION PROOF PROCEDURE                                            *)
(* Copyright (c) 2001-2007 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

signature Resolution =
sig

(* ------------------------------------------------------------------------- *)
(* A type of resolution proof procedures.                                    *)
(* ------------------------------------------------------------------------- *)

type parameters =
     {active : Metis.Active.parameters,
      waiting : Metis.Waiting.parameters}

type resolution

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val default : parameters

val new :
    parameters -> {axioms : Metis.Thm.thm list, conjecture : Metis.Thm.thm list} ->
    resolution

val active : resolution -> Metis.Active.active

val waiting : resolution -> Metis.Waiting.waiting

val pp : resolution Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* The main proof loop.                                                      *)
(* ------------------------------------------------------------------------- *)

datatype decision =
    Contradiction of Metis.Thm.thm
  | Satisfiable of Metis.Thm.thm list

datatype state =
    Decided of decision
  | Undecided of resolution

val iterate : resolution -> state

val loop : resolution -> decision

end

(**** Original file: Resolution.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* THE RESOLUTION PROOF PROCEDURE                                            *)
(* Copyright (c) 2001-2007 Joe Hurd, distributed under the BSD License       *)
(* ========================================================================= *)

structure Resolution :> Resolution =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of resolution proof procedures.                                    *)
(* ------------------------------------------------------------------------- *)

type parameters =
     {active : Active.parameters,
      waiting : Waiting.parameters};

datatype resolution =
    Resolution of
      {parameters : parameters,
       active : Active.active,
       waiting : Waiting.waiting};

(* ------------------------------------------------------------------------- *)
(* Basic operations.                                                         *)
(* ------------------------------------------------------------------------- *)

val default : parameters =
    {active = Active.default,
     waiting = Waiting.default};

fun new parameters ths =
    let
      val {active = activeParm, waiting = waitingParm} = parameters
      val (active,cls) = Active.new activeParm ths  (* cls = factored ths *)
      val waiting = Waiting.new waitingParm cls
    in
      Resolution {parameters = parameters, active = active, waiting = waiting}
    end;

fun active (Resolution {active = a, ...}) = a;

fun waiting (Resolution {waiting = w, ...}) = w;

val pp =
    Print.ppMap
      (fn Resolution {active,waiting,...} =>
          "Resolution(" ^ Int.toString (Active.size active) ^
          "<-" ^ Int.toString (Waiting.size waiting) ^ ")")
      Print.ppString;

(* ------------------------------------------------------------------------- *)
(* The main proof loop.                                                      *)
(* ------------------------------------------------------------------------- *)

datatype decision =
    Contradiction of Thm.thm
  | Satisfiable of Thm.thm list;

datatype state =
    Decided of decision
  | Undecided of resolution;

fun iterate resolution =
    let
      val Resolution {parameters,active,waiting} = resolution
(*MetisTrace2
      val () = Print.trace Active.pp "Resolution.iterate: active" active
      val () = Print.trace Waiting.pp "Resolution.iterate: waiting" waiting
*)
    in
      case Waiting.remove waiting of
        NONE =>
        Decided (Satisfiable (map Clause.thm (Active.saturation active)))
      | SOME ((d,cl),waiting) =>
        if Clause.isContradiction cl then
          Decided (Contradiction (Clause.thm cl))
        else
          let
(*MetisTrace1
            val () = Print.trace Clause.pp "Resolution.iterate: cl" cl
*)
            val (active,cls) = Active.add active cl
            val waiting = Waiting.add waiting (d,cls)
          in
            Undecided
              (Resolution
                 {parameters = parameters, active = active, waiting = waiting})
          end
    end;

fun loop resolution =
    case iterate resolution of
      Decided decision => decision
    | Undecided resolution => loop resolution;

end
end;

(**** Original file: Tptp.sig ****)

(* ========================================================================= *)
(* THE TPTP PROBLEM FILE FORMAT                                              *)
(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

signature Tptp =
sig

(* ------------------------------------------------------------------------- *)
(* Mapping to and from TPTP variable, function and relation names.           *)
(* ------------------------------------------------------------------------- *)

type mapping

val defaultMapping : mapping

val mkMapping :
    {functionMapping : {name : Metis.Name.name, arity : int, tptp : string} list,
     relationMapping : {name : Metis.Name.name, arity : int, tptp : string} list} ->
    mapping

val addVarSetMapping : mapping -> Metis.NameSet.set -> mapping

(* ------------------------------------------------------------------------- *)
(* Interpreting TPTP functions and relations in a finite model.              *)
(* ------------------------------------------------------------------------- *)

val defaultFixedMap : Metis.Model.fixedMap

val defaultModel : Metis.Model.parameters

val ppFixedMap : mapping -> Metis.Model.fixedMap Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* TPTP roles.                                                               *)
(* ------------------------------------------------------------------------- *)

datatype role =
    AxiomRole
  | ConjectureRole
  | DefinitionRole
  | NegatedConjectureRole
  | PlainRole
  | TheoremRole
  | OtherRole of string;

val isCnfConjectureRole : role -> bool

val isFofConjectureRole : role -> bool

val toStringRole : role -> string

val fromStringRole : string -> role

val ppRole : role Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* SZS statuses.                                                             *)
(* ------------------------------------------------------------------------- *)

datatype status =
    CounterSatisfiableStatus
  | TheoremStatus
  | SatisfiableStatus
  | UnknownStatus
  | UnsatisfiableStatus

val toStringStatus : status -> string

val ppStatus : status Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* TPTP literals.                                                            *)
(* ------------------------------------------------------------------------- *)

datatype literal =
    Boolean of bool
  | Literal of Metis.Literal.literal

val negateLiteral : literal -> literal

val functionsLiteral : literal -> Metis.NameAritySet.set

val relationLiteral : literal -> Metis.Atom.relation option

val freeVarsLiteral : literal -> Metis.NameSet.set

(* ------------------------------------------------------------------------- *)
(* TPTP formula names.                                                       *)
(* ------------------------------------------------------------------------- *)

datatype formulaName =
    FormulaName of string

val ppFormulaName : formulaName Metis.Print.pp

(* ------------------------------------------------------------------------- *)
(* TPTP formula bodies.                                                      *)
(* ------------------------------------------------------------------------- *)

datatype formulaBody =
    CnfFormulaBody of literal list
  | FofFormulaBody of Metis.Formula.formula

(* ------------------------------------------------------------------------- *)
(* TPTP formula sources.                                                     *)
(* ------------------------------------------------------------------------- *)

datatype formulaSource =
    NoFormulaSource
  | StripFormulaSource of
      {inference : string,
       parents : formulaName list}
  | NormalizeFormulaSource of
      {inference : Metis.Normalize.inference,
       parents : formulaName list}
  | ProofFormulaSource of
      {inference : Metis.Proof.inference,
       parents : formulaName list}

(* ------------------------------------------------------------------------- *)
(* TPTP formulas.                                                            *)
(* ------------------------------------------------------------------------- *)

datatype formula =
    Formula of
      {name : formulaName,
       role : role,
       body : formulaBody,
       source : formulaSource}

val nameFormula : formula -> formulaName

val roleFormula : formula -> role

val bodyFormula : formula -> formulaBody

val sourceFormula : formula -> formulaSource

val functionsFormula : formula -> Metis.NameAritySet.set

val relationsFormula : formula -> Metis.NameAritySet.set

val freeVarsFormula : formula -> Metis.NameSet.set

val freeVarsListFormula : formula list -> Metis.NameSet.set

val isCnfConjectureFormula : formula -> bool
val isFofConjectureFormula : formula -> bool
val isConjectureFormula : formula -> bool

(* ------------------------------------------------------------------------- *)
(* Clause information.                                                       *)
(* ------------------------------------------------------------------------- *)

datatype clauseSource =
    CnfClauseSource of formulaName * literal list
  | FofClauseSource of Metis.Normalize.thm

type 'a clauseInfo = 'a Metis.LiteralSetMap.map

type clauseNames = formulaName clauseInfo

type clauseRoles = role clauseInfo

type clauseSources = clauseSource clauseInfo

val noClauseNames : clauseNames

val noClauseRoles : clauseRoles

val noClauseSources : clauseSources

(* ------------------------------------------------------------------------- *)
(* TPTP problems.                                                            *)
(* ------------------------------------------------------------------------- *)

type comments = string list

type includes = string list

datatype problem =
    Problem of
      {comments : comments,
       includes : includes,
       formulas : formula list}

val hasCnfConjecture : problem -> bool
val hasFofConjecture : problem -> bool
val hasConjecture : problem -> bool

val freeVars : problem -> Metis.NameSet.set

val mkProblem :
    {comments : comments,
     includes : includes,
     names : clauseNames,
     roles : clauseRoles,
     problem : Metis.Problem.problem} -> problem

val normalize :
    problem ->
    {subgoal : Metis.Formula.formula * formulaName list,
     problem : Metis.Problem.problem,
     sources : clauseSources} list

val goal : problem -> Metis.Formula.formula

val read : {mapping : mapping, filename : string} -> problem

val write :
    {problem : problem,
     mapping : mapping,
     filename : string} -> unit

val prove : {filename : string, mapping : mapping} -> bool

(* ------------------------------------------------------------------------- *)
(* TSTP proofs.                                                              *)
(* ------------------------------------------------------------------------- *)

val fromProof :
    {problem : problem,
     proofs : {subgoal : Metis.Formula.formula * formulaName list,
               sources : clauseSources,
               refutation : Metis.Thm.thm} list} -> formula list

end

(**** Original file: Tptp.sml ****)

structure Metis = struct open Metis
(* Metis-specific ML environment *)
nonfix ++ -- RL;
val explode = String.explode;
val implode = String.implode;
val print = TextIO.print;
val foldl = List.foldl;
val foldr = List.foldr;

(* ========================================================================= *)
(* THE TPTP PROBLEM FILE FORMAT                                              *)
(* Copyright (c) 2001 Joe Hurd, distributed under the BSD License            *)
(* ========================================================================= *)

structure Tptp :> Tptp =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Default TPTP function and relation name mapping.                          *)
(* ------------------------------------------------------------------------- *)

val defaultFunctionMapping =
    [(* Mapping TPTP functions to infix symbols *)
     {name = "~", arity = 1, tptp = "negate"},
     {name = "*", arity = 2, tptp = "multiply"},
     {name = "/", arity = 2, tptp = "divide"},
     {name = "+", arity = 2, tptp = "add"},
     {name = "-", arity = 2, tptp = "subtract"},
     {name = "::", arity = 2, tptp = "cons"},
     {name = "@", arity = 2, tptp = "append"},
     {name = ",", arity = 2, tptp = "pair"},
     (* Expanding HOL symbols to TPTP alphanumerics *)
     {name = ":", arity = 2, tptp = "has_type"},
     {name = ".", arity = 2, tptp = "apply"}];

val defaultRelationMapping =
    [(* Mapping TPTP relations to infix symbols *)
     {name = "=", arity = 2, tptp = "="},  (* this preserves the = symbol *)
     {name = "==", arity = 2, tptp = "equalish"},
     {name = "<=", arity = 2, tptp = "less_equal"},
     {name = "<", arity = 2, tptp = "less_than"},
     {name = ">=", arity = 2, tptp = "greater_equal"},
     {name = ">", arity = 2, tptp = "greater_than"},
     (* Expanding HOL symbols to TPTP alphanumerics *)
     {name = "{}", arity = 1, tptp = "bool"}];

(* ------------------------------------------------------------------------- *)
(* Interpreting TPTP functions and relations in a finite model.              *)
(* ------------------------------------------------------------------------- *)

val defaultFunctionModel =
    [{name = "~", arity = 1, model = Model.negName},
     {name = "*", arity = 2, model = Model.multName},
     {name = "/", arity = 2, model = Model.divName},
     {name = "+", arity = 2, model = Model.addName},
     {name = "-", arity = 2, model = Model.subName},
     {name = "::", arity = 2, model = Model.consName},
     {name = "@", arity = 2, model = Model.appendName},
     {name = ":", arity = 2, model = Term.hasTypeFunctionName},
     {name = "additive_identity", arity = 0, model = Model.numeralName 0},
     {name = "app", arity = 2, model = Model.appendName},
     {name = "complement", arity = 1, model = Model.complementName},
     {name = "difference", arity = 2, model = Model.differenceName},
     {name = "divide", arity = 2, model = Model.divName},
     {name = "empty_set", arity = 0, model = Model.emptyName},
     {name = "identity", arity = 0, model = Model.numeralName 1},
     {name = "identity_map", arity = 1, model = Model.projectionName 1},
     {name = "intersection", arity = 2, model = Model.intersectName},
     {name = "minus", arity = 1, model = Model.negName},
     {name = "multiplicative_identity", arity = 0, model = Model.numeralName 1},
     {name = "n0", arity = 0, model = Model.numeralName 0},
     {name = "n1", arity = 0, model = Model.numeralName 1},
     {name = "n2", arity = 0, model = Model.numeralName 2},
     {name = "n3", arity = 0, model = Model.numeralName 3},
     {name = "n4", arity = 0, model = Model.numeralName 4},
     {name = "n5", arity = 0, model = Model.numeralName 5},
     {name = "n6", arity = 0, model = Model.numeralName 6},
     {name = "n7", arity = 0, model = Model.numeralName 7},
     {name = "n8", arity = 0, model = Model.numeralName 8},
     {name = "n9", arity = 0, model = Model.numeralName 9},
     {name = "nil", arity = 0, model = Model.nilName},
     {name = "null_class", arity = 0, model = Model.emptyName},
     {name = "singleton", arity = 1, model = Model.singletonName},
     {name = "successor", arity = 1, model = Model.sucName},
     {name = "symmetric_difference", arity = 2,
      model = Model.symmetricDifferenceName},
     {name = "union", arity = 2, model = Model.unionName},
     {name = "universal_class", arity = 0, model = Model.universeName}];

val defaultRelationModel =
    [{name = "=", arity = 2, model = Atom.eqRelationName},
     {name = "==", arity = 2, model = Atom.eqRelationName},
     {name = "<=", arity = 2, model = Model.leName},
     {name = "<", arity = 2, model = Model.ltName},
     {name = ">=", arity = 2, model = Model.geName},
     {name = ">", arity = 2, model = Model.gtName},
     {name = "divides", arity = 2, model = Model.dividesName},
     {name = "element_of_set", arity = 2, model = Model.memberName},
     {name = "equal", arity = 2, model = Atom.eqRelationName},
     {name = "equal_elements", arity = 2, model = Atom.eqRelationName},
     {name = "equal_sets", arity = 2, model = Atom.eqRelationName},
     {name = "equivalent", arity = 2, model = Atom.eqRelationName},
     {name = "less", arity = 2, model = Model.ltName},
     {name = "less_or_equal", arity = 2, model = Model.leName},
     {name = "member", arity = 2, model = Model.memberName},
     {name = "subclass", arity = 2, model = Model.subsetName},
     {name = "subset", arity = 2, model = Model.subsetName}];

(* ------------------------------------------------------------------------- *)
(* Helper functions.                                                         *)
(* ------------------------------------------------------------------------- *)

fun isHdTlString hp tp s =
    let
      fun ct 0 = true
        | ct i = tp (String.sub (s,i)) andalso ct (i - 1)

      val n = size s
    in
      n > 0 andalso hp (String.sub (s,0)) andalso ct (n - 1)
    end;

fun stripSuffix pred s =
    let
      fun f 0 = ""
        | f n =
          let
            val n' = n - 1
          in
            if pred (String.sub (s,n')) then f n'
            else String.substring (s,0,n)
          end
    in
      f (size s)
    end;

fun variant avoid s =
    if not (StringSet.member s avoid) then s
    else
      let
        val s = stripSuffix Char.isDigit s

        fun var i =
            let
              val s_i = s ^ Int.toString i
            in
              if not (StringSet.member s_i avoid) then s_i else var (i + 1)
            end
      in
        var 0
      end;

(* ------------------------------------------------------------------------- *)
(* Mapping to legal TPTP names.                                              *)
(* ------------------------------------------------------------------------- *)

local
  fun nonEmptyPred p l =
      case l of
        [] => false
      | c :: cs => p (c,cs);

  fun existsPred l x = List.exists (fn p => p x) l;

  fun isTptpChar #"_" = true
    | isTptpChar c = Char.isAlphaNum c;

  fun isTptpName l s = nonEmptyPred (existsPred l) (explode s);

  fun isRegular (c,cs) =
      Char.isLower c andalso List.all isTptpChar cs;

  fun isNumber (c,cs) =
      Char.isDigit c andalso List.all Char.isDigit cs;

  fun isDefined (c,cs) =
      c = #"$" andalso nonEmptyPred isRegular cs;

  fun isSystem (c,cs) =
      c = #"$" andalso nonEmptyPred isDefined cs;
in
  fun mkTptpVarName s =
      let
        val s =
            case List.filter isTptpChar (explode s) of
              [] => [#"X"]
            | l as c :: cs =>
              if Char.isUpper c then l
              else if Char.isLower c then Char.toUpper c :: cs
              else #"X" :: l
      in
        implode s
      end;

  val isTptpConstName = isTptpName [isRegular,isNumber,isDefined,isSystem]
  and isTptpFnName = isTptpName [isRegular,isDefined,isSystem]
  and isTptpPropName = isTptpName [isRegular,isDefined,isSystem]
  and isTptpRelName = isTptpName [isRegular,isDefined,isSystem];

  val isTptpFormulaName = isTptpName [isRegular,isNumber];
end;

(* ------------------------------------------------------------------------- *)
(* Mapping to legal TPTP variable names.                                     *)
(* ------------------------------------------------------------------------- *)

datatype varToTptp = VarToTptp of StringSet.set * string NameMap.map;

val emptyVarToTptp = VarToTptp (StringSet.empty, NameMap.new ());

fun addVarToTptp vm v =
    let
      val VarToTptp (avoid,mapping) = vm
    in
      if NameMap.inDomain v mapping then vm
      else
        let
          val s = variant avoid (mkTptpVarName (Name.toString v))

          val avoid = StringSet.add avoid s
          and mapping = NameMap.insert mapping (v,s)
        in
          VarToTptp (avoid,mapping)
        end
    end;

local
  fun add (v,vm) = addVarToTptp vm v;
in
  val addListVarToTptp = List.foldl add;

  val addSetVarToTptp = NameSet.foldl add;
end;

val fromListVarToTptp = addListVarToTptp emptyVarToTptp;

val fromSetVarToTptp = addSetVarToTptp emptyVarToTptp;

fun getVarToTptp vm v =
    let
      val VarToTptp (_,mapping) = vm
    in
      case NameMap.peek mapping v of
        SOME s => s
      | NONE => raise Bug "Tptp.getVarToTptp: unknown var"
    end;

(* ------------------------------------------------------------------------- *)
(* Mapping from TPTP variable names.                                         *)
(* ------------------------------------------------------------------------- *)

fun getVarFromTptp s = Name.fromString s;

(* ------------------------------------------------------------------------- *)
(* Mapping to TPTP function and relation names.                              *)
(* ------------------------------------------------------------------------- *)

datatype nameToTptp = NameToTptp of string NameArityMap.map;

local
  val emptyNames : string NameArityMap.map = NameArityMap.new ();

  fun addNames ({name,arity,tptp},mapping) =
      NameArityMap.insert mapping ((name,arity),tptp);

  val fromListNames = List.foldl addNames emptyNames;
in
  fun mkNameToTptp mapping = NameToTptp (fromListNames mapping);
end;

local
  fun escapeChar c =
      case c of
        #"\\" => "\\\\"
      | #"'" => "\\'"
      | #"\n" => "\\n"
      | #"\t" => "\\t"
      | _ => str c;

  val escapeString = String.translate escapeChar;
in
  fun singleQuote s = "'" ^ escapeString s ^ "'";
end;

fun getNameToTptp isTptp s = if isTptp s then s else singleQuote s;

fun getNameArityToTptp isZeroTptp isPlusTptp (NameToTptp mapping) na =
    case NameArityMap.peek mapping na of
      SOME s => s
    | NONE =>
      let
        val (n,a) = na
        val isTptp = if a = 0 then isZeroTptp else isPlusTptp
        val s = Name.toString n
      in
        getNameToTptp isTptp s
      end;

(* ------------------------------------------------------------------------- *)
(* Mapping from TPTP function and relation names.                            *)
(* ------------------------------------------------------------------------- *)

datatype nameFromTptp = NameFromTptp of (string * int, Name.name) Map.map;

local
  val stringArityCompare = prodCompare String.compare Int.compare;

  val emptyStringArityMap = Map.new stringArityCompare;

  fun addStringArityMap ({name,arity,tptp},mapping) =
      Map.insert mapping ((tptp,arity),name);

  val fromListStringArityMap =
      List.foldl addStringArityMap emptyStringArityMap;
in
  fun mkNameFromTptp mapping = NameFromTptp (fromListStringArityMap mapping);
end;

fun getNameFromTptp (NameFromTptp mapping) sa =
    case Map.peek mapping sa of
      SOME n => n
    | NONE =>
      let
        val (s,_) = sa
      in
        Name.fromString s
      end;

(* ------------------------------------------------------------------------- *)
(* Mapping to and from TPTP variable, function and relation names.           *)
(* ------------------------------------------------------------------------- *)

datatype mapping =
    Mapping of
      {varTo : varToTptp,
       fnTo : nameToTptp,
       relTo : nameToTptp,
       fnFrom : nameFromTptp,
       relFrom : nameFromTptp};

fun mkMapping mapping =
    let
      val {functionMapping,relationMapping} = mapping

      val varTo = emptyVarToTptp
      val fnTo = mkNameToTptp functionMapping
      val relTo = mkNameToTptp relationMapping

      val fnFrom = mkNameFromTptp functionMapping
      val relFrom = mkNameFromTptp relationMapping
    in
      Mapping
        {varTo = varTo,
         fnTo = fnTo,
         relTo = relTo,
         fnFrom = fnFrom,
         relFrom = relFrom}
    end;

fun addVarListMapping mapping vs =
    let
      val Mapping
            {varTo,
             fnTo,
             relTo,
             fnFrom,
             relFrom} = mapping

      val varTo = addListVarToTptp varTo vs
    in
      Mapping
        {varTo = varTo,
         fnTo = fnTo,
         relTo = relTo,
         fnFrom = fnFrom,
         relFrom = relFrom}
    end;

fun addVarSetMapping mapping vs =
    let
      val Mapping
            {varTo,
             fnTo,
             relTo,
             fnFrom,
             relFrom} = mapping

      val varTo = addSetVarToTptp varTo vs
    in
      Mapping
        {varTo = varTo,
         fnTo = fnTo,
         relTo = relTo,
         fnFrom = fnFrom,
         relFrom = relFrom}
    end;

fun varToTptp mapping v =
    let
      val Mapping {varTo,...} = mapping
    in
      getVarToTptp varTo v
    end;

fun fnToTptp mapping fa =
    let
      val Mapping {fnTo,...} = mapping
    in
      getNameArityToTptp isTptpConstName isTptpFnName fnTo fa
    end;

fun relToTptp mapping ra =
    let
      val Mapping {relTo,...} = mapping
    in
      getNameArityToTptp isTptpPropName isTptpRelName relTo ra
    end;

fun varFromTptp (_ : mapping) v = getVarFromTptp v;

fun fnFromTptp mapping fa =
    let
      val Mapping {fnFrom,...} = mapping
    in
      getNameFromTptp fnFrom fa
    end;

fun relFromTptp mapping ra =
    let
      val Mapping {relFrom,...} = mapping
    in
      getNameFromTptp relFrom ra
    end;

val defaultMapping =
    let
      fun lift {name,arity,tptp} =
          {name = Name.fromString name, arity = arity, tptp = tptp}

      val functionMapping = map lift defaultFunctionMapping
      and relationMapping = map lift defaultRelationMapping

      val mapping =
          {functionMapping = functionMapping,
           relationMapping = relationMapping}
    in
      mkMapping mapping
    end;

(* ------------------------------------------------------------------------- *)
(* Interpreting TPTP functions and relations in a finite model.              *)
(* ------------------------------------------------------------------------- *)

fun mkFixedMap funcModel relModel =
    let
      fun mkEntry {name,arity,model} = ((Name.fromString name, arity), model)

      fun mkMap l = NameArityMap.fromList (map mkEntry l)
    in
      {functionMap = mkMap funcModel,
       relationMap = mkMap relModel}
    end;

val defaultFixedMap = mkFixedMap defaultFunctionModel defaultRelationModel;

val defaultModel =
    let
      val {size = N, fixed = fix} = Model.default

      val fix = Model.mapFixed defaultFixedMap fix
    in
      {size = N, fixed = fix}
    end;

local
  fun toTptpMap toTptp =
      let
        fun add ((src,arity),dest,m) =
            let
              val src = Name.fromString (toTptp (src,arity))
            in
              NameArityMap.insert m ((src,arity),dest)
            end
      in
        fn m => NameArityMap.foldl add (NameArityMap.new ()) m
      end;

  fun toTptpFixedMap mapping fixMap =
      let
        val {functionMap = fnMap, relationMap = relMap} = fixMap

        val fnMap = toTptpMap (fnToTptp mapping) fnMap
        and relMap = toTptpMap (relToTptp mapping) relMap
      in
        {functionMap = fnMap,
         relationMap = relMap}
      end;
in
  fun ppFixedMap mapping fixMap =
      Model.ppFixedMap (toTptpFixedMap mapping fixMap);
end;

(* ------------------------------------------------------------------------- *)
(* TPTP roles.                                                               *)
(* ------------------------------------------------------------------------- *)

datatype role =
    AxiomRole
  | ConjectureRole
  | DefinitionRole
  | NegatedConjectureRole
  | PlainRole
  | TheoremRole
  | OtherRole of string;

fun isCnfConjectureRole role =
    case role of
      NegatedConjectureRole => true
    | _ => false;

fun isFofConjectureRole role =
    case role of
      ConjectureRole => true
    | _ => false;

fun toStringRole role =
    case role of
      AxiomRole => "axiom"
    | ConjectureRole => "conjecture"
    | DefinitionRole => "definition"
    | NegatedConjectureRole => "negated_conjecture"
    | PlainRole => "plain"
    | TheoremRole => "theorem"
    | OtherRole s => s;

fun fromStringRole s =
    case s of
      "axiom" => AxiomRole
    | "conjecture" => ConjectureRole
    | "definition" => DefinitionRole
    | "negated_conjecture" => NegatedConjectureRole
    | "plain" => PlainRole
    | "theorem" => TheoremRole
    | _ => OtherRole s;

val ppRole = Print.ppMap toStringRole Print.ppString;

(* ------------------------------------------------------------------------- *)
(* SZS statuses.                                                             *)
(* ------------------------------------------------------------------------- *)

datatype status =
    CounterSatisfiableStatus
  | TheoremStatus
  | SatisfiableStatus
  | UnknownStatus
  | UnsatisfiableStatus;

fun toStringStatus status =
    case status of
      CounterSatisfiableStatus => "CounterSatisfiable"
    | TheoremStatus => "Theorem"
    | SatisfiableStatus => "Satisfiable"
    | UnknownStatus => "Unknown"
    | UnsatisfiableStatus => "Unsatisfiable";

val ppStatus = Print.ppMap toStringStatus Print.ppString;

(* ------------------------------------------------------------------------- *)
(* TPTP literals.                                                            *)
(* ------------------------------------------------------------------------- *)

datatype literal =
    Boolean of bool
  | Literal of Literal.literal;

fun destLiteral lit =
    case lit of
      Literal l => l
    | _ => raise Error "Tptp.destLiteral";

fun isBooleanLiteral lit =
    case lit of
      Boolean _ => true
    | _ => false;

fun equalBooleanLiteral b lit =
    case lit of
      Boolean b' => b = b'
    | _ => false;

fun negateLiteral (Boolean b) = (Boolean (not b))
  | negateLiteral (Literal l) = (Literal (Literal.negate l));

fun functionsLiteral (Boolean _) = NameAritySet.empty
  | functionsLiteral (Literal lit) = Literal.functions lit;

fun relationLiteral (Boolean _) = NONE
  | relationLiteral (Literal lit) = SOME (Literal.relation lit);

fun literalToFormula (Boolean true) = Formula.True
  | literalToFormula (Boolean false) = Formula.False
  | literalToFormula (Literal lit) = Literal.toFormula lit;

fun literalFromFormula Formula.True = Boolean true
  | literalFromFormula Formula.False = Boolean false
  | literalFromFormula fm = Literal (Literal.fromFormula fm);

fun freeVarsLiteral (Boolean _) = NameSet.empty
  | freeVarsLiteral (Literal lit) = Literal.freeVars lit;

fun literalSubst sub lit =
    case lit of
      Boolean _ => lit
    | Literal l => Literal (Literal.subst sub l);

(* ------------------------------------------------------------------------- *)
(* Printing formulas using TPTP syntax.                                      *)
(* ------------------------------------------------------------------------- *)

fun ppVar mapping v =
    let
      val s = varToTptp mapping v
    in
      Print.addString s
    end;

fun ppFnName mapping fa = Print.addString (fnToTptp mapping fa);

fun ppConst mapping c = ppFnName mapping (c,0);

fun ppTerm mapping =
    let
      fun term tm =
          case tm of
            Term.Var v => ppVar mapping v
          | Term.Fn (f,tms) =>
            case length tms of
              0 => ppConst mapping f
            | a =>
              Print.blockProgram Print.Inconsistent 2
                [ppFnName mapping (f,a),
                 Print.addString "(",
                 Print.ppOpList "," term tms,
                 Print.addString ")"]
    in
      Print.block Print.Inconsistent 0 o term
    end;

fun ppRelName mapping ra = Print.addString (relToTptp mapping ra);

fun ppProp mapping p = ppRelName mapping (p,0);

fun ppAtom mapping (r,tms) =
    case length tms of
      0 => ppProp mapping r
    | a =>
      Print.blockProgram Print.Inconsistent 2
        [ppRelName mapping (r,a),
         Print.addString "(",
         Print.ppOpList "," (ppTerm mapping) tms,
         Print.addString ")"];

local
  val neg = Print.sequence (Print.addString "~") (Print.addBreak 1);

  fun fof mapping fm =
      case fm of
        Formula.And _ => assoc_binary mapping ("&", Formula.stripConj fm)
      | Formula.Or _ => assoc_binary mapping ("|", Formula.stripDisj fm)
      | Formula.Imp a_b => nonassoc_binary mapping ("=>",a_b)
      | Formula.Iff a_b => nonassoc_binary mapping ("<=>",a_b)
      | _ => unitary mapping fm

  and nonassoc_binary mapping (s,a_b) =
      Print.ppOp2 (" " ^ s) (unitary mapping) (unitary mapping) a_b

  and assoc_binary mapping (s,l) = Print.ppOpList (" " ^ s) (unitary mapping) l

  and unitary mapping fm =
      case fm of
        Formula.True => Print.addString "$true"
      | Formula.False => Print.addString "$false"
      | Formula.Forall _ => quantified mapping ("!", Formula.stripForall fm)
      | Formula.Exists _ => quantified mapping ("?", Formula.stripExists fm)
      | Formula.Not _ =>
        (case total Formula.destNeq fm of
           SOME a_b => Print.ppOp2 " !=" (ppTerm mapping) (ppTerm mapping) a_b
         | NONE =>
           let
             val (n,fm) = Formula.stripNeg fm
           in
             Print.blockProgram Print.Inconsistent 2
               [Print.duplicate n neg,
                unitary mapping fm]
           end)
      | Formula.Atom atm =>
        (case total Formula.destEq fm of
           SOME a_b => Print.ppOp2 " =" (ppTerm mapping) (ppTerm mapping) a_b
         | NONE => ppAtom mapping atm)
      | _ =>
        Print.blockProgram Print.Inconsistent 1
          [Print.addString "(",
           fof mapping fm,
           Print.addString ")"]

  and quantified mapping (q,(vs,fm)) =
      let
        val mapping = addVarListMapping mapping vs
      in
        Print.blockProgram Print.Inconsistent 2
          [Print.addString q,
           Print.addString " ",
           Print.blockProgram Print.Inconsistent (String.size q)
             [Print.addString "[",
              Print.ppOpList "," (ppVar mapping) vs,
              Print.addString "] :"],
           Print.addBreak 1,
           unitary mapping fm]
      end;
in
  fun ppFof mapping fm = Print.block Print.Inconsistent 0 (fof mapping fm);
end;

(* ------------------------------------------------------------------------- *)
(* Lexing TPTP files.                                                        *)
(* ------------------------------------------------------------------------- *)

datatype token =
    AlphaNum of string
  | Punct of char
  | Quote of string;

fun isAlphaNum #"_" = true
  | isAlphaNum c = Char.isAlphaNum c;

local
  open Parse;

  infixr 9 >>++
  infixr 8 ++
  infixr 7 >>
  infixr 6 ||

  val alphaNumToken = atLeastOne (some isAlphaNum) >> (AlphaNum o implode);

  val punctToken =
      let
        val punctChars = "<>=-*+/\\?@|!$%&#^:;~()[]{}.,"
      in
        some (Char.contains punctChars) >> Punct
      end;

  val quoteToken =
      let
        val escapeParser =
            some (equal #"'") >> singleton ||
            some (equal #"\\") >> singleton

        fun stopOn #"'" = true
          | stopOn #"\n" = true
          | stopOn _ = false

        val quotedParser =
            some (equal #"\\") ++ escapeParser >> op:: ||
            some (not o stopOn) >> singleton
      in
        exactChar #"'" ++ many quotedParser ++ exactChar #"'" >>
        (fn (_,(l,_)) => Quote (implode (List.concat l)))
      end;

  val lexToken = alphaNumToken || punctToken || quoteToken;

  val space = many (some Char.isSpace) >> K ();
in
  val lexer = (space ++ lexToken ++ space) >> (fn ((),(tok,())) => tok);
end;

(* ------------------------------------------------------------------------- *)
(* TPTP clauses.                                                             *)
(* ------------------------------------------------------------------------- *)

type clause = literal list;

val clauseFunctions =
    let
      fun funcs (lit,acc) = NameAritySet.union (functionsLiteral lit) acc
    in
      foldl funcs NameAritySet.empty
    end;

val clauseRelations =
    let
      fun rels (lit,acc) =
          case relationLiteral lit of
            NONE => acc
          | SOME r => NameAritySet.add acc r
    in
      foldl rels NameAritySet.empty
    end;

val clauseFreeVars =
    let
      fun fvs (lit,acc) = NameSet.union (freeVarsLiteral lit) acc
    in
      foldl fvs NameSet.empty
    end;

fun clauseSubst sub lits = map (literalSubst sub) lits;

fun clauseToFormula lits = Formula.listMkDisj (map literalToFormula lits);

fun clauseFromFormula fm = map literalFromFormula (Formula.stripDisj fm);

fun clauseFromLiteralSet cl =
    clauseFromFormula
      (Formula.listMkDisj (LiteralSet.transform Literal.toFormula cl));

fun clauseFromThm th = clauseFromLiteralSet (Thm.clause th);

fun ppClause mapping = Print.ppMap clauseToFormula (ppFof mapping);

(* ------------------------------------------------------------------------- *)
(* TPTP formula names.                                                       *)
(* ------------------------------------------------------------------------- *)

datatype formulaName = FormulaName of string;

datatype formulaNameSet = FormulaNameSet of formulaName Set.set;

fun compareFormulaName (FormulaName s1, FormulaName s2) =
    String.compare (s1,s2);

fun toTptpFormulaName (FormulaName s) =
    getNameToTptp isTptpFormulaName s;

val ppFormulaName = Print.ppMap toTptpFormulaName Print.ppString;

val emptyFormulaNameSet = FormulaNameSet (Set.empty compareFormulaName);

fun memberFormulaNameSet n (FormulaNameSet s) = Set.member n s;

fun addFormulaNameSet (FormulaNameSet s) n = FormulaNameSet (Set.add s n);

fun addListFormulaNameSet (FormulaNameSet s) l =
    FormulaNameSet (Set.addList s l);

(* ------------------------------------------------------------------------- *)
(* TPTP formula bodies.                                                      *)
(* ------------------------------------------------------------------------- *)

datatype formulaBody =
    CnfFormulaBody of literal list
  | FofFormulaBody of Formula.formula;

fun destCnfFormulaBody body =
    case body of
      CnfFormulaBody x => x
    | _ => raise Error "destCnfFormulaBody";

val isCnfFormulaBody = can destCnfFormulaBody;

fun destFofFormulaBody body =
    case body of
      FofFormulaBody x => x
    | _ => raise Error "destFofFormulaBody";

val isFofFormulaBody = can destFofFormulaBody;

fun formulaBodyFunctions body =
    case body of
      CnfFormulaBody cl => clauseFunctions cl
    | FofFormulaBody fm => Formula.functions fm;

fun formulaBodyRelations body =
    case body of
      CnfFormulaBody cl => clauseRelations cl
    | FofFormulaBody fm => Formula.relations fm;

fun formulaBodyFreeVars body =
    case body of
      CnfFormulaBody cl => clauseFreeVars cl
    | FofFormulaBody fm => Formula.freeVars fm;

fun ppFormulaBody mapping body =
    case body of
      CnfFormulaBody cl => ppClause mapping cl
    | FofFormulaBody fm => ppFof mapping (Formula.generalize fm);

(* ------------------------------------------------------------------------- *)
(* TPTP formula sources.                                                     *)
(* ------------------------------------------------------------------------- *)

datatype formulaSource =
    NoFormulaSource
  | StripFormulaSource of
      {inference : string,
       parents : formulaName list}
  | NormalizeFormulaSource of
      {inference : Normalize.inference,
       parents : formulaName list}
  | ProofFormulaSource of
      {inference : Proof.inference,
       parents : formulaName list};

fun isNoFormulaSource source =
    case source of
      NoFormulaSource => true
    | _ => false;

fun functionsFormulaSource source =
    case source of
      NoFormulaSource => NameAritySet.empty
    | StripFormulaSource _ => NameAritySet.empty
    | NormalizeFormulaSource data =>
      let
        val {inference = inf, parents = _} = data
      in
        case inf of
          Normalize.Axiom fm => Formula.functions fm
        | Normalize.Definition (_,fm) => Formula.functions fm
        | _ => NameAritySet.empty
      end
    | ProofFormulaSource data =>
      let
        val {inference = inf, parents = _} = data
      in
        case inf of
          Proof.Axiom cl => LiteralSet.functions cl
        | Proof.Assume atm => Atom.functions atm
        | Proof.Subst (sub,_) => Subst.functions sub
        | Proof.Resolve (atm,_,_) => Atom.functions atm
        | Proof.Refl tm => Term.functions tm
        | Proof.Equality (lit,_,tm) =>
          NameAritySet.union (Literal.functions lit) (Term.functions tm)
      end;

fun relationsFormulaSource source =
    case source of
      NoFormulaSource => NameAritySet.empty
    | StripFormulaSource _ => NameAritySet.empty
    | NormalizeFormulaSource data =>
      let
        val {inference = inf, parents = _} = data
      in
        case inf of
          Normalize.Axiom fm => Formula.relations fm
        | Normalize.Definition (_,fm) => Formula.relations fm
        | _ => NameAritySet.empty
      end
    | ProofFormulaSource data =>
      let
        val {inference = inf, parents = _} = data
      in
        case inf of
          Proof.Axiom cl => LiteralSet.relations cl
        | Proof.Assume atm => NameAritySet.singleton (Atom.relation atm)
        | Proof.Subst _ => NameAritySet.empty
        | Proof.Resolve (atm,_,_) => NameAritySet.singleton (Atom.relation atm)
        | Proof.Refl tm => NameAritySet.empty
        | Proof.Equality (lit,_,_) =>
          NameAritySet.singleton (Literal.relation lit)
      end;

fun freeVarsFormulaSource source =
    case source of
      NoFormulaSource => NameSet.empty
    | StripFormulaSource _ => NameSet.empty
    | NormalizeFormulaSource data => NameSet.empty
    | ProofFormulaSource data =>
      let
        val {inference = inf, parents = _} = data
      in
        case inf of
          Proof.Axiom cl => LiteralSet.freeVars cl
        | Proof.Assume atm => Atom.freeVars atm
        | Proof.Subst (sub,_) => Subst.freeVars sub
        | Proof.Resolve (atm,_,_) => Atom.freeVars atm
        | Proof.Refl tm => Term.freeVars tm
        | Proof.Equality (lit,_,tm) =>
          NameSet.union (Literal.freeVars lit) (Term.freeVars tm)
      end;

local
  val GEN_INFERENCE = "inference"
  and GEN_INTRODUCED = "introduced";

  fun nameStrip inf = inf;

  fun ppStrip mapping inf = Print.skip;

  fun nameNormalize inf =
      case inf of
        Normalize.Axiom _ => "canonicalize"
      | Normalize.Definition _ => "canonicalize"
      | Normalize.Simplify _ => "simplify"
      | Normalize.Conjunct _ => "conjunct"
      | Normalize.Specialize _ => "specialize"
      | Normalize.Skolemize _ => "skolemize"
      | Normalize.Clausify _ => "clausify";

  fun ppNormalize mapping inf = Print.skip;

  fun nameProof inf =
      case inf of
        Proof.Axiom _ => "canonicalize"
      | Proof.Assume _ => "assume"
      | Proof.Subst _ => "subst"
      | Proof.Resolve _ => "resolve"
      | Proof.Refl _ => "refl"
      | Proof.Equality _ => "equality";

  local
    fun ppTermInf mapping = ppTerm mapping;

    fun ppAtomInf mapping atm =
        case total Atom.destEq atm of
          SOME (a,b) => ppAtom mapping (Name.fromString "$equal", [a,b])
        | NONE => ppAtom mapping atm;

    fun ppLiteralInf mapping (pol,atm) =
        Print.sequence
          (if pol then Print.skip else Print.addString "~ ")
          (ppAtomInf mapping atm);
  in
    fun ppProofTerm mapping =
        Print.ppBracket "$fot(" ")" (ppTermInf mapping);

    fun ppProofAtom mapping =
        Print.ppBracket "$cnf(" ")" (ppAtomInf mapping);

    fun ppProofLiteral mapping =
        Print.ppBracket "$cnf(" ")" (ppLiteralInf mapping);
  end;

  val ppProofVar = ppVar;

  val ppProofPath = Term.ppPath;

  fun ppProof mapping inf =
      Print.blockProgram Print.Inconsistent 1
        [Print.addString "[",
         (case inf of
            Proof.Axiom _ => Print.skip
          | Proof.Assume atm => ppProofAtom mapping atm
          | Proof.Subst _ => Print.skip
          | Proof.Resolve (atm,_,_) => ppProofAtom mapping atm
          | Proof.Refl tm => ppProofTerm mapping tm
          | Proof.Equality (lit,path,tm) =>
            Print.program
              [ppProofLiteral mapping lit,
               Print.addString ",",
               Print.addBreak 1,
               ppProofPath path,
               Print.addString ",",
               Print.addBreak 1,
               ppProofTerm mapping tm]),
         Print.addString "]"];

  val ppParent = ppFormulaName;

  fun ppProofSubst mapping =
      Print.ppMap Subst.toList
        (Print.ppList
           (Print.ppBracket "bind(" ")"
              (Print.ppOp2 "," (ppProofVar mapping)
                 (ppProofTerm mapping))));

  fun ppProofParent mapping (p,s) =
      if Subst.null s then ppParent p
      else Print.ppOp2 " :" ppParent (ppProofSubst mapping) (p,s);
in
  fun ppFormulaSource mapping source =
      case source of
        NoFormulaSource => Print.skip
      | StripFormulaSource {inference,parents} =>
        let
          val gen = GEN_INFERENCE

          val name = nameStrip inference
        in
          Print.blockProgram Print.Inconsistent (size gen + 1)
            [Print.addString gen,
             Print.addString "(",
             Print.addString name,
             Print.addString ",",
             Print.addBreak 1,
             Print.ppBracket "[" "]" (ppStrip mapping) inference,
             Print.addString ",",
             Print.addBreak 1,
             Print.ppList ppParent parents,
             Print.addString ")"]
        end
      | NormalizeFormulaSource {inference,parents} =>
        let
          val gen = GEN_INFERENCE

          val name = nameNormalize inference
        in
          Print.blockProgram Print.Inconsistent (size gen + 1)
            [Print.addString gen,
             Print.addString "(",
             Print.addString name,
             Print.addString ",",
             Print.addBreak 1,
             Print.ppBracket "[" "]" (ppNormalize mapping) inference,
             Print.addString ",",
             Print.addBreak 1,
             Print.ppList ppParent parents,
             Print.addString ")"]
        end
      | ProofFormulaSource {inference,parents} =>
        let
          val isTaut = null parents

          val gen = if isTaut then GEN_INTRODUCED else GEN_INFERENCE

          val name = nameProof inference

          val parents =
              let
                val sub =
                    case inference of
                      Proof.Subst (s,_) => s
                    | _ => Subst.empty
              in
                map (fn parent => (parent,sub)) parents
              end
        in
          Print.blockProgram Print.Inconsistent (size gen + 1)
            ([Print.addString gen,
              Print.addString "("] @
             (if isTaut then
                [Print.addString "tautology",
                 Print.addString ",",
                 Print.addBreak 1,
                 Print.blockProgram Print.Inconsistent 1
                   [Print.addString "[",
                    Print.addString name,
                    Print.addString ",",
                    Print.addBreak 1,
                    ppProof mapping inference,
                    Print.addString "]"]]
              else
                [Print.addString name,
                 Print.addString ",",
                 Print.addBreak 1,
                 ppProof mapping inference,
                 Print.addString ",",
                 Print.addBreak 1,
                 Print.ppList (ppProofParent mapping) parents]) @
             [Print.addString ")"])
        end
end;

(* ------------------------------------------------------------------------- *)
(* TPTP formulas.                                                            *)
(* ------------------------------------------------------------------------- *)

datatype formula =
    Formula of
      {name : formulaName,
       role : role,
       body : formulaBody,
       source : formulaSource};

fun nameFormula (Formula {name,...}) = name;

fun roleFormula (Formula {role,...}) = role;

fun bodyFormula (Formula {body,...}) = body;

fun sourceFormula (Formula {source,...}) = source;

fun destCnfFormula fm = destCnfFormulaBody (bodyFormula fm);

val isCnfFormula = can destCnfFormula;

fun destFofFormula fm = destFofFormulaBody (bodyFormula fm);

val isFofFormula = can destFofFormula;

fun functionsFormula fm =
    let
      val bodyFns = formulaBodyFunctions (bodyFormula fm)
      and sourceFns = functionsFormulaSource (sourceFormula fm)
    in
      NameAritySet.union bodyFns sourceFns
    end;

fun relationsFormula fm =
    let
      val bodyRels = formulaBodyRelations (bodyFormula fm)
      and sourceRels = relationsFormulaSource (sourceFormula fm)
    in
      NameAritySet.union bodyRels sourceRels
    end;

fun freeVarsFormula fm =
    let
      val bodyFvs = formulaBodyFreeVars (bodyFormula fm)
      and sourceFvs = freeVarsFormulaSource (sourceFormula fm)
    in
      NameSet.union bodyFvs sourceFvs
    end;

val freeVarsListFormula =
    let
      fun add (fm,vs) = NameSet.union vs (freeVarsFormula fm)
    in
      List.foldl add NameSet.empty
    end;

val formulasFunctions =
    let
      fun funcs (fm,acc) = NameAritySet.union (functionsFormula fm) acc
    in
      foldl funcs NameAritySet.empty
    end;

val formulasRelations =
    let
      fun rels (fm,acc) = NameAritySet.union (relationsFormula fm) acc
    in
      foldl rels NameAritySet.empty
    end;

fun isCnfConjectureFormula fm =
    case fm of
      Formula {role, body = CnfFormulaBody _, ...} => isCnfConjectureRole role
    | _ => false;

fun isFofConjectureFormula fm =
    case fm of
      Formula {role, body = FofFormulaBody _, ...} => isFofConjectureRole role
    | _ => false;

fun isConjectureFormula fm =
    isCnfConjectureFormula fm orelse
    isFofConjectureFormula fm;

(* Parsing and pretty-printing *)

fun ppFormula mapping fm =
    let
      val Formula {name,role,body,source} = fm

      val gen =
          case body of
            CnfFormulaBody _ => "cnf"
          | FofFormulaBody _ => "fof"
    in
      Print.blockProgram Print.Inconsistent (size gen + 1)
        ([Print.addString gen,
          Print.addString "(",
          ppFormulaName name,
          Print.addString ",",
          Print.addBreak 1,
          ppRole role,
          Print.addString ",",
          Print.addBreak 1,
          Print.blockProgram Print.Consistent 1
            [Print.addString "(",
             ppFormulaBody mapping body,
             Print.addString ")"]] @
         (if isNoFormulaSource source then []
          else
            [Print.addString ",",
             Print.addBreak 1,
             ppFormulaSource mapping source]) @
         [Print.addString ")."])
    end;

fun formulaToString mapping = Print.toString (ppFormula mapping);

local
  open Parse;

  infixr 9 >>++
  infixr 8 ++
  infixr 7 >>
  infixr 6 ||

  fun someAlphaNum p =
      maybe (fn AlphaNum s => if p s then SOME s else NONE | _ => NONE);

  fun alphaNumParser s = someAlphaNum (equal s) >> K ();

  val lowerParser = someAlphaNum (fn s => Char.isLower (String.sub (s,0)));

  val upperParser = someAlphaNum (fn s => Char.isUpper (String.sub (s,0)));

  val stringParser = lowerParser || upperParser;

  val numberParser = someAlphaNum (List.all Char.isDigit o explode);

  fun somePunct p =
      maybe (fn Punct c => if p c then SOME c else NONE | _ => NONE);

  fun punctParser c = somePunct (equal c) >> K ();

  val quoteParser = maybe (fn Quote s => SOME s | _ => NONE);

  local
    fun f [] = raise Bug "symbolParser"
      | f [x] = x
      | f (h :: t) = (h ++ f t) >> K ();
  in
    fun symbolParser s = f (map punctParser (explode s));
  end;

  val definedParser =
      punctParser #"$" ++ someAlphaNum (K true) >> (fn ((),s) => "$" ^ s);

  val systemParser =
      punctParser #"$" ++ punctParser #"$" ++ someAlphaNum (K true) >>
      (fn ((),((),s)) => "$$" ^ s);

  val nameParser =
      (stringParser || numberParser || quoteParser) >> FormulaName;

  val roleParser = lowerParser >> fromStringRole;

  local
    fun isProposition s = isHdTlString Char.isLower isAlphaNum s;
  in
    val propositionParser =
        someAlphaNum isProposition ||
        definedParser ||
        systemParser ||
        quoteParser;
  end;

  local
    fun isFunction s = isHdTlString Char.isLower isAlphaNum s;
  in
    val functionParser =
        someAlphaNum isFunction ||
        definedParser ||
        systemParser ||
        quoteParser;
  end;

  local
    fun isConstant s = isHdTlString Char.isLower isAlphaNum s;
  in
    val constantParser =
        someAlphaNum isConstant ||
        definedParser ||
        numberParser ||
        systemParser ||
        quoteParser;
  end;

  val varParser = upperParser;

  val varListParser =
      (punctParser #"[" ++ varParser ++
       many ((punctParser #"," ++ varParser) >> snd) ++
       punctParser #"]") >>
      (fn ((),(h,(t,()))) => h :: t);

  fun mkVarName mapping v = varFromTptp mapping v;

  fun mkVar mapping v =
      let
        val v = mkVarName mapping v
      in
        Term.Var v
      end

  fun mkFn mapping (f,tms) =
      let
        val f = fnFromTptp mapping (f, length tms)
      in
        Term.Fn (f,tms)
      end;

  fun mkConst mapping c = mkFn mapping (c,[]);

  fun mkAtom mapping (r,tms) =
      let
        val r = relFromTptp mapping (r, length tms)
      in
        (r,tms)
      end;

  fun termParser mapping input =
      let
        val fnP = functionArgumentsParser mapping >> mkFn mapping
        val nonFnP = nonFunctionArgumentsTermParser mapping
      in
        fnP || nonFnP
      end input

  and functionArgumentsParser mapping input =
      let
        val commaTmP = (punctParser #"," ++ termParser mapping) >> snd
      in
        (functionParser ++ punctParser #"(" ++ termParser mapping ++
         many commaTmP ++ punctParser #")") >>
        (fn (f,((),(t,(ts,())))) => (f, t :: ts))
      end input

  and nonFunctionArgumentsTermParser mapping input =
      let
        val varP = varParser >> mkVar mapping
        val constP = constantParser >> mkConst mapping
      in
        varP || constP
      end input;

  fun binaryAtomParser mapping tm input =
      let
        val eqP =
            (punctParser #"=" ++ termParser mapping) >>
            (fn ((),r) => (true,("$equal",[tm,r])))

        val neqP =
            (symbolParser "!=" ++ termParser mapping) >>
            (fn ((),r) => (false,("$equal",[tm,r])))
      in
        eqP || neqP
      end input;

  fun maybeBinaryAtomParser mapping (s,tms) input =
      let
        val tm = mkFn mapping (s,tms)
      in
        optional (binaryAtomParser mapping tm) >>
        (fn SOME lit => lit
          | NONE => (true,(s,tms)))
      end input;

  fun literalAtomParser mapping input =
      let
        val fnP =
            functionArgumentsParser mapping >>++
            maybeBinaryAtomParser mapping

        val nonFnP =
            nonFunctionArgumentsTermParser mapping >>++
            binaryAtomParser mapping

        val propP = propositionParser >> (fn s => (true,(s,[])))
      in
        fnP || nonFnP || propP
      end input;

  fun atomParser mapping input =
      let
        fun mk (pol,rel) =
          case rel of
            ("$true",[]) => Boolean pol
          | ("$false",[]) => Boolean (not pol)
          | ("$equal",[l,r]) => Literal (pol, Atom.mkEq (l,r))
          | (r,tms) => Literal (pol, mkAtom mapping (r,tms))
      in
        literalAtomParser mapping >> mk
      end input;

  fun literalParser mapping input =
      let
        val negP =
            (punctParser #"~" ++ atomParser mapping) >>
            (negateLiteral o snd)

        val posP = atomParser mapping
      in
        negP || posP
      end input;

  fun disjunctionParser mapping input =
      let
        val orLitP = (punctParser #"|" ++ literalParser mapping) >> snd
      in
        (literalParser mapping ++ many orLitP) >> (fn (h,t) => h :: t)
      end input;

  fun clauseParser mapping input =
      let
        val disjP = disjunctionParser mapping

        val bracketDisjP =
            (punctParser #"(" ++ disjP ++ punctParser #")") >>
            (fn ((),(c,())) => c)
      in
        bracketDisjP || disjP
      end input;

  val binaryConnectiveParser =
      (symbolParser "<=>" >> K Formula.Iff) ||
      (symbolParser "=>" >> K Formula.Imp) ||
      (symbolParser "<=" >> K (fn (f,g) => Formula.Imp (g,f))) ||
      (symbolParser "<~>" >> K (Formula.Not o Formula.Iff)) ||
      (symbolParser "~|" >> K (Formula.Not o Formula.Or)) ||
      (symbolParser "~&" >> K (Formula.Not o Formula.And));

  val quantifierParser =
      (punctParser #"!" >> K Formula.listMkForall) ||
      (punctParser #"?" >> K Formula.listMkExists);

  fun fofFormulaParser mapping input =
      let
        fun mk (f,NONE) = f
          | mk (f, SOME t) = t f
      in
        (unitaryFormulaParser mapping ++
         optional (binaryFormulaParser mapping)) >> mk
      end input

  and binaryFormulaParser mapping input =
      let
        val nonAssocP = nonAssocBinaryFormulaParser mapping

        val assocP = assocBinaryFormulaParser mapping
      in
        nonAssocP || assocP
      end input

  and nonAssocBinaryFormulaParser mapping input =
      let
        fun mk (c,g) f = c (f,g)
      in
        (binaryConnectiveParser ++ unitaryFormulaParser mapping) >> mk
      end input

  and assocBinaryFormulaParser mapping input =
      let
        val orP = orFormulaParser mapping

        val andP = andFormulaParser mapping
      in
        orP || andP
      end input

  and orFormulaParser mapping input =
      let
        val orFmP = (punctParser #"|" ++ unitaryFormulaParser mapping) >> snd
      in
        atLeastOne orFmP >>
        (fn fs => fn f => Formula.listMkDisj (f :: fs))
      end input

  and andFormulaParser mapping input =
      let
        val andFmP = (punctParser #"&" ++ unitaryFormulaParser mapping) >> snd
      in
        atLeastOne andFmP >>
        (fn fs => fn f => Formula.listMkConj (f :: fs))
      end input

  and unitaryFormulaParser mapping input =
      let
        val quantP = quantifiedFormulaParser mapping

        val unaryP = unaryFormulaParser mapping

        val brackP =
            (punctParser #"(" ++ fofFormulaParser mapping ++
             punctParser #")") >>
            (fn ((),(f,())) => f)

        val atomP =
            atomParser mapping >>
            (fn Boolean b => Formula.mkBoolean b
              | Literal l => Literal.toFormula l)
      in
        quantP ||
        unaryP ||
        brackP ||
        atomP
      end input

  and quantifiedFormulaParser mapping input =
      let
        fun mk (q,(vs,((),f))) = q (map (mkVarName mapping) vs, f)
      in
        (quantifierParser ++ varListParser ++ punctParser #":" ++
         unitaryFormulaParser mapping) >> mk
      end input

  and unaryFormulaParser mapping input =
      let
        fun mk (c,f) = c f
      in
        (unaryConnectiveParser ++ unitaryFormulaParser mapping) >> mk
      end input

  and unaryConnectiveParser input =
      (punctParser #"~" >> K Formula.Not) input;

  fun cnfParser mapping input =
      let
        fun mk ((),((),(name,((),(role,((),(cl,((),())))))))) =
            let
              val body = CnfFormulaBody cl
              val source = NoFormulaSource
            in
              Formula
                {name = name,
                 role = role,
                 body = body,
                 source = source}
            end
      in
        (alphaNumParser "cnf" ++ punctParser #"(" ++
         nameParser ++ punctParser #"," ++
         roleParser ++ punctParser #"," ++
         clauseParser mapping ++ punctParser #")" ++
         punctParser #".") >> mk
      end input;

  fun fofParser mapping input =
      let
        fun mk ((),((),(name,((),(role,((),(fm,((),())))))))) =
            let
              val body = FofFormulaBody fm
              val source = NoFormulaSource
            in
              Formula
                {name = name,
                 role = role,
                 body = body,
                 source = source}
            end
      in
        (alphaNumParser "fof" ++ punctParser #"(" ++
         nameParser ++ punctParser #"," ++
         roleParser ++ punctParser #"," ++
         fofFormulaParser mapping ++ punctParser #")" ++
         punctParser #".") >> mk
      end input;
in
  fun formulaParser mapping input =
      let
        val cnfP = cnfParser mapping

        val fofP = fofParser mapping
      in
        cnfP || fofP
      end input;
end;

(* ------------------------------------------------------------------------- *)
(* Include declarations.                                                     *)
(* ------------------------------------------------------------------------- *)

fun ppInclude i =
    Print.blockProgram Print.Inconsistent 2
      [Print.addString "include('",
       Print.addString i,
       Print.addString "')."];

val includeToString = Print.toString ppInclude;

local
  open Parse;

  infixr 9 >>++
  infixr 8 ++
  infixr 7 >>
  infixr 6 ||

  val filenameParser = maybe (fn Quote s => SOME s | _ => NONE);
in
  val includeParser =
      (some (equal (AlphaNum "include")) ++
       some (equal (Punct #"(")) ++
       filenameParser ++
       some (equal (Punct #")")) ++
       some (equal (Punct #"."))) >>
      (fn (_,(_,(f,(_,_)))) => f);
end;

(* ------------------------------------------------------------------------- *)
(* Parsing TPTP files.                                                       *)
(* ------------------------------------------------------------------------- *)

datatype declaration =
    IncludeDeclaration of string
  | FormulaDeclaration of formula;

val partitionDeclarations =
    let
      fun part (d,(il,fl)) =
          case d of
            IncludeDeclaration i => (i :: il, fl)
          | FormulaDeclaration f => (il, f :: fl)
    in
      fn l => List.foldl part ([],[]) (rev l)
    end;

local
  open Parse;

  infixr 9 >>++
  infixr 8 ++
  infixr 7 >>
  infixr 6 ||

  fun declarationParser mapping =
      (includeParser >> IncludeDeclaration) ||
      (formulaParser mapping >> FormulaDeclaration);

  fun parseChars parser chars =
      let
        val tokens = Parse.everything (lexer >> singleton) chars
      in
        Parse.everything (parser >> singleton) tokens
      end;
in
  fun parseDeclaration mapping = parseChars (declarationParser mapping);
end;

(* ------------------------------------------------------------------------- *)
(* Clause information.                                                       *)
(* ------------------------------------------------------------------------- *)

datatype clauseSource =
    CnfClauseSource of formulaName * literal list
  | FofClauseSource of Normalize.thm;

type 'a clauseInfo = 'a LiteralSetMap.map;

type clauseNames = formulaName clauseInfo;

type clauseRoles = role clauseInfo;

type clauseSources = clauseSource clauseInfo;

val noClauseNames : clauseNames = LiteralSetMap.new ();

val allClauseNames : clauseNames -> formulaNameSet =
    let
      fun add (_,n,s) = addFormulaNameSet s n
    in
      LiteralSetMap.foldl add emptyFormulaNameSet
    end;

val noClauseRoles : clauseRoles = LiteralSetMap.new ();

val noClauseSources : clauseSources = LiteralSetMap.new ();

(* ------------------------------------------------------------------------- *)
(* Comments.                                                                 *)
(* ------------------------------------------------------------------------- *)

fun mkLineComment "" = "%"
  | mkLineComment line = "% " ^ line;

fun destLineComment cs =
    case cs of
      [] => ""
    | #"%" :: #" " :: rest => implode rest
    | #"%" :: rest => implode rest
    | _ => raise Error "Tptp.destLineComment";

val isLineComment = can destLineComment;

(* ------------------------------------------------------------------------- *)
(* TPTP problems.                                                            *)
(* ------------------------------------------------------------------------- *)

type comments = string list;

type includes = string list;

datatype problem =
    Problem of
      {comments : comments,
       includes : includes,
       formulas : formula list};

fun hasCnfConjecture (Problem {formulas,...}) =
    List.exists isCnfConjectureFormula formulas;

fun hasFofConjecture (Problem {formulas,...}) =
    List.exists isFofConjectureFormula formulas;

fun hasConjecture (Problem {formulas,...}) =
    List.exists isConjectureFormula formulas;

fun freeVars (Problem {formulas,...}) = freeVarsListFormula formulas;

local
  fun bump n avoid =
      let
        val s = FormulaName (Int.toString n)
      in
        if memberFormulaNameSet s avoid then bump (n + 1) avoid
        else (s, n, addFormulaNameSet avoid s)
      end;

  fun fromClause defaultRole names roles cl (n,avoid) =
      let
        val (name,n,avoid) =
            case LiteralSetMap.peek names cl of
              SOME name => (name,n,avoid)
            | NONE => bump n avoid

        val role = Option.getOpt (LiteralSetMap.peek roles cl, defaultRole)

        val body = CnfFormulaBody (clauseFromLiteralSet cl)

        val source = NoFormulaSource

        val formula =
            Formula
              {name = name,
               role = role,
               body = body,
               source = source}
      in
        (formula,(n,avoid))
      end;
in
  fun mkProblem {comments,includes,names,roles,problem} =
      let
        fun fromCl defaultRole = fromClause defaultRole names roles

        val {axioms,conjecture} = problem

        val n_avoid = (0, allClauseNames names)

        val (axiomFormulas,n_avoid) = maps (fromCl AxiomRole) axioms n_avoid

        val (conjectureFormulas,_) =
            maps (fromCl NegatedConjectureRole) conjecture n_avoid

        val formulas = axiomFormulas @ conjectureFormulas
      in
        Problem
          {comments = comments,
           includes = includes,
           formulas = formulas}
      end;
end;

type normalization =
     {problem : Problem.problem,
      sources : clauseSources};

val initialNormalization : normalization =
    {problem = {axioms = [], conjecture = []},
     sources = noClauseSources};

datatype problemGoal =
    NoGoal
  | CnfGoal of (formulaName * clause) list
  | FofGoal of (formulaName * Formula.formula) list;

local
  fun partitionFormula (formula,(cnfAxioms,fofAxioms,cnfGoals,fofGoals)) =
      let
        val Formula {name,role,body,...} = formula
      in
        case body of
          CnfFormulaBody cl =>
          if isCnfConjectureRole role then
            let
              val cnfGoals = (name,cl) :: cnfGoals
            in
              (cnfAxioms,fofAxioms,cnfGoals,fofGoals)
            end
          else
            let
              val cnfAxioms = (name,cl) :: cnfAxioms
            in
              (cnfAxioms,fofAxioms,cnfGoals,fofGoals)
            end
        | FofFormulaBody fm =>
          if isFofConjectureRole role then
            let
              val fofGoals = (name,fm) :: fofGoals
            in
              (cnfAxioms,fofAxioms,cnfGoals,fofGoals)
            end
          else
            let
              val fofAxioms = (name,fm) :: fofAxioms
            in
              (cnfAxioms,fofAxioms,cnfGoals,fofGoals)
            end
      end;

  fun partitionFormulas fms =
      let
        val (cnfAxioms,fofAxioms,cnfGoals,fofGoals) =
            List.foldl partitionFormula ([],[],[],[]) fms

        val goal =
            case (rev cnfGoals, rev fofGoals) of
              ([],[]) => NoGoal
            | (cnfGoals,[]) => CnfGoal cnfGoals
            | ([],fofGoals) => FofGoal fofGoals
            | (_ :: _, _ :: _) =>
              raise Error "TPTP problem has both cnf and fof conjecture formulas"
      in
        {cnfAxioms = rev cnfAxioms,
         fofAxioms = rev fofAxioms,
         goal = goal}
      end;

  fun addClauses role clauses acc : normalization =
      let
        fun addClause (cl_src,sources) =
            LiteralSetMap.insert sources cl_src

        val {problem,sources} : normalization = acc
        val {axioms,conjecture} = problem

        val cls = map fst clauses
        val (axioms,conjecture) =
            if isCnfConjectureRole role then (axioms, cls @ conjecture)
            else (cls @ axioms, conjecture)

        val problem = {axioms = axioms, conjecture = conjecture}
        and sources = List.foldl addClause sources clauses
      in
        {problem = problem,
         sources = sources}
      end;

  fun addCnf role ((name,clause),(norm,cnf)) =
      if List.exists (equalBooleanLiteral true) clause then (norm,cnf)
      else
        let
          val cl = List.mapPartial (total destLiteral) clause
          val cl = LiteralSet.fromList cl

          val src = CnfClauseSource (name,clause)

          val norm = addClauses role [(cl,src)] norm
        in
          (norm,cnf)
        end;

  val addCnfAxiom = addCnf AxiomRole;

  val addCnfGoal = addCnf NegatedConjectureRole;

  fun addFof role (th,(norm,cnf)) =
      let
        fun sourcify (cl,inf) = (cl, FofClauseSource inf)

        val (clauses,cnf) = Normalize.addCnf th cnf
        val clauses = map sourcify clauses
        val norm = addClauses role clauses norm
      in
        (norm,cnf)
      end;

  fun addFofAxiom ((_,fm),acc) =
      addFof AxiomRole (Normalize.mkAxiom fm, acc);

  fun normProblem subgoal (norm,_) =
      let
        val {problem,sources} = norm
        val {axioms,conjecture} = problem
        val problem = {axioms = rev axioms, conjecture = rev conjecture}
      in
        {subgoal = subgoal,
         problem = problem,
         sources = sources}
      end;

  val normProblemFalse = normProblem (Formula.False,[]);

  fun splitProblem acc =
      let
        fun mk parents subgoal =
            let
              val subgoal = Formula.generalize subgoal

              val th = Normalize.mkAxiom (Formula.Not subgoal)

              val acc = addFof NegatedConjectureRole (th,acc)
            in
              normProblem (subgoal,parents) acc
            end

        fun split (name,goal) =
            let
              val subgoals = Formula.splitGoal goal
              val subgoals =
                  if null subgoals then [Formula.True] else subgoals

              val parents = [name]
            in
              map (mk parents) subgoals
            end
      in
        fn goals => List.concat (map split goals)
      end;

  fun clausesToGoal cls =
      let
        val cls = map (Formula.generalize o clauseToFormula o snd) cls
      in
        Formula.listMkConj cls
      end;

  fun formulasToGoal fms =
      let
        val fms = map (Formula.generalize o snd) fms
      in
        Formula.listMkConj fms
      end;
in
  fun goal (Problem {formulas,...}) =
      let
        val {cnfAxioms,fofAxioms,goal} = partitionFormulas formulas

        val fm =
            case goal of
              NoGoal => Formula.False
            | CnfGoal cls => Formula.Imp (clausesToGoal cls, Formula.False)
            | FofGoal goals => formulasToGoal goals

        val fm =
            if null fofAxioms then fm
            else Formula.Imp (formulasToGoal fofAxioms, fm)

        val fm =
            if null cnfAxioms then fm
            else Formula.Imp (clausesToGoal cnfAxioms, fm)
      in
        fm
      end;

  fun normalize (Problem {formulas,...}) =
      let
        val {cnfAxioms,fofAxioms,goal} = partitionFormulas formulas

        val acc = (initialNormalization, Normalize.initialCnf)
        val acc = List.foldl addCnfAxiom acc cnfAxioms
        val acc = List.foldl addFofAxiom acc fofAxioms
      in
        case goal of
          NoGoal => [normProblemFalse acc]
        | CnfGoal cls => [normProblemFalse (List.foldl addCnfGoal acc cls)]
        | FofGoal goals => splitProblem acc goals
      end;
end;

local
  datatype blockComment =
      OutsideBlockComment
    | EnteringBlockComment
    | InsideBlockComment
    | LeavingBlockComment;

  fun stripLineComments acc strm =
      case strm of
        Stream.Nil => (rev acc, Stream.Nil)
      | Stream.Cons (line,rest) =>
        case total destLineComment line of
          SOME s => stripLineComments (s :: acc) (rest ())
        | NONE => (rev acc, Stream.filter (not o isLineComment) strm);

  fun advanceBlockComment c state =
      case state of
        OutsideBlockComment =>
        if c = #"/" then (Stream.Nil, EnteringBlockComment)
        else (Stream.singleton c, OutsideBlockComment)
      | EnteringBlockComment =>
        if c = #"*" then (Stream.Nil, InsideBlockComment)
        else if c = #"/" then (Stream.singleton #"/", EnteringBlockComment)
        else (Stream.fromList [#"/",c], OutsideBlockComment)
      | InsideBlockComment =>
        if c = #"*" then (Stream.Nil, LeavingBlockComment)
        else (Stream.Nil, InsideBlockComment)
      | LeavingBlockComment =>
        if c = #"/" then (Stream.Nil, OutsideBlockComment)
        else if c = #"*" then (Stream.Nil, LeavingBlockComment)
        else (Stream.Nil, InsideBlockComment);

  fun eofBlockComment state =
      case state of
        OutsideBlockComment => Stream.Nil
      | EnteringBlockComment => Stream.singleton #"/"
      | _ => raise Error "EOF inside a block comment";

  val stripBlockComments =
      Stream.mapsConcat advanceBlockComment eofBlockComment
        OutsideBlockComment;
in
  fun read {mapping,filename} =
      let
        (* Estimating parse error line numbers *)

        val lines = Stream.fromTextFile {filename = filename}

        val {chars,parseErrorLocation} = Parse.initialize {lines = lines}
      in
        (let
           (* The character stream *)

           val (comments,chars) = stripLineComments [] chars

           val chars = Parse.everything Parse.any chars

           val chars = stripBlockComments chars

           (* The declaration stream *)

           val declarations = Stream.toList (parseDeclaration mapping chars)

           val (includes,formulas) = partitionDeclarations declarations
         in
           Problem
             {comments = comments,
              includes = includes,
              formulas = formulas}
         end
         handle Parse.NoParse => raise Error "parse error")
        handle Error err =>
          raise Error ("error in TPTP file \"" ^ filename ^ "\" " ^
                       parseErrorLocation () ^ "\n" ^ err)
      end;
end;

local
  val newline = Stream.singleton "\n";

  fun spacer top = if top then Stream.Nil else newline;

  fun mkComment comment = mkLineComment comment ^ "\n";

  fun mkInclude inc = includeToString inc ^ "\n";

  fun formulaStream _ _ [] = Stream.Nil
    | formulaStream mapping top (h :: t) =
      Stream.append
        (Stream.concatList
           [spacer top,
            Stream.singleton (formulaToString mapping h),
            newline])
        (fn () => formulaStream mapping false t);
in
  fun write {problem,mapping,filename} =
      let
        val Problem {comments,includes,formulas} = problem

        val includesTop = null comments
        val formulasTop = includesTop andalso null includes
      in
        Stream.toTextFile
          {filename = filename}
          (Stream.concatList
             [Stream.map mkComment (Stream.fromList comments),
              spacer includesTop,
              Stream.map mkInclude (Stream.fromList includes),
              formulaStream mapping formulasTop formulas])
      end;
end;

local
  fun refute {axioms,conjecture} =
      let
        val axioms = map Thm.axiom axioms
        and conjecture = map Thm.axiom conjecture
        val problem = {axioms = axioms, conjecture = conjecture}
        val resolution = Resolution.new Resolution.default problem
      in
        case Resolution.loop resolution of
          Resolution.Contradiction _ => true
        | Resolution.Satisfiable _ => false
      end;
in
  fun prove filename =
      let
        val problem = read filename
        val problems = map #problem (normalize problem)
      in
        List.all refute problems
      end;
end;

(* ------------------------------------------------------------------------- *)
(* TSTP proofs.                                                              *)
(* ------------------------------------------------------------------------- *)

local
  fun newName avoid prefix =
      let
        fun bump i =
            let
              val name = FormulaName (prefix ^ Int.toString i)
              val i = i + 1
            in
              if memberFormulaNameSet name avoid then bump i else (name,i)
            end
      in
        bump
      end;

  fun lookupClauseSource sources cl =
      case LiteralSetMap.peek sources cl of
        SOME src => src
      | NONE => raise Bug "Tptp.lookupClauseSource";

  fun lookupFormulaName fmNames fm =
      case FormulaMap.peek fmNames fm of
        SOME name => name
      | NONE => raise Bug "Tptp.lookupFormulaName";

  fun lookupClauseName clNames cl =
      case LiteralSetMap.peek clNames cl of
        SOME name => name
      | NONE => raise Bug "Tptp.lookupClauseName";

  fun lookupClauseSourceName sources fmNames cl =
      case lookupClauseSource sources cl of
        CnfClauseSource (name,_) => name
      | FofClauseSource th =>
        let
          val (fm,_) = Normalize.destThm th
        in
          lookupFormulaName fmNames fm
        end;

  fun collectProofDeps sources ((_,inf),names_ths) =
      case inf of
        Proof.Axiom cl =>
        let
          val (names,ths) = names_ths
        in
          case lookupClauseSource sources cl of
            CnfClauseSource (name,_) =>
            let
              val names = addFormulaNameSet names name
            in
              (names,ths)
            end
          | FofClauseSource th =>
            let
              val ths = th :: ths
            in
              (names,ths)
            end
        end
      | _ => names_ths;

  fun collectNormalizeDeps ((_,inf,_),fofs_defs) =
      case inf of
        Normalize.Axiom fm =>
        let
          val (fofs,defs) = fofs_defs
          val fofs = FormulaSet.add fofs fm
        in
          (fofs,defs)
        end
      | Normalize.Definition n_d =>
        let
          val (fofs,defs) = fofs_defs
          val defs = StringMap.insert defs n_d
        in
          (fofs,defs)
        end
      | _ => fofs_defs;

  fun collectSubgoalProofDeps subgoalProof (names,fofs,defs) =
      let
        val {subgoal,sources,refutation} = subgoalProof

        val names = addListFormulaNameSet names (snd subgoal)

        val proof = Proof.proof refutation

        val (names,ths) =
            List.foldl (collectProofDeps sources) (names,[]) proof

        val normalization = Normalize.proveThms (rev ths)

        val (fofs,defs) =
            List.foldl collectNormalizeDeps (fofs,defs) normalization

        val subgoalProof =
            {subgoal = subgoal,
             normalization = normalization,
             sources = sources,
             proof = proof}
      in
        (subgoalProof,(names,fofs,defs))
      end;

  fun addProblemFormula names fofs (formula,(avoid,formulas,fmNames)) =
      let
        val name = nameFormula formula

        val avoid = addFormulaNameSet avoid name

        val (formulas,fmNames) =
            if memberFormulaNameSet name names then
              (formula :: formulas, fmNames)
            else
              case bodyFormula formula of
                CnfFormulaBody _ => (formulas,fmNames)
              | FofFormulaBody fm =>
                if not (FormulaSet.member fm fofs) then (formulas,fmNames)
                else (formula :: formulas, FormulaMap.insert fmNames (fm,name))
      in
        (avoid,formulas,fmNames)
      end;

  fun addDefinitionFormula avoid (_,def,(formulas,i,fmNames)) =
      let
        val (name,i) = newName avoid "definition_" i

        val role = DefinitionRole

        val body = FofFormulaBody def

        val source = NoFormulaSource

        val formula =
            Formula
              {name = name,
               role = role,
               body = body,
               source = source}

        val formulas = formula :: formulas

        val fmNames = FormulaMap.insert fmNames (def,name)
      in
        (formulas,i,fmNames)
      end;

  fun addSubgoalFormula avoid subgoalProof (formulas,i) =
      let
        val {subgoal,normalization,sources,proof} = subgoalProof

        val (fm,pars) = subgoal

        val (name,i) = newName avoid "subgoal_" i

        val number = i - 1

        val (subgoal,formulas) =
            if null pars then (NONE,formulas)
            else
              let
                val role = PlainRole

                val body = FofFormulaBody fm

                val source =
                    StripFormulaSource
                      {inference = "strip",
                       parents = pars}

                val formula =
                    Formula
                      {name = name,
                       role = role,
                       body = body,
                       source = source}
              in
                (SOME (name,fm), formula :: formulas)
              end

        val subgoalProof =
            {number = number,
             subgoal = subgoal,
             normalization = normalization,
             sources = sources,
             proof = proof}
      in
        (subgoalProof,(formulas,i))
      end;

  fun mkNormalizeFormulaSource fmNames inference fms =
      let
        val fms =
            case inference of
              Normalize.Axiom fm => fm :: fms
            | Normalize.Definition (_,fm) => fm :: fms
            | _ => fms

        val parents = map (lookupFormulaName fmNames) fms
      in
        NormalizeFormulaSource
          {inference = inference,
           parents = parents}
      end;

  fun mkProofFormulaSource sources fmNames clNames inference =
      let
        val parents =
            case inference of
              Proof.Axiom cl => [lookupClauseSourceName sources fmNames cl]
            | _ =>
              let
                val cls = map Thm.clause (Proof.parents inference)
              in
                map (lookupClauseName clNames) cls
              end
      in
        ProofFormulaSource
          {inference = inference,
           parents = parents}
      end;

  fun addNormalizeFormula avoid prefix ((fm,inf,fms),acc) =
      let
        val (formulas,i,fmNames) = acc

        val (name,i) = newName avoid prefix i

        val role = PlainRole

        val body = FofFormulaBody fm

        val source = mkNormalizeFormulaSource fmNames inf fms

        val formula =
            Formula
              {name = name,
               role = role,
               body = body,
               source = source}

        val formulas = formula :: formulas

        val fmNames = FormulaMap.insert fmNames (fm,name)
      in
        (formulas,i,fmNames)
      end;

  fun isSameClause sources formulas inf =
      case inf of
        Proof.Axiom cl =>
          (case lookupClauseSource sources cl of
             CnfClauseSource (name,lits) =>
             if List.exists isBooleanLiteral lits then NONE
             else SOME name
           | _ => NONE)
      | _ => NONE;

  fun addProofFormula avoid sources fmNames prefix ((th,inf),acc) =
      let
        val (formulas,i,clNames) = acc

        val cl = Thm.clause th
      in
        case isSameClause sources formulas inf of
          SOME name =>
          let
            val clNames = LiteralSetMap.insert clNames (cl,name)
          in
            (formulas,i,clNames)
          end
        | NONE =>
          let
            val (name,i) = newName avoid prefix i

            val role = PlainRole

            val body = CnfFormulaBody (clauseFromLiteralSet cl)

            val source = mkProofFormulaSource sources fmNames clNames inf

            val formula =
                Formula
                  {name = name,
                   role = role,
                   body = body,
                   source = source}

            val formulas = formula :: formulas

            val clNames = LiteralSetMap.insert clNames (cl,name)
          in
            (formulas,i,clNames)
          end
      end;

  fun addSubgoalProofFormulas avoid fmNames (subgoalProof,formulas) =
      let
        val {number,subgoal,normalization,sources,proof} = subgoalProof

        val (formulas,fmNames) =
            case subgoal of
              NONE => (formulas,fmNames)
            | SOME (name,fm) =>
              let
                val source =
                    StripFormulaSource
                      {inference = "negate",
                       parents = [name]}

                val prefix = "negate_" ^ Int.toString number ^ "_"

                val (name,_) = newName avoid prefix 0

                val role = PlainRole

                val fm = Formula.Not fm

                val body = FofFormulaBody fm

                val formula =
                    Formula
                      {name = name,
                       role = role,
                       body = body,
                       source = source}

                val formulas = formula :: formulas

                val fmNames = FormulaMap.insert fmNames (fm,name)
              in
                (formulas,fmNames)
              end

        val prefix = "normalize_" ^ Int.toString number ^ "_"
        val (formulas,_,fmNames) =
            List.foldl (addNormalizeFormula avoid prefix)
              (formulas,0,fmNames) normalization

        val prefix = "refute_" ^ Int.toString number ^ "_"
        val clNames : formulaName LiteralSetMap.map = LiteralSetMap.new ()
        val (formulas,_,_) =
            List.foldl (addProofFormula avoid sources fmNames prefix)
              (formulas,0,clNames) proof
      in
        formulas
      end;
in
  fun fromProof {problem,proofs} =
      let
        val names = emptyFormulaNameSet
        and fofs = FormulaSet.empty
        and defs : Formula.formula StringMap.map = StringMap.new ()

        val (proofs,(names,fofs,defs)) =
            maps collectSubgoalProofDeps proofs (names,fofs,defs)

        val Problem {formulas,...} = problem

        val fmNames : formulaName FormulaMap.map = FormulaMap.new ()
        val (avoid,formulas,fmNames) =
            List.foldl (addProblemFormula names fofs)
              (emptyFormulaNameSet,[],fmNames) formulas

        val (formulas,_,fmNames) =
            StringMap.foldl (addDefinitionFormula avoid)
              (formulas,0,fmNames) defs

        val (proofs,(formulas,_)) =
            maps (addSubgoalFormula avoid) proofs (formulas,0)

        val formulas =
            List.foldl (addSubgoalProofFormulas avoid fmNames) formulas proofs
      in
        rev formulas
      end
(*MetisDebug
      handle Error err => raise Bug ("Tptp.fromProof: shouldn't fail:\n" ^ err);
*)
end;

end
end;
print_depth 10;