efficient representation of sets: more compact than Table.set;
authorwenzelm
Mon, 27 Mar 2023 19:41:18 +0200
changeset 77722 8faf28a80a7f
parent 77721 7c1cc9ce9340
child 77723 b761c91c2447
efficient representation of sets: more compact than Table.set;
src/Pure/General/set.ML
src/Pure/ROOT.ML
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Pure/General/set.ML	Mon Mar 27 19:41:18 2023 +0200
@@ -0,0 +1,279 @@
+(*  Title:      Pure/General/set.ML
+    Author:     Makarius
+
+Efficient representation of sets (see also Pure/General/table.ML).
+*)
+
+signature SET =
+sig
+  type elem
+  type T
+  val empty: T
+  val build: (T -> T) -> T
+  val is_empty: T -> bool
+  val is_single: T -> bool
+  val fold: (elem -> 'a -> 'a) -> T -> 'a -> 'a
+  val fold_rev: (elem -> 'a -> 'a) -> T -> 'a -> 'a
+  val dest: T -> elem list
+  val exists: (elem -> bool) -> T -> bool
+  val forall: (elem -> bool) -> T -> bool
+  val member: T -> elem -> bool
+  val insert: elem -> T -> T
+  val make: elem list -> T
+  val merge: T * T -> T
+  val remove: elem -> T -> T
+end;
+
+functor Set(Key: KEY): SET =
+struct
+
+(* datatype *)
+
+type elem = Key.key;
+
+datatype T =
+  Empty |
+  Branch2 of T * elem * T |
+  Branch3 of T * elem * T * elem * T;
+
+
+(* empty and single *)
+
+val empty = Empty;
+
+fun build (f: T -> T) = f empty;
+
+fun is_empty Empty = true
+  | is_empty _ = false;
+
+fun is_single (Branch2 (Empty, _, Empty)) = true
+  | is_single _ = false;
+
+
+(* fold combinators *)
+
+fun fold_set f =
+  let
+    fun fold Empty x = x
+      | fold (Branch2 (left, e, right)) x =
+          fold right (f e (fold left x))
+      | fold (Branch3 (left, e1, mid, e2, right)) x =
+          fold right (f e2 (fold mid (f e1 (fold left x))));
+  in fold end;
+
+fun fold_rev_set f =
+  let
+    fun fold Empty x = x
+      | fold (Branch2 (left, e, right)) x =
+          fold left (f e (fold right x))
+      | fold (Branch3 (left, e1, mid, e2, right)) x =
+          fold left (f e1 (fold mid (f e2 (fold right x))));
+  in fold end;
+
+fun dest tab = fold_rev_set cons tab [];
+
+
+(* exists and forall *)
+
+fun exists pred =
+  let
+    fun ex Empty = false
+      | ex (Branch2 (left, e, right)) =
+          ex left orelse pred e orelse ex right
+      | ex (Branch3 (left, e1, mid, e2, right)) =
+          ex left orelse pred e1 orelse ex mid orelse pred e2 orelse ex right;
+  in ex end;
+
+fun forall pred = not o exists (not o pred);
+
+
+(* member *)
+
+fun member set elem =
+  let
+    fun mem Empty = false
+      | mem (Branch2 (left, e, right)) =
+          (case Key.ord (elem, e) of
+            LESS => mem left
+          | EQUAL => true
+          | GREATER => mem right)
+      | mem (Branch3 (left, e1, mid, e2, right)) =
+          (case Key.ord (elem, e1) of
+            LESS => mem left
+          | EQUAL => true
+          | GREATER =>
+              (case Key.ord (elem, e2) of
+                LESS => mem mid
+              | EQUAL => true
+              | GREATER => mem right));
+  in mem set end;
+
+
+(* insert *)
+
+datatype growth = Stay of T | Sprout of T * elem * T;
+
+fun insert elem set =
+  if member set elem then set
+  else
+    let
+      fun ins Empty = Sprout (Empty, elem, Empty)
+        | ins (Branch2 (left, e, right)) =
+            (case Key.ord (elem, e) of
+              LESS =>
+                (case ins left of
+                  Stay left' => Stay (Branch2 (left', e, right))
+                | Sprout (left1, e', left2) => Stay (Branch3 (left1, e', left2, e, right)))
+            | EQUAL => Stay (Branch2 (left, e, right))
+            | GREATER =>
+                (case ins right of
+                  Stay right' => Stay (Branch2 (left, e, right'))
+                | Sprout (right1, e', right2) =>
+                    Stay (Branch3 (left, e, right1, e', right2))))
+        | ins (Branch3 (left, e1, mid, e2, right)) =
+            (case Key.ord (elem, e1) of
+              LESS =>
+                (case ins left of
+                  Stay left' => Stay (Branch3 (left', e1, mid, e2, right))
+                | Sprout (left1, e', left2) =>
+                    Sprout (Branch2 (left1, e', left2), e1, Branch2 (mid, e2, right)))
+            | EQUAL => Stay (Branch3 (left, e1, mid, e2, right))
+            | GREATER =>
+                (case Key.ord (elem, e2) of
+                  LESS =>
+                    (case ins mid of
+                      Stay mid' => Stay (Branch3 (left, e1, mid', e2, right))
+                    | Sprout (mid1, e', mid2) =>
+                        Sprout (Branch2 (left, e1, mid1), e', Branch2 (mid2, e2, right)))
+                | EQUAL => Stay (Branch3 (left, e1, mid, e2, right))
+                | GREATER =>
+                    (case ins right of
+                      Stay right' => Stay (Branch3 (left, e1, mid, e2, right'))
+                    | Sprout (right1, e', right2) =>
+                        Sprout (Branch2 (left, e1, mid), e2, Branch2 (right1, e', right2)))));
+    in
+      (case ins set of
+        Stay set' => set'
+      | Sprout br => Branch2 br)
+    end;
+
+fun make elems = build (fold insert elems);
+
+fun merge (set1, set2) =
+  if pointer_eq (set1, set2) then set1
+  else if is_empty set1 then set2
+  else fold_set insert set2 set1;
+
+
+(* remove *)
+
+local
+
+fun compare NONE _ = LESS
+  | compare (SOME e1) e2 = Key.ord (e1, e2);
+
+fun if_eq EQUAL x y = x
+  | if_eq _ x y = y;
+
+exception UNDEF of elem;
+
+(*literal copy from table.ML -- by Stefan Berghofer*)
+fun del (SOME k) Empty = raise UNDEF k
+  | del NONE (Branch2 (Empty, p, Empty)) = (p, (true, Empty))
+  | del NONE (Branch3 (Empty, p, Empty, q, Empty)) =
+      (p, (false, Branch2 (Empty, q, Empty)))
+  | del k (Branch2 (Empty, p, Empty)) =
+      (case compare k p of
+        EQUAL => (p, (true, Empty))
+      | _ => raise UNDEF (the k))
+  | del k (Branch3 (Empty, p, Empty, q, Empty)) =
+      (case compare k p of
+        EQUAL => (p, (false, Branch2 (Empty, q, Empty)))
+      | _ =>
+        (case compare k q of
+          EQUAL => (q, (false, Branch2 (Empty, p, Empty)))
+        | _ => raise UNDEF (the k)))
+  | del k (Branch2 (l, p, r)) =
+      (case compare k p of
+        LESS =>
+          (case del k l of
+            (p', (false, l')) => (p', (false, Branch2 (l', p, r)))
+          | (p', (true, l')) => (p', case r of
+              Branch2 (rl, rp, rr) =>
+                (true, Branch3 (l', p, rl, rp, rr))
+            | Branch3 (rl, rp, rm, rq, rr) => (false, Branch2
+                (Branch2 (l', p, rl), rp, Branch2 (rm, rq, rr)))))
+      | ord =>
+          (case del (if_eq ord NONE k) r of
+            (p', (false, r')) => (p', (false, Branch2 (l, if_eq ord p' p, r')))
+          | (p', (true, r')) => (p', case l of
+              Branch2 (ll, lp, lr) =>
+                (true, Branch3 (ll, lp, lr, if_eq ord p' p, r'))
+            | Branch3 (ll, lp, lm, lq, lr) => (false, Branch2
+                (Branch2 (ll, lp, lm), lq, Branch2 (lr, if_eq ord p' p, r'))))))
+  | del k (Branch3 (l, p, m, q, r)) =
+      (case compare k q of
+        LESS =>
+          (case compare k p of
+            LESS =>
+              (case del k l of
+                (p', (false, l')) => (p', (false, Branch3 (l', p, m, q, r)))
+              | (p', (true, l')) => (p', (false, case (m, r) of
+                  (Branch2 (ml, mp, mr), Branch2 _) =>
+                    Branch2 (Branch3 (l', p, ml, mp, mr), q, r)
+                | (Branch3 (ml, mp, mm, mq, mr), _) =>
+                    Branch3 (Branch2 (l', p, ml), mp, Branch2 (mm, mq, mr), q, r)
+                | (Branch2 (ml, mp, mr), Branch3 (rl, rp, rm, rq, rr)) =>
+                    Branch3 (Branch2 (l', p, ml), mp, Branch2 (mr, q, rl), rp,
+                      Branch2 (rm, rq, rr)))))
+          | ord =>
+              (case del (if_eq ord NONE k) m of
+                (p', (false, m')) =>
+                  (p', (false, Branch3 (l, if_eq ord p' p, m', q, r)))
+              | (p', (true, m')) => (p', (false, case (l, r) of
+                  (Branch2 (ll, lp, lr), Branch2 _) =>
+                    Branch2 (Branch3 (ll, lp, lr, if_eq ord p' p, m'), q, r)
+                | (Branch3 (ll, lp, lm, lq, lr), _) =>
+                    Branch3 (Branch2 (ll, lp, lm), lq,
+                      Branch2 (lr, if_eq ord p' p, m'), q, r)
+                | (_, Branch3 (rl, rp, rm, rq, rr)) =>
+                    Branch3 (l, if_eq ord p' p, Branch2 (m', q, rl), rp,
+                      Branch2 (rm, rq, rr))))))
+      | ord =>
+          (case del (if_eq ord NONE k) r of
+            (q', (false, r')) =>
+              (q', (false, Branch3 (l, p, m, if_eq ord q' q, r')))
+          | (q', (true, r')) => (q', (false, case (l, m) of
+              (Branch2 _, Branch2 (ml, mp, mr)) =>
+                Branch2 (l, p, Branch3 (ml, mp, mr, if_eq ord q' q, r'))
+            | (_, Branch3 (ml, mp, mm, mq, mr)) =>
+                Branch3 (l, p, Branch2 (ml, mp, mm), mq,
+                  Branch2 (mr, if_eq ord q' q, r'))
+            | (Branch3 (ll, lp, lm, lq, lr), Branch2 (ml, mp, mr)) =>
+                Branch3 (Branch2 (ll, lp, lm), lq, Branch2 (lr, p, ml), mp,
+                  Branch2 (mr, if_eq ord q' q, r'))))));
+
+in
+
+fun remove elem set =
+  if member set elem then snd (snd (del (SOME elem) set)) else set;
+
+end;
+
+
+(* ML pretty-printing *)
+
+val _ =
+  ML_system_pp (fn depth => fn _ => fn set =>
+    ML_Pretty.to_polyml
+      (ML_Pretty.enum "," "{" "}" (ML_Pretty.from_polyml o ML_system_pretty) (dest set, depth)));
+
+
+(*final declarations of this structure!*)
+val fold = fold_set;
+val fold_rev = fold_rev_set;
+
+end;
+
+structure Intset = Set(type key = int val ord = int_ord);
+structure Symset = Set(type key = string val ord = fast_string_ord);
--- a/src/Pure/ROOT.ML	Mon Mar 27 16:24:54 2023 +0200
+++ b/src/Pure/ROOT.ML	Mon Mar 27 19:41:18 2023 +0200
@@ -40,6 +40,7 @@
 ML_file "General/print_mode.ML";
 ML_file "General/alist.ML";
 ML_file "General/table.ML";
+ML_file "General/set.ML";
 
 ML_file "General/random.ML";
 ML_file "General/value.ML";