performance tuning;
authorwenzelm
Mon, 10 Apr 2023 13:43:11 +0200
changeset 77800 9a30b76a6f60
parent 77799 3fb2c47a7605
child 77801 e7cf427f8b2a
performance tuning;
src/Pure/General/set.ML
src/Pure/General/table.ML
--- a/src/Pure/General/set.ML	Sun Apr 09 23:09:24 2023 +0200
+++ b/src/Pure/General/set.ML	Mon Apr 10 13:43:11 2023 +0200
@@ -46,7 +46,8 @@
   Leaf2 of elem * elem |
   Leaf3 of elem * elem * elem |
   Branch2 of T * elem * T |
-  Branch3 of T * elem * T * elem * T;
+  Branch3 of T * elem * T * elem * T |
+  Size of int * T;
 
 (*literal copy from table.ML*)
 fun make2 (Empty, e, Empty) = Leaf1 e
@@ -76,19 +77,24 @@
   | unmake (Leaf2 (e1, e2)) = Branch3 (Empty, e1, Empty, e2, Empty)
   | unmake (Leaf3 (e1, e2, e3)) =
       Branch2 (Branch2 (Empty, e1, Empty), e2, Branch2 (Empty, e3, Empty))
+  | unmake (Size (_, arg)) = arg
   | unmake arg = arg;
 
 
 (* size *)
 
 (*literal copy from table.ML*)
+fun make_size m arg = if m > 12 then Size (m, arg) else arg;
+
+(*literal copy from table.ML*)
 local
   fun count Empty n = n
     | count (Leaf1 _) n = n + 1
     | count (Leaf2 _) n = n + 2
     | count (Leaf3 _) n = n + 3
     | count (Branch2 (left, _, right)) n = count right (count left (n + 1))
-    | count (Branch3 (left, _, mid, _, right)) n = count right (count mid (count left (n + 2)));
+    | count (Branch3 (left, _, mid, _, right)) n = count right (count mid (count left (n + 2)))
+    | count (Size (m, _)) n = m + n;
 in
   val size = Integer.build o count;
 end;
@@ -100,7 +106,9 @@
 
 fun build (f: T -> T) = f empty;
 
+(*literal copy from table.ML*)
 fun is_empty Empty = true
+  | is_empty (Size (_, arg)) = is_empty arg
   | is_empty _ = false;
 
 
@@ -115,7 +123,8 @@
       | fold (Branch2 (left, e, right)) x =
           fold right (f e (fold left x))
       | fold (Branch3 (left, e1, mid, e2, right)) x =
-          fold right (f e2 (fold mid (f e1 (fold left x))));
+          fold right (f e2 (fold mid (f e1 (fold left x))))
+      | fold (Size (_, arg)) x = fold arg x;
   in fold end;
 
 fun fold_rev_set f =
@@ -127,7 +136,8 @@
       | fold_rev (Branch2 (left, e, right)) x =
           fold_rev left (f e (fold_rev right x))
       | fold_rev (Branch3 (left, e1, mid, e2, right)) x =
-          fold_rev left (f e1 (fold_rev mid (f e2 (fold_rev right x))));
+          fold_rev left (f e1 (fold_rev mid (f e2 (fold_rev right x))))
+      | fold_rev (Size (_, arg)) x = fold_rev arg x;
   in fold_rev end;
 
 val dest = Library.build o fold_rev_set cons;
@@ -144,7 +154,8 @@
       | ex (Branch2 (left, e, right)) =
           ex left orelse pred e orelse ex right
       | ex (Branch3 (left, e1, mid, e2, right)) =
-          ex left orelse pred e1 orelse ex mid orelse pred e2 orelse ex right;
+          ex left orelse pred e1 orelse ex mid orelse pred e2 orelse ex right
+      | ex (Size (_, arg)) = ex arg;
   in ex end;
 
 fun forall pred = not o exists (not o pred);
@@ -186,7 +197,8 @@
                       | some => some)
                   | some => some)
               | some => some)
-          | some => some);
+          | some => some)
+      | get (Size (_, arg)) = get arg;
   in get end;
 
 
@@ -222,7 +234,8 @@
               (case elem_ord e2 of
                 LESS => mem mid
               | EQUAL => true
-              | GREATER => mem right));
+              | GREATER => mem right))
+      | mem (Size (_, arg)) = mem arg;
   in mem set end;
 
 
@@ -285,11 +298,13 @@
                     (case ins right of
                       Stay right' => Stay (make3 (left, e1, mid, e2, right'))
                     | Sprout (right1, e', right2) =>
-                        Sprout (make2 (left, e1, mid), e2, make2 (right1, e', right2)))));
+                        Sprout (make2 (left, e1, mid), e2, make2 (right1, e', right2)))))
+        | ins (Size (_, arg)) = ins arg;
     in
-      (case ins set of
-        Stay set' => set'
-      | Sprout br => make2 br)
+      make_size (size set + 1)
+        (case ins set of
+          Stay set' => set'
+        | Sprout br => make2 br)
     end;
 
 fun make elems = build (fold insert elems);
@@ -389,12 +404,15 @@
                   make2 (mr, if_equal ord q' q, r'))
             | (Branch3 (ll, lp, lm, lq, lr), Branch2 (ml, mp, mr)) =>
                 make3 (make2 (ll, lp, lm), lq, make2 (lr, p, ml), mp,
-                  make2 (mr, if_equal ord q' q, r'))))));
+                  make2 (mr, if_equal ord q' q, r'))))))
+  | del k (Size (_, arg)) = del k arg;
 
 in
 
 fun remove elem set =
-  if member set elem then snd (snd (del (SOME elem) set)) else set;
+  if member set elem
+  then make_size (size set - 1) (snd (snd (del (SOME elem) set)))
+  else set;
 
 val subtract = fold_set remove;
 
--- a/src/Pure/General/table.ML	Sun Apr 09 23:09:24 2023 +0200
+++ b/src/Pure/General/table.ML	Mon Apr 10 13:43:11 2023 +0200
@@ -83,7 +83,8 @@
   Leaf2 of (key * 'a) * (key * 'a) |
   Leaf3 of (key * 'a) * (key * 'a) * (key * 'a) |
   Branch2 of 'a table * (key * 'a) * 'a table |
-  Branch3 of 'a table * (key * 'a) * 'a table * (key * 'a) * 'a table;
+  Branch3 of 'a table * (key * 'a) * 'a table * (key * 'a) * 'a table |
+  Size of int * 'a table;
 
 (*literal copy from set.ML*)
 fun make2 (Empty, e, Empty) = Leaf1 e
@@ -113,11 +114,15 @@
   | unmake (Leaf2 (e1, e2)) = Branch3 (Empty, e1, Empty, e2, Empty)
   | unmake (Leaf3 (e1, e2, e3)) =
       Branch2 (Branch2 (Empty, e1, Empty), e2, Branch2 (Empty, e3, Empty))
+  | unmake (Size (_, arg)) = arg
   | unmake arg = arg;
 
 
 (* size *)
 
+(*literal copy from set.ML*)
+fun make_size m arg = if m > 12 then Size (m, arg) else arg;
+
 local
   (*literal copy from set.ML*)
   fun count Empty n = n
@@ -125,7 +130,8 @@
     | count (Leaf2 _) n = n + 2
     | count (Leaf3 _) n = n + 3
     | count (Branch2 (left, _, right)) n = count right (count left (n + 1))
-    | count (Branch3 (left, _, mid, _, right)) n = count right (count mid (count left (n + 2)));
+    | count (Branch3 (left, _, mid, _, right)) n = count right (count mid (count left (n + 2)))
+    | count (Size (m, _)) n = m + n;
 in
   fun size tab = Integer.build (count tab);
 end;
@@ -137,7 +143,9 @@
 
 fun build (f: 'a table -> 'a table) = f empty;
 
+(*literal copy from set.ML*)
 fun is_empty Empty = true
+  | is_empty (Size (_, arg)) = is_empty arg
   | is_empty _ = false;
 
 
@@ -153,7 +161,8 @@
       | map (Branch2 (left, (k, x), right)) =
           Branch2 (map left, (k, f k x), map right)
       | map (Branch3 (left, (k1, x1), mid, (k2, x2), right)) =
-          Branch3 (map left, (k1, f k1 x1), map mid, (k2, f k2 x2), map right);
+          Branch3 (map left, (k1, f k1 x1), map mid, (k2, f k2 x2), map right)
+      | map (Size (m, arg)) = Size (m, map arg);
   in map end;
 
 fun fold_table f =
@@ -165,7 +174,8 @@
       | fold (Branch2 (left, e, right)) x =
           fold right (f e (fold left x))
       | fold (Branch3 (left, e1, mid, e2, right)) x =
-          fold right (f e2 (fold mid (f e1 (fold left x))));
+          fold right (f e2 (fold mid (f e1 (fold left x))))
+      | fold (Size (_, arg)) x = fold arg x;
   in fold end;
 
 fun fold_rev_table f =
@@ -177,7 +187,8 @@
       | fold_rev (Branch2 (left, e, right)) x =
           fold_rev left (f e (fold_rev right x))
       | fold_rev (Branch3 (left, e1, mid, e2, right)) x =
-          fold_rev left (f e1 (fold_rev mid (f e2 (fold_rev right x))));
+          fold_rev left (f e1 (fold_rev mid (f e2 (fold_rev right x))))
+      | fold_rev (Size (_, arg)) x = fold_rev arg x;
   in fold_rev end;
 
 fun dest tab = Library.build (fold_rev_table cons tab);
@@ -193,7 +204,8 @@
   | min (Branch2 (Empty, e, _)) = SOME e
   | min (Branch3 (Empty, e, _, _, _)) = SOME e
   | min (Branch2 (left, _, _)) = min left
-  | min (Branch3 (left, _, _, _, _)) = min left;
+  | min (Branch3 (left, _, _, _, _)) = min left
+  | min (Size (_, arg)) = min arg;
 
 fun max Empty = NONE
   | max (Leaf1 e) = SOME e
@@ -202,7 +214,8 @@
   | max (Branch2 (_, e, Empty)) = SOME e
   | max (Branch3 (_, _, _, e, Empty)) = SOME e
   | max (Branch2 (_, _, right)) = max right
-  | max (Branch3 (_, _, _, _, right)) = max right;
+  | max (Branch3 (_, _, _, _, right)) = max right
+  | max (Size (_, arg)) = max arg;
 
 
 (* exists and forall *)
@@ -216,7 +229,8 @@
       | ex (Branch2 (left, e, right)) =
           ex left orelse pred e orelse ex right
       | ex (Branch3 (left, e1, mid, e2, right)) =
-          ex left orelse pred e1 orelse ex mid orelse pred e2 orelse ex right;
+          ex left orelse pred e1 orelse ex mid orelse pred e2 orelse ex right
+      | ex (Size (_, arg)) = ex arg;
   in ex end;
 
 fun forall pred = not o exists (not o pred);
@@ -258,7 +272,8 @@
                       | some => some)
                   | some => some)
               | some => some)
-          | some => some);
+          | some => some)
+      | get (Size (_, arg)) = get arg;
   in get end;
 
 
@@ -295,7 +310,8 @@
               (case key_ord k2 of
                 LESS => look mid
               | EQUAL => SOME x2
-              | GREATER => look right));
+              | GREATER => look right))
+      | look (Size (_, arg)) = look arg;
   in look tab end;
 
 fun lookup_key tab key =
@@ -329,7 +345,8 @@
               (case key_ord k2 of
                 LESS => look mid
               | EQUAL => SOME (k2, x2)
-              | GREATER => look right));
+              | GREATER => look right))
+      | look (Size (_, arg)) = look arg;
   in look tab end;
 
 fun defined tab key =
@@ -362,7 +379,8 @@
               (case key_ord k2 of
                 LESS => def mid
               | EQUAL => true
-              | GREATER => def right));
+              | GREATER => def right))
+      | def (Size (_, arg)) = def arg;
   in def tab end;
 
 
@@ -378,7 +396,11 @@
   let
     fun key_ord k = Key.ord (key, k);
 
-    fun modfy Empty = Sprout (Empty, (key, f NONE), Empty)
+    val inc = Unsynchronized.ref 0;
+    fun insert () = f NONE before ignore (Unsynchronized.inc inc);
+    fun update x = f (SOME x);
+
+    fun modfy Empty = Sprout (Empty, (key, insert ()), Empty)
       | modfy (t as Leaf1 _) = modfy (unmake t)
       | modfy (t as Leaf2 _) = modfy (unmake t)
       | modfy (t as Leaf3 _) = modfy (unmake t)
@@ -388,7 +410,7 @@
               (case modfy left of
                 Stay left' => Stay (make2 (left', p, right))
               | Sprout (left1, q, left2) => Stay (make3 (left1, q, left2, p, right)))
-          | EQUAL => Stay (make2 (left, (k, f (SOME x)), right))
+          | EQUAL => Stay (make2 (left, (k, update x), right))
           | GREATER =>
               (case modfy right of
                 Stay right' => Stay (make2 (left, p, right'))
@@ -401,7 +423,7 @@
                 Stay left' => Stay (make3 (left', p1, mid, p2, right))
               | Sprout (left1, q, left2) =>
                   Sprout (make2 (left1, q, left2), p1, make2 (mid, p2, right)))
-          | EQUAL => Stay (make3 (left, (k1, f (SOME x1)), mid, p2, right))
+          | EQUAL => Stay (make3 (left, (k1, update x1), mid, p2, right))
           | GREATER =>
               (case key_ord k2 of
                 LESS =>
@@ -409,19 +431,21 @@
                     Stay mid' => Stay (make3 (left, p1, mid', p2, right))
                   | Sprout (mid1, q, mid2) =>
                       Sprout (make2 (left, p1, mid1), q, make2 (mid2, p2, right)))
-              | EQUAL => Stay (make3 (left, p1, mid, (k2, f (SOME x2)), right))
+              | EQUAL => Stay (make3 (left, p1, mid, (k2, update x2), right))
               | GREATER =>
                   (case modfy right of
                     Stay right' => Stay (make3 (left, p1, mid, p2, right'))
                   | Sprout (right1, q, right2) =>
-                      Sprout (make2 (left, p1, mid), p2, make2 (right1, q, right2)))));
+                      Sprout (make2 (left, p1, mid), p2, make2 (right1, q, right2)))))
+      | modfy (Size (_, arg)) = modfy arg;
 
+    val tab' =
+      (case modfy tab of
+        Stay tab' => tab'
+      | Sprout br => make2 br);
   in
-    (case modfy tab of
-      Stay tab' => tab'
-    | Sprout br => make2 br)
-    handle SAME => tab
-  end;
+    make_size (size tab + !inc) tab'
+  end handle SAME => tab;
 
 fun update (key, x) tab = modify key (fn _ => x) tab;
 fun update_new (key, x) tab = modify key (fn NONE => x | SOME _ => raise DUP key) tab;
@@ -516,11 +540,11 @@
                   make2 (mr, if_equal ord q' q, r'))
             | (Branch3 (ll, lp, lm, lq, lr), Branch2 (ml, mp, mr)) =>
                 make3 (make2 (ll, lp, lm), lq, make2 (lr, p, ml), mp,
-                  make2 (mr, if_equal ord q' q, r'))))));
-
+                  make2 (mr, if_equal ord q' q, r'))))))
+  | del k (Size (_, arg)) = del k arg;
 in
 
-fun delete key tab = snd (snd (del (SOME key) tab));
+fun delete key tab = make_size (size tab - 1) (snd (snd (del (SOME key) tab)));
 fun delete_safe key tab = if defined tab key then delete key tab else tab;
 
 end;