# HG changeset patch # User berghofe # Date 1248334702 -7200 # Node ID 31cd1ea502aa333151aa83af7cdc82dd6ca9d060 # Parent 228905e0235057c08b5d71d415ab45d8abfec778 Purely functional type inference. diff -r 228905e02350 -r 31cd1ea502aa src/Pure/type_infer.ML --- a/src/Pure/type_infer.ML Wed Jul 22 18:08:45 2009 +0200 +++ b/src/Pure/type_infer.ML Thu Jul 23 09:38:22 2009 +0200 @@ -71,13 +71,12 @@ (** pretyps and preterms **) -(*links to parameters may get instantiated, anything else is rigid*) +(*parameters may get instantiated, anything else is rigid*) datatype pretyp = PType of string * pretyp list | PTFree of string * sort | PTVar of indexname * sort | - Param of sort | - Link of pretyp ref; + Param of int * sort; datatype preterm = PConst of string * pretyp | @@ -91,11 +90,10 @@ (* utils *) -val mk_param = Link o ref o Param; - -fun deref (T as Link (ref (Param _))) = T - | deref (Link (ref T)) = deref T - | deref T = T; +fun deref tye (T as Param (i, S)) = (case Inttab.lookup tye i of + NONE => T + | SOME U => deref tye U) + | deref tye T = T; fun fold_pretyps f (PConst (_, T)) x = f T x | fold_pretyps f (PFree (_, T)) x = f T x @@ -111,46 +109,50 @@ (* pretyp_of *) -fun pretyp_of is_para typ params = +fun pretyp_of is_para typ params_idx = let - val params' = fold_atyps + val (params', idx) = fold_atyps (fn TVar (xi as (x, _), S) => - (fn ps => + (fn ps_idx as (ps, idx) => if is_para xi andalso not (Vartab.defined ps xi) - then Vartab.update (xi, mk_param S) ps else ps) - | _ => I) typ params; + then (Vartab.update (xi, Param (idx, S)) ps, idx + 1) else ps_idx) + | _ => I) typ params_idx; - fun pre_of (TVar (v as (xi, _))) = + fun pre_of (TVar (v as (xi, _))) idx = (case Vartab.lookup params' xi of NONE => PTVar v - | SOME p => p) - | pre_of (TFree ("'_dummy_", S)) = mk_param S - | pre_of (TFree v) = PTFree v - | pre_of (T as Type (a, Ts)) = - if T = dummyT then mk_param [] - else PType (a, map pre_of Ts); - in (pre_of typ, params') end; + | SOME p => p, idx) + | pre_of (TFree ("'_dummy_", S)) idx = (Param (idx, S), idx + 1) + | pre_of (TFree v) idx = (PTFree v, idx) + | pre_of (T as Type (a, Ts)) idx = + if T = dummyT then (Param (idx, []), idx + 1) + else + let val (Ts', idx') = fold_map pre_of Ts idx + in (PType (a, Ts'), idx') end; + + val (ptyp, idx') = pre_of typ idx; + in (ptyp, (params', idx')) end; (* preterm_of *) -fun preterm_of const_type is_para tm (vparams, params) = +fun preterm_of const_type is_para tm (vparams, params, idx) = let - fun add_vparm xi ps = + fun add_vparm xi (ps_idx as (ps, idx)) = if not (Vartab.defined ps xi) then - Vartab.update (xi, mk_param []) ps - else ps; + (Vartab.update (xi, Param (idx, [])) ps, idx + 1) + else ps_idx; - val vparams' = fold_aterms + val (vparams', idx') = fold_aterms (fn Var (_, Type ("_polymorphic_", _)) => I | Var (xi, _) => add_vparm xi | Free (x, _) => add_vparm (x, ~1) | _ => I) - tm vparams; + tm (vparams, idx); fun var_param xi = the (Vartab.lookup vparams' xi); val preT_of = pretyp_of is_para; - fun polyT_of T = fst (pretyp_of (K true) T Vartab.empty); + fun polyT_of T idx = apsnd snd (pretyp_of (K true) T (Vartab.empty, idx)); fun constraint T t ps = if T = dummyT then (t, ps) @@ -158,29 +160,33 @@ let val (T', ps') = preT_of T ps in (Constraint (t, T'), ps') end; - fun pre_of (Const (c, T)) ps = + fun pre_of (Const (c, T)) (ps, idx) = (case const_type c of - SOME U => constraint T (PConst (c, polyT_of U)) ps + SOME U => + let val (pU, idx') = polyT_of U idx + in constraint T (PConst (c, pU)) (ps, idx') end | NONE => raise TYPE ("No such constant: " ^ quote c, [], [])) - | pre_of (Var (xi, Type ("_polymorphic_", [T]))) ps = (PVar (xi, polyT_of T), ps) - | pre_of (Var (xi, T)) ps = constraint T (PVar (xi, var_param xi)) ps - | pre_of (Free (x, T)) ps = constraint T (PFree (x, var_param (x, ~1))) ps - | pre_of (Const ("_type_constraint_", Type ("fun", [T, _])) $ t) ps = - pre_of t ps |-> constraint T - | pre_of (Bound i) ps = (PBound i, ps) - | pre_of (Abs (x, T, t)) ps = + | pre_of (Var (xi, Type ("_polymorphic_", [T]))) (ps, idx) = + let val (pT, idx') = polyT_of T idx + in (PVar (xi, pT), (ps, idx')) end + | pre_of (Var (xi, T)) ps_idx = constraint T (PVar (xi, var_param xi)) ps_idx + | pre_of (Free (x, T)) ps_idx = constraint T (PFree (x, var_param (x, ~1))) ps_idx + | pre_of (Const ("_type_constraint_", Type ("fun", [T, _])) $ t) ps_idx = + pre_of t ps_idx |-> constraint T + | pre_of (Bound i) ps_idx = (PBound i, ps_idx) + | pre_of (Abs (x, T, t)) ps_idx = let - val (T', ps') = preT_of T ps; - val (t', ps'') = pre_of t ps'; - in (PAbs (x, T', t'), ps'') end - | pre_of (t $ u) ps = + val (T', ps_idx') = preT_of T ps_idx; + val (t', ps_idx'') = pre_of t ps_idx'; + in (PAbs (x, T', t'), ps_idx'') end + | pre_of (t $ u) ps_idx = let - val (t', ps') = pre_of t ps; - val (u', ps'') = pre_of u ps'; - in (PAppl (t', u'), ps'') end; + val (t', ps_idx') = pre_of t ps_idx; + val (u', ps_idx'') = pre_of u ps_idx'; + in (PAppl (t', u'), ps_idx'') end; - val (tm', params') = pre_of tm params; - in (tm', (vparams', params')) end; + val (tm', (params', idx'')) = pre_of tm (params, idx'); + in (tm', (vparams', params', idx'')) end; @@ -188,62 +194,61 @@ (* add_parms *) -fun add_parmsT (PType (_, Ts)) rs = fold add_parmsT Ts rs - | add_parmsT (Link (r as ref (Param _))) rs = insert (op =) r rs - | add_parmsT (Link (ref T)) rs = add_parmsT T rs - | add_parmsT _ rs = rs; +fun add_parmsT tye T = case deref tye T of + PType (_, Ts) => fold (add_parmsT tye) Ts + | Param (i, _) => insert (op =) i + | _ => I; -val add_parms = fold_pretyps add_parmsT; +fun add_parms tye = fold_pretyps (add_parmsT tye); (* add_names *) -fun add_namesT (PType (_, Ts)) = fold add_namesT Ts - | add_namesT (PTFree (x, _)) = Name.declare x - | add_namesT (PTVar ((x, _), _)) = Name.declare x - | add_namesT (Link (ref T)) = add_namesT T - | add_namesT (Param _) = I; +fun add_namesT tye T = case deref tye T of + PType (_, Ts) => fold (add_namesT tye) Ts + | PTFree (x, _) => Name.declare x + | PTVar ((x, _), _) => Name.declare x + | Param _ => I; -val add_names = fold_pretyps add_namesT; +fun add_names tye = fold_pretyps (add_namesT tye); (* simple_typ/term_of *) -(*deref links, fail on params*) -fun simple_typ_of (PType (a, Ts)) = Type (a, map simple_typ_of Ts) - | simple_typ_of (PTFree v) = TFree v - | simple_typ_of (PTVar v) = TVar v - | simple_typ_of (Link (ref T)) = simple_typ_of T - | simple_typ_of (Param _) = sys_error "simple_typ_of: illegal Param"; +fun simple_typ_of tye f T = case deref tye T of + PType (a, Ts) => Type (a, map (simple_typ_of tye f) Ts) + | PTFree v => TFree v + | PTVar v => TVar v + | Param (i, S) => TVar (f i, S); (*convert types, drop constraints*) -fun simple_term_of (PConst (c, T)) = Const (c, simple_typ_of T) - | simple_term_of (PFree (x, T)) = Free (x, simple_typ_of T) - | simple_term_of (PVar (xi, T)) = Var (xi, simple_typ_of T) - | simple_term_of (PBound i) = Bound i - | simple_term_of (PAbs (x, T, t)) = Abs (x, simple_typ_of T, simple_term_of t) - | simple_term_of (PAppl (t, u)) = simple_term_of t $ simple_term_of u - | simple_term_of (Constraint (t, _)) = simple_term_of t; +fun simple_term_of tye f (PConst (c, T)) = Const (c, simple_typ_of tye f T) + | simple_term_of tye f (PFree (x, T)) = Free (x, simple_typ_of tye f T) + | simple_term_of tye f (PVar (xi, T)) = Var (xi, simple_typ_of tye f T) + | simple_term_of tye f (PBound i) = Bound i + | simple_term_of tye f (PAbs (x, T, t)) = + Abs (x, simple_typ_of tye f T, simple_term_of tye f t) + | simple_term_of tye f (PAppl (t, u)) = + simple_term_of tye f t $ simple_term_of tye f u + | simple_term_of tye f (Constraint (t, _)) = simple_term_of tye f t; -(* typs_terms_of *) (*DESTRUCTIVE*) +(* typs_terms_of *) -fun typs_terms_of used maxidx (Ts, ts) = +fun typs_terms_of tye used maxidx (Ts, ts) = let - fun elim (r as ref (Param S), x) = r := PTVar ((x, maxidx + 1), S) - | elim _ = (); - - val used' = fold add_names ts (fold add_namesT Ts used); - val parms = rev (fold add_parms ts (fold add_parmsT Ts [])); + val used' = fold (add_names tye) ts (fold (add_namesT tye) Ts used); + val parms = rev (fold (add_parms tye) ts (fold (add_parmsT tye) Ts [])); val names = Name.invents used' ("?" ^ Name.aT) (length parms); - val _ = ListPair.app elim (parms, names); - in (map simple_typ_of Ts, map simple_term_of ts) end; + val tab = Inttab.make (parms ~~ names); + fun f i = (the (Inttab.lookup tab i), maxidx + 1); + in (map (simple_typ_of tye f) Ts, map (simple_term_of tye f) ts) end; -(** order-sorted unification of types **) (*DESTRUCTIVE*) +(** order-sorted unification of types **) -exception NO_UNIFIER of string; +exception NO_UNIFIER of string * pretyp Inttab.table; fun unify pp tsig = let @@ -254,49 +259,52 @@ "Variable " ^ x ^ "::" ^ Pretty.string_of_sort pp S' ^ " not of sort " ^ Pretty.string_of_sort pp S; - fun meet (_, []) = () - | meet (Link (r as (ref (Param S'))), S) = - if Type.subsort tsig (S', S) then () - else r := mk_param (Type.inter_sort tsig (S', S)) - | meet (Link (ref T), S) = meet (T, S) - | meet (PType (a, Ts), S) = - ListPair.app meet (Ts, Type.arity_sorts pp tsig a S - handle ERROR msg => raise NO_UNIFIER msg) - | meet (PTFree (x, S'), S) = - if Type.subsort tsig (S', S) then () - else raise NO_UNIFIER (not_of_sort x S' S) - | meet (PTVar (xi, S'), S) = - if Type.subsort tsig (S', S) then () - else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S) - | meet (Param _, _) = sys_error "meet"; + fun meet (_, []) tye_idx = tye_idx + | meet (Param (i, S'), S) (tye_idx as (tye, idx)) = + if Type.subsort tsig (S', S) then tye_idx + else (Inttab.update_new (i, + Param (idx, Type.inter_sort tsig (S', S))) tye, idx + 1) + | meet (PType (a, Ts), S) (tye_idx as (tye, _)) = + meets (Ts, Type.arity_sorts pp tsig a S + handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx + | meet (PTFree (x, S'), S) (tye_idx as (tye, _)) = + if Type.subsort tsig (S', S) then tye_idx + else raise NO_UNIFIER (not_of_sort x S' S, tye) + | meet (PTVar (xi, S'), S) (tye_idx as (tye, _)) = + if Type.subsort tsig (S', S) then tye_idx + else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye) + and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) = + meets (Ts, Ss) (meet (deref tye T, S) tye_idx) + | meets _ tye_idx = tye_idx; (* occurs check and assigment *) - fun occurs_check r (Link (r' as ref T)) = - if r = r' then raise NO_UNIFIER "Occurs check!" - else occurs_check r T - | occurs_check r (PType (_, Ts)) = List.app (occurs_check r) Ts - | occurs_check _ _ = (); + fun occurs_check tye i (Param (i', S)) = + if i = i' then raise NO_UNIFIER ("Occurs check!", tye) + else (case Inttab.lookup tye i' of + NONE => () + | SOME T => occurs_check tye i T) + | occurs_check tye i (PType (_, Ts)) = List.app (occurs_check tye i) Ts + | occurs_check _ _ _ = (); - fun assign r T S = - (case deref T of - T' as Link (r' as ref (Param _)) => - if r = r' then () else (meet (T', S); r := T') - | T' => (occurs_check r T'; meet (T', S); r := T')); + fun assign i (T as Param (i', _)) S (tye_idx as (tye, idx)) = + if i = i' then tye_idx + else meet (T, S) (Inttab.update_new (i, T) tye, idx) + | assign i T S (tye, idx) = + (occurs_check tye i T; meet (T, S) (Inttab.update_new (i, T) tye, idx)); (* unification *) - fun unif (Link (r as ref (Param S)), T) = assign r T S - | unif (T, Link (r as ref (Param S))) = assign r T S - | unif (Link (ref T), U) = unif (T, U) - | unif (T, Link (ref U)) = unif (T, U) - | unif (PType (a, Ts), PType (b, Us)) = + fun unif (T1, T2) (tye_idx as (tye, idx)) = case (deref tye T1, deref tye T2) of + (Param (i, S), T) => assign i T S tye_idx + | (T, Param (i, S)) => assign i T S tye_idx + | (PType (a, Ts), PType (b, Us)) => if a <> b then - raise NO_UNIFIER ("Clash of types " ^ quote a ^ " and " ^ quote b) - else ListPair.app unif (Ts, Us) - | unif (T, U) = if T = U then () else raise NO_UNIFIER ""; + raise NO_UNIFIER ("Clash of types " ^ quote a ^ " and " ^ quote b, tye) + else fold unif (Ts ~~ Us) tye_idx + | (T, U) => if T = U then tye_idx else raise NO_UNIFIER ("", tye); in unif end; @@ -318,7 +326,7 @@ ""]; -(* infer *) (*DESTRUCTIVE*) +(* infer *) fun infer pp tsig = let @@ -327,9 +335,9 @@ fun unif_failed msg = "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n"; - fun prep_output bs ts Ts = + fun prep_output tye bs ts Ts = let - val (Ts_bTs', ts') = typs_terms_of Name.context ~1 (Ts @ map snd bs, ts); + val (Ts_bTs', ts') = typs_terms_of tye Name.context ~1 (Ts @ map snd bs, ts); val (Ts', Ts'') = chop (length Ts) Ts_bTs'; fun prep t = let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts''))) @@ -339,9 +347,9 @@ fun err_loose i = raise TYPE ("Loose bound variable: B." ^ string_of_int i, [], []); - fun err_appl msg bs t T u U = + fun err_appl msg tye bs t T u U = let - val ([t', u'], [T', U']) = prep_output bs [t, u] [T, U]; + val ([t', u'], [T', U']) = prep_output tye bs [t, u] [T, U]; val why = (case T' of Type ("fun", _) => "Incompatible operand type" @@ -349,9 +357,9 @@ val text = unif_failed msg ^ cat_lines (appl_error pp why t' T' u' U'); in raise TYPE (text, [T', U'], [t', u']) end; - fun err_constraint msg bs t T U = + fun err_constraint msg tye bs t T U = let - val ([t'], [T', U']) = prep_output bs [t] [T, U]; + val ([t'], [T', U']) = prep_output tye bs [t] [T, U]; val text = cat_lines [unif_failed msg, "Cannot meet type constraint:", "", @@ -367,23 +375,28 @@ val unif = unify pp tsig; - fun inf _ (PConst (_, T)) = T - | inf _ (PFree (_, T)) = T - | inf _ (PVar (_, T)) = T - | inf bs (PBound i) = snd (nth bs i handle Subscript => err_loose i) - | inf bs (PAbs (x, T, t)) = PType ("fun", [T, inf ((x, T) :: bs) t]) - | inf bs (PAppl (t, u)) = + fun inf _ (PConst (_, T)) tye_idx = (T, tye_idx) + | inf _ (PFree (_, T)) tye_idx = (T, tye_idx) + | inf _ (PVar (_, T)) tye_idx = (T, tye_idx) + | inf bs (PBound i) tye_idx = + (snd (nth bs i handle Subscript => err_loose i), tye_idx) + | inf bs (PAbs (x, T, t)) tye_idx = + let val (U, tye_idx') = inf ((x, T) :: bs) t tye_idx + in (PType ("fun", [T, U]), tye_idx') end + | inf bs (PAppl (t, u)) tye_idx = let - val T = inf bs t; - val U = inf bs u; - val V = mk_param []; + val (T, tye_idx') = inf bs t tye_idx; + val (U, (tye, idx)) = inf bs u tye_idx'; + val V = Param (idx, []); val U_to_V = PType ("fun", [U, V]); - val _ = unif (U_to_V, T) handle NO_UNIFIER msg => err_appl msg bs t T u U; - in V end - | inf bs (Constraint (t, U)) = - let val T = inf bs t in - unif (T, U) handle NO_UNIFIER msg => err_constraint msg bs t T U; - T + val tye_idx'' = unif (U_to_V, T) (tye, idx + 1) + handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U; + in (V, tye_idx'') end + | inf bs (Constraint (t, U)) tye_idx = + let val (T, tye_idx') = inf bs t tye_idx in + (T, + unif (T, U) tye_idx' + handle NO_UNIFIER (msg, tye) => err_constraint msg tye bs t T U) end; in inf [] end; @@ -402,11 +415,12 @@ (*convert to preterms*) val ts = burrow_types check_typs raw_ts; - val (ts', _) = - fold_map (preterm_of const_type is_param o constrain_vars) ts (Vartab.empty, Vartab.empty); + val (ts', (_, _, idx)) = + fold_map (preterm_of const_type is_param o constrain_vars) ts + (Vartab.empty, Vartab.empty, 0); (*do type inference*) - val _ = List.app (ignore o infer pp tsig) ts'; - in #2 (typs_terms_of used maxidx ([], ts')) end; + val (tye, _) = fold (snd oo infer pp tsig) ts' (Inttab.empty, idx); + in #2 (typs_terms_of tye used maxidx ([], ts')) end; end;