interface for wrappers;
authorwenzelm
Mon, 03 Dec 2001 21:02:26 +0100
changeset 12348 c3f34d7c50f8
parent 12347 6ee66b76d813
child 12349 94e812f9683e
interface for wrappers;
src/Pure/Isar/rule_context.ML
--- a/src/Pure/Isar/rule_context.ML	Mon Dec 03 21:02:08 2001 +0100
+++ b/src/Pure/Isar/rule_context.ML	Mon Dec 03 21:02:26 2001 +0100
@@ -11,8 +11,17 @@
 sig
   type netpair
   type T
+  val Snetpair: Proof.context -> netpair
+  val netpair: Proof.context -> netpair
+  val netpairs: Proof.context -> netpair list
+  val orderlist: ((int * int) * 'a) list -> 'a list
+  val orderlist_no_weight: ((int * int) * 'a) list -> 'a list
   val print_global_rules: theory -> unit
   val print_local_rules: Proof.context -> unit
+  val addSWrapper: ((int -> tactic) -> int -> tactic) -> theory -> theory
+  val addWrapper: ((int -> tactic) -> int -> tactic) -> theory -> theory
+  val Swrap: Proof.context -> (int -> tactic) -> int -> tactic
+  val wrap: Proof.context -> (int -> tactic) -> int -> tactic
   val intro_bang_global: int option -> theory attribute
   val elim_bang_global: int option -> theory attribute
   val dest_bang_global: int option -> theory attribute
@@ -33,8 +42,6 @@
   val elim_query_local: int option -> Proof.context attribute
   val dest_query_local: int option -> Proof.context attribute
   val rule_del_local: Proof.context attribute
-  val netpairs: Proof.context -> netpair list
-  val orderlist: ((int * int) * 'a) list -> 'a list
   val setup: (theory -> theory) list
 end;
 
@@ -65,7 +72,7 @@
 val rule_indexes = distinct (map #1 rule_kinds);
 
 
-(* netpairs *)
+(* raw data *)
 
 type netpair = ((int * int) * (bool * thm)) Net.net * ((int * int) * (bool * thm)) Net.net;
 val empty_netpairs: netpair list = replicate (length rule_indexes) (Net.empty, Net.empty);
@@ -74,7 +81,8 @@
  {next: int,
   rules: (int * ((int * bool) * thm)) list,
   netpairs: netpair list,
-  wrappers: ((bool * ((int -> tactic) -> int -> tactic)) * stamp) list};
+  wrappers: (((int -> tactic) -> int -> tactic) * stamp) list *
+    (((int -> tactic) -> int -> tactic) * stamp) list};
 
 fun make_rules next rules netpairs wrappers =
   Rules {next = next, rules = rules, netpairs = netpairs, wrappers = wrappers};
@@ -88,7 +96,7 @@
 fun del_rule th (rs as Rules {next, rules, netpairs, wrappers}) =
   let
     fun eq_th (_, (_, th')) = Thm.eq_thm (th, th');
-    fun del b netpair = delete_tagged_brl ((b, th), netpair);
+    fun del b netpair = delete_tagged_brl ((b, th), netpair) handle Net.DELETE => netpair;
   in
     if not (exists eq_th rules) then rs
     else make_rules next (filter_out eq_th rules) (map (del false o del true) netpairs) wrappers
@@ -98,7 +106,8 @@
   let
     fun prt_kind (i, b) =
       Pretty.big_list (the (assoc (kind_names, (i, b))) ^ ":")
-        (mapfilter (fn (_, (k, th)) => if k = (i, b) then Some (prt x th) else None) rules);
+        (mapfilter (fn (_, (k, th)) => if k = (i, b) then Some (prt x th) else None)
+          (sort (int_ord o pairself fst) rules));
   in Pretty.writeln (Pretty.chunks (map prt_kind rule_kinds)) end;
 
 
@@ -109,14 +118,14 @@
   val name = "Isar/rule_context";
   type T = T;
 
-  val empty = make_rules ~1 [] empty_netpairs [];
+  val empty = make_rules ~1 [] empty_netpairs ([], []);
   val copy = I;
   val prep_ext = I;
 
-  fun merge (Rules {rules = rules1, wrappers = wrappers1, ...},
-      Rules {rules = rules2, wrappers = wrappers2, ...}) =
+  fun merge (Rules {rules = rules1, wrappers = (ws1, ws1'), ...},
+      Rules {rules = rules2, wrappers = (ws2, ws2'), ...}) =
     let
-      val wrappers = gen_merge_lists' eq_snd wrappers1 wrappers2;
+      val wrappers = (gen_merge_lists' eq_snd ws1 ws2, gen_merge_lists' eq_snd ws1' ws2');
       val rules = gen_merge_lists' (fn ((_, (k1, th1)), (_, (k2, th2))) =>
           k1 = k2 andalso Thm.eq_thm (th1, th2)) rules1 rules2;
       val next = ~ (length rules);
@@ -142,6 +151,37 @@
 structure LocalRules = ProofDataFun(LocalRulesArgs);
 val print_local_rules = LocalRules.print;
 
+fun netpairs ctxt = let val Rules {netpairs, ...} = LocalRules.get ctxt in netpairs end;
+val Snetpair = hd o netpairs;
+val netpair = hd o tl o netpairs;
+
+
+fun untaglist [] = []
+  | untaglist [(k : int * int, x)] = [x]
+  | untaglist ((k, x) :: (rest as (k', x') :: _)) =
+      if k = k' then untaglist rest
+      else x :: untaglist rest;
+
+fun orderlist brls = untaglist (sort (prod_ord int_ord int_ord o pairself fst) brls);
+fun orderlist_no_weight brls = untaglist (sort (int_ord o pairself (snd o fst)) brls);
+
+
+(* wrappers *)
+
+fun gen_add_wrapper upd w = GlobalRules.map (fn (rs as Rules {next, rules, netpairs, wrappers}) =>
+  make_rules next rules netpairs (upd (fn ws => (w, stamp ()) :: ws) wrappers));
+
+val addSWrapper = gen_add_wrapper Library.apfst;
+val addWrapper = gen_add_wrapper Library.apsnd;
+
+
+fun gen_wrap which ctxt =
+  let val Rules {wrappers, ...} = LocalRules.get ctxt
+  in fn tac => foldr (fn ((w, _), t) => w t) (which wrappers, tac) end;
+
+val Swrap = gen_wrap #1;
+val wrap = gen_wrap #2;
+
 
 
 (** attributes **)
@@ -186,7 +226,7 @@
 (* concrete syntax *)
 
 fun add_args a b c x = Attrib.syntax
-  (Scan.lift (Scan.option (Args.bracks Args.nat) --
+  (Scan.lift (Scan.option Args.nat --
     (Args.bang >> K a || Args.query >> K c || Scan.succeed b) >> op |>)) x;
 
 fun del_args att = Attrib.syntax (Scan.lift Args.del >> K att);
@@ -210,15 +250,6 @@
 
 
 
-(** retrieving rules **)
-
-fun netpairs ctxt = let val Rules {netpairs, ...} = LocalRules.get ctxt in netpairs end;
-
-fun orderlist brls =
-  map snd (sort (prod_ord int_ord int_ord o pairself fst) brls);
-
-
-
 (** theory setup **)
 
 val setup =