--- a/src/Pure/General/bitset.ML Wed Nov 29 00:07:54 2023 +0100
+++ b/src/Pure/General/bitset.ML Wed Nov 29 11:54:12 2023 +0100
@@ -7,28 +7,30 @@
signature BITSET =
sig
+ structure Key: KEY
+ type elem
type T
val empty: T
val build: (T -> T) -> T
val is_empty: T -> bool
- val fold: (int -> 'a -> 'a) -> T -> 'a -> 'a
- val fold_rev: (int -> 'a -> 'a) -> T -> 'a -> 'a
- val dest: T -> int list
- val min: T -> int option
- val max: T -> int option
- val get_first: (int -> 'a option) -> T -> 'a option
- val exists: (int -> bool) -> T -> bool
- val forall: (int -> bool) -> T -> bool
- val member: T -> int -> bool
+ val fold: (elem -> 'a -> 'a) -> T -> 'a -> 'a
+ val fold_rev: (elem -> 'a -> 'a) -> T -> 'a -> 'a
+ val dest: T -> elem list
+ val min: T -> elem option
+ val max: T -> elem option
+ val get_first: (elem -> 'a option) -> T -> 'a option
+ val exists: (elem -> bool) -> T -> bool
+ val forall: (elem -> bool) -> T -> bool
+ val member: T -> elem -> bool
val subset: T * T -> bool
val eq_set: T * T -> bool
- val insert: int -> T -> T
- val make: int list -> T
+ val insert: elem -> T -> T
+ val make: elem list -> T
val merge: T * T -> T
val merges: T list -> T
- val remove: int -> T -> T
+ val remove: elem -> T -> T
val subtract: T -> T -> T
- val restrict: (int -> bool) -> T -> T
+ val restrict: (elem -> bool) -> T -> T
val inter: T -> T -> T
val union: T -> T -> T
end;
@@ -50,9 +52,6 @@
val mimimum_bit = make_bit min_bit;
val maximum_bit = make_bit max_bit;
-fun make_int m n = if check_bit n then m * word_size + n else raise BAD n;
-fun dest_int x = Integer.div_mod x word_size;
-
fun add_bits v w = Word.orb (v, w);
fun del_bits v w = Word.andb (Word.notb v, w);
fun incl_bits v w = add_bits v w = w;
@@ -78,8 +77,17 @@
(* datatype *)
+structure Key = Inttab.Key;
+type elem = Key.key;
+
+fun make_elem m n : elem = if check_bit n then m * word_size + n else raise BAD n;
+fun dest_elem (x: elem) = Integer.div_mod x word_size;
+
datatype T = Bitset of word Inttab.table;
+
+(* empty *)
+
val empty = Bitset Inttab.empty;
fun build (f: T -> T) = f empty;
@@ -91,11 +99,11 @@
fun fold_set f (Bitset t) =
Inttab.fold (fn (m, w) =>
- (if m < 0 then fold_rev_bits else fold_bits) (f o make_int m) w) t;
+ (if m < 0 then fold_rev_bits else fold_bits) (f o make_elem m) w) t;
fun fold_rev_set f (Bitset t) =
Inttab.fold_rev (fn (m, w) =>
- (if m < 0 then fold_bits else fold_rev_bits) (f o make_int m) w) t;
+ (if m < 0 then fold_bits else fold_rev_bits) (f o make_elem m) w) t;
val dest = Library.build o fold_rev_set cons;
@@ -104,11 +112,11 @@
fun min (Bitset t) =
Inttab.min t |> Option.map (fn (m, w) =>
- make_int m (fold_bits Integer.min w max_bit));
+ make_elem m (fold_bits Integer.min w max_bit));
fun max (Bitset t) =
Inttab.max t |> Option.map (fn (m, w) =>
- make_int m (fold_bits Integer.max w min_bit));
+ make_elem m (fold_bits Integer.max w min_bit));
(* linear search *)
@@ -126,7 +134,7 @@
(* member *)
fun member (Bitset t) x =
- let val (m, n) = dest_int x in
+ let val (m, n) = dest_elem x in
(case Inttab.lookup t m of
NONE => false
| SOME w => incl_bits (make_bit n) w)
@@ -150,7 +158,7 @@
(* insert *)
fun insert x (Bitset t) =
- let val (m, n) = dest_int x
+ let val (m, n) = dest_elem x
in Bitset (Inttab.map_default (m, 0w0) (add_bits (make_bit n)) t) end;
fun make xs = build (fold insert xs);
@@ -174,7 +182,7 @@
(* remove *)
fun remove x (set as Bitset t) =
- let val (m, n) = dest_int x in
+ let val (m, n) = dest_elem x in
(case Inttab.lookup t m of
NONE => set
| SOME w =>