--- a/src/Pure/defs.ML Sat May 20 23:37:02 2006 +0200
+++ b/src/Pure/defs.ML Sat May 20 23:37:02 2006 +0200
@@ -2,9 +2,9 @@
ID: $Id$
Author: Makarius
-Global well-formedness checks for constant definitions. Covers
-dependencies of simple sub-structural overloading, where type
-arguments are approximated by the outermost type constructor.
+Global well-formedness checks for constant definitions. Covers plain
+definitions and simple sub-structural overloading (depending on a
+single type argument).
*)
signature DEFS =
@@ -13,79 +13,28 @@
val specifications_of: T -> string ->
(serial * {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list}) list
val empty: T
- val merge: T * T -> T
- val define: (string * typ -> typ list) ->
+ val merge: Pretty.pp -> T * T -> T
+ val define: Pretty.pp -> Consts.T ->
bool -> bool -> string -> string -> string * typ -> (string * typ) list -> T -> T
end
structure Defs: DEFS =
struct
-(* dependency items *)
-
-(*
- Constant c covers all instances of c
-
- Instance (c, a) covers all instances of applications (c, [Type (a, _)])
-
- Different Constant/Constant or Instance/Instance items represent
- disjoint sets of instances. The set Constant c subsumes any
- Instance (c, a) -- dependencies are propagated accordingly.
-*)
-
-datatype item =
- Constant of string |
- Instance of string * string;
+(* consts with type arguments *)
-fun make_item (c, [Type (a, _)]) = Instance (c, a)
- | make_item (c, _) = Constant c;
-
-fun pretty_item (Constant c) = Pretty.str (quote c)
- | pretty_item (Instance (c, a)) = Pretty.str (quote c ^ " (type " ^ quote a ^ ")");
-
-fun item_ord (Constant c, Constant c') = fast_string_ord (c, c')
- | item_ord (Instance ca, Instance ca') = prod_ord fast_string_ord fast_string_ord (ca, ca')
- | item_ord (Constant _, Instance _) = LESS
- | item_ord (Instance _, Constant _) = GREATER;
-
-structure Items = GraphFun(type key = item val ord = item_ord);
-
-fun propagate_deps insts deps =
+fun print_const pp (c, args) =
let
- fun inst_item (Constant c) = Symtab.lookup_list insts c
- | inst_item (Instance _) = [];
- fun inst_edge i j =
- fold Items.add_edge_acyclic (tl (product (i :: inst_item i) (j :: inst_item j)));
- in Items.fold (fn (i, (_, (_, js))) => fold (inst_edge i) js) deps deps end;
+ val prt_args =
+ if null args then []
+ else [Pretty.brk 1, Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)];
+ in Pretty.string_of (Pretty.block (Pretty.str c :: prt_args)) end;
-(* specifications *)
+(* source specs *)
type spec = {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list};
-datatype T = Defs of
- {specs: (bool * spec Inttab.table) Symtab.table,
- insts: item list Symtab.table,
- deps: unit Items.T};
-
-fun no_overloading_of (Defs {specs, ...}) c =
- (case Symtab.lookup specs c of
- SOME (b, _) => b
- | NONE => false);
-
-fun specifications_of (Defs {specs, ...}) c =
- (case Symtab.lookup specs c of
- SOME (_, sps) => Inttab.dest sps
- | NONE => []);
-
-fun make_defs (specs, insts, deps) = Defs {specs = specs, insts = insts, deps = deps};
-fun map_defs f (Defs {specs, insts, deps}) = make_defs (f (specs, insts, deps));
-
-val empty = make_defs (Symtab.empty, Symtab.empty, Items.empty);
-
-
-(* disjoint specs *)
-
fun disjoint_types T U =
(Type.raw_unify (T, Logic.incr_tvar (maxidx_of_typ T + 1) U) Vartab.empty; false)
handle Type.TUNIFY => true;
@@ -97,60 +46,169 @@
" for constant " ^ quote c));
+(* patterns *)
+
+datatype pattern = Unknown | Plain | Overloaded;
+
+fun str_of_pattern Overloaded = "overloading"
+ | str_of_pattern _ = "no overloading";
+
+fun merge_pattern c (p1, p2) =
+ if p1 = p2 orelse p2 = Unknown then p1
+ else if p1 = Unknown then p2
+ else error ("Inconsistent type patterns for constant " ^ quote c ^ ":\n" ^
+ str_of_pattern p1 ^ " versus " ^ str_of_pattern p2);
+
+fun plain_args args =
+ forall Term.is_TVar args andalso not (has_duplicates (op =) args);
+
+fun the_pattern _ name (c, [Type (a, args)]) =
+ (Overloaded, if plain_args args then [] else [(a, (args, name))])
+ | the_pattern prt _ (c, args) =
+ if plain_args args then (Plain, [])
+ else error ("Illegal type pattern for constant " ^ prt (c, args));
+
+
+(* datatype defs *)
+
+type def =
+ {specs: spec Inttab.table,
+ pattern: pattern,
+ restricts: (string * (typ list * string)) list,
+ reducts: (typ list * (string * typ list) list) list};
+
+fun make_def (specs, pattern, restricts, reducts) =
+ {specs = specs, pattern = pattern, restricts = restricts, reducts = reducts}: def;
+
+fun map_def f ({specs, pattern, restricts, reducts}: def) =
+ make_def (f (specs, pattern, restricts, reducts));
+
+fun default_def (pattern, restricts) = make_def (Inttab.empty, pattern, restricts, []);
+
+datatype T = Defs of def Symtab.table;
+val empty = Defs Symtab.empty;
+
+fun lookup_list which (Defs defs) c =
+ (case Symtab.lookup defs c of
+ SOME def => which def
+ | NONE => []);
+
+val specifications_of = lookup_list (Inttab.dest o #specs);
+val restricts_of = lookup_list #restricts;
+val reducts_of = lookup_list #reducts;
+
+
+(* normalize defs *)
+
+fun matcher arg =
+ Option.map Envir.typ_subst_TVars
+ (SOME (Type.raw_matches arg Vartab.empty) handle Type.TYPE_MATCH => NONE);
+
+fun restriction prt defs (c, args) =
+ (case args of
+ [Type (a, Us)] =>
+ (case AList.lookup (op =) (restricts_of defs c) a of
+ SOME (Ts, name) =>
+ if is_some (matcher (Ts, Us)) then ()
+ else error ("Occurrence of overloaded constant " ^ prt (c, args) ^
+ "\nviolates restriction " ^ prt (c, Ts) ^ "\nimposed by " ^ quote name)
+ | NONE => ())
+ | _ => ());
+
+fun reduction defs deps =
+ let
+ fun reduct Us (Ts, rhs) =
+ (case matcher (Ts, Us) of
+ NONE => NONE
+ | SOME subst => SOME (map (apsnd (map subst)) rhs));
+ fun reducts (d, Us) = get_first (reduct Us) (reducts_of defs d);
+
+ fun add (NONE, dp) = insert (op =) dp
+ | add (SOME dps, _) = fold (insert (op =)) dps;
+ val deps' = map (`reducts) deps;
+ in
+ if forall (is_none o #1) deps' then NONE
+ else SOME (fold_rev add deps' [])
+ end;
+
+fun normalize prt defs (c, args) deps =
+ let
+ val reds = reduction defs deps;
+ val deps' = the_default deps reds;
+ val _ = List.app (restriction prt defs) ((c, args) :: deps');
+ val _ = deps' |> List.app (fn (c', args') =>
+ if c' = c andalso is_some (matcher (args, args')) then
+ error ("Circular dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (c, args'))
+ else ());
+ in reds end;
+
+
+(* dependencies *)
+
+fun normalize_deps prt defs0 (Defs defs) =
+ let
+ fun norm const deps = perhaps (normalize prt defs0 const) deps;
+ fun norm_update (c, {reducts, ...}) =
+ let val reducts' = reducts |> map (fn (args, deps) => (args, norm (c, args) deps)) in
+ if reducts = reducts' then I
+ else Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
+ (specs, pattern, restricts, reducts')))
+ end;
+ in Defs (Symtab.fold norm_update defs defs) end;
+
+fun dependencies prt (c, args) pat deps (Defs defs) =
+ let
+ val deps' = perhaps (normalize prt (Defs defs) (c, args)) deps;
+ val defs' = defs
+ |> Symtab.default (c, default_def pat)
+ |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
+ let
+ val pattern' = merge_pattern c (pattern, #1 pat);
+ val restricts' = Library.merge (op =) (restricts, #2 pat);
+ val reducts' = insert (op =) (args, deps') reducts;
+ in (specs, pattern', restricts', reducts') end));
+ in normalize_deps prt (Defs defs') (Defs defs') end;
+
+
(* merge *)
-fun cycle_msg css =
+fun join_specs c ({specs = specs1, pattern, restricts, reducts}, {specs = specs2, ...}: def) =
let
- fun prt_cycle items = Pretty.block (flat
- (separate [Pretty.str " ->", Pretty.brk 1] (map (single o pretty_item) items)));
- in Pretty.string_of (Pretty.big_list "Cyclic dependency of constants:" (map prt_cycle css)) end;
+ val specs' =
+ Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1;
+ in make_def (specs', pattern, restricts, reducts) end;
-fun merge
- (Defs {specs = specs1, insts = insts1, deps = deps1},
- Defs {specs = specs2, insts = insts2, deps = deps2}) =
+fun merge pp (Defs defs1, Defs defs2) =
let
- val specs' = (specs1, specs2) |> Symtab.join (fn c => fn ((b, sps1), (_, sps2)) =>
- (b, Inttab.fold (fn sp2 => (disjoint_specs c sp2 sps1; Inttab.update sp2)) sps2 sps1));
- val insts' = Symtab.merge_list (op =) (insts1, insts2);
- val items' = propagate_deps insts' (Items.merge_acyclic (K true) (deps1, deps2))
- handle Items.CYCLES cycles => error (cycle_msg cycles);
- in make_defs (specs', insts', items') end;
+ fun add_deps (c, args) pat deps defs =
+ if AList.defined (op =) (reducts_of defs c) args then defs
+ else dependencies (print_const pp) (c, args) pat deps defs;
+ fun add_def (c, {pattern, restricts, reducts, ...}) =
+ fold (fn (args, deps) => add_deps (c, args) (pattern, restricts) deps) reducts;
+ in Defs (Symtab.join join_specs (defs1, defs2)) |> Symtab.fold add_def defs2 end;
(* define *)
-fun pure_args args = forall Term.is_TVar args andalso not (has_duplicates (op =) args);
-
-fun define const_typargs unchecked is_def module name lhs rhs defs = defs
- |> map_defs (fn (specs, insts, deps) =>
+fun define pp consts unchecked is_def module name lhs rhs (Defs defs) =
let
- val (c, T) = lhs;
- val args = const_typargs lhs;
- val no_overloading = pure_args args;
- val rec_args = (case args of [Type (_, Ts)] => if pure_args Ts then Ts else [] | _ => []);
+ val prt = print_const pp;
+ fun typargs const = (#1 const, Consts.typargs consts const);
- val lhs' = make_item (c, args);
- val rhs' =
- if unchecked then []
- else rhs |> map_filter (fn (c', T') =>
- let val args' = const_typargs (c', T') in
- if gen_subset (op =) (args', rec_args) then NONE
- else SOME (make_item (c', if no_overloading_of defs c' then [] else args'))
- end);
+ val (c, args) = typargs lhs;
+ val pat =
+ if unchecked then (Unknown, [])
+ else the_pattern prt name (c, args);
+ val spec =
+ (serial (), {is_def = is_def, module = module, name = name, lhs = #2 lhs, rhs = rhs});
- val spec = (serial (), {is_def = is_def, module = module, name = name, lhs = T, rhs = rhs});
- val specs' = specs
- |> Symtab.default (c, (false, Inttab.empty))
- |> Symtab.map_entry c (fn (_, sps) =>
- (disjoint_specs c spec sps; (no_overloading, Inttab.update spec sps)));
- val insts' = insts |> fold (fn i as Instance (c, _) =>
- Symtab.insert_list (op =) (c, i) | _ => I) (lhs' :: rhs');
- val deps' = deps
- |> fold (Items.default_node o rpair ()) (lhs' :: rhs')
- |> Items.add_deps_acyclic (lhs', rhs')
- |> propagate_deps insts'
- handle Items.CYCLES cycles => error (cycle_msg cycles);
-
- in (specs', insts', deps') end);
+ val defs' = defs
+ |> Symtab.default (c, default_def pat)
+ |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
+ let
+ val _ = disjoint_specs c spec specs;
+ val specs' = Inttab.update spec specs;
+ in (specs', pattern, restricts, reducts) end));
+ in Defs defs' |> (if unchecked then I else dependencies prt (c, args) pat (map typargs rhs)) end;
end;