src/HOL/Tools/Sledgehammer/sledgehammer_proof_reconstruct.ML
changeset 36393 be73a2b2443b
parent 36392 c00c57850eb7
child 36395 e73923451f6f
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_proof_reconstruct.ML	Sun Apr 25 10:22:31 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_proof_reconstruct.ML	Sun Apr 25 11:38:46 2010 +0200
@@ -8,6 +8,7 @@
 signature SLEDGEHAMMER_PROOF_RECONSTRUCT =
 sig
   type minimize_command = string list -> string
+  type name_pool = Sledgehammer_FOL_Clause.name_pool
 
   val chained_hint: string
   val invert_const: string -> string
@@ -20,11 +21,11 @@
     minimize_command * string * string vector * thm * int
     -> string * string list
   val isar_proof_text:
-    bool -> int -> bool -> Proof.context
+    name_pool option -> bool -> int -> bool -> Proof.context
     -> minimize_command * string * string vector * thm * int
     -> string * string list
   val proof_text:
-    bool -> bool -> int -> bool -> Proof.context
+    bool -> name_pool option -> bool -> int -> bool -> Proof.context
     -> minimize_command * string * string vector * thm * int
     -> string * string list
 end;
@@ -37,18 +38,23 @@
 
 type minimize_command = string list -> string
 
-val trace_proof_path = Path.basic "sledgehammer_trace_proof"
-
-fun trace_proof_msg f =
-  if !trace then File.append (File.tmp_path trace_proof_path) (f ()) else ();
-
-fun string_of_thm ctxt = PrintMode.setmp [] (Display.string_of_thm ctxt);
-
 fun is_ident_char c = Char.isAlphaNum c orelse c = #"_"
 fun is_head_digit s = Char.isDigit (String.sub (s, 0))
 
 fun is_axiom thm_names line_no = line_no <= Vector.length thm_names
 
+fun ugly_name NONE s = s
+  | ugly_name (SOME the_pool) s =
+    case Symtab.lookup (snd the_pool) s of
+      SOME s' => s'
+    | NONE => s
+
+val trace_path = Path.basic "sledgehammer_proof_trace"
+fun trace_proof_msg f =
+  if !trace then File.append (File.tmp_path trace_path) (f ()) else ();
+
+val string_of_thm = PrintMode.setmp [] o Display.string_of_thm
+
 (**** PARSING OF TSTP FORMAT ****)
 
 (* Syntax trees, either term list or formulae *)
@@ -66,52 +72,60 @@
 val parse_integer = Scan.many1 is_head_digit >> (the o Int.fromString o implode)
 
 (* needed for SPASS's output format *)
-fun fix_bool_literal "true" = "c_True"
-  | fix_bool_literal "false" = "c_False"
-fun fix_symbol "equal" = "c_equal"
-  | fix_symbol s = s
+fun repair_bool_literal "true" = "c_True"
+  | repair_bool_literal "false" = "c_False"
+fun repair_name pool "equal" = "c_equal"
+  | repair_name pool s = ugly_name pool s
 (* Generalized first-order terms, which include file names, numbers, etc. *)
-fun parse_term x =
+(* The "x" argument is not strictly necessary, but without it Poly/ML loops
+   forever at compile time. *)
+fun parse_term pool x =
   (parse_quoted >> atom
    || parse_integer >> SInt
-   || $$ "$" |-- Symbol.scan_id >> (atom o fix_bool_literal)
-   || (Symbol.scan_id >> fix_symbol)
-      -- Scan.optional ($$ "(" |-- parse_terms --| $$ ")") [] >> SBranch
-   || $$ "(" |-- parse_term --| $$ ")"
-   || $$ "[" |-- Scan.optional parse_terms [] --| $$ "]" >> slist_of) x
-and parse_terms x = (parse_term ::: Scan.repeat ($$ "," |-- parse_term)) x
+   || $$ "$" |-- Symbol.scan_id >> (atom o repair_bool_literal)
+   || (Symbol.scan_id >> repair_name pool)
+      -- Scan.optional ($$ "(" |-- parse_terms pool --| $$ ")") [] >> SBranch
+   || $$ "(" |-- parse_term pool --| $$ ")"
+   || $$ "[" |-- Scan.optional (parse_terms pool) [] --| $$ "]" >> slist_of) x
+and parse_terms pool x =
+  (parse_term pool ::: Scan.repeat ($$ "," |-- parse_term pool)) x
 
 fun negate_stree t = SBranch ("c_Not", [t])
 fun equate_strees t1 t2 = SBranch ("c_equal", [t1, t2]);
 
 (* Apply equal or not-equal to a term. *)
-fun do_equal (t, NONE) = t
-  | do_equal (t1, SOME (NONE, t2)) = equate_strees t1 t2
-  | do_equal (t1, SOME (SOME _, t2)) = negate_stree (equate_strees t1 t2)
+fun repair_predicate_term (t, NONE) = t
+  | repair_predicate_term (t1, SOME (NONE, t2)) = equate_strees t1 t2
+  | repair_predicate_term (t1, SOME (SOME _, t2)) =
+    negate_stree (equate_strees t1 t2)
+fun parse_predicate_term pool =
+  parse_term pool -- Scan.option (Scan.option ($$ "!") --| $$ "="
+                                  -- parse_term pool)
+  >> repair_predicate_term
 (*Literals can involve negation, = and !=.*)
-fun parse_literal x =
-  ($$ "~" |-- parse_literal >> negate_stree
-   || (parse_term -- Scan.option (Scan.option ($$ "!") --| $$ "=" -- parse_term)
-       >> do_equal)) x
+fun parse_literal pool x =
+  ($$ "~" |-- parse_literal pool >> negate_stree || parse_predicate_term pool) x
 
-val parse_literals = parse_literal ::: Scan.repeat ($$ "|" |-- parse_literal)
+fun parse_literals pool =
+  parse_literal pool ::: Scan.repeat ($$ "|" |-- parse_literal pool)
 
 (*Clause: a list of literals separated by the disjunction sign*)
-val parse_clause =
-  $$ "(" |-- parse_literals --| $$ ")" || Scan.single parse_literal
+fun parse_clause pool =
+  $$ "(" |-- parse_literals pool --| $$ ")" || Scan.single (parse_literal pool)
 
 fun ints_of_stree (SInt n) = cons n
   | ints_of_stree (SBranch (_, ts)) = fold ints_of_stree ts
 val parse_tstp_annotations =
-  Scan.optional ($$ "," |-- parse_term --| Scan.option ($$ "," |-- parse_terms)
+  Scan.optional ($$ "," |-- parse_term NONE
+                   --| Scan.option ($$ "," |-- parse_terms NONE)
                  >> (fn source => ints_of_stree source [])) []
 
 (* <cnf_annotated> ::= cnf(<name>, <formula_role>, <cnf_formula> <annotations>).
    The <name> could be an identifier, but we assume integers. *)
 fun retuple_tstp_line ((name, ts), deps) = (name, ts, deps)
-val parse_tstp_line =
+fun parse_tstp_line pool =
   (Scan.this_string "cnf" -- $$ "(") |-- parse_integer --| $$ ","
-   --| Symbol.scan_id --| $$ "," -- parse_clause -- parse_tstp_annotations
+   --| Symbol.scan_id --| $$ "," -- parse_clause pool -- parse_tstp_annotations
    --| $$ ")" --| $$ "."
   >> retuple_tstp_line
 
@@ -127,23 +141,26 @@
 
 (* It is not clear why some literals are followed by sequences of stars. We
    ignore them. *)
-val parse_starred_literal = parse_literal --| Scan.repeat ($$ "*" || $$ " ")
+fun parse_starred_predicate_term pool =
+  parse_predicate_term pool --| Scan.repeat ($$ "*" || $$ " ")
 
-val parse_horn_clause =
-  Scan.repeat parse_starred_literal --| $$ "-" --| $$ ">"
-  -- Scan.repeat parse_starred_literal
+fun parse_horn_clause pool =
+  Scan.repeat (parse_starred_predicate_term pool) --| $$ "-" --| $$ ">"
+  -- Scan.repeat (parse_starred_predicate_term pool)
   >> (fn ([], []) => [atom "c_False"]
        | (clauses1, clauses2) => map negate_stree clauses1 @ clauses2)
 
-(* Syntax: <name>[0:<inference><annotations>] || -> <cnf_formula>. *)
+(* Syntax: <name>[0:<inference><annotations>] ||
+           <cnf_formulas> -> <cnf_formulas>. *)
 fun retuple_spass_proof_line ((name, deps), ts) = (name, ts, deps)
-val parse_spass_proof_line =
+fun parse_spass_proof_line pool =
   parse_integer --| $$ "[" --| $$ "0" --| $$ ":" --| Symbol.scan_id
   -- parse_spass_annotations --| $$ "]" --| $$ "|" --| $$ "|"
-  -- parse_horn_clause --| $$ "."
+  -- parse_horn_clause pool --| $$ "."
   >> retuple_spass_proof_line
 
-val parse_proof_line = fst o (parse_tstp_line || parse_spass_proof_line)
+fun parse_proof_line pool = 
+  fst o (parse_tstp_line pool || parse_spass_proof_line pool)
 
 (**** INTERPRETATION OF TSTP SYNTAX TREES ****)
 
@@ -271,7 +288,7 @@
       lits_of_strees ctxt (vt, term_of_stree [] (ProofContext.theory_of ctxt) t :: lits) ts;
 
 (*Update TVars/TFrees with detected sort constraints.*)
-fun fix_sorts vt =
+fun repair_sorts vt =
   let fun tysubst (Type (a, Ts)) = Type (a, map tysubst Ts)
         | tysubst (TVar (xi, s)) = TVar (xi, the_default s (Vartab.lookup vt xi))
         | tysubst (TFree (x, s)) = TFree (x, the_default s (Vartab.lookup vt (x, ~1)))
@@ -285,9 +302,10 @@
 
 (*Interpret a list of syntax trees as a clause, given by "real" literals and sort constraints.
   vt0 holds the initial sort constraints, from the conjecture clauses.*)
-fun clause_of_strees ctxt vt0 ts =
-  let val (vt, dt) = lits_of_strees ctxt (vt0,[]) ts in
-    singleton (Syntax.check_terms ctxt) (TypeInfer.constrain HOLogic.boolT (fix_sorts vt dt))
+fun clause_of_strees ctxt vt ts =
+  let val (vt, dt) = lits_of_strees ctxt (vt, []) ts in
+    dt |> repair_sorts vt |> TypeInfer.constrain HOLogic.boolT
+       |> Syntax.check_term ctxt
   end
 
 fun gen_all_vars t = fold_rev Logic.all (OldTerm.term_vars t) t;
@@ -491,10 +509,10 @@
 fun isar_proof_end 1 = "qed"
   | isar_proof_end _ = "next"
 
-fun isar_proof_from_atp_proof cnfs modulus sorts ctxt goal i thm_names =
+fun isar_proof_from_atp_proof pool modulus sorts ctxt cnfs thm_names goal i =
   let
     val _ = trace_proof_msg (K "\nisar_proof_from_atp_proof: start\n")
-    val tuples = map (parse_proof_line o explode) cnfs
+    val tuples = map (parse_proof_line pool o explode) cnfs
     val _ = trace_proof_msg (fn () =>
       Int.toString (length tuples) ^ " tuples extracted\n")
     val ctxt = ProofContext.set_mode ProofContext.mode_schematic ctxt
@@ -600,7 +618,7 @@
 
 val strip_spaces = strip_spaces_in_list o String.explode
 
-fun isar_proof_text debug modulus sorts ctxt
+fun isar_proof_text pool debug modulus sorts ctxt
                     (minimize_command, proof, thm_names, goal, i) =
   let
     val cnfs = proof |> split_lines |> map strip_spaces |> filter is_proof_line
@@ -608,7 +626,8 @@
       metis_proof_text (minimize_command, proof, thm_names, goal, i)
     val tokens = String.tokens (fn c => c = #" ") one_line_proof
     fun isar_proof_for () =
-      case isar_proof_from_atp_proof cnfs modulus sorts ctxt goal i thm_names of
+      case isar_proof_from_atp_proof pool modulus sorts ctxt cnfs thm_names goal
+                                     i of
         "" => ""
       | isar_proof =>
         "\nStructured proof:\n" ^ Markup.markup Markup.sendback isar_proof
@@ -622,8 +641,8 @@
         |> the_default "Warning: The Isar proof construction failed.\n"
   in (one_line_proof ^ isar_proof, lemma_names) end
 
-fun proof_text isar_proof debug modulus sorts ctxt =
-  if isar_proof then isar_proof_text debug modulus sorts ctxt
+fun proof_text isar_proof pool debug modulus sorts ctxt =
+  if isar_proof then isar_proof_text pool debug modulus sorts ctxt
   else metis_proof_text
 
 end;