src/Pure/defs.ML
author wenzelm
Sat May 20 23:45:37 2006 +0200 (2006-05-20)
changeset 19695 7706aeac6cf1
parent 19692 bad13b32c0f3
child 19697 423af2e013b8
permissions -rw-r--r--
made smlnj happy;
wenzelm@17707
     1
(*  Title:      Pure/defs.ML
obua@16108
     2
    ID:         $Id$
wenzelm@17707
     3
    Author:     Makarius
obua@16108
     4
wenzelm@19692
     5
Global well-formedness checks for constant definitions.  Covers plain
wenzelm@19692
     6
definitions and simple sub-structural overloading (depending on a
wenzelm@19692
     7
single type argument).
obua@16108
     8
*)
obua@16108
     9
wenzelm@16877
    10
signature DEFS =
wenzelm@16877
    11
sig
wenzelm@17707
    12
  type T
wenzelm@19569
    13
  val specifications_of: T -> string ->
wenzelm@19569
    14
   (serial * {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list}) list
wenzelm@19590
    15
  val empty: T
wenzelm@19692
    16
  val merge: Pretty.pp -> T * T -> T
wenzelm@19692
    17
  val define: Pretty.pp -> Consts.T ->
wenzelm@19628
    18
    bool -> bool -> string -> string -> string * typ -> (string * typ) list -> T -> T
obua@16108
    19
end
obua@16108
    20
wenzelm@17711
    21
structure Defs: DEFS =
wenzelm@17707
    22
struct
obua@16108
    23
wenzelm@19692
    24
(* consts with type arguments *)
wenzelm@19613
    25
wenzelm@19692
    26
fun print_const pp (c, args) =
wenzelm@19613
    27
  let
wenzelm@19692
    28
    val prt_args =
wenzelm@19692
    29
      if null args then []
wenzelm@19692
    30
      else [Pretty.brk 1, Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)];
wenzelm@19692
    31
  in Pretty.string_of (Pretty.block (Pretty.str c :: prt_args)) end;
wenzelm@19613
    32
wenzelm@19613
    33
wenzelm@19692
    34
(* source specs *)
wenzelm@16877
    35
wenzelm@19569
    36
type spec = {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list};
wenzelm@19569
    37
wenzelm@17707
    38
fun disjoint_types T U =
wenzelm@17707
    39
  (Type.raw_unify (T, Logic.incr_tvar (maxidx_of_typ T + 1) U) Vartab.empty; false)
wenzelm@17707
    40
    handle Type.TUNIFY => true;
obua@16308
    41
wenzelm@19613
    42
fun disjoint_specs c (i, {lhs = T, name = a, ...}: spec) =
wenzelm@19569
    43
  Inttab.forall (fn (j, {lhs = U, name = b, ...}: spec) =>
wenzelm@19569
    44
    i = j orelse not (Type.could_unify (T, U)) orelse disjoint_types T U orelse
wenzelm@19569
    45
      error ("Type clash in specifications " ^ quote a ^ " and " ^ quote b ^
wenzelm@19569
    46
        " for constant " ^ quote c));
wenzelm@16877
    47
wenzelm@19624
    48
wenzelm@19692
    49
(* patterns *)
wenzelm@19692
    50
wenzelm@19692
    51
datatype pattern = Unknown | Plain | Overloaded;
wenzelm@19692
    52
wenzelm@19692
    53
fun str_of_pattern Overloaded = "overloading"
wenzelm@19692
    54
  | str_of_pattern _ = "no overloading";
wenzelm@19692
    55
wenzelm@19692
    56
fun merge_pattern c (p1, p2) =
wenzelm@19692
    57
  if p1 = p2 orelse p2 = Unknown then p1
wenzelm@19692
    58
  else if p1 = Unknown then p2
wenzelm@19692
    59
  else error ("Inconsistent type patterns for constant " ^ quote c ^ ":\n" ^
wenzelm@19692
    60
    str_of_pattern p1 ^ " versus " ^ str_of_pattern p2);
wenzelm@19692
    61
wenzelm@19692
    62
fun plain_args args =
wenzelm@19692
    63
  forall Term.is_TVar args andalso not (has_duplicates (op =) args);
wenzelm@19692
    64
wenzelm@19692
    65
fun the_pattern _ name (c, [Type (a, args)]) =
wenzelm@19692
    66
      (Overloaded, if plain_args args then [] else [(a, (args, name))])
wenzelm@19692
    67
  | the_pattern prt _ (c, args) =
wenzelm@19692
    68
      if plain_args args then (Plain, [])
wenzelm@19692
    69
      else error ("Illegal type pattern for constant " ^ prt (c, args));
wenzelm@19692
    70
wenzelm@19692
    71
wenzelm@19692
    72
(* datatype defs *)
wenzelm@19692
    73
wenzelm@19692
    74
type def =
wenzelm@19692
    75
 {specs: spec Inttab.table,
wenzelm@19692
    76
  pattern: pattern,
wenzelm@19692
    77
  restricts: (string * (typ list * string)) list,
wenzelm@19692
    78
  reducts: (typ list * (string * typ list) list) list};
wenzelm@19692
    79
wenzelm@19692
    80
fun make_def (specs, pattern, restricts, reducts) =
wenzelm@19692
    81
  {specs = specs, pattern = pattern, restricts = restricts, reducts = reducts}: def;
wenzelm@19692
    82
wenzelm@19692
    83
fun map_def f ({specs, pattern, restricts, reducts}: def) =
wenzelm@19692
    84
  make_def (f (specs, pattern, restricts, reducts));
wenzelm@19692
    85
wenzelm@19692
    86
fun default_def (pattern, restricts) = make_def (Inttab.empty, pattern, restricts, []);
wenzelm@19692
    87
wenzelm@19692
    88
datatype T = Defs of def Symtab.table;
wenzelm@19692
    89
val empty = Defs Symtab.empty;
wenzelm@19692
    90
wenzelm@19692
    91
fun lookup_list which (Defs defs) c =
wenzelm@19692
    92
  (case Symtab.lookup defs c of
wenzelm@19692
    93
    SOME def => which def
wenzelm@19692
    94
  | NONE => []);
wenzelm@19692
    95
wenzelm@19692
    96
val specifications_of = lookup_list (Inttab.dest o #specs);
wenzelm@19692
    97
val restricts_of = lookup_list #restricts;
wenzelm@19692
    98
val reducts_of = lookup_list #reducts;
wenzelm@19692
    99
wenzelm@19692
   100
wenzelm@19692
   101
(* normalize defs *)
wenzelm@19692
   102
wenzelm@19692
   103
fun matcher arg =
wenzelm@19692
   104
  Option.map Envir.typ_subst_TVars
wenzelm@19692
   105
    (SOME (Type.raw_matches arg Vartab.empty) handle Type.TYPE_MATCH => NONE);
wenzelm@19692
   106
wenzelm@19692
   107
fun restriction prt defs (c, args) =
wenzelm@19692
   108
  (case args of
wenzelm@19692
   109
    [Type (a, Us)] =>
wenzelm@19692
   110
      (case AList.lookup (op =) (restricts_of defs c) a of
wenzelm@19692
   111
        SOME (Ts, name) =>
wenzelm@19692
   112
          if is_some (matcher (Ts, Us)) then ()
wenzelm@19692
   113
          else error ("Occurrence of overloaded constant " ^ prt (c, args) ^
wenzelm@19692
   114
            "\nviolates restriction " ^ prt (c, Ts) ^ "\nimposed by " ^ quote name)
wenzelm@19692
   115
      | NONE => ())
wenzelm@19692
   116
  | _ => ());
wenzelm@19692
   117
wenzelm@19692
   118
fun reduction defs deps =
wenzelm@19692
   119
  let
wenzelm@19692
   120
    fun reduct Us (Ts, rhs) =
wenzelm@19692
   121
      (case matcher (Ts, Us) of
wenzelm@19692
   122
        NONE => NONE
wenzelm@19692
   123
      | SOME subst => SOME (map (apsnd (map subst)) rhs));
wenzelm@19692
   124
    fun reducts (d, Us) = get_first (reduct Us) (reducts_of defs d);
wenzelm@19692
   125
wenzelm@19692
   126
    fun add (NONE, dp) = insert (op =) dp
wenzelm@19692
   127
      | add (SOME dps, _) = fold (insert (op =)) dps;
wenzelm@19692
   128
    val deps' = map (`reducts) deps;
wenzelm@19692
   129
  in
wenzelm@19692
   130
    if forall (is_none o #1) deps' then NONE
wenzelm@19692
   131
    else SOME (fold_rev add deps' [])
wenzelm@19692
   132
  end;
wenzelm@19692
   133
wenzelm@19692
   134
fun normalize prt defs (c, args) deps =
wenzelm@19692
   135
  let
wenzelm@19692
   136
    val reds = reduction defs deps;
wenzelm@19692
   137
    val deps' = the_default deps reds;
wenzelm@19692
   138
    val _ = List.app (restriction prt defs) ((c, args) :: deps');
wenzelm@19692
   139
    val _ = deps' |> List.app (fn (c', args') =>
wenzelm@19692
   140
      if c' = c andalso is_some (matcher (args, args')) then
wenzelm@19692
   141
        error ("Circular dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (c, args'))
wenzelm@19692
   142
      else ());
wenzelm@19692
   143
  in reds end;
wenzelm@19692
   144
wenzelm@19692
   145
wenzelm@19692
   146
(* dependencies *)
wenzelm@19692
   147
wenzelm@19692
   148
fun normalize_deps prt defs0 (Defs defs) =
wenzelm@19692
   149
  let
wenzelm@19692
   150
    fun norm const deps = perhaps (normalize prt defs0 const) deps;
wenzelm@19695
   151
    fun norm_update (c, {reducts, ...}: def) =
wenzelm@19692
   152
      let val reducts' = reducts |> map (fn (args, deps) => (args, norm (c, args) deps)) in
wenzelm@19692
   153
        if reducts = reducts' then I
wenzelm@19692
   154
        else Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
wenzelm@19692
   155
          (specs, pattern, restricts, reducts')))
wenzelm@19692
   156
      end;
wenzelm@19692
   157
  in Defs (Symtab.fold norm_update defs defs) end;
wenzelm@19692
   158
wenzelm@19692
   159
fun dependencies prt (c, args) pat deps (Defs defs) =
wenzelm@19692
   160
  let
wenzelm@19692
   161
    val deps' = perhaps (normalize prt (Defs defs) (c, args)) deps;
wenzelm@19692
   162
    val defs' = defs
wenzelm@19692
   163
      |> Symtab.default (c, default_def pat)
wenzelm@19692
   164
      |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
wenzelm@19692
   165
        let
wenzelm@19692
   166
          val pattern' = merge_pattern c (pattern, #1 pat);
wenzelm@19692
   167
          val restricts' = Library.merge (op =) (restricts, #2 pat);
wenzelm@19692
   168
          val reducts' = insert (op =) (args, deps') reducts;
wenzelm@19692
   169
        in (specs, pattern', restricts', reducts') end));
wenzelm@19692
   170
  in normalize_deps prt (Defs defs') (Defs defs') end;
wenzelm@19692
   171
wenzelm@19692
   172
wenzelm@19624
   173
(* merge *)
wenzelm@19624
   174
wenzelm@19692
   175
fun join_specs c ({specs = specs1, pattern, restricts, reducts}, {specs = specs2, ...}: def) =
wenzelm@19613
   176
  let
wenzelm@19692
   177
    val specs' =
wenzelm@19692
   178
      Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1;
wenzelm@19692
   179
  in make_def (specs', pattern, restricts, reducts) end;
wenzelm@16982
   180
wenzelm@19692
   181
fun merge pp (Defs defs1, Defs defs2) =
wenzelm@19613
   182
  let
wenzelm@19692
   183
    fun add_deps (c, args) pat deps defs =
wenzelm@19692
   184
      if AList.defined (op =) (reducts_of defs c) args then defs
wenzelm@19692
   185
      else dependencies (print_const pp) (c, args) pat deps defs;
wenzelm@19695
   186
    fun add_def (c, {pattern, restricts, reducts, ...}: def) =
wenzelm@19692
   187
      fold (fn (args, deps) => add_deps (c, args) (pattern, restricts) deps) reducts;
wenzelm@19692
   188
  in Defs (Symtab.join join_specs (defs1, defs2)) |> Symtab.fold add_def defs2 end;
wenzelm@19613
   189
wenzelm@19613
   190
wenzelm@19613
   191
(* define *)
wenzelm@19590
   192
wenzelm@19692
   193
fun define pp consts unchecked is_def module name lhs rhs (Defs defs) =
wenzelm@17707
   194
  let
wenzelm@19692
   195
    val prt = print_const pp;
wenzelm@19692
   196
    fun typargs const = (#1 const, Consts.typargs consts const);
wenzelm@17707
   197
wenzelm@19692
   198
    val (c, args) = typargs lhs;
wenzelm@19692
   199
    val pat =
wenzelm@19692
   200
      if unchecked then (Unknown, [])
wenzelm@19692
   201
      else the_pattern prt name (c, args);
wenzelm@19692
   202
    val spec =
wenzelm@19692
   203
      (serial (), {is_def = is_def, module = module, name = name, lhs = #2 lhs, rhs = rhs});
wenzelm@19628
   204
wenzelm@19692
   205
    val defs' = defs
wenzelm@19692
   206
      |> Symtab.default (c, default_def pat)
wenzelm@19692
   207
      |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
wenzelm@19692
   208
        let
wenzelm@19692
   209
          val _ = disjoint_specs c spec specs;
wenzelm@19692
   210
          val specs' = Inttab.update spec specs;
wenzelm@19692
   211
        in (specs', pattern, restricts, reducts) end));
wenzelm@19692
   212
  in Defs defs' |> (if unchecked then I else dependencies prt (c, args) pat (map typargs rhs)) end;
wenzelm@16877
   213
obua@16108
   214
end;