specifications_of: lhs/rhs represented as typargs;
authorwenzelm
Mon, 22 May 2006 22:29:15 +0200
changeset 19697 423af2e013b8
parent 19696 26a268c299d8
child 19698 f48cfaacd92c
specifications_of: lhs/rhs represented as typargs; export pretty_const; export dest; more precise checking of lhs patterns; more precise normalization; misc cleanup;
src/Pure/defs.ML
--- a/src/Pure/defs.ML	Mon May 22 21:27:01 2006 +0200
+++ b/src/Pure/defs.ML	Mon May 22 22:29:15 2006 +0200
@@ -9,9 +9,13 @@
 
 signature DEFS =
 sig
+  val pretty_const: Pretty.pp -> string * typ list -> Pretty.T
   type T
-  val specifications_of: T -> string ->
-   (serial * {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list}) list
+  val specifications_of: T -> string -> (serial * {is_def: bool, module: string, name: string,
+    lhs: typ list, rhs: (string * typ list) list}) list
+  val dest: T ->
+   {restricts: ((string * typ list) * string) list,
+    reducts: ((string * typ list) * (string * typ list) list) list}
   val empty: T
   val merge: Pretty.pp -> T * T -> T
   val define: Pretty.pp -> Consts.T ->
@@ -21,72 +25,47 @@
 structure Defs: DEFS =
 struct
 
-(* consts with type arguments *)
+
+(* type arguments *)
 
-fun print_const pp (c, args) =
+type args = typ list;
+
+fun pretty_const pp (c, args) =
   let
     val prt_args =
       if null args then []
-      else [Pretty.brk 1, Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)];
-  in Pretty.string_of (Pretty.block (Pretty.str c :: prt_args)) end;
-
-
-(* source specs *)
-
-type spec = {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list};
-
-fun disjoint_types T U =
-  (Type.raw_unify (T, Logic.incr_tvar (maxidx_of_typ T + 1) U) Vartab.empty; false)
-    handle Type.TUNIFY => true;
-
-fun disjoint_specs c (i, {lhs = T, name = a, ...}: spec) =
-  Inttab.forall (fn (j, {lhs = U, name = b, ...}: spec) =>
-    i = j orelse not (Type.could_unify (T, U)) orelse disjoint_types T U orelse
-      error ("Type clash in specifications " ^ quote a ^ " and " ^ quote b ^
-        " for constant " ^ quote c));
-
+      else [Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)];
+  in Pretty.block (Pretty.str c :: prt_args) end;
 
-(* patterns *)
-
-datatype pattern = Unknown | Plain | Overloaded;
-
-fun str_of_pattern Overloaded = "overloading"
-  | str_of_pattern _ = "no overloading";
+fun disjoint_args (Ts, Us) =
+  not (Type.could_unifys (Ts, Us)) orelse
+    ((Type.raw_unifys (Ts, map (Logic.incr_tvar (maxidx_of_typs Ts + 1)) Us) Vartab.empty; false)
+      handle Type.TUNIFY => true);
 
-fun merge_pattern c (p1, p2) =
-  if p1 = p2 orelse p2 = Unknown then p1
-  else if p1 = Unknown then p2
-  else error ("Inconsistent type patterns for constant " ^ quote c ^ ":\n" ^
-    str_of_pattern p1 ^ " versus " ^ str_of_pattern p2);
-
-fun plain_args args =
-  forall Term.is_TVar args andalso not (has_duplicates (op =) args);
-
-fun the_pattern _ name (c, [Type (a, args)]) =
-      (Overloaded, if plain_args args then [] else [(a, (args, name))])
-  | the_pattern prt _ (c, args) =
-      if plain_args args then (Plain, [])
-      else error ("Illegal type pattern for constant " ^ prt (c, args));
+fun match_args (Ts, Us) =
+  Option.map Envir.typ_subst_TVars
+    (SOME (Type.raw_matches (Ts, Us) Vartab.empty) handle Type.TYPE_MATCH => NONE);
 
 
 (* datatype defs *)
 
+type spec = {is_def: bool, module: string, name: string, lhs: args, rhs: (string * args) list};
+
 type def =
  {specs: spec Inttab.table,
-  pattern: pattern,
-  restricts: (string * (typ list * string)) list,
-  reducts: (typ list * (string * typ list) list) list};
+  restricts: (args * string) list,
+  reducts: (args * (string * args) list) list};
+
+fun make_def (specs, restricts, reducts) =
+  {specs = specs, restricts = restricts, reducts = reducts}: def;
 
-fun make_def (specs, pattern, restricts, reducts) =
-  {specs = specs, pattern = pattern, restricts = restricts, reducts = reducts}: def;
+fun map_def c f =
+  Symtab.default (c, make_def (Inttab.empty, [], [])) #>
+  Symtab.map_entry c (fn {specs, restricts, reducts}: def =>
+    make_def (f (specs, restricts, reducts)));
 
-fun map_def f ({specs, pattern, restricts, reducts}: def) =
-  make_def (f (specs, pattern, restricts, reducts));
-
-fun default_def (pattern, restricts) = make_def (Inttab.empty, pattern, restricts, []);
 
 datatype T = Defs of def Symtab.table;
-val empty = Defs Symtab.empty;
 
 fun lookup_list which (Defs defs) c =
   (case Symtab.lookup defs c of
@@ -97,31 +76,46 @@
 val restricts_of = lookup_list #restricts;
 val reducts_of = lookup_list #reducts;
 
-
-(* normalize defs *)
+fun dest (Defs defs) =
+  let
+    val restricts = Symtab.fold (fn (c, {restricts, ...}) =>
+      fold (fn (args, name) => cons ((c, args), name)) restricts) defs [];
+    val reducts = Symtab.fold (fn (c, {reducts, ...}) =>
+      fold (fn (args, deps) => cons ((c, args), deps)) reducts) defs [];
+  in {restricts = restricts, reducts = reducts} end;
 
-fun matcher arg =
-  Option.map Envir.typ_subst_TVars
-    (SOME (Type.raw_matches arg Vartab.empty) handle Type.TYPE_MATCH => NONE);
+val empty = Defs Symtab.empty;
+
+
+(* specifications *)
 
-fun restriction prt defs (c, args) =
-  (case args of
-    [Type (a, Us)] =>
-      (case AList.lookup (op =) (restricts_of defs c) a of
-        SOME (Ts, name) =>
-          if is_some (matcher (Ts, Us)) then ()
-          else error ("Occurrence of overloaded constant " ^ prt (c, args) ^
-            "\nviolates restriction " ^ prt (c, Ts) ^ "\nimposed by " ^ quote name)
-      | NONE => ())
-  | _ => ());
+fun disjoint_specs c (i, {lhs = Ts, name = a, ...}: spec) =
+  Inttab.forall (fn (j, {lhs = Us, name = b, ...}: spec) =>
+    i = j orelse disjoint_args (Ts, Us) orelse
+      error ("Type clash in specifications " ^ quote a ^ " and " ^ quote b ^
+        " for constant " ^ quote c));
 
-fun reduction defs deps =
+fun join_specs c ({specs = specs1, restricts, reducts}, {specs = specs2, ...}: def) =
+  let
+    val specs' =
+      Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1;
+  in make_def (specs', restricts, reducts) end;
+
+fun update_specs c spec = map_def c (fn (specs, restricts, reducts) =>
+  (disjoint_specs c spec specs; (Inttab.update spec specs, restricts, reducts)));
+
+
+(* normalization: reduction and well-formedness check *)
+
+local
+
+fun reduction reds_of deps =
   let
     fun reduct Us (Ts, rhs) =
-      (case matcher (Ts, Us) of
+      (case match_args (Ts, Us) of
         NONE => NONE
       | SOME subst => SOME (map (apsnd (map subst)) rhs));
-    fun reducts (d, Us) = get_first (reduct Us) (reducts_of defs d);
+    fun reducts (d: string, Us) = get_first (reduct Us) (reds_of d);
 
     fun add (NONE, dp) = insert (op =) dp
       | add (SOME dps, _) = fold (insert (op =)) dps;
@@ -131,84 +125,114 @@
     else SOME (fold_rev add deps' [])
   end;
 
-fun normalize prt defs (c, args) deps =
+fun reductions reds_of deps =
+  (case reduction reds_of deps of
+    SOME deps' => reductions reds_of deps'
+  | NONE => deps);
+
+fun contained U (Type (_, Ts)) = exists (fn T => T = U orelse contained U T) Ts
+  | contained _ _ = false;
+
+fun wellformed pp rests_of (c, args) (d, Us) =
   let
-    val reds = reduction defs deps;
-    val deps' = the_default deps reds;
-    val _ = List.app (restriction prt defs) ((c, args) :: deps');
-    val _ = deps' |> List.app (fn (c', args') =>
-      if c' = c andalso is_some (matcher (args, args')) then
-        error ("Circular dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (c, args'))
-      else ());
-  in reds end;
+    val prt = Pretty.string_of o pretty_const pp;
+    fun err s1 s2 =
+      error (s1 ^ " dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (d, Us) ^ s2);
+  in
+    exists (fn U => exists (contained U) args) Us orelse
+    (c <> d andalso exists (member (op =) args) Us) orelse
+      (case find_first (fn (Ts, _) => not (disjoint_args (Ts, Us))) (rests_of d) of
+        NONE =>
+          c <> d orelse is_none (match_args (args, Us)) orelse err "Circular" ""
+      | SOME (Ts, name) =>
+          if c = d then err "Circular" ("\n(via " ^ quote name ^ ")")
+          else
+            err "Malformed" ("\n(restriction " ^ prt (d, Ts) ^ " from " ^ quote name ^ ")"))
+  end;
 
-
-(* dependencies *)
-
-fun normalize_deps prt defs0 (Defs defs) =
+fun normalize pp rests_of reds_of (c, args) deps =
   let
-    fun norm const deps = perhaps (normalize prt defs0 const) deps;
-    fun norm_update (c, {reducts, ...}: def) =
-      let val reducts' = reducts |> map (fn (args, deps) => (args, norm (c, args) deps)) in
-        if reducts = reducts' then I
-        else Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
-          (specs, pattern, restricts, reducts')))
-      end;
-  in Defs (Symtab.fold norm_update defs defs) end;
+    val deps' = reductions reds_of deps;
+    val _ = forall (wellformed pp rests_of (c, args)) deps';
+  in deps' end;
 
-fun dependencies prt (c, args) pat deps (Defs defs) =
+fun normalize_all pp (c, args) deps defs =
   let
-    val deps' = perhaps (normalize prt (Defs defs) (c, args)) deps;
+    val norm = normalize pp (restricts_of (Defs defs));
+    val norm_rule = norm (fn c' => if c' = c then [(args, deps)] else []);
+    val norm_defs = norm (reducts_of (Defs defs));
+    fun norm_update (c', {reducts, ...}: def) =
+      let val reducts' = reducts
+        |> map (fn (args', deps') => (args', norm_defs (c', args') (norm_rule (c', args') deps')))
+      in
+        K (reducts <> reducts') ?
+          map_def c' (fn (specs, restricts, reducts) => (specs, restricts, reducts'))
+      end;
+  in Symtab.fold norm_update defs defs end;
+
+in
+
+fun dependencies pp (c, args) restr deps (Defs defs) =
+  let
+    val deps' = normalize pp (restricts_of (Defs defs)) (reducts_of (Defs defs)) (c, args) deps;
     val defs' = defs
-      |> Symtab.default (c, default_def pat)
-      |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
-        let
-          val pattern' = merge_pattern c (pattern, #1 pat);
-          val restricts' = Library.merge (op =) (restricts, #2 pat);
-          val reducts' = insert (op =) (args, deps') reducts;
-        in (specs, pattern', restricts', reducts') end));
-  in normalize_deps prt (Defs defs') (Defs defs') end;
+      |> map_def c (fn (specs, restricts, reducts) =>
+        (specs, Library.merge (op =) (restricts, restr), reducts))
+      |> normalize_all pp (c, args) deps';
+    val deps'' =
+      normalize pp (restricts_of (Defs defs')) (reducts_of (Defs defs')) (c, args) deps';
+    val defs'' = defs'
+      |> map_def c (fn (specs, restricts, reducts) =>
+        (specs, restricts, insert (op =) (args, deps'') reducts));
+  in Defs defs'' end;
+
+end;
 
 
 (* merge *)
 
-fun join_specs c ({specs = specs1, pattern, restricts, reducts}, {specs = specs2, ...}: def) =
-  let
-    val specs' =
-      Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1;
-  in make_def (specs', pattern, restricts, reducts) end;
-
 fun merge pp (Defs defs1, Defs defs2) =
   let
-    fun add_deps (c, args) pat deps defs =
+    fun add_deps (c, args) restr deps defs =
       if AList.defined (op =) (reducts_of defs c) args then defs
-      else dependencies (print_const pp) (c, args) pat deps defs;
-    fun add_def (c, {pattern, restricts, reducts, ...}: def) =
-      fold (fn (args, deps) => add_deps (c, args) (pattern, restricts) deps) reducts;
+      else dependencies pp (c, args) restr deps defs;
+    fun add_def (c, {restricts, reducts, ...}: def) =
+      fold (fn (args, deps) => add_deps (c, args) restricts deps) reducts;
   in Defs (Symtab.join join_specs (defs1, defs2)) |> Symtab.fold add_def defs2 end;
 
+local  (* FIXME *)
+  val merge_aux = merge
+  val acc = Output.time_accumulator "Defs.merge"
+in fun merge pp = acc (merge_aux pp) end;
+
 
 (* define *)
 
+fun plain_args args =
+  forall Term.is_TVar args andalso not (has_duplicates (op =) args);
+
 fun define pp consts unchecked is_def module name lhs rhs (Defs defs) =
   let
-    val prt = print_const pp;
     fun typargs const = (#1 const, Consts.typargs consts const);
-
     val (c, args) = typargs lhs;
-    val pat =
-      if unchecked then (Unknown, [])
-      else the_pattern prt name (c, args);
+    val deps = map typargs rhs;
+    val restr =
+      if plain_args args orelse
+        (case args of [Type (a, rec_args)] => plain_args rec_args | _ => false)
+      then [] else [(args, name)];
     val spec =
-      (serial (), {is_def = is_def, module = module, name = name, lhs = #2 lhs, rhs = rhs});
+      (serial (), {is_def = is_def, module = module, name = name, lhs = args, rhs = deps});
+    val defs' = defs |> update_specs c spec;
+  in Defs defs' |> (if unchecked then I else dependencies pp (c, args) restr deps) end;
+
 
-    val defs' = defs
-      |> Symtab.default (c, default_def pat)
-      |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
-        let
-          val _ = disjoint_specs c spec specs;
-          val specs' = Inttab.update spec specs;
-        in (specs', pattern, restricts, reducts) end));
-  in Defs defs' |> (if unchecked then I else dependencies prt (c, args) pat (map typargs rhs)) end;
+local  (* FIXME *)
+  val define_aux = define
+  val acc = Output.time_accumulator "Defs.define"
+in
+  fun define pp consts unchecked is_def module name lhs rhs =
+    acc (define_aux pp consts unchecked is_def module name lhs rhs)
+end;
+
 
 end;