1 
(* Title: Pure/type_infer_context.ML 
split Type_Infer into early and late part, after Proof_Context;
wenzelm
parents:
diff
changeset

2 
Author: Stefan Berghofer and Markus Wenzel, TU Muenchen 
split Type_Infer into early and late part, after Proof_Context;
wenzelm
parents:
diff
changeset

3 

split Type_Infer into early and late part, after Proof_Context;
wenzelm
parents:
diff
changeset

4 
Typeinference preparation and standard type inference. 
split Type_Infer into early and late part, after Proof_Context;
wenzelm
parents:
diff
changeset

5 
*) 
split Type_Infer into early and late part, after Proof_Context;
wenzelm
parents:
diff
changeset

6 

split Type_Infer into early and late part, after Proof_Context;
wenzelm
parents:
diff
changeset

7 
signature TYPE_INFER_CONTEXT = 
split Type_Infer into early and late part, after Proof_Context;
wenzelm
parents:
diff
changeset

8 
sig 
split Type_Infer into early and late part, after Proof_Context;
wenzelm
parents:
diff
changeset

9 
val const_sorts: bool Config.T 
10 
val prepare_positions: Proof.context > term list > term list * (Position.T * typ) list 
11 
val prepare: Proof.context > term list > int * term list 
12 
val infer_types: Proof.context > term list > term list 
13 
end; 
14 

15 
structure Type_Infer_Context: TYPE_INFER_CONTEXT = 
16 
struct 
17 

18 
(** prepare types/terms: create inference parameters **) 
19 

20 
(* constraints *) 
21 

22 
val const_sorts = Config.bool (Config.declare "const_sorts" (K (Config.Bool true))); 
23 

24 
fun const_type ctxt = 
25 
try ((not (Config.get ctxt const_sorts) ? Type.strip_sorts) o 
26 
Consts.the_constraint (Proof_Context.consts_of ctxt)); 
27 

28 
fun var_type ctxt = the_default dummyT o Proof_Context.def_type ctxt; 
29 

30 

31 
(* prepare_typ *) 
32 

33 
fun prepare_typ typ params_idx = 
34 
let 
35 
val (params', idx) = fold_atyps 
36 
(fn TVar (xi, S) => 
37 
(fn ps_idx as (ps, idx) => 
38 
if Type_Infer.is_param xi andalso not (Vartab.defined ps xi) 
39 
then (Vartab.update (xi, Type_Infer.mk_param idx S) ps, idx + 1) else ps_idx) 
40 
 _ => I) typ params_idx; 
41 

42 
fun prepare (T as Type (a, Ts)) idx = 
43 
if T = dummyT then (Type_Infer.mk_param idx [], idx + 1) 
44 
else 
45 
let val (Ts', idx') = fold_map prepare Ts idx 
46 
in (Type (a, Ts'), idx') end 
47 
 prepare (T as TVar (xi, _)) idx = 
48 
(case Vartab.lookup params' xi of 
49 
NONE => T 
50 
 SOME p => p, idx) 
51 
 prepare (TFree ("'_dummy_", S)) idx = (Type_Infer.mk_param idx S, idx + 1) 
52 
 prepare (T as TFree _) idx = (T, idx); 
53 

54 
val (typ', idx') = prepare typ idx; 
55 
in (typ', (params', idx')) end; 
56 

57 

58 
(* prepare_term *) 
59 

60 
fun prepare_term ctxt tm (vparams, params, idx) = 
61 
let 
62 
fun add_vparm xi (ps_idx as (ps, idx)) = 
63 
if not (Vartab.defined ps xi) then 
64 
(Vartab.update (xi, Type_Infer.mk_param idx []) ps, idx + 1) 
65 
else ps_idx; 
66 

67 
val (vparams', idx') = fold_aterms 
68 
(fn Var (_, Type ("_polymorphic_", _)) => I 
69 
 Var (xi, _) => add_vparm xi 
70 
 Free (x, _) => add_vparm (x, ~1) 
71 
 _ => I) 
72 
tm (vparams, idx); 
73 
fun var_param xi = the (Vartab.lookup vparams' xi); 
74 

75 
fun polyT_of T idx = 
76 
apsnd snd (prepare_typ (Type_Infer.paramify_vars T) (Vartab.empty, idx)); 
77 

78 
fun constraint T t ps = 
79 
if T = dummyT then (t, ps) 
80 
else 
81 
let val (T', ps') = prepare_typ T ps 
82 
in (Type.constraint T' t, ps') end; 
83 

84 
fun prepare (Const ("_type_constraint_", T) $ t) ps_idx = 
85 
let 
86 
val A = Type.constraint_type ctxt T; 
87 
val (A', ps_idx') = prepare_typ A ps_idx; 
88 
val (t', ps_idx'') = prepare t ps_idx'; 
89 
in (Const ("_type_constraint_", A' > A') $ t', ps_idx'') end 
90 
 prepare (Const (c, T)) (ps, idx) = 
91 
(case const_type ctxt c of 
92 
SOME U => 
93 
let val (U', idx') = polyT_of U idx 
94 
in constraint T (Const (c, U')) (ps, idx') end 
95 
 NONE => error ("Undeclared constant: " ^ quote c)) 
96 
 prepare (Var (xi, Type ("_polymorphic_", [T]))) (ps, idx) = 
97 
let val (T', idx') = polyT_of T idx 
98 
in (Var (xi, T'), (ps, idx')) end 
99 
 prepare (Var (xi, T)) ps_idx = constraint T (Var (xi, var_param xi)) ps_idx 
100 
 prepare (Free (x, T)) ps_idx = constraint T (Free (x, var_param (x, ~1))) ps_idx 
101 
 prepare (Bound i) ps_idx = (Bound i, ps_idx) 
102 
 prepare (Abs (x, T, t)) ps_idx = 
103 
let 
104 
val (T', ps_idx') = prepare_typ T ps_idx; 
105 
val (t', ps_idx'') = prepare t ps_idx'; 
106 
in (Abs (x, T', t'), ps_idx'') end 
107 
 prepare (t $ u) ps_idx = 
108 
let 
109 
val (t', ps_idx') = prepare t ps_idx; 
110 
val (u', ps_idx'') = prepare u ps_idx'; 
111 
in (t' $ u', ps_idx'') end; 
112 

113 
val (tm', (params', idx'')) = prepare tm (params, idx'); 
114 
in (tm', (vparams', params', idx'')) end; 
115 

116 

117 
(* prepare_positions *) 
118 

119 
fun prepare_positions ctxt tms = 
120 
let 
121 
fun prepareT (Type (a, Ts)) ps_idx = 
122 
let val (Ts', ps_idx') = fold_map prepareT Ts ps_idx 
123 
in (Type (a, Ts'), ps_idx') end 
124 
 prepareT T (ps, idx) = 
125 
(case Term_Position.decode_positionT T of 
126 
SOME pos => 
127 
let val U = Type_Infer.mk_param idx [] 
128 
in (U, ((pos, U) :: ps, idx + 1)) end 
129 
 NONE => (T, (ps, idx))); 
130 

131 
fun prepare (Const ("_type_constraint_", T)) ps_idx = 
132 
let 
133 
val A = Type.constraint_type ctxt T; 
134 
val (A', ps_idx') = prepareT A ps_idx; 
135 
in (Const ("_type_constraint_", A' > A'), ps_idx') end 
136 
 prepare (Const (c, T)) ps_idx = 
137 
let val (T', ps_idx') = prepareT T ps_idx 
138 
in (Const (c, T'), ps_idx') end 
139 
 prepare (Free (x, T)) ps_idx = 
140 
let val (T', ps_idx') = prepareT T ps_idx 
141 
in (Free (x, T'), ps_idx') end 
142 
 prepare (Var (xi, T)) ps_idx = 
143 
let val (T', ps_idx') = prepareT T ps_idx 
144 
in (Var (xi, T'), ps_idx') end 
145 
 prepare (t as Bound _) ps_idx = (t, ps_idx) 
146 
 prepare (Abs (x, T, t)) ps_idx = 
147 
let 
148 
val (T', ps_idx') = prepareT T ps_idx; 
149 
val (t', ps_idx'') = prepare t ps_idx'; 
150 
in (Abs (x, T', t'), ps_idx'') end 
151 
 prepare (t $ u) ps_idx = 
152 
let 
153 
val (t', ps_idx') = prepare t ps_idx; 
154 
val (u', ps_idx'') = prepare u ps_idx'; 
155 
in (t' $ u', ps_idx'') end; 
156 

157 
val idx = Type_Infer.param_maxidx_of tms + 1; 
158 
val (tms', (ps, _)) = fold_map prepare tms ([], idx); 
159 
in (tms', ps) end; 
160 

161 

162 

163 
(** ordersorted unification of types **) 
164 

165 
exception NO_UNIFIER of string * typ Vartab.table; 
166 

167 
fun unify ctxt = 
168 
let 
169 
val thy = Proof_Context.theory_of ctxt; 
170 
val arity_sorts = Type.arity_sorts (Context.pretty ctxt) (Sign.tsig_of thy); 
171 

172 

173 
(* adjust sorts of parameters *) 
174 

175 
fun not_of_sort x S' S = 
176 
"Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^ 
177 
Syntax.string_of_sort ctxt S; 
178 

179 
fun meet (_, []) tye_idx = tye_idx 
180 
 meet (Type (a, Ts), S) (tye_idx as (tye, _)) = 
181 
meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx 
182 
 meet (TFree (x, S'), S) (tye_idx as (tye, _)) = 
183 
if Sign.subsort thy (S', S) then tye_idx 
184 
else raise NO_UNIFIER (not_of_sort x S' S, tye) 
185 
 meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) = 
186 
if Sign.subsort thy (S', S) then tye_idx 
187 
else if Type_Infer.is_param xi then 
188 
(Vartab.update_new 
189 
(xi, Type_Infer.mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1) 
190 
else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye) 
191 
and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) = 
192 
meets (Ts, Ss) (meet (Type_Infer.deref tye T, S) tye_idx) 
193 
 meets _ tye_idx = tye_idx; 
194 

195 

196 
(* occurs check and assignment *) 
197 

198 
fun occurs_check tye xi (TVar (xi', _)) = 
199 
if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye) 
200 
else 
201 
(case Vartab.lookup tye xi' of 
202 
NONE => () 
203 
 SOME T => occurs_check tye xi T) 
204 
 occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts 
205 
 occurs_check _ _ _ = (); 
206 

207 
fun assign xi (T as TVar (xi', _)) S env = 
208 
if xi = xi' then env 
209 
else env > meet (T, S) >> Vartab.update_new (xi, T) 
210 
 assign xi T S (env as (tye, _)) = 
211 
(occurs_check tye xi T; env > meet (T, S) >> Vartab.update_new (xi, T)); 
212 

213 

214 
(* unification *) 
215 

216 
fun show_tycon (a, Ts) = 
217 
quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT))); 
218 

219 
fun unif (T1, T2) (env as (tye, _)) = 
220 
(case pairself (`Type_Infer.is_paramT o Type_Infer.deref tye) (T1, T2) of 
221 
((true, TVar (xi, S)), (_, T)) => assign xi T S env 
222 
 ((_, T), (true, TVar (xi, S))) => assign xi T S env 
223 
 ((_, Type (a, Ts)), (_, Type (b, Us))) => 
224 
if a <> b then 
225 
raise NO_UNIFIER 
226 
("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye) 
227 
else fold unif (Ts ~~ Us) env 
228 
 ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye)); 
229 

230 
in unif end; 
231 

232 

233 

234 
(** simple type inference **) 
235 

236 
(* infer *) 
237 

238 
fun infer ctxt = 
239 
let 
240 
(* errors *) 
241 

242 
fun prep_output tye bs ts Ts = 
243 
let 
244 
val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts); 
245 
val (Ts', Ts'') = chop (length Ts) Ts_bTs'; 
246 
fun prep t = 
247 
let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts''))) 
248 
in Term.subst_bounds (map Syntax_Trans.mark_boundT xs, t) end; 
249 
in (map prep ts', Ts') end; 
250 

251 
fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i); 
252 

253 
fun unif_failed msg = 
254 
"Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n"; 
255 

256 
fun err_appl msg tye bs t T u U = 
257 
let val ([t', u'], [T', U']) = prep_output tye bs [t, u] [T, U] 
258 
in error (unif_failed msg ^ Type.appl_error ctxt t' T' u' U' ^ "\n") end; 
259 

260 

261 
(* main *) 
262 

263 
fun inf _ (Const (_, T)) tye_idx = (T, tye_idx) 
264 
 inf _ (Free (_, T)) tye_idx = (T, tye_idx) 
265 
 inf _ (Var (_, T)) tye_idx = (T, tye_idx) 
266 
 inf bs (Bound i) tye_idx = 
43278  267 
(snd (nth bs i handle General.Subscript => err_loose i), tye_idx) 
268 
 inf bs (Abs (x, T, t)) tye_idx = 
269 
let val (U, tye_idx') = inf ((x, T) :: bs) t tye_idx 
270 
in (T > U, tye_idx') end 
271 
 inf bs (t $ u) tye_idx = 
272 
let 
273 
val (T, tye_idx') = inf bs t tye_idx; 
274 
val (U, (tye, idx)) = inf bs u tye_idx'; 
275 
val V = Type_Infer.mk_param idx []; 
276 
val tye_idx'' = unify ctxt (U > V, T) (tye, idx + 1) 
277 
handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U; 
278 
in (V, tye_idx'') end; 
279 

280 
in inf [] end; 
281 

282 

13ecdb3057d8
(* main interfaces *) 
284 

285 
fun prepare ctxt raw_ts = 
286 
let 
287 
val constrain_vars = Term.map_aterms 
288 
(fn Free (x, T) => Type.constraint T (Free (x, var_type ctxt (x, ~1))) 
289 
 Var (xi, T) => Type.constraint T (Var (xi, var_type ctxt xi)) 
290 
 t => t); 
291 

292 
val ts = burrow_types (Syntax.check_typs ctxt) raw_ts; 
293 
val idx = Type_Infer.param_maxidx_of ts + 1; 
294 
val (ts', (_, _, idx')) = 
295 
fold_map (prepare_term ctxt o constrain_vars) ts 
296 
(Vartab.empty, Vartab.empty, idx); 
297 
in (idx', ts') end; 
298 

299 
fun infer_types ctxt raw_ts = 
300 
let 
301 
val (idx, ts) = prepare ctxt raw_ts; 
302 
val (tye, _) = fold (snd oo infer ctxt) ts (Vartab.empty, idx); 
303 
val (_, ts') = Type_Infer.finish ctxt tye ([], ts); 
304 
in ts' end; 
305 

306 
end; 