src/Pure/General/table.ML
author wenzelm
Thu, 07 Apr 2005 09:26:40 +0200
changeset 15665 7e7412fffc0c
parent 15574 b1d1b5bfc464
child 15761 c9561302c74a
permissions -rw-r--r--
tuned updates, added map_entry;

(*  Title:      Pure/General/table.ML
    ID:         $Id$
    Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen

Generic tables and tables indexed by strings.  Efficient purely
functional implementation using balanced 2-3 trees.
*)

signature KEY =
sig
  type key
  val ord: key * key -> order
end;

signature TABLE =
sig
  type key
  type 'a table
  exception DUP of key
  exception DUPS of key list
  exception UNDEF of key
  val empty: 'a table
  val is_empty: 'a table -> bool
  val map: ('a -> 'b) -> 'a table -> 'b table
  val foldl: ('a * (key * 'b) -> 'a) -> 'a * 'b table -> 'a
  val dest: 'a table -> (key * 'a) list
  val keys: 'a table -> key list
  val min_key: 'a table -> key option
  val max_key: 'a table -> key option
  val exists: (key * 'a -> bool) -> 'a table -> bool
  val lookup: 'a table * key -> 'a option
  val update: (key * 'a) * 'a table -> 'a table
  val update_new: (key * 'a) * 'a table -> 'a table                    (*exception DUP*)
  val map_entry: key -> ('a -> 'a) -> 'a table -> 'a table
  val make: (key * 'a) list -> 'a table                                (*exception DUPS*)
  val extend: 'a table * (key * 'a) list -> 'a table                   (*exception DUPS*)
  val join: ('a * 'a -> 'a option) -> 'a table * 'a table -> 'a table  (*exception DUPS*)
  val merge: ('a * 'a -> bool) -> 'a table * 'a table -> 'a table      (*exception DUPS*)
  val delete: key -> 'a table -> 'a table                              (*exception UNDEF*)
  val lookup_multi: 'a list table * key -> 'a list
  val update_multi: (key * 'a) * 'a list table -> 'a list table
  val make_multi: (key * 'a) list -> 'a list table
  val dest_multi: 'a list table -> (key * 'a) list
  val merge_multi: ('a * 'a -> bool) ->
    'a list table * 'a list table -> 'a list table    (*exception DUPS*)
  val merge_multi': ('a * 'a -> bool) ->
    'a list table * 'a list table -> 'a list table    (*exception DUPS*)
end;

functor TableFun(Key: KEY): TABLE =
struct


(* datatype table *)

type key = Key.key;

datatype 'a table =
  Empty |
  Branch2 of 'a table * (key * 'a) * 'a table |
  Branch3 of 'a table * (key * 'a) * 'a table * (key * 'a) * 'a table;

exception DUP of key;
exception DUPS of key list;


(* empty *)

val empty = Empty;

fun is_empty Empty = true
  | is_empty _ = false;


(* map and fold combinators *)

fun map_table _ Empty = Empty
  | map_table f (Branch2 (left, (k, x), right)) =
      Branch2 (map_table f left, (k, f x), map_table f right)
  | map_table f (Branch3 (left, (k1, x1), mid, (k2, x2), right)) =
      Branch3 (map_table f left, (k1, f x1), map_table f mid, (k2, f x2), map_table f right);

fun foldl_table _ (x, Empty) = x
  | foldl_table f (x, Branch2 (left, p, right)) =
      foldl_table f (f (foldl_table f (x, left), p), right)
  | foldl_table f (x, Branch3 (left, p1, mid, p2, right)) =
      foldl_table f (f (foldl_table f (f (foldl_table f (x, left), p1), mid), p2), right);

fun dest tab = rev (foldl_table (fn (rev_ps, p) => p :: rev_ps) ([], tab));
fun keys tab = rev (foldl_table (fn (rev_ks, (k, _)) => k :: rev_ks) ([], tab));
fun exists P tab = foldl_table (fn (false, e) => P e | (b, _) => b) (false, tab);

fun min_key Empty = NONE
  | min_key (Branch2 (left, (k, _), _)) = SOME (getOpt (min_key left, k))
  | min_key (Branch3 (left, (k, _), _, _, _)) = SOME (getOpt (min_key left, k));

fun max_key Empty = NONE
  | max_key (Branch2 (_, (k, _), right)) = SOME (getOpt (max_key right, k))
  | max_key (Branch3 (_, _, _, (k,_), right)) = SOME (getOpt (max_key right, k));


(* lookup *)

fun lookup (Empty, _) = NONE
  | lookup (Branch2 (left, (k, x), right), key) =
      (case Key.ord (key, k) of
        LESS => lookup (left, key)
      | EQUAL => SOME x
      | GREATER => lookup (right, key))
  | lookup (Branch3 (left, (k1, x1), mid, (k2, x2), right), key) =
      (case Key.ord (key, k1) of
        LESS => lookup (left, key)
      | EQUAL => SOME x1
      | GREATER =>
          (case Key.ord (key, k2) of
            LESS => lookup (mid, key)
          | EQUAL => SOME x2
          | GREATER => lookup (right, key)));


(* updates *)

local

exception SAME;

datatype 'a growth =
  Stay of 'a table |
  Sprout of 'a table * (key * 'a) * 'a table;

fun modify key f tab =
  let
    fun modfy Empty = Sprout (Empty, (key, f NONE), Empty)
      | modfy (Branch2 (left, p as (k, x), right)) =
          (case Key.ord (key, k) of
            LESS =>
              (case modfy left of
                Stay left' => Stay (Branch2 (left', p, right))
              | Sprout (left1, q, left2) => Stay (Branch3 (left1, q, left2, p, right)))
          | EQUAL => Stay (Branch2 (left, (k, f (SOME x)), right))
          | GREATER =>
              (case modfy right of
                Stay right' => Stay (Branch2 (left, p, right'))
              | Sprout (right1, q, right2) =>
                  Stay (Branch3 (left, p, right1, q, right2))))
      | modfy (Branch3 (left, p1 as (k1, x1), mid, p2 as (k2, x2), right)) =
          (case Key.ord (key, k1) of
            LESS =>
              (case modfy left of
                Stay left' => Stay (Branch3 (left', p1, mid, p2, right))
              | Sprout (left1, q, left2) =>
                  Sprout (Branch2 (left1, q, left2), p1, Branch2 (mid, p2, right)))
          | EQUAL => Stay (Branch3 (left, (k1, f (SOME x1)), mid, p2, right))
          | GREATER =>
              (case Key.ord (key, k2) of
                LESS =>
                  (case modfy mid of
                    Stay mid' => Stay (Branch3 (left, p1, mid', p2, right))
                  | Sprout (mid1, q, mid2) =>
                      Sprout (Branch2 (left, p1, mid1), q, Branch2 (mid2, p2, right)))
              | EQUAL => Stay (Branch3 (left, p1, mid, (k2, f (SOME x2)), right))
              | GREATER =>
                  (case modfy right of
                    Stay right' => Stay (Branch3 (left, p1, mid, p2, right'))
                  | Sprout (right1, q, right2) =>
                      Sprout (Branch2 (left, p1, mid), p2, Branch2 (right1, q, right2)))));

  in
    (case modfy tab of
      Stay tab' => tab'
    | Sprout br => Branch2 br)
    handle SAME => tab
  end;

in

fun update ((k, x), tab) = modify k (fn _ => x) tab;
fun update_new ((k, x), tab) = modify k (fn NONE => x | SOME _ => raise DUP k) tab;
fun map_entry k f = modify k (fn NONE => raise SAME | SOME x => f x);

end;


(* extend and make *)

fun extend (table, pairs) =
  let
    fun add ((tab, dups), (key, x)) =
      (case lookup (tab, key) of
        NONE => (update ((key, x), tab), dups)
      | _ => (tab, key :: dups));
  in
    (case Library.foldl add ((table, []), pairs) of
      (table', []) => table'
    | (_, dups) => raise DUPS (rev dups))
  end;

fun make pairs = extend (empty, pairs);


(* delete *)

exception UNDEF of key;

local

fun compare NONE (k2, _) = LESS
  | compare (SOME k1) (k2, _) = Key.ord (k1, k2);

fun if_eq EQUAL x y = x
  | if_eq _ x y = y;

fun del (SOME k) Empty = raise UNDEF k
  | del NONE (Branch2 (Empty, p, Empty)) = (p, (true, Empty))
  | del NONE (Branch3 (Empty, p, Empty, q, Empty)) =
      (p, (false, Branch2 (Empty, q, Empty)))
  | del k (Branch2 (Empty, p, Empty)) = (case compare k p of
      EQUAL => (p, (true, Empty)) | _ => raise UNDEF (valOf k))
  | del k (Branch3 (Empty, p, Empty, q, Empty)) = (case compare k p of
      EQUAL => (p, (false, Branch2 (Empty, q, Empty)))
    | _ => (case compare k q of
        EQUAL => (q, (false, Branch2 (Empty, p, Empty)))
      | _ => raise UNDEF (valOf k)))
  | del k (Branch2 (l, p, r)) = (case compare k p of
      LESS => (case del k l of
        (p', (false, l')) => (p', (false, Branch2 (l', p, r)))
      | (p', (true, l')) => (p', case r of
          Branch2 (rl, rp, rr) =>
            (true, Branch3 (l', p, rl, rp, rr))
        | Branch3 (rl, rp, rm, rq, rr) => (false, Branch2
            (Branch2 (l', p, rl), rp, Branch2 (rm, rq, rr)))))
    | ord => (case del (if_eq ord NONE k) r of
        (p', (false, r')) => (p', (false, Branch2 (l, if_eq ord p' p, r')))
      | (p', (true, r')) => (p', case l of
          Branch2 (ll, lp, lr) =>
            (true, Branch3 (ll, lp, lr, if_eq ord p' p, r'))
        | Branch3 (ll, lp, lm, lq, lr) => (false, Branch2
            (Branch2 (ll, lp, lm), lq, Branch2 (lr, if_eq ord p' p, r'))))))
  | del k (Branch3 (l, p, m, q, r)) = (case compare k q of
      LESS => (case compare k p of
        LESS => (case del k l of
          (p', (false, l')) => (p', (false, Branch3 (l', p, m, q, r)))
        | (p', (true, l')) => (p', (false, case (m, r) of
            (Branch2 (ml, mp, mr), Branch2 _) =>
              Branch2 (Branch3 (l', p, ml, mp, mr), q, r)
          | (Branch3 (ml, mp, mm, mq, mr), _) =>
              Branch3 (Branch2 (l', p, ml), mp, Branch2 (mm, mq, mr), q, r)
          | (Branch2 (ml, mp, mr), Branch3 (rl, rp, rm, rq, rr)) =>
              Branch3 (Branch2 (l', p, ml), mp, Branch2 (mr, q, rl), rp,
                Branch2 (rm, rq, rr)))))
      | ord => (case del (if_eq ord NONE k) m of
          (p', (false, m')) =>
            (p', (false, Branch3 (l, if_eq ord p' p, m', q, r)))
        | (p', (true, m')) => (p', (false, case (l, r) of
            (Branch2 (ll, lp, lr), Branch2 _) =>
              Branch2 (Branch3 (ll, lp, lr, if_eq ord p' p, m'), q, r)
          | (Branch3 (ll, lp, lm, lq, lr), _) =>
              Branch3 (Branch2 (ll, lp, lm), lq,
                Branch2 (lr, if_eq ord p' p, m'), q, r)
          | (_, Branch3 (rl, rp, rm, rq, rr)) =>
              Branch3 (l, if_eq ord p' p, Branch2 (m', q, rl), rp,
                Branch2 (rm, rq, rr))))))
    | ord => (case del (if_eq ord NONE k) r of
        (q', (false, r')) =>
          (q', (false, Branch3 (l, p, m, if_eq ord q' q, r')))
      | (q', (true, r')) => (q', (false, case (l, m) of
          (Branch2 _, Branch2 (ml, mp, mr)) =>
            Branch2 (l, p, Branch3 (ml, mp, mr, if_eq ord q' q, r'))
        | (_, Branch3 (ml, mp, mm, mq, mr)) =>
            Branch3 (l, p, Branch2 (ml, mp, mm), mq,
              Branch2 (mr, if_eq ord q' q, r'))
        | (Branch3 (ll, lp, lm, lq, lr), Branch2 (ml, mp, mr)) =>
            Branch3 (Branch2 (ll, lp, lm), lq, Branch2 (lr, p, ml), mp,
              Branch2 (mr, if_eq ord q' q, r'))))));

in

fun delete k t = snd (snd (del (SOME k) t));

end;


(* join and merge *)

fun join f (table1, table2) =
  let
    fun add ((tab, dups), (key, x)) =
      (case lookup (tab, key) of
        NONE => (update ((key, x), tab), dups)
      | SOME y =>
          (case f (y, x) of
            SOME z => (update ((key, z), tab), dups)
          | NONE => (tab, key :: dups)));
  in
    (case foldl_table add ((table1, []), table2) of
      (table', []) => table'
    | (_, dups) => raise DUPS (rev dups))
  end;

fun merge eq tabs = join (fn (y, x) => if eq (y, x) then SOME y else NONE) tabs;


(* tables with multiple entries per key (preserving order) *)

fun lookup_multi tab_key = getOpt (lookup tab_key,[]);
fun update_multi ((key, x), tab) = update ((key, x :: lookup_multi (tab, key)), tab);

fun make_multi pairs = foldr update_multi empty pairs;
fun dest_multi tab = List.concat (map (fn (key, xs) => map (Library.pair key) xs) (dest tab));
fun merge_multi eq tabs = join (fn (xs, xs') => SOME (gen_merge_lists eq xs xs')) tabs;
fun merge_multi' eq tabs = join (fn (xs, xs') => SOME (gen_merge_lists' eq xs xs')) tabs;


(*final declarations of this structure!*)
val map = map_table;
val foldl = foldl_table;

end;


(*tables indexed by strings*)
structure Symtab = TableFun(type key = string val ord = string_ord);