efficient rule declarations in canonical order, for update of netpairs and print operation;
authorwenzelm
Mon, 07 Jul 2025 22:11:44 +0200
changeset 82826 f5fd9b41188a
parent 82820 ae85cd17ffbe
child 82827 b7c1c23058cf
efficient rule declarations in canonical order, for update of netpairs and print operation;
src/Pure/Isar/context_rules.ML
src/Pure/bires.ML
--- a/src/Pure/Isar/context_rules.ML	Sun Jul 06 15:26:59 2025 +0200
+++ b/src/Pure/Isar/context_rules.ML	Mon Jul 07 22:11:44 2025 +0200
@@ -40,68 +40,61 @@
 (* context data *)
 
 datatype rules = Rules of
- {next: int,
-  rules: (int * (Bires.kind * thm)) list,
+ {decls: Bires.decls,
   netpairs: Bires.netpair list,
   wrappers:
     ((Proof.context -> (int -> tactic) -> int -> tactic) * stamp) list *
     ((Proof.context -> (int -> tactic) -> int -> tactic) * stamp) list};
 
-fun make_rules next rules netpairs wrappers =
-  Rules {next = next, rules = rules, netpairs = netpairs, wrappers = wrappers};
+fun make_rules decls netpairs wrappers =
+  Rules {decls = decls, netpairs = netpairs, wrappers = wrappers};
 
-fun add_rule kind opt_weight th (Rules {next, rules, netpairs, wrappers}) =
+fun add_rule kind opt_weight th (rules as Rules {decls, netpairs, wrappers}) =
   let
     val weight = opt_weight |> \<^if_none>\<open>Bires.subgoals_of (Bires.kind_rule kind th)\<close>;
-    val tag = {weight = weight, index = next};
-    val th' = Thm.trim_context th;
-    val rules' = (weight, (kind, th')) :: rules;
-    val netpairs' = netpairs
-      |> Bires.kind_map kind (Bires.insert_tagged_rule (tag, Bires.kind_rule kind th'));
-  in make_rules (next - 1) rules' netpairs' wrappers end;
+    val decl = {kind = kind, tag = Bires.weight_tag weight, implicit = false};
+  in
+    (case Bires.extend_decls (Thm.trim_context th, decl) decls of
+      NONE => rules
+    | SOME (new_rule, decls') =>
+        let val netpairs' = netpairs |> Bires.kind_map kind (Bires.insert_rule new_rule)
+        in make_rules decls' netpairs' wrappers end)
+  end;
 
-fun del_rule0 th (rs as Rules {next, rules, netpairs, wrappers}) =
-  let
-    fun eq_th (_, (_, th')) = Thm.eq_thm_prop (th, th');
-    fun del b netpair = Bires.delete_tagged_rule (b, th) netpair handle Net.DELETE => netpair;
-    val rules' = filter_out eq_th rules;
-    val netpairs' = map (del false o del true) netpairs;
-  in
-    if not (exists eq_th rules) then rs
-    else make_rules next rules' netpairs' wrappers
-  end;
+fun del_rule0 th (rules as Rules {decls, netpairs, wrappers}) =
+  (case Bires.remove_decls th decls of
+    NONE => rules
+  | SOME decls' =>
+      let val netpairs' = map (Bires.remove_rule th) netpairs
+      in make_rules decls' netpairs' wrappers end);
 
 fun del_rule th = del_rule0 th o del_rule0 (Tactic.make_elim th);
 
 structure Data = Generic_Data
 (
   type T = rules;
-  val empty = make_rules ~1 [] Bires.kind_netpairs ([], []);
+  val empty = make_rules Bires.empty_decls Bires.kind_netpairs ([], []);
   fun merge
-    (Rules {rules = rules1, wrappers = (ws1, ws1'), ...},
-      Rules {rules = rules2, wrappers = (ws2, ws2'), ...}) =
+    (Rules {decls = decls1, netpairs = netpairs1, wrappers = (ws1, ws1')},
+     Rules {decls = decls2, netpairs = _, wrappers = (ws2, ws2')}) =
     let
+      val (new_rules, decls) = Bires.merge_decls (decls1, decls2);
+      val netpairs =
+        netpairs1 |> map_index (uncurry (fn i =>
+          new_rules |> fold (fn (th, decl) =>
+            Bires.kind_index (#kind decl) = i ? Bires.insert_rule (th, decl))));
       val wrappers =
-        (Library.merge (eq_snd (op =)) (ws1, ws2), Library.merge (eq_snd (op =)) (ws1', ws2'));
-      val rules = Library.merge (fn ((_, (k1, th1)), (_, (k2, th2))) =>
-          k1 = k2 andalso Thm.eq_thm_prop (th1, th2)) (rules1, rules2);
-      val next = ~ (length rules);
-      val netpairs =
-        fold (fn (index, (weight, (kind, th))) =>
-          Bires.kind_map kind
-            (Bires.insert_tagged_rule ({weight = weight, index = index}, (Bires.kind_elim kind, th))))
-        (next upto ~1 ~~ rules) Bires.kind_netpairs;
-    in make_rules (next - 1) rules netpairs wrappers end;
+       (Library.merge (eq_snd (op =)) (ws1, ws2),
+        Library.merge (eq_snd (op =)) (ws1', ws2'));
+    in make_rules decls netpairs wrappers end;
 );
 
 fun print_rules ctxt =
   let
-    val Rules {rules, ...} = Data.get (Context.Proof ctxt);
+    val Rules {decls, ...} = Data.get (Context.Proof ctxt);
     fun prt_kind kind =
       Pretty.big_list (Bires.kind_title kind ^ ":")
-        (map_filter (fn (_, (kind', th)) =>
-            if kind = kind' then SOME (Thm.pretty_thm_item ctxt th) else NONE)
-          (sort (int_ord o apply2 fst) rules));
+        (Bires.print_decls kind decls |> map (fn (th, _) => Thm.pretty_thm_item ctxt th));
   in Pretty.writeln (Pretty.chunks (map prt_kind Bires.kind_domain)) end;
 
 
@@ -142,8 +135,8 @@
 (* wrappers *)
 
 fun gen_add_wrapper upd w =
-  Context.theory_map (Data.map (fn Rules {next, rules, netpairs, wrappers} =>
-    make_rules next rules netpairs (upd (fn ws => (w, stamp ()) :: ws) wrappers)));
+  Context.theory_map (Data.map (fn Rules {decls, netpairs, wrappers} =>
+    make_rules decls netpairs (upd (fn ws => (w, stamp ()) :: ws) wrappers)));
 
 val addSWrapper = gen_add_wrapper Library.apfst;
 val addWrapper = gen_add_wrapper Library.apsnd;
--- a/src/Pure/bires.ML	Sun Jul 06 15:26:59 2025 +0200
+++ b/src/Pure/bires.ML	Mon Jul 07 22:11:44 2025 +0200
@@ -13,10 +13,12 @@
   val no_subgoals: rule -> bool
 
   type tag = {weight: int, index: int}
-  val tag0_ord: tag ord
+  val tag_weight_ord: tag ord
+  val tag_index_ord: tag ord
   val tag_ord: tag ord
   val weighted_tag_ord: bool -> tag ord
   val tag_order: (tag * 'a) list -> 'a list
+  val weight_tag: int -> tag
 
   type netpair = (tag * rule) Net.net * (tag * rule) Net.net
   val empty_netpair: netpair
@@ -50,6 +52,19 @@
   val kind_map: kind -> ('a -> 'a) -> 'a list -> 'a list
   val kind_rule: kind -> thm -> rule
   val kind_title: kind -> string
+
+  type decl = {kind: kind, tag: tag, implicit: bool}
+  val decl_ord: decl ord
+  val insert_rule: thm * decl -> netpair -> netpair
+  val remove_rule: thm -> netpair -> netpair
+  type decls
+  val has_decls: decls -> thm -> bool
+  val list_decls: (thm * decl -> bool) -> decls -> (thm * decl) list
+  val print_decls: kind -> decls -> (thm * decl) list
+  val merge_decls: decls * decls -> (thm * decl) list * decls
+  val extend_decls: thm * decl -> decls -> ((thm * decl) * decls) option
+  val remove_decls: thm -> decls -> decls option
+  val empty_decls: decls
 end
 
 structure Bires: BIRES =
@@ -74,13 +89,19 @@
 
 type tag = {weight: int, index: int};
 
-val tag0_ord: tag ord = int_ord o apply2 #index;
-val tag_ord: tag ord = int_ord o apply2 #weight ||| tag0_ord;
+val tag_weight_ord: tag ord = int_ord o apply2 #weight;
+val tag_index_ord: tag ord = int_ord o apply2 #index;
 
-fun weighted_tag_ord weighted = if weighted then tag_ord else tag0_ord;
+val tag_ord: tag ord = tag_weight_ord ||| tag_index_ord;
+
+fun weighted_tag_ord weighted = if weighted then tag_ord else tag_index_ord;
 
 fun tag_order list = make_order_list tag_ord NONE list;
 
+fun weight_tag weight : tag = {weight = weight, index = 0};
+
+fun next_tag next ({weight, ...}: tag) = {weight = weight, index = next};
+
 
 (* discrimination nets for intr/elim rules *)
 
@@ -138,6 +159,8 @@
 
 (** Rule kinds and declarations **)
 
+(* kind: intro! / elim! / intro / elim / intro? / elim? *)
+
 datatype kind = Kind of int * bool;
 
 val intro_bang_kind = Kind (0, false);
@@ -172,6 +195,75 @@
   in a ^ " rules " ^ b end;
 
 
+(* rule declarations in canonical order *)
+
+type decl = {kind: kind, tag: tag, implicit: bool};
+
+val decl_ord: decl ord = tag_index_ord o apply2 #tag;
+
+fun decl_equiv (decl1: decl, decl2: decl) =
+  #kind decl1 = #kind decl2 andalso
+  is_equal (tag_weight_ord (#tag decl1, #tag decl2));
+
+fun next_decl next ({kind, tag, implicit}: decl) : decl =
+  {kind = kind, tag = next_tag next tag, implicit = implicit};
+
+fun insert_rule (thm, {kind, tag, ...}: decl) netpair =
+  insert_tagged_rule (tag, kind_rule kind thm) netpair;
+
+fun remove_rule thm =
+  let fun del b netpair = delete_tagged_rule (b, thm) netpair handle Net.DELETE => netpair
+  in del false o del true end;
+
+
+abstype decls = Decls of {next: int, rules: decl list Thmtab.table}
+with
+
+local
+
+fun dest_decls pred (Decls {rules, ...}) =
+  build (rules |> Thmtab.fold (fn (th, ds) => ds |> fold (fn d => pred (th, d) ? cons (th, d))));
+
+fun dup_decls (Decls {rules, ...}) (thm, decl) =
+  member decl_equiv (Thmtab.lookup_list rules thm) decl;
+
+fun add_decls (thm, decl) (Decls {next, rules}) =
+  let
+    val decl' = next_decl next decl;
+    val decls' = Decls {next = next - 1, rules = Thmtab.cons_list (thm, decl') rules};
+  in ((thm, decl'), decls') end;
+
+in
+
+fun has_decls (Decls {rules, ...}) = Thmtab.defined rules;
+
+fun list_decls pred =
+  dest_decls pred #> sort (rev_order o decl_ord o apply2 #2);
+
+fun print_decls kind =
+  dest_decls (fn (_, {kind = kind', implicit, ...}) => kind = kind' andalso not implicit)
+  #> sort (tag_ord o apply2 (#tag o #2));
+
+fun merge_decls (decls1, decls2) =
+  decls1 |> fold_map add_decls (list_decls (not o dup_decls decls1) decls2);
+
+fun extend_decls (thm, decl) decls =
+  if dup_decls decls (thm, decl) then NONE
+  else SOME (add_decls (thm, decl) decls);
+
+fun remove_decls thm (decls as Decls {next, rules}) =
+  if has_decls decls thm
+  then SOME (Decls {next = next, rules = Thmtab.delete thm rules})
+  else NONE;
+
+val empty_decls = Decls {next = ~1, rules = Thmtab.empty};
+
+end;
+
+end;
+
+
+
 (** Simpler version for resolve_tac -- only one net, and no hyps **)
 
 type net = (int * thm) Net.net;