|
1 (* Title: Tools/subtyping.ML |
|
2 Author: Dmitriy Traytel, TU Muenchen |
|
3 |
|
4 Coercive subtyping via subtype constraints. |
|
5 *) |
|
6 |
|
7 signature SUBTYPING = |
|
8 sig |
|
9 datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT |
|
10 val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) -> |
|
11 term list -> term list |
|
12 end; |
|
13 |
|
14 structure Subtyping = |
|
15 struct |
|
16 |
|
17 |
|
18 |
|
19 (** coercions data **) |
|
20 |
|
21 datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT |
|
22 |
|
23 datatype data = Data of |
|
24 {coes: term Symreltab.table, (* coercions table *) |
|
25 coes_graph: unit Graph.T, (* coercions graph *) |
|
26 tmaps: (term * variance list) Symtab.table}; (* map functions *) |
|
27 |
|
28 fun make_data (coes, coes_graph, tmaps) = |
|
29 Data {coes = coes, coes_graph = coes_graph, tmaps = tmaps}; |
|
30 |
|
31 structure Data = Generic_Data |
|
32 ( |
|
33 type T = data; |
|
34 val empty = make_data (Symreltab.empty, Graph.empty, Symtab.empty); |
|
35 val extend = I; |
|
36 fun merge |
|
37 (Data {coes = coes1, coes_graph = coes_graph1, tmaps = tmaps1}, |
|
38 Data {coes = coes2, coes_graph = coes_graph2, tmaps = tmaps2}) = |
|
39 make_data (Symreltab.merge (op aconv) (coes1, coes2), |
|
40 Graph.merge (op =) (coes_graph1, coes_graph2), |
|
41 Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2)); |
|
42 ); |
|
43 |
|
44 fun map_data f = |
|
45 Data.map (fn Data {coes, coes_graph, tmaps} => |
|
46 make_data (f (coes, coes_graph, tmaps))); |
|
47 |
|
48 fun map_coes f = |
|
49 map_data (fn (coes, coes_graph, tmaps) => |
|
50 (f coes, coes_graph, tmaps)); |
|
51 |
|
52 fun map_coes_graph f = |
|
53 map_data (fn (coes, coes_graph, tmaps) => |
|
54 (coes, f coes_graph, tmaps)); |
|
55 |
|
56 fun map_coes_and_graph f = |
|
57 map_data (fn (coes, coes_graph, tmaps) => |
|
58 let val (coes', coes_graph') = f (coes, coes_graph); |
|
59 in (coes', coes_graph', tmaps) end); |
|
60 |
|
61 fun map_tmaps f = |
|
62 map_data (fn (coes, coes_graph, tmaps) => |
|
63 (coes, coes_graph, f tmaps)); |
|
64 |
|
65 fun rep_data context = Data.get context |> (fn Data args => args); |
|
66 |
|
67 val coes_of = #coes o rep_data; |
|
68 val coes_graph_of = #coes_graph o rep_data; |
|
69 val tmaps_of = #tmaps o rep_data; |
|
70 |
|
71 |
|
72 |
|
73 (** utils **) |
|
74 |
|
75 val is_param = Type_Infer.is_param |
|
76 val is_paramT = Type_Infer.is_paramT |
|
77 val deref = Type_Infer.deref |
|
78 fun mk_param i S = TVar (("?'a", i), S); (* TODO dup? see src/Pure/type_infer.ML *) |
|
79 |
|
80 fun nameT (Type (s, [])) = s; |
|
81 fun t_of s = Type (s, []); |
|
82 fun sort_of (TFree (_, S)) = SOME S |
|
83 | sort_of (TVar (_, S)) = SOME S |
|
84 | sort_of _ = NONE; |
|
85 |
|
86 val is_typeT = fn (Type _) => true | _ => false; |
|
87 val is_compT = fn (Type (_, _::_)) => true | _ => false; |
|
88 val is_freeT = fn (TFree _) => true | _ => false; |
|
89 val is_fixedvarT = fn (TVar (xi, _)) => not (is_param xi) | _ => false; |
|
90 |
|
91 |
|
92 (* unification TODO dup? needed for weak unification *) |
|
93 |
|
94 exception NO_UNIFIER of string * typ Vartab.table; |
|
95 |
|
96 fun unify weak ctxt = |
|
97 let |
|
98 val thy = ProofContext.theory_of ctxt; |
|
99 val pp = Syntax.pp ctxt; |
|
100 val arity_sorts = Type.arity_sorts pp (Sign.tsig_of thy); |
|
101 |
|
102 |
|
103 (* adjust sorts of parameters *) |
|
104 |
|
105 fun not_of_sort x S' S = |
|
106 "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^ |
|
107 Syntax.string_of_sort ctxt S; |
|
108 |
|
109 fun meet (_, []) tye_idx = tye_idx |
|
110 | meet (Type (a, Ts), S) (tye_idx as (tye, _)) = |
|
111 meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx |
|
112 | meet (TFree (x, S'), S) (tye_idx as (tye, _)) = |
|
113 if Sign.subsort thy (S', S) then tye_idx |
|
114 else raise NO_UNIFIER (not_of_sort x S' S, tye) |
|
115 | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) = |
|
116 if Sign.subsort thy (S', S) then tye_idx |
|
117 else if Type_Infer.is_param xi then |
|
118 (Vartab.update_new (xi, mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1) |
|
119 else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye) |
|
120 and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) = |
|
121 meets (Ts, Ss) (meet (deref tye T, S) tye_idx) |
|
122 | meets _ tye_idx = tye_idx; |
|
123 |
|
124 val weak_meet = if weak then fn _ => I else meet |
|
125 |
|
126 |
|
127 (* occurs check and assignment *) |
|
128 |
|
129 fun occurs_check tye xi (TVar (xi', _)) = |
|
130 if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye) |
|
131 else |
|
132 (case Vartab.lookup tye xi' of |
|
133 NONE => () |
|
134 | SOME T => occurs_check tye xi T) |
|
135 | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts |
|
136 | occurs_check _ _ _ = (); |
|
137 |
|
138 fun assign xi (T as TVar (xi', _)) S env = |
|
139 if xi = xi' then env |
|
140 else env |> weak_meet (T, S) |>> Vartab.update_new (xi, T) |
|
141 | assign xi T S (env as (tye, _)) = |
|
142 (occurs_check tye xi T; env |> weak_meet (T, S) |>> Vartab.update_new (xi, T)); |
|
143 |
|
144 |
|
145 (* unification *) |
|
146 |
|
147 fun show_tycon (a, Ts) = |
|
148 quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT))); |
|
149 |
|
150 fun unif (T1, T2) (env as (tye, _)) = |
|
151 (case pairself (`is_paramT o deref tye) (T1, T2) of |
|
152 ((true, TVar (xi, S)), (_, T)) => assign xi T S env |
|
153 | ((_, T), (true, TVar (xi, S))) => assign xi T S env |
|
154 | ((_, Type (a, Ts)), (_, Type (b, Us))) => |
|
155 if weak andalso null Ts andalso null Us then env |
|
156 else if a <> b then |
|
157 raise NO_UNIFIER |
|
158 ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye) |
|
159 else fold unif (Ts ~~ Us) env |
|
160 | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye)); |
|
161 |
|
162 in unif end; |
|
163 |
|
164 val weak_unify = unify true; |
|
165 val strong_unify = unify false; |
|
166 |
|
167 |
|
168 (* Typ_Graph shortcuts *) |
|
169 |
|
170 val add_edge = Typ_Graph.add_edge_acyclic; |
|
171 fun get_preds G T = Typ_Graph.all_preds G [T]; |
|
172 fun get_succs G T = Typ_Graph.all_succs G [T]; |
|
173 fun maybe_new_typnode T G = perhaps (try (Typ_Graph.new_node (T, ()))) G; |
|
174 fun maybe_new_typnodes Ts G = fold maybe_new_typnode Ts G; |
|
175 fun new_imm_preds G Ts = |
|
176 subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.imm_preds G) Ts)); |
|
177 fun new_imm_succs G Ts = |
|
178 subtract op= Ts (distinct (op =) (maps (Typ_Graph.imm_succs G) Ts)); |
|
179 |
|
180 |
|
181 (* Graph shortcuts *) |
|
182 |
|
183 fun maybe_new_node s G = perhaps (try (Graph.new_node (s, ()))) G |
|
184 fun maybe_new_nodes ss G = fold maybe_new_node ss G |
|
185 |
|
186 |
|
187 |
|
188 (** error messages **) |
|
189 |
|
190 fun prep_output ctxt tye bs ts Ts = |
|
191 let |
|
192 val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts); |
|
193 val (Ts', Ts'') = chop (length Ts) Ts_bTs'; |
|
194 fun prep t = |
|
195 let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts''))) |
|
196 in Term.subst_bounds (map Syntax.mark_boundT xs, t) end; |
|
197 in (map prep ts', Ts') end; |
|
198 |
|
199 fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i); |
|
200 |
|
201 fun inf_failed msg = |
|
202 "Subtype inference failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n"; |
|
203 |
|
204 fun err_appl ctxt msg tye bs t T u U = |
|
205 let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U] |
|
206 in error (inf_failed msg ^ Type.appl_error (Syntax.pp ctxt) t' T' u' U' ^ "\n") end; |
|
207 |
|
208 fun err_subtype ctxt msg tye (bs, t $ u, U, V, U') = |
|
209 err_appl ctxt msg tye bs t (U --> V) u U'; |
|
210 |
|
211 fun err_list ctxt msg tye Ts = |
|
212 let |
|
213 val (_, Ts') = prep_output ctxt tye [] [] Ts; |
|
214 val text = cat_lines ([inf_failed msg, |
|
215 "Cannot unify a list of types that should be the same,", |
|
216 "according to suptype dependencies:", |
|
217 (Pretty.string_of (Pretty.list "[" "]" (map (Pretty.typ (Syntax.pp ctxt)) Ts')))]); |
|
218 in |
|
219 error text |
|
220 end; |
|
221 |
|
222 fun err_bound ctxt msg tye packs = |
|
223 let |
|
224 val pp = Syntax.pp ctxt; |
|
225 val (ts, Ts) = fold |
|
226 (fn (bs, t $ u, U, _, U') => fn (ts, Ts) => |
|
227 let val (t', T') = prep_output ctxt tye bs [t, u] [U, U'] |
|
228 in (t'::ts, T'::Ts) end) |
|
229 packs ([], []); |
|
230 val text = cat_lines ([inf_failed msg, "Cannot fullfill subtype constraints:"] @ |
|
231 (map2 (fn [t, u] => fn [T, U] => Pretty.string_of ( |
|
232 Pretty.block [ |
|
233 Pretty.typ pp T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2, Pretty.typ pp U, |
|
234 Pretty.brk 3, Pretty.str "from function application", Pretty.brk 2, |
|
235 Pretty.block [Pretty.term pp t, Pretty.brk 1, Pretty.term pp u]])) |
|
236 ts Ts)) |
|
237 in |
|
238 error text |
|
239 end; |
|
240 |
|
241 |
|
242 |
|
243 (** constraint generation **) |
|
244 |
|
245 fun generate_constraints ctxt = |
|
246 let |
|
247 fun gen cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs) |
|
248 | gen cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs) |
|
249 | gen cs _ (Var (_, T)) tye_idx = (T, tye_idx, cs) |
|
250 | gen cs bs (Bound i) tye_idx = |
|
251 (snd (nth bs i handle Subscript => err_loose i), tye_idx, cs) |
|
252 | gen cs bs (Abs (x, T, t)) tye_idx = |
|
253 let val (U, tye_idx', cs') = gen cs ((x, T) :: bs) t tye_idx |
|
254 in (T --> U, tye_idx', cs') end |
|
255 | gen cs bs (t $ u) tye_idx = |
|
256 let |
|
257 val (T, tye_idx', cs') = gen cs bs t tye_idx; |
|
258 val (U', (tye, idx), cs'') = gen cs' bs u tye_idx'; |
|
259 val U = mk_param idx []; |
|
260 val V = mk_param (idx + 1) []; |
|
261 val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2) |
|
262 handle NO_UNIFIER (msg, tye') => err_appl ctxt msg tye' bs t T u U; |
|
263 val error_pack = (bs, t $ u, U, V, U'); |
|
264 in (V, tye_idx'', ((U', U), error_pack) :: cs'') end; |
|
265 in |
|
266 gen [] [] |
|
267 end; |
|
268 |
|
269 |
|
270 |
|
271 (** constraint resolution **) |
|
272 |
|
273 exception BOUND_ERROR of string; |
|
274 |
|
275 fun process_constraints ctxt cs tye_idx = |
|
276 let |
|
277 val coes_graph = coes_graph_of (Context.Proof ctxt); |
|
278 val tmaps = tmaps_of (Context.Proof ctxt); |
|
279 val tsig = Sign.tsig_of (ProofContext.theory_of ctxt); |
|
280 val pp = Syntax.pp ctxt; |
|
281 val arity_sorts = Type.arity_sorts pp tsig; |
|
282 val subsort = Type.subsort tsig; |
|
283 |
|
284 fun split_cs _ [] = ([], []) |
|
285 | split_cs f (c::cs) = |
|
286 (case pairself f (fst c) of |
|
287 (false, false) => apsnd (cons c) (split_cs f cs) |
|
288 | _ => apfst (cons c) (split_cs f cs)); |
|
289 |
|
290 |
|
291 (* check whether constraint simplification will terminate using weak unification *) |
|
292 |
|
293 val _ = fold (fn (TU, error_pack) => fn tye_idx => |
|
294 (weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, tye) => |
|
295 err_subtype ctxt ("Weak unification of subtype constraints fails:\n" ^ msg) |
|
296 tye error_pack)) cs tye_idx; |
|
297 |
|
298 |
|
299 (* simplify constraints *) |
|
300 |
|
301 fun simplify_constraints cs tye_idx = |
|
302 let |
|
303 fun contract a Ts Us error_pack done todo tye idx = |
|
304 let |
|
305 val arg_var = |
|
306 (case Symtab.lookup tmaps a of |
|
307 (*everything is invariant for unknown constructors*) |
|
308 NONE => replicate (length Ts) INVARIANT |
|
309 | SOME av => snd av); |
|
310 fun new_constraints (variance, constraint) (cs, tye_idx) = |
|
311 (case variance of |
|
312 COVARIANT => (constraint :: cs, tye_idx) |
|
313 | CONTRAVARIANT => (swap constraint :: cs, tye_idx) |
|
314 | INVARIANT => (cs, strong_unify ctxt constraint tye_idx |
|
315 handle NO_UNIFIER (msg, tye) => err_subtype ctxt msg tye error_pack)); |
|
316 val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack)) |
|
317 (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx))); |
|
318 val test_update = is_compT orf is_freeT orf is_fixedvarT; |
|
319 val (ch, done') = |
|
320 if not (null new) then ([], done) |
|
321 else split_cs (test_update o deref tye') done; |
|
322 val todo' = ch @ todo; |
|
323 in |
|
324 simplify done' (new @ todo') (tye', idx') |
|
325 end |
|
326 (*xi is definitely a parameter*) |
|
327 and expand varleq xi S a Ts error_pack done todo tye idx = |
|
328 let |
|
329 val n = length Ts; |
|
330 val args = map2 mk_param (idx upto idx + n - 1) (arity_sorts a S); |
|
331 val tye' = Vartab.update_new (xi, Type(a, args)) tye; |
|
332 val (ch, done') = split_cs (is_compT o deref tye') done; |
|
333 val todo' = ch @ todo; |
|
334 val new = |
|
335 if varleq then (Type(a, args), Type (a, Ts)) |
|
336 else (Type (a, Ts), Type(a, args)); |
|
337 in |
|
338 simplify done' ((new, error_pack) :: todo') (tye', idx + n) |
|
339 end |
|
340 (*TU is a pair of a parameter and a free/fixed variable*) |
|
341 and eliminate TU error_pack done todo tye idx = |
|
342 let |
|
343 val [TVar (xi, S)] = filter is_paramT TU; |
|
344 val [T] = filter_out is_paramT TU; |
|
345 val SOME S' = sort_of T; |
|
346 val test_update = if is_freeT T then is_freeT else is_fixedvarT; |
|
347 val tye' = Vartab.update_new (xi, T) tye; |
|
348 val (ch, done') = split_cs (test_update o deref tye') done; |
|
349 val todo' = ch @ todo; |
|
350 in |
|
351 if subsort (S', S) (*TODO check this*) |
|
352 then simplify done' todo' (tye', idx) |
|
353 else err_subtype ctxt "Sort mismatch" tye error_pack |
|
354 end |
|
355 and simplify done [] tye_idx = (done, tye_idx) |
|
356 | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) = |
|
357 (case (deref tye T, deref tye U) of |
|
358 (Type (a, []), Type (b, [])) => |
|
359 if a = b then simplify done todo tye_idx |
|
360 else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx |
|
361 else err_subtype ctxt (a ^" is not a subtype of " ^ b) (fst tye_idx) error_pack |
|
362 | (Type (a, Ts), Type (b, Us)) => |
|
363 if a<>b then err_subtype ctxt "Different constructors" (fst tye_idx) error_pack |
|
364 else contract a Ts Us error_pack done todo tye idx |
|
365 | (TVar (xi, S), Type (a, Ts as (_::_))) => |
|
366 expand true xi S a Ts error_pack done todo tye idx |
|
367 | (Type (a, Ts as (_::_)), TVar (xi, S)) => |
|
368 expand false xi S a Ts error_pack done todo tye idx |
|
369 | (T, U) => |
|
370 if T = U then simplify done todo tye_idx |
|
371 else if exists (is_freeT orf is_fixedvarT) [T, U] andalso |
|
372 exists is_paramT [T, U] |
|
373 then eliminate [T, U] error_pack done todo tye idx |
|
374 else if exists (is_freeT orf is_fixedvarT) [T, U] |
|
375 then err_subtype ctxt "Not eliminated free/fixed variables" |
|
376 (fst tye_idx) error_pack |
|
377 else simplify (((T, U), error_pack)::done) todo tye_idx); |
|
378 in |
|
379 simplify [] cs tye_idx |
|
380 end; |
|
381 |
|
382 |
|
383 (* do simplification *) |
|
384 |
|
385 val (cs', tye_idx') = simplify_constraints cs tye_idx; |
|
386 |
|
387 fun find_error_pack lower T' = |
|
388 map snd (filter (fn ((T, U), _) => if lower then T' = U else T' = T) cs'); |
|
389 |
|
390 fun unify_list (T::Ts) tye_idx = |
|
391 fold (fn U => fn tye_idx => strong_unify ctxt (T, U) tye_idx |
|
392 handle NO_UNIFIER (msg, tye) => err_list ctxt msg tye (T::Ts)) |
|
393 Ts tye_idx; |
|
394 |
|
395 (*styps stands either for supertypes or for subtypes of a type T |
|
396 in terms of the subtype-relation (excluding T itself)*) |
|
397 fun styps super T = |
|
398 (if super then Graph.imm_succs else Graph.imm_preds) coes_graph T |
|
399 handle Graph.UNDEF _ => []; |
|
400 |
|
401 fun minmax sup (T::Ts) = |
|
402 let |
|
403 fun adjust T U = if sup then (T, U) else (U, T); |
|
404 fun extract T [] = T |
|
405 | extract T (U::Us) = |
|
406 if Graph.is_edge coes_graph (adjust T U) then extract T Us |
|
407 else if Graph.is_edge coes_graph (adjust U T) then extract U Us |
|
408 else raise BOUND_ERROR "Uncomparable types in type list"; |
|
409 in |
|
410 t_of (extract T Ts) |
|
411 end; |
|
412 |
|
413 fun ex_styp_of_sort super T styps_and_sorts = |
|
414 let |
|
415 fun adjust T U = if super then (T, U) else (U, T); |
|
416 fun styp_test U Ts = forall |
|
417 (fn T => T = U orelse Graph.is_edge coes_graph (adjust U T)) Ts; |
|
418 fun fitting Ts S U = Type.of_sort tsig (t_of U, S) andalso styp_test U Ts |
|
419 in |
|
420 forall (fn (Ts, S) => exists (fitting Ts S) (T :: styps super T)) styps_and_sorts |
|
421 end; |
|
422 |
|
423 (* computes the tightest possible, correct assignment for 'a::S |
|
424 e.g. in the supremum case (sup = true): |
|
425 ------- 'a::S--- |
|
426 / / \ \ |
|
427 / / \ \ |
|
428 'b::C1 'c::C2 ... T1 T2 ... |
|
429 |
|
430 sorts - list of sorts [C1, C2, ...] |
|
431 T::Ts - non-empty list of base types [T1, T2, ...] |
|
432 *) |
|
433 fun tightest sup S styps_and_sorts (T::Ts) = |
|
434 let |
|
435 fun restriction T = Type.of_sort tsig (t_of T, S) |
|
436 andalso ex_styp_of_sort (not sup) T styps_and_sorts; |
|
437 fun candidates T = inter (op =) (filter restriction (T :: styps sup T)); |
|
438 in |
|
439 (case fold candidates Ts (filter restriction (T :: styps sup T)) of |
|
440 [] => raise BOUND_ERROR ("No " ^ (if sup then "supremum" else "infimum")) |
|
441 | [T] => t_of T |
|
442 | Ts => minmax sup Ts) |
|
443 end; |
|
444 |
|
445 fun build_graph G [] tye_idx = (G, tye_idx) |
|
446 | build_graph G ((T, U)::cs) tye_idx = |
|
447 if T = U then build_graph G cs tye_idx |
|
448 else |
|
449 let |
|
450 val G' = maybe_new_typnodes [T, U] G; |
|
451 val (G'', tye_idx') = (add_edge (T, U) G', tye_idx) |
|
452 handle Typ_Graph.CYCLES cycles => |
|
453 let |
|
454 val (tye, idx) = fold unify_list cycles tye_idx |
|
455 in |
|
456 (*all cycles collapse to one node, |
|
457 because all of them share at least the nodes x and y*) |
|
458 collapse (tye, idx) (distinct (op =) (flat cycles)) G |
|
459 end; |
|
460 in |
|
461 build_graph G'' cs tye_idx' |
|
462 end |
|
463 and collapse (tye, idx) nodes G = (*nodes non-empty list*) |
|
464 let |
|
465 val T = hd nodes; |
|
466 val P = new_imm_preds G nodes; |
|
467 val S = new_imm_succs G nodes; |
|
468 val G' = Typ_Graph.del_nodes (tl nodes) G; |
|
469 in |
|
470 build_graph G' (map (fn x => (x, T)) P @ map (fn x => (T, x)) S) (tye, idx) |
|
471 end; |
|
472 |
|
473 fun assign_bound lower G key (tye_idx as (tye, _)) = |
|
474 if is_paramT (deref tye key) then |
|
475 let |
|
476 val TVar (xi, S) = deref tye key; |
|
477 val get_bound = if lower then get_preds else get_succs; |
|
478 val raw_bound = get_bound G key; |
|
479 val bound = map (deref tye) raw_bound; |
|
480 val not_params = filter_out is_paramT bound; |
|
481 fun to_fulfil T = |
|
482 (case sort_of T of |
|
483 NONE => NONE |
|
484 | SOME S => |
|
485 SOME (map nameT (filter_out is_paramT (map (deref tye) (get_bound G T))), S)); |
|
486 val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound); |
|
487 val assignment = |
|
488 if null bound orelse null not_params then NONE |
|
489 else SOME (tightest lower S styps_and_sorts (map nameT not_params) |
|
490 handle BOUND_ERROR msg => err_bound ctxt msg tye (find_error_pack lower key)) |
|
491 in |
|
492 (case assignment of |
|
493 NONE => tye_idx |
|
494 | SOME T => |
|
495 if is_paramT T then tye_idx |
|
496 else if lower then (*upper bound check*) |
|
497 let |
|
498 val other_bound = map (deref tye) (get_succs G key); |
|
499 val s = nameT T; |
|
500 in |
|
501 if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s) |
|
502 then apfst (Vartab.update (xi, T)) tye_idx |
|
503 else err_bound ctxt ("Assigned simple type " ^ s ^ |
|
504 " clashes with the upper bound of variable " ^ |
|
505 Syntax.string_of_typ ctxt (TVar(xi, S))) tye (find_error_pack (not lower) key) |
|
506 end |
|
507 else apfst (Vartab.update (xi, T)) tye_idx) |
|
508 end |
|
509 else tye_idx; |
|
510 |
|
511 val assign_lb = assign_bound true; |
|
512 val assign_ub = assign_bound false; |
|
513 |
|
514 fun assign_alternating ts' ts G tye_idx = |
|
515 if ts' = ts then tye_idx |
|
516 else |
|
517 let |
|
518 val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx |
|
519 |> fold (assign_ub G) ts; |
|
520 in |
|
521 assign_alternating ts (filter (is_paramT o deref tye) ts) G tye_idx' |
|
522 end; |
|
523 |
|
524 (*Unify all weakly connected components of the constraint forest, |
|
525 that contain only params. These are the only WCCs that contain |
|
526 params anyway.*) |
|
527 fun unify_params G (tye_idx as (tye, _)) = |
|
528 let |
|
529 val max_params = filter (is_paramT o deref tye) (Typ_Graph.maximals G); |
|
530 val to_unify = map (fn T => T :: get_preds G T) max_params; |
|
531 in |
|
532 fold unify_list to_unify tye_idx |
|
533 end; |
|
534 |
|
535 fun solve_constraints G tye_idx = tye_idx |
|
536 |> assign_alternating [] (Typ_Graph.keys G) G |
|
537 |> unify_params G; |
|
538 in |
|
539 build_graph Typ_Graph.empty (map fst cs') tye_idx' |
|
540 |-> solve_constraints |
|
541 end; |
|
542 |
|
543 |
|
544 |
|
545 (** coercion insertion **) |
|
546 |
|
547 fun insert_coercions ctxt tye ts = |
|
548 let |
|
549 fun deep_deref T = |
|
550 (case deref tye T of |
|
551 Type (a, Ts) => Type (a, map deep_deref Ts) |
|
552 | U => U); |
|
553 |
|
554 fun gen_coercion ((Type (a, [])), (Type (b, []))) = |
|
555 if a = b |
|
556 then Abs (Name.uu, Type (a, []), Bound 0) |
|
557 else |
|
558 (case Symreltab.lookup (coes_of (Context.Proof ctxt)) (a, b) of |
|
559 NONE => raise Fail (a ^ " is not a subtype of " ^ b) |
|
560 | SOME co => co) |
|
561 | gen_coercion ((Type (a, Ts)), (Type (b, Us))) = |
|
562 if a <> b |
|
563 then raise raise Fail ("Different constructors: " ^ a ^ " and " ^ b) |
|
564 else |
|
565 let |
|
566 fun inst t Ts = |
|
567 Term.subst_vars |
|
568 (((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t; |
|
569 fun sub_co (COVARIANT, TU) = gen_coercion TU |
|
570 | sub_co (CONTRAVARIANT, TU) = gen_coercion (swap TU); |
|
571 fun ts_of [] = [] |
|
572 | ts_of (Type ("fun", [x1, x2])::xs) = x1::x2::(ts_of xs); |
|
573 in |
|
574 (case Symtab.lookup (tmaps_of (Context.Proof ctxt)) a of |
|
575 NONE => raise Fail ("No map function for " ^ a ^ " known") |
|
576 | SOME tmap => |
|
577 let |
|
578 val used_coes = map sub_co ((snd tmap) ~~ (Ts ~~ Us)); |
|
579 in |
|
580 Term.list_comb |
|
581 (inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes) |
|
582 end) |
|
583 end |
|
584 | gen_coercion (T, U) = |
|
585 if Type.could_unify (T, U) |
|
586 then Abs (Name.uu, T, Bound 0) |
|
587 else raise Fail ("Cannot generate coercion from " |
|
588 ^ Syntax.string_of_typ ctxt T ^ " to " ^ Syntax.string_of_typ ctxt U); |
|
589 |
|
590 fun insert _ (Const (c, T)) = |
|
591 let val T' = deep_deref T; |
|
592 in (Const (c, T'), T') end |
|
593 | insert _ (Free (x, T)) = |
|
594 let val T' = deep_deref T; |
|
595 in (Free (x, T'), T') end |
|
596 | insert _ (Var (xi, T)) = |
|
597 let val T' = deep_deref T; |
|
598 in (Var (xi, T'), T') end |
|
599 | insert bs (Bound i) = |
|
600 let val T = nth bs i handle Subscript => |
|
601 raise TYPE ("Loose bound variable: B." ^ string_of_int i, [], []); |
|
602 in (Bound i, T) end |
|
603 | insert bs (Abs (x, T, t)) = |
|
604 let |
|
605 val T' = deep_deref T; |
|
606 val (t', T'') = insert (T'::bs) t; |
|
607 in |
|
608 (Abs (x, T', t'), T' --> T'') |
|
609 end |
|
610 | insert bs (t $ u) = |
|
611 let |
|
612 val (t', Type ("fun", [U, T])) = insert bs t; |
|
613 val (u', U') = insert bs u; |
|
614 in |
|
615 if U <> U' |
|
616 then (t' $ (gen_coercion (U', U) $ u'), T) |
|
617 else (t' $ u', T) |
|
618 end |
|
619 in |
|
620 map (fst o insert []) ts |
|
621 end; |
|
622 |
|
623 |
|
624 |
|
625 (** assembling the pipeline **) |
|
626 |
|
627 fun infer_types ctxt const_type var_type raw_ts = |
|
628 let |
|
629 val (idx, ts) = Type_Infer.prepare ctxt const_type var_type raw_ts; |
|
630 |
|
631 fun gen_all t (tye_idx, constraints) = |
|
632 let |
|
633 val (_, tye_idx', constraints') = generate_constraints ctxt t tye_idx |
|
634 in (tye_idx', constraints' @ constraints) end; |
|
635 |
|
636 val (tye_idx, constraints) = fold gen_all ts ((Vartab.empty, idx), []); |
|
637 val (tye, _) = process_constraints ctxt constraints tye_idx; |
|
638 val ts' = insert_coercions ctxt tye ts; |
|
639 |
|
640 val (_, ts'') = Type_Infer.finish ctxt tye ([], ts'); |
|
641 in ts'' end; |
|
642 |
|
643 |
|
644 |
|
645 (** installation **) |
|
646 |
|
647 fun coercion_infer_types ctxt = |
|
648 infer_types ctxt |
|
649 (try (Consts.the_constraint (ProofContext.consts_of ctxt))) |
|
650 (ProofContext.def_type ctxt); |
|
651 |
|
652 local |
|
653 |
|
654 fun add eq what f = Context.>> (what (fn xs => fn ctxt => |
|
655 let val xs' = f ctxt xs in if eq_list eq (xs, xs') then NONE else SOME (xs', ctxt) end)); |
|
656 |
|
657 in |
|
658 |
|
659 val _ = add (op aconv) (Syntax.add_term_check ~100 "coercions") coercion_infer_types; |
|
660 |
|
661 end; |
|
662 |
|
663 |
|
664 (* interface *) |
|
665 |
|
666 fun add_type_map map_fun context = |
|
667 let |
|
668 val ctxt = Context.proof_of context; |
|
669 val t = singleton (Variable.polymorphic ctxt) (Syntax.read_term ctxt map_fun); |
|
670 |
|
671 fun err_str () = "\n\nthe general type signature for a map function is" ^ |
|
672 "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [x1, ..., xn]" ^ |
|
673 "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi)"; |
|
674 |
|
675 fun gen_arg_var ([], []) = [] |
|
676 | gen_arg_var ((T, T')::Ts, (U, U')::Us) = |
|
677 if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us) |
|
678 else if T = U' andalso T' = U then CONTRAVARIANT :: gen_arg_var (Ts, Us) |
|
679 else error ("Functions do not apply to arguments correctly:" ^ err_str ()) |
|
680 | gen_arg_var (_, _) = |
|
681 error ("Different numbers of functions and arguments\n" ^ err_str ()); |
|
682 |
|
683 (* TODO: This function is only needed to introde the fun type map |
|
684 function: "% f g h . g o h o f". There must be a better solution. *) |
|
685 fun balanced (Type (_, [])) (Type (_, [])) = true |
|
686 | balanced (Type (a, Ts)) (Type (b, Us)) = |
|
687 a = b andalso forall I (map2 balanced Ts Us) |
|
688 | balanced (TFree _) (TFree _) = true |
|
689 | balanced (TVar _) (TVar _) = true |
|
690 | balanced _ _ = false; |
|
691 |
|
692 fun check_map_fun (pairs, []) (Type ("fun", [T as Type (C, Ts), U as Type (_, Us)])) = |
|
693 if balanced T U |
|
694 then ((pairs, Ts~~Us), C) |
|
695 else if C = "fun" |
|
696 then check_map_fun (pairs @ [(hd Ts, hd (tl Ts))], []) U |
|
697 else error ("Not a proper map function:" ^ err_str ()) |
|
698 | check_map_fun _ _ = error ("Not a proper map function:" ^ err_str ()); |
|
699 |
|
700 val res = check_map_fun ([], []) (fastype_of t); |
|
701 val res_av = gen_arg_var (fst res); |
|
702 in |
|
703 map_tmaps (Symtab.update (snd res, (t, res_av))) context |
|
704 end; |
|
705 |
|
706 fun add_coercion coercion context = |
|
707 let |
|
708 val ctxt = Context.proof_of context; |
|
709 val t = singleton (Variable.polymorphic ctxt) (Syntax.read_term ctxt coercion); |
|
710 |
|
711 fun err_coercion () = error ("Bad type for coercion " ^ |
|
712 Syntax.string_of_term ctxt t ^ ":\n" ^ |
|
713 Syntax.string_of_typ ctxt (fastype_of t)); |
|
714 |
|
715 val (Type ("fun", [T1, T2])) = fastype_of t |
|
716 handle Bind => err_coercion (); |
|
717 |
|
718 val a = |
|
719 (case T1 of |
|
720 Type (x, []) => x |
|
721 | _ => err_coercion ()); |
|
722 |
|
723 val b = |
|
724 (case T2 of |
|
725 Type (x, []) => x |
|
726 | _ => err_coercion ()); |
|
727 |
|
728 fun coercion_data_update (tab, G) = |
|
729 let |
|
730 val G' = maybe_new_nodes [a, b] G |
|
731 val G'' = Graph.add_edge_trans_acyclic (a, b) G' |
|
732 handle Graph.CYCLES _ => error (a ^ " is already a subtype of " ^ b ^ |
|
733 "!\n\nCannot add coercion of type: " ^ a ^ " => " ^ b); |
|
734 val new_edges = |
|
735 flat (Graph.dest G'' |> map (fn (x, ys) => ys |> map_filter (fn y => |
|
736 if Graph.is_edge G' (x, y) then NONE else SOME (x, y)))); |
|
737 val G_and_new = Graph.add_edge (a, b) G'; |
|
738 |
|
739 fun complex_coercion tab G (a, b) = |
|
740 let |
|
741 val path = hd (Graph.irreducible_paths G (a, b)) |
|
742 val path' = (fst (split_last path)) ~~ tl path |
|
743 in Abs (Name.uu, Type (a, []), |
|
744 fold (fn t => fn u => t $ u) (map (the o Symreltab.lookup tab) path') (Bound 0)) |
|
745 end; |
|
746 |
|
747 val tab' = fold |
|
748 (fn pair => fn tab => Symreltab.update (pair, complex_coercion tab G_and_new pair) tab) |
|
749 (filter (fn pair => pair <> (a, b)) new_edges) |
|
750 (Symreltab.update ((a, b), t) tab); |
|
751 in |
|
752 (tab', G'') |
|
753 end; |
|
754 in |
|
755 map_coes_and_graph coercion_data_update context |
|
756 end; |
|
757 |
|
758 val _ = Context.>> (Context.map_theory |
|
759 (Attrib.setup (Binding.name "coercion") (Scan.lift Parse.term >> |
|
760 (fn t => fn (context, thm) => (add_coercion t context, thm))) |
|
761 "declaration of new coercions" #> |
|
762 Attrib.setup (Binding.name "map_function") (Scan.lift Parse.term >> |
|
763 (fn t => fn (context, thm) => (add_type_map t context, thm))) |
|
764 "declaration of new map functions")); |
|
765 |
|
766 end; |