src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML
changeset 42829 1558741f8a72
parent 42828 8794ec73ec13
child 42830 1068d8fc1331
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML	Tue May 17 15:11:36 2011 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_atp_translate.ML	Tue May 17 15:11:36 2011 +0200
@@ -151,10 +151,6 @@
   | level_of_type_sys (Preds (_, level, _)) = level
   | level_of_type_sys (Tags (_, level, _)) = level
 
-fun depth_of_type_sys (Simple_Types _) = Shallow
-  | depth_of_type_sys (Preds (_, _, depth)) = depth
-  | depth_of_type_sys (Tags (_, _, depth)) = depth
-
 fun is_type_level_virtually_sound level =
   level = All_Types orelse level = Nonmonotonic_Types
 val is_type_sys_virtually_sound =
@@ -203,7 +199,7 @@
   s <> type_pred_base andalso s <> type_tag_name andalso
   (s = @{const_name HOL.eq} orelse level_of_type_sys type_sys = No_Types orelse
    (case type_sys of
-      Tags (_, All_Types, _) => true
+      Tags (_, All_Types, Deep) => true
     | _ => polymorphism_of_type_sys type_sys <> Mangled_Monomorphic andalso
            member (op =) boring_consts s))
 
@@ -213,8 +209,6 @@
   Mangled_Type_Args of bool |
   No_Type_Args
 
-(* FIXME: Find out whether and when erasing the non-result type arguments is
-   sound. *)
 fun general_type_arg_policy type_sys =
   if level_of_type_sys type_sys = No_Types then
     No_Type_Args
@@ -497,13 +491,23 @@
     exists (curry Type.raw_instance T) nonmono_Ts
   | should_encode_type _ _ _ _ = false
 
-fun should_predicate_on_type ctxt nonmono_Ts (Preds (_, level, _)) T =
-    should_encode_type ctxt nonmono_Ts level T
+fun should_predicate_on_type ctxt nonmono_Ts (Preds (_, level, depth)) T =
+    (case depth of
+       Deep => should_encode_type ctxt nonmono_Ts level T
+     | Shallow => error "Not implemented yet.")
   | should_predicate_on_type _ _ _ _ = false
 
-fun should_tag_with_type ctxt nonmono_Ts (Tags (_, level, _)) T =
-    should_encode_type ctxt nonmono_Ts level T
-  | should_tag_with_type _ _ _ _ = false
+datatype tag_site = Top_Level | Eq_Arg | Elsewhere
+
+fun should_tag_with_type _ _ _ Top_Level _ _ = false
+  | should_tag_with_type ctxt nonmono_Ts (Tags (_, level, depth)) site u T =
+    (case depth of
+       Deep => should_encode_type ctxt nonmono_Ts level T
+     | Shallow =>
+       case (site, u) of
+         (Eq_Arg, CombVar _) => should_encode_type ctxt nonmono_Ts level T
+       | _ => false)
+  | should_tag_with_type _ _ _ _ _ _ = false
 
 val homo_infinite_T = @{typ ind} (* any infinite type *)
 
@@ -772,31 +776,33 @@
   |> enforce_type_arg_policy_in_combterm ctxt nonmono_Ts type_sys
   |> AAtom
 
-fun formula_from_combformula ctxt nonmono_Ts type_sys =
+fun tag_with_type ctxt nonmono_Ts type_sys T tm =
+  CombConst (`make_fixed_const type_tag_name, T --> T, [T])
+  |> enforce_type_arg_policy_in_combterm ctxt nonmono_Ts type_sys
+  |> term_from_combterm ctxt nonmono_Ts type_sys Top_Level
+  |> (fn ATerm (s, tms) => ATerm (s, tms @ [tm]))
+and term_from_combterm ctxt nonmono_Ts type_sys site u =
   let
-    fun tag_with_type type_sys T tm =
-      CombConst (`make_fixed_const type_tag_name, T --> T, [T])
-      |> enforce_type_arg_policy_in_combterm ctxt nonmono_Ts type_sys
-      |> do_term true
-      |> (fn ATerm (s, tms) => ATerm (s, tms @ [tm]))
-    and do_term top_level u =
-      let
-        val (head, args) = strip_combterm_comb u
-        val (x, T_args) =
-          case head of
-            CombConst (name, _, T_args) => (name, T_args)
-          | CombVar (name, _) => (name, [])
-          | CombApp _ => raise Fail "impossible \"CombApp\""
-        val t = ATerm (x, map fo_term_from_typ T_args @
-                          map (do_term false) args)
-        val T = combtyp_of u
-      in
-        t |> (if not top_level andalso
-                should_tag_with_type ctxt nonmono_Ts type_sys T then
-                tag_with_type type_sys T
-              else
-                I)
-      end
+    val (head, args) = strip_combterm_comb u
+    val (x as (s, _), T_args) =
+      case head of
+        CombConst (name, _, T_args) => (name, T_args)
+      | CombVar (name, _) => (name, [])
+      | CombApp _ => raise Fail "impossible \"CombApp\""
+    val arg_site = if site = Top_Level andalso s = "equal" then Eq_Arg
+                   else Elsewhere
+    val t = ATerm (x, map fo_term_from_typ T_args @
+                      map (term_from_combterm ctxt nonmono_Ts type_sys arg_site)
+                          args)
+    val T = combtyp_of u
+  in
+    t |> (if should_tag_with_type ctxt nonmono_Ts type_sys site u T then
+            tag_with_type ctxt nonmono_Ts type_sys T
+          else
+            I)
+  end
+and formula_from_combformula ctxt nonmono_Ts type_sys =
+  let
     val do_bound_type =
       case type_sys of
         Simple_Types level =>
@@ -817,7 +823,8 @@
                            | (s, SOME T) => do_out_of_bound_type (s, T)) xs)
                     (do_formula phi))
       | do_formula (AConn (c, phis)) = AConn (c, map do_formula phis)
-      | do_formula (AAtom tm) = AAtom (do_term true tm)
+      | do_formula (AAtom tm) =
+        AAtom (term_from_combterm ctxt nonmono_Ts type_sys Top_Level tm)
   in do_formula end
 
 fun bound_atomic_types type_sys Ts =
@@ -940,22 +947,30 @@
 
 (* This inference is described in section 2.3 of Claessen et al.'s "Sorting it
    out with monotonicity" paper presented at CADE 2011. *)
-fun add_combterm_nonmonotonic_types _ (SOME false) _ = I
-  | add_combterm_nonmonotonic_types ctxt _
+fun add_combterm_nonmonotonic_types _ _  (SOME false) _ = I
+  | add_combterm_nonmonotonic_types ctxt level _
         (CombApp (CombApp (CombConst (("equal", _), Type (_, [T, _]), _), tm1),
                   tm2)) =
     (exists is_var_or_bound_var [tm1, tm2] andalso
-     not (is_type_surely_infinite ctxt T)) ? insert_type I T
-  | add_combterm_nonmonotonic_types _ _ _ = I
-fun add_fact_nonmonotonic_types ctxt ({kind, combformula, ...}
-                                      : translated_formula) =
-  formula_fold (kind <> Conjecture) (add_combterm_nonmonotonic_types ctxt)
-               combformula
+     (case level of
+        Nonmonotonic_Types => not (is_type_surely_infinite ctxt T)
+      | Finite_Types => is_type_surely_finite ctxt T
+      | _ => true)) ? insert_type I T
+  | add_combterm_nonmonotonic_types _ _ _ _ = I
+fun add_fact_nonmonotonic_types ctxt level ({kind, combformula, ...}
+                                            : translated_formula) =
+  formula_fold (kind <> Conjecture)
+               (add_combterm_nonmonotonic_types ctxt level) combformula
 fun add_nonmonotonic_types_for_facts ctxt type_sys facts =
-  level_of_type_sys type_sys = Nonmonotonic_Types
-  ? (fold (add_fact_nonmonotonic_types ctxt) facts
-     (* in case helper "True_or_False" is included *)
-     #> insert_type I @{typ bool})
+  let val level = level_of_type_sys type_sys in
+    (level = Nonmonotonic_Types orelse
+     (case type_sys of
+        Tags (poly, _, Shallow) => poly <> Polymorphic
+      | _ => false))
+    ? (fold (add_fact_nonmonotonic_types ctxt level) facts
+       (* in case helper "True_or_False" is included *)
+       #> insert_type I @{typ bool})
+  end
 
 fun result_type_of_decl (_, _, T, _, ary, _) = chop_fun ary T |> snd
 
@@ -972,8 +987,8 @@
 
 fun is_polymorphic_type T = fold_atyps (fn TVar _ => K true | _ => I) T false
 
-fun formula_line_for_sym_decl ctxt conj_sym_kind nonmono_Ts type_sys n s j
-                              (s', T_args, T, _, ary, in_conj) =
+fun formula_line_for_pred_sym_decl ctxt conj_sym_kind nonmono_Ts type_sys n s j
+                                   (s', T_args, T, _, ary, in_conj) =
   let
     val (kind, maybe_negate) =
       if in_conj then (conj_sym_kind, conj_sym_kind = Conjecture ? mk_anot)
@@ -981,7 +996,7 @@
     val (arg_Ts, res_T) = chop_fun ary T
     val bound_names =
       1 upto length arg_Ts |> map (`I o make_bound_var o string_of_int)
-    val bound_tms =
+    val bounds =
       bound_names ~~ arg_Ts |> map (fn (name, T) => CombConst (name, T, []))
     val bound_Ts =
       arg_Ts |> map (fn T => if n > 1 orelse is_polymorphic_type T then SOME T
@@ -990,7 +1005,7 @@
     Formula (sym_decl_prefix ^ s ^
              (if n > 1 then "_" ^ string_of_int j else ""), kind,
              CombConst ((s, s'), T, T_args)
-             |> fold (curry (CombApp o swap)) bound_tms
+             |> fold (curry (CombApp o swap)) bounds
              |> type_pred_combatom ctxt nonmono_Ts type_sys res_T
              |> mk_aquant AForall (bound_names ~~ bound_Ts)
              |> formula_from_combformula ctxt nonmono_Ts type_sys
@@ -1000,11 +1015,56 @@
              NONE, NONE)
   end
 
+fun formula_lines_for_tag_sym_decl ctxt nonmono_Ts type_sys n s
+                                   (j, (s', T_args, T, _, ary, _)) =
+  let
+    val ident_base =
+      sym_decl_prefix ^ s ^ (if n > 1 then "_" ^ string_of_int j else "")
+    val (arg_Ts, res_T) = chop_fun ary T
+    val bound_names =
+      1 upto length arg_Ts |> map (`I o make_bound_var o string_of_int)
+    val bounds = bound_names |> map (fn name => ATerm (name, []))
+    fun const args = ATerm ((s, s'), map fo_term_from_typ T_args @ args)
+    fun eq tm1 tm2 = ATerm (`I "equal", [tm1, tm2])
+    val should_encode =
+      should_encode_type ctxt nonmono_Ts
+          (if polymorphism_of_type_sys type_sys = Polymorphic then All_Types
+           else Nonmonotonic_Types)
+    val tag_with = tag_with_type ctxt nonmono_Ts type_sys
+    val add_formula_for_res =
+      if should_encode res_T then
+        cons (Formula (ident_base ^ "_res", Axiom,
+                       AAtom (eq (tag_with res_T (const bounds))
+                                 (const bounds))
+                       |> close_formula_universally,
+                       NONE, NONE))
+      else
+        I
+    fun add_formula_for_arg k =
+      let val arg_T = nth arg_Ts k in
+        if should_encode arg_T then
+          case chop k bounds of
+            (bounds1, bound :: bounds2) =>
+            cons (Formula (ident_base ^ "_arg" ^ string_of_int (k + 1), Axiom,
+                           AAtom (eq (const (bounds1 @
+                                  tag_with arg_T bound :: bounds2))
+                                     (const bounds))
+                           |> close_formula_universally,
+                           NONE, NONE))
+          | _ => raise Fail "expected nonempty tail"
+        else
+          I
+      end
+  in
+    [] |> add_formula_for_res
+       |> fold add_formula_for_arg (ary - 1 downto 0)
+  end
+
 fun problem_lines_for_sym_decls ctxt conj_sym_kind nonmono_Ts type_sys
                                 (s, decls) =
   case type_sys of
     Simple_Types level => map (decl_line_for_sym ctxt nonmono_Ts level s) decls
-  | _ =>
+  | Preds _ =>
     let
       val decls =
         case decls of
@@ -1023,9 +1083,17 @@
                          o result_type_of_decl)
     in
       (0 upto length decls - 1, decls)
-      |-> map2 (formula_line_for_sym_decl ctxt conj_sym_kind nonmono_Ts type_sys
-                                          n s)
+      |-> map2 (formula_line_for_pred_sym_decl ctxt conj_sym_kind nonmono_Ts
+                                               type_sys n s)
     end
+  | Tags (_, _, depth) =>
+    (case depth of
+       Deep => []
+     | Shallow =>
+       let val n = length decls in
+         (0 upto n - 1 ~~ decls)
+         |> maps (formula_lines_for_tag_sym_decl ctxt nonmono_Ts type_sys n s)
+       end)
 
 fun problem_lines_for_sym_decl_table ctxt conj_sym_kind nonmono_Ts type_sys
                                      sym_decl_tab =
@@ -1094,7 +1162,7 @@
                       (0 upto length helpers - 1 ~~ helpers)
                   |> (case type_sys of
                         Tags (Polymorphic, level, _) =>
-                        is_type_level_partial level
+                        is_type_level_partial level (* ### FIXME *)
                         ? cons (ti_ti_helper_fact ())
                       | _ => I)),
        (conjsN, map (formula_line_for_conjecture ctxt nonmono_Ts type_sys)