instantiate foralls and release exists in the order described by the topological order
authorblanchet
Mon, 04 Oct 2010 14:36:18 +0200
changeset 39933 e764c5cf01fe
parent 39932 acde1b606b0e
child 39934 9f116d095e5e
instantiate foralls and release exists in the order described by the topological order
src/HOL/Tools/Sledgehammer/metis_tactics.ML
--- a/src/HOL/Tools/Sledgehammer/metis_tactics.ML	Mon Oct 04 14:34:15 2010 +0200
+++ b/src/HOL/Tools/Sledgehammer/metis_tactics.ML	Mon Oct 04 14:36:18 2010 +0200
@@ -57,74 +57,67 @@
    models = []}
 val resolution_params = {active = active_params, waiting = waiting_params}
 
-(* In principle, it should be sufficient to apply "assume_tac" to unify the
-   conclusion with one of the premises. However, in practice, this fails
-   horribly because of the mildly higher-order nature of the unification
-   problems. Typical constraints are of the form "?x a b =?= b", where "a" and
-   "b" are goal parameters. *)
-fun unify_prem_with_concl thy i th =
+fun shuffle_key (((axiom_no, (_, index_no)), _), _) = (index_no, axiom_no)
+fun shuffle_ord p =
+  rev_order (prod_ord int_ord int_ord (pairself shuffle_key p))
+
+val copy_prem = @{lemma "P ==> (P ==> P ==> Q) ==> Q" by fast}
+
+fun copy_prems_tac [] ns i =
+    if forall (curry (op =) 1) ns then all_tac else copy_prems_tac (rev ns) [] i
+  | copy_prems_tac (1 :: ms) ns i =
+    rotate_tac 1 i THEN copy_prems_tac ms (1 :: ns) i
+  | copy_prems_tac (m :: ms) ns i =
+    etac copy_prem i THEN copy_prems_tac ms (m div 2 :: (m + 1) div 2 :: ns) i
+
+fun instantiate_forall_tac thy params t i =
   let
-    val goal = Logic.get_goal (prop_of th) i |> Envir.beta_eta_contract
-    val prem = goal |> Logic.strip_assums_hyp |> the_single
-    val concl = goal |> Logic.strip_assums_concl
-    fun add_types Tp instT =
-      if exists (curry (op =) Tp) instT then instT
-      else Tp :: map (apsnd (typ_subst_atomic [Tp])) instT
-    fun unify_types (T, U) =
-      if T = U then
-        I
-      else case (T, U) of
-        (TVar _, _) => add_types (T, U)
-      | (_, TVar _) => add_types (U, T)
-      | (Type (s, Ts), Type (t, Us)) =>
-        if s = t andalso length Ts = length Us then fold unify_types (Ts ~~ Us)
-        else raise TYPE ("unify_types", [T, U], [])
-      | _ => raise TYPE ("unify_types", [T, U], [])
-    fun pair_untyped_aconv (t1, t2) (u1, u2) =
-      untyped_aconv t1 u1 andalso untyped_aconv t2 u2
-    fun add_terms tp inst =
-      if exists (pair_untyped_aconv tp) inst then inst
-      else tp :: map (apsnd (subst_atomic [tp])) inst
-    fun is_flex t =
-      case strip_comb t of
-        (Var _, args) => forall (is_Bound orf is_Var (*FIXME: orf is_Free*)) args
-      | _ => false
-    fun unify_flex flex rigid =
-      case strip_comb flex of
-        (Var (z as (_, T)), args) =>
-        add_terms (Var z,
-          (* FIXME: reindex bound variables *)
-          fold_rev (curry absdummy) (take (length args) (binder_types T)) rigid)
-      | _ => raise TERM ("unify_flex: expected flex", [flex])
-    fun unify_potential_flex comb atom =
-      if is_flex comb then unify_flex comb atom
-      else if is_Var atom then add_terms (atom, comb)
-      else raise TERM ("unify_terms", [comb, atom])
-    fun unify_terms (t, u) =
-      case (t, u) of
-        (t1 $ t2, u1 $ u2) =>
-        if is_flex t then unify_flex t u
-        else if is_flex u then unify_flex u t
-        else fold unify_terms [(t1, u1), (t2, u2)]
-      | (_ $ _, _) => unify_potential_flex t u
-      | (_, _ $ _) => unify_potential_flex u t
-      | (Var _, _) => add_terms (t, u)
-      | (_, Var _) => add_terms (u, t)
-      | _ => if untyped_aconv t u then I else raise TERM ("unify_terms", [t, u])
+    fun repair (t as (Var ((s, _), _))) =
+        (case find_index (fn ((s', _), _) => s' = s) params of
+           ~1 => t
+         | j => Bound j)
+      | repair (t $ u) = repair t $ repair u
+      | repair t = t
+    val t' = t |> repair |> fold (curry absdummy) (map snd params)
+    fun do_instantiate th =
+      let val var = Term.add_vars (prop_of th) [] |> the_single in
+        th |> Thm.instantiate ([], [(cterm_of thy (Var var), cterm_of thy t')])
+      end
+  in
+    etac @{thm allE} i
+    THEN rotate_tac ~1 i
+    THEN PRIMITIVE do_instantiate
+  end
 
-    val inst = [] |> unify_terms (prem, concl)
-    val _ = trace_msg (fn () => cat_lines (map (fn (t, u) =>
-        Syntax.string_of_term @{context} t ^ " |-> " ^
-        Syntax.string_of_term @{context} u) inst))
-    val instT = fold (fn Tp => unify_types (pairself fastype_of Tp)
-                               handle TERM _ => I) inst []
-    val inst = inst |> map (pairself (subst_atomic_types instT))
-    val cinstT = instT |> map (pairself (ctyp_of thy))
-    val cinst = inst |> map (pairself (cterm_of thy))
-  in th |> Thm.instantiate (cinstT, []) |> Thm.instantiate ([], cinst) end
-  handle Empty => th (* ### FIXME *)
+(*### TODO: fix confusion between ax_no and prem_no *)
+fun release_clusters_tac _ _ _ _ [] = K all_tac
+  | release_clusters_tac thy ax_counts substs params
+                         ((ax_no, cluster_no) :: clusters) =
+    let
+      val n = AList.lookup (op =) ax_counts ax_no |> the
+      fun in_right_cluster s =
+        (s |> Meson_Clausify.cluster_of_zapped_var_name |> fst |> snd |> fst)
+        = cluster_no
+      val alls =
+        substs
+        |> maps (fn (ax_no', (_, (_, tsubst))) =>
+                    if ax_no' = ax_no then
+                      tsubst |> filter (in_right_cluster
+                                        o fst o fst o dest_Var o fst)
+                             |> map snd
+                    else
+                      [])
+      val params' = params
+    in
+      rotate_tac ax_no
+      THEN' EVERY' (map (instantiate_forall_tac thy params) alls)
+(*      THEN' ### *)
+      THEN' rotate_tac (~ ax_no)
+      THEN' release_clusters_tac thy ax_counts substs params' clusters
+   end
 
-val cluster_ord = prod_ord (prod_ord int_ord int_ord) bool_ord
+val cluster_ord =
+  prod_ord (prod_ord int_ord (prod_ord int_ord int_ord)) bool_ord
 
 (* Attempts to derive the theorem "False" from a theorem of the form
    "P1 ==> ... ==> Pn ==> False", where the "Pi"s are to be discharged using the
@@ -148,17 +141,23 @@
                  |> map (Meson.term_pair_of
                          #> pairself (Envir.subst_term_types tyenv))
         in (tyenv, tsubst) end
-      fun subst_info_for_prem assm_no prem =
+      fun subst_info_for_prem prem_no prem =
         case prem of
           _ $ (Const (@{const_name skolem}, _) $ (_ $ t $ num)) =>
           let val ax_no = HOLogic.dest_nat num in
-            (ax_no, (assm_no, match_term (nth axioms ax_no |> snd, t)))
+            (ax_no, (prem_no, match_term (nth axioms ax_no |> snd, t)))
           end
         | _ => raise TERM ("discharge_skolem_premises: Malformed premise",
                            [prem])
       fun cluster_of_var_name skolem s =
-        let val (jj, skolem') = Meson_Clausify.cluster_of_zapped_var_name s in
-          if skolem' = skolem then SOME jj else NONE
+        let
+          val ((ax_no, (cluster_no, _)), skolem') =
+            Meson_Clausify.cluster_of_zapped_var_name s
+        in
+          if skolem' = skolem andalso cluster_no > 0 then
+            SOME (ax_no, cluster_no)
+          else
+            NONE
         end
       fun clusters_in_term skolem t =
         Term.add_var_names t [] |> map_filter (cluster_of_var_name skolem o fst)
@@ -168,10 +167,11 @@
         | [(ax_no, cluster_no)] =>
           SOME ((ax_no, cluster_no),
                 clusters_in_term true t
-                |> cluster_no > 0 ? cons (ax_no, cluster_no - 1))
+                |> cluster_no > 1 ? cons (ax_no, cluster_no - 1))
         | _ => raise TERM ("discharge_skolem_premises: Expected Var", [var])
       val prems = Logic.strip_imp_prems prems_imp_false_prop
-      val substs = map2 subst_info_for_prem (0 upto length prems - 1) prems
+      val substs = prems |> map2 subst_info_for_prem (0 upto length prems - 1)
+                         |> sort (int_ord o pairself fst)
       val depss = maps (map_filter deps_for_term_subst o snd o snd o snd) substs
       val clusters = maps (op ::) depss
       val ordered_clusters =
@@ -182,21 +182,41 @@
         handle Int_Pair_Graph.CYCLES _ =>
                error "Cannot replay Metis proof in Isabelle without axiom of \
                      \choice."
+      val params0 =
+        [] |> fold Term.add_vars (map snd axioms)
+           |> map (`(Meson_Clausify.cluster_of_zapped_var_name o fst o fst))
+           |> filter (fn (((_, (cluster_no, _)), skolem), _) =>
+                         cluster_no = 0 andalso skolem)
+           |> sort shuffle_ord |> map snd
+      val ax_counts =
+        Inttab.empty
+        |> fold (fn (ax_no, _) => Inttab.map_default (ax_no, 0) (Integer.add 1))
+                substs
+        |> Inttab.dest
 (* for debugging:
-      val _ = tracing ("SUBSTS: " ^ PolyML.makestring substs)
-      val _ = tracing ("ORDERED: " ^ PolyML.makestring ordered_clusters)
+      fun string_for_subst_info (ax_no, (prem_no, (tyenv, tsubst))) =
+        "ax: " ^ string_of_int ax_no ^ "; asm: " ^ string_of_int prem_no ^
+        "; tyenv: " ^ PolyML.makestring tyenv ^ "; tsubst: {" ^
+        commas (map ((fn (s, t) => s ^ " |-> " ^ t)
+                     o pairself (Syntax.string_of_term ctxt)) tsubst) ^ "}"
+      val _ = tracing ("SUBSTS:\n" ^ cat_lines (map string_for_subst_info substs))
+      val _ = tracing ("OUTERMOST SKOLEMS: " ^ PolyML.makestring params0)
+      val _ = tracing ("ORDERED CLUSTERS: " ^ PolyML.makestring ordered_clusters)
+      val _ = tracing ("AXIOM COUNT: " ^ PolyML.makestring ax_counts)
 *)
     in
       Goal.prove ctxt [] [] @{prop False}
-          (K (cut_rules_tac (map fst axioms) 1
+          (K (cut_rules_tac (map (fst o nth axioms o fst) ax_counts) 1
               THEN TRY (REPEAT_ALL_NEW (etac @{thm exE}) 1)
-              (* two copies are better than one (FIXME) *)
-              THEN etac @{lemma "P ==> (P ==> P ==> Q) ==> Q" by fast} 1
-              THEN TRY (REPEAT_ALL_NEW (etac @{thm allE}) 1)
+              THEN copy_prems_tac (map snd ax_counts) [] 1
+              THEN print_tac "copied axioms:"
+              THEN release_clusters_tac thy ax_counts substs params0
+                                        ordered_clusters 1
+              THEN print_tac "released axioms:"
               THEN match_tac [prems_imp_false] 1
+              THEN print_tac "applied rule:"
               THEN DETERM_UNTIL_SOLVED
                        (rtac @{thm skolem_COMBK_I} 1
-                        THEN PRIMITIVE (unify_prem_with_concl thy 1)
                         THEN assume_tac 1)))
     end