improve helper type instantiation code
authorblanchet
Sun, 01 May 2011 18:37:24 +0200
changeset 42563 e70ffe3846d0
parent 42562 f1d903f789b1
child 42564 d40bdf941a9a
improve helper type instantiation code
src/HOL/Tools/Metis/metis_translate.ML
src/HOL/Tools/Sledgehammer/sledgehammer_atp_reconstruct.ML
src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML
--- a/src/HOL/Tools/Metis/metis_translate.ML	Sun May 01 18:37:24 2011 +0200
+++ b/src/HOL/Tools/Metis/metis_translate.ML	Sun May 01 18:37:24 2011 +0200
@@ -43,6 +43,8 @@
   val type_const_prefix: string
   val class_prefix: string
   val new_skolem_const_prefix : string
+  val metis_proxies : (string * (string * string)) list
+  val safe_invert_const: string -> string
   val invert_const: string -> string
   val ascii_of: string -> string
   val unascii_of: string -> string
@@ -101,36 +103,42 @@
 
 fun union_all xss = fold (union (op =)) xss []
 
+val metis_proxies =
+  [("c_False", ("fFalse", @{const_name Metis.fFalse})),
+   ("c_True", ("fTrue", @{const_name Metis.fTrue})),
+   ("c_Not", ("fNot", @{const_name Metis.fNot})),
+   ("c_conj", ("fconj", @{const_name Metis.fconj})),
+   ("c_disj", ("fdisj", @{const_name Metis.fdisj})),
+   ("c_implies", ("fimplies", @{const_name Metis.fimplies})),
+   ("equal", ("fequal", @{const_name Metis.fequal}))]
+
 (* Readable names for the more common symbolic functions. Do not mess with the
    table unless you know what you are doing. *)
 val const_trans_table =
-  Symtab.make [(@{type_name Product_Type.prod}, "prod"),
-               (@{type_name Sum_Type.sum}, "sum"),
-               (@{const_name False}, "False"),
-               (@{const_name True}, "True"),
-               (@{const_name Not}, "Not"),
-               (@{const_name conj}, "conj"),
-               (@{const_name disj}, "disj"),
-               (@{const_name implies}, "implies"),
-               (@{const_name HOL.eq}, "equal"),
-               (@{const_name If}, "If"),
-               (@{const_name Set.member}, "member"),
-               (@{const_name Meson.COMBI}, "COMBI"),
-               (@{const_name Meson.COMBK}, "COMBK"),
-               (@{const_name Meson.COMBB}, "COMBB"),
-               (@{const_name Meson.COMBC}, "COMBC"),
-               (@{const_name Meson.COMBS}, "COMBS"),
-               (@{const_name Metis.fFalse}, "fFalse"),
-               (@{const_name Metis.fTrue}, "fTrue"),
-               (@{const_name Metis.fNot}, "fNot"),
-               (@{const_name Metis.fconj}, "fconj"),
-               (@{const_name Metis.fdisj}, "fdisj"),
-               (@{const_name Metis.fimplies}, "fimplies"),
-               (@{const_name Metis.fequal}, "fequal")]
+  [(@{type_name Product_Type.prod}, "prod"),
+   (@{type_name Sum_Type.sum}, "sum"),
+   (@{const_name False}, "False"),
+   (@{const_name True}, "True"),
+   (@{const_name Not}, "Not"),
+   (@{const_name conj}, "conj"),
+   (@{const_name disj}, "disj"),
+   (@{const_name implies}, "implies"),
+   (@{const_name HOL.eq}, "equal"),
+   (@{const_name If}, "If"),
+   (@{const_name Set.member}, "member"),
+   (@{const_name Meson.COMBI}, "COMBI"),
+   (@{const_name Meson.COMBK}, "COMBK"),
+   (@{const_name Meson.COMBB}, "COMBB"),
+   (@{const_name Meson.COMBC}, "COMBC"),
+   (@{const_name Meson.COMBS}, "COMBS")] @
+   (metis_proxies |> map (swap o snd))
+  |> Symtab.make
 
 (* Invert the table of translations between Isabelle and ATPs. *)
+val const_trans_table_safe_inv =
+  const_trans_table |> Symtab.dest |> map swap |> Symtab.make
 val const_trans_table_inv =
-  const_trans_table |> Symtab.dest |> map swap |> Symtab.make
+  const_trans_table_safe_inv
   |> fold Symtab.update [("fFalse", @{const_name False}),
                          ("fTrue", @{const_name True}),
                          ("fNot", @{const_name Not}),
@@ -139,6 +147,7 @@
                          ("fimplies", @{const_name implies}),
                          ("fequal", @{const_name HOL.eq})]
 
+val safe_invert_const = perhaps (Symtab.lookup const_trans_table_safe_inv)
 val invert_const = perhaps (Symtab.lookup const_trans_table_inv)
 
 (*Escaping of special characters.
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_atp_reconstruct.ML	Sun May 01 18:37:24 2011 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_atp_reconstruct.ML	Sun May 01 18:37:24 2011 +0200
@@ -283,8 +283,8 @@
 
 (* Type variables are given the basic sort "HOL.type". Some will later be
    constrained by information from type literals, or by type inference. *)
-fun type_from_fo_term tfrees (u as ATerm (a, us)) =
-  let val Ts = map (type_from_fo_term tfrees) us in
+fun typ_from_fo_term tfrees (u as ATerm (a, us)) =
+  let val Ts = map (typ_from_fo_term tfrees) us in
     case strip_prefix_and_unascii type_const_prefix a of
       SOME b => Type (invert_const b, Ts)
     | NONE =>
@@ -307,7 +307,7 @@
    type. *)
 fun type_constraint_from_term pos tfrees (u as ATerm (a, us)) =
   case (strip_prefix_and_unascii class_prefix a,
-        map (type_from_fo_term tfrees) us) of
+        map (typ_from_fo_term tfrees) us) of
     (SOME b, [T]) => (pos, b, T)
   | _ => raise FO_TERM [u]
 
@@ -344,7 +344,7 @@
         if a = type_tag_name then
           case us of
             [typ_u, term_u] =>
-            aux (SOME (type_from_fo_term tfrees typ_u)) extra_us term_u
+            aux (SOME (typ_from_fo_term tfrees typ_u)) extra_us term_u
           | _ => raise FO_TERM us
         else if String.isPrefix tff_type_prefix a then
           @{const True} (* ignore TFF type information *)
@@ -380,7 +380,7 @@
                   | NONE =>
                     if num_type_args thy b = length type_us then
                       Sign.const_instance thy
-                          (b, map (type_from_fo_term tfrees) type_us)
+                          (b, map (typ_from_fo_term tfrees) type_us)
                     else
                       HOLogic.typeT
               in list_comb (Const (b, T), term_ts @ extra_ts) end
@@ -428,7 +428,7 @@
   | uncombine_term t = t
 
 (* Update schematic type variables with detected sort constraints. It's not
-   totally clear when this code is necessary. *)
+   totally clear whether this code is necessary. *)
 fun repair_tvar_sorts (t, tvar_tab) =
   let
     fun do_type (Type (a, Ts)) = Type (a, map do_type Ts)
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML	Sun May 01 18:37:24 2011 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML	Sun May 01 18:37:24 2011 +0200
@@ -111,14 +111,13 @@
 
 datatype type_arg_policy = No_Type_Args | Mangled_Types | Explicit_Type_Args
 
+fun general_type_arg_policy Many_Typed = Mangled_Types
+  | general_type_arg_policy (Mangled _) = Mangled_Types
+  | general_type_arg_policy _ = Explicit_Type_Args
+
 fun type_arg_policy type_sys s =
-  if dont_need_type_args type_sys s then
-    No_Type_Args
-  else
-    case type_sys of
-      Many_Typed => Mangled_Types
-    | Mangled _ => Mangled_Types
-    | _ => Explicit_Type_Args
+  if dont_need_type_args type_sys s then No_Type_Args
+  else general_type_arg_policy type_sys
 
 fun num_atp_type_args thy type_sys s =
   if type_arg_policy type_sys s = Explicit_Type_Args then num_type_args thy s
@@ -330,7 +329,7 @@
     val t = t |> Envir.beta_eta_contract
               |> transform_elim_term
               |> Object_Logic.atomize_term thy
-    val need_trueprop = (fastype_of t = HOLogic.boolT)
+    val need_trueprop = (fastype_of t = @{typ bool})
     val t = t |> need_trueprop ? HOLogic.mk_Trueprop
               |> extensionalize_term
               |> presimp ? presimplify_term thy
@@ -367,6 +366,9 @@
   | formula_fold f (AConn (_, phis)) = fold (formula_fold f) phis
   | formula_fold f (AAtom tm) = f tm
 
+type sym_table_info =
+  {pred_sym : bool, min_ary : int, max_ary : int, typ : typ option}
+
 fun ti_ti_helper_fact () =
   let
     fun var s = ATerm (`I s, [])
@@ -377,18 +379,21 @@
              |> close_formula_universally, NONE, NONE)
   end
 
-fun helper_facts_for_typed_const ctxt type_sys s (_, _, T) =
+fun helper_facts_for_sym ctxt type_sys (s, {typ, ...} : sym_table_info) =
   case strip_prefix_and_unascii const_prefix s of
-    SOME s'' =>
+    SOME mangled_s =>
     let
       val thy = Proof_Context.theory_of ctxt
-      val unmangled_s = s'' |> unmangled_const_name
-      (* ### FIXME avoid duplicate names *)
+      val unmangled_s = mangled_s |> unmangled_const_name
       fun dub_and_inst c needs_full_types (th, j) =
         ((c ^ "_" ^ string_of_int j ^ (if needs_full_types then "ft" else ""),
           false),
          th |> prop_of
-            |> specialize_type thy (invert_const unmangled_s, T))
+            |> general_type_arg_policy type_sys = Mangled_Types
+               ? (case typ of
+                    SOME T =>
+                    specialize_type thy (safe_invert_const unmangled_s, T)
+                  | NONE => I))
       fun make_facts eq_as_iff =
         map_filter (make_fact ctxt false eq_as_iff false)
     in
@@ -400,16 +405,12 @@
                     []
                   else
                     ths ~~ (1 upto length ths)
-                    |> map (dub_and_inst s needs_full_types)
+                    |> map (dub_and_inst mangled_s needs_full_types)
                     |> make_facts (not needs_full_types))
     end
   | NONE => []
-fun helper_facts_for_const ctxt type_sys (s, xs) =
-  maps (helper_facts_for_typed_const ctxt type_sys s) xs
-fun helper_facts ctxt type_sys typed_const_tab =
-  (Symtab.fold_rev (append o helper_facts_for_const ctxt type_sys)
-                   typed_const_tab [],
-   if type_sys = Tags false then [ti_ti_helper_fact ()] else [])
+fun helper_facts_for_sym_table ctxt type_sys sym_tab =
+  Symtab.fold_rev (append o helper_facts_for_sym ctxt type_sys) sym_tab []
 
 fun translate_atp_fact ctxt keep_trivial =
   `(make_fact ctxt keep_trivial true true o apsnd prop_of)
@@ -440,15 +441,6 @@
     (fact_names |> map single, (conjs, facts, class_rel_clauses, arity_clauses))
   end
 
-val proxy_table =
-  [("c_False", ("c_fFalse", @{const_name Metis.fFalse})),
-   ("c_True", ("c_fTrue", @{const_name Metis.fTrue})),
-   ("c_Not", ("c_fNot", @{const_name Metis.fNot})),
-   ("c_conj", ("c_fconj", @{const_name Metis.fconj})),
-   ("c_disj", ("c_fdisj", @{const_name Metis.fdisj})),
-   ("c_implies", ("c_fimplies", @{const_name Metis.fimplies})),
-   ("equal", ("c_fequal", @{const_name Metis.fequal}))]
-
 fun repair_combterm_consts type_sys =
   let
     fun aux top_level (CombApp (tm1, tm2)) =
@@ -464,15 +456,15 @@
              | Explicit_Type_Args => (name, ty_args)
            end)
         |> (fn (name as (s, s'), ty_args) =>
-               case AList.lookup (op =) proxy_table s of
-                 SOME proxy_name =>
+               case AList.lookup (op =) metis_proxies s of
+                 SOME proxy_base =>
                  if top_level then
                    (case s of
                       "c_False" => ("$false", s')
                     | "c_True" => ("$true", s')
                     | _ => name, [])
                   else
-                    (proxy_name, ty_args)
+                    (proxy_base |>> prefix const_prefix, ty_args)
                 | NONE => (name, ty_args))
         |> (fn (name, ty_args) => CombConst (name, ty, ty_args))
       | aux _ tm = tm
@@ -516,8 +508,8 @@
   | should_tag_with_type _ _ _ = false
 
 fun type_pred_combatom type_sys T tm =
-  CombApp (CombConst (`make_fixed_const type_pred_base,
-                      T --> HOLogic.boolT, [T]), tm)
+  CombApp (CombConst (`make_fixed_const type_pred_base, T --> @{typ bool}, [T]),
+           tm)
   |> repair_combterm_consts type_sys
   |> AAtom
 
@@ -627,26 +619,25 @@
 
 (** "hBOOL" and "hAPP" **)
 
-type repair_info = {pred_sym: bool, min_arity: int, max_arity: int}
-
 fun add_combterm_to_sym_table explicit_apply =
   let
     fun aux top_level tm =
       let val (head, args) = strip_combterm_comb tm in
         (case head of
-           CombConst ((s, _), _, _) =>
+           CombConst ((s, _), T, _) =>
            if String.isPrefix bound_var_prefix s then
              I
            else
-             let val arity = length args in
+             let val ary = length args in
                Symtab.map_default
                    (s, {pred_sym = true,
-                        min_arity = if explicit_apply then 0 else arity,
-                        max_arity = 0})
-                   (fn {pred_sym, min_arity, max_arity} =>
+                        min_ary = if explicit_apply then 0 else ary,
+                        max_ary = 0, typ = SOME T})
+                   (fn {pred_sym, min_ary, max_ary, typ} =>
                        {pred_sym = pred_sym andalso top_level,
-                        min_arity = Int.min (arity, min_arity),
-                        max_arity = Int.max (arity, max_arity)})
+                        min_ary = Int.min (ary, min_ary),
+                        max_ary = Int.max (ary, max_ary),
+                        typ = if typ = SOME T then typ else NONE})
             end
          | _ => I)
         #> fold (aux false) args
@@ -660,19 +651,19 @@
    that no "hBOOL" is introduced for them. The "hBOOL" entry is needed to ensure
    that no "hAPP"s are introduced for passing arguments to it. *)
 val default_sym_table_entries =
-  [("equal", {pred_sym = true, min_arity = 2, max_arity = 2}),
-   ("$false", {pred_sym = true, min_arity = 0, max_arity = 0}),
-   ("$true", {pred_sym = true, min_arity = 0, max_arity = 0}),
+  [("equal", {pred_sym = true, min_ary = 2, max_ary = 2, typ = NONE}),
    (make_fixed_const boolify_base,
-    {pred_sym = true, min_arity = 1, max_arity = 1})]
+    {pred_sym = true, min_ary = 1, max_ary = 1, typ = NONE})] @
+  (["$false", "$true"]
+   |> map (rpair {pred_sym = true, min_ary = 0, max_ary = 0, typ = NONE}))
 
 fun sym_table_for_facts explicit_apply facts =
-  Symtab.empty |> fold Symtab.default default_sym_table_entries
-               |> fold (add_fact_to_sym_table explicit_apply) facts
+  Symtab.empty |> fold (add_fact_to_sym_table explicit_apply) facts
+               |> fold Symtab.default default_sym_table_entries
 
 fun min_arity_of sym_tab s =
   case Symtab.lookup sym_tab s of
-    SOME ({min_arity, ...} : repair_info) => min_arity
+    SOME ({min_ary, ...} : sym_table_info) => min_ary
   | NONE =>
     case strip_prefix_and_unascii const_prefix s of
       SOME s =>
@@ -690,8 +681,7 @@
    arguments and is used as a predicate. *)
 fun is_pred_sym sym_tab s =
   case Symtab.lookup sym_tab s of
-    SOME {pred_sym, min_arity, max_arity} =>
-    pred_sym andalso min_arity = max_arity
+    SOME {pred_sym, min_ary, max_ary, ...} => pred_sym andalso min_ary = max_ary
   | NONE => false
 
 val boolify_combconst =
@@ -764,10 +754,10 @@
   | strip_and_map_type _ _ _ = raise Fail "unexpected non-function"
 
 fun problem_line_for_typed_const ctxt type_sys sym_tab s j (s', ty_args, T) =
-  let val arity = min_arity_of sym_tab s in
+  let val ary = min_arity_of sym_tab s in
     if type_sys = Many_Typed then
       let
-        val (arg_Ts, res_T) = strip_and_map_type arity mangled_type_name T
+        val (arg_Ts, res_T) = strip_and_map_type ary mangled_type_name T
         val (s, s') = (s, s') |> mangled_const_name ty_args
       in
         Decl (sym_decl_prefix ^ ascii_of s, (s, s'), arg_Ts,
@@ -776,7 +766,7 @@
       end
     else
       let
-        val (arg_Ts, res_T) = strip_and_map_type arity I T
+        val (arg_Ts, res_T) = strip_and_map_type ary I T
         val bounds =
           map (`I o make_bound_var o string_of_int) (1 upto length arg_Ts)
           ~~ map SOME arg_Ts
@@ -797,10 +787,9 @@
 fun problem_lines_for_sym_decl ctxt type_sys sym_tab (s, xs) =
   map2 (problem_line_for_typed_const ctxt type_sys sym_tab s)
        (0 upto length xs - 1) xs
-fun problem_lines_for_sym_decls ctxt type_sys repaired_sym_tab typed_const_tab =
-  Symtab.fold_rev
-      (append o problem_lines_for_sym_decl ctxt type_sys repaired_sym_tab)
-      typed_const_tab []
+fun problem_lines_for_sym_decls ctxt type_sys sym_tab typed_const_tab =
+  Symtab.fold_rev (append o problem_lines_for_sym_decl ctxt type_sys sym_tab)
+                  typed_const_tab []
 
 fun add_tff_types_in_formula (AQuant (_, xs, phi)) =
     union (op =) (map_filter snd xs) #> add_tff_types_in_formula phi
@@ -838,16 +827,16 @@
   let
     val (fact_names, (conjs, facts, class_rel_clauses, arity_clauses)) =
       translate_formulas ctxt type_sys hyp_ts concl_t facts
-    val sym_tab = sym_table_for_facts explicit_apply (conjs @ facts)
+    val sym_tab = conjs @ facts |> sym_table_for_facts explicit_apply
     val conjs = conjs |> map (repair_fact type_sys sym_tab)
     val facts = facts |> map (repair_fact type_sys sym_tab)
-    val repaired_sym_tab = sym_table_for_facts false (conjs @ facts)
+    val sym_tab' = conjs @ facts |> sym_table_for_facts false
     val typed_const_tab =
-      typed_const_table_for_facts type_sys repaired_sym_tab (conjs @ facts)
+      conjs @ facts |> typed_const_table_for_facts type_sys sym_tab'
     val sym_decl_lines =
-      problem_lines_for_sym_decls ctxt type_sys repaired_sym_tab typed_const_tab
-    val (helpers, raw_helper_lines) = helper_facts ctxt type_sys typed_const_tab
-    val helpers = helpers |> map (repair_fact type_sys sym_tab)
+      typed_const_tab |> problem_lines_for_sym_decls ctxt type_sys sym_tab'
+    val helpers = helper_facts_for_sym_table ctxt type_sys sym_tab'
+                  |> map (repair_fact type_sys sym_tab')
     (* Reordering these might confuse the proof reconstruction code or the SPASS
        Flotter hack. *)
     val problem =
@@ -857,8 +846,9 @@
        (class_relsN, map formula_line_for_class_rel_clause class_rel_clauses),
        (aritiesN, map formula_line_for_arity_clause arity_clauses),
        (helpersN, map (formula_line_for_fact ctxt helper_prefix type_sys)
-                      (0 upto length helpers - 1 ~~ helpers) @
-                  raw_helper_lines),
+                      (0 upto length helpers - 1 ~~ helpers)
+                  |> (if type_sys = Tags false then cons (ti_ti_helper_fact ())
+                      else I)),
        (conjsN, map (formula_line_for_conjecture ctxt type_sys) conjs),
        (free_typesN, formula_lines_for_free_types type_sys (facts @ conjs))]
     val problem =