src/Pure/defs.ML
changeset 19692 bad13b32c0f3
parent 19628 de019ddcd89e
child 19695 7706aeac6cf1
     1.1 --- a/src/Pure/defs.ML	Sat May 20 23:37:02 2006 +0200
     1.2 +++ b/src/Pure/defs.ML	Sat May 20 23:37:02 2006 +0200
     1.3 @@ -2,9 +2,9 @@
     1.4      ID:         $Id$
     1.5      Author:     Makarius
     1.6  
     1.7 -Global well-formedness checks for constant definitions.  Covers
     1.8 -dependencies of simple sub-structural overloading, where type
     1.9 -arguments are approximated by the outermost type constructor.
    1.10 +Global well-formedness checks for constant definitions.  Covers plain
    1.11 +definitions and simple sub-structural overloading (depending on a
    1.12 +single type argument).
    1.13  *)
    1.14  
    1.15  signature DEFS =
    1.16 @@ -13,79 +13,28 @@
    1.17    val specifications_of: T -> string ->
    1.18     (serial * {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list}) list
    1.19    val empty: T
    1.20 -  val merge: T * T -> T
    1.21 -  val define: (string * typ -> typ list) ->
    1.22 +  val merge: Pretty.pp -> T * T -> T
    1.23 +  val define: Pretty.pp -> Consts.T ->
    1.24      bool -> bool -> string -> string -> string * typ -> (string * typ) list -> T -> T
    1.25  end
    1.26  
    1.27  structure Defs: DEFS =
    1.28  struct
    1.29  
    1.30 -(* dependency items *)
    1.31 -
    1.32 -(*
    1.33 -  Constant c covers all instances of c
    1.34 -
    1.35 -  Instance (c, a) covers all instances of applications (c, [Type (a, _)])
    1.36 -
    1.37 -  Different Constant/Constant or Instance/Instance items represent
    1.38 -  disjoint sets of instances.  The set Constant c subsumes any
    1.39 -  Instance (c, a) -- dependencies are propagated accordingly.
    1.40 -*)
    1.41 -
    1.42 -datatype item =
    1.43 -  Constant of string |
    1.44 -  Instance of string * string;
    1.45 +(* consts with type arguments *)
    1.46  
    1.47 -fun make_item (c, [Type (a, _)]) = Instance (c, a)
    1.48 -  | make_item (c, _) = Constant c;
    1.49 -
    1.50 -fun pretty_item (Constant c) = Pretty.str (quote c)
    1.51 -  | pretty_item (Instance (c, a)) = Pretty.str (quote c ^ " (type " ^ quote a ^ ")");
    1.52 -
    1.53 -fun item_ord (Constant c, Constant c') = fast_string_ord (c, c')
    1.54 -  | item_ord (Instance ca, Instance ca') = prod_ord fast_string_ord fast_string_ord (ca, ca')
    1.55 -  | item_ord (Constant _, Instance _) = LESS
    1.56 -  | item_ord (Instance _, Constant _) = GREATER;
    1.57 -
    1.58 -structure Items = GraphFun(type key = item val ord = item_ord);
    1.59 -
    1.60 -fun propagate_deps insts deps =
    1.61 +fun print_const pp (c, args) =
    1.62    let
    1.63 -    fun inst_item (Constant c) = Symtab.lookup_list insts c
    1.64 -      | inst_item (Instance _) = [];
    1.65 -    fun inst_edge i j =
    1.66 -      fold Items.add_edge_acyclic (tl (product (i :: inst_item i) (j :: inst_item j)));
    1.67 -  in Items.fold (fn (i, (_, (_, js))) => fold (inst_edge i) js) deps deps end;
    1.68 +    val prt_args =
    1.69 +      if null args then []
    1.70 +      else [Pretty.brk 1, Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)];
    1.71 +  in Pretty.string_of (Pretty.block (Pretty.str c :: prt_args)) end;
    1.72  
    1.73  
    1.74 -(* specifications *)
    1.75 +(* source specs *)
    1.76  
    1.77  type spec = {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list};
    1.78  
    1.79 -datatype T = Defs of
    1.80 - {specs: (bool * spec Inttab.table) Symtab.table,
    1.81 -  insts: item list Symtab.table,
    1.82 -  deps: unit Items.T};
    1.83 -
    1.84 -fun no_overloading_of (Defs {specs, ...}) c =
    1.85 -  (case Symtab.lookup specs c of
    1.86 -    SOME (b, _) => b
    1.87 -  | NONE => false);
    1.88 -
    1.89 -fun specifications_of (Defs {specs, ...}) c =
    1.90 -  (case Symtab.lookup specs c of
    1.91 -    SOME (_, sps) => Inttab.dest sps
    1.92 -  | NONE => []);
    1.93 -
    1.94 -fun make_defs (specs, insts, deps) = Defs {specs = specs, insts = insts, deps = deps};
    1.95 -fun map_defs f (Defs {specs, insts, deps}) = make_defs (f (specs, insts, deps));
    1.96 -
    1.97 -val empty = make_defs (Symtab.empty, Symtab.empty, Items.empty);
    1.98 -
    1.99 -
   1.100 -(* disjoint specs *)
   1.101 -
   1.102  fun disjoint_types T U =
   1.103    (Type.raw_unify (T, Logic.incr_tvar (maxidx_of_typ T + 1) U) Vartab.empty; false)
   1.104      handle Type.TUNIFY => true;
   1.105 @@ -97,60 +46,169 @@
   1.106          " for constant " ^ quote c));
   1.107  
   1.108  
   1.109 +(* patterns *)
   1.110 +
   1.111 +datatype pattern = Unknown | Plain | Overloaded;
   1.112 +
   1.113 +fun str_of_pattern Overloaded = "overloading"
   1.114 +  | str_of_pattern _ = "no overloading";
   1.115 +
   1.116 +fun merge_pattern c (p1, p2) =
   1.117 +  if p1 = p2 orelse p2 = Unknown then p1
   1.118 +  else if p1 = Unknown then p2
   1.119 +  else error ("Inconsistent type patterns for constant " ^ quote c ^ ":\n" ^
   1.120 +    str_of_pattern p1 ^ " versus " ^ str_of_pattern p2);
   1.121 +
   1.122 +fun plain_args args =
   1.123 +  forall Term.is_TVar args andalso not (has_duplicates (op =) args);
   1.124 +
   1.125 +fun the_pattern _ name (c, [Type (a, args)]) =
   1.126 +      (Overloaded, if plain_args args then [] else [(a, (args, name))])
   1.127 +  | the_pattern prt _ (c, args) =
   1.128 +      if plain_args args then (Plain, [])
   1.129 +      else error ("Illegal type pattern for constant " ^ prt (c, args));
   1.130 +
   1.131 +
   1.132 +(* datatype defs *)
   1.133 +
   1.134 +type def =
   1.135 + {specs: spec Inttab.table,
   1.136 +  pattern: pattern,
   1.137 +  restricts: (string * (typ list * string)) list,
   1.138 +  reducts: (typ list * (string * typ list) list) list};
   1.139 +
   1.140 +fun make_def (specs, pattern, restricts, reducts) =
   1.141 +  {specs = specs, pattern = pattern, restricts = restricts, reducts = reducts}: def;
   1.142 +
   1.143 +fun map_def f ({specs, pattern, restricts, reducts}: def) =
   1.144 +  make_def (f (specs, pattern, restricts, reducts));
   1.145 +
   1.146 +fun default_def (pattern, restricts) = make_def (Inttab.empty, pattern, restricts, []);
   1.147 +
   1.148 +datatype T = Defs of def Symtab.table;
   1.149 +val empty = Defs Symtab.empty;
   1.150 +
   1.151 +fun lookup_list which (Defs defs) c =
   1.152 +  (case Symtab.lookup defs c of
   1.153 +    SOME def => which def
   1.154 +  | NONE => []);
   1.155 +
   1.156 +val specifications_of = lookup_list (Inttab.dest o #specs);
   1.157 +val restricts_of = lookup_list #restricts;
   1.158 +val reducts_of = lookup_list #reducts;
   1.159 +
   1.160 +
   1.161 +(* normalize defs *)
   1.162 +
   1.163 +fun matcher arg =
   1.164 +  Option.map Envir.typ_subst_TVars
   1.165 +    (SOME (Type.raw_matches arg Vartab.empty) handle Type.TYPE_MATCH => NONE);
   1.166 +
   1.167 +fun restriction prt defs (c, args) =
   1.168 +  (case args of
   1.169 +    [Type (a, Us)] =>
   1.170 +      (case AList.lookup (op =) (restricts_of defs c) a of
   1.171 +        SOME (Ts, name) =>
   1.172 +          if is_some (matcher (Ts, Us)) then ()
   1.173 +          else error ("Occurrence of overloaded constant " ^ prt (c, args) ^
   1.174 +            "\nviolates restriction " ^ prt (c, Ts) ^ "\nimposed by " ^ quote name)
   1.175 +      | NONE => ())
   1.176 +  | _ => ());
   1.177 +
   1.178 +fun reduction defs deps =
   1.179 +  let
   1.180 +    fun reduct Us (Ts, rhs) =
   1.181 +      (case matcher (Ts, Us) of
   1.182 +        NONE => NONE
   1.183 +      | SOME subst => SOME (map (apsnd (map subst)) rhs));
   1.184 +    fun reducts (d, Us) = get_first (reduct Us) (reducts_of defs d);
   1.185 +
   1.186 +    fun add (NONE, dp) = insert (op =) dp
   1.187 +      | add (SOME dps, _) = fold (insert (op =)) dps;
   1.188 +    val deps' = map (`reducts) deps;
   1.189 +  in
   1.190 +    if forall (is_none o #1) deps' then NONE
   1.191 +    else SOME (fold_rev add deps' [])
   1.192 +  end;
   1.193 +
   1.194 +fun normalize prt defs (c, args) deps =
   1.195 +  let
   1.196 +    val reds = reduction defs deps;
   1.197 +    val deps' = the_default deps reds;
   1.198 +    val _ = List.app (restriction prt defs) ((c, args) :: deps');
   1.199 +    val _ = deps' |> List.app (fn (c', args') =>
   1.200 +      if c' = c andalso is_some (matcher (args, args')) then
   1.201 +        error ("Circular dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (c, args'))
   1.202 +      else ());
   1.203 +  in reds end;
   1.204 +
   1.205 +
   1.206 +(* dependencies *)
   1.207 +
   1.208 +fun normalize_deps prt defs0 (Defs defs) =
   1.209 +  let
   1.210 +    fun norm const deps = perhaps (normalize prt defs0 const) deps;
   1.211 +    fun norm_update (c, {reducts, ...}) =
   1.212 +      let val reducts' = reducts |> map (fn (args, deps) => (args, norm (c, args) deps)) in
   1.213 +        if reducts = reducts' then I
   1.214 +        else Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
   1.215 +          (specs, pattern, restricts, reducts')))
   1.216 +      end;
   1.217 +  in Defs (Symtab.fold norm_update defs defs) end;
   1.218 +
   1.219 +fun dependencies prt (c, args) pat deps (Defs defs) =
   1.220 +  let
   1.221 +    val deps' = perhaps (normalize prt (Defs defs) (c, args)) deps;
   1.222 +    val defs' = defs
   1.223 +      |> Symtab.default (c, default_def pat)
   1.224 +      |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
   1.225 +        let
   1.226 +          val pattern' = merge_pattern c (pattern, #1 pat);
   1.227 +          val restricts' = Library.merge (op =) (restricts, #2 pat);
   1.228 +          val reducts' = insert (op =) (args, deps') reducts;
   1.229 +        in (specs, pattern', restricts', reducts') end));
   1.230 +  in normalize_deps prt (Defs defs') (Defs defs') end;
   1.231 +
   1.232 +
   1.233  (* merge *)
   1.234  
   1.235 -fun cycle_msg css =
   1.236 +fun join_specs c ({specs = specs1, pattern, restricts, reducts}, {specs = specs2, ...}: def) =
   1.237    let
   1.238 -    fun prt_cycle items = Pretty.block (flat
   1.239 -      (separate [Pretty.str " ->", Pretty.brk 1] (map (single o pretty_item) items)));
   1.240 -  in Pretty.string_of (Pretty.big_list "Cyclic dependency of constants:" (map prt_cycle css)) end;
   1.241 +    val specs' =
   1.242 +      Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1;
   1.243 +  in make_def (specs', pattern, restricts, reducts) end;
   1.244  
   1.245 -fun merge
   1.246 -   (Defs {specs = specs1, insts = insts1, deps = deps1},
   1.247 -    Defs {specs = specs2, insts = insts2, deps = deps2}) =
   1.248 +fun merge pp (Defs defs1, Defs defs2) =
   1.249    let
   1.250 -    val specs' = (specs1, specs2) |> Symtab.join (fn c => fn ((b, sps1), (_, sps2)) =>
   1.251 -      (b, Inttab.fold (fn sp2 => (disjoint_specs c sp2 sps1; Inttab.update sp2)) sps2 sps1));
   1.252 -    val insts' = Symtab.merge_list (op =) (insts1, insts2);
   1.253 -    val items' = propagate_deps insts' (Items.merge_acyclic (K true) (deps1, deps2))
   1.254 -      handle Items.CYCLES cycles => error (cycle_msg cycles);
   1.255 -  in make_defs (specs', insts', items') end;
   1.256 +    fun add_deps (c, args) pat deps defs =
   1.257 +      if AList.defined (op =) (reducts_of defs c) args then defs
   1.258 +      else dependencies (print_const pp) (c, args) pat deps defs;
   1.259 +    fun add_def (c, {pattern, restricts, reducts, ...}) =
   1.260 +      fold (fn (args, deps) => add_deps (c, args) (pattern, restricts) deps) reducts;
   1.261 +  in Defs (Symtab.join join_specs (defs1, defs2)) |> Symtab.fold add_def defs2 end;
   1.262  
   1.263  
   1.264  (* define *)
   1.265  
   1.266 -fun pure_args args = forall Term.is_TVar args andalso not (has_duplicates (op =) args);
   1.267 -
   1.268 -fun define const_typargs unchecked is_def module name lhs rhs defs = defs
   1.269 -    |> map_defs (fn (specs, insts, deps) =>
   1.270 +fun define pp consts unchecked is_def module name lhs rhs (Defs defs) =
   1.271    let
   1.272 -    val (c, T) = lhs;
   1.273 -    val args = const_typargs lhs;
   1.274 -    val no_overloading = pure_args args;
   1.275 -    val rec_args = (case args of [Type (_, Ts)] => if pure_args Ts then Ts else [] | _ => []);
   1.276 +    val prt = print_const pp;
   1.277 +    fun typargs const = (#1 const, Consts.typargs consts const);
   1.278  
   1.279 -    val lhs' = make_item (c, args);
   1.280 -    val rhs' =
   1.281 -      if unchecked then []
   1.282 -      else rhs |> map_filter (fn (c', T') =>
   1.283 -        let val args' = const_typargs (c', T') in
   1.284 -          if gen_subset (op =) (args', rec_args) then NONE
   1.285 -          else SOME (make_item (c', if no_overloading_of defs c' then [] else args'))
   1.286 -        end);
   1.287 +    val (c, args) = typargs lhs;
   1.288 +    val pat =
   1.289 +      if unchecked then (Unknown, [])
   1.290 +      else the_pattern prt name (c, args);
   1.291 +    val spec =
   1.292 +      (serial (), {is_def = is_def, module = module, name = name, lhs = #2 lhs, rhs = rhs});
   1.293  
   1.294 -    val spec = (serial (), {is_def = is_def, module = module, name = name, lhs = T, rhs = rhs});
   1.295 -    val specs' = specs
   1.296 -      |> Symtab.default (c, (false, Inttab.empty))
   1.297 -      |> Symtab.map_entry c (fn (_, sps) =>
   1.298 -          (disjoint_specs c spec sps; (no_overloading, Inttab.update spec sps)));
   1.299 -    val insts' = insts |> fold (fn i as Instance (c, _) =>
   1.300 -        Symtab.insert_list (op =) (c, i) | _ => I) (lhs' :: rhs');
   1.301 -    val deps' = deps
   1.302 -      |> fold (Items.default_node o rpair ()) (lhs' :: rhs')
   1.303 -      |> Items.add_deps_acyclic (lhs', rhs')
   1.304 -      |> propagate_deps insts'
   1.305 -      handle Items.CYCLES cycles => error (cycle_msg cycles);
   1.306 -
   1.307 -  in (specs', insts', deps') end);
   1.308 +    val defs' = defs
   1.309 +      |> Symtab.default (c, default_def pat)
   1.310 +      |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
   1.311 +        let
   1.312 +          val _ = disjoint_specs c spec specs;
   1.313 +          val specs' = Inttab.update spec specs;
   1.314 +        in (specs', pattern, restricts, reducts) end));
   1.315 +  in Defs defs' |> (if unchecked then I else dependencies prt (c, args) pat (map typargs rhs)) end;
   1.316  
   1.317  end;