Revert changes to order solver closure-solver
authorLukas Stevens <mail@lukas-stevens.de>
Thu, 23 Sep 2021 14:11:55 +0200
branchclosure-solver
changeset 74717 e0cbf224a970
parent 74716 a744e9a22655
child 74718 3b5e7f919e5e
Revert changes to order solver
src/Provers/order_tac.ML
--- a/src/Provers/order_tac.ML	Thu Sep 23 12:03:29 2021 +0000
+++ b/src/Provers/order_tac.ML	Thu Sep 23 14:11:55 2021 +0200
@@ -78,10 +78,68 @@
   fun expect _ (SOME x) = x
     | expect f NONE = f ()
 
+  fun matches_skeleton t s = t = Term.dummy orelse
+    (case (t, s) of
+      (t0 $ t1, s0 $ s1) => matches_skeleton t0 s0 andalso matches_skeleton t1 s1
+    | _ => t aconv s)
+
+  fun dest_binop t =
+    let
+      val binop_skel = Term.dummy $ Term.dummy $ Term.dummy
+      val not_binop_skel = Logic_Sig.Not $ binop_skel
+    in
+      if matches_skeleton not_binop_skel t
+        then (case t of (_ $ (t1 $ t2 $ t3)) => (false, (t1, t2, t3)))
+        else if matches_skeleton binop_skel t
+          then (case t of (t1 $ t2 $ t3) => (true, (t1, t2, t3)))
+          else raise TERM ("Not a binop literal", [t])
+    end
+
+  fun find_term t = Library.find_first (fn (t', _) => t' aconv t)
+
+  fun reify_order_atom (eq, le, lt) t reifytab =
+    let
+      val (b, (t0, t1, t2)) =
+        (dest_binop t) handle TERM (_, _) => raise TERM ("Can't reify order literal", [t])
+      val binops = [(eq, EQ), (le, LEQ), (lt, LESS)]
+    in
+      case find_term t0 binops of
+        SOME (_, reified_bop) =>
+          reifytab
+          |> Reifytab.get_var t1 ||> Reifytab.get_var t2
+          |> (fn (v1, (v2, vartab')) =>
+               ((b, reified_bop (Int_of_integer v1, Int_of_integer v2)), vartab'))
+          |>> Atom
+      | NONE => raise TERM ("Can't reify order literal", [t])
+    end
+
+  fun reify consts reify_atom t reifytab =
+    let
+      fun reify' (t1 $ t2) reifytab =
+            let
+              val (t0, ts) = strip_comb (t1 $ t2)
+              val consts_of_arity = filter (fn (_, (_, ar)) => length ts = ar) consts
+            in
+              (case find_term t0 consts_of_arity of
+                SOME (_, (reified_op, _)) => fold_map reify' ts reifytab |>> reified_op
+              | NONE => reify_atom (t1 $ t2) reifytab)
+            end
+        | reify' t reifytab = reify_atom t reifytab
+    in
+      reify' t reifytab
+    end
+
   fun list_curry0 f = (fn [] => f, 0)
   fun list_curry1 f = (fn [x] => f x, 1)
   fun list_curry2 f = (fn [x, y] => f x y, 2)
 
+  fun reify_order_conj ord_ops =
+    let
+      val consts = map (apsnd (list_curry2 o curry)) [(Logic_Sig.conj, And), (Logic_Sig.disj, Or)]
+    in   
+      reify consts (reify_order_atom ord_ops)
+    end
+
   fun dereify_term consts reifytab t =
     let
       fun dereify_term' (App (t1, t2)) = (dereify_term' t1) $ (dereify_term' t2)
@@ -146,17 +204,7 @@
             replay_prf_trm' assmtab p
             |> Drule.infer_instantiate' ctxt [SOME (Thm.cterm_of ctxt (dereify t))]
         | replay_prf_trm' assmtab (AppP (p1, p2)) =
-            let
-              val thy = Proof_Context.theory_of ctxt
-              val (thm1, thm2) = apply2 (replay_prf_trm' assmtab) (p1, p2)
-              val prem = hd (Thm.prems_of thm1)
-              val (_, tenv) = Pattern.first_order_match thy (prem, Thm.prop_of thm2)
-                                                            (Vartab.empty, Vartab.empty)
-              val inst = Vartab.dest tenv |> map (apsnd (Thm.cterm_of ctxt o snd))
-              val thm1 = Drule.infer_instantiate ctxt inst thm1
-            in
-              thm2 COMP thm1
-            end
+            apply2 (replay_prf_trm' assmtab) (p2, p1) |> (op COMP)
         | replay_prf_trm' assmtab (AbsP (reified_t, p)) =
             let
               val t = dereify reified_t
@@ -205,134 +253,108 @@
       replay_prf_trm (replay_conv convs) dereify ctxt thmtab assmtab
     end
 
-  fun strip_Not (nt $ t) = if nt = Logic_Sig.Not then t else nt $ t
-    | strip_Not t = t
+  fun is_binop_term t =
+    let
+      fun is_included t = forall (curry (op <>) (t |> fastype_of |> domain_type)) excluded_types
+    in
+      (case dest_binop (Logic_Sig.dest_Trueprop t) of
+        (_, (binop, t1, t2)) =>
+          is_included binop andalso
+          (* Exclude terms with schematic variables since the solver can't deal with them.
+             More specifically, the solver uses Assumption.assume which does not allow schematic
+             variables in the assumed cterm.
+          *)
+          Term.add_var_names (binop $ t1 $ t2) [] = []
+      ) handle TERM (_, _) => false
+    end
 
-  fun limit_not_less [_, _, lt] ctxt decomp_prems =
+  fun partition_matches ctxt term_of pats ys =
+    let
+      val thy = Proof_Context.theory_of ctxt
+
+      fun find_match t env =
+        Library.get_first (try (fn pat => Pattern.match thy (pat, t) env)) pats
+      
+      fun filter_matches xs = fold (fn x => fn (mxs, nmxs, env) =>
+        case find_match (term_of x) env of
+          SOME env' => (x::mxs, nmxs, env')
+        | NONE => (mxs, x::nmxs, env)) xs ([], [], (Vartab.empty, Vartab.empty))
+
+      fun partition xs =
+        case filter_matches xs of
+          ([], _, _) => []
+        | (mxs, nmxs, env) => (env, mxs) :: partition nmxs
+    in
+      partition ys
+    end
+
+  fun limit_not_less [_, _, lt] ctxt prems =
     let
       val thy = Proof_Context.theory_of ctxt
       val trace = Config.get ctxt order_trace_cfg
       val limit = Config.get ctxt order_split_limit_cfg
 
       fun is_not_less_term t =
-        case try Logic_Sig.dest_Trueprop t |> Option.map strip_Not of
-          SOME (binop $ _ $ _) => Pattern.matches thy (lt, binop)
-        | NONE => false
+        (case dest_binop (Logic_Sig.dest_Trueprop t) of
+          (false, (t0, _, _)) => Pattern.matches thy (lt, t0)
+        | _ => false)
+        handle TERM _ => false
 
-      val not_less_prems = filter (is_not_less_term o Thm.prop_of o fst) decomp_prems
+      val not_less_prems = filter (is_not_less_term o Thm.prop_of) prems
       val _ = if trace andalso length not_less_prems > limit
                 then tracing "order split limit exceeded"
                 else ()
      in
-      filter_out (is_not_less_term o Thm.prop_of o fst) decomp_prems @
+      filter_out (is_not_less_term o Thm.prop_of) prems @
       take limit not_less_prems
      end
-
-  fun decomp [eq, le, lt] ctxt t =
-    let
-      fun is_excluded t = exists (fn ty => ty = fastype_of t) excluded_types
-
-      fun decomp'' (binop $ t1 $ t2) =
-            let
-              open Order_Procedure
-              val thy = Proof_Context.theory_of ctxt
-              fun try_match pat = try (Pattern.match thy (pat, binop)) (Vartab.empty, Vartab.empty)
-            in if is_excluded t1 then NONE
-               else case (try_match eq, try_match le, try_match lt) of
-                      (SOME env, _, _) => SOME (true, EQ, (t1, t2), env)
-                    | (_, SOME env, _) => SOME (true, LEQ, (t1, t2), env)
-                    | (_, _, SOME env) => SOME (true, LESS, (t1, t2), env)
-                    | _ => NONE
-            end
-        | decomp'' _ = NONE
-
-        fun decomp' (nt $ t) =
-              if nt = Logic_Sig.Not
-                then decomp'' t |> Option.map (fn (b, c, p, e) => (not b, c, p, e))
-                else decomp'' (nt $ t)
-          | decomp' t = decomp'' t
-
-    in
-      try Logic_Sig.dest_Trueprop t |> Option.mapPartial decomp'
-    end
-
-  fun maximal_envs envs =
-    let
-      fun test_opt p (SOME x) = p x
-        | test_opt _ NONE = false
-
-      fun leq_env (tyenv1, tenv1) (tyenv2, tenv2) =
-        Vartab.forall (fn (v, ty) =>
-          Vartab.lookup tyenv2 v |> test_opt (fn ty2 => ty2 = ty)) tyenv1
-        andalso
-        Vartab.forall (fn (v, (ty, t)) =>
-          Vartab.lookup tenv2 v |> test_opt (fn (ty2, t2) => ty2 = ty andalso t2 aconv t)) tenv1
-
-      fun fold_env (i, env) es = fold_index (fn (i2, env2) => fn es =>
-        if i = i2 then es else if leq_env env env2 then (i, i2) :: es else es) envs es
-      
-      val env_order = fold_index fold_env envs []
-
-      val graph = fold_index (fn (i, env) => fn g => Int_Graph.new_node (i, env) g)
-                             envs Int_Graph.empty
-      val graph = fold Int_Graph.add_edge env_order graph
-
-      val strong_conns = Int_Graph.strong_conn graph
-      val maximals =
-        filter (fn comp => length comp = length (Int_Graph.all_succs graph comp)) strong_conns
-    in
-      map (Int_Graph.all_preds graph) maximals
-    end
       
   fun order_tac raw_order_proc octxt simp_prems =
     Subgoal.FOCUS (fn {prems=prems, context=ctxt, ...} =>
       let
         val trace = Config.get ctxt order_trace_cfg
 
-        fun these' _ [] = []
-          | these' f (x :: xs) = case f x of NONE => these' f xs | SOME y => (x, y) :: these' f xs
-
-        val prems = filter (fn p => null (Term.add_vars (Thm.prop_of p) [])) (simp_prems @ prems)
-        val decomp_prems = these' (decomp (#ops octxt) ctxt o Thm.prop_of) prems
-
-        fun env_of (_, (_, _, _, env)) = env
-        val env_groups = maximal_envs (map env_of decomp_prems)
-        
-        fun order_tac' (_, []) = no_tac
-          | order_tac' (env, decomp_prems) =
-            let
-              val [eq, le, lt] = #ops octxt |> map (Envir.subst_term env) |> map Envir.eta_contract
+        val binop_prems = filter (is_binop_term o Thm.prop_of) (prems @ simp_prems)
+        val strip_binop = (fn (x, _, _) => x) o snd o dest_binop
+        val binop_of = strip_binop o Logic_Sig.dest_Trueprop o Thm.prop_of
 
-              val decomp_prems = case #kind octxt of
-                                   Order => limit_not_less (#ops octxt) ctxt decomp_prems
-                                 | _ => decomp_prems
-      
-              fun reify_prem (_, (b, ctor, (x, y), _)) (ps, reifytab) =
-                (Reifytab.get_var x ##>> Reifytab.get_var y) reifytab
-                |>> (fn vp => (b, ctor (apply2 Int_of_integer vp)) :: ps)
-              val (reified_prems, reifytab) = fold_rev reify_prem decomp_prems ([], Reifytab.empty)
-
-              val _ = if trace then @{print} ([eq, le, lt], reified_prems, prems)
-                               else ([eq, le, lt], reified_prems, prems)
+        (* Due to local_setup, the operators of the order may contain schematic term and type
+           variables. We partition the premises according to distinct instances of those operators.
+         *)
+        val part_prems = partition_matches ctxt binop_of (#ops octxt) binop_prems
+          |> (case #kind octxt of
+                Order => map (fn (env, prems) =>
+                          (env, limit_not_less (#ops octxt) ctxt prems))
+              | _ => I)
+              
+        fun order_tac' (_, []) = no_tac
+          | order_tac' (env, prems) =
+            let
+              val [eq, le, lt] = #ops octxt
+              val subst_contract = Envir.eta_contract o Envir.subst_term env
+              val ord_ops = (subst_contract eq,
+                             subst_contract le,
+                             subst_contract lt)
   
-              val reified_prems_conj = foldl1 (fn (x, a) => And (x, a)) (map Atom reified_prems)
-              val prems_conj_thm = map fst decomp_prems
-                                   |> foldl1 (fn (x, a) => Logic_Sig.conjI OF [x, a])
-                                   |> Conv.fconv_rule Thm.eta_conversion 
+              val _ = if trace then @{print} (ord_ops, prems) else (ord_ops, prems)
+  
+              val prems_conj_thm = foldl1 (fn (x, a) => Logic_Sig.conjI OF [x, a]) prems
+                |> Conv.fconv_rule Thm.eta_conversion 
               val prems_conj = prems_conj_thm |> Thm.prop_of
-
+              val (reified_prems_conj, reifytab) =
+                reify_order_conj ord_ops (Logic_Sig.dest_Trueprop prems_conj) Reifytab.empty
+  
               val proof = raw_order_proc reified_prems_conj
   
               val assmtab = Termtab.make [(prems_conj, prems_conj_thm)]
-              val replay = replay_order_prf_trm (eq, le, lt) octxt ctxt reifytab assmtab
+              val replay = replay_order_prf_trm ord_ops octxt ctxt reifytab assmtab
             in
               case proof of
                 NONE => no_tac
               | SOME p => SOLVED' (resolve_tac ctxt [replay p]) 1
             end
      in
-       map (fn is => ` (env_of o hd) (map (nth decomp_prems) is) |> order_tac') env_groups
-       |> FIRST
+      FIRST (map order_tac' part_prems)
      end)
 
   val ad_absurdum_tac = SUBGOAL (fn (A, i) =>
@@ -345,8 +367,11 @@
       | NONE => resolve0_tac [Logic_Sig.ccontr] i)
 
   fun tac raw_order_proc octxt simp_prems ctxt =
-        EVERY' [ ad_absurdum_tac, CONVERSION Thm.eta_conversion
-               , order_tac raw_order_proc octxt simp_prems ctxt]
+      EVERY' [
+          ad_absurdum_tac,
+          CONVERSION Thm.eta_conversion,
+          order_tac raw_order_proc octxt simp_prems ctxt
+        ]
   
 end