yet another re-implementation:
authorwenzelm
Sat, 20 May 2006 23:37:02 +0200
changeset 19692 bad13b32c0f3
parent 19691 dd9ccb370f52
child 19693 ab816ca8df06
yet another re-implementation: . maintain explicit mapping from unspecified to specified consts (no dependency graph, no termination check, but direct reduction of specifications); . more precise checking of LHS patterns -- specialized patterns (e.g. 'a => 'a instead of general 'a => 'b) impose global restrictions;
src/Pure/defs.ML
--- a/src/Pure/defs.ML	Sat May 20 23:37:02 2006 +0200
+++ b/src/Pure/defs.ML	Sat May 20 23:37:02 2006 +0200
@@ -2,9 +2,9 @@
     ID:         $Id$
     Author:     Makarius
 
-Global well-formedness checks for constant definitions.  Covers
-dependencies of simple sub-structural overloading, where type
-arguments are approximated by the outermost type constructor.
+Global well-formedness checks for constant definitions.  Covers plain
+definitions and simple sub-structural overloading (depending on a
+single type argument).
 *)
 
 signature DEFS =
@@ -13,79 +13,28 @@
   val specifications_of: T -> string ->
    (serial * {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list}) list
   val empty: T
-  val merge: T * T -> T
-  val define: (string * typ -> typ list) ->
+  val merge: Pretty.pp -> T * T -> T
+  val define: Pretty.pp -> Consts.T ->
     bool -> bool -> string -> string -> string * typ -> (string * typ) list -> T -> T
 end
 
 structure Defs: DEFS =
 struct
 
-(* dependency items *)
-
-(*
-  Constant c covers all instances of c
-
-  Instance (c, a) covers all instances of applications (c, [Type (a, _)])
-
-  Different Constant/Constant or Instance/Instance items represent
-  disjoint sets of instances.  The set Constant c subsumes any
-  Instance (c, a) -- dependencies are propagated accordingly.
-*)
-
-datatype item =
-  Constant of string |
-  Instance of string * string;
+(* consts with type arguments *)
 
-fun make_item (c, [Type (a, _)]) = Instance (c, a)
-  | make_item (c, _) = Constant c;
-
-fun pretty_item (Constant c) = Pretty.str (quote c)
-  | pretty_item (Instance (c, a)) = Pretty.str (quote c ^ " (type " ^ quote a ^ ")");
-
-fun item_ord (Constant c, Constant c') = fast_string_ord (c, c')
-  | item_ord (Instance ca, Instance ca') = prod_ord fast_string_ord fast_string_ord (ca, ca')
-  | item_ord (Constant _, Instance _) = LESS
-  | item_ord (Instance _, Constant _) = GREATER;
-
-structure Items = GraphFun(type key = item val ord = item_ord);
-
-fun propagate_deps insts deps =
+fun print_const pp (c, args) =
   let
-    fun inst_item (Constant c) = Symtab.lookup_list insts c
-      | inst_item (Instance _) = [];
-    fun inst_edge i j =
-      fold Items.add_edge_acyclic (tl (product (i :: inst_item i) (j :: inst_item j)));
-  in Items.fold (fn (i, (_, (_, js))) => fold (inst_edge i) js) deps deps end;
+    val prt_args =
+      if null args then []
+      else [Pretty.brk 1, Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)];
+  in Pretty.string_of (Pretty.block (Pretty.str c :: prt_args)) end;
 
 
-(* specifications *)
+(* source specs *)
 
 type spec = {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list};
 
-datatype T = Defs of
- {specs: (bool * spec Inttab.table) Symtab.table,
-  insts: item list Symtab.table,
-  deps: unit Items.T};
-
-fun no_overloading_of (Defs {specs, ...}) c =
-  (case Symtab.lookup specs c of
-    SOME (b, _) => b
-  | NONE => false);
-
-fun specifications_of (Defs {specs, ...}) c =
-  (case Symtab.lookup specs c of
-    SOME (_, sps) => Inttab.dest sps
-  | NONE => []);
-
-fun make_defs (specs, insts, deps) = Defs {specs = specs, insts = insts, deps = deps};
-fun map_defs f (Defs {specs, insts, deps}) = make_defs (f (specs, insts, deps));
-
-val empty = make_defs (Symtab.empty, Symtab.empty, Items.empty);
-
-
-(* disjoint specs *)
-
 fun disjoint_types T U =
   (Type.raw_unify (T, Logic.incr_tvar (maxidx_of_typ T + 1) U) Vartab.empty; false)
     handle Type.TUNIFY => true;
@@ -97,60 +46,169 @@
         " for constant " ^ quote c));
 
 
+(* patterns *)
+
+datatype pattern = Unknown | Plain | Overloaded;
+
+fun str_of_pattern Overloaded = "overloading"
+  | str_of_pattern _ = "no overloading";
+
+fun merge_pattern c (p1, p2) =
+  if p1 = p2 orelse p2 = Unknown then p1
+  else if p1 = Unknown then p2
+  else error ("Inconsistent type patterns for constant " ^ quote c ^ ":\n" ^
+    str_of_pattern p1 ^ " versus " ^ str_of_pattern p2);
+
+fun plain_args args =
+  forall Term.is_TVar args andalso not (has_duplicates (op =) args);
+
+fun the_pattern _ name (c, [Type (a, args)]) =
+      (Overloaded, if plain_args args then [] else [(a, (args, name))])
+  | the_pattern prt _ (c, args) =
+      if plain_args args then (Plain, [])
+      else error ("Illegal type pattern for constant " ^ prt (c, args));
+
+
+(* datatype defs *)
+
+type def =
+ {specs: spec Inttab.table,
+  pattern: pattern,
+  restricts: (string * (typ list * string)) list,
+  reducts: (typ list * (string * typ list) list) list};
+
+fun make_def (specs, pattern, restricts, reducts) =
+  {specs = specs, pattern = pattern, restricts = restricts, reducts = reducts}: def;
+
+fun map_def f ({specs, pattern, restricts, reducts}: def) =
+  make_def (f (specs, pattern, restricts, reducts));
+
+fun default_def (pattern, restricts) = make_def (Inttab.empty, pattern, restricts, []);
+
+datatype T = Defs of def Symtab.table;
+val empty = Defs Symtab.empty;
+
+fun lookup_list which (Defs defs) c =
+  (case Symtab.lookup defs c of
+    SOME def => which def
+  | NONE => []);
+
+val specifications_of = lookup_list (Inttab.dest o #specs);
+val restricts_of = lookup_list #restricts;
+val reducts_of = lookup_list #reducts;
+
+
+(* normalize defs *)
+
+fun matcher arg =
+  Option.map Envir.typ_subst_TVars
+    (SOME (Type.raw_matches arg Vartab.empty) handle Type.TYPE_MATCH => NONE);
+
+fun restriction prt defs (c, args) =
+  (case args of
+    [Type (a, Us)] =>
+      (case AList.lookup (op =) (restricts_of defs c) a of
+        SOME (Ts, name) =>
+          if is_some (matcher (Ts, Us)) then ()
+          else error ("Occurrence of overloaded constant " ^ prt (c, args) ^
+            "\nviolates restriction " ^ prt (c, Ts) ^ "\nimposed by " ^ quote name)
+      | NONE => ())
+  | _ => ());
+
+fun reduction defs deps =
+  let
+    fun reduct Us (Ts, rhs) =
+      (case matcher (Ts, Us) of
+        NONE => NONE
+      | SOME subst => SOME (map (apsnd (map subst)) rhs));
+    fun reducts (d, Us) = get_first (reduct Us) (reducts_of defs d);
+
+    fun add (NONE, dp) = insert (op =) dp
+      | add (SOME dps, _) = fold (insert (op =)) dps;
+    val deps' = map (`reducts) deps;
+  in
+    if forall (is_none o #1) deps' then NONE
+    else SOME (fold_rev add deps' [])
+  end;
+
+fun normalize prt defs (c, args) deps =
+  let
+    val reds = reduction defs deps;
+    val deps' = the_default deps reds;
+    val _ = List.app (restriction prt defs) ((c, args) :: deps');
+    val _ = deps' |> List.app (fn (c', args') =>
+      if c' = c andalso is_some (matcher (args, args')) then
+        error ("Circular dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (c, args'))
+      else ());
+  in reds end;
+
+
+(* dependencies *)
+
+fun normalize_deps prt defs0 (Defs defs) =
+  let
+    fun norm const deps = perhaps (normalize prt defs0 const) deps;
+    fun norm_update (c, {reducts, ...}) =
+      let val reducts' = reducts |> map (fn (args, deps) => (args, norm (c, args) deps)) in
+        if reducts = reducts' then I
+        else Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
+          (specs, pattern, restricts, reducts')))
+      end;
+  in Defs (Symtab.fold norm_update defs defs) end;
+
+fun dependencies prt (c, args) pat deps (Defs defs) =
+  let
+    val deps' = perhaps (normalize prt (Defs defs) (c, args)) deps;
+    val defs' = defs
+      |> Symtab.default (c, default_def pat)
+      |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
+        let
+          val pattern' = merge_pattern c (pattern, #1 pat);
+          val restricts' = Library.merge (op =) (restricts, #2 pat);
+          val reducts' = insert (op =) (args, deps') reducts;
+        in (specs, pattern', restricts', reducts') end));
+  in normalize_deps prt (Defs defs') (Defs defs') end;
+
+
 (* merge *)
 
-fun cycle_msg css =
+fun join_specs c ({specs = specs1, pattern, restricts, reducts}, {specs = specs2, ...}: def) =
   let
-    fun prt_cycle items = Pretty.block (flat
-      (separate [Pretty.str " ->", Pretty.brk 1] (map (single o pretty_item) items)));
-  in Pretty.string_of (Pretty.big_list "Cyclic dependency of constants:" (map prt_cycle css)) end;
+    val specs' =
+      Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1;
+  in make_def (specs', pattern, restricts, reducts) end;
 
-fun merge
-   (Defs {specs = specs1, insts = insts1, deps = deps1},
-    Defs {specs = specs2, insts = insts2, deps = deps2}) =
+fun merge pp (Defs defs1, Defs defs2) =
   let
-    val specs' = (specs1, specs2) |> Symtab.join (fn c => fn ((b, sps1), (_, sps2)) =>
-      (b, Inttab.fold (fn sp2 => (disjoint_specs c sp2 sps1; Inttab.update sp2)) sps2 sps1));
-    val insts' = Symtab.merge_list (op =) (insts1, insts2);
-    val items' = propagate_deps insts' (Items.merge_acyclic (K true) (deps1, deps2))
-      handle Items.CYCLES cycles => error (cycle_msg cycles);
-  in make_defs (specs', insts', items') end;
+    fun add_deps (c, args) pat deps defs =
+      if AList.defined (op =) (reducts_of defs c) args then defs
+      else dependencies (print_const pp) (c, args) pat deps defs;
+    fun add_def (c, {pattern, restricts, reducts, ...}) =
+      fold (fn (args, deps) => add_deps (c, args) (pattern, restricts) deps) reducts;
+  in Defs (Symtab.join join_specs (defs1, defs2)) |> Symtab.fold add_def defs2 end;
 
 
 (* define *)
 
-fun pure_args args = forall Term.is_TVar args andalso not (has_duplicates (op =) args);
-
-fun define const_typargs unchecked is_def module name lhs rhs defs = defs
-    |> map_defs (fn (specs, insts, deps) =>
+fun define pp consts unchecked is_def module name lhs rhs (Defs defs) =
   let
-    val (c, T) = lhs;
-    val args = const_typargs lhs;
-    val no_overloading = pure_args args;
-    val rec_args = (case args of [Type (_, Ts)] => if pure_args Ts then Ts else [] | _ => []);
+    val prt = print_const pp;
+    fun typargs const = (#1 const, Consts.typargs consts const);
 
-    val lhs' = make_item (c, args);
-    val rhs' =
-      if unchecked then []
-      else rhs |> map_filter (fn (c', T') =>
-        let val args' = const_typargs (c', T') in
-          if gen_subset (op =) (args', rec_args) then NONE
-          else SOME (make_item (c', if no_overloading_of defs c' then [] else args'))
-        end);
+    val (c, args) = typargs lhs;
+    val pat =
+      if unchecked then (Unknown, [])
+      else the_pattern prt name (c, args);
+    val spec =
+      (serial (), {is_def = is_def, module = module, name = name, lhs = #2 lhs, rhs = rhs});
 
-    val spec = (serial (), {is_def = is_def, module = module, name = name, lhs = T, rhs = rhs});
-    val specs' = specs
-      |> Symtab.default (c, (false, Inttab.empty))
-      |> Symtab.map_entry c (fn (_, sps) =>
-          (disjoint_specs c spec sps; (no_overloading, Inttab.update spec sps)));
-    val insts' = insts |> fold (fn i as Instance (c, _) =>
-        Symtab.insert_list (op =) (c, i) | _ => I) (lhs' :: rhs');
-    val deps' = deps
-      |> fold (Items.default_node o rpair ()) (lhs' :: rhs')
-      |> Items.add_deps_acyclic (lhs', rhs')
-      |> propagate_deps insts'
-      handle Items.CYCLES cycles => error (cycle_msg cycles);
-
-  in (specs', insts', deps') end);
+    val defs' = defs
+      |> Symtab.default (c, default_def pat)
+      |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
+        let
+          val _ = disjoint_specs c spec specs;
+          val specs' = Inttab.update spec specs;
+        in (specs', pattern, restricts, reducts) end));
+  in Defs defs' |> (if unchecked then I else dependencies prt (c, args) pat (map typargs rhs)) end;
 
 end;