slightly more efficient claset operations, using Item_Net to maintain rules in canonical order;
authorwenzelm
Sat, 14 May 2011 21:42:17 +0200
changeset 42810 2425068fe13a
parent 42808 30870aee8a3f
child 42811 c5146d5fc54c
slightly more efficient claset operations, using Item_Net to maintain rules in canonical order;
src/HOL/Tools/Sledgehammer/sledgehammer_filter.ML
src/Provers/classical.ML
src/Pure/item_net.ML
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_filter.ML	Sat May 14 18:29:06 2011 +0200
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_filter.ML	Sat May 14 21:42:17 2011 +0200
@@ -799,8 +799,8 @@
 fun clasimpset_rules_of ctxt =
   let
     val {safeIs, safeEs, hazIs, hazEs, ...} = ctxt |> claset_of |> rep_cs
-    val intros = safeIs @ hazIs
-    val elims = map Classical.classical_rule (safeEs @ hazEs)
+    val intros = Item_Net.content safeIs @ Item_Net.content hazIs
+    val elims = map Classical.classical_rule (Item_Net.content safeEs @ Item_Net.content hazEs)
     val simps = ctxt |> simpset_of |> dest_ss |> #simps
   in
     (mk_fact_table I I intros,
--- a/src/Provers/classical.ML	Sat May 14 18:29:06 2011 +0200
+++ b/src/Provers/classical.ML	Sat May 14 21:42:17 2011 +0200
@@ -37,11 +37,12 @@
 sig
   type claset
   val empty_cs: claset
+  val merge_cs: claset * claset -> claset
   val rep_cs: claset ->
-   {safeIs: thm list,
-    safeEs: thm list,
-    hazIs: thm list,
-    hazEs: thm list,
+   {safeIs: thm Item_Net.T,
+    safeEs: thm Item_Net.T,
+    hazIs: thm Item_Net.T,
+    hazEs: thm Item_Net.T,
     swrappers: (string * (Proof.context -> wrapper)) list,
     uwrappers: (string * (Proof.context -> wrapper)) list,
     safe0_netpair: netpair,
@@ -214,10 +215,10 @@
 
 datatype claset =
   CS of
-   {safeIs         : thm list,                (*safe introduction rules*)
-    safeEs         : thm list,                (*safe elimination rules*)
-    hazIs          : thm list,                (*unsafe introduction rules*)
-    hazEs          : thm list,                (*unsafe elimination rules*)
+   {safeIs         : thm Item_Net.T,          (*safe introduction rules*)
+    safeEs         : thm Item_Net.T,          (*safe elimination rules*)
+    hazIs          : thm Item_Net.T,          (*unsafe introduction rules*)
+    hazEs          : thm Item_Net.T,          (*unsafe elimination rules*)
     swrappers      : (string * (Proof.context -> wrapper)) list, (*for transforming safe_step_tac*)
     uwrappers      : (string * (Proof.context -> wrapper)) list, (*for transforming step_tac*)
     safe0_netpair  : netpair,                 (*nets for trivial cases*)
@@ -244,10 +245,10 @@
 
 val empty_cs =
   CS
-   {safeIs = [],
-    safeEs = [],
-    hazIs = [],
-    hazEs = [],
+   {safeIs = Thm.full_rules,
+    safeEs = Thm.full_rules,
+    hazIs = Thm.full_rules,
+    hazEs = Thm.full_rules,
     swrappers = [],
     uwrappers = [],
     safe0_netpair = empty_netpair,
@@ -294,9 +295,6 @@
 fun delete x = delete_tagged_list (joinrules x);
 fun delete' x = delete_tagged_list (joinrules' x);
 
-val mem_thm = member Thm.eq_thm_prop
-and rem_thm = remove Thm.eq_thm_prop;
-
 fun string_of_thm NONE = Display.string_of_thm_without_context
   | string_of_thm (SOME context) =
       Display.string_of_thm (Context.cases Syntax.init_pretty_global I context);
@@ -312,7 +310,7 @@
   else ();
 
 fun warn_rules context msg rules th =
-  mem_thm rules th andalso (warn_thm context msg th; true);
+  Item_Net.member rules th andalso (warn_thm context msg th; true);
 
 fun warn_claset context th (CS {safeIs, safeEs, hazIs, hazEs, ...}) =
   warn_rules context "Rule already declared as safe introduction (intro!)\n" safeIs th orelse
@@ -332,12 +330,12 @@
       val th' = flat_rule th;
       val (safe0_rls, safep_rls) = (*0 subgoals vs 1 or more*)
         List.partition Thm.no_prems [th'];
-      val nI = length safeIs + 1;
-      val nE = length safeEs;
+      val nI = Item_Net.length safeIs + 1;
+      val nE = Item_Net.length safeEs;
       val _ = warn_claset context th cs;
     in
       CS
-       {safeIs  = th::safeIs,
+       {safeIs = Item_Net.update th safeIs,
         safe0_netpair = insert (nI,nE) (safe0_rls, []) safe0_netpair,
         safep_netpair = insert (nI,nE) (safep_rls, []) safep_netpair,
         safeEs = safeEs,
@@ -361,12 +359,12 @@
       val th' = classical_rule (flat_rule th);
       val (safe0_rls, safep_rls) = (*0 subgoals vs 1 or more*)
         List.partition (fn rl => nprems_of rl=1) [th'];
-      val nI = length safeIs;
-      val nE = length safeEs + 1;
+      val nI = Item_Net.length safeIs;
+      val nE = Item_Net.length safeEs + 1;
       val _ = warn_claset context th cs;
     in
       CS
-       {safeEs  = th::safeEs,
+       {safeEs = Item_Net.update th safeEs,
         safe0_netpair = insert (nI,nE) ([], safe0_rls) safe0_netpair,
         safep_netpair = insert (nI,nE) ([], safep_rls) safep_netpair,
         safeIs = safeIs,
@@ -391,12 +389,12 @@
   else
     let
       val th' = flat_rule th;
-      val nI = length hazIs + 1;
-      val nE = length hazEs;
+      val nI = Item_Net.length hazIs + 1;
+      val nE = Item_Net.length hazEs;
       val _ = warn_claset context th cs;
     in
       CS
-       {hazIs = th :: hazIs,
+       {hazIs = Item_Net.update th hazIs,
         haz_netpair = insert (nI, nE) ([th'], []) haz_netpair,
         dup_netpair = insert (nI, nE) ([dup_intr th'], []) dup_netpair,
         safeIs = safeIs,
@@ -420,12 +418,12 @@
   else
     let
       val th' = classical_rule (flat_rule th);
-      val nI = length hazIs;
-      val nE = length hazEs + 1;
+      val nI = Item_Net.length hazIs;
+      val nE = Item_Net.length hazEs + 1;
       val _ = warn_claset context th cs;
     in
       CS
-       {hazEs = th :: hazEs,
+       {hazEs = Item_Net.update th hazEs,
         haz_netpair = insert (nI, nE) ([], [th']) haz_netpair,
         dup_netpair = insert (nI, nE) ([], [dup_elim th']) dup_netpair,
         safeIs = safeIs,
@@ -450,7 +448,7 @@
 fun delSI th
     (cs as CS {safeIs, safeEs, hazIs, hazEs, swrappers, uwrappers,
       safe0_netpair, safep_netpair, haz_netpair, dup_netpair, xtra_netpair}) =
-  if mem_thm safeIs th then
+  if Item_Net.member safeIs th then
     let
       val th' = flat_rule th;
       val (safe0_rls, safep_rls) = List.partition Thm.no_prems [th'];
@@ -458,7 +456,7 @@
       CS
        {safe0_netpair = delete (safe0_rls, []) safe0_netpair,
         safep_netpair = delete (safep_rls, []) safep_netpair,
-        safeIs = rem_thm th safeIs,
+        safeIs = Item_Net.remove th safeIs,
         safeEs = safeEs,
         hazIs = hazIs,
         hazEs = hazEs,
@@ -473,7 +471,7 @@
 fun delSE th
     (cs as CS {safeIs, safeEs, hazIs, hazEs, swrappers, uwrappers,
       safe0_netpair, safep_netpair, haz_netpair, dup_netpair, xtra_netpair}) =
-  if mem_thm safeEs th then
+  if Item_Net.member safeEs th then
     let
       val th' = classical_rule (flat_rule th);
       val (safe0_rls, safep_rls) = List.partition (fn rl => nprems_of rl = 1) [th'];
@@ -482,7 +480,7 @@
        {safe0_netpair = delete ([], safe0_rls) safe0_netpair,
         safep_netpair = delete ([], safep_rls) safep_netpair,
         safeIs = safeIs,
-        safeEs = rem_thm th safeEs,
+        safeEs = Item_Net.remove th safeEs,
         hazIs = hazIs,
         hazEs = hazEs,
         swrappers = swrappers,
@@ -496,14 +494,14 @@
 fun delI context th
     (cs as CS {safeIs, safeEs, hazIs, hazEs, swrappers, uwrappers,
       safe0_netpair, safep_netpair, haz_netpair, dup_netpair, xtra_netpair}) =
-  if mem_thm hazIs th then
+  if Item_Net.member hazIs th then
     let val th' = flat_rule th in
       CS
        {haz_netpair = delete ([th'], []) haz_netpair,
         dup_netpair = delete ([dup_intr th'], []) dup_netpair,
         safeIs = safeIs,
         safeEs = safeEs,
-        hazIs = rem_thm th hazIs,
+        hazIs = Item_Net.remove th hazIs,
         hazEs = hazEs,
         swrappers = swrappers,
         uwrappers = uwrappers,
@@ -518,7 +516,7 @@
 fun delE th
     (cs as CS {safeIs, safeEs, hazIs, hazEs, swrappers, uwrappers,
       safe0_netpair, safep_netpair, haz_netpair, dup_netpair, xtra_netpair}) =
-  if mem_thm hazEs th then
+  if Item_Net.member hazEs th then
     let val th' = classical_rule (flat_rule th) in
       CS
        {haz_netpair = delete ([], [th']) haz_netpair,
@@ -526,7 +524,7 @@
         safeIs = safeIs,
         safeEs = safeEs,
         hazIs = hazIs,
-        hazEs = rem_thm th hazEs,
+        hazEs = Item_Net.remove th hazEs,
         swrappers = swrappers,
         uwrappers = uwrappers,
         safe0_netpair = safe0_netpair,
@@ -538,9 +536,9 @@
 (*Delete ALL occurrences of "th" in the claset (perhaps from several lists)*)
 fun delrule context th (cs as CS {safeIs, safeEs, hazIs, hazEs, ...}) =
   let val th' = Tactic.make_elim th in
-    if mem_thm safeIs th orelse mem_thm safeEs th orelse
-      mem_thm hazIs th orelse mem_thm hazEs th orelse
-      mem_thm safeEs th' orelse mem_thm hazEs th'
+    if Item_Net.member safeIs th orelse Item_Net.member safeEs th orelse
+      Item_Net.member hazIs th orelse Item_Net.member hazEs th orelse
+      Item_Net.member safeEs th' orelse Item_Net.member hazEs th'
     then delSI th (delSE th (delI context th (delE th (delSE th' (delE th' cs)))))
     else (warn_thm context "Undeclared classical rule\n" th; cs)
   end;
@@ -570,28 +568,24 @@
 
 (* merge_cs *)
 
-(*Merge works by adding all new rules of the 2nd claset into the 1st claset.
-  Merging the term nets may look more efficient, but the rather delicate
-  treatment of priority might get muddled up.*)
+(*Merge works by adding all new rules of the 2nd claset into the 1st claset,
+  in order to preserve priorities reliably.*)
+
+fun merge_thms add thms1 thms2 =
+  fold_rev (fn thm => if Item_Net.member thms1 thm then I else add thm) (Item_Net.content thms2);
+
 fun merge_cs (cs as CS {safeIs, safeEs, hazIs, hazEs, ...},
     cs' as CS {safeIs = safeIs2, safeEs = safeEs2, hazIs = hazIs2, hazEs = hazEs2,
       swrappers, uwrappers, ...}) =
   if pointer_eq (cs, cs') then cs
   else
-    let
-      val safeIs' = fold rem_thm safeIs safeIs2;
-      val safeEs' = fold rem_thm safeEs safeEs2;
-      val hazIs' = fold rem_thm hazIs hazIs2;
-      val hazEs' = fold rem_thm hazEs hazEs2;
-    in
-      cs
-      |> fold_rev (addSI NONE NONE) safeIs'
-      |> fold_rev (addSE NONE NONE) safeEs'
-      |> fold_rev (addI NONE NONE) hazIs'
-      |> fold_rev (addE NONE NONE) hazEs'
-      |> map_swrappers (fn ws => AList.merge (op =) (K true) (ws, swrappers))
-      |> map_uwrappers (fn ws => AList.merge (op =) (K true) (ws, uwrappers))
-    end;
+    cs
+    |> merge_thms (addSI NONE NONE) safeIs safeIs2
+    |> merge_thms (addSE NONE NONE) safeEs safeEs2
+    |> merge_thms (addI NONE NONE) hazIs hazIs2
+    |> merge_thms (addE NONE NONE) hazEs hazEs2
+    |> map_swrappers (fn ws => AList.merge (op =) (K true) (ws, swrappers))
+    |> map_uwrappers (fn ws => AList.merge (op =) (K true) (ws, uwrappers));
 
 
 (* data *)
@@ -617,7 +611,7 @@
 fun print_claset ctxt =
   let
     val {safeIs, safeEs, hazIs, hazEs, swrappers, uwrappers, ...} = rep_claset_of ctxt;
-    val pretty_thms = map (Display.pretty_thm ctxt);
+    val pretty_thms = map (Display.pretty_thm ctxt) o Item_Net.content;
   in
     [Pretty.big_list "safe introduction rules (intro!):" (pretty_thms safeIs),
       Pretty.big_list "introduction rules (intro):" (pretty_thms hazIs),
--- a/src/Pure/item_net.ML	Sat May 14 18:29:06 2011 +0200
+++ b/src/Pure/item_net.ML	Sat May 14 21:42:17 2011 +0200
@@ -10,6 +10,7 @@
   type 'a T
   val init: ('a * 'a -> bool) -> ('a -> term list) -> 'a T
   val content: 'a T -> 'a list
+  val length: 'a T -> int
   val retrieve: 'a T -> term -> 'a list
   val member: 'a T -> 'a -> bool
   val merge: 'a T * 'a T -> 'a T
@@ -36,6 +37,7 @@
 fun init eq index = mk_items eq index [] ~1 Net.empty;
 
 fun content (Items {content, ...}) = content;
+fun length items = List.length (content items);
 fun retrieve (Items {net, ...}) = order_list o Net.unify_term net;