src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML
changeset 41138 eb80538166b6
parent 41137 8b634031b2a5
child 41140 9c68004b8c9d
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML	Wed Dec 15 11:26:28 2010 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML	Wed Dec 15 11:26:28 2010 +0100
@@ -15,12 +15,13 @@
     Tags of bool |
     Preds of bool |
     Const_Args |
-    Overload_Args |
     No_Types
 
+  val precise_overloaded_args : bool Unsynchronized.ref
   val fact_prefix : string
   val conjecture_prefix : string
   val is_fully_typed : type_system -> bool
+  val types_dangerous_types : type_system -> bool
   val num_atp_type_args : theory -> type_system -> string -> int
   val translate_atp_fact :
     Proof.context -> (string * 'a) * thm
@@ -31,13 +32,17 @@
     -> string problem * string Symtab.table * int * (string * 'a) list vector
 end;
 
-structure Sledgehammer_ATP_Translate : SLEDGEHAMMER_ATP_TRANSLATE =
+structure Sledgehammer_ATP_Translate (*### : SLEDGEHAMMER_ATP_TRANSLATE *) =
 struct
 
 open ATP_Problem
 open Metis_Translate
 open Sledgehammer_Util
 
+(* FIXME: Remove references once appropriate defaults have been determined
+   empirically. *)
+val precise_overloaded_args = Unsynchronized.ref false
+
 val fact_prefix = "fact_"
 val conjecture_prefix = "conj_"
 val helper_prefix = "help_"
@@ -58,26 +63,29 @@
   Tags of bool |
   Preds of bool |
   Const_Args |
-  Overload_Args |
   No_Types
 
 fun is_fully_typed (Tags full_types) = full_types
   | is_fully_typed (Preds full_types) = full_types
   | is_fully_typed _ = false
 
+fun types_dangerous_types (Tags _) = true
+  | types_dangerous_types (Preds _) = true
+  | types_dangerous_types _ = false
+
 (* This is an approximation. If it returns "true" for a constant that isn't
    overloaded (i.e., that has one uniform definition), needless clutter is
    generated; if it returns "false" for an overloaded constant, the ATP gets a
    license to do unsound reasoning if the type system is "overloaded_args". *)
 fun is_overloaded thy s =
+  not (!precise_overloaded_args) orelse
   length (Defs.specifications_of (Theory.defs_of thy) s) > 1
 
 fun needs_type_args thy type_sys s =
   case type_sys of
-    Tags full_types => not full_types
-  | Preds full_types => not full_types
-  | Const_Args => true
-  | Overload_Args => is_overloaded thy s
+    Tags full_types => not full_types andalso is_overloaded thy s
+  | Preds full_types => is_overloaded thy s (* FIXME: could be more precise *)
+  | Const_Args => is_overloaded thy s
   | No_Types => false
 
 fun num_atp_type_args thy type_sys s =
@@ -319,7 +327,7 @@
      (conjectures, facts, helper_facts, class_rel_clauses, arity_clauses))
   end
 
-fun wrap_type ty t = ATerm ((type_wrapper_name, type_wrapper_name), [ty, t])
+fun tag_with_type ty t = ATerm (`I type_tag_name, [ty, t])
 
 fun fo_term_for_combtyp (CombTVar name) = ATerm (name, [])
   | fo_term_for_combtyp (CombTFree name) = ATerm (name, [])
@@ -333,8 +341,44 @@
 
 fun formula_for_fo_literal (pos, t) = AAtom t |> not pos ? mk_anot
 
-fun fo_term_for_combterm thy type_sys =
+(* Finite types such as "unit", "bool", "bool * bool", and "bool => bool" are
+   considered dangerous because their "exhaust" properties can easily lead to
+   unsound ATP proofs. The checks below are an (unsound) approximation of
+   finiteness. *)
+
+fun is_dtyp_dangerous _ (Datatype_Aux.DtTFree _) = true
+  | is_dtyp_dangerous ctxt (Datatype_Aux.DtType (s, Us)) =
+    is_type_constr_dangerous ctxt s andalso forall (is_dtyp_dangerous ctxt) Us
+  | is_dtyp_dangerous _ (Datatype_Aux.DtRec _) = false
+and is_type_dangerous ctxt (Type (s, Ts)) =
+    is_type_constr_dangerous ctxt s andalso forall (is_type_dangerous ctxt) Ts
+  | is_type_dangerous ctxt _ = false
+and is_type_constr_dangerous ctxt s =
+  let val thy = ProofContext.theory_of ctxt in
+    case Datatype_Data.get_info thy s of
+      SOME {descr, ...} =>
+      forall (fn (_, (_, _, constrs)) =>
+                 forall (forall (is_dtyp_dangerous ctxt) o snd) constrs) descr
+    | NONE =>
+      case Typedef.get_info ctxt s of
+        ({rep_type, ...}, _) :: _ => is_type_dangerous ctxt rep_type
+      | [] => true
+  end
+
+fun is_combtyp_dangerous ctxt (CombType ((s, _), tys)) =
+    (case strip_prefix_and_unascii type_const_prefix s of
+       SOME s' => forall (is_combtyp_dangerous ctxt) tys andalso
+                  is_type_constr_dangerous ctxt (invert_const s')
+     | NONE => false)
+  | is_combtyp_dangerous _ _ = false
+
+fun should_tag_with_type ctxt (Tags full_types) ty =
+    full_types orelse is_combtyp_dangerous ctxt ty
+  | should_tag_with_type _ _ _ = false
+
+fun fo_term_for_combterm ctxt type_sys =
   let
+    val thy = ProofContext.theory_of ctxt
     fun aux top_level u =
       let
         val (head, args) = strip_combterm_comb u
@@ -364,31 +408,32 @@
                 end)
           | CombVar (name, _) => (name, [])
           | CombApp _ => raise Fail "impossible \"CombApp\""
-        val t = ATerm (x, map fo_term_for_combtyp ty_args @
-                          map (aux false) args)
+        val t =
+          ATerm (x, map fo_term_for_combtyp ty_args @ map (aux false) args)
+        val ty = combtyp_of u
     in
-      t |> (if type_sys = Tags true then
-              wrap_type (fo_term_for_combtyp (combtyp_of u))
+      t |> (if should_tag_with_type ctxt type_sys ty then
+              tag_with_type (fo_term_for_combtyp ty)
             else
               I)
     end
   in aux true end
 
-fun formula_for_combformula thy type_sys =
+fun formula_for_combformula ctxt type_sys =
   let
     fun aux (AQuant (q, xs, phi)) = AQuant (q, xs, aux phi)
       | aux (AConn (c, phis)) = AConn (c, map aux phis)
-      | aux (AAtom tm) = AAtom (fo_term_for_combterm thy type_sys tm)
+      | aux (AAtom tm) = AAtom (fo_term_for_combterm ctxt type_sys tm)
   in aux end
 
-fun formula_for_fact thy type_sys
+fun formula_for_fact ctxt type_sys
                      ({combformula, ctypes_sorts, ...} : translated_formula) =
   mk_ahorn (map (formula_for_fo_literal o fo_literal_for_type_literal)
                 (atp_type_literals_for_types type_sys ctypes_sorts))
-           (formula_for_combformula thy type_sys combformula)
+           (formula_for_combformula ctxt type_sys combformula)
 
-fun problem_line_for_fact thy prefix type_sys (formula as {name, kind, ...}) =
-  Fof (prefix ^ ascii_of name, kind, formula_for_fact thy type_sys formula)
+fun problem_line_for_fact ctxt prefix type_sys (formula as {name, kind, ...}) =
+  Fof (prefix ^ ascii_of name, kind, formula_for_fact ctxt type_sys formula)
 
 fun problem_line_for_class_rel_clause (ClassRelClause {name, subclass,
                                                        superclass, ...}) =
@@ -411,10 +456,10 @@
                 (formula_for_fo_literal
                      (fo_literal_for_arity_literal conclLit)))
 
-fun problem_line_for_conjecture thy type_sys
+fun problem_line_for_conjecture ctxt type_sys
         ({name, kind, combformula, ...} : translated_formula) =
   Fof (conjecture_prefix ^ name, kind,
-       formula_for_combformula thy type_sys combformula)
+       formula_for_combformula ctxt type_sys combformula)
 
 fun free_type_literals_for_conjecture type_sys
         ({ctypes_sorts, ...} : translated_formula) =
@@ -445,7 +490,7 @@
                 max_arity = Int.max (n, max_arity),
                 sub_level = sub_level orelse not top_level})
      end)
-  #> fold (consider_term (top_level andalso s = type_wrapper_name)) ts
+  #> fold (consider_term (top_level andalso s = type_tag_name)) ts
 fun consider_formula (AQuant (_, _, phi)) = consider_formula phi
   | consider_formula (AConn (_, phis)) = fold consider_formula phis
   | consider_formula (AAtom tm) = consider_term true tm
@@ -458,7 +503,7 @@
   else SOME (Symtab.empty |> consider_problem problem)
 
 fun min_arity_of thy type_sys NONE s =
-    (if s = "equal" orelse s = type_wrapper_name orelse
+    (if s = "equal" orelse s = type_tag_name orelse
         String.isPrefix type_const_prefix s orelse
         String.isPrefix class_prefix s then
        16383 (* large number *)
@@ -471,25 +516,29 @@
     | NONE => 0
 
 fun full_type_of (ATerm ((s, _), [ty, _])) =
-    if s = type_wrapper_name then ty else raise Fail "expected type wrapper"
-  | full_type_of _ = raise Fail "expected type wrapper"
+    if s = type_tag_name then SOME ty else NONE
+  | full_type_of _ = NONE
 
 fun list_hAPP_rev _ t1 [] = t1
   | list_hAPP_rev NONE t1 (t2 :: ts2) =
     ATerm (`I "hAPP", [list_hAPP_rev NONE t1 ts2, t2])
   | list_hAPP_rev (SOME ty) t1 (t2 :: ts2) =
-    let val ty' = ATerm (`make_fixed_type_const @{type_name fun},
-                         [full_type_of t2, ty]) in
-      ATerm (`I "hAPP", [wrap_type ty' (list_hAPP_rev (SOME ty') t1 ts2), t2])
-    end
+    case full_type_of t2 of
+      SOME ty2 =>
+      let val ty' = ATerm (`make_fixed_type_const @{type_name fun},
+                           [ty2, ty]) in
+        ATerm (`I "hAPP",
+               [tag_with_type ty' (list_hAPP_rev (SOME ty') t1 ts2), t2])
+      end
+    | NONE => list_hAPP_rev NONE t1 (t2 :: ts2)
 
 fun repair_applications_in_term thy type_sys const_tab =
   let
     fun aux opt_ty (ATerm (name as (s, _), ts)) =
-      if s = type_wrapper_name then
+      if s = type_tag_name then
         case ts of
           [t1, t2] => ATerm (name, [aux NONE t1, aux (SOME t1) t2])
-        | _ => raise Fail "malformed type wrapper"
+        | _ => raise Fail "malformed type tag"
       else
         let
           val ts = map (aux NONE) ts
@@ -513,11 +562,11 @@
     | NONE => false
 
 fun repair_predicates_in_term const_tab (t as ATerm ((s, _), ts)) =
-  if s = type_wrapper_name then
+  if s = type_tag_name then
     case ts of
       [_, t' as ATerm ((s', _), _)] =>
       if is_predicate const_tab s' then t' else boolify t
-    | _ => raise Fail "malformed type wrapper"
+    | _ => raise Fail "malformed type tag"
   else
     t |> not (is_predicate const_tab s) ? boolify
 
@@ -561,11 +610,11 @@
     val (fact_names, (conjectures, facts, helper_facts, class_rel_clauses,
                       arity_clauses)) =
       translate_formulas ctxt type_sys hyp_ts concl_t facts
-    val fact_lines = map (problem_line_for_fact thy fact_prefix type_sys) facts
+    val fact_lines = map (problem_line_for_fact ctxt fact_prefix type_sys) facts
     val helper_lines =
-      map (problem_line_for_fact thy helper_prefix type_sys) helper_facts
+      map (problem_line_for_fact ctxt helper_prefix type_sys) helper_facts
     val conjecture_lines =
-      map (problem_line_for_conjecture thy type_sys) conjectures
+      map (problem_line_for_conjecture ctxt type_sys) conjectures
     val tfree_lines = problem_lines_for_free_types type_sys conjectures
     val class_rel_lines =
       map problem_line_for_class_rel_clause class_rel_clauses