1 (* Title: Pure/type_infer.ML |
1 (* Title: Pure/type_infer.ML |
2 Author: Stefan Berghofer and Markus Wenzel, TU Muenchen |
2 Author: Stefan Berghofer and Markus Wenzel, TU Muenchen |
3 |
3 |
4 Simple type inference. |
4 Representation of type-inference problems. Simple type inference. |
5 *) |
5 *) |
6 |
6 |
7 signature TYPE_INFER = |
7 signature TYPE_INFER = |
8 sig |
8 sig |
9 val anyT: sort -> typ |
|
10 val is_param: indexname -> bool |
9 val is_param: indexname -> bool |
11 val is_paramT: typ -> bool |
10 val is_paramT: typ -> bool |
12 val param: int -> string * sort -> typ |
11 val param: int -> string * sort -> typ |
|
12 val anyT: sort -> typ |
13 val paramify_vars: typ -> typ |
13 val paramify_vars: typ -> typ |
14 val paramify_dummies: typ -> int -> typ * int |
14 val paramify_dummies: typ -> int -> typ * int |
15 val fixate_params: Proof.context -> term list -> term list |
15 val deref: typ Vartab.table -> typ -> typ |
|
16 val finish: Proof.context -> typ Vartab.table -> typ list * term list -> typ list * term list |
|
17 val fixate: Proof.context -> term list -> term list |
|
18 val prepare: Proof.context -> (string -> typ option) -> (string * int -> typ option) -> |
|
19 term list -> int * term list |
16 val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) -> |
20 val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) -> |
17 term list -> term list |
21 term list -> term list |
18 end; |
22 end; |
19 |
23 |
20 structure Type_Infer: TYPE_INFER = |
24 structure Type_Infer: TYPE_INFER = |
21 struct |
25 struct |
22 |
26 |
23 |
|
24 (** type parameters and constraints **) |
27 (** type parameters and constraints **) |
25 |
|
26 fun anyT S = TFree ("'_dummy_", S); |
|
27 |
|
28 |
28 |
29 (* type inference parameters -- may get instantiated *) |
29 (* type inference parameters -- may get instantiated *) |
30 |
30 |
31 fun is_param (x, _: int) = String.isPrefix "?" x; |
31 fun is_param (x, _: int) = String.isPrefix "?" x; |
32 |
32 |
34 | is_paramT _ = false; |
34 | is_paramT _ = false; |
35 |
35 |
36 fun param i (x, S) = TVar (("?" ^ x, i), S); |
36 fun param i (x, S) = TVar (("?" ^ x, i), S); |
37 |
37 |
38 fun mk_param i S = TVar (("?'a", i), S); |
38 fun mk_param i S = TVar (("?'a", i), S); |
|
39 |
|
40 |
|
41 (* pre-stage parameters *) |
|
42 |
|
43 fun anyT S = TFree ("'_dummy_", S); |
39 |
44 |
40 val paramify_vars = |
45 val paramify_vars = |
41 Same.commit |
46 Same.commit |
42 (Term_Subst.map_atypsT_same |
47 (Term_Subst.map_atypsT_same |
43 (fn TVar ((x, i), S) => (param i (x, S)) | _ => raise Same.SAME)); |
48 (fn TVar ((x, i), S) => (param i (x, S)) | _ => raise Same.SAME)); |
51 | paramify (Type (a, Ts)) maxidx = |
56 | paramify (Type (a, Ts)) maxidx = |
52 let val (Ts', maxidx') = fold_map paramify Ts maxidx |
57 let val (Ts', maxidx') = fold_map paramify Ts maxidx |
53 in (Type (a, Ts'), maxidx') end |
58 in (Type (a, Ts'), maxidx') end |
54 | paramify T maxidx = (T, maxidx); |
59 | paramify T maxidx = (T, maxidx); |
55 in paramify end; |
60 in paramify end; |
56 |
|
57 fun fixate_params ctxt ts = |
|
58 let |
|
59 fun subst_param (xi, S) (inst, used) = |
|
60 if is_param xi then |
|
61 let |
|
62 val [a] = Name.invents used Name.aT 1; |
|
63 val used' = Name.declare a used; |
|
64 in (((xi, S), TFree (a, S)) :: inst, used') end |
|
65 else (inst, used); |
|
66 val used = (fold o fold_types) Term.declare_typ_names ts (Variable.names_of ctxt); |
|
67 val (inst, _) = fold_rev subst_param (fold Term.add_tvars ts []) ([], used); |
|
68 in (map o map_types) (Term_Subst.instantiateT inst) ts end; |
|
69 |
61 |
70 |
62 |
71 |
63 |
72 (** prepare types/terms: create inference parameters **) |
64 (** prepare types/terms: create inference parameters **) |
73 |
65 |
154 val (tm', (params', idx'')) = prepare tm (params, idx'); |
146 val (tm', (params', idx'')) = prepare tm (params, idx'); |
155 in (tm', (vparams', params', idx'')) end; |
147 in (tm', (vparams', params', idx'')) end; |
156 |
148 |
157 |
149 |
158 |
150 |
159 (** finish types/terms: standardize remaining parameters **) |
151 (** results **) |
160 |
152 |
161 (* dereferenced views *) |
153 (* dereferenced views *) |
162 |
154 |
163 fun deref tye (T as TVar (xi, _)) = |
155 fun deref tye (T as TVar (xi, _)) = |
164 (case Vartab.lookup tye xi of |
156 (case Vartab.lookup tye xi of |
177 Type (_, Ts) => fold (add_names tye) Ts |
169 Type (_, Ts) => fold (add_names tye) Ts |
178 | TFree (x, _) => Name.declare x |
170 | TFree (x, _) => Name.declare x |
179 | TVar ((x, i), _) => if is_param (x, i) then I else Name.declare x); |
171 | TVar ((x, i), _) => if is_param (x, i) then I else Name.declare x); |
180 |
172 |
181 |
173 |
182 (* finish *) |
174 (* finish -- standardize remaining parameters *) |
183 |
175 |
184 fun finish ctxt tye (Ts, ts) = |
176 fun finish ctxt tye (Ts, ts) = |
185 let |
177 let |
186 val used = |
178 val used = |
187 (fold o fold_types) (add_names tye) ts (fold (add_names tye) Ts (Variable.names_of ctxt)); |
179 (fold o fold_types) (add_names tye) ts (fold (add_names tye) Ts (Variable.names_of ctxt)); |
196 | U as TVar (xi, S) => |
188 | U as TVar (xi, S) => |
197 (case Vartab.lookup tab xi of |
189 (case Vartab.lookup tab xi of |
198 NONE => U |
190 NONE => U |
199 | SOME a => TVar ((a, 0), S))); |
191 | SOME a => TVar ((a, 0), S))); |
200 in (map finish_typ Ts, map (Type.strip_constraints o Term.map_types finish_typ) ts) end; |
192 in (map finish_typ Ts, map (Type.strip_constraints o Term.map_types finish_typ) ts) end; |
|
193 |
|
194 |
|
195 (* fixate -- introduce fresh type variables *) |
|
196 |
|
197 fun fixate ctxt ts = |
|
198 let |
|
199 fun subst_param (xi, S) (inst, used) = |
|
200 if is_param xi then |
|
201 let |
|
202 val [a] = Name.invents used Name.aT 1; |
|
203 val used' = Name.declare a used; |
|
204 in (((xi, S), TFree (a, S)) :: inst, used') end |
|
205 else (inst, used); |
|
206 val used = (fold o fold_types) Term.declare_typ_names ts (Variable.names_of ctxt); |
|
207 val (inst, _) = fold_rev subst_param (fold Term.add_tvars ts []) ([], used); |
|
208 in (map o map_types) (Term_Subst.instantiateT inst) ts end; |
201 |
209 |
202 |
210 |
203 |
211 |
204 (** order-sorted unification of types **) |
212 (** order-sorted unification of types **) |
205 |
213 |
321 in (V, tye_idx'') end; |
329 in (V, tye_idx'') end; |
322 |
330 |
323 in inf [] end; |
331 in inf [] end; |
324 |
332 |
325 |
333 |
326 (* infer_types *) |
334 (* main interfaces *) |
327 |
335 |
328 fun infer_types ctxt const_type var_type raw_ts = |
336 fun prepare ctxt const_type var_type raw_ts = |
329 let |
337 let |
330 (*constrain vars*) |
|
331 val get_type = the_default dummyT o var_type; |
338 val get_type = the_default dummyT o var_type; |
332 val constrain_vars = Term.map_aterms |
339 val constrain_vars = Term.map_aterms |
333 (fn Free (x, T) => Type.constraint T (Free (x, get_type (x, ~1))) |
340 (fn Free (x, T) => Type.constraint T (Free (x, get_type (x, ~1))) |
334 | Var (xi, T) => Type.constraint T (Var (xi, get_type xi)) |
341 | Var (xi, T) => Type.constraint T (Var (xi, get_type xi)) |
335 | t => t); |
342 | t => t); |
336 |
343 |
337 (*convert to preterms*) |
|
338 val ts = burrow_types (Syntax.check_typs ctxt) raw_ts; |
344 val ts = burrow_types (Syntax.check_typs ctxt) raw_ts; |
339 val (ts', (_, _, idx)) = |
345 val (ts', (_, _, idx)) = |
340 fold_map (prepare_term const_type o constrain_vars) ts |
346 fold_map (prepare_term const_type o constrain_vars) ts |
341 (Vartab.empty, Vartab.empty, 0); |
347 (Vartab.empty, Vartab.empty, 0); |
342 |
348 in (idx, ts') end; |
343 (*do type inference*) |
349 |
344 val (tye, _) = fold (snd oo infer ctxt) ts' (Vartab.empty, idx); |
350 fun infer_types ctxt const_type var_type raw_ts = |
345 in #2 (finish ctxt tye ([], ts')) end; |
351 let |
|
352 val (idx, ts) = prepare ctxt const_type var_type raw_ts; |
|
353 val (tye, _) = fold (snd oo infer ctxt) ts (Vartab.empty, idx); |
|
354 val (_, ts') = finish ctxt tye ([], ts); |
|
355 in ts' end; |
346 |
356 |
347 end; |
357 end; |