reintroduce old "unify_prem_with_concl" code to avoid reaching unification bound + primitive handling for polymorphism
authorblanchet
Mon, 04 Oct 2010 16:24:53 +0200
changeset 39935 56409c11195d
parent 39934 9f116d095e5e
child 39936 8f415cfc2180
reintroduce old "unify_prem_with_concl" code to avoid reaching unification bound + primitive handling for polymorphism
src/HOL/Tools/Sledgehammer/metis_tactics.ML
--- a/src/HOL/Tools/Sledgehammer/metis_tactics.ML	Mon Oct 04 15:05:19 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/metis_tactics.ML	Mon Oct 04 16:24:53 2010 +0200
@@ -57,6 +57,71 @@
    models = []}
 val resolution_params = {active = active_params, waiting = waiting_params}
 
+(* In principle, it should be sufficient to apply "assume_tac" to unify the
+   conclusion with one of the premises. However, in practice, this fails
+   horribly because of the mildly higher-order nature of the unification
+   problems. Typical constraints are of the form "?x a b =?= b", where "a" and
+   "b" are goal parameters. *)
+fun unify_prem_with_concl thy i th =
+  let
+    val goal = Logic.get_goal (prop_of th) i |> Envir.beta_eta_contract
+    val prem = goal |> Logic.strip_assums_hyp |> hd
+    val concl = goal |> Logic.strip_assums_concl
+    fun add_types Tp instT =
+      if exists (curry (op =) Tp) instT then instT
+      else Tp :: map (apsnd (typ_subst_atomic [Tp])) instT
+    fun unify_types (T, U) =
+      if T = U then
+        I
+      else case (T, U) of
+        (TVar _, _) => add_types (T, U)
+      | (_, TVar _) => add_types (U, T)
+      | (Type (s, Ts), Type (t, Us)) =>
+        if s = t andalso length Ts = length Us then fold unify_types (Ts ~~ Us)
+        else raise TYPE ("unify_types", [T, U], [])
+      | _ => raise TYPE ("unify_types", [T, U], [])
+    fun pair_untyped_aconv (t1, t2) (u1, u2) =
+      untyped_aconv t1 u1 andalso untyped_aconv t2 u2
+    fun add_terms tp inst =
+      if exists (pair_untyped_aconv tp) inst then inst
+      else tp :: map (apsnd (subst_atomic [tp])) inst
+    fun is_flex t =
+      case strip_comb t of
+        (Var _, args) => forall (is_Bound orf is_Var) args
+      | _ => false
+    fun unify_flex flex rigid =
+      case strip_comb flex of
+        (Var (z as (_, T)), args) =>
+        add_terms (Var z,
+          fold_rev (curry absdummy) (take (length args) (binder_types T)) rigid)
+      | _ => raise TERM ("unify_flex: expected flex", [flex])
+    fun unify_potential_flex comb atom =
+      if is_flex comb then unify_flex comb atom
+      else if is_Var atom then add_terms (atom, comb)
+      else raise TERM ("unify_terms", [comb, atom])
+    fun unify_terms (t, u) =
+      case (t, u) of
+        (t1 $ t2, u1 $ u2) =>
+        if is_flex t then unify_flex t u
+        else if is_flex u then unify_flex u t
+        else fold unify_terms [(t1, u1), (t2, u2)]
+      | (_ $ _, _) => unify_potential_flex t u
+      | (_, _ $ _) => unify_potential_flex u t
+      | (Var _, _) => add_terms (t, u)
+      | (_, Var _) => add_terms (u, t)
+      | _ => if untyped_aconv t u then I else raise TERM ("unify_terms", [t, u])
+
+    val inst = [] |> unify_terms (prem, concl)
+    val _ = trace_msg (fn () => cat_lines (map (fn (t, u) =>
+        Syntax.string_of_term @{context} t ^ " |-> " ^
+        Syntax.string_of_term @{context} u) inst))
+    val instT = fold (fn Tp => unify_types (pairself fastype_of Tp)
+                               handle TERM _ => I) inst []
+    val inst = inst |> map (pairself (subst_atomic_types instT))
+    val cinstT = instT |> map (pairself (ctyp_of thy))
+    val cinst = inst |> map (pairself (cterm_of thy))
+  in th |> Thm.instantiate (cinstT, []) |> Thm.instantiate ([], cinst) end
+
 fun shuffle_key (((axiom_no, (_, index_no)), _), _) = (index_no, axiom_no)
 fun shuffle_ord p =
   rev_order (prod_ord int_ord int_ord (pairself shuffle_key p))
@@ -94,7 +159,7 @@
   | release_clusters_tac thy ax_counts substs params
                          ((ax_no, cluster_no) :: clusters) =
     let
-      val n = AList.lookup (op =) ax_counts ax_no |> the
+(*      val n = AList.lookup (op =) ax_counts ax_no |> the (*###*) *)
       fun in_right_cluster s =
         (s |> Meson_Clausify.cluster_of_zapped_var_name |> fst |> snd |> fst)
         = cluster_no
@@ -119,16 +184,33 @@
 val cluster_ord =
   prod_ord (prod_ord int_ord (prod_ord int_ord int_ord)) bool_ord
 
+val tysubst_ord =
+  list_ord (prod_ord Term_Ord.fast_indexname_ord
+                     (prod_ord Term_Ord.sort_ord Term_Ord.typ_ord))
+
+structure Int_Tysubst_Table =
+  Table(type key = int * (indexname * (sort * typ)) list
+        val ord = prod_ord int_ord tysubst_ord)
+
 (* Attempts to derive the theorem "False" from a theorem of the form
    "P1 ==> ... ==> Pn ==> False", where the "Pi"s are to be discharged using the
    specified axioms. The axioms have leading "All" and "Ex" quantifiers, which
    must be eliminated first. *)
 fun discharge_skolem_premises ctxt axioms prems_imp_false =
-  case prop_of prems_imp_false of
-    @{prop False} => prems_imp_false
-  | prems_imp_false_prop =>
+  if prop_of prems_imp_false aconv @{prop False} then
+    prems_imp_false
+  else
     let
       val thy = ProofContext.theory_of ctxt
+      (* distinguish variables with same name but different types *)
+      val prems_imp_false' =
+        prems_imp_false |> try (forall_intr_vars o gen_all)
+                        |> the_default prems_imp_false
+      val prems_imp_false =
+        if prop_of prems_imp_false aconv prop_of prems_imp_false' then
+          prems_imp_false
+        else
+          prems_imp_false'
       fun match_term p =
         let
           val (tyenv, tenv) =
@@ -140,7 +222,8 @@
                                       o fst o fst))
                  |> map (Meson.term_pair_of
                          #> pairself (Envir.subst_term_types tyenv))
-        in (tyenv, tsubst) end
+          val tysubst = tyenv |> Vartab.dest
+        in (tysubst, tsubst) end
       fun subst_info_for_prem subgoal_no prem =
         case prem of
           _ $ (Const (@{const_name skolem}, _) $ (_ $ t $ num)) =>
@@ -169,7 +252,7 @@
                 clusters_in_term true t
                 |> cluster_no > 1 ? cons (ax_no, cluster_no - 1))
         | _ => raise TERM ("discharge_skolem_premises: Expected Var", [var])
-      val prems = Logic.strip_imp_prems prems_imp_false_prop
+      val prems = Logic.strip_imp_prems (prop_of prems_imp_false)
       val substs = prems |> map2 subst_info_for_prem (1 upto length prems)
                          |> sort (int_ord o pairself fst)
       val depss = maps (map_filter deps_for_term_subst o snd o snd o snd) substs
@@ -189,37 +272,39 @@
                          cluster_no = 0 andalso skolem)
            |> sort shuffle_ord |> map snd
       val ax_counts =
-        Inttab.empty
-        |> fold (fn (ax_no, _) => Inttab.map_default (ax_no, 0) (Integer.add 1))
-                substs
-        |> Inttab.dest
-(* for debugging:
-      fun string_for_subst_info (ax_no, (subgoal_no, (tyenv, tsubst))) =
+        Int_Tysubst_Table.empty
+        |> fold (fn (ax_no, (_, (tysubst, _))) =>
+                    Int_Tysubst_Table.map_default ((ax_no, tysubst), 0)
+                                                  (Integer.add 1)) substs
+        |> Int_Tysubst_Table.dest
+(* for debugging:###
+*)
+      fun string_for_subst_info (ax_no, (subgoal_no, (tysubst, tsubst))) =
         "ax: " ^ string_of_int ax_no ^ "; asm: " ^ string_of_int subgoal_no ^
-        "; tyenv: " ^ PolyML.makestring tyenv ^ "; tsubst: {" ^
+        "; tysubst: " ^ PolyML.makestring tysubst ^ "; tsubst: {" ^
         commas (map ((fn (s, t) => s ^ " |-> " ^ t)
                      o pairself (Syntax.string_of_term ctxt)) tsubst) ^ "}"
       val _ = tracing ("SUBSTS:\n" ^ cat_lines (map string_for_subst_info substs))
       val _ = tracing ("OUTERMOST SKOLEMS: " ^ PolyML.makestring params0)
       val _ = tracing ("ORDERED CLUSTERS: " ^ PolyML.makestring ordered_clusters)
-      val _ = tracing ("AXIOM COUNT: " ^ PolyML.makestring ax_counts)
-*)
+      val _ = tracing ("AXIOM COUNTS: " ^ PolyML.makestring ax_counts)
       fun rotation_for_subgoal i =
         find_index (fn (_, (subgoal_no, _)) => subgoal_no = i) substs
     in
       Goal.prove ctxt [] [] @{prop False}
-          (K (cut_rules_tac (map (fst o nth axioms o fst) ax_counts) 1
+          (K (cut_rules_tac (map (fst o nth axioms o fst o fst) ax_counts) 1
               THEN TRY (REPEAT_ALL_NEW (etac @{thm exE}) 1)
               THEN copy_prems_tac (map snd ax_counts) [] 1
-              THEN print_tac "copied axioms:"
               THEN release_clusters_tac thy ax_counts substs params0
                                         ordered_clusters 1
               THEN print_tac "released axioms:"
               THEN match_tac [prems_imp_false] 1
               THEN print_tac "applied rule:"
-              THEN ALLGOALS (fn i => rtac @{thm skolem_COMBK_I} i
-                                     THEN rotate_tac (rotation_for_subgoal i) i
-                                     THEN assume_tac i)))
+              THEN ALLGOALS (fn i =>
+                       rtac @{thm skolem_COMBK_I} i
+                       THEN rotate_tac (rotation_for_subgoal i) i
+                       THEN PRIMITIVE (unify_prem_with_concl thy i)
+                       THEN assume_tac i)))
     end
 
 (* Main function to start Metis proof and reconstruction *)
@@ -281,7 +366,7 @@
                   ();
                 case result of
                     (_,ith)::_ =>
-                        (trace_msg (fn () => "Success: " ^ Display.string_of_thm ctxt ith);
+                        (tracing(*###*) ("Success: " ^ Display.string_of_thm ctxt ith);
                          [discharge_skolem_premises ctxt dischargers ith])
                   | _ => (trace_msg (fn () => "Metis: No result"); [])
             end