specifications_of: lhs/rhs represented as typargs;
authorwenzelm
Mon May 22 22:29:15 2006 +0200 (2006-05-22)
changeset 19697423af2e013b8
parent 19696 26a268c299d8
child 19698 f48cfaacd92c
specifications_of: lhs/rhs represented as typargs;
export pretty_const;
export dest;
more precise checking of lhs patterns;
more precise normalization;
misc cleanup;
src/Pure/defs.ML
     1.1 --- a/src/Pure/defs.ML	Mon May 22 21:27:01 2006 +0200
     1.2 +++ b/src/Pure/defs.ML	Mon May 22 22:29:15 2006 +0200
     1.3 @@ -9,9 +9,13 @@
     1.4  
     1.5  signature DEFS =
     1.6  sig
     1.7 +  val pretty_const: Pretty.pp -> string * typ list -> Pretty.T
     1.8    type T
     1.9 -  val specifications_of: T -> string ->
    1.10 -   (serial * {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list}) list
    1.11 +  val specifications_of: T -> string -> (serial * {is_def: bool, module: string, name: string,
    1.12 +    lhs: typ list, rhs: (string * typ list) list}) list
    1.13 +  val dest: T ->
    1.14 +   {restricts: ((string * typ list) * string) list,
    1.15 +    reducts: ((string * typ list) * (string * typ list) list) list}
    1.16    val empty: T
    1.17    val merge: Pretty.pp -> T * T -> T
    1.18    val define: Pretty.pp -> Consts.T ->
    1.19 @@ -21,72 +25,47 @@
    1.20  structure Defs: DEFS =
    1.21  struct
    1.22  
    1.23 -(* consts with type arguments *)
    1.24 +
    1.25 +(* type arguments *)
    1.26  
    1.27 -fun print_const pp (c, args) =
    1.28 +type args = typ list;
    1.29 +
    1.30 +fun pretty_const pp (c, args) =
    1.31    let
    1.32      val prt_args =
    1.33        if null args then []
    1.34 -      else [Pretty.brk 1, Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)];
    1.35 -  in Pretty.string_of (Pretty.block (Pretty.str c :: prt_args)) end;
    1.36 -
    1.37 -
    1.38 -(* source specs *)
    1.39 -
    1.40 -type spec = {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list};
    1.41 -
    1.42 -fun disjoint_types T U =
    1.43 -  (Type.raw_unify (T, Logic.incr_tvar (maxidx_of_typ T + 1) U) Vartab.empty; false)
    1.44 -    handle Type.TUNIFY => true;
    1.45 -
    1.46 -fun disjoint_specs c (i, {lhs = T, name = a, ...}: spec) =
    1.47 -  Inttab.forall (fn (j, {lhs = U, name = b, ...}: spec) =>
    1.48 -    i = j orelse not (Type.could_unify (T, U)) orelse disjoint_types T U orelse
    1.49 -      error ("Type clash in specifications " ^ quote a ^ " and " ^ quote b ^
    1.50 -        " for constant " ^ quote c));
    1.51 -
    1.52 +      else [Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)];
    1.53 +  in Pretty.block (Pretty.str c :: prt_args) end;
    1.54  
    1.55 -(* patterns *)
    1.56 -
    1.57 -datatype pattern = Unknown | Plain | Overloaded;
    1.58 -
    1.59 -fun str_of_pattern Overloaded = "overloading"
    1.60 -  | str_of_pattern _ = "no overloading";
    1.61 +fun disjoint_args (Ts, Us) =
    1.62 +  not (Type.could_unifys (Ts, Us)) orelse
    1.63 +    ((Type.raw_unifys (Ts, map (Logic.incr_tvar (maxidx_of_typs Ts + 1)) Us) Vartab.empty; false)
    1.64 +      handle Type.TUNIFY => true);
    1.65  
    1.66 -fun merge_pattern c (p1, p2) =
    1.67 -  if p1 = p2 orelse p2 = Unknown then p1
    1.68 -  else if p1 = Unknown then p2
    1.69 -  else error ("Inconsistent type patterns for constant " ^ quote c ^ ":\n" ^
    1.70 -    str_of_pattern p1 ^ " versus " ^ str_of_pattern p2);
    1.71 -
    1.72 -fun plain_args args =
    1.73 -  forall Term.is_TVar args andalso not (has_duplicates (op =) args);
    1.74 -
    1.75 -fun the_pattern _ name (c, [Type (a, args)]) =
    1.76 -      (Overloaded, if plain_args args then [] else [(a, (args, name))])
    1.77 -  | the_pattern prt _ (c, args) =
    1.78 -      if plain_args args then (Plain, [])
    1.79 -      else error ("Illegal type pattern for constant " ^ prt (c, args));
    1.80 +fun match_args (Ts, Us) =
    1.81 +  Option.map Envir.typ_subst_TVars
    1.82 +    (SOME (Type.raw_matches (Ts, Us) Vartab.empty) handle Type.TYPE_MATCH => NONE);
    1.83  
    1.84  
    1.85  (* datatype defs *)
    1.86  
    1.87 +type spec = {is_def: bool, module: string, name: string, lhs: args, rhs: (string * args) list};
    1.88 +
    1.89  type def =
    1.90   {specs: spec Inttab.table,
    1.91 -  pattern: pattern,
    1.92 -  restricts: (string * (typ list * string)) list,
    1.93 -  reducts: (typ list * (string * typ list) list) list};
    1.94 +  restricts: (args * string) list,
    1.95 +  reducts: (args * (string * args) list) list};
    1.96 +
    1.97 +fun make_def (specs, restricts, reducts) =
    1.98 +  {specs = specs, restricts = restricts, reducts = reducts}: def;
    1.99  
   1.100 -fun make_def (specs, pattern, restricts, reducts) =
   1.101 -  {specs = specs, pattern = pattern, restricts = restricts, reducts = reducts}: def;
   1.102 +fun map_def c f =
   1.103 +  Symtab.default (c, make_def (Inttab.empty, [], [])) #>
   1.104 +  Symtab.map_entry c (fn {specs, restricts, reducts}: def =>
   1.105 +    make_def (f (specs, restricts, reducts)));
   1.106  
   1.107 -fun map_def f ({specs, pattern, restricts, reducts}: def) =
   1.108 -  make_def (f (specs, pattern, restricts, reducts));
   1.109 -
   1.110 -fun default_def (pattern, restricts) = make_def (Inttab.empty, pattern, restricts, []);
   1.111  
   1.112  datatype T = Defs of def Symtab.table;
   1.113 -val empty = Defs Symtab.empty;
   1.114  
   1.115  fun lookup_list which (Defs defs) c =
   1.116    (case Symtab.lookup defs c of
   1.117 @@ -97,31 +76,46 @@
   1.118  val restricts_of = lookup_list #restricts;
   1.119  val reducts_of = lookup_list #reducts;
   1.120  
   1.121 -
   1.122 -(* normalize defs *)
   1.123 +fun dest (Defs defs) =
   1.124 +  let
   1.125 +    val restricts = Symtab.fold (fn (c, {restricts, ...}) =>
   1.126 +      fold (fn (args, name) => cons ((c, args), name)) restricts) defs [];
   1.127 +    val reducts = Symtab.fold (fn (c, {reducts, ...}) =>
   1.128 +      fold (fn (args, deps) => cons ((c, args), deps)) reducts) defs [];
   1.129 +  in {restricts = restricts, reducts = reducts} end;
   1.130  
   1.131 -fun matcher arg =
   1.132 -  Option.map Envir.typ_subst_TVars
   1.133 -    (SOME (Type.raw_matches arg Vartab.empty) handle Type.TYPE_MATCH => NONE);
   1.134 +val empty = Defs Symtab.empty;
   1.135 +
   1.136 +
   1.137 +(* specifications *)
   1.138  
   1.139 -fun restriction prt defs (c, args) =
   1.140 -  (case args of
   1.141 -    [Type (a, Us)] =>
   1.142 -      (case AList.lookup (op =) (restricts_of defs c) a of
   1.143 -        SOME (Ts, name) =>
   1.144 -          if is_some (matcher (Ts, Us)) then ()
   1.145 -          else error ("Occurrence of overloaded constant " ^ prt (c, args) ^
   1.146 -            "\nviolates restriction " ^ prt (c, Ts) ^ "\nimposed by " ^ quote name)
   1.147 -      | NONE => ())
   1.148 -  | _ => ());
   1.149 +fun disjoint_specs c (i, {lhs = Ts, name = a, ...}: spec) =
   1.150 +  Inttab.forall (fn (j, {lhs = Us, name = b, ...}: spec) =>
   1.151 +    i = j orelse disjoint_args (Ts, Us) orelse
   1.152 +      error ("Type clash in specifications " ^ quote a ^ " and " ^ quote b ^
   1.153 +        " for constant " ^ quote c));
   1.154  
   1.155 -fun reduction defs deps =
   1.156 +fun join_specs c ({specs = specs1, restricts, reducts}, {specs = specs2, ...}: def) =
   1.157 +  let
   1.158 +    val specs' =
   1.159 +      Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1;
   1.160 +  in make_def (specs', restricts, reducts) end;
   1.161 +
   1.162 +fun update_specs c spec = map_def c (fn (specs, restricts, reducts) =>
   1.163 +  (disjoint_specs c spec specs; (Inttab.update spec specs, restricts, reducts)));
   1.164 +
   1.165 +
   1.166 +(* normalization: reduction and well-formedness check *)
   1.167 +
   1.168 +local
   1.169 +
   1.170 +fun reduction reds_of deps =
   1.171    let
   1.172      fun reduct Us (Ts, rhs) =
   1.173 -      (case matcher (Ts, Us) of
   1.174 +      (case match_args (Ts, Us) of
   1.175          NONE => NONE
   1.176        | SOME subst => SOME (map (apsnd (map subst)) rhs));
   1.177 -    fun reducts (d, Us) = get_first (reduct Us) (reducts_of defs d);
   1.178 +    fun reducts (d: string, Us) = get_first (reduct Us) (reds_of d);
   1.179  
   1.180      fun add (NONE, dp) = insert (op =) dp
   1.181        | add (SOME dps, _) = fold (insert (op =)) dps;
   1.182 @@ -131,84 +125,114 @@
   1.183      else SOME (fold_rev add deps' [])
   1.184    end;
   1.185  
   1.186 -fun normalize prt defs (c, args) deps =
   1.187 +fun reductions reds_of deps =
   1.188 +  (case reduction reds_of deps of
   1.189 +    SOME deps' => reductions reds_of deps'
   1.190 +  | NONE => deps);
   1.191 +
   1.192 +fun contained U (Type (_, Ts)) = exists (fn T => T = U orelse contained U T) Ts
   1.193 +  | contained _ _ = false;
   1.194 +
   1.195 +fun wellformed pp rests_of (c, args) (d, Us) =
   1.196    let
   1.197 -    val reds = reduction defs deps;
   1.198 -    val deps' = the_default deps reds;
   1.199 -    val _ = List.app (restriction prt defs) ((c, args) :: deps');
   1.200 -    val _ = deps' |> List.app (fn (c', args') =>
   1.201 -      if c' = c andalso is_some (matcher (args, args')) then
   1.202 -        error ("Circular dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (c, args'))
   1.203 -      else ());
   1.204 -  in reds end;
   1.205 +    val prt = Pretty.string_of o pretty_const pp;
   1.206 +    fun err s1 s2 =
   1.207 +      error (s1 ^ " dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (d, Us) ^ s2);
   1.208 +  in
   1.209 +    exists (fn U => exists (contained U) args) Us orelse
   1.210 +    (c <> d andalso exists (member (op =) args) Us) orelse
   1.211 +      (case find_first (fn (Ts, _) => not (disjoint_args (Ts, Us))) (rests_of d) of
   1.212 +        NONE =>
   1.213 +          c <> d orelse is_none (match_args (args, Us)) orelse err "Circular" ""
   1.214 +      | SOME (Ts, name) =>
   1.215 +          if c = d then err "Circular" ("\n(via " ^ quote name ^ ")")
   1.216 +          else
   1.217 +            err "Malformed" ("\n(restriction " ^ prt (d, Ts) ^ " from " ^ quote name ^ ")"))
   1.218 +  end;
   1.219  
   1.220 -
   1.221 -(* dependencies *)
   1.222 -
   1.223 -fun normalize_deps prt defs0 (Defs defs) =
   1.224 +fun normalize pp rests_of reds_of (c, args) deps =
   1.225    let
   1.226 -    fun norm const deps = perhaps (normalize prt defs0 const) deps;
   1.227 -    fun norm_update (c, {reducts, ...}: def) =
   1.228 -      let val reducts' = reducts |> map (fn (args, deps) => (args, norm (c, args) deps)) in
   1.229 -        if reducts = reducts' then I
   1.230 -        else Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
   1.231 -          (specs, pattern, restricts, reducts')))
   1.232 -      end;
   1.233 -  in Defs (Symtab.fold norm_update defs defs) end;
   1.234 +    val deps' = reductions reds_of deps;
   1.235 +    val _ = forall (wellformed pp rests_of (c, args)) deps';
   1.236 +  in deps' end;
   1.237  
   1.238 -fun dependencies prt (c, args) pat deps (Defs defs) =
   1.239 +fun normalize_all pp (c, args) deps defs =
   1.240    let
   1.241 -    val deps' = perhaps (normalize prt (Defs defs) (c, args)) deps;
   1.242 +    val norm = normalize pp (restricts_of (Defs defs));
   1.243 +    val norm_rule = norm (fn c' => if c' = c then [(args, deps)] else []);
   1.244 +    val norm_defs = norm (reducts_of (Defs defs));
   1.245 +    fun norm_update (c', {reducts, ...}: def) =
   1.246 +      let val reducts' = reducts
   1.247 +        |> map (fn (args', deps') => (args', norm_defs (c', args') (norm_rule (c', args') deps')))
   1.248 +      in
   1.249 +        K (reducts <> reducts') ?
   1.250 +          map_def c' (fn (specs, restricts, reducts) => (specs, restricts, reducts'))
   1.251 +      end;
   1.252 +  in Symtab.fold norm_update defs defs end;
   1.253 +
   1.254 +in
   1.255 +
   1.256 +fun dependencies pp (c, args) restr deps (Defs defs) =
   1.257 +  let
   1.258 +    val deps' = normalize pp (restricts_of (Defs defs)) (reducts_of (Defs defs)) (c, args) deps;
   1.259      val defs' = defs
   1.260 -      |> Symtab.default (c, default_def pat)
   1.261 -      |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
   1.262 -        let
   1.263 -          val pattern' = merge_pattern c (pattern, #1 pat);
   1.264 -          val restricts' = Library.merge (op =) (restricts, #2 pat);
   1.265 -          val reducts' = insert (op =) (args, deps') reducts;
   1.266 -        in (specs, pattern', restricts', reducts') end));
   1.267 -  in normalize_deps prt (Defs defs') (Defs defs') end;
   1.268 +      |> map_def c (fn (specs, restricts, reducts) =>
   1.269 +        (specs, Library.merge (op =) (restricts, restr), reducts))
   1.270 +      |> normalize_all pp (c, args) deps';
   1.271 +    val deps'' =
   1.272 +      normalize pp (restricts_of (Defs defs')) (reducts_of (Defs defs')) (c, args) deps';
   1.273 +    val defs'' = defs'
   1.274 +      |> map_def c (fn (specs, restricts, reducts) =>
   1.275 +        (specs, restricts, insert (op =) (args, deps'') reducts));
   1.276 +  in Defs defs'' end;
   1.277 +
   1.278 +end;
   1.279  
   1.280  
   1.281  (* merge *)
   1.282  
   1.283 -fun join_specs c ({specs = specs1, pattern, restricts, reducts}, {specs = specs2, ...}: def) =
   1.284 -  let
   1.285 -    val specs' =
   1.286 -      Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1;
   1.287 -  in make_def (specs', pattern, restricts, reducts) end;
   1.288 -
   1.289  fun merge pp (Defs defs1, Defs defs2) =
   1.290    let
   1.291 -    fun add_deps (c, args) pat deps defs =
   1.292 +    fun add_deps (c, args) restr deps defs =
   1.293        if AList.defined (op =) (reducts_of defs c) args then defs
   1.294 -      else dependencies (print_const pp) (c, args) pat deps defs;
   1.295 -    fun add_def (c, {pattern, restricts, reducts, ...}: def) =
   1.296 -      fold (fn (args, deps) => add_deps (c, args) (pattern, restricts) deps) reducts;
   1.297 +      else dependencies pp (c, args) restr deps defs;
   1.298 +    fun add_def (c, {restricts, reducts, ...}: def) =
   1.299 +      fold (fn (args, deps) => add_deps (c, args) restricts deps) reducts;
   1.300    in Defs (Symtab.join join_specs (defs1, defs2)) |> Symtab.fold add_def defs2 end;
   1.301  
   1.302 +local  (* FIXME *)
   1.303 +  val merge_aux = merge
   1.304 +  val acc = Output.time_accumulator "Defs.merge"
   1.305 +in fun merge pp = acc (merge_aux pp) end;
   1.306 +
   1.307  
   1.308  (* define *)
   1.309  
   1.310 +fun plain_args args =
   1.311 +  forall Term.is_TVar args andalso not (has_duplicates (op =) args);
   1.312 +
   1.313  fun define pp consts unchecked is_def module name lhs rhs (Defs defs) =
   1.314    let
   1.315 -    val prt = print_const pp;
   1.316      fun typargs const = (#1 const, Consts.typargs consts const);
   1.317 -
   1.318      val (c, args) = typargs lhs;
   1.319 -    val pat =
   1.320 -      if unchecked then (Unknown, [])
   1.321 -      else the_pattern prt name (c, args);
   1.322 +    val deps = map typargs rhs;
   1.323 +    val restr =
   1.324 +      if plain_args args orelse
   1.325 +        (case args of [Type (a, rec_args)] => plain_args rec_args | _ => false)
   1.326 +      then [] else [(args, name)];
   1.327      val spec =
   1.328 -      (serial (), {is_def = is_def, module = module, name = name, lhs = #2 lhs, rhs = rhs});
   1.329 +      (serial (), {is_def = is_def, module = module, name = name, lhs = args, rhs = deps});
   1.330 +    val defs' = defs |> update_specs c spec;
   1.331 +  in Defs defs' |> (if unchecked then I else dependencies pp (c, args) restr deps) end;
   1.332 +
   1.333  
   1.334 -    val defs' = defs
   1.335 -      |> Symtab.default (c, default_def pat)
   1.336 -      |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
   1.337 -        let
   1.338 -          val _ = disjoint_specs c spec specs;
   1.339 -          val specs' = Inttab.update spec specs;
   1.340 -        in (specs', pattern, restricts, reducts) end));
   1.341 -  in Defs defs' |> (if unchecked then I else dependencies prt (c, args) pat (map typargs rhs)) end;
   1.342 +local  (* FIXME *)
   1.343 +  val define_aux = define
   1.344 +  val acc = Output.time_accumulator "Defs.define"
   1.345 +in
   1.346 +  fun define pp consts unchecked is_def module name lhs rhs =
   1.347 +    acc (define_aux pp consts unchecked is_def module name lhs rhs)
   1.348 +end;
   1.349 +
   1.350  
   1.351  end;