7 single type argument). |
7 single type argument). |
8 *) |
8 *) |
9 |
9 |
10 signature DEFS = |
10 signature DEFS = |
11 sig |
11 sig |
|
12 val pretty_const: Pretty.pp -> string * typ list -> Pretty.T |
12 type T |
13 type T |
13 val specifications_of: T -> string -> |
14 val specifications_of: T -> string -> (serial * {is_def: bool, module: string, name: string, |
14 (serial * {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list}) list |
15 lhs: typ list, rhs: (string * typ list) list}) list |
|
16 val dest: T -> |
|
17 {restricts: ((string * typ list) * string) list, |
|
18 reducts: ((string * typ list) * (string * typ list) list) list} |
15 val empty: T |
19 val empty: T |
16 val merge: Pretty.pp -> T * T -> T |
20 val merge: Pretty.pp -> T * T -> T |
17 val define: Pretty.pp -> Consts.T -> |
21 val define: Pretty.pp -> Consts.T -> |
18 bool -> bool -> string -> string -> string * typ -> (string * typ) list -> T -> T |
22 bool -> bool -> string -> string -> string * typ -> (string * typ) list -> T -> T |
19 end |
23 end |
20 |
24 |
21 structure Defs: DEFS = |
25 structure Defs: DEFS = |
22 struct |
26 struct |
23 |
27 |
24 (* consts with type arguments *) |
28 |
25 |
29 (* type arguments *) |
26 fun print_const pp (c, args) = |
30 |
|
31 type args = typ list; |
|
32 |
|
33 fun pretty_const pp (c, args) = |
27 let |
34 let |
28 val prt_args = |
35 val prt_args = |
29 if null args then [] |
36 if null args then [] |
30 else [Pretty.brk 1, Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)]; |
37 else [Pretty.list "(" ")" (map (Pretty.typ pp o Type.freeze_type) args)]; |
31 in Pretty.string_of (Pretty.block (Pretty.str c :: prt_args)) end; |
38 in Pretty.block (Pretty.str c :: prt_args) end; |
32 |
39 |
33 |
40 fun disjoint_args (Ts, Us) = |
34 (* source specs *) |
41 not (Type.could_unifys (Ts, Us)) orelse |
35 |
42 ((Type.raw_unifys (Ts, map (Logic.incr_tvar (maxidx_of_typs Ts + 1)) Us) Vartab.empty; false) |
36 type spec = {is_def: bool, module: string, name: string, lhs: typ, rhs: (string * typ) list}; |
43 handle Type.TUNIFY => true); |
37 |
44 |
38 fun disjoint_types T U = |
45 fun match_args (Ts, Us) = |
39 (Type.raw_unify (T, Logic.incr_tvar (maxidx_of_typ T + 1) U) Vartab.empty; false) |
46 Option.map Envir.typ_subst_TVars |
40 handle Type.TUNIFY => true; |
47 (SOME (Type.raw_matches (Ts, Us) Vartab.empty) handle Type.TYPE_MATCH => NONE); |
41 |
|
42 fun disjoint_specs c (i, {lhs = T, name = a, ...}: spec) = |
|
43 Inttab.forall (fn (j, {lhs = U, name = b, ...}: spec) => |
|
44 i = j orelse not (Type.could_unify (T, U)) orelse disjoint_types T U orelse |
|
45 error ("Type clash in specifications " ^ quote a ^ " and " ^ quote b ^ |
|
46 " for constant " ^ quote c)); |
|
47 |
|
48 |
|
49 (* patterns *) |
|
50 |
|
51 datatype pattern = Unknown | Plain | Overloaded; |
|
52 |
|
53 fun str_of_pattern Overloaded = "overloading" |
|
54 | str_of_pattern _ = "no overloading"; |
|
55 |
|
56 fun merge_pattern c (p1, p2) = |
|
57 if p1 = p2 orelse p2 = Unknown then p1 |
|
58 else if p1 = Unknown then p2 |
|
59 else error ("Inconsistent type patterns for constant " ^ quote c ^ ":\n" ^ |
|
60 str_of_pattern p1 ^ " versus " ^ str_of_pattern p2); |
|
61 |
|
62 fun plain_args args = |
|
63 forall Term.is_TVar args andalso not (has_duplicates (op =) args); |
|
64 |
|
65 fun the_pattern _ name (c, [Type (a, args)]) = |
|
66 (Overloaded, if plain_args args then [] else [(a, (args, name))]) |
|
67 | the_pattern prt _ (c, args) = |
|
68 if plain_args args then (Plain, []) |
|
69 else error ("Illegal type pattern for constant " ^ prt (c, args)); |
|
70 |
48 |
71 |
49 |
72 (* datatype defs *) |
50 (* datatype defs *) |
|
51 |
|
52 type spec = {is_def: bool, module: string, name: string, lhs: args, rhs: (string * args) list}; |
73 |
53 |
74 type def = |
54 type def = |
75 {specs: spec Inttab.table, |
55 {specs: spec Inttab.table, |
76 pattern: pattern, |
56 restricts: (args * string) list, |
77 restricts: (string * (typ list * string)) list, |
57 reducts: (args * (string * args) list) list}; |
78 reducts: (typ list * (string * typ list) list) list}; |
58 |
79 |
59 fun make_def (specs, restricts, reducts) = |
80 fun make_def (specs, pattern, restricts, reducts) = |
60 {specs = specs, restricts = restricts, reducts = reducts}: def; |
81 {specs = specs, pattern = pattern, restricts = restricts, reducts = reducts}: def; |
61 |
82 |
62 fun map_def c f = |
83 fun map_def f ({specs, pattern, restricts, reducts}: def) = |
63 Symtab.default (c, make_def (Inttab.empty, [], [])) #> |
84 make_def (f (specs, pattern, restricts, reducts)); |
64 Symtab.map_entry c (fn {specs, restricts, reducts}: def => |
85 |
65 make_def (f (specs, restricts, reducts))); |
86 fun default_def (pattern, restricts) = make_def (Inttab.empty, pattern, restricts, []); |
66 |
87 |
67 |
88 datatype T = Defs of def Symtab.table; |
68 datatype T = Defs of def Symtab.table; |
89 val empty = Defs Symtab.empty; |
|
90 |
69 |
91 fun lookup_list which (Defs defs) c = |
70 fun lookup_list which (Defs defs) c = |
92 (case Symtab.lookup defs c of |
71 (case Symtab.lookup defs c of |
93 SOME def => which def |
72 SOME def => which def |
94 | NONE => []); |
73 | NONE => []); |
95 |
74 |
96 val specifications_of = lookup_list (Inttab.dest o #specs); |
75 val specifications_of = lookup_list (Inttab.dest o #specs); |
97 val restricts_of = lookup_list #restricts; |
76 val restricts_of = lookup_list #restricts; |
98 val reducts_of = lookup_list #reducts; |
77 val reducts_of = lookup_list #reducts; |
99 |
78 |
100 |
79 fun dest (Defs defs) = |
101 (* normalize defs *) |
80 let |
102 |
81 val restricts = Symtab.fold (fn (c, {restricts, ...}) => |
103 fun matcher arg = |
82 fold (fn (args, name) => cons ((c, args), name)) restricts) defs []; |
104 Option.map Envir.typ_subst_TVars |
83 val reducts = Symtab.fold (fn (c, {reducts, ...}) => |
105 (SOME (Type.raw_matches arg Vartab.empty) handle Type.TYPE_MATCH => NONE); |
84 fold (fn (args, deps) => cons ((c, args), deps)) reducts) defs []; |
106 |
85 in {restricts = restricts, reducts = reducts} end; |
107 fun restriction prt defs (c, args) = |
86 |
108 (case args of |
87 val empty = Defs Symtab.empty; |
109 [Type (a, Us)] => |
88 |
110 (case AList.lookup (op =) (restricts_of defs c) a of |
89 |
111 SOME (Ts, name) => |
90 (* specifications *) |
112 if is_some (matcher (Ts, Us)) then () |
91 |
113 else error ("Occurrence of overloaded constant " ^ prt (c, args) ^ |
92 fun disjoint_specs c (i, {lhs = Ts, name = a, ...}: spec) = |
114 "\nviolates restriction " ^ prt (c, Ts) ^ "\nimposed by " ^ quote name) |
93 Inttab.forall (fn (j, {lhs = Us, name = b, ...}: spec) => |
115 | NONE => ()) |
94 i = j orelse disjoint_args (Ts, Us) orelse |
116 | _ => ()); |
95 error ("Type clash in specifications " ^ quote a ^ " and " ^ quote b ^ |
117 |
96 " for constant " ^ quote c)); |
118 fun reduction defs deps = |
97 |
|
98 fun join_specs c ({specs = specs1, restricts, reducts}, {specs = specs2, ...}: def) = |
|
99 let |
|
100 val specs' = |
|
101 Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1; |
|
102 in make_def (specs', restricts, reducts) end; |
|
103 |
|
104 fun update_specs c spec = map_def c (fn (specs, restricts, reducts) => |
|
105 (disjoint_specs c spec specs; (Inttab.update spec specs, restricts, reducts))); |
|
106 |
|
107 |
|
108 (* normalization: reduction and well-formedness check *) |
|
109 |
|
110 local |
|
111 |
|
112 fun reduction reds_of deps = |
119 let |
113 let |
120 fun reduct Us (Ts, rhs) = |
114 fun reduct Us (Ts, rhs) = |
121 (case matcher (Ts, Us) of |
115 (case match_args (Ts, Us) of |
122 NONE => NONE |
116 NONE => NONE |
123 | SOME subst => SOME (map (apsnd (map subst)) rhs)); |
117 | SOME subst => SOME (map (apsnd (map subst)) rhs)); |
124 fun reducts (d, Us) = get_first (reduct Us) (reducts_of defs d); |
118 fun reducts (d: string, Us) = get_first (reduct Us) (reds_of d); |
125 |
119 |
126 fun add (NONE, dp) = insert (op =) dp |
120 fun add (NONE, dp) = insert (op =) dp |
127 | add (SOME dps, _) = fold (insert (op =)) dps; |
121 | add (SOME dps, _) = fold (insert (op =)) dps; |
128 val deps' = map (`reducts) deps; |
122 val deps' = map (`reducts) deps; |
129 in |
123 in |
130 if forall (is_none o #1) deps' then NONE |
124 if forall (is_none o #1) deps' then NONE |
131 else SOME (fold_rev add deps' []) |
125 else SOME (fold_rev add deps' []) |
132 end; |
126 end; |
133 |
127 |
134 fun normalize prt defs (c, args) deps = |
128 fun reductions reds_of deps = |
135 let |
129 (case reduction reds_of deps of |
136 val reds = reduction defs deps; |
130 SOME deps' => reductions reds_of deps' |
137 val deps' = the_default deps reds; |
131 | NONE => deps); |
138 val _ = List.app (restriction prt defs) ((c, args) :: deps'); |
132 |
139 val _ = deps' |> List.app (fn (c', args') => |
133 fun contained U (Type (_, Ts)) = exists (fn T => T = U orelse contained U T) Ts |
140 if c' = c andalso is_some (matcher (args, args')) then |
134 | contained _ _ = false; |
141 error ("Circular dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (c, args')) |
135 |
142 else ()); |
136 fun wellformed pp rests_of (c, args) (d, Us) = |
143 in reds end; |
137 let |
144 |
138 val prt = Pretty.string_of o pretty_const pp; |
145 |
139 fun err s1 s2 = |
146 (* dependencies *) |
140 error (s1 ^ " dependency of constant " ^ prt (c, args) ^ " -> " ^ prt (d, Us) ^ s2); |
147 |
141 in |
148 fun normalize_deps prt defs0 (Defs defs) = |
142 exists (fn U => exists (contained U) args) Us orelse |
149 let |
143 (c <> d andalso exists (member (op =) args) Us) orelse |
150 fun norm const deps = perhaps (normalize prt defs0 const) deps; |
144 (case find_first (fn (Ts, _) => not (disjoint_args (Ts, Us))) (rests_of d) of |
151 fun norm_update (c, {reducts, ...}: def) = |
145 NONE => |
152 let val reducts' = reducts |> map (fn (args, deps) => (args, norm (c, args) deps)) in |
146 c <> d orelse is_none (match_args (args, Us)) orelse err "Circular" "" |
153 if reducts = reducts' then I |
147 | SOME (Ts, name) => |
154 else Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) => |
148 if c = d then err "Circular" ("\n(via " ^ quote name ^ ")") |
155 (specs, pattern, restricts, reducts'))) |
149 else |
|
150 err "Malformed" ("\n(restriction " ^ prt (d, Ts) ^ " from " ^ quote name ^ ")")) |
|
151 end; |
|
152 |
|
153 fun normalize pp rests_of reds_of (c, args) deps = |
|
154 let |
|
155 val deps' = reductions reds_of deps; |
|
156 val _ = forall (wellformed pp rests_of (c, args)) deps'; |
|
157 in deps' end; |
|
158 |
|
159 fun normalize_all pp (c, args) deps defs = |
|
160 let |
|
161 val norm = normalize pp (restricts_of (Defs defs)); |
|
162 val norm_rule = norm (fn c' => if c' = c then [(args, deps)] else []); |
|
163 val norm_defs = norm (reducts_of (Defs defs)); |
|
164 fun norm_update (c', {reducts, ...}: def) = |
|
165 let val reducts' = reducts |
|
166 |> map (fn (args', deps') => (args', norm_defs (c', args') (norm_rule (c', args') deps'))) |
|
167 in |
|
168 K (reducts <> reducts') ? |
|
169 map_def c' (fn (specs, restricts, reducts) => (specs, restricts, reducts')) |
156 end; |
170 end; |
157 in Defs (Symtab.fold norm_update defs defs) end; |
171 in Symtab.fold norm_update defs defs end; |
158 |
172 |
159 fun dependencies prt (c, args) pat deps (Defs defs) = |
173 in |
160 let |
174 |
161 val deps' = perhaps (normalize prt (Defs defs) (c, args)) deps; |
175 fun dependencies pp (c, args) restr deps (Defs defs) = |
|
176 let |
|
177 val deps' = normalize pp (restricts_of (Defs defs)) (reducts_of (Defs defs)) (c, args) deps; |
162 val defs' = defs |
178 val defs' = defs |
163 |> Symtab.default (c, default_def pat) |
179 |> map_def c (fn (specs, restricts, reducts) => |
164 |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) => |
180 (specs, Library.merge (op =) (restricts, restr), reducts)) |
165 let |
181 |> normalize_all pp (c, args) deps'; |
166 val pattern' = merge_pattern c (pattern, #1 pat); |
182 val deps'' = |
167 val restricts' = Library.merge (op =) (restricts, #2 pat); |
183 normalize pp (restricts_of (Defs defs')) (reducts_of (Defs defs')) (c, args) deps'; |
168 val reducts' = insert (op =) (args, deps') reducts; |
184 val defs'' = defs' |
169 in (specs, pattern', restricts', reducts') end)); |
185 |> map_def c (fn (specs, restricts, reducts) => |
170 in normalize_deps prt (Defs defs') (Defs defs') end; |
186 (specs, restricts, insert (op =) (args, deps'') reducts)); |
|
187 in Defs defs'' end; |
|
188 |
|
189 end; |
171 |
190 |
172 |
191 |
173 (* merge *) |
192 (* merge *) |
174 |
193 |
175 fun join_specs c ({specs = specs1, pattern, restricts, reducts}, {specs = specs2, ...}: def) = |
|
176 let |
|
177 val specs' = |
|
178 Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1; |
|
179 in make_def (specs', pattern, restricts, reducts) end; |
|
180 |
|
181 fun merge pp (Defs defs1, Defs defs2) = |
194 fun merge pp (Defs defs1, Defs defs2) = |
182 let |
195 let |
183 fun add_deps (c, args) pat deps defs = |
196 fun add_deps (c, args) restr deps defs = |
184 if AList.defined (op =) (reducts_of defs c) args then defs |
197 if AList.defined (op =) (reducts_of defs c) args then defs |
185 else dependencies (print_const pp) (c, args) pat deps defs; |
198 else dependencies pp (c, args) restr deps defs; |
186 fun add_def (c, {pattern, restricts, reducts, ...}: def) = |
199 fun add_def (c, {restricts, reducts, ...}: def) = |
187 fold (fn (args, deps) => add_deps (c, args) (pattern, restricts) deps) reducts; |
200 fold (fn (args, deps) => add_deps (c, args) restricts deps) reducts; |
188 in Defs (Symtab.join join_specs (defs1, defs2)) |> Symtab.fold add_def defs2 end; |
201 in Defs (Symtab.join join_specs (defs1, defs2)) |> Symtab.fold add_def defs2 end; |
189 |
202 |
|
203 local (* FIXME *) |
|
204 val merge_aux = merge |
|
205 val acc = Output.time_accumulator "Defs.merge" |
|
206 in fun merge pp = acc (merge_aux pp) end; |
|
207 |
190 |
208 |
191 (* define *) |
209 (* define *) |
192 |
210 |
|
211 fun plain_args args = |
|
212 forall Term.is_TVar args andalso not (has_duplicates (op =) args); |
|
213 |
193 fun define pp consts unchecked is_def module name lhs rhs (Defs defs) = |
214 fun define pp consts unchecked is_def module name lhs rhs (Defs defs) = |
194 let |
215 let |
195 val prt = print_const pp; |
|
196 fun typargs const = (#1 const, Consts.typargs consts const); |
216 fun typargs const = (#1 const, Consts.typargs consts const); |
197 |
|
198 val (c, args) = typargs lhs; |
217 val (c, args) = typargs lhs; |
199 val pat = |
218 val deps = map typargs rhs; |
200 if unchecked then (Unknown, []) |
219 val restr = |
201 else the_pattern prt name (c, args); |
220 if plain_args args orelse |
|
221 (case args of [Type (a, rec_args)] => plain_args rec_args | _ => false) |
|
222 then [] else [(args, name)]; |
202 val spec = |
223 val spec = |
203 (serial (), {is_def = is_def, module = module, name = name, lhs = #2 lhs, rhs = rhs}); |
224 (serial (), {is_def = is_def, module = module, name = name, lhs = args, rhs = deps}); |
204 |
225 val defs' = defs |> update_specs c spec; |
205 val defs' = defs |
226 in Defs defs' |> (if unchecked then I else dependencies pp (c, args) restr deps) end; |
206 |> Symtab.default (c, default_def pat) |
227 |
207 |> Symtab.map_entry c (map_def (fn (specs, pattern, restricts, reducts) => |
228 |
208 let |
229 local (* FIXME *) |
209 val _ = disjoint_specs c spec specs; |
230 val define_aux = define |
210 val specs' = Inttab.update spec specs; |
231 val acc = Output.time_accumulator "Defs.define" |
211 in (specs', pattern, restricts, reducts) end)); |
232 in |
212 in Defs defs' |> (if unchecked then I else dependencies prt (c, args) pat (map typargs rhs)) end; |
233 fun define pp consts unchecked is_def module name lhs rhs = |
213 |
234 acc (define_aux pp consts unchecked is_def module name lhs rhs) |
214 end; |
235 end; |
|
236 |
|
237 |
|
238 end; |