now if this doesn't make SML/NJ happy, nothing will
authorblanchet
Sat, 01 May 2010 10:37:31 +0200
changeset 36605 6f11c9b1fb3e
parent 36604 65a8b49e8948
child 36606 5479681ab465
now if this doesn't make SML/NJ happy, nothing will
src/HOL/Tools/Sledgehammer/sledgehammer_proof_reconstruct.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_proof_reconstruct.ML	Sat May 01 00:23:57 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_proof_reconstruct.ML	Sat May 01 10:37:31 2010 +0200
@@ -13,7 +13,7 @@
   val chained_hint: string
   val invert_const: string -> string
   val invert_type_const: string -> string
-  val num_typargs: theory -> string -> int
+  val num_type_args: theory -> string -> int
   val make_tvar: string -> typ
   val strip_prefix: string -> string -> string option
   val metis_line: int -> int -> string list -> string
@@ -21,11 +21,11 @@
     minimize_command * string * string vector * thm * int
     -> string * string list
   val isar_proof_text:
-    name_pool option * bool * int * Proof.context * int list list
+    name_pool option * bool * bool * int * Proof.context * int list list
     -> minimize_command * string * string vector * thm * int
     -> string * string list
   val proof_text:
-    bool -> name_pool option * bool * int * Proof.context * int list list
+    bool -> name_pool option * bool * bool * int * Proof.context * int list list
     -> minimize_command * string * string vector * thm * int
     -> string * string list
 end;
@@ -35,6 +35,7 @@
 
 open Sledgehammer_Util
 open Sledgehammer_FOL_Clause
+open Sledgehammer_HOL_Clause
 open Sledgehammer_Fact_Preprocessor
 
 type minimize_command = string list -> string
@@ -216,7 +217,7 @@
 
 (**** INTERPRETATION OF TSTP SYNTAX TREES ****)
 
-exception NODE of node
+exception NODE of node list
 
 (*If string s has the prefix s1, return the result of deleting it.*)
 fun strip_prefix s1 s =
@@ -237,16 +238,16 @@
 fun make_tparam s = TypeInfer.param 0 (s, HOLogic.typeS)
 fun make_var (b,T) = Var((b,0),T);
 
-(*Type variables are given the basic sort, HOL.type. Some will later be constrained
-  by information from type literals, or by type inference.*)
-fun type_of_node (u as IntLeaf _) = raise NODE u
-  | type_of_node (u as StrNode (a, us)) =
-    let val Ts = map type_of_node us in
+(* Type variables are given the basic sort "HOL.type". Some will later be
+  constrained by information from type literals, or by type inference. *)
+fun type_from_node (u as IntLeaf _) = raise NODE [u]
+  | type_from_node (u as StrNode (a, us)) =
+    let val Ts = map type_from_node us in
       case strip_prefix tconst_prefix a of
         SOME b => Type (invert_type_const b, Ts)
       | NONE =>
         if not (null us) then
-          raise NODE u  (*only tconsts have type arguments*)
+          raise NODE [u]  (* only "tconst"s have type arguments *)
         else case strip_prefix tfree_prefix a of
           SOME b => TFree ("'" ^ b, HOLogic.typeS)
         | NONE =>
@@ -263,10 +264,8 @@
 fun invert_const c = c |> Symtab.lookup const_trans_table_inv |> the_default c
 
 (*The number of type arguments of a constant, zero if it's monomorphic*)
-fun num_typargs thy s = length (Sign.const_typargs thy (s, Sign.the_const_type thy s));
-
-(*Generates a constant, given its type arguments*)
-fun const_of thy (a,Ts) = Const(a, Sign.const_instance thy (a,Ts));
+fun num_type_args thy s =
+  length (Sign.const_typargs thy (s, Sign.the_const_type thy s))
 
 fun fix_atp_variable_name s =
   let
@@ -285,59 +284,81 @@
     | _ => s
   end
 
-(*First-order translation. No types are known for variables. HOLogic.typeT should allow
-  them to be inferred.*)
-fun term_of_node args thy u =
-  case u of
-    IntLeaf _ => raise NODE u
-  | StrNode ("hBOOL", [u]) => term_of_node [] thy u  (* ignore hBOOL *)
-  | StrNode ("hAPP", [u1, u2]) => term_of_node (u2 :: args) thy u1
-  | StrNode (a, us) =>
-    case strip_prefix const_prefix a of
-      SOME "equal" =>
-      list_comb (Const (@{const_name "op ="}, HOLogic.typeT),
-                 map (term_of_node [] thy) us)
-    | SOME b =>
-      let
-        val c = invert_const b
-        val nterms = length us - num_typargs thy c
-        val ts = map (term_of_node [] thy) (take nterms us @ args)
-        (*Extra args from hAPP come AFTER any arguments given directly to the
-          constant.*)
-        val Ts = map type_of_node (drop nterms us)
-      in list_comb(const_of thy (c, Ts), ts) end
-    | NONE => (*a variable, not a constant*)
-      let
-        val opr =
-          (* a Free variable is typically a Skolem function *)
-          case strip_prefix fixed_var_prefix a of
-            SOME b => Free (b, HOLogic.typeT)
-          | NONE =>
-            case strip_prefix schematic_var_prefix a of
-              SOME b => make_var (b, HOLogic.typeT)
-            | NONE =>
-              (* Variable from the ATP, say "X1" *)
-              make_var (fix_atp_variable_name a, HOLogic.typeT)
-      in list_comb (opr, map (term_of_node [] thy) (us @ args)) end
+(* First-order translation. No types are known for variables. "HOLogic.typeT"
+   should allow them to be inferred.*)
+fun term_from_node thy full_types =
+  let
+    fun aux opt_T args u =
+      case u of
+        IntLeaf _ => raise NODE [u]
+      | StrNode ("hBOOL", [u1]) => aux (SOME @{typ bool}) [] u1
+      | StrNode ("hAPP", [u1, u2]) => aux opt_T (u2 :: args) u1
+      | StrNode ("c_Not", [u1]) => @{const Not} $ aux (SOME @{typ bool}) [] u1
+      | StrNode (a, us) =>
+        if a = type_wrapper_name then
+          case us of
+            [term_u, typ_u] => aux (SOME (type_from_node typ_u)) args term_u
+          | _ => raise NODE us
+        else case strip_prefix const_prefix a of
+          SOME "equal" =>
+          list_comb (Const (@{const_name "op ="}, HOLogic.typeT),
+                     map (aux NONE []) us)
+        | SOME b =>
+          let
+            val c = invert_const b
+            val num_type_args = num_type_args thy c
+            val actual_num_type_args = if full_types then 0 else num_type_args
+            val num_term_args = length us - actual_num_type_args
+            val ts = map (aux NONE []) (take num_term_args us @ args)
+            val t =
+              Const (c, if full_types then
+                          case opt_T of
+                            SOME T => map fastype_of ts ---> T
+                          | NONE =>
+                            if num_type_args = 0 then
+                              Sign.const_instance thy (c, [])
+                            else
+                              raise Fail ("no type information for " ^ quote c)
+                        else
+                          (* Extra args from "hAPP" come after any arguments
+                             given directly to the constant. *)
+                          Sign.const_instance thy (c,
+                                    map type_from_node (drop num_term_args us)))
+          in list_comb (t, ts) end
+        | NONE => (* a free or schematic variable *)
+          let
+            val ts = map (aux NONE []) (us @ args)
+            val T = map fastype_of ts ---> HOLogic.typeT
+            val t =
+              case strip_prefix fixed_var_prefix a of
+                SOME b => Free (b, T)
+              | NONE =>
+                case strip_prefix schematic_var_prefix a of
+                  SOME b => make_var (b, T)
+                | NONE =>
+                  (* Variable from the ATP, say "X1" *)
+                  make_var (fix_atp_variable_name a, T)
+          in list_comb (t, ts) end
+  in aux end
 
 (* Type class literal applied to a type. Returns triple of polarity, class,
    type. *)
-fun constraint_of_node pos (StrNode ("c_Not", [u])) =
-    constraint_of_node (not pos) u
-  | constraint_of_node pos u = case u of
-        IntLeaf _ => raise NODE u
+fun type_constraint_from_node pos (StrNode ("c_Not", [u])) =
+    type_constraint_from_node (not pos) u
+  | type_constraint_from_node pos u = case u of
+        IntLeaf _ => raise NODE [u]
       | StrNode (a, us) =>
-            (case (strip_prefix class_prefix a, map type_of_node us) of
+            (case (strip_prefix class_prefix a, map type_from_node us) of
                  (SOME b, [T]) => (pos, b, T)
-               | _ => raise NODE u)
+               | _ => raise NODE [u])
 
 (** Accumulate type constraints in a clause: negative type literals **)
 
 fun add_var (key, z)  = Vartab.map_default (key, []) (cons z)
 
-fun add_constraint ((false, cl, TFree(a,_)), vt) = add_var ((a,~1),cl) vt
-  | add_constraint ((false, cl, TVar(ix,_)), vt) = add_var (ix,cl) vt
-  | add_constraint (_, vt) = vt;
+fun add_type_constraint (false, cl, TFree (a ,_)) = add_var ((a, ~1), cl)
+  | add_type_constraint (false, cl, TVar (ix, _)) = add_var (ix, cl)
+  | add_type_constraint _ = I
 
 fun is_positive_literal (@{const Not} $ _) = false
   | is_positive_literal t = true
@@ -373,10 +394,16 @@
          |> clause_for_literals thy
 
 (*Accumulate sort constraints in vt, with "real" literals in lits.*)
-fun lits_of_nodes thy (vt, lits) [] = (vt, finish_clause thy lits)
-  | lits_of_nodes thy (vt, lits) (u :: us) =
-    lits_of_nodes thy (add_constraint (constraint_of_node true u, vt), lits) us
-    handle NODE _ => lits_of_nodes thy (vt, term_of_node [] thy u :: lits) us
+fun lits_of_nodes thy full_types (vt, lits) us =
+  case us of
+    [] => (vt, finish_clause thy lits)
+  | (u :: us) =>
+    lits_of_nodes thy full_types
+        (add_type_constraint (type_constraint_from_node true u) vt, lits) us
+    handle NODE _ =>
+           lits_of_nodes thy full_types
+                         (vt, term_from_node thy full_types (SOME @{typ bool})
+                                             [] u :: lits) us
 
 (*Update TVars/TFrees with detected sort constraints.*)
 fun repair_sorts vt =
@@ -394,8 +421,9 @@
   in not (Vartab.is_empty vt) ? do_term end
 
 fun unskolemize_term t =
-  fold forall_of (Term.add_consts t []
-                 |> filter (is_skolem_const_name o fst) |> map Const) t
+  Term.add_consts t []
+  |> filter (is_skolem_const_name o fst) |> map Const
+  |> rpair t |-> fold forall_of
 
 val combinator_table =
   [(@{const_name COMBI}, @{thm COMBI_def_raw}),
@@ -415,12 +443,13 @@
 (* Interpret a list of syntax trees as a clause, given by "real" literals and
    sort constraints. "vt" holds the initial sort constraints, from the
    conjecture clauses. *)
-fun clause_of_nodes ctxt vt us =
-  let val (vt, t) = lits_of_nodes (ProofContext.theory_of ctxt) (vt, []) us in
-    t |> repair_sorts vt
-  end
+fun clause_of_nodes ctxt full_types vt us =
+  let
+    val thy = ProofContext.theory_of ctxt
+    val (vt, t) = lits_of_nodes thy full_types (vt, []) us
+  in repair_sorts vt t end
 fun check_formula ctxt =
-  TypeInfer.constrain HOLogic.boolT
+  TypeInfer.constrain @{typ bool}
   #> Syntax.check_term (ProofContext.set_mode ProofContext.mode_schematic ctxt)
 
 (** Global sort constraints on TFrees (from tfree_tcs) are positive unit
@@ -431,7 +460,7 @@
 fun tfree_constraints_of_clauses vt [] = vt
   | tfree_constraints_of_clauses vt ([lit] :: uss) =
     (tfree_constraints_of_clauses (add_tfree_constraint
-                                          (constraint_of_node true lit) vt) uss
+                                    (type_constraint_from_node true lit) vt) uss
      handle NODE _ => (* Not a positive type constraint? Ignore the literal. *)
      tfree_constraints_of_clauses vt uss)
   | tfree_constraints_of_clauses vt (_ :: uss) =
@@ -446,13 +475,13 @@
 fun clauses_in_lines (Definition (_, u, us)) = u :: us
   | clauses_in_lines (Inference (_, us, _)) = us
 
-fun decode_line vt (Definition (num, u, us)) ctxt =
+fun decode_line full_types vt (Definition (num, u, us)) ctxt =
     let
-      val t1 = clause_of_nodes ctxt vt [u]
+      val t1 = clause_of_nodes ctxt full_types vt [u]
       val vars = snd (strip_comb t1)
       val frees = map unvarify_term vars
       val unvarify_args = subst_atomic (vars ~~ frees)
-      val t2 = clause_of_nodes ctxt vt us
+      val t2 = clause_of_nodes ctxt full_types vt us
       val (t1, t2) =
         HOLogic.eq_const HOLogic.typeT $ t1 $ t2
         |> unvarify_args |> uncombine_term |> check_formula ctxt
@@ -461,19 +490,19 @@
       (Definition (num, t1, t2),
        fold Variable.declare_term (maps OldTerm.term_frees [t1, t2]) ctxt)
     end
-  | decode_line vt (Inference (num, us, deps)) ctxt =
+  | decode_line full_types vt (Inference (num, us, deps)) ctxt =
     let
-      val t = us |> clause_of_nodes ctxt vt
+      val t = us |> clause_of_nodes ctxt full_types vt
                  |> unskolemize_term |> uncombine_term |> check_formula ctxt
     in
       (Inference (num, t, deps),
        fold Variable.declare_term (OldTerm.term_frees t) ctxt)
     end
-fun decode_lines ctxt lines =
+fun decode_lines ctxt full_types lines =
   let
     val vt = tfree_constraints_of_clauses Vartab.empty
                                           (map clauses_in_lines lines)
-  in #1 (fold_map (decode_line vt) lines ctxt) end
+  in #1 (fold_map (decode_line full_types vt) lines ctxt) end
 
 fun aint_inference _ (Definition _) = true
   | aint_inference t (Inference (_, t', _)) = not (t aconv t')
@@ -590,7 +619,13 @@
       "To minimize the number of lemmas, try this command: " ^
       Markup.markup Markup.sendback command ^ ".\n"
 
-fun metis_proof_text (minimize_command, atp_proof, thm_names, goal, i) =
+(* Make SML/NJ happy. *)
+type isar_params =
+  name_pool option * bool * bool * int * Proof.context * int list list
+type other_params = minimize_command * string * string vector * thm * int
+
+fun metis_proof_text ((minimize_command, atp_proof, thm_names, goal, i)
+                      : other_params) =
   let
     val lemmas =
       atp_proof |> extract_clause_numbers_in_atp_proof
@@ -644,20 +679,20 @@
           forall_vars t,
           ByMetis (fold (add_fact_from_dep thm_names) deps ([], [])))
 
-fun proof_from_atp_proof pool ctxt shrink_factor atp_proof conjecture_shape
-                         thm_names frees =
+fun proof_from_atp_proof pool ctxt full_types shrink_factor atp_proof
+                         conjecture_shape thm_names params frees =
   let
     val lines =
       atp_proof ^ "$" (* the $ sign acts as a sentinel *)
       |> parse_proof pool
-      |> decode_lines ctxt
+      |> decode_lines ctxt full_types
       |> rpair [] |-> fold_rev (add_line conjecture_shape thm_names)
       |> rpair [] |-> fold_rev add_nontrivial_line
       |> rpair (0, []) |-> fold_rev (add_desired_line ctxt shrink_factor
                                                conjecture_shape thm_names frees)
       |> snd
   in
-    (if null frees then [] else [Fix frees]) @
+    (if null params then [] else [Fix params]) @
     map2 (step_for_line thm_names) (length lines downto 1) lines
   end
 
@@ -951,17 +986,20 @@
         do_indent 0 ^ (if n <> 1 then "next" else "qed") ^ "\n"
   in do_proof end
 
-fun isar_proof_text (pool, debug, shrink_factor, ctxt, conjecture_shape)
-                    (minimize_command, atp_proof, thm_names, goal, i) =
+fun isar_proof_text ((pool, debug, full_types, shrink_factor, ctxt,
+                      conjecture_shape) : isar_params)
+                    ((minimize_command, atp_proof, thm_names, goal, i)
+                     : other_params) =
   let
     val thy = ProofContext.theory_of ctxt
-    val (frees, hyp_ts, concl_t) = strip_subgoal goal i
+    val (params, hyp_ts, concl_t) = strip_subgoal goal i
+    val frees = fold Term.add_frees (concl_t :: hyp_ts) []
     val n = Logic.count_prems (prop_of goal)
     val (one_line_proof, lemma_names) =
       metis_proof_text (minimize_command, atp_proof, thm_names, goal, i)
     fun isar_proof_for () =
-      case proof_from_atp_proof pool ctxt shrink_factor atp_proof
-                                conjecture_shape thm_names frees
+      case proof_from_atp_proof pool ctxt full_types shrink_factor atp_proof
+                                conjecture_shape thm_names params frees
            |> redirect_proof thy conjecture_shape hyp_ts concl_t
            |> kill_duplicate_assumptions_in_proof
            |> then_chain_proof