|
1 (* Title: Pure/type_infer.ML |
|
2 ID: $Id$ |
|
3 Author: Stefan Berghofer and Markus Wenzel, TU Muenchen |
|
4 |
|
5 Type inference. |
|
6 *) |
|
7 |
|
8 signature TYPE_INFER = |
|
9 sig |
|
10 val infer_types: (string -> typ option) -> Sorts.classrel -> Sorts.arities |
|
11 -> string list -> bool -> (indexname -> bool) -> term list -> typ list |
|
12 -> term list * typ list * (indexname * typ) list |
|
13 end; |
|
14 |
|
15 structure TypeInfer: TYPE_INFER = |
|
16 struct |
|
17 |
|
18 |
|
19 (** generic utils **) |
|
20 |
|
21 fun seq2 _ [] [] = () |
|
22 | seq2 f (x :: xs) (y :: ys) = (f x y; seq2 f xs ys) |
|
23 | seq2 _ _ _ = raise LIST "seq2"; |
|
24 |
|
25 fun scan _ (xs, []) = (xs, []) |
|
26 | scan f (xs, y :: ys) = |
|
27 let |
|
28 val (xs', y') = f (xs, y); |
|
29 val (xs'', ys') = scan f (xs', ys); |
|
30 in (xs'', y' :: ys') end; |
|
31 |
|
32 |
|
33 |
|
34 (** term encodings **) |
|
35 |
|
36 (* |
|
37 Flavours of term encodings: |
|
38 |
|
39 parse trees (type term): |
|
40 A very complicated structure produced by the syntax module's |
|
41 read functions. Encodes types and sorts as terms; may contain |
|
42 explicit constraints and partial typing information (where |
|
43 dummyT serves as wildcard). |
|
44 |
|
45 Parse trees are INTERNAL! Users should never encounter them, |
|
46 except in parse / print translation functions. |
|
47 |
|
48 raw terms (type term): |
|
49 Provide the user interface to type inferences. They may contain |
|
50 partial type information (dummyT is wildcard) or explicit type |
|
51 constraints (introduced via constrain: term -> typ -> term). |
|
52 |
|
53 The type inference function also lets users specify a certain |
|
54 subset of TVars to be treated as non-rigid inference parameters. |
|
55 |
|
56 preterms (type preterm): |
|
57 The internal representation for type inference. |
|
58 |
|
59 well-typed term (type term): |
|
60 Fully typed lambda terms to be accepted by appropriate |
|
61 certification functions. |
|
62 *) |
|
63 |
|
64 |
|
65 |
|
66 (** pretyps and preterms **) |
|
67 |
|
68 (*links to parameters may get instantiated, anything else is rigid*) |
|
69 datatype pretyp = |
|
70 PType of string * pretyp list | |
|
71 PTFree of string * sort | |
|
72 PTVar of indexname * sort | |
|
73 Param of sort | |
|
74 Link of pretyp ref; |
|
75 |
|
76 datatype preterm = |
|
77 PConst of string * pretyp | |
|
78 PFree of string * pretyp | |
|
79 PVar of indexname * pretyp | |
|
80 PBound of int | |
|
81 PAbs of string * pretyp * preterm | |
|
82 PAppl of preterm * preterm | |
|
83 Constraint of preterm * pretyp; |
|
84 |
|
85 |
|
86 (* utils *) |
|
87 |
|
88 val mk_param = Link o ref o Param; |
|
89 |
|
90 fun deref (T as Link (ref (Param _))) = T |
|
91 | deref (Link (ref T)) = deref T |
|
92 | deref T = T; |
|
93 |
|
94 fun foldl_pretyps f (x, PConst (_, T)) = f (x, T) |
|
95 | foldl_pretyps f (x, PFree (_, T)) = f (x, T) |
|
96 | foldl_pretyps f (x, PVar (_, T)) = f (x, T) |
|
97 | foldl_pretyps _ (x, PBound _) = x |
|
98 | foldl_pretyps f (x, PAbs (_, T, t)) = foldl_pretyps f (f (x, T), t) |
|
99 | foldl_pretyps f (x, PAppl (t, u)) = foldl_pretyps f (foldl_pretyps f (x, t), u) |
|
100 | foldl_pretyps f (x, Constraint (t, T)) = f (foldl_pretyps f (x, t), T); |
|
101 |
|
102 |
|
103 |
|
104 (** raw typs/terms to pretyps/preterms **) |
|
105 |
|
106 (* pretyp(s)_of *) |
|
107 |
|
108 fun pretyp_of is_param (params, typ) = |
|
109 let |
|
110 fun add_parms (ps, TVar (xi as (x, _), S)) = |
|
111 if is_param xi andalso is_none (assoc (ps, xi)) |
|
112 then (xi, mk_param S) :: ps else ps |
|
113 | add_parms (ps, TFree _) = ps |
|
114 | add_parms (ps, Type (_, Ts)) = foldl add_parms (ps, Ts); |
|
115 |
|
116 val params' = add_parms (params, typ); |
|
117 |
|
118 fun pre_of (TVar (v as (xi, _))) = |
|
119 (case assoc (params', xi) of |
|
120 None => PTVar v |
|
121 | Some p => p) |
|
122 | pre_of (TFree v) = PTFree v |
|
123 | pre_of (T as Type (a, Ts)) = |
|
124 if T = dummyT then mk_param [] |
|
125 else PType (a, map pre_of Ts); |
|
126 in (params', pre_of typ) end; |
|
127 |
|
128 fun pretyps_of is_param = scan (pretyp_of is_param); |
|
129 |
|
130 |
|
131 (* preterm(s)_of *) |
|
132 |
|
133 fun preterm_of const_type is_param ((vparams, params), tm) = |
|
134 let |
|
135 fun add_vparm (ps, xi) = |
|
136 if is_none (assoc (ps, xi)) then |
|
137 (xi, mk_param []) :: ps |
|
138 else ps; |
|
139 |
|
140 fun add_vparms (ps, Var (xi, _)) = add_vparm (ps, xi) |
|
141 | add_vparms (ps, Free (x, _)) = add_vparm (ps, (x, ~1)) |
|
142 | add_vparms (ps, Abs (_, _, t)) = add_vparms (ps, t) |
|
143 | add_vparms (ps, t $ u) = add_vparms (add_vparms (ps, t), u) |
|
144 | add_vparms (ps, _) = ps; |
|
145 |
|
146 val vparams' = add_vparms (vparams, tm); |
|
147 fun var_param xi = the (assoc (vparams', xi)); |
|
148 |
|
149 |
|
150 val preT_of = pretyp_of is_param; |
|
151 |
|
152 fun constrain (ps, t) T = |
|
153 if T = dummyT then (ps, t) |
|
154 else |
|
155 let val (ps', T') = preT_of (ps, T) in |
|
156 (ps', Constraint (t, T')) |
|
157 end; |
|
158 |
|
159 fun pre_of (ps, Const (c, T)) = |
|
160 (case const_type c of |
|
161 Some U => constrain (ps, PConst (c, snd (pretyp_of (K true) ([], U)))) T |
|
162 | None => raise_type ("No such constant: " ^ quote c) [] []) |
|
163 | pre_of (ps, Free (x, T)) = constrain (ps, PFree (x, var_param (x, ~1))) T |
|
164 | pre_of (ps, Var (xi, T)) = constrain (ps, PVar (xi, var_param xi)) T |
|
165 | pre_of (ps, Const ("_type_constraint_", T) $ t) = constrain (pre_of (ps, t)) T |
|
166 | pre_of (ps, Bound i) = (ps, PBound i) |
|
167 | pre_of (ps, Abs (x, T, t)) = |
|
168 let |
|
169 val (ps', T') = preT_of (ps, T); |
|
170 val (ps'', t') = pre_of (ps', t); |
|
171 in (ps'', PAbs (x, T', t')) end |
|
172 | pre_of (ps, t $ u) = |
|
173 let |
|
174 val (ps', t') = pre_of (ps, t); |
|
175 val (ps'', u') = pre_of (ps', u); |
|
176 in (ps'', PAppl (t', u')) end; |
|
177 |
|
178 |
|
179 val (params', tm') = pre_of (params, tm); |
|
180 in |
|
181 ((vparams', params'), tm') |
|
182 end; |
|
183 |
|
184 fun preterms_of const_type is_param = scan (preterm_of const_type is_param); |
|
185 |
|
186 |
|
187 |
|
188 (** pretyps/terms to typs/terms **) |
|
189 |
|
190 (* add_parms *) |
|
191 |
|
192 fun add_parmsT (rs, PType (_, Ts)) = foldl add_parmsT (rs, Ts) |
|
193 | add_parmsT (rs, Link (r as ref (Param _))) = r ins rs |
|
194 | add_parmsT (rs, Link (ref T)) = add_parmsT (rs, T) |
|
195 | add_parmsT (rs, _) = rs; |
|
196 |
|
197 val add_parms = foldl_pretyps add_parmsT; |
|
198 |
|
199 |
|
200 (* add_names *) |
|
201 |
|
202 fun add_namesT (xs, PType (_, Ts)) = foldl add_namesT (xs, Ts) |
|
203 | add_namesT (xs, PTFree (x, _)) = x ins xs |
|
204 | add_namesT (xs, PTVar ((x, _), _)) = x ins xs |
|
205 | add_namesT (xs, Link (ref T)) = add_namesT (xs, T) |
|
206 | add_namesT (xs, Param _) = xs; |
|
207 |
|
208 val add_names = foldl_pretyps add_namesT; |
|
209 |
|
210 |
|
211 (* simple_typ/term_of *) |
|
212 |
|
213 (*deref links, fail on params*) |
|
214 fun simple_typ_of (PType (a, Ts)) = Type (a, map simple_typ_of Ts) |
|
215 | simple_typ_of (PTFree v) = TFree v |
|
216 | simple_typ_of (PTVar v) = TVar v |
|
217 | simple_typ_of (Link (ref T)) = simple_typ_of T |
|
218 | simple_typ_of (Param _) = sys_error "simple_typ_of: illegal Param"; |
|
219 |
|
220 (*convert types, drop constraints*) |
|
221 fun simple_term_of (PConst (c, T)) = Const (c, simple_typ_of T) |
|
222 | simple_term_of (PFree (x, T)) = Free (x, simple_typ_of T) |
|
223 | simple_term_of (PVar (xi, T)) = Var (xi, simple_typ_of T) |
|
224 | simple_term_of (PBound i) = Bound i |
|
225 | simple_term_of (PAbs (x, T, t)) = Abs (x, simple_typ_of T, simple_term_of t) |
|
226 | simple_term_of (PAppl (t, u)) = simple_term_of t $ simple_term_of u |
|
227 | simple_term_of (Constraint (t, _)) = simple_term_of t; |
|
228 |
|
229 |
|
230 (* typs_terms_of *) (*DESTRUCTIVE*) |
|
231 |
|
232 fun typs_terms_of used mk_var prfx (Ts, ts) = |
|
233 let |
|
234 fun elim (r as ref (Param S)) x = r := mk_var (x, S) |
|
235 | elim _ _ = (); |
|
236 |
|
237 val used' = foldl add_names (foldl add_namesT (used, Ts), ts); |
|
238 val parms = rev (foldl add_parms (foldl add_parmsT ([], Ts), ts)); |
|
239 val pre_names = replicate (length parms) (prfx ^ "'"); |
|
240 val names = variantlist (pre_names, prfx ^ "'" :: used'); |
|
241 in |
|
242 seq2 elim parms names; |
|
243 (map simple_typ_of Ts, map simple_term_of ts) |
|
244 end; |
|
245 |
|
246 |
|
247 |
|
248 (** order-sorted unification of types **) (*DESTRUCTIVE*) |
|
249 |
|
250 exception NO_UNIFIER of string; |
|
251 |
|
252 |
|
253 fun unify classrel arities = |
|
254 let |
|
255 |
|
256 (* adjust sorts of parameters *) |
|
257 |
|
258 fun not_in_sort x S' S = |
|
259 "Type variable " ^ x ^ "::" ^ Sorts.str_of_sort S' ^ " not in sort " ^ |
|
260 Sorts.str_of_sort S; |
|
261 |
|
262 fun meet _ [] = () |
|
263 | meet (Link (r as (ref (Param S')))) S = |
|
264 if Sorts.sort_le classrel (S', S) then () |
|
265 else r := mk_param (Sorts.inter_sort classrel (S', S)) |
|
266 | meet (Link (ref T)) S = meet T S |
|
267 | meet (PType (a, Ts)) S = |
|
268 seq2 meet Ts (Sorts.mg_domain classrel arities a S |
|
269 handle TYPE (msg, _, _) => raise NO_UNIFIER msg) |
|
270 | meet (PTFree (x, S')) S = |
|
271 if Sorts.sort_le classrel (S', S) then () |
|
272 else raise NO_UNIFIER (not_in_sort x S' S) |
|
273 | meet (PTVar (xi, S')) S = |
|
274 if Sorts.sort_le classrel (S', S) then () |
|
275 else raise NO_UNIFIER (not_in_sort (Syntax.string_of_vname xi) S' S) |
|
276 | meet (Param _) _ = sys_error "meet"; |
|
277 |
|
278 |
|
279 (* occurs check and assigment *) |
|
280 |
|
281 fun occurs_check r (Link (r' as ref T)) = |
|
282 if r = r' then raise NO_UNIFIER "Occurs check!" |
|
283 else occurs_check r T |
|
284 | occurs_check r (PType (_, Ts)) = seq (occurs_check r) Ts |
|
285 | occurs_check _ _ = (); |
|
286 |
|
287 fun assign r T S = |
|
288 (case deref T of |
|
289 T' as Link (r' as ref (Param _)) => |
|
290 if r = r' then () else (r := T'; meet T' S) |
|
291 | T' => (occurs_check r T'; r := T'; meet T' S)); |
|
292 |
|
293 |
|
294 (* unification *) |
|
295 |
|
296 fun unif (Link (r as ref (Param S))) T = assign r T S |
|
297 | unif T (Link (r as ref (Param S))) = assign r T S |
|
298 | unif (Link (ref T)) U = unif T U |
|
299 | unif T (Link (ref U)) = unif T U |
|
300 | unif (PType (a, Ts)) (PType (b, Us)) = |
|
301 if a <> b then raise NO_UNIFIER ("Clash of " ^ a ^ ", " ^ b ^ "!") |
|
302 else seq2 unif Ts Us |
|
303 | unif T U = if T = U then () else raise NO_UNIFIER "Unification failed!"; |
|
304 |
|
305 in unif end; |
|
306 |
|
307 |
|
308 |
|
309 (** type inference **) |
|
310 |
|
311 (* infer *) (*DESTRUCTIVE*) |
|
312 |
|
313 fun infer classrel arities = |
|
314 let |
|
315 val unif = unify classrel arities; |
|
316 |
|
317 fun err msg1 msg2 bs ts Ts = |
|
318 let |
|
319 val (Ts_bTs', ts') = typs_terms_of [] PTFree "??" (Ts @ map snd bs, ts); |
|
320 val len = length Ts; |
|
321 val Ts' = take (len, Ts_bTs'); |
|
322 val xs = map Free (map fst bs ~~ drop (len, Ts_bTs')); |
|
323 val ts'' = map (fn t => subst_bounds (xs, t)) ts'; |
|
324 in |
|
325 raise_type (msg1 ^ " " ^ msg2) Ts' ts'' |
|
326 end; |
|
327 |
|
328 fun inf _ (PConst (_, T)) = T |
|
329 | inf _ (PFree (_, T)) = T |
|
330 | inf _ (PVar (_, T)) = T |
|
331 | inf bs (PBound i) = snd (nth_elem (i, bs) |
|
332 handle LIST _ => raise_type "Loose bound variable" [] [Bound i]) |
|
333 | inf bs (PAbs (x, T, t)) = PType ("fun", [T, inf ((x, T) :: bs) t]) |
|
334 | inf bs (PAppl (t, u)) = |
|
335 let |
|
336 val T = inf bs t; |
|
337 val U = inf bs u; |
|
338 val V = mk_param []; |
|
339 val U_to_V = PType ("fun", [U, V]); |
|
340 val _ = unif U_to_V T handle NO_UNIFIER msg => |
|
341 err msg "Bad function application." bs [PAppl (t, u)] [U_to_V, U]; |
|
342 in V end |
|
343 | inf bs (Constraint (t, U)) = |
|
344 let val T = inf bs t in |
|
345 unif T U handle NO_UNIFIER msg => |
|
346 err msg "Cannot meet type constraint." bs [t] [T, U]; |
|
347 T |
|
348 end; |
|
349 |
|
350 in inf [] end; |
|
351 |
|
352 |
|
353 (* infer_types *) |
|
354 |
|
355 fun infer_types const_type classrel arities used freeze is_param ts Ts = |
|
356 let |
|
357 (*convert to preterms/typs*) |
|
358 val (Tps, Ts') = pretyps_of (K true) ([], Ts); |
|
359 val ((vps, ps), ts') = preterms_of const_type is_param (([], Tps), ts); |
|
360 |
|
361 (*run type inference*) |
|
362 val tTs' = ListPair.map Constraint (ts', Ts'); |
|
363 val _ = seq (fn t => (infer classrel arities t; ())) tTs'; |
|
364 |
|
365 (*collect result unifier*) |
|
366 fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); None) |
|
367 | ch_var xi_T = Some xi_T; |
|
368 val env = mapfilter ch_var Tps; |
|
369 |
|
370 (*convert back to terms/typs*) |
|
371 val mk_var = |
|
372 if freeze then PTFree |
|
373 else (fn (x, S) => PTVar ((x, 0), S)); |
|
374 val (final_Ts, final_ts) = typs_terms_of used mk_var "" (Ts', ts'); |
|
375 val final_env = map (apsnd simple_typ_of) env; |
|
376 in |
|
377 (final_ts, final_Ts, final_env) |
|
378 end; |
|
379 |
|
380 end; |