Continue with closure_solver closure-solver
authorLukas Stevens <mail@lukas-stevens.de>
Tue, 24 Aug 2021 15:21:41 +0000
branchclosure-solver
changeset 74557 fc7fb7dae81e
parent 74556 0bedb59471f8
child 74559 7a35485e8c41
Continue with closure_solver
src/HOL/Tools/closure_tac.ML
src/HOL/Transitive_Closure.thy
--- a/src/HOL/Tools/closure_tac.ML	Mon Aug 23 13:43:02 2021 +0200
+++ b/src/HOL/Tools/closure_tac.ML	Tue Aug 24 15:21:41 2021 +0000
@@ -1,6 +1,9 @@
+type closure_ctor = (Closure_Procedure.inta * Closure_Procedure.inta)
+                    -> Closure_Procedure.closure_atom
+
 signature CLOSURE_TAC_ARGS =
 sig
-  val mk_rel_typ : typ -> typ
+  val field_type_of : typ -> typ
 
   val in_pat : term -> term -> term -> term
   val in_trancl_pat : term -> term -> term -> term
@@ -20,8 +23,13 @@
   val not_reflcl_eq_and_in_conv : conv
   val in_converse_conv : conv
   val not_in_converse_conv : conv
+
+  val decomp : term
+           -> (bool * (term * term) * (closure_ctor * term)) option
 end
 
+val closure_trace_cfg = Attrib.setup_config_bool @{binding "closure_trace"} (K false)
+
 functor Closure_Tac(Closure_Tac_Args : CLOSURE_TAC_ARGS) =
 struct
 
@@ -31,12 +39,6 @@
 fun expect _ (SOME x) = x
   | expect f NONE = f ()
 
-fun find_term t = Library.find_first (fn (t', _) => t' aconv t)
-
-fun list_curry0 f = (fn [] => f, 0)
-fun list_curry1 f = (fn [x] => f x, 1)
-fun list_curry2 f = (fn [x, y] => f x y, 2)
-
 fun dereify_closure_fm r reifytab t =
   let
     fun dereify_term' (App (App (Const "in", t1), t2)) =
@@ -66,6 +68,10 @@
         | strip x = x
   in strip (t, []) end
 
+fun list_curry0 f = (fn [] => f, 0)
+fun list_curry1 f = (fn [x] => f x, 1)
+fun list_curry2 f = (fn [x, y] => f x y, 2)
+
 fun replay_conv convs cvp =
   let
     val convs = convs @
@@ -104,7 +110,7 @@
           replay_prf_trm' assmtab p
           |> Drule.infer_instantiate' ctxt [SOME (Thm.cterm_of ctxt (dereify t))]
       | replay_prf_trm' assmtab (AppP (p1, p2)) =
-          apply2 (replay_prf_trm' assmtab) (p2, p1) |> (op COMP)
+          @{print} (apply2 (replay_prf_trm' assmtab) (p2, p1)) |> (op COMP)
       | replay_prf_trm' assmtab (AbsP (reified_t, p)) =
           let
             val t = dereify reified_t
@@ -164,138 +170,77 @@
     replay_prf_trm (replay_conv convs) dereify ctxt thmtab assmtab
   end
 
-fun partition_matches ctxt dummys pats term_of ts =
-  let
-    val thy = Proof_Context.theory_of ctxt
+  fun closure_tac simp_prems r = Subgoal.FOCUS (fn {prems=prems, context=ctxt, ...} =>
+    let
+      fun these' _ [] = []
+        | these' f (x :: xs) = case f x of NONE => these' f xs | SOME y => (x, y) :: these' f xs
+
+      val field_ty = field_type_of (Term.fastype_of r)
 
-    fun find_match t env =                      
-      Library.get_first (try (fn pat => Pattern.match thy (pat, t) env)) pats
-    
-    fun delete_dummys tab = fold Vartab.delete_safe dummys tab
-    fun filter_matches xs = fold (fn x => fn (mxs, nmxs, env) =>
-        case find_match (term_of x) env of
-          SOME (tyenv', tenv') => (x::mxs, nmxs, (delete_dummys tyenv', delete_dummys tenv'))
-        | NONE => (mxs, x::nmxs, env))
-      xs ([], [], (Vartab.empty, Vartab.empty))
-
-    fun partition xs =
-      case filter_matches xs of
-        ([], _, _) => []
-      | (mxs, nmxs, env) => (env, mxs) :: partition nmxs
-  in
-    partition ts
-  end
+      fun decomp_eq' (@{const Not} $ t) =
+            Option.map (fn (b, t') => (not b, t')) (decomp_eq' t)
+        | decomp_eq' ((Term.Const (@{const_name HOL.eq}, ty) $ x $ y)) =
+            if ty = field_ty --> field_ty --> @{typ bool} then SOME (true, (x, y)) else NONE
+        | decomp_eq' _ = NONE
+      fun decomp_eq (@{const Trueprop} $ t) = decomp_eq' t
+        | decomp_eq _ = NONE
 
-fun reify ctxt pats atom_pats atom_ctor t reifytab =
-  let
-    val thy = Proof_Context.theory_of ctxt
-    fun try_match pat t = try (Pattern.match thy (pat, t)) (Vartab.empty, Vartab.empty)
-    fun get_first_match pats t =
-      Library.get_first (fn (pat, vars, ctor) => Option.map (pair (vars, ctor)) (try_match pat t)) pats 
-    val _ = fold
-  in
-    case get_first_match pats t of
-      SOME ((vars, ctor), env) =>
-        let
-          val ts = map (snd o the o Vartab.lookup (snd env)) vars
-          val (rts, reifytab') = fold_rev (fn t => fn (rts, reifytab) =>
-                                            reify ctxt pats atom_pats atom_ctor t reifytab
-                                          |>> (fn rt => rt :: rts)) ts ([], reifytab)
-        in
-          (ctor rts, reifytab')
-        end
-    | NONE =>
-        case get_first_match atom_pats t of
-          SOME ((vars, ctor), env) =>
+      val prems = filter (fn p => null (Term.add_vars (Thm.prop_of p) [])) (simp_prems @ prems)
+      val decomp_prems = these' (decomp o Thm.prop_of) prems
+                         |> filter (fn (_, (_, _, (_, r'))) => r aconv r')
+      val decomp_eq_prems = these' (decomp_eq o Thm.prop_of) prems
+
+      fun reify (_, (b, (x, y), (ctor, _))) (ps, reifytab) =
+        (Reifytab.get_var x ##>> Reifytab.get_var y) reifytab
+        |>> (fn vp => (b, (ctor (apply2 Int_of_integer vp))) :: ps)
+      val (reified_prems, reifytab) = fold_rev reify decomp_prems ([], Reifytab.empty)
+
+      fun reify_eq (_, (b, (x, y))) (ps, reifytab) =
+        (Reifytab.get_var x ##>> Reifytab.get_var y) reifytab
+        |>> (fn vp => (b, EQ (apply2 Int_of_integer vp)) :: ps)
+      val (reified_eq_prems, reifytab) = fold_rev reify_eq decomp_eq_prems ([], reifytab)
+
+      fun closure_procedure [] _ _ = no_tac
+        | closure_procedure prems reified_prems reifytab =
             let
-              val ts = map (snd o the o Vartab.lookup (snd env)) vars 
-              val (rvs, reifytab') = fold_rev (fn t => fn (rvs, reifytab) => Reifytab.get_var t reifytab
-                                               |>> (fn rv => rv :: rvs)) ts ([], reifytab)
-            in
-              (atom_ctor (ctor (map Int_of_integer rvs)), reifytab')
-            end
-        | NONE => raise TERM ("Can't reify term", [t])
-  end
-
-local
-  val aty = TVar (("'a", 0), ["HOL.type"])
-  val xn = ("x", 0)
-  val yn = ("y", 0)
-  val x = Term.Var (xn, aty)
-  val y = Term.Var (yn, aty)
-  val r = Term.Var (("r", 0), mk_rel_typ aty)
+              val _ = if Config.get ctxt closure_trace_cfg then @{print} prems else prems
 
-  val closure_pats' = [ in_trancl_pat x y r, in_rtrancl_pat x y r
-                      , in_reflcl_pat x y r, in_converse_pat x y r]
-  val closure_pats = closure_pats' @ map HOLogic.mk_not closure_pats'
-  val eq_pat = HOLogic.mk_eq (x, y)
-  fun rem_pats r = [ eq_pat, HOLogic.mk_not eq_pat, in_pat x y r, HOLogic.mk_not (in_pat x y r) ]
-in
-
-  fun reify_closure_conj ctxt r =
-    let
-      val pats = [ (@{term conj} $ x $ y, [xn, yn], (fn [x, y] => And (x, y))) ]
-      val atom_pats = [ (eq_pat, [xn, yn], (fn [x, y] => (true, EQ (x, y))))
-                      , (in_trancl_pat x y r, [xn, yn], (fn [x, y] => (true, InTcl (x, y))))
-                      , (in_rtrancl_pat x y r, [xn, yn], (fn [x, y] => (true, InRtcl (x, y))))
-                      , (in_reflcl_pat x y r, [xn, yn], (fn [x, y] => (true, InReflcl (x, y))))
-                      , (in_converse_pat x y r, [xn, yn], (fn [x, y] => (true, InConv (x, y))))
-                      , (in_pat x y r, [xn, yn], (fn [x, y] => (true, In (x, y))))
-                      ]
-      fun negate_pat (pat, vars, ctor) = (HOLogic.mk_not pat, vars, ctor #>> not)
+              val prems_conj_thm = foldl1 (fn (x, a) => @{thm conjI} OF [x, a]) prems
+                                   |> Conv.fconv_rule Thm.eta_conversion
+              val reified_prems_conj = foldl1 (fn (x, a) => And (x, a)) (map Atom reified_prems)
+    
+              val proof = Closure_Procedure.full_contr_prf reified_prems_conj
+          
+              val assmtab = Termtab.make [(Thm.prop_of prems_conj_thm, prems_conj_thm)]
+              val replay = replay_closure_prf_trm ctxt r reifytab assmtab
+            in
+              case proof of
+                NONE => no_tac
+              | SOME p => SOLVED' (resolve_tac ctxt [replay p]) 1
+            end
     in
-      reify ctxt pats (atom_pats @ map negate_pat atom_pats) Atom
-    end
-
-  fun closure_tac simp_prems = Subgoal.FOCUS (fn {prems=prems, context=ctxt, ...} =>
-    let
-      fun is_atomic p = case Thm.prop_of p of tp $ _ => tp = HOLogic.Trueprop | _ => false
-      val prems = filter (fn p => null (Term.add_vars (Thm.prop_of p) []) andalso is_atomic p)
-                         (simp_prems @ prems)
-      val part_prems = partition_matches ctxt (map (fst o Term.dest_Var) [x, y]) closure_pats
-                                         (HOLogic.dest_Trueprop o Thm.prop_of) prems
-
-      fun closure_tac' (_, []) = no_tac
-        | closure_tac' (env, matched_prems) =
-          let
-            val thy = Proof_Context.theory_of ctxt
-            val r = Vartab.lookup (snd env) (fst (Term.dest_Var r)) |> the |> snd
+      closure_procedure (map fst decomp_prems @ map fst decomp_eq_prems)
+                        (reified_prems @ reified_eq_prems)
+                        reifytab
+    end)
+ 
+  val ad_absurdum_tac = SUBGOAL (fn (A, i) =>
+      case try (HOLogic.dest_Trueprop o Logic.strip_assums_concl) A of
+        SOME (nt $ _) =>
+          if nt = HOLogic.Not
+            then resolve0_tac [@{thm notI}] i
+            else resolve0_tac [@{thm ccontr}] i
+      | _ => resolve0_tac [@{thm ccontr}] i)
 
-            fun try_match pats t =
-              Library.get_first (fn pat => try (Pattern.match thy (pat, t)) env) pats
-            fun try_match_thm pats p = try_match pats (HOLogic.dest_Trueprop (Thm.prop_of p))
-            val prems = matched_prems @
-                        filter (Option.isSome o try_match_thm (rem_pats r)) prems
-
-            val prems_conj_thm = foldl1 (fn (x, a) => @{thm conjI} OF [x, a]) prems
-              |> Conv.fconv_rule Thm.eta_conversion 
-            val prems_conj = Thm.prop_of prems_conj_thm
-            val (reified_prems_conj, reifytab) =
-              reify_closure_conj ctxt r (HOLogic.dest_Trueprop prems_conj) Reifytab.empty
-
-            val proof = Closure_Procedure.full_contr_prf reified_prems_conj
-
-            val assmtab = Termtab.make [(prems_conj, prems_conj_thm)]
-            val replay = replay_closure_prf_trm ctxt r reifytab assmtab
-          in
-            case proof of
-              NONE => no_tac
-            | SOME p => SOLVED' (resolve_tac ctxt [replay p]) 1
-          end
+  fun tac simp_prems ctxt = SUBGOAL (fn (A, i) =>
+    let val goal = Logic.strip_assums_concl A
     in
-      FIRST (map closure_tac' part_prems)
+      if null (Term.add_vars goal []) then
+        case decomp goal of
+          NONE => no_tac
+        | SOME (_, _, (_, r)) => EVERY' [ ad_absurdum_tac, CONVERSION Thm.eta_conversion
+                                        , closure_tac simp_prems r ctxt ] i
+      else no_tac
     end)
-end
-
-val ad_absurdum_tac = SUBGOAL (fn (A, i) =>
-    case try (HOLogic.dest_Trueprop o Logic.strip_assums_concl) A of
-      SOME (nt $ _) =>
-        if nt = HOLogic.Not
-          then resolve0_tac [@{thm notI}] i
-          else resolve0_tac [@{thm ccontr}] i
-    | _ => resolve0_tac [@{thm ccontr}] i)
-
-fun tac simp_prems ctxt =
-  EVERY' [ad_absurdum_tac, CONVERSION Thm.eta_conversion, closure_tac simp_prems ctxt]
 
 end
\ No newline at end of file
--- a/src/HOL/Transitive_Closure.thy	Mon Aug 23 13:43:02 2021 +0200
+++ b/src/HOL/Transitive_Closure.thy	Tue Aug 24 15:21:41 2021 +0000
@@ -1306,9 +1306,12 @@
   using conversep_iff by blast
 
 ML \<open>
-  local
+
+\<close>
+ML \<open>
+ local
     fun mk_closure_type r = Term.fastype_of r --> Term.fastype_of r
-  
+
     fun mk_Id ty = Const (@{const_name Id}, ty)
     fun mk_eq ty = Const (@{const_name HOL.eq}, ty)
     fun mk_reflcl mk_Id r =
@@ -1316,30 +1319,52 @@
       in Const (@{const_name sup}, r_ty --> r_ty --> r_ty) $ r $ mk_Id r_ty end
   in
     structure Closure_Tac_Rel = Closure_Tac(
-      fun mk_rel_typ ty = HOLogic.mk_prodT (ty, ty) |> HOLogic.mk_setT
-
-      fun in_pat x y r = HOLogic.mk_mem (HOLogic.mk_prod (x, y), r)
-      fun in_trancl_pat x y r = in_pat x y (Const (@{const_name trancl}, mk_closure_type r) $ r)
-      fun in_rtrancl_pat x y r = in_pat x y (Const (@{const_name rtrancl}, mk_closure_type r) $ r)
-      fun in_reflcl_pat x y r = in_pat x y (mk_reflcl mk_Id r)
-      fun in_converse_pat x y r = in_pat x y (Const (@{const_name converse}, Term.fastype_of r))
+        val field_type_of = fst o HOLogic.dest_prodT o HOLogic.dest_setT
+  
+        fun in_pat x y r = HOLogic.mk_mem (HOLogic.mk_prod (x, y), r)
+        fun in_trancl_pat x y r = in_pat x y (Const (@{const_name trancl}, mk_closure_type r) $ r)
+        fun in_rtrancl_pat x y r = in_pat x y (Const (@{const_name rtrancl}, mk_closure_type r) $ r)
+        fun in_reflcl_pat x y r = in_pat x y (mk_reflcl mk_Id r)
+        fun in_converse_pat x y r = in_pat x y (Const (@{const_name converse}, Term.fastype_of r))
+  
+        val r_into_trancl_thm = @{thm r_into_trancl}
+        val trancl_trans_thm = @{thm trancl_trans}
+        val eq_in_thm = @{lemma \<open>v = x \<Longrightarrow> w = y \<Longrightarrow> (x, y) \<in> r \<Longrightarrow> (v, w) \<in> r\<close> by simp}
+        val eq_in_trancl_thm = @{lemma \<open>v = x \<Longrightarrow> w = y \<Longrightarrow> (x, y) \<in> r\<^sup>+ \<Longrightarrow> (v, w) \<in> r\<^sup>+\<close> by simp}
+        val rtrancl_refl_thm = @{lemma \<open>x = y \<Longrightarrow> (x, y) \<in> r\<^sup>*\<close> by simp}
+        val not_rtrancl_into_not_trancl_thm = @{thm not_rtrancl_into_not_trancl}
+  
+        val rtrancl_eq_or_trancl_conv = Conv.rewr_conv @{thm eq_reflection[OF rtrancl_eq_or_trancl]}
+        val reflcl_eq_or_in_conv = Conv.rewr_conv @{thm eq_reflection[OF reflcl_eq_or_in]}
+        val not_reflcl_eq_and_in_conv = Conv.rewr_conv @{thm eq_reflection[OF not_reflcl_eq_and_in]}
+        val in_converse_conv = Conv.rewr_conv @{thm eq_reflection[OF converse_iff]}
+        val not_in_converse_conv = Conv.rewr_conv @{thm eq_reflection[OF not_in_converse]}
   
-      val r_into_trancl_thm = @{thm r_into_trancl}
-      val trancl_trans_thm = @{thm trancl_trans}
-      val eq_in_thm = @{lemma \<open>v = x \<Longrightarrow> w = y \<Longrightarrow> (x, y) \<in> r \<Longrightarrow> (v, w) \<in> r\<close> by simp}
-      val eq_in_trancl_thm = @{lemma \<open>v = x \<Longrightarrow> w = y \<Longrightarrow> (x, y) \<in> r\<^sup>+ \<Longrightarrow> (v, w) \<in> r\<^sup>+\<close> by simp}
-      val rtrancl_refl_thm = @{lemma \<open>x = y \<Longrightarrow> (x, y) \<in> r\<^sup>*\<close> by simp}
-      val not_rtrancl_into_not_trancl_thm = @{thm not_rtrancl_into_not_trancl}
-
-      val rtrancl_eq_or_trancl_conv = Conv.rewr_conv @{thm eq_reflection[OF rtrancl_eq_or_trancl]}
-      val reflcl_eq_or_in_conv = Conv.rewr_conv @{thm eq_reflection[OF reflcl_eq_or_in]}
-      val not_reflcl_eq_and_in_conv = Conv.rewr_conv @{thm eq_reflection[OF not_reflcl_eq_and_in]}
-      val in_converse_conv = Conv.rewr_conv @{thm eq_reflection[OF converse_iff]}
-      val not_in_converse_conv = Conv.rewr_conv @{thm eq_reflection[OF not_in_converse]}
-    );
+        fun decomp' (@{const Not} $ t) =
+              Option.map (fn (b, p, t') => (not b, p, t')) (decomp' t)
+          | decomp' t =
+            let
+              local open Closure_Procedure
+              in
+                fun decomp_rel (Term.Const (@{const_name rtrancl}, _) $ r) = (InRtcl, r)
+                  | decomp_rel (Term.Const (@{const_name trancl}, _) $ r) = (InTcl, r)
+                  | decomp_rel (Term.Const (@{const_name converse}, _) $ r) = (InConv, r)
+                  | decomp_rel (Term.Const (@{const_name sup}, _) $ r $ Term.Const (@{const_name Id}, _)) =
+                      (InReflcl, r)
+                  | decomp_rel r = (In, r)
+              end
+            in
+              (* We use dest_prod to ensure that r is a relation (a set of pairs) *)
+              case try (HOLogic.dest_mem #>> HOLogic.dest_prod) (Envir.beta_eta_contract t) of
+                SOME ((x, y), r) => SOME (true, (x, y), decomp_rel r)
+              | NONE => NONE
+            end
+        fun decomp (@{const Trueprop} $ t) = decomp' t
+          | decomp _ = NONE
+      );
 
     structure Closure_Tac_Pred = Closure_Tac(
-      fun mk_rel_typ ty = ty --> ty --> HOLogic.boolT
+      val field_type_of = Term.domain_type
 
       fun in_pat x y r = r $ x $ y
       fun in_trancl_pat x y r = in_pat x y (Const (@{const_name tranclp}, mk_closure_type r) $ r)
@@ -1359,8 +1384,33 @@
       val not_reflcl_eq_and_in_conv = Conv.rewr_conv @{thm eq_reflection[OF not_reflclp_eq_and_in]}
       val in_converse_conv = Conv.rewr_conv @{thm eq_reflection[OF conversep_iff]}
       val not_in_converse_conv = Conv.rewr_conv @{thm eq_reflection[OF not_in_conversep]}
+
+      fun decomp' (@{const Not} $ t) =
+            Option.map (fn (b, p, t') => (not b, p, t')) (decomp' t)
+        | decomp' t =
+          let
+            local open Closure_Procedure
+            in
+              fun decomp_rel (Term.Const (@{const_name rtranclp}, _) $ r) = (InRtcl, r)
+                | decomp_rel (Term.Const (@{const_name tranclp}, _) $ r) = (InTcl, r)
+                | decomp_rel (Term.Const (@{const_name conversep}, _) $ r) = (InConv, r)
+                | decomp_rel (Term.Const (@{const_name sup}, _) $ r $ Term.Const (@{const_name HOL.eq}, _)) =
+                    (InReflcl, r)
+                | decomp_rel r = (In, r)
+            end
+          in
+            case Envir.beta_eta_contract t of
+              r $ x $ y =>
+                try Term.fastype_of r
+                |> Option.mapPartial (fn r_ty =>
+                     if r_ty = field_type_of r_ty --> field_type_of r_ty --> @{typ bool}
+                       then SOME (true, (x, y), decomp_rel r) else NONE)
+            | _ => NONE
+          end
+      fun decomp (@{const Trueprop} $ t) = decomp' t
+        | decomp _ = NONE
     );
-  end
+  end 
 \<close>
 
 setup \<open>
@@ -1373,6 +1423,30 @@
 lemma transp_rtranclp [simp]: "transp R\<^sup>*\<^sup>*"
   by(auto simp add: transp_def)
 
+context
+  fixes r :: "'a \<Rightarrow> 'a \<Rightarrow> bool" and x y
+  assumes t: "r x y"
+begin
+
+ML \<open>
+  val pat = nth (Thm.prems_of @{thm tranclp.r_into_trancl}) 0
+  val r = Pattern.match @{theory} (pat, Thm.prop_of @{thm t}) (Vartab.empty, Vartab.empty)
+\<close>
+
+ML \<open>
+  Thm.match (nth (Thm.cprems_of @{thm tranclp.r_into_trancl}) 0, Thm.cprop_of @{thm t})
+\<close>
+
+ML \<open>
+  Drule.multi_resolve (SOME @{context}) [Conv.fconv_rule Drule.beta_eta_conversion @{thm t}] (Conv.fconv_rule Drule.beta_eta_conversion @{thm tranclp.r_into_trancl})
+  |> Seq.list_of
+\<close>
+thm tranclp.r_into_trancl[OF t]
+end
+lemma "r x y \<Longrightarrow> r\<^sup>*\<^sup>* x y"
+  using [[closure_trace]]
+  apply(simp)
+
 text \<open>Optional methods.\<close>
 
 method_setup rtrancl =