src/Provers/classical.ML
changeset 42810 2425068fe13a
parent 42807 e639d91d9073
child 42812 dda4aef7cba4
--- a/src/Provers/classical.ML	Sat May 14 18:29:06 2011 +0200
+++ b/src/Provers/classical.ML	Sat May 14 21:42:17 2011 +0200
@@ -37,11 +37,12 @@
 sig
   type claset
   val empty_cs: claset
+  val merge_cs: claset * claset -> claset
   val rep_cs: claset ->
-   {safeIs: thm list,
-    safeEs: thm list,
-    hazIs: thm list,
-    hazEs: thm list,
+   {safeIs: thm Item_Net.T,
+    safeEs: thm Item_Net.T,
+    hazIs: thm Item_Net.T,
+    hazEs: thm Item_Net.T,
     swrappers: (string * (Proof.context -> wrapper)) list,
     uwrappers: (string * (Proof.context -> wrapper)) list,
     safe0_netpair: netpair,
@@ -214,10 +215,10 @@
 
 datatype claset =
   CS of
-   {safeIs         : thm list,                (*safe introduction rules*)
-    safeEs         : thm list,                (*safe elimination rules*)
-    hazIs          : thm list,                (*unsafe introduction rules*)
-    hazEs          : thm list,                (*unsafe elimination rules*)
+   {safeIs         : thm Item_Net.T,          (*safe introduction rules*)
+    safeEs         : thm Item_Net.T,          (*safe elimination rules*)
+    hazIs          : thm Item_Net.T,          (*unsafe introduction rules*)
+    hazEs          : thm Item_Net.T,          (*unsafe elimination rules*)
     swrappers      : (string * (Proof.context -> wrapper)) list, (*for transforming safe_step_tac*)
     uwrappers      : (string * (Proof.context -> wrapper)) list, (*for transforming step_tac*)
     safe0_netpair  : netpair,                 (*nets for trivial cases*)
@@ -244,10 +245,10 @@
 
 val empty_cs =
   CS
-   {safeIs = [],
-    safeEs = [],
-    hazIs = [],
-    hazEs = [],
+   {safeIs = Thm.full_rules,
+    safeEs = Thm.full_rules,
+    hazIs = Thm.full_rules,
+    hazEs = Thm.full_rules,
     swrappers = [],
     uwrappers = [],
     safe0_netpair = empty_netpair,
@@ -294,9 +295,6 @@
 fun delete x = delete_tagged_list (joinrules x);
 fun delete' x = delete_tagged_list (joinrules' x);
 
-val mem_thm = member Thm.eq_thm_prop
-and rem_thm = remove Thm.eq_thm_prop;
-
 fun string_of_thm NONE = Display.string_of_thm_without_context
   | string_of_thm (SOME context) =
       Display.string_of_thm (Context.cases Syntax.init_pretty_global I context);
@@ -312,7 +310,7 @@
   else ();
 
 fun warn_rules context msg rules th =
-  mem_thm rules th andalso (warn_thm context msg th; true);
+  Item_Net.member rules th andalso (warn_thm context msg th; true);
 
 fun warn_claset context th (CS {safeIs, safeEs, hazIs, hazEs, ...}) =
   warn_rules context "Rule already declared as safe introduction (intro!)\n" safeIs th orelse
@@ -332,12 +330,12 @@
       val th' = flat_rule th;
       val (safe0_rls, safep_rls) = (*0 subgoals vs 1 or more*)
         List.partition Thm.no_prems [th'];
-      val nI = length safeIs + 1;
-      val nE = length safeEs;
+      val nI = Item_Net.length safeIs + 1;
+      val nE = Item_Net.length safeEs;
       val _ = warn_claset context th cs;
     in
       CS
-       {safeIs  = th::safeIs,
+       {safeIs = Item_Net.update th safeIs,
         safe0_netpair = insert (nI,nE) (safe0_rls, []) safe0_netpair,
         safep_netpair = insert (nI,nE) (safep_rls, []) safep_netpair,
         safeEs = safeEs,
@@ -361,12 +359,12 @@
       val th' = classical_rule (flat_rule th);
       val (safe0_rls, safep_rls) = (*0 subgoals vs 1 or more*)
         List.partition (fn rl => nprems_of rl=1) [th'];
-      val nI = length safeIs;
-      val nE = length safeEs + 1;
+      val nI = Item_Net.length safeIs;
+      val nE = Item_Net.length safeEs + 1;
       val _ = warn_claset context th cs;
     in
       CS
-       {safeEs  = th::safeEs,
+       {safeEs = Item_Net.update th safeEs,
         safe0_netpair = insert (nI,nE) ([], safe0_rls) safe0_netpair,
         safep_netpair = insert (nI,nE) ([], safep_rls) safep_netpair,
         safeIs = safeIs,
@@ -391,12 +389,12 @@
   else
     let
       val th' = flat_rule th;
-      val nI = length hazIs + 1;
-      val nE = length hazEs;
+      val nI = Item_Net.length hazIs + 1;
+      val nE = Item_Net.length hazEs;
       val _ = warn_claset context th cs;
     in
       CS
-       {hazIs = th :: hazIs,
+       {hazIs = Item_Net.update th hazIs,
         haz_netpair = insert (nI, nE) ([th'], []) haz_netpair,
         dup_netpair = insert (nI, nE) ([dup_intr th'], []) dup_netpair,
         safeIs = safeIs,
@@ -420,12 +418,12 @@
   else
     let
       val th' = classical_rule (flat_rule th);
-      val nI = length hazIs;
-      val nE = length hazEs + 1;
+      val nI = Item_Net.length hazIs;
+      val nE = Item_Net.length hazEs + 1;
       val _ = warn_claset context th cs;
     in
       CS
-       {hazEs = th :: hazEs,
+       {hazEs = Item_Net.update th hazEs,
         haz_netpair = insert (nI, nE) ([], [th']) haz_netpair,
         dup_netpair = insert (nI, nE) ([], [dup_elim th']) dup_netpair,
         safeIs = safeIs,
@@ -450,7 +448,7 @@
 fun delSI th
     (cs as CS {safeIs, safeEs, hazIs, hazEs, swrappers, uwrappers,
       safe0_netpair, safep_netpair, haz_netpair, dup_netpair, xtra_netpair}) =
-  if mem_thm safeIs th then
+  if Item_Net.member safeIs th then
     let
       val th' = flat_rule th;
       val (safe0_rls, safep_rls) = List.partition Thm.no_prems [th'];
@@ -458,7 +456,7 @@
       CS
        {safe0_netpair = delete (safe0_rls, []) safe0_netpair,
         safep_netpair = delete (safep_rls, []) safep_netpair,
-        safeIs = rem_thm th safeIs,
+        safeIs = Item_Net.remove th safeIs,
         safeEs = safeEs,
         hazIs = hazIs,
         hazEs = hazEs,
@@ -473,7 +471,7 @@
 fun delSE th
     (cs as CS {safeIs, safeEs, hazIs, hazEs, swrappers, uwrappers,
       safe0_netpair, safep_netpair, haz_netpair, dup_netpair, xtra_netpair}) =
-  if mem_thm safeEs th then
+  if Item_Net.member safeEs th then
     let
       val th' = classical_rule (flat_rule th);
       val (safe0_rls, safep_rls) = List.partition (fn rl => nprems_of rl = 1) [th'];
@@ -482,7 +480,7 @@
        {safe0_netpair = delete ([], safe0_rls) safe0_netpair,
         safep_netpair = delete ([], safep_rls) safep_netpair,
         safeIs = safeIs,
-        safeEs = rem_thm th safeEs,
+        safeEs = Item_Net.remove th safeEs,
         hazIs = hazIs,
         hazEs = hazEs,
         swrappers = swrappers,
@@ -496,14 +494,14 @@
 fun delI context th
     (cs as CS {safeIs, safeEs, hazIs, hazEs, swrappers, uwrappers,
       safe0_netpair, safep_netpair, haz_netpair, dup_netpair, xtra_netpair}) =
-  if mem_thm hazIs th then
+  if Item_Net.member hazIs th then
     let val th' = flat_rule th in
       CS
        {haz_netpair = delete ([th'], []) haz_netpair,
         dup_netpair = delete ([dup_intr th'], []) dup_netpair,
         safeIs = safeIs,
         safeEs = safeEs,
-        hazIs = rem_thm th hazIs,
+        hazIs = Item_Net.remove th hazIs,
         hazEs = hazEs,
         swrappers = swrappers,
         uwrappers = uwrappers,
@@ -518,7 +516,7 @@
 fun delE th
     (cs as CS {safeIs, safeEs, hazIs, hazEs, swrappers, uwrappers,
       safe0_netpair, safep_netpair, haz_netpair, dup_netpair, xtra_netpair}) =
-  if mem_thm hazEs th then
+  if Item_Net.member hazEs th then
     let val th' = classical_rule (flat_rule th) in
       CS
        {haz_netpair = delete ([], [th']) haz_netpair,
@@ -526,7 +524,7 @@
         safeIs = safeIs,
         safeEs = safeEs,
         hazIs = hazIs,
-        hazEs = rem_thm th hazEs,
+        hazEs = Item_Net.remove th hazEs,
         swrappers = swrappers,
         uwrappers = uwrappers,
         safe0_netpair = safe0_netpair,
@@ -538,9 +536,9 @@
 (*Delete ALL occurrences of "th" in the claset (perhaps from several lists)*)
 fun delrule context th (cs as CS {safeIs, safeEs, hazIs, hazEs, ...}) =
   let val th' = Tactic.make_elim th in
-    if mem_thm safeIs th orelse mem_thm safeEs th orelse
-      mem_thm hazIs th orelse mem_thm hazEs th orelse
-      mem_thm safeEs th' orelse mem_thm hazEs th'
+    if Item_Net.member safeIs th orelse Item_Net.member safeEs th orelse
+      Item_Net.member hazIs th orelse Item_Net.member hazEs th orelse
+      Item_Net.member safeEs th' orelse Item_Net.member hazEs th'
     then delSI th (delSE th (delI context th (delE th (delSE th' (delE th' cs)))))
     else (warn_thm context "Undeclared classical rule\n" th; cs)
   end;
@@ -570,28 +568,24 @@
 
 (* merge_cs *)
 
-(*Merge works by adding all new rules of the 2nd claset into the 1st claset.
-  Merging the term nets may look more efficient, but the rather delicate
-  treatment of priority might get muddled up.*)
+(*Merge works by adding all new rules of the 2nd claset into the 1st claset,
+  in order to preserve priorities reliably.*)
+
+fun merge_thms add thms1 thms2 =
+  fold_rev (fn thm => if Item_Net.member thms1 thm then I else add thm) (Item_Net.content thms2);
+
 fun merge_cs (cs as CS {safeIs, safeEs, hazIs, hazEs, ...},
     cs' as CS {safeIs = safeIs2, safeEs = safeEs2, hazIs = hazIs2, hazEs = hazEs2,
       swrappers, uwrappers, ...}) =
   if pointer_eq (cs, cs') then cs
   else
-    let
-      val safeIs' = fold rem_thm safeIs safeIs2;
-      val safeEs' = fold rem_thm safeEs safeEs2;
-      val hazIs' = fold rem_thm hazIs hazIs2;
-      val hazEs' = fold rem_thm hazEs hazEs2;
-    in
-      cs
-      |> fold_rev (addSI NONE NONE) safeIs'
-      |> fold_rev (addSE NONE NONE) safeEs'
-      |> fold_rev (addI NONE NONE) hazIs'
-      |> fold_rev (addE NONE NONE) hazEs'
-      |> map_swrappers (fn ws => AList.merge (op =) (K true) (ws, swrappers))
-      |> map_uwrappers (fn ws => AList.merge (op =) (K true) (ws, uwrappers))
-    end;
+    cs
+    |> merge_thms (addSI NONE NONE) safeIs safeIs2
+    |> merge_thms (addSE NONE NONE) safeEs safeEs2
+    |> merge_thms (addI NONE NONE) hazIs hazIs2
+    |> merge_thms (addE NONE NONE) hazEs hazEs2
+    |> map_swrappers (fn ws => AList.merge (op =) (K true) (ws, swrappers))
+    |> map_uwrappers (fn ws => AList.merge (op =) (K true) (ws, uwrappers));
 
 
 (* data *)
@@ -617,7 +611,7 @@
 fun print_claset ctxt =
   let
     val {safeIs, safeEs, hazIs, hazEs, swrappers, uwrappers, ...} = rep_claset_of ctxt;
-    val pretty_thms = map (Display.pretty_thm ctxt);
+    val pretty_thms = map (Display.pretty_thm ctxt) o Item_Net.content;
   in
     [Pretty.big_list "safe introduction rules (intro!):" (pretty_thms safeIs),
       Pretty.big_list "introduction rules (intro):" (pretty_thms hazIs),