src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML
changeset 42561 23ddc4e3d19c
parent 42560 7bb3796a4975
child 42562 f1d903f789b1
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML	Sun May 01 18:37:24 2011 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML	Sun May 01 18:37:24 2011 +0200
@@ -206,6 +206,7 @@
                                                 quote s)) parse_mangled_type))
     |> fst
 
+val unmangled_const_name = space_explode mangled_type_sep #> hd
 fun unmangled_const s =
   let val ss = space_explode mangled_type_sep s in
     (hd ss, map unmangled_type (tl ss))
@@ -340,12 +341,12 @@
      ctypes_sorts = ctypes_sorts}
   end
 
-fun make_fact ctxt keep_trivial eq_as_iff presimp ((name, _), th) =
-  case (keep_trivial,
-        make_formula ctxt eq_as_iff presimp name Axiom (prop_of th)) of
+fun make_fact ctxt keep_trivial eq_as_iff presimp ((name, _), t) =
+  case (keep_trivial, make_formula ctxt eq_as_iff presimp name Axiom t) of
     (false, {combformula = AAtom (CombConst (("c_True", _), _, _)), ...}) =>
     NONE
   | (_, formula) => SOME formula
+
 fun make_conjecture ctxt ts =
   let val last = length ts - 1 in
     map2 (fn j => make_formula ctxt true true (string_of_int j)
@@ -363,51 +364,63 @@
   | formula_fold f (AConn (_, phis)) = fold (formula_fold f) phis
   | formula_fold f (AAtom tm) = f tm
 
-fun count_term (ATerm ((s, _), tms)) =
-  (if is_atp_variable s then I else Symtab.map_entry s (Integer.add 1))
-  #> fold count_term tms
-fun count_formula x = formula_fold count_term x
-
-val init_counters =
-  metis_helpers |> map fst |> sort_distinct string_ord |> map (rpair 0)
-  |> Symtab.make
-
-(* ### FIXME: do this on repaired combterms *)
-fun get_helper_facts ctxt type_sys formulas =
+fun ti_ti_helper_fact () =
   let
-    val no_dangerous_types = type_system_types_dangerous_types type_sys
-    val ct = init_counters |> fold count_formula formulas
-    fun is_used s = the (Symtab.lookup ct s) > 0
-    fun dub c needs_full_types (th, j) =
-      ((c ^ "_" ^ string_of_int j ^ (if needs_full_types then "ft" else ""),
-        false), th)
-    fun make_facts eq_as_iff = map_filter (make_fact ctxt false eq_as_iff false)
+    fun var s = ATerm (`I s, [])
+    fun tag tm = ATerm (`I type_tag_name, [var "X", tm])
   in
-    (metis_helpers
-     |> filter (is_used o fst)
-     |> maps (fn (c, (needs_full_types, ths)) =>
-                 if needs_full_types andalso not no_dangerous_types then
-                   []
-                 else
-                   ths ~~ (1 upto length ths)
-                   |> map (dub c needs_full_types)
-                   |> make_facts (not needs_full_types)),
-     if type_sys = Tags false then
-       let
-         fun var s = ATerm (`I s, [])
-         fun tag tm = ATerm (`I type_tag_name, [var "X", tm])
-       in
-         [Formula (Fof, helper_prefix ^ ascii_of "ti_ti", Axiom,
-                   AAtom (ATerm (`I "equal",
-                                 [tag (tag (var "Y")), tag (var "Y")]))
-                   |> close_formula_universally, NONE, NONE)]
-       end
-     else
-       [])
+    Formula (Fof, helper_prefix ^ ascii_of "ti_ti", Axiom,
+             AAtom (ATerm (`I "equal", [tag (tag (var "Y")), tag (var "Y")]))
+             |> close_formula_universally, NONE, NONE)
   end
 
+(* FIXME #### : abolish combtyp altogether *)
+fun typ_from_combtyp (CombType ((s, _), tys)) =
+    Type (s |> strip_prefix_and_unascii type_const_prefix |> the
+            |> invert_const,
+          map typ_from_combtyp tys)
+  | typ_from_combtyp (CombTFree (s, _)) =
+    TFree (s |> strip_prefix_and_unascii tfree_prefix |> the, HOLogic.typeS)
+  | typ_from_combtyp (CombTVar (s, _)) =
+    TVar ((s |> strip_prefix_and_unascii tvar_prefix |> the, 0), HOLogic.typeS)
+
+fun helper_facts_for_typed_const ctxt type_sys s (_, _, ty) =
+  case strip_prefix_and_unascii const_prefix s of
+    SOME s'' =>
+    let
+      val thy = Proof_Context.theory_of ctxt
+      val unmangled_s = s'' |> unmangled_const_name
+      (* ### FIXME avoid duplicate names *)
+      fun dub_and_inst c needs_full_types (th, j) =
+        ((c ^ "_" ^ string_of_int j ^ (if needs_full_types then "ft" else ""),
+          false),
+         th |> prop_of
+            |> specialize_type thy (invert_const unmangled_s,
+                                    typ_from_combtyp ty))
+      fun make_facts eq_as_iff =
+        map_filter (make_fact ctxt false eq_as_iff false)
+    in
+      metis_helpers
+      |> maps (fn (metis_s, (needs_full_types, ths)) =>
+                  if metis_s <> unmangled_s orelse
+                     (needs_full_types andalso
+                      not (type_system_types_dangerous_types type_sys)) then
+                    []
+                  else
+                    ths ~~ (1 upto length ths)
+                    |> map (dub_and_inst s needs_full_types)
+                    |> make_facts (not needs_full_types))
+    end
+  | NONE => []
+fun helper_facts_for_const ctxt type_sys (s, xs) =
+  maps (helper_facts_for_typed_const ctxt type_sys s) xs
+fun helper_facts ctxt type_sys typed_const_tab =
+  (Symtab.fold_rev (append o helper_facts_for_const ctxt type_sys)
+                   typed_const_tab [],
+   if type_sys = Tags false then [ti_ti_helper_fact ()] else [])
+
 fun translate_atp_fact ctxt keep_trivial =
-  `(make_fact ctxt keep_trivial true true)
+  `(make_fact ctxt keep_trivial true true o apsnd prop_of)
 
 fun translate_formulas ctxt type_sys hyp_ts concl_t rich_facts =
   let
@@ -435,6 +448,44 @@
     (fact_names |> map single, (conjs, facts, class_rel_clauses, arity_clauses))
   end
 
+val proxy_table =
+  [("c_False", ("c_fFalse", @{const_name Metis.fFalse})),
+   ("c_True", ("c_fTrue", @{const_name Metis.fTrue})),
+   ("c_Not", ("c_fNot", @{const_name Metis.fNot})),
+   ("c_conj", ("c_fconj", @{const_name Metis.fconj})),
+   ("c_disj", ("c_fdisj", @{const_name Metis.fdisj})),
+   ("c_implies", ("c_fimplies", @{const_name Metis.fimplies})),
+   ("equal", ("c_fequal", @{const_name Metis.fequal}))]
+
+fun repair_combterm_consts type_sys =
+  let
+    fun aux top_level (CombApp (tm1, tm2)) =
+        CombApp (aux top_level tm1, aux false tm2)
+      | aux top_level (CombConst (name as (s, _), ty, ty_args)) =
+        (case strip_prefix_and_unascii const_prefix s of
+           NONE => (name, ty_args)
+         | SOME s'' =>
+           let val s'' = invert_const s'' in
+             case type_arg_policy type_sys s'' of
+               No_Type_Args => (name, [])
+             | Mangled_Types => (mangled_const_name ty_args name, [])
+             | Explicit_Type_Args => (name, ty_args)
+           end)
+        |> (fn (name as (s, s'), ty_args) =>
+               case AList.lookup (op =) proxy_table s of
+                 SOME proxy_name =>
+                 if top_level then
+                   (case s of
+                      "c_False" => ("$false", s')
+                    | "c_True" => ("$true", s')
+                    | _ => name, [])
+                  else
+                    (proxy_name, ty_args)
+                | NONE => (name, ty_args))
+        |> (fn (name, ty_args) => CombConst (name, ty, ty_args))
+      | aux _ tm = tm
+  in aux true end
+
 fun tag_with_type ty t = ATerm (`I type_tag_name, [ty, t])
 
 fun fo_term_for_combtyp (CombTVar name) = ATerm (name, [])
@@ -484,44 +535,6 @@
     full_types orelse is_combtyp_dangerous ctxt ty
   | should_tag_with_type _ _ _ = false
 
-val proxy_table =
-  [("c_False", ("c_fFalse", @{const_name Metis.fFalse})),
-   ("c_True", ("c_fTrue", @{const_name Metis.fTrue})),
-   ("c_Not", ("c_fNot", @{const_name Metis.fNot})),
-   ("c_conj", ("c_fconj", @{const_name Metis.fconj})),
-   ("c_disj", ("c_fdisj", @{const_name Metis.fdisj})),
-   ("c_implies", ("c_fimplies", @{const_name Metis.fimplies})),
-   ("equal", ("c_fequal", @{const_name Metis.fequal}))]
-
-fun repair_combterm_consts type_sys =
-  let
-    fun aux top_level (CombApp (tm1, tm2)) =
-        CombApp (aux top_level tm1, aux false tm2)
-      | aux top_level (CombConst (name as (s, _), ty, ty_args)) =
-        (case strip_prefix_and_unascii const_prefix s of
-           NONE => (name, ty_args)
-         | SOME s'' =>
-           let val s'' = invert_const s'' in
-             case type_arg_policy type_sys s'' of
-               No_Type_Args => (name, [])
-             | Mangled_Types => (mangled_const_name ty_args name, [])
-             | Explicit_Type_Args => (name, ty_args)
-           end)
-        |> (fn (name as (s, s'), ty_args) =>
-               case AList.lookup (op =) proxy_table s of
-                 SOME proxy_name =>
-                 if top_level then
-                   (case s of
-                      "c_False" => ("$false", s')
-                    | "c_True" => ("$true", s')
-                    | _ => name, [])
-                  else
-                    (proxy_name, ty_args)
-                | NONE => (name, ty_args))
-        |> (fn (name, ty_args) => CombConst (name, ty, ty_args))
-      | aux _ tm = tm
-  in aux true end
-
 fun pred_combtyp ty =
   case combtyp_from_typ @{typ "unit => bool"} of
     CombType (name, [_, bool_ty]) => CombType (name, [ty, bool_ty])
@@ -688,7 +701,7 @@
   | NONE =>
     case strip_prefix_and_unascii const_prefix s of
       SOME s =>
-      let val s = s |> unmangled_const |> fst |> invert_const in
+      let val s = s |> unmangled_const_name |> invert_const in
         if s = boolify_base then 1
         else if s = explicit_app_base then 2
         else if s = type_pred_base then 1
@@ -769,7 +782,7 @@
   fact_lift (formula_fold (consider_combterm_consts type_sys sym_tab))
 
 (* FIXME: needed? *)
-fun const_table_for_facts type_sys sym_tab facts =
+fun typed_const_table_for_facts type_sys sym_tab facts =
   Symtab.empty |> member (op =) [Many_Typed, Mangled true, Args true] type_sys
                   ? fold (consider_fact_consts type_sys sym_tab) facts
 
@@ -790,6 +803,7 @@
       in
         Decl (sym_decl_prefix ^ ascii_of s, (s, s'),
               arg_tys,
+              (* ### FIXME: put that in typed_const_tab *)
               if is_pred_sym sym_tab s then `I tff_bool_type else res_ty)
       end
     else
@@ -812,9 +826,13 @@
                  NONE, NONE)
       end
   end
-fun problem_lines_for_const ctxt type_sys sym_tab (s, xs) =
+fun problem_lines_for_sym_decl ctxt type_sys sym_tab (s, xs) =
   map2 (problem_line_for_typed_const ctxt type_sys sym_tab s)
        (0 upto length xs - 1) xs
+fun problem_lines_for_sym_decls ctxt type_sys repaired_sym_tab typed_const_tab =
+  Symtab.fold_rev
+      (append o problem_lines_for_sym_decl ctxt type_sys repaired_sym_tab)
+      typed_const_tab []
 
 fun add_tff_types_in_formula (AQuant (_, xs, phi)) =
     union (op =) (map_filter snd xs) #> add_tff_types_in_formula phi
@@ -853,46 +871,36 @@
     val (fact_names, (conjs, facts, class_rel_clauses, arity_clauses)) =
       translate_formulas ctxt type_sys hyp_ts concl_t facts
     val sym_tab = sym_table_for_facts explicit_apply (conjs @ facts)
-    val conjs = map (repair_fact type_sys sym_tab) conjs
-    val facts = map (repair_fact type_sys sym_tab) facts
+    val conjs = conjs |> map (repair_fact type_sys sym_tab)
+    val facts = facts |> map (repair_fact type_sys sym_tab)
+    val repaired_sym_tab = sym_table_for_facts false (conjs @ facts)
+    val typed_const_tab =
+      typed_const_table_for_facts type_sys repaired_sym_tab (conjs @ facts)
+    val sym_decl_lines =
+      problem_lines_for_sym_decls ctxt type_sys repaired_sym_tab typed_const_tab
+    val (helpers, raw_helper_lines) = helper_facts ctxt type_sys typed_const_tab
+    val helpers = helpers |> map (repair_fact type_sys sym_tab)
     (* Reordering these might confuse the proof reconstruction code or the SPASS
        Flotter hack. *)
     val problem =
-      [(type_declsN, []),
-       (sym_declsN, []),
+      [(sym_declsN, sym_decl_lines),
        (factsN, map (formula_line_for_fact ctxt fact_prefix type_sys)
                     (0 upto length facts - 1 ~~ facts)),
        (class_relsN, map formula_line_for_class_rel_clause class_rel_clauses),
        (aritiesN, map formula_line_for_arity_clause arity_clauses),
-       (helpersN, []),
+       (helpersN, map (formula_line_for_fact ctxt helper_prefix type_sys)
+                      (0 upto length helpers - 1 ~~ helpers) @
+                  raw_helper_lines),
        (conjsN, map (formula_line_for_conjecture ctxt type_sys) conjs),
        (free_typesN, formula_lines_for_free_types type_sys (facts @ conjs))]
-    val helper_facts =
-      problem |> maps (map_filter (fn Formula (_, _, _, phi, _, _) => SOME phi
-                                    | _ => NONE) o snd)
-              |> get_helper_facts ctxt type_sys
-              |>> map (repair_fact type_sys sym_tab)
-    val sym_tab = sym_table_for_facts false (conjs @ facts)
-    val const_tab = const_table_for_facts type_sys sym_tab (conjs @ facts)
-    val sym_decl_lines =
-      Symtab.fold_rev (append o problem_lines_for_const ctxt type_sys sym_tab)
-                      const_tab []
-    val helper_lines =
-      helper_facts
-      |>> map (pair 0 #> formula_line_for_fact ctxt helper_prefix type_sys)
-      |> op @
     val problem =
-      problem |> fold (AList.update (op =))
-                      [(sym_declsN, sym_decl_lines),
-                       (helpersN, helper_lines)]
-    val type_decl_lines =
-      if type_sys = Many_Typed then
-        problem |> tff_types_in_problem |> map decl_line_for_tff_type
-      else
-        []
-    val (problem, pool) =
-      problem |> AList.update (op =) (type_declsN, type_decl_lines)
-              |> nice_atp_problem readable_names
+      problem
+      |> (if type_sys = Many_Typed then
+            cons (type_declsN,
+                  map decl_line_for_tff_type (tff_types_in_problem problem))
+          else
+            I)
+    val (problem, pool) = problem |> nice_atp_problem readable_names
   in
     (problem,
      case pool of SOME the_pool => snd the_pool | NONE => Symtab.empty,