src/Pure/defs.ML
changeset 19697 423af2e013b8
parent 19695 7706aeac6cf1
child 19701 c07c31ac689b
equal deleted inserted replaced
19696:26a268c299d8 19697:423af2e013b8
     7 single type argument).
     7 single type argument).
     8 *)
     8 *)
     9 
     9 
    10 signature DEFS =
    10 signature DEFS =
    11 sig
    11 sig
       
    12   val pretty_const: Pretty.pp -> string * typ list -> Pretty.T
    12   type T
    13   type T
    13   val specifications_of: T -> string ->
    14   val specifications_of: T -> string -> (serial * {is_def: bool, module: string, name: string,
    14    (serial * {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list}) list
    15     lhs: typ list, rhs: (string * typ list) list}) list
       
    16   val dest: T ->
       
    17    {restricts: ((string * typ list) * string) list,
       
    18     reducts: ((string * typ list) * (string * typ list) list) list}
    15   val empty: T
    19   val empty: T
    16   val merge: Pretty.pp -> T * T -> T
    20   val merge: Pretty.pp -> T * T -> T
    17   val define: Pretty.pp -> Consts.T ->
    21   val define: Pretty.pp -> Consts.T ->
    18     bool -> bool -> string -> string -> string * typ -> (string * typ) list -> T -> T
    22     bool -> bool -> string -> string -> string * typ -> (string * typ) list -> T -> T
    19 end
    23 end
    20 
    24 
    21 structure Defs: DEFS =
    25 structure Defs: DEFS =
    22 struct
    26 struct
    23 
    27 
    24 (* consts with type arguments *)
    28 
    25 
    29 (* type arguments *)
    26 fun print_const pp (c, args) =
    30 
       
    31 type args = typ list;
       
    32 
       
    33 fun pretty_const pp (c, args) =
    27   let
    34   let
    28     val prt_args =
    35     val prt_args =
    29       if null args then []
    36       if null args then []
    30       else [Pretty.brk 1, Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)];
    37       else [Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)];
    31   in Pretty.string_of (Pretty.block (Pretty.str c :: prt_args)) end;
    38   in Pretty.block (Pretty.str c :: prt_args) end;
    32 
    39 
    33 
    40 fun disjoint_args (Ts, Us) =
    34 (* source specs *)
    41   not (Type.could_unifys (Ts, Us)) orelse
    35 
    42     ((Type.raw_unifys (Ts, map (Logic.incr_tvar (maxidx_of_typs Ts + 1)) Us) Vartab.empty; false)
    36 type spec = {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list};
    43       handle Type.TUNIFY => true);
    37 
    44 
    38 fun disjoint_types T U =
    45 fun match_args (Ts, Us) =
    39   (Type.raw_unify (T, Logic.incr_tvar (maxidx_of_typ T + 1) U) Vartab.empty; false)
    46   Option.map Envir.typ_subst_TVars
    40     handle Type.TUNIFY => true;
    47     (SOME (Type.raw_matches (Ts, Us) Vartab.empty) handle Type.TYPE_MATCH => NONE);
    41 
       
    42 fun disjoint_specs c (i, {lhs = T, name = a, ...}: spec) =
       
    43   Inttab.forall (fn (j, {lhs = U, name = b, ...}: spec) =>
       
    44     i = j orelse not (Type.could_unify (T, U)) orelse disjoint_types T U orelse
       
    45       error ("Type clash in specifications " ^ quote a ^ " and " ^ quote b ^
       
    46         " for constant " ^ quote c));
       
    47 
       
    48 
       
    49 (* patterns *)
       
    50 
       
    51 datatype pattern = Unknown | Plain | Overloaded;
       
    52 
       
    53 fun str_of_pattern Overloaded = "overloading"
       
    54   | str_of_pattern _ = "no overloading";
       
    55 
       
    56 fun merge_pattern c (p1, p2) =
       
    57   if p1 = p2 orelse p2 = Unknown then p1
       
    58   else if p1 = Unknown then p2
       
    59   else error ("Inconsistent type patterns for constant " ^ quote c ^ ":\n" ^
       
    60     str_of_pattern p1 ^ " versus " ^ str_of_pattern p2);
       
    61 
       
    62 fun plain_args args =
       
    63   forall Term.is_TVar args andalso not (has_duplicates (op =) args);
       
    64 
       
    65 fun the_pattern _ name (c, [Type (a, args)]) =
       
    66       (Overloaded, if plain_args args then [] else [(a, (args, name))])
       
    67   | the_pattern prt _ (c, args) =
       
    68       if plain_args args then (Plain, [])
       
    69       else error ("Illegal type pattern for constant " ^ prt (c, args));
       
    70 
    48 
    71 
    49 
    72 (* datatype defs *)
    50 (* datatype defs *)
       
    51 
       
    52 type spec = {is_def: bool, module: string, name: string, lhs: args, rhs: (string * args) list};
    73 
    53 
    74 type def =
    54 type def =
    75  {specs: spec Inttab.table,
    55  {specs: spec Inttab.table,
    76   pattern: pattern,
    56   restricts: (args * string) list,
    77   restricts: (string * (typ list * string)) list,
    57   reducts: (args * (string * args) list) list};
    78   reducts: (typ list * (string * typ list) list) list};
    58 
    79 
    59 fun make_def (specs, restricts, reducts) =
    80 fun make_def (specs, pattern, restricts, reducts) =
    60   {specs = specs, restricts = restricts, reducts = reducts}: def;
    81   {specs = specs, pattern = pattern, restricts = restricts, reducts = reducts}: def;
    61 
    82 
    62 fun map_def c f =
    83 fun map_def f ({specs, pattern, restricts, reducts}: def) =
    63   Symtab.default (c, make_def (Inttab.empty, [], [])) #>
    84   make_def (f (specs, pattern, restricts, reducts));
    64   Symtab.map_entry c (fn {specs, restricts, reducts}: def =>
    85 
    65     make_def (f (specs, restricts, reducts)));
    86 fun default_def (pattern, restricts) = make_def (Inttab.empty, pattern, restricts, []);
    66 
    87 
    67 
    88 datatype T = Defs of def Symtab.table;
    68 datatype T = Defs of def Symtab.table;
    89 val empty = Defs Symtab.empty;
       
    90 
    69 
    91 fun lookup_list which (Defs defs) c =
    70 fun lookup_list which (Defs defs) c =
    92   (case Symtab.lookup defs c of
    71   (case Symtab.lookup defs c of
    93     SOME def => which def
    72     SOME def => which def
    94   | NONE => []);
    73   | NONE => []);
    95 
    74 
    96 val specifications_of = lookup_list (Inttab.dest o #specs);
    75 val specifications_of = lookup_list (Inttab.dest o #specs);
    97 val restricts_of = lookup_list #restricts;
    76 val restricts_of = lookup_list #restricts;
    98 val reducts_of = lookup_list #reducts;
    77 val reducts_of = lookup_list #reducts;
    99 
    78 
   100 
    79 fun dest (Defs defs) =
   101 (* normalize defs *)
    80   let
   102 
    81     val restricts = Symtab.fold (fn (c, {restricts, ...}) =>
   103 fun matcher arg =
    82       fold (fn (args, name) => cons ((c, args), name)) restricts) defs [];
   104   Option.map Envir.typ_subst_TVars
    83     val reducts = Symtab.fold (fn (c, {reducts, ...}) =>
   105     (SOME (Type.raw_matches arg Vartab.empty) handle Type.TYPE_MATCH => NONE);
    84       fold (fn (args, deps) => cons ((c, args), deps)) reducts) defs [];
   106 
    85   in {restricts = restricts, reducts = reducts} end;
   107 fun restriction prt defs (c, args) =
    86 
   108   (case args of
    87 val empty = Defs Symtab.empty;
   109     [Type (a, Us)] =>
    88 
   110       (case AList.lookup (op =) (restricts_of defs c) a of
    89 
   111         SOME (Ts, name) =>
    90 (* specifications *)
   112           if is_some (matcher (Ts, Us)) then ()
    91 
   113           else error ("Occurrence of overloaded constant " ^ prt (c, args) ^
    92 fun disjoint_specs c (i, {lhs = Ts, name = a, ...}: spec) =
   114             "\nviolates restriction " ^ prt (c, Ts) ^ "\nimposed by " ^ quote name)
    93   Inttab.forall (fn (j, {lhs = Us, name = b, ...}: spec) =>
   115       | NONE => ())
    94     i = j orelse disjoint_args (Ts, Us) orelse
   116   | _ => ());
    95       error ("Type clash in specifications " ^ quote a ^ " and " ^ quote b ^
   117 
    96         " for constant " ^ quote c));
   118 fun reduction defs deps =
    97 
       
    98 fun join_specs c ({specs = specs1, restricts, reducts}, {specs = specs2, ...}: def) =
       
    99   let
       
   100     val specs' =
       
   101       Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1;
       
   102   in make_def (specs', restricts, reducts) end;
       
   103 
       
   104 fun update_specs c spec = map_def c (fn (specs, restricts, reducts) =>
       
   105   (disjoint_specs c spec specs; (Inttab.update spec specs, restricts, reducts)));
       
   106 
       
   107 
       
   108 (* normalization: reduction and well-formedness check *)
       
   109 
       
   110 local
       
   111 
       
   112 fun reduction reds_of deps =
   119   let
   113   let
   120     fun reduct Us (Ts, rhs) =
   114     fun reduct Us (Ts, rhs) =
   121       (case matcher (Ts, Us) of
   115       (case match_args (Ts, Us) of
   122         NONE => NONE
   116         NONE => NONE
   123       | SOME subst => SOME (map (apsnd (map subst)) rhs));
   117       | SOME subst => SOME (map (apsnd (map subst)) rhs));
   124     fun reducts (d, Us) = get_first (reduct Us) (reducts_of defs d);
   118     fun reducts (d: string, Us) = get_first (reduct Us) (reds_of d);
   125 
   119 
   126     fun add (NONE, dp) = insert (op =) dp
   120     fun add (NONE, dp) = insert (op =) dp
   127       | add (SOME dps, _) = fold (insert (op =)) dps;
   121       | add (SOME dps, _) = fold (insert (op =)) dps;
   128     val deps' = map (`reducts) deps;
   122     val deps' = map (`reducts) deps;
   129   in
   123   in
   130     if forall (is_none o #1) deps' then NONE
   124     if forall (is_none o #1) deps' then NONE
   131     else SOME (fold_rev add deps' [])
   125     else SOME (fold_rev add deps' [])
   132   end;
   126   end;
   133 
   127 
   134 fun normalize prt defs (c, args) deps =
   128 fun reductions reds_of deps =
   135   let
   129   (case reduction reds_of deps of
   136     val reds = reduction defs deps;
   130     SOME deps' => reductions reds_of deps'
   137     val deps' = the_default deps reds;
   131   | NONE => deps);
   138     val _ = List.app (restriction prt defs) ((c, args) :: deps');
   132 
   139     val _ = deps' |> List.app (fn (c', args') =>
   133 fun contained U (Type (_, Ts)) = exists (fn T => T = U orelse contained U T) Ts
   140       if c' = c andalso is_some (matcher (args, args')) then
   134   | contained _ _ = false;
   141         error ("Circular dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (c, args'))
   135 
   142       else ());
   136 fun wellformed pp rests_of (c, args) (d, Us) =
   143   in reds end;
   137   let
   144 
   138     val prt = Pretty.string_of o pretty_const pp;
   145 
   139     fun err s1 s2 =
   146 (* dependencies *)
   140       error (s1 ^ " dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (d, Us) ^ s2);
   147 
   141   in
   148 fun normalize_deps prt defs0 (Defs defs) =
   142     exists (fn U => exists (contained U) args) Us orelse
   149   let
   143     (c <> d andalso exists (member (op =) args) Us) orelse
   150     fun norm const deps = perhaps (normalize prt defs0 const) deps;
   144       (case find_first (fn (Ts, _) => not (disjoint_args (Ts, Us))) (rests_of d) of
   151     fun norm_update (c, {reducts, ...}: def) =
   145         NONE =>
   152       let val reducts' = reducts |> map (fn (args, deps) => (args, norm (c, args) deps)) in
   146           c <> d orelse is_none (match_args (args, Us)) orelse err "Circular" ""
   153         if reducts = reducts' then I
   147       | SOME (Ts, name) =>
   154         else Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
   148           if c = d then err "Circular" ("\n(via " ^ quote name ^ ")")
   155           (specs, pattern, restricts, reducts')))
   149           else
       
   150             err "Malformed" ("\n(restriction " ^ prt (d, Ts) ^ " from " ^ quote name ^ ")"))
       
   151   end;
       
   152 
       
   153 fun normalize pp rests_of reds_of (c, args) deps =
       
   154   let
       
   155     val deps' = reductions reds_of deps;
       
   156     val _ = forall (wellformed pp rests_of (c, args)) deps';
       
   157   in deps' end;
       
   158 
       
   159 fun normalize_all pp (c, args) deps defs =
       
   160   let
       
   161     val norm = normalize pp (restricts_of (Defs defs));
       
   162     val norm_rule = norm (fn c' => if c' = c then [(args, deps)] else []);
       
   163     val norm_defs = norm (reducts_of (Defs defs));
       
   164     fun norm_update (c', {reducts, ...}: def) =
       
   165       let val reducts' = reducts
       
   166         |> map (fn (args', deps') => (args', norm_defs (c', args') (norm_rule (c', args') deps')))
       
   167       in
       
   168         K (reducts <> reducts') ?
       
   169           map_def c' (fn (specs, restricts, reducts) => (specs, restricts, reducts'))
   156       end;
   170       end;
   157   in Defs (Symtab.fold norm_update defs defs) end;
   171   in Symtab.fold norm_update defs defs end;
   158 
   172 
   159 fun dependencies prt (c, args) pat deps (Defs defs) =
   173 in
   160   let
   174 
   161     val deps' = perhaps (normalize prt (Defs defs) (c, args)) deps;
   175 fun dependencies pp (c, args) restr deps (Defs defs) =
       
   176   let
       
   177     val deps' = normalize pp (restricts_of (Defs defs)) (reducts_of (Defs defs)) (c, args) deps;
   162     val defs' = defs
   178     val defs' = defs
   163       |> Symtab.default (c, default_def pat)
   179       |> map_def c (fn (specs, restricts, reducts) =>
   164       |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
   180         (specs, Library.merge (op =) (restricts, restr), reducts))
   165         let
   181       |> normalize_all pp (c, args) deps';
   166           val pattern' = merge_pattern c (pattern, #1 pat);
   182     val deps'' =
   167           val restricts' = Library.merge (op =) (restricts, #2 pat);
   183       normalize pp (restricts_of (Defs defs')) (reducts_of (Defs defs')) (c, args) deps';
   168           val reducts' = insert (op =) (args, deps') reducts;
   184     val defs'' = defs'
   169         in (specs, pattern', restricts', reducts') end));
   185       |> map_def c (fn (specs, restricts, reducts) =>
   170   in normalize_deps prt (Defs defs') (Defs defs') end;
   186         (specs, restricts, insert (op =) (args, deps'') reducts));
       
   187   in Defs defs'' end;
       
   188 
       
   189 end;
   171 
   190 
   172 
   191 
   173 (* merge *)
   192 (* merge *)
   174 
   193 
   175 fun join_specs c ({specs = specs1, pattern, restricts, reducts}, {specs = specs2, ...}: def) =
       
   176   let
       
   177     val specs' =
       
   178       Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1;
       
   179   in make_def (specs', pattern, restricts, reducts) end;
       
   180 
       
   181 fun merge pp (Defs defs1, Defs defs2) =
   194 fun merge pp (Defs defs1, Defs defs2) =
   182   let
   195   let
   183     fun add_deps (c, args) pat deps defs =
   196     fun add_deps (c, args) restr deps defs =
   184       if AList.defined (op =) (reducts_of defs c) args then defs
   197       if AList.defined (op =) (reducts_of defs c) args then defs
   185       else dependencies (print_const pp) (c, args) pat deps defs;
   198       else dependencies pp (c, args) restr deps defs;
   186     fun add_def (c, {pattern, restricts, reducts, ...}: def) =
   199     fun add_def (c, {restricts, reducts, ...}: def) =
   187       fold (fn (args, deps) => add_deps (c, args) (pattern, restricts) deps) reducts;
   200       fold (fn (args, deps) => add_deps (c, args) restricts deps) reducts;
   188   in Defs (Symtab.join join_specs (defs1, defs2)) |> Symtab.fold add_def defs2 end;
   201   in Defs (Symtab.join join_specs (defs1, defs2)) |> Symtab.fold add_def defs2 end;
   189 
   202 
       
   203 local  (* FIXME *)
       
   204   val merge_aux = merge
       
   205   val acc = Output.time_accumulator "Defs.merge"
       
   206 in fun merge pp = acc (merge_aux pp) end;
       
   207 
   190 
   208 
   191 (* define *)
   209 (* define *)
   192 
   210 
       
   211 fun plain_args args =
       
   212   forall Term.is_TVar args andalso not (has_duplicates (op =) args);
       
   213 
   193 fun define pp consts unchecked is_def module name lhs rhs (Defs defs) =
   214 fun define pp consts unchecked is_def module name lhs rhs (Defs defs) =
   194   let
   215   let
   195     val prt = print_const pp;
       
   196     fun typargs const = (#1 const, Consts.typargs consts const);
   216     fun typargs const = (#1 const, Consts.typargs consts const);
   197 
       
   198     val (c, args) = typargs lhs;
   217     val (c, args) = typargs lhs;
   199     val pat =
   218     val deps = map typargs rhs;
   200       if unchecked then (Unknown, [])
   219     val restr =
   201       else the_pattern prt name (c, args);
   220       if plain_args args orelse
       
   221         (case args of [Type (a, rec_args)] => plain_args rec_args | _ => false)
       
   222       then [] else [(args, name)];
   202     val spec =
   223     val spec =
   203       (serial (), {is_def = is_def, module = module, name = name, lhs = #2 lhs, rhs = rhs});
   224       (serial (), {is_def = is_def, module = module, name = name, lhs = args, rhs = deps});
   204 
   225     val defs' = defs |> update_specs c spec;
   205     val defs' = defs
   226   in Defs defs' |> (if unchecked then I else dependencies pp (c, args) restr deps) end;
   206       |> Symtab.default (c, default_def pat)
   227 
   207       |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) =>
   228 
   208         let
   229 local  (* FIXME *)
   209           val _ = disjoint_specs c spec specs;
   230   val define_aux = define
   210           val specs' = Inttab.update spec specs;
   231   val acc = Output.time_accumulator "Defs.define"
   211         in (specs', pattern, restricts, reducts) end));
   232 in
   212   in Defs defs' |> (if unchecked then I else dependencies prt (c, args) pat (map typargs rhs)) end;
   233   fun define pp consts unchecked is_def module name lhs rhs =
   213 
   234     acc (define_aux pp consts unchecked is_def module name lhs rhs)
   214 end;
   235 end;
       
   236 
       
   237 
       
   238 end;