src/Pure/General/table.ML
changeset 15665 7e7412fffc0c
parent 15574 b1d1b5bfc464
child 15761 c9561302c74a
--- a/src/Pure/General/table.ML	Thu Apr 07 09:26:29 2005 +0200
+++ b/src/Pure/General/table.ML	Thu Apr 07 09:26:40 2005 +0200
@@ -31,11 +31,12 @@
   val lookup: 'a table * key -> 'a option
   val update: (key * 'a) * 'a table -> 'a table
   val update_new: (key * 'a) * 'a table -> 'a table                    (*exception DUP*)
+  val map_entry: key -> ('a -> 'a) -> 'a table -> 'a table
   val make: (key * 'a) list -> 'a table                                (*exception DUPS*)
   val extend: 'a table * (key * 'a) list -> 'a table                   (*exception DUPS*)
   val join: ('a * 'a -> 'a option) -> 'a table * 'a table -> 'a table  (*exception DUPS*)
   val merge: ('a * 'a -> bool) -> 'a table * 'a table -> 'a table      (*exception DUPS*)
-  val delete: key -> 'a table -> 'a table    (* exception UNDEF *)
+  val delete: key -> 'a table -> 'a table                              (*exception UNDEF*)
   val lookup_multi: 'a list table * key -> 'a list
   val update_multi: (key * 'a) * 'a list table -> 'a list table
   val make_multi: (key * 'a) list -> 'a list table
@@ -90,12 +91,13 @@
 fun exists P tab = foldl_table (fn (false, e) => P e | (b, _) => b) (false, tab);
 
 fun min_key Empty = NONE
-  | min_key (Branch2 (left, (k, _), _)) = SOME (getOpt (min_key left,k))
-  | min_key (Branch3 (left, (k, _), _, _, _)) = SOME (getOpt (min_key left,k));
+  | min_key (Branch2 (left, (k, _), _)) = SOME (getOpt (min_key left, k))
+  | min_key (Branch3 (left, (k, _), _, _, _)) = SOME (getOpt (min_key left, k));
 
 fun max_key Empty = NONE
-  | max_key (Branch2 (_, (k, _), right)) = SOME (getOpt (max_key right,k))
-  | max_key (Branch3 (_, _, _, (k,_), right)) = SOME (getOpt (max_key right,k));
+  | max_key (Branch2 (_, (k, _), right)) = SOME (getOpt (max_key right, k))
+  | max_key (Branch3 (_, _, _, (k,_), right)) = SOME (getOpt (max_key right, k));
+
 
 (* lookup *)
 
@@ -116,57 +118,67 @@
           | GREATER => lookup (right, key)));
 
 
-(* update *)
+(* updates *)
 
-fun compare (k1, _) (k2, _) = Key.ord (k1, k2);
+local
+
+exception SAME;
 
 datatype 'a growth =
   Stay of 'a table |
   Sprout of 'a table * (key * 'a) * 'a table;
 
-fun insert pair Empty = Sprout (Empty, pair, Empty)
-  | insert pair (Branch2 (left, p, right)) =
-      (case compare pair p of
-        LESS =>
-          (case insert pair left of
-            Stay left' => Stay (Branch2 (left', p, right))
-          | Sprout (left1, q, left2) => Stay (Branch3 (left1, q, left2, p, right)))
-      | EQUAL => Stay (Branch2 (left, pair, right))
-      | GREATER =>
-          (case insert pair right of
-            Stay right' => Stay (Branch2 (left, p, right'))
-          | Sprout (right1, q, right2) =>
-              Stay (Branch3 (left, p, right1, q, right2))))
-  | insert pair (Branch3 (left, p1, mid, p2, right)) =
-      (case compare pair p1 of
-        LESS =>
-          (case insert pair left of
-            Stay left' => Stay (Branch3 (left', p1, mid, p2, right))
-          | Sprout (left1, q, left2) =>
-              Sprout (Branch2 (left1, q, left2), p1, Branch2 (mid, p2, right)))
-      | EQUAL => Stay (Branch3 (left, pair, mid, p2, right))
-      | GREATER =>
-          (case compare pair p2 of
+fun modify key f tab =
+  let
+    fun modfy Empty = Sprout (Empty, (key, f NONE), Empty)
+      | modfy (Branch2 (left, p as (k, x), right)) =
+          (case Key.ord (key, k) of
+            LESS =>
+              (case modfy left of
+                Stay left' => Stay (Branch2 (left', p, right))
+              | Sprout (left1, q, left2) => Stay (Branch3 (left1, q, left2, p, right)))
+          | EQUAL => Stay (Branch2 (left, (k, f (SOME x)), right))
+          | GREATER =>
+              (case modfy right of
+                Stay right' => Stay (Branch2 (left, p, right'))
+              | Sprout (right1, q, right2) =>
+                  Stay (Branch3 (left, p, right1, q, right2))))
+      | modfy (Branch3 (left, p1 as (k1, x1), mid, p2 as (k2, x2), right)) =
+          (case Key.ord (key, k1) of
             LESS =>
-              (case insert pair mid of
-                Stay mid' => Stay (Branch3 (left, p1, mid', p2, right))
-              | Sprout (mid1, q, mid2) =>
-                  Sprout (Branch2 (left, p1, mid1), q, Branch2 (mid2, p2, right)))
-          | EQUAL => Stay (Branch3 (left, p1, mid, pair, right))
+              (case modfy left of
+                Stay left' => Stay (Branch3 (left', p1, mid, p2, right))
+              | Sprout (left1, q, left2) =>
+                  Sprout (Branch2 (left1, q, left2), p1, Branch2 (mid, p2, right)))
+          | EQUAL => Stay (Branch3 (left, (k1, f (SOME x1)), mid, p2, right))
           | GREATER =>
-              (case insert pair right of
-                Stay right' => Stay (Branch3 (left, p1, mid, p2, right'))
-              | Sprout (right1, q, right2) =>
-                  Sprout (Branch2 (left, p1, mid), p2, Branch2 (right1, q, right2)))));
+              (case Key.ord (key, k2) of
+                LESS =>
+                  (case modfy mid of
+                    Stay mid' => Stay (Branch3 (left, p1, mid', p2, right))
+                  | Sprout (mid1, q, mid2) =>
+                      Sprout (Branch2 (left, p1, mid1), q, Branch2 (mid2, p2, right)))
+              | EQUAL => Stay (Branch3 (left, p1, mid, (k2, f (SOME x2)), right))
+              | GREATER =>
+                  (case modfy right of
+                    Stay right' => Stay (Branch3 (left, p1, mid, p2, right'))
+                  | Sprout (right1, q, right2) =>
+                      Sprout (Branch2 (left, p1, mid), p2, Branch2 (right1, q, right2)))));
 
-fun update (pair, tab) =
-  (case insert pair tab of
-    Stay tab => tab
-  | Sprout br => Branch2 br);
+  in
+    (case modfy tab of
+      Stay tab' => tab'
+    | Sprout br => Branch2 br)
+    handle SAME => tab
+  end;
 
-fun update_new (pair as (key, _), tab) =
-  if is_none (lookup (tab, key)) then update (pair, tab)
-  else raise DUP key;
+in
+
+fun update ((k, x), tab) = modify k (fn _ => x) tab;
+fun update_new ((k, x), tab) = modify k (fn NONE => x | SOME _ => raise DUP k) tab;
+fun map_entry k f = modify k (fn NONE => raise SAME | SOME x => f x);
+
+end;
 
 
 (* extend and make *)
@@ -188,26 +200,28 @@
 
 (* delete *)
 
-fun compare' NONE (k2, _) = LESS
-  | compare' (SOME k1) (k2, _) = Key.ord (k1, k2);
+exception UNDEF of key;
+
+local
+
+fun compare NONE (k2, _) = LESS
+  | compare (SOME k1) (k2, _) = Key.ord (k1, k2);
 
 fun if_eq EQUAL x y = x
   | if_eq _ x y = y;
 
-exception UNDEF of key;
-
 fun del (SOME k) Empty = raise UNDEF k
   | del NONE (Branch2 (Empty, p, Empty)) = (p, (true, Empty))
   | del NONE (Branch3 (Empty, p, Empty, q, Empty)) =
       (p, (false, Branch2 (Empty, q, Empty)))
-  | del k (Branch2 (Empty, p, Empty)) = (case compare' k p of
+  | del k (Branch2 (Empty, p, Empty)) = (case compare k p of
       EQUAL => (p, (true, Empty)) | _ => raise UNDEF (valOf k))
-  | del k (Branch3 (Empty, p, Empty, q, Empty)) = (case compare' k p of
+  | del k (Branch3 (Empty, p, Empty, q, Empty)) = (case compare k p of
       EQUAL => (p, (false, Branch2 (Empty, q, Empty)))
-    | _ => (case compare' k q of
+    | _ => (case compare k q of
         EQUAL => (q, (false, Branch2 (Empty, p, Empty)))
       | _ => raise UNDEF (valOf k)))
-  | del k (Branch2 (l, p, r)) = (case compare' k p of
+  | del k (Branch2 (l, p, r)) = (case compare k p of
       LESS => (case del k l of
         (p', (false, l')) => (p', (false, Branch2 (l', p, r)))
       | (p', (true, l')) => (p', case r of
@@ -222,8 +236,8 @@
             (true, Branch3 (ll, lp, lr, if_eq ord p' p, r'))
         | Branch3 (ll, lp, lm, lq, lr) => (false, Branch2
             (Branch2 (ll, lp, lm), lq, Branch2 (lr, if_eq ord p' p, r'))))))
-  | del k (Branch3 (l, p, m, q, r)) = (case compare' k q of
-      LESS => (case compare' k p of
+  | del k (Branch3 (l, p, m, q, r)) = (case compare k q of
+      LESS => (case compare k p of
         LESS => (case del k l of
           (p', (false, l')) => (p', (false, Branch3 (l', p, m, q, r)))
         | (p', (true, l')) => (p', (false, case (m, r) of
@@ -259,8 +273,12 @@
             Branch3 (Branch2 (ll, lp, lm), lq, Branch2 (lr, p, ml), mp,
               Branch2 (mr, if_eq ord q' q, r'))))));
 
+in
+
 fun delete k t = snd (snd (del (SOME k) t));
 
+end;
+
 
 (* join and merge *)