296 fun unif (Link (r as ref (Param S))) T = assign r T S |
297 fun unif (Link (r as ref (Param S))) T = assign r T S |
297 | unif T (Link (r as ref (Param S))) = assign r T S |
298 | unif T (Link (r as ref (Param S))) = assign r T S |
298 | unif (Link (ref T)) U = unif T U |
299 | unif (Link (ref T)) U = unif T U |
299 | unif T (Link (ref U)) = unif T U |
300 | unif T (Link (ref U)) = unif T U |
300 | unif (PType (a, Ts)) (PType (b, Us)) = |
301 | unif (PType (a, Ts)) (PType (b, Us)) = |
301 if a <> b then raise NO_UNIFIER ("Clash of " ^ a ^ ", " ^ b ^ "!") |
302 if a <> b then |
|
303 raise NO_UNIFIER ("Clash of types " ^ quote a ^ " and " ^ quote b ^ ".") |
302 else seq2 unif Ts Us |
304 else seq2 unif Ts Us |
303 | unif T U = if T = U then () else raise NO_UNIFIER "Unification failed!"; |
305 | unif T U = if T = U then () else raise NO_UNIFIER ""; |
304 |
306 |
305 in unif end; |
307 in unif end; |
306 |
308 |
307 |
309 |
308 |
310 |
309 (** type inference **) |
311 (** type inference **) |
310 |
312 |
311 (* infer *) (*DESTRUCTIVE*) |
313 (* infer *) (*DESTRUCTIVE*) |
312 |
314 |
313 fun infer classrel arities = |
315 fun infer prt prT classrel arities = |
314 let |
316 let |
315 val unif = unify classrel arities; |
317 (* errors *) |
316 |
318 |
317 fun err msg1 msg2 bs ts Ts = |
319 fun unif_failed msg = |
|
320 "Type unification failed" ^ (if msg = "" then "." else ": " ^ msg) ^ "\n"; |
|
321 |
|
322 val str_of = Pretty.string_of; |
|
323 |
|
324 fun prep_output bs ts Ts = |
318 let |
325 let |
319 val (Ts_bTs', ts') = typs_terms_of [] PTFree "??" (Ts @ map snd bs, ts); |
326 val (Ts_bTs', ts') = typs_terms_of [] PTFree "??" (Ts @ map snd bs, ts); |
320 val len = length Ts; |
327 val len = length Ts; |
321 val Ts' = take (len, Ts_bTs'); |
328 val Ts' = take (len, Ts_bTs'); |
322 val xs = map Free (map fst bs ~~ drop (len, Ts_bTs')); |
329 val xs = map Free (map fst bs ~~ drop (len, Ts_bTs')); |
323 val ts'' = map (fn t => subst_bounds (xs, t)) ts'; |
330 val ts'' = map (fn t => subst_bounds (xs, t)) ts'; |
324 in |
331 in (ts'', Ts') end; |
325 raise_type (msg1 ^ " " ^ msg2) Ts' ts'' |
332 |
326 end; |
333 fun err_loose i = |
|
334 raise_type ("Loose bound variable: B." ^ string_of_int i) [] []; |
|
335 |
|
336 fun err_appl msg bs t T U_to_V u U = |
|
337 let |
|
338 val ([t', u'], [T', U_to_V', U']) = prep_output bs [t, u] [T, U_to_V, U]; |
|
339 val text = cat_lines |
|
340 [unif_failed msg, |
|
341 "Type error in application:", |
|
342 "", |
|
343 str_of (Pretty.block [Pretty.str "operator: ", Pretty.brk 1, prt t', |
|
344 Pretty.str " :: ", prT T']), |
|
345 str_of (Pretty.block [Pretty.str "expected type:", Pretty.brk 1, prT U_to_V']), |
|
346 "", |
|
347 str_of (Pretty.block [Pretty.str "operand: ", Pretty.brk 1, prt u', |
|
348 Pretty.str " :: ", prT U']), ""]; |
|
349 in raise_type text [T', U_to_V', U'] [t', u'] end; |
|
350 |
|
351 fun err_constraint msg bs t T U = |
|
352 let |
|
353 val ([t'], [T', U']) = prep_output bs [t] [T, U]; |
|
354 val text = cat_lines |
|
355 [unif_failed msg, |
|
356 "Cannot meet type constraint:", |
|
357 "", |
|
358 str_of (Pretty.block [Pretty.str "term: ", Pretty.brk 1, prt t', |
|
359 Pretty.str " :: ", prT T']), |
|
360 str_of (Pretty.block [Pretty.str "expected type: ", Pretty.brk 1, prT U']), ""]; |
|
361 in raise_type text [T', U'] [t'] end; |
|
362 |
|
363 |
|
364 (* main *) |
|
365 |
|
366 val unif = unify classrel arities; |
327 |
367 |
328 fun inf _ (PConst (_, T)) = T |
368 fun inf _ (PConst (_, T)) = T |
329 | inf _ (PFree (_, T)) = T |
369 | inf _ (PFree (_, T)) = T |
330 | inf _ (PVar (_, T)) = T |
370 | inf _ (PVar (_, T)) = T |
331 | inf bs (PBound i) = snd (nth_elem (i, bs) |
371 | inf bs (PBound i) = snd (nth_elem (i, bs) handle LIST _ => err_loose i) |
332 handle LIST _ => raise_type "Loose bound variable" [] [Bound i]) |
|
333 | inf bs (PAbs (x, T, t)) = PType ("fun", [T, inf ((x, T) :: bs) t]) |
372 | inf bs (PAbs (x, T, t)) = PType ("fun", [T, inf ((x, T) :: bs) t]) |
334 | inf bs (PAppl (t, u)) = |
373 | inf bs (PAppl (t, u)) = |
335 let |
374 let |
336 val T = inf bs t; |
375 val T = inf bs t; |
337 val U = inf bs u; |
376 val U = inf bs u; |
338 val V = mk_param []; |
377 val V = mk_param []; |
339 val U_to_V = PType ("fun", [U, V]); |
378 val U_to_V = PType ("fun", [U, V]); |
340 val _ = unif U_to_V T handle NO_UNIFIER msg => |
379 val _ = unif U_to_V T handle NO_UNIFIER msg => |
341 err msg "Bad function application." bs [PAppl (t, u)] [U_to_V, U]; |
380 err_appl msg bs t T U_to_V u U; |
342 in V end |
381 in V end |
343 | inf bs (Constraint (t, U)) = |
382 | inf bs (Constraint (t, U)) = |
344 let val T = inf bs t in |
383 let val T = inf bs t in |
345 unif T U handle NO_UNIFIER msg => |
384 unif T U handle NO_UNIFIER msg => err_constraint msg bs t T U; |
346 err msg "Cannot meet type constraint." bs [t] [T, U]; |
|
347 T |
385 T |
348 end; |
386 end; |
349 |
387 |
350 in inf [] end; |
388 in inf [] end; |
351 |
389 |
352 |
390 |
353 (* infer_types *) |
391 (* infer_types *) |
354 |
392 |
355 fun infer_types const_type classrel arities used freeze is_param ts Ts = |
393 fun infer_types prt prT const_type classrel arities used freeze is_param ts Ts = |
356 let |
394 let |
357 (*convert to preterms/typs*) |
395 (*convert to preterms/typs*) |
358 val (Tps, Ts') = pretyps_of (K true) ([], Ts); |
396 val (Tps, Ts') = pretyps_of (K true) ([], Ts); |
359 val ((vps, ps), ts') = preterms_of const_type is_param (([], Tps), ts); |
397 val ((vps, ps), ts') = preterms_of const_type is_param (([], Tps), ts); |
360 |
398 |
361 (*run type inference*) |
399 (*run type inference*) |
362 val tTs' = ListPair.map Constraint (ts', Ts'); |
400 val tTs' = ListPair.map Constraint (ts', Ts'); |
363 val _ = seq (fn t => (infer classrel arities t; ())) tTs'; |
401 val _ = seq (fn t => (infer prt prT classrel arities t; ())) tTs'; |
364 |
402 |
365 (*collect result unifier*) |
403 (*collect result unifier*) |
366 fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); None) |
404 fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); None) |
367 | ch_var xi_T = Some xi_T; |
405 | ch_var xi_T = Some xi_T; |
368 val env = mapfilter ch_var Tps; |
406 val env = mapfilter ch_var Tps; |