src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML
changeset 41140 9c68004b8c9d
parent 41138 eb80538166b6
child 41145 a5ee3b8e5a90
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML	Wed Dec 15 11:26:28 2010 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML	Wed Dec 15 11:26:28 2010 +0100
@@ -20,7 +20,6 @@
   val precise_overloaded_args : bool Unsynchronized.ref
   val fact_prefix : string
   val conjecture_prefix : string
-  val is_fully_typed : type_system -> bool
   val types_dangerous_types : type_system -> bool
   val num_atp_type_args : theory -> type_system -> string -> int
   val translate_atp_fact :
@@ -32,7 +31,7 @@
     -> string problem * string Symtab.table * int * (string * 'a) list vector
 end;
 
-structure Sledgehammer_ATP_Translate (*### : SLEDGEHAMMER_ATP_TRANSLATE *) =
+structure Sledgehammer_ATP_Translate : SLEDGEHAMMER_ATP_TRANSLATE =
 struct
 
 open ATP_Problem
@@ -65,10 +64,6 @@
   Const_Args |
   No_Types
 
-fun is_fully_typed (Tags full_types) = full_types
-  | is_fully_typed (Preds full_types) = full_types
-  | is_fully_typed _ = false
-
 fun types_dangerous_types (Tags _) = true
   | types_dangerous_types (Preds _) = true
   | types_dangerous_types _ = false
@@ -84,7 +79,7 @@
 fun needs_type_args thy type_sys s =
   case type_sys of
     Tags full_types => not full_types andalso is_overloaded thy s
-  | Preds full_types => is_overloaded thy s (* FIXME: could be more precise *)
+  | Preds _ => is_overloaded thy s (* FIXME: could be more precise *)
   | Const_Args => is_overloaded thy s
   | No_Types => false
 
@@ -100,9 +95,11 @@
   | mk_ahorn (phi :: phis) psi =
     AConn (AImplies, [fold (mk_aconn AAnd) phis phi, psi])
 
-fun combformula_for_prop thy =
+fun combformula_for_prop thy eq_as_iff =
   let
-    val do_term = combterm_from_term thy
+    fun do_term bs t ts =
+      combterm_from_term thy bs (Envir.eta_contract t)
+      |>> AAtom ||> union (op =) ts
     fun do_quant bs q s T t' =
       let val s = Name.variant (map fst bs) s in
         do_formula ((s, T) :: bs) t'
@@ -123,9 +120,8 @@
       | @{const HOL.disj} $ t1 $ t2 => do_conn bs AOr t1 t2
       | @{const HOL.implies} $ t1 $ t2 => do_conn bs AImplies t1 t2
       | Const (@{const_name HOL.eq}, Type (_, [@{typ bool}, _])) $ t1 $ t2 =>
-        do_conn bs AIff t1 t2
-      | _ => (fn ts => do_term bs (Envir.eta_contract t)
-                       |>> AAtom ||> union (op =) ts)
+        if eq_as_iff then do_conn bs AIff t1 t2 else do_term bs t
+      | _ => do_term bs t
   in do_formula [] end
 
 val presimplify_term = prop_of o Meson.presimplify oo Skip_Proof.make_thm
@@ -224,7 +220,7 @@
   in perhaps (try aux) end
 
 (* making fact and conjecture formulas *)
-fun make_formula ctxt presimp name kind t =
+fun make_formula ctxt eq_as_iff presimp name kind t =
   let
     val thy = ProofContext.theory_of ctxt
     val t = t |> Envir.beta_eta_contract
@@ -237,66 +233,59 @@
               |> perhaps (try (HOLogic.dest_Trueprop))
               |> introduce_combinators_in_term ctxt kind
               |> kind <> Axiom ? freeze_term
-    val (combformula, ctypes_sorts) = combformula_for_prop thy t []
+    val (combformula, ctypes_sorts) = combformula_for_prop thy eq_as_iff t []
   in
     {name = name, combformula = combformula, kind = kind,
      ctypes_sorts = ctypes_sorts}
   end
 
-fun make_fact ctxt presimp ((name, _), th) =
-  case make_formula ctxt presimp name Axiom (prop_of th) of
+fun make_fact ctxt eq_as_iff presimp ((name, _), th) =
+  case make_formula ctxt eq_as_iff presimp name Axiom (prop_of th) of
     {combformula = AAtom (CombConst (("c_True", _), _, _)), ...} => NONE
   | formula => SOME formula
 fun make_conjecture ctxt ts =
   let val last = length ts - 1 in
-    map2 (fn j => make_formula ctxt true (Int.toString j)
+    map2 (fn j => make_formula ctxt true true (Int.toString j)
                                (if j = last then Conjecture else Hypothesis))
          (0 upto last) ts
   end
 
 (** Helper facts **)
 
-fun count_combterm (CombConst ((s, _), _, _)) =
-    Symtab.map_entry s (Integer.add 1)
-  | count_combterm (CombVar _) = I
-  | count_combterm (CombApp (t1, t2)) = fold count_combterm [t1, t2]
-fun count_combformula (AQuant (_, _, phi)) = count_combformula phi
-  | count_combformula (AConn (_, phis)) = fold count_combformula phis
-  | count_combformula (AAtom tm) = count_combterm tm
-fun count_translated_formula ({combformula, ...} : translated_formula) =
-  count_combformula combformula
-
-val optional_helpers =
-  [(["c_COMBI"], @{thms Meson.COMBI_def}),
-   (["c_COMBK"], @{thms Meson.COMBK_def}),
-   (["c_COMBB"], @{thms Meson.COMBB_def}),
-   (["c_COMBC"], @{thms Meson.COMBC_def}),
-   (["c_COMBS"], @{thms Meson.COMBS_def})]
-val optional_fully_typed_helpers =
-  [(["c_True", "c_False", "c_If"], @{thms True_or_False}),
-   (["c_If"], @{thms if_True if_False})]
-val mandatory_helpers = @{thms Metis.fequal_def}
+fun count_term (ATerm ((s, _), tms)) =
+  (if is_atp_variable s then I
+   else Symtab.map_entry s (Integer.add 1))
+  #> fold count_term tms
+fun count_formula (AQuant (_, _, phi)) = count_formula phi
+  | count_formula (AConn (_, phis)) = fold count_formula phis
+  | count_formula (AAtom tm) = count_term tm
 
 val init_counters =
-  [optional_helpers, optional_fully_typed_helpers] |> maps (maps fst)
-  |> sort_distinct string_ord |> map (rpair 0) |> Symtab.make
+  metis_helpers |> map fst |> sort_distinct string_ord |> map (rpair 0)
+  |> Symtab.make
 
-fun get_helper_facts ctxt is_FO type_sys conjectures facts =
+fun get_helper_facts ctxt type_sys formulas =
   let
-    val ct =
-      fold (fold count_translated_formula) [conjectures, facts] init_counters
-    fun is_needed c = the (Symtab.lookup ct c) > 0
-    fun baptize th = ((Thm.get_name_hint th, false), th)
+    val no_dangerous_types = types_dangerous_types type_sys
+    val ct = init_counters |> fold count_formula formulas
+    fun is_used s = the (Symtab.lookup ct s) > 0
+    fun dub c needs_full_types (th, j) =
+      ((c ^ "_" ^ string_of_int j ^ (if needs_full_types then "ft" else ""),
+        false), th)
+    fun make_facts eq_as_iff = map_filter (make_fact ctxt eq_as_iff false)
   in
-    (optional_helpers
-     |> is_fully_typed type_sys ? append optional_fully_typed_helpers
-     |> maps (fn (ss, ths) =>
-                 if exists is_needed ss then map baptize ths else [])) @
-    (if is_FO then [] else map baptize mandatory_helpers)
-    |> map_filter (make_fact ctxt false)
+    metis_helpers
+    |> filter (is_used o fst)
+    |> maps (fn (c, (needs_full_types, ths)) =>
+                if needs_full_types andalso not no_dangerous_types then
+                  []
+                else
+                  ths ~~ (1 upto length ths)
+                  |> map (dub c needs_full_types)
+                  |> make_facts (not needs_full_types))
   end
 
-fun translate_atp_fact ctxt = `(make_fact ctxt true)
+fun translate_atp_fact ctxt = `(make_fact ctxt true true)
 
 fun translate_formulas ctxt type_sys hyp_ts concl_t rich_facts =
   let
@@ -311,20 +300,18 @@
        boost an ATP's performance (for some reason). *)
     val hyp_ts = hyp_ts |> filter_out (member (op aconv) fact_ts)
     val goal_t = Logic.list_implies (hyp_ts, concl_t)
-    val is_FO = Meson.is_fol_term thy goal_t
     val subs = tfree_classes_of_terms [goal_t]
     val supers = tvar_classes_of_terms fact_ts
     val tycons = type_consts_of_terms thy (goal_t :: fact_ts)
     (* TFrees in the conjecture; TVars in the facts *)
     val conjectures = make_conjecture ctxt (hyp_ts @ [concl_t])
-    val helper_facts = get_helper_facts ctxt is_FO type_sys conjectures facts
     val (supers', arity_clauses) =
       if type_sys = No_Types then ([], [])
       else make_arity_clauses thy tycons supers
     val class_rel_clauses = make_class_rel_clauses thy subs supers'
   in
     (fact_names |> map single |> Vector.fromList,
-     (conjectures, facts, helper_facts, class_rel_clauses, arity_clauses))
+     (conjectures, facts, class_rel_clauses, arity_clauses))
   end
 
 fun tag_with_type ty t = ATerm (`I type_tag_name, [ty, t])
@@ -352,7 +339,7 @@
   | is_dtyp_dangerous _ (Datatype_Aux.DtRec _) = false
 and is_type_dangerous ctxt (Type (s, Ts)) =
     is_type_constr_dangerous ctxt s andalso forall (is_type_dangerous ctxt) Ts
-  | is_type_dangerous ctxt _ = false
+  | is_type_dangerous _ _ = false
 and is_type_constr_dangerous ctxt s =
   let val thy = ProofContext.theory_of ctxt in
     case Datatype_Data.get_info thy s of
@@ -376,6 +363,15 @@
     full_types orelse is_combtyp_dangerous ctxt ty
   | should_tag_with_type _ _ _ = false
 
+val fname_table =
+  [("c_False", (0, ("c_fFalse", @{const_name Metis.fFalse}))),
+   ("c_True", (0, ("c_fTrue", @{const_name Metis.fTrue}))),
+   ("c_Not", (1, ("c_fNot", @{const_name Metis.fNot}))),
+   ("c_conj", (2, ("c_fconj", @{const_name Metis.fconj}))),
+   ("c_disj", (2, ("c_fdisj", @{const_name Metis.fdisj}))),
+   ("c_implies", (2, ("c_fimplies", @{const_name Metis.fimplies}))),
+   ("equal", (2, ("c_fequal", @{const_name Metis.fequal})))]
+
 fun fo_term_for_combterm ctxt type_sys =
   let
     val thy = ProofContext.theory_of ctxt
@@ -385,27 +381,27 @@
         val (x, ty_args) =
           case head of
             CombConst (name as (s, s'), _, ty_args) =>
-            (case strip_prefix_and_unascii const_prefix s of
-               NONE =>
-               if s = "equal" then
-                 if top_level andalso length args = 2 then (name, [])
-                 else (("c_fequal", @{const_name Metis.fequal}), ty_args)
-               else
-                 (name, ty_args)
-             | SOME s'' =>
-               let
-                 val s'' = invert_const s''
-                 val ty_args =
-                   if needs_type_args thy type_sys s'' then ty_args else []
-                in
-                  if top_level then
-                    case s of
-                      "c_False" => (("$false", s'), [])
-                    | "c_True" => (("$true", s'), [])
-                    | _ => (name, ty_args)
-                  else
-                    (name, ty_args)
-                end)
+            (case AList.lookup (op =) fname_table s of
+               SOME (n, fname) =>
+               if top_level andalso length args = n then (name, [])
+               else (fname, ty_args)
+             | NONE =>
+               case strip_prefix_and_unascii const_prefix s of
+                 NONE => (name, ty_args)
+               | SOME s'' =>
+                 let
+                   val s'' = invert_const s''
+                   val ty_args =
+                     if needs_type_args thy type_sys s'' then ty_args else []
+                  in
+                    if top_level then
+                      case s of
+                        "c_False" => (("$false", s'), [])
+                      | "c_True" => (("$true", s'), [])
+                      | _ => (name, ty_args)
+                    else
+                      (name, ty_args)
+                  end)
           | CombVar (name, _) => (name, [])
           | CombApp _ => raise Fail "impossible \"CombApp\""
         val t =
@@ -498,9 +494,14 @@
 fun consider_problem_line (Fof (_, _, phi)) = consider_formula phi
 fun consider_problem problem = fold (fold consider_problem_line o snd) problem
 
+(* needed for helper facts if the problem otherwise does not involve equality *)
+val equal_entry = ("equal", {min_arity = 2, max_arity = 2, sub_level = false})
+
 fun const_table_for_problem explicit_apply problem =
-  if explicit_apply then NONE
-  else SOME (Symtab.empty |> consider_problem problem)
+  if explicit_apply then
+    NONE
+  else
+    SOME (Symtab.empty |> Symtab.update equal_entry |> consider_problem problem)
 
 fun min_arity_of thy type_sys NONE s =
     (if s = "equal" orelse s = type_tag_name orelse
@@ -561,14 +562,14 @@
       not sub_level andalso min_arity = max_arity
     | NONE => false
 
-fun repair_predicates_in_term const_tab (t as ATerm ((s, _), ts)) =
+fun repair_predicates_in_term pred_const_tab (t as ATerm ((s, _), ts)) =
   if s = type_tag_name then
     case ts of
       [_, t' as ATerm ((s', _), _)] =>
-      if is_predicate const_tab s' then t' else boolify t
+      if is_predicate pred_const_tab s' then t' else boolify t
     | _ => raise Fail "malformed type tag"
   else
-    t |> not (is_predicate const_tab s) ? boolify
+    t |> not (is_predicate pred_const_tab s) ? boolify
 
 fun close_universally phi =
   let
@@ -586,33 +587,28 @@
 
 fun repair_formula thy explicit_forall type_sys const_tab =
   let
+    val pred_const_tab = case type_sys of Tags _ => NONE | _ => const_tab
     fun aux (AQuant (q, xs, phi)) = AQuant (q, xs, aux phi)
       | aux (AConn (c, phis)) = AConn (c, map aux phis)
       | aux (AAtom tm) =
         AAtom (tm |> repair_applications_in_term thy type_sys const_tab
-                  |> repair_predicates_in_term const_tab)
+                  |> repair_predicates_in_term pred_const_tab)
   in aux #> explicit_forall ? close_universally end
 
 fun repair_problem_line thy explicit_forall type_sys const_tab
                         (Fof (ident, kind, phi)) =
   Fof (ident, kind, repair_formula thy explicit_forall type_sys const_tab phi)
-fun repair_problem_with_const_table thy =
-  map o apsnd o map ooo repair_problem_line thy
+fun repair_problem thy = map o apsnd o map ooo repair_problem_line thy
 
-fun repair_problem thy explicit_forall type_sys explicit_apply problem =
-  repair_problem_with_const_table thy explicit_forall type_sys
-      (const_table_for_problem explicit_apply problem) problem
+fun dest_Fof (Fof z) = z
 
 fun prepare_atp_problem ctxt readable_names explicit_forall type_sys
                         explicit_apply hyp_ts concl_t facts =
   let
     val thy = ProofContext.theory_of ctxt
-    val (fact_names, (conjectures, facts, helper_facts, class_rel_clauses,
-                      arity_clauses)) =
+    val (fact_names, (conjectures, facts, class_rel_clauses, arity_clauses)) =
       translate_formulas ctxt type_sys hyp_ts concl_t facts
     val fact_lines = map (problem_line_for_fact ctxt fact_prefix type_sys) facts
-    val helper_lines =
-      map (problem_line_for_fact ctxt helper_prefix type_sys) helper_facts
     val conjecture_lines =
       map (problem_line_for_conjecture ctxt type_sys) conjectures
     val tfree_lines = problem_lines_for_free_types type_sys conjectures
@@ -625,11 +621,21 @@
       [("Relevant facts", fact_lines),
        ("Class relationships", class_rel_lines),
        ("Arity declarations", arity_lines),
-       ("Helper facts", helper_lines),
+       ("Helper facts", []),
        ("Conjectures", conjecture_lines),
        ("Type variables", tfree_lines)]
-      |> repair_problem thy explicit_forall type_sys explicit_apply
-    val (problem, pool) = nice_atp_problem readable_names problem
+    val const_tab = const_table_for_problem explicit_apply problem
+    val problem =
+      problem |> repair_problem thy explicit_forall type_sys const_tab
+    val helper_facts =
+      get_helper_facts ctxt type_sys (maps (map (#3 o dest_Fof) o snd) problem)
+    val helper_lines =
+      helper_facts
+      |> map (problem_line_for_fact ctxt helper_prefix type_sys
+              #> repair_problem_line thy explicit_forall type_sys const_tab)
+    val (problem, pool) =
+      problem |> AList.update (op =) ("Helper facts", helper_lines)
+              |> nice_atp_problem readable_names
     val conjecture_offset =
       length fact_lines + length class_rel_lines + length arity_lines
       + length helper_lines