112 let |
112 let |
113 fun add_parms (ps, TVar (xi as (x, _), S)) = |
113 fun add_parms (ps, TVar (xi as (x, _), S)) = |
114 if is_param xi andalso is_none (assoc (ps, xi)) |
114 if is_param xi andalso is_none (assoc (ps, xi)) |
115 then (xi, mk_param S) :: ps else ps |
115 then (xi, mk_param S) :: ps else ps |
116 | add_parms (ps, TFree _) = ps |
116 | add_parms (ps, TFree _) = ps |
117 | add_parms (ps, Type (_, Ts)) = foldl add_parms (ps, Ts); |
117 | add_parms (ps, Type (_, Ts)) = Library.foldl add_parms (ps, Ts); |
118 |
118 |
119 val params' = add_parms (params, typ); |
119 val params' = add_parms (params, typ); |
120 |
120 |
121 fun pre_of (TVar (v as (xi, _))) = |
121 fun pre_of (TVar (v as (xi, _))) = |
122 (case assoc (params', xi) of |
122 (case assoc (params', xi) of |
147 | add_vparms (ps, Abs (_, _, t)) = add_vparms (ps, t) |
147 | add_vparms (ps, Abs (_, _, t)) = add_vparms (ps, t) |
148 | add_vparms (ps, t $ u) = add_vparms (add_vparms (ps, t), u) |
148 | add_vparms (ps, t $ u) = add_vparms (add_vparms (ps, t), u) |
149 | add_vparms (ps, _) = ps; |
149 | add_vparms (ps, _) = ps; |
150 |
150 |
151 val vparams' = add_vparms (vparams, tm); |
151 val vparams' = add_vparms (vparams, tm); |
152 fun var_param xi = the (assoc (vparams', xi)); |
152 fun var_param xi = valOf (assoc (vparams', xi)); |
153 |
153 |
154 |
154 |
155 val preT_of = pretyp_of is_param; |
155 val preT_of = pretyp_of is_param; |
156 |
156 |
157 fun constrain (ps, t) T = |
157 fun constrain (ps, t) T = |
190 |
190 |
191 (** pretyps/terms to typs/terms **) |
191 (** pretyps/terms to typs/terms **) |
192 |
192 |
193 (* add_parms *) |
193 (* add_parms *) |
194 |
194 |
195 fun add_parmsT (rs, PType (_, Ts)) = foldl add_parmsT (rs, Ts) |
195 fun add_parmsT (rs, PType (_, Ts)) = Library.foldl add_parmsT (rs, Ts) |
196 | add_parmsT (rs, Link (r as ref (Param _))) = r ins rs |
196 | add_parmsT (rs, Link (r as ref (Param _))) = r ins rs |
197 | add_parmsT (rs, Link (ref T)) = add_parmsT (rs, T) |
197 | add_parmsT (rs, Link (ref T)) = add_parmsT (rs, T) |
198 | add_parmsT (rs, _) = rs; |
198 | add_parmsT (rs, _) = rs; |
199 |
199 |
200 val add_parms = foldl_pretyps add_parmsT; |
200 val add_parms = foldl_pretyps add_parmsT; |
201 |
201 |
202 |
202 |
203 (* add_names *) |
203 (* add_names *) |
204 |
204 |
205 fun add_namesT (xs, PType (_, Ts)) = foldl add_namesT (xs, Ts) |
205 fun add_namesT (xs, PType (_, Ts)) = Library.foldl add_namesT (xs, Ts) |
206 | add_namesT (xs, PTFree (x, _)) = x ins xs |
206 | add_namesT (xs, PTFree (x, _)) = x ins xs |
207 | add_namesT (xs, PTVar ((x, _), _)) = x ins xs |
207 | add_namesT (xs, PTVar ((x, _), _)) = x ins xs |
208 | add_namesT (xs, Link (ref T)) = add_namesT (xs, T) |
208 | add_namesT (xs, Link (ref T)) = add_namesT (xs, T) |
209 | add_namesT (xs, Param _) = xs; |
209 | add_namesT (xs, Param _) = xs; |
210 |
210 |
235 fun typs_terms_of used mk_var prfx (Ts, ts) = |
235 fun typs_terms_of used mk_var prfx (Ts, ts) = |
236 let |
236 let |
237 fun elim (r as ref (Param S), x) = r := mk_var (x, S) |
237 fun elim (r as ref (Param S), x) = r := mk_var (x, S) |
238 | elim _ = (); |
238 | elim _ = (); |
239 |
239 |
240 val used' = foldl add_names (foldl add_namesT (used, Ts), ts); |
240 val used' = Library.foldl add_names (Library.foldl add_namesT (used, Ts), ts); |
241 val parms = rev (foldl add_parms (foldl add_parmsT ([], Ts), ts)); |
241 val parms = rev (Library.foldl add_parms (Library.foldl add_parmsT ([], Ts), ts)); |
242 val names = Term.invent_names used' (prfx ^ "'a") (length parms); |
242 val names = Term.invent_names used' (prfx ^ "'a") (length parms); |
243 in |
243 in |
244 seq2 elim (parms, names); |
244 seq2 elim (parms, names); |
245 (map simple_typ_of Ts, map simple_term_of ts) |
245 (map simple_typ_of Ts, map simple_term_of ts) |
246 end; |
246 end; |
283 (* occurs check and assigment *) |
283 (* occurs check and assigment *) |
284 |
284 |
285 fun occurs_check r (Link (r' as ref T)) = |
285 fun occurs_check r (Link (r' as ref T)) = |
286 if r = r' then raise NO_UNIFIER "Occurs check!" |
286 if r = r' then raise NO_UNIFIER "Occurs check!" |
287 else occurs_check r T |
287 else occurs_check r T |
288 | occurs_check r (PType (_, Ts)) = seq (occurs_check r) Ts |
288 | occurs_check r (PType (_, Ts)) = List.app (occurs_check r) Ts |
289 | occurs_check _ _ = (); |
289 | occurs_check _ _ = (); |
290 |
290 |
291 fun assign r T S = |
291 fun assign r T S = |
292 (case deref T of |
292 (case deref T of |
293 T' as Link (r' as ref (Param _)) => |
293 T' as Link (r' as ref (Param _)) => |
374 val unif = unify pp classes arities; |
374 val unif = unify pp classes arities; |
375 |
375 |
376 fun inf _ (PConst (_, T)) = T |
376 fun inf _ (PConst (_, T)) = T |
377 | inf _ (PFree (_, T)) = T |
377 | inf _ (PFree (_, T)) = T |
378 | inf _ (PVar (_, T)) = T |
378 | inf _ (PVar (_, T)) = T |
379 | inf bs (PBound i) = snd (nth_elem (i, bs) handle LIST _ => err_loose i) |
379 | inf bs (PBound i) = snd (List.nth (bs, i) handle Subscript => err_loose i) |
380 | inf bs (PAbs (x, T, t)) = PType ("fun", [T, inf ((x, T) :: bs) t]) |
380 | inf bs (PAbs (x, T, t)) = PType ("fun", [T, inf ((x, T) :: bs) t]) |
381 | inf bs (PAppl (t, u)) = |
381 | inf bs (PAppl (t, u)) = |
382 let |
382 let |
383 val T = inf bs t; |
383 val T = inf bs t; |
384 val U = inf bs u; |
384 val U = inf bs u; |
403 val (Tps, Ts') = pretyps_of (K true) ([], Ts); |
403 val (Tps, Ts') = pretyps_of (K true) ([], Ts); |
404 val ((vps, ps), ts') = preterms_of const_type is_param (([], Tps), ts); |
404 val ((vps, ps), ts') = preterms_of const_type is_param (([], Tps), ts); |
405 |
405 |
406 (*run type inference*) |
406 (*run type inference*) |
407 val tTs' = ListPair.map Constraint (ts', Ts'); |
407 val tTs' = ListPair.map Constraint (ts', Ts'); |
408 val _ = seq (fn t => (infer pp classes arities t; ())) tTs'; |
408 val _ = List.app (fn t => (infer pp classes arities t; ())) tTs'; |
409 |
409 |
410 (*collect result unifier*) |
410 (*collect result unifier*) |
411 fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); NONE) |
411 fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); NONE) |
412 | ch_var xi_T = SOME xi_T; |
412 | ch_var xi_T = SOME xi_T; |
413 val env = mapfilter ch_var Tps; |
413 val env = List.mapPartial ch_var Tps; |
414 |
414 |
415 (*convert back to terms/typs*) |
415 (*convert back to terms/typs*) |
416 val mk_var = |
416 val mk_var = |
417 if freeze then PTFree |
417 if freeze then PTFree |
418 else (fn (x, S) => PTVar ((x, 0), S)); |
418 else (fn (x, S) => PTVar ((x, 0), S)); |
470 |
470 |
471 (* decode_types -- transform parse tree into raw term *) |
471 (* decode_types -- transform parse tree into raw term *) |
472 |
472 |
473 fun decode_types tsig is_const def_type def_sort map_const map_type map_sort tm = |
473 fun decode_types tsig is_const def_type def_sort map_const map_type map_sort tm = |
474 let |
474 let |
475 fun get_type xi = if_none (def_type xi) dummyT; |
475 fun get_type xi = getOpt (def_type xi, dummyT); |
476 fun is_free x = is_some (def_type (x, ~1)); |
476 fun is_free x = isSome (def_type (x, ~1)); |
477 val raw_env = Syntax.raw_term_sorts tm; |
477 val raw_env = Syntax.raw_term_sorts tm; |
478 val sort_of = get_sort tsig def_sort map_sort raw_env; |
478 val sort_of = get_sort tsig def_sort map_sort raw_env; |
479 |
479 |
480 val certT = Type.cert_typ tsig o map_type; |
480 val certT = Type.cert_typ tsig o map_type; |
481 fun decodeT t = certT (Syntax.typ_of_term sort_of map_sort t); |
481 fun decodeT t = certT (Syntax.typ_of_term sort_of map_sort t); |
517 fun infer_types pp tsig const_type def_type def_sort |
517 fun infer_types pp tsig const_type def_type def_sort |
518 map_const map_type map_sort used freeze pat_Ts raw_ts = |
518 map_const map_type map_sort used freeze pat_Ts raw_ts = |
519 let |
519 let |
520 val {classes, arities, ...} = Type.rep_tsig tsig; |
520 val {classes, arities, ...} = Type.rep_tsig tsig; |
521 val pat_Ts' = map (Type.cert_typ tsig) pat_Ts; |
521 val pat_Ts' = map (Type.cert_typ tsig) pat_Ts; |
522 val is_const = is_some o const_type; |
522 val is_const = isSome o const_type; |
523 val raw_ts' = |
523 val raw_ts' = |
524 map (decode_types tsig is_const def_type def_sort map_const map_type map_sort) raw_ts; |
524 map (decode_types tsig is_const def_type def_sort map_const map_type map_sort) raw_ts; |
525 val (ts, Ts, unifier) = basic_infer_types pp const_type |
525 val (ts, Ts, unifier) = basic_infer_types pp const_type |
526 classes arities used freeze is_param raw_ts' pat_Ts'; |
526 classes arities used freeze is_param raw_ts' pat_Ts'; |
527 in (ts, unifier) end; |
527 in (ts, unifier) end; |