add hook to insert premises in the order solver
authorLukas Stevens <mail@lukas-stevens.de>
Fri, 31 Jan 2025 16:59:12 +0100
changeset 82026 57b4e44f5bc4
parent 82017 9a8d408492a7
child 82027 9c33627cea18
add hook to insert premises in the order solver
NEWS
src/HOL/Orderings.thy
src/Provers/order_tac.ML
--- a/NEWS	Wed Jan 29 21:25:44 2025 +0100
+++ b/NEWS	Fri Jan 31 16:59:12 2025 +0100
@@ -330,6 +330,10 @@
 * Theory "HOL-Library.Adhoc_Overloading" has been moved to Pure. Minor
 INCOMPATIBILITY: need to adjust theory imports.
 
+* Theory "HOL.Orderings":
+  Added experimental support for inserting additional premises when the order solver is called.
+  This can used to e.g. extend the order solver to deal with numerals. 
+  In Isabelle/HOL, hooks can be added with HOL_Base_Order_Tac.declare_insert_prems_hook.
 
 *** ML ***
 
--- a/src/HOL/Orderings.thy	Wed Jan 29 21:25:44 2025 +0100
+++ b/src/HOL/Orderings.thy	Fri Jan 31 16:59:12 2025 +0100
@@ -568,7 +568,7 @@
         Pretty.quote (Syntax.pretty_typ ctxt (type_of t)), Pretty.brk 1]
     fun pretty_order ({kind = kind, ops = ops, ...}, _) =
       Pretty.block ([Pretty.str (@{make_string} kind), Pretty.str ":", Pretty.brk 1]
-                    @ map pretty_term ops)
+                    @ map pretty_term [#eq ops, #le ops, #lt ops])
   in
     Pretty.writeln (Pretty.big_list "order structures:" (map pretty_order orders))
   end
--- a/src/Provers/order_tac.ML	Wed Jan 29 21:25:44 2025 +0100
+++ b/src/Provers/order_tac.ML	Fri Jan 31 16:59:12 2025 +0100
@@ -44,35 +44,78 @@
   val conj_disj_distribR_conv : conv (* (Q \<or> R) \<and> P \<equiv> (Q \<and> P) \<or> (R \<and> P) *)
 end
 
-(* Control tracing output of the solver. *)
-val order_trace_cfg = Attrib.setup_config_bool @{binding "order_trace"} (K false)
-(* In partial orders, literals of the form \<not> x < y will force the order solver to perform case
-   distinctions, which leads to an exponential blowup of the runtime. The split limit controls
-   the number of literals of this form that are passed to the solver. 
- *)
-val order_split_limit_cfg = Attrib.setup_config_int @{binding "order_split_limit"} (K 8)
+signature BASE_ORDER_TAC_BASE =
+sig
+  
+  val order_trace_cfg : bool Config.T
+  val order_split_limit_cfg : int Config.T
+  
+  datatype order_kind = Order | Linorder
+  
+  type order_literal = (bool * Order_Procedure.order_atom)
+  
+  type order_ops = { eq : term, le : term, lt : term }
+  
+  val map_order_ops : (term -> term) -> order_ops -> order_ops
+  
+  type order_context = {
+      kind : order_kind,
+      ops : order_ops,
+      thms : (string * thm) list, conv_thms : (string * thm) list
+    }
+
+end
 
-datatype order_kind = Order | Linorder
-
-type order_literal = (bool * Order_Procedure.order_atom)
+structure Base_Order_Tac_Base : BASE_ORDER_TAC_BASE =
+struct
+  
+  (* Control tracing output of the solver. *)
+  val order_trace_cfg = Attrib.setup_config_bool @{binding "order_trace"} (K false)
+  (* In partial orders, literals of the form \<not> x < y will force the order solver to perform case
+     distinctions, which leads to an exponential blowup of the runtime. The split limit controls
+     the number of literals of this form that are passed to the solver. 
+   *)
+  val order_split_limit_cfg = Attrib.setup_config_int @{binding "order_split_limit"} (K 8)
+  
+  datatype order_kind = Order | Linorder
+  
+  type order_literal = (bool * Order_Procedure.order_atom)
+  
+  type order_ops = { eq : term, le : term, lt : term }
+  
+  fun map_order_ops f {eq, le, lt} = {eq = f eq, le = f le, lt = f lt}
+  
+  type order_context = {
+      kind : order_kind,
+      ops : order_ops,
+      thms : (string * thm) list, conv_thms : (string * thm) list
+    }
 
-type order_context = {
-    kind : order_kind,
-    ops : term list, thms : (string * thm) list, conv_thms : (string * thm) list
-  }
+end
 
 signature BASE_ORDER_TAC =
 sig
+  include BASE_ORDER_TAC_BASE
+   
+  type insert_prems_hook =
+    order_kind -> order_ops -> Proof.context -> (thm * (bool * term * (term * term))) list
+      -> thm list
+
+  val declare_insert_prems_hook :
+    (binding * insert_prems_hook) -> local_theory -> local_theory
+
+  val insert_prems_hook_names : Proof.context -> binding list
 
   val tac :
-        (order_literal Order_Procedure.fm -> Order_Procedure.prf_trm option)
-        -> order_context -> thm list
-        -> Proof.context -> int -> tactic
+    (order_literal Order_Procedure.fm -> Order_Procedure.prf_trm option)
+      -> order_context -> thm list
+      -> Proof.context -> int -> tactic
 end
 
 functor Base_Order_Tac(
   structure Logic_Sig : LOGIC_SIGNATURE; val excluded_types : typ list) : BASE_ORDER_TAC =
 struct
+  open Base_Order_Tac_Base
   open Order_Procedure
 
   fun expect _ (SOME x) = x
@@ -198,50 +241,48 @@
   fun strip_Not (nt $ t) = if nt = Logic_Sig.Not then t else nt $ t
     | strip_Not t = t
 
-  fun limit_not_less [_, _, lt] ctxt decomp_prems =
+  fun limit_not_less lt ctxt decomp_prems =
     let
-      val thy = Proof_Context.theory_of ctxt
       val trace = Config.get ctxt order_trace_cfg
       val limit = Config.get ctxt order_split_limit_cfg
 
       fun is_not_less_term t =
         case try (strip_Not o Logic_Sig.dest_Trueprop) t of
-          SOME (binop $ _ $ _) => Pattern.matches thy (lt, binop)
-        | NONE => false
+          SOME (binop $ _ $ _) => binop = lt
+        | _ => false
 
       val not_less_prems = filter (is_not_less_term o Thm.prop_of o fst) decomp_prems
       val _ = if trace andalso length not_less_prems > limit
                 then tracing "order split limit exceeded"
                 else ()
-     in
+    in
       filter_out (is_not_less_term o Thm.prop_of o fst) decomp_prems @
       take limit not_less_prems
-     end
+    end
 
-  fun decomp [eq, le, lt] ctxt t =
+  fun decomp {eq, le, lt} ctxt t =
     let
-      fun is_excluded t = exists (fn ty => ty = fastype_of t) excluded_types
-
       fun decomp'' (binop $ t1 $ t2) =
             let
+              fun is_excluded t = exists (fn ty => ty = fastype_of t) excluded_types
+
               open Order_Procedure
               val thy = Proof_Context.theory_of ctxt
               fun try_match pat = try (Pattern.match thy (pat, binop)) (Vartab.empty, Vartab.empty)
             in if is_excluded t1 then NONE
                else case (try_match eq, try_match le, try_match lt) of
-                      (SOME env, _, _) => SOME (true, EQ, (t1, t2), env)
-                    | (_, SOME env, _) => SOME (true, LEQ, (t1, t2), env)
-                    | (_, _, SOME env) => SOME (true, LESS, (t1, t2), env)
+                      (SOME env, _, _) => SOME ((true, EQ, (t1, t2)), env)
+                    | (_, SOME env, _) => SOME ((true, LEQ, (t1, t2)), env)
+                    | (_, _, SOME env) => SOME ((true, LESS, (t1, t2)), env)
                     | _ => NONE
             end
         | decomp'' _ = NONE
 
         fun decomp' (nt $ t) =
               if nt = Logic_Sig.Not
-                then decomp'' t |> Option.map (fn (b, c, p, e) => (not b, c, p, e))
+                then decomp'' t |> Option.map (fn ((b, c, p), e) => ((not b, c, p), e))
                 else decomp'' (nt $ t)
           | decomp' t = decomp'' t
-
     in
       try Logic_Sig.dest_Trueprop t |> Option.mapPartial decomp'
     end
@@ -273,33 +314,111 @@
     in
       map (Int_Graph.all_preds graph) maximals
     end
+
+  fun partition_prems octxt ctxt prems =
+    let
+      fun these' _ [] = []
+        | these' f (x :: xs) = case f x of NONE => these' f xs | SOME y => (x, y) :: these' f xs
+      
+      val (decomp_prems, envs) =
+        these' (decomp (#ops octxt) ctxt o Thm.prop_of) prems
+        |> map_split (fn (thm, (l, env)) => ((thm, l), env))
+          
+      val env_groups = maximal_envs envs
+    in
+      map (fn is => (map (nth decomp_prems) is, nth envs (hd is))) env_groups
+    end
+
+  local
+    fun pretty_term_list ctxt =
+      Pretty.list "" "" o map (Syntax.pretty_term (Config.put show_types true ctxt))
+    fun pretty_type_of ctxt t = Pretty.block
+      [ Pretty.str "::", Pretty.brk 1
+      , Pretty.quote (Syntax.pretty_typ ctxt (Term.fastype_of t)) ]
+  in
+    fun pretty_order_kind (okind : order_kind) = Pretty.str (@{make_string} okind)
+    fun pretty_order_ops ctxt ({eq, le, lt} : order_ops) =
+      Pretty.block [pretty_term_list ctxt [eq, le, lt], Pretty.brk 1, pretty_type_of ctxt le]
+  end
+
+  type insert_prems_hook =
+    order_kind -> order_ops -> Proof.context -> (thm * (bool * term * (term * term))) list
+      -> thm list
+
+  structure Insert_Prems_Hook_Data = Generic_Data(
+    type T = (binding * insert_prems_hook) list
+    val empty = []
+    val merge = Library.merge ((op =) o apply2 fst)
+  )
+
+  fun declare_insert_prems_hook (binding, hook) lthy =
+    lthy |> Local_Theory.declaration {syntax = false, pervasive = false, pos = \<^here>}
+      (fn phi => fn context =>
+        let
+          val binding = Morphism.binding phi binding
+        in
+          context
+          |> Insert_Prems_Hook_Data.map (Library.insert ((op =) o apply2 fst) (binding, hook))
+        end)
+
+  val insert_prems_hook_names = Context.Proof #> Insert_Prems_Hook_Data.get #> map fst
+
+  fun eval_insert_prems_hook kind order_ops ctxt decomp_prems (hookN, hook : insert_prems_hook) = 
+    let
+      fun dereify_order_op' (EQ _) = #eq order_ops
+        | dereify_order_op' (LEQ _) = #le order_ops
+        | dereify_order_op' (LESS _) = #lt order_ops
+      fun dereify_order_op oop = (~1, ~1) |> apply2 Int_of_integer |> oop |> dereify_order_op'
+      val decomp_prems =
+        decomp_prems
+        |> map (apsnd (fn (b, oop, (t1, t2)) => (b, dereify_order_op oop, (t1, t2))))
+      fun unzip (acc1, acc2) [] = (rev acc1, rev acc2)
+        | unzip (acc1, acc2) ((thm, NONE) :: ps) = unzip (acc1, thm :: acc2) ps
+        | unzip (acc1, acc2) ((thm, SOME dp) :: ps) = unzip ((thm, dp) :: acc1, acc2) ps
+      val (decomp_extra_prems, invalid_extra_prems) =
+        hook kind order_ops ctxt decomp_prems
+        |> map (swap o ` (decomp order_ops ctxt o Thm.prop_of))
+        |> unzip ([], [])
+
+      val pretty_thm_list = Pretty.list "" "" o map (Thm.pretty_thm ctxt)
+      fun pretty_trace () = 
+        [ ("order kind:", pretty_order_kind kind)
+        , ("order operators:", pretty_order_ops ctxt order_ops)
+        , ("inserted premises:", pretty_thm_list (map fst decomp_extra_prems))
+        , ("invalid premises:", pretty_thm_list invalid_extra_prems)
+        ]
+        |> map (fn (t, pp) => Pretty.block [Pretty.str t, Pretty.brk 1, pp])
+        |> Pretty.big_list ("insert premises hook " ^ Pretty.string_of (Binding.pretty hookN) 
+            ^ " called with the parameters")
+      val trace = Config.get ctxt order_trace_cfg
+      val _ = if trace then tracing (Pretty.string_of (pretty_trace ())) else ()
+    in
+      map (apsnd fst) decomp_extra_prems
+    end
       
   fun order_tac raw_order_proc octxt simp_prems =
     Subgoal.FOCUS (fn {prems=prems, context=ctxt, ...} =>
       let
         val trace = Config.get ctxt order_trace_cfg
-
-        fun these' _ [] = []
-          | these' f (x :: xs) = case f x of NONE => these' f xs | SOME y => (x, y) :: these' f xs
-
-        val prems = simp_prems @ prems
-                    |> filter (fn p => null (Term.add_vars (Thm.prop_of p) []))
-                    |> map (Conv.fconv_rule Thm.eta_conversion)
-        val decomp_prems = these' (decomp (#ops octxt) ctxt o Thm.prop_of) prems
+        
+        fun order_tac' ([], _) = no_tac
+          | order_tac' (decomp_prems, env) =
+            let
+              val (order_ops as {eq, le, lt}) =
+                #ops octxt |> map_order_ops (Envir.eta_contract o Envir.subst_term env)
+                
+              val insert_prems_hooks = Insert_Prems_Hook_Data.get (Context.Proof ctxt)
+              val inserted_decomp_prems =
+                insert_prems_hooks
+                |> maps (eval_insert_prems_hook (#kind octxt) order_ops ctxt decomp_prems)
 
-        fun env_of (_, (_, _, _, env)) = env
-        val env_groups = maximal_envs (map env_of decomp_prems)
-        
-        fun order_tac' (_, []) = no_tac
-          | order_tac' (env, decomp_prems) =
-            let
-              val [eq, le, lt] = #ops octxt |> map (Envir.eta_contract o Envir.subst_term env)
-
-              val decomp_prems = case #kind octxt of
-                                   Order => limit_not_less (#ops octxt) ctxt decomp_prems
-                                 | _ => decomp_prems
+              val decomp_prems = decomp_prems @ inserted_decomp_prems
+              val decomp_prems =
+                case #kind octxt of
+                  Order => limit_not_less lt ctxt decomp_prems
+                | _ => decomp_prems
       
-              fun reify_prem (_, (b, ctor, (x, y), _)) (ps, reifytab) =
+              fun reify_prem (_, (b, ctor, (x, y))) (ps, reifytab) =
                 (Reifytab.get_var x ##>> Reifytab.get_var y) reifytab
                 |>> (fn vp => (b, ctor (apply2 Int_of_integer vp)) :: ps)
               val (reified_prems, reifytab) = fold_rev reify_prem decomp_prems ([], Reifytab.empty)
@@ -312,21 +431,17 @@
               
               val proof = raw_order_proc reified_prems_conj
 
-              val pretty_term_list =
-                Pretty.list "" "" o map (Syntax.pretty_term (Config.put show_types true ctxt))
               val pretty_thm_list = Pretty.list "" "" o map (Thm.pretty_thm ctxt)
-              fun pretty_type_of t = Pretty.block [ Pretty.str "::", Pretty.brk 1,
-                    Pretty.quote (Syntax.pretty_typ ctxt (Term.fastype_of t)) ]
               fun pretty_trace () = 
-                [ ("order kind:", Pretty.str (@{make_string} (#kind octxt)))
-                , ("order operators:", Pretty.block [ pretty_term_list [eq, le, lt], Pretty.brk 1
-                                                     , pretty_type_of le ])
+                [ ("order kind:", pretty_order_kind (#kind octxt))
+                , ("order operators:", pretty_order_ops ctxt order_ops)
                 , ("premises:", pretty_thm_list prems)
                 , ("selected premises:", pretty_thm_list (map fst decomp_prems))
                 , ("reified premises:", Pretty.str (@{make_string} reified_prems))
                 , ("contradiction:", Pretty.str (@{make_string} (Option.isSome proof)))
-                ] |> map (fn (t, pp) => Pretty.block [Pretty.str t, Pretty.brk 1, pp])
-                  |> Pretty.big_list "order solver called with the parameters"
+                ] 
+                |> map (fn (t, pp) => Pretty.block [Pretty.str t, Pretty.brk 1, pp])
+                |> Pretty.big_list "order solver called with the parameters"
               val _ = if trace then tracing (Pretty.string_of (pretty_trace ())) else ()
 
               val assmtab = Termtab.make [(prems_conj, prems_conj_thm)]
@@ -336,9 +451,12 @@
                 NONE => no_tac
               | SOME p => SOLVED' (resolve_tac ctxt [replay p]) 1
             end
+
+        val prems = simp_prems @ prems
+                    |> filter (fn p => null (Term.add_vars (Thm.prop_of p) []))
+                    |> map (Conv.fconv_rule Thm.eta_conversion)
      in
-       map (fn is => ` (env_of o hd) (map (nth decomp_prems) is) |> order_tac') env_groups
-       |> FIRST
+      partition_prems octxt ctxt prems |> map order_tac' |> FIRST
      end)
 
   val ad_absurdum_tac = SUBGOAL (fn (A, i) =>
@@ -355,11 +473,15 @@
 end
 
 functor Order_Tac(structure Base_Tac : BASE_ORDER_TAC) = struct
+  open Base_Tac
 
   fun order_context_eq ({kind = kind1, ops = ops1, ...}, {kind = kind2, ops = ops2, ...}) =
-    kind1 = kind2 andalso eq_list (op aconv) (ops1, ops2)
-
-  fun order_data_eq (x, y) = order_context_eq (fst x, fst y)
+    let
+      fun ops_list ops = [#eq ops, #le ops, #lt ops]
+    in
+      kind1 = kind2 andalso eq_list (op aconv) (apply2 ops_list (ops1, ops2))
+    end
+  val order_data_eq = order_context_eq o apply2 fst
   
   structure Data = Generic_Data(
     type T = (order_context * (order_context -> thm list -> Proof.context -> int -> tactic)) list
@@ -371,7 +493,7 @@
     lthy |> Local_Theory.declaration {syntax = false, pervasive = false, pos = \<^here>}
       (fn phi => fn context =>
         let
-          val ops = map (Morphism.term phi) (#ops octxt)
+          val ops = map_order_ops (Morphism.term phi) (#ops octxt)
           val thms = map (fn (s, thm) => (s, Morphism.thm phi thm)) (#thms octxt)
           val conv_thms = map (fn (s, thm) => (s, Morphism.thm phi thm)) (#conv_thms octxt)
           val octxt' = {kind = kind, ops = ops, thms = thms, conv_thms = conv_thms}
@@ -380,7 +502,7 @@
         end)
 
   fun declare_order {
-      ops = {eq = eq, le = le, lt = lt},
+      ops = ops,
       thms = {
         trans = trans, (* x \<le> y \<Longrightarrow> y \<le> z \<Longrightarrow> x \<le> z *)
         refl = refl, (* x \<le> x *)
@@ -396,7 +518,7 @@
     } =
     declare {
       kind = Order,
-      ops = [eq, le, lt],
+      ops = ops,
       thms = [("trans", trans), ("refl", refl), ("eqD1", eqD1), ("eqD2", eqD2),
               ("antisym", antisym), ("contr", contr)],
       conv_thms = [("less_le", less_le), ("nless_le", nless_le)],
@@ -404,7 +526,7 @@
      }                
 
   fun declare_linorder {
-      ops = {eq = eq, le = le, lt = lt},
+      ops = ops,
       thms = {
         trans = trans, (* x \<le> y \<Longrightarrow> y \<le> z \<Longrightarrow> x \<le> z *)
         refl = refl, (* x \<le> x *)
@@ -421,13 +543,13 @@
     } =
     declare {
       kind = Linorder,
-      ops = [eq, le, lt],
+      ops = ops,
       thms = [("trans", trans), ("refl", refl), ("eqD1", eqD1), ("eqD2", eqD2),
               ("antisym", antisym), ("contr", contr)],
       conv_thms = [("less_le", less_le), ("nless_le", nless_le), ("nle_le", nle_le)],
       raw_proc = Base_Tac.tac Order_Procedure.lo_contr_prf
      }
-  
+
   (* Try to solve the goal by calling the order solver with each of the declared orders. *)      
   fun tac simp_prems ctxt =
     let fun app_tac (octxt, tac0) = CHANGED o tac0 octxt simp_prems ctxt