# HG changeset patch # User wenzelm # Date 1284370555 -7200 # Node ID 27fae73fe7690558e87dee2506a1fa7f611ba448 # Parent 651e5a3e8cfd8773e6cb0a67b6580d74ce87e2ac simplified Type_Infer: eliminated separate datatypes pretyp/preterm -- only assign is_paramT TVars; diff -r 651e5a3e8cfd -r 27fae73fe769 src/Pure/type_infer.ML --- a/src/Pure/type_infer.ML Mon Sep 13 00:10:29 2010 +0200 +++ b/src/Pure/type_infer.ML Mon Sep 13 11:35:55 2010 +0200 @@ -8,6 +8,7 @@ sig val anyT: sort -> typ val is_param: indexname -> bool + val is_paramT: typ -> bool val param: int -> string * sort -> typ val paramify_vars: typ -> typ val paramify_dummies: typ -> int -> typ * int @@ -28,8 +29,14 @@ (* type inference parameters -- may get instantiated *) fun is_param (x, _: int) = String.isPrefix "?" x; + +fun is_paramT (TVar (xi, _)) = is_param xi + | is_paramT _ = false; + fun param i (x, S) = TVar (("?" ^ x, i), S); +fun mk_param i S = TVar (("?'a", i), S); + val paramify_vars = Same.commit (Term_Subst.map_atypsT_same @@ -62,76 +69,42 @@ -(** pretyps and preterms **) - -datatype pretyp = - PType of string * pretyp list | - PTFree of string * sort | - PTVar of indexname * sort | - Param of int * sort; - -datatype preterm = - PConst of string * pretyp | - PFree of string * pretyp | - PVar of indexname * pretyp | - PBound of int | - PAbs of string * pretyp * preterm | - PAppl of preterm * preterm; - - -(* utils *) +(** prepare types/terms: create inference parameters **) -fun deref tye (T as Param (i, S)) = - (case Inttab.lookup tye i of - NONE => T - | SOME U => deref tye U) - | deref tye T = T; +(* prepare_typ *) -fun fold_pretyps f (PConst (_, T)) x = f T x - | fold_pretyps f (PFree (_, T)) x = f T x - | fold_pretyps f (PVar (_, T)) x = f T x - | fold_pretyps _ (PBound _) x = x - | fold_pretyps f (PAbs (_, T, t)) x = fold_pretyps f t (f T x) - | fold_pretyps f (PAppl (t, u)) x = fold_pretyps f u (fold_pretyps f t x); - - - -(** raw typs/terms to pretyps/preterms **) - -(* pretyp_of *) - -fun pretyp_of typ params_idx = +fun prepare_typ typ params_idx = let val (params', idx) = fold_atyps (fn TVar (xi as (x, _), S) => (fn ps_idx as (ps, idx) => if is_param xi andalso not (Vartab.defined ps xi) - then (Vartab.update (xi, Param (idx, S)) ps, idx + 1) else ps_idx) + then (Vartab.update (xi, mk_param idx S) ps, idx + 1) else ps_idx) | _ => I) typ params_idx; - fun pre_of (TVar (v as (xi, _))) idx = + fun prepare (T as Type (a, Ts)) idx = + if T = dummyT then (mk_param idx [], idx + 1) + else + let val (Ts', idx') = fold_map prepare Ts idx + in (Type (a, Ts'), idx') end + | prepare (T as TVar (xi, _)) idx = (case Vartab.lookup params' xi of - NONE => PTVar v + NONE => T | SOME p => p, idx) - | pre_of (TFree ("'_dummy_", S)) idx = (Param (idx, S), idx + 1) - | pre_of (TFree v) idx = (PTFree v, idx) - | pre_of (T as Type (a, Ts)) idx = - if T = dummyT then (Param (idx, []), idx + 1) - else - let val (Ts', idx') = fold_map pre_of Ts idx - in (PType (a, Ts'), idx') end; + | prepare (TFree ("'_dummy_", S)) idx = (mk_param idx S, idx + 1) + | prepare (T as TFree _) idx = (T, idx); - val (ptyp, idx') = pre_of typ idx; - in (ptyp, (params', idx')) end; + val (typ', idx') = prepare typ idx; + in (typ', (params', idx')) end; -(* preterm_of *) +(* prepare_term *) -fun preterm_of const_type tm (vparams, params, idx) = +fun prepare_term const_type tm (vparams, params, idx) = let fun add_vparm xi (ps_idx as (ps, idx)) = if not (Vartab.defined ps xi) then - (Vartab.update (xi, Param (idx, [])) ps, idx + 1) + (Vartab.update (xi, mk_param idx []) ps, idx + 1) else ps_idx; val (vparams', idx') = fold_aterms @@ -142,109 +115,96 @@ tm (vparams, idx); fun var_param xi = the (Vartab.lookup vparams' xi); - fun polyT_of T idx = apsnd snd (pretyp_of (paramify_vars T) (Vartab.empty, idx)); + fun polyT_of T idx = apsnd snd (prepare_typ (paramify_vars T) (Vartab.empty, idx)); fun constraint T t ps = if T = dummyT then (t, ps) else - let val (T', ps') = pretyp_of T ps - in (PAppl (PConst ("_type_constraint_", PType ("fun", [T', T'])), t), ps') end; + let val (T', ps') = prepare_typ T ps + in (Type.constraint T' t, ps') end; - fun pre_of (Const (c, T)) (ps, idx) = + fun prepare (Const ("_type_constraint_", T) $ t) ps_idx = + let + val (T', ps_idx') = prepare_typ T ps_idx; + val (t', ps_idx'') = prepare t ps_idx'; + in (Const ("_type_constraint_", T') $ t', ps_idx'') end + | prepare (Const (c, T)) (ps, idx) = (case const_type c of SOME U => - let val (pU, idx') = polyT_of U idx - in constraint T (PConst (c, pU)) (ps, idx') end + let val (U', idx') = polyT_of U idx + in constraint T (Const (c, U')) (ps, idx') end | NONE => error ("Undeclared constant: " ^ quote c)) - | pre_of (Const ("_type_constraint_", T) $ t) ps_idx = + | prepare (Var (xi, Type ("_polymorphic_", [T]))) (ps, idx) = + let val (T', idx') = polyT_of T idx + in (Var (xi, T'), (ps, idx')) end + | prepare (Var (xi, T)) ps_idx = constraint T (Var (xi, var_param xi)) ps_idx + | prepare (Free (x, T)) ps_idx = constraint T (Free (x, var_param (x, ~1))) ps_idx + | prepare (Bound i) ps_idx = (Bound i, ps_idx) + | prepare (Abs (x, T, t)) ps_idx = let - val (T', ps_idx') = pretyp_of T ps_idx; - val (t', ps_idx'') = pre_of t ps_idx'; - in (PAppl (PConst ("_type_constraint_", T'), t'), ps_idx'') end - | pre_of (Var (xi, Type ("_polymorphic_", [T]))) (ps, idx) = - let val (pT, idx') = polyT_of T idx - in (PVar (xi, pT), (ps, idx')) end - | pre_of (Var (xi, T)) ps_idx = constraint T (PVar (xi, var_param xi)) ps_idx - | pre_of (Free (x, T)) ps_idx = constraint T (PFree (x, var_param (x, ~1))) ps_idx - | pre_of (Bound i) ps_idx = (PBound i, ps_idx) - | pre_of (Abs (x, T, t)) ps_idx = + val (T', ps_idx') = prepare_typ T ps_idx; + val (t', ps_idx'') = prepare t ps_idx'; + in (Abs (x, T', t'), ps_idx'') end + | prepare (t $ u) ps_idx = let - val (T', ps_idx') = pretyp_of T ps_idx; - val (t', ps_idx'') = pre_of t ps_idx'; - in (PAbs (x, T', t'), ps_idx'') end - | pre_of (t $ u) ps_idx = - let - val (t', ps_idx') = pre_of t ps_idx; - val (u', ps_idx'') = pre_of u ps_idx'; - in (PAppl (t', u'), ps_idx'') end; + val (t', ps_idx') = prepare t ps_idx; + val (u', ps_idx'') = prepare u ps_idx'; + in (t' $ u', ps_idx'') end; - val (tm', (params', idx'')) = pre_of tm (params, idx'); + val (tm', (params', idx'')) = prepare tm (params, idx'); in (tm', (vparams', params', idx'')) end; -(** pretyps/terms to typs/terms **) +(** finish types/terms: standardize remaining parameters **) -(* add_parms *) +(* dereferenced views *) -fun add_parmsT tye T = +fun deref tye (T as TVar (xi, _)) = + (case Vartab.lookup tye xi of + NONE => T + | SOME U => deref tye U) + | deref tye T = T; + +fun add_parms tye T = (case deref tye T of - PType (_, Ts) => fold (add_parmsT tye) Ts - | Param (i, _) => insert (op =) i + Type (_, Ts) => fold (add_parms tye) Ts + | TVar (xi, _) => if is_param xi then insert (op =) xi else I | _ => I); -fun add_parms tye = fold_pretyps (add_parmsT tye); - - -(* add_names *) - -fun add_namesT tye T = +fun add_names tye T = (case deref tye T of - PType (_, Ts) => fold (add_namesT tye) Ts - | PTFree (x, _) => Name.declare x - | PTVar ((x, _), _) => Name.declare x - | Param _ => I); - -fun add_names tye = fold_pretyps (add_namesT tye); + Type (_, Ts) => fold (add_names tye) Ts + | TFree (x, _) => Name.declare x + | TVar ((x, i), _) => if is_param (x, i) then I else Name.declare x); -(* simple_typ/term_of *) - -fun simple_typ_of tye f T = - (case deref tye T of - PType (a, Ts) => Type (a, map (simple_typ_of tye f) Ts) - | PTFree v => TFree v - | PTVar v => TVar v - | Param (i, S) => TVar (f i, S)); +(* finish *) -fun simple_term_of tye f (PConst (c, T)) = Const (c, simple_typ_of tye f T) - | simple_term_of tye f (PFree (x, T)) = Free (x, simple_typ_of tye f T) - | simple_term_of tye f (PVar (xi, T)) = Var (xi, simple_typ_of tye f T) - | simple_term_of tye f (PBound i) = Bound i - | simple_term_of tye f (PAbs (x, T, t)) = - Abs (x, simple_typ_of tye f T, simple_term_of tye f t) - | simple_term_of tye f (PAppl (t, u)) = - simple_term_of tye f t $ simple_term_of tye f u; - +fun finish ctxt tye (Ts, ts) = + let + val used = + (fold o fold_types) (add_names tye) ts (fold (add_names tye) Ts (Variable.names_of ctxt)); + val parms = rev ((fold o fold_types) (add_parms tye) ts (fold (add_parms tye) Ts [])); + val names = Name.invents used ("?" ^ Name.aT) (length parms); + val tab = Vartab.make (parms ~~ names); + val idx = Variable.maxidx_of ctxt + 1; -(* typs_terms_of *) - -fun typs_terms_of ctxt tye (Ts, ts) = - let - val used = fold (add_names tye) ts (fold (add_namesT tye) Ts (Variable.names_of ctxt)); - val parms = rev (fold (add_parms tye) ts (fold (add_parmsT tye) Ts [])); - val names = Name.invents used ("?" ^ Name.aT) (length parms); - val tab = Inttab.make (parms ~~ names); - - val maxidx = Variable.maxidx_of ctxt; - fun f i = (the (Inttab.lookup tab i), maxidx + 1); - in (map (simple_typ_of tye f) Ts, map (Type.strip_constraints o simple_term_of tye f) ts) end; + fun finish_typ T = + (case deref tye T of + Type (a, Ts) => Type (a, map finish_typ Ts) + | U as TFree _ => U + | U as TVar (xi, S) => + (case Vartab.lookup tab xi of + NONE => U + | SOME a => TVar ((a, idx), S))); + in (map finish_typ Ts, map (Type.strip_constraints o Term.map_types finish_typ) ts) end; (** order-sorted unification of types **) -exception NO_UNIFIER of string * pretyp Inttab.table; +exception NO_UNIFIER of string * typ Vartab.table; fun unify ctxt pp = let @@ -259,17 +219,15 @@ Syntax.string_of_sort ctxt S; fun meet (_, []) tye_idx = tye_idx - | meet (Param (i, S'), S) (tye_idx as (tye, idx)) = - if Sign.subsort thy (S', S) then tye_idx - else (Inttab.update_new (i, - Param (idx, Sign.inter_sort thy (S', S))) tye, idx + 1) - | meet (PType (a, Ts), S) (tye_idx as (tye, _)) = + | meet (Type (a, Ts), S) (tye_idx as (tye, _)) = meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx - | meet (PTFree (x, S'), S) (tye_idx as (tye, _)) = + | meet (TFree (x, S'), S) (tye_idx as (tye, _)) = if Sign.subsort thy (S', S) then tye_idx else raise NO_UNIFIER (not_of_sort x S' S, tye) - | meet (PTVar (xi, S'), S) (tye_idx as (tye, _)) = + | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) = if Sign.subsort thy (S', S) then tye_idx + else if is_param xi then + (Vartab.update_new (xi, mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1) else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye) and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) = meets (Ts, Ss) (meet (deref tye T, S) tye_idx) @@ -278,20 +236,20 @@ (* occurs check and assignment *) - fun occurs_check tye i (Param (i', S)) = - if i = i' then raise NO_UNIFIER ("Occurs check!", tye) + fun occurs_check tye xi (TVar (xi', S)) = + if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye) else - (case Inttab.lookup tye i' of + (case Vartab.lookup tye xi' of NONE => () - | SOME T => occurs_check tye i T) - | occurs_check tye i (PType (_, Ts)) = List.app (occurs_check tye i) Ts + | SOME T => occurs_check tye xi T) + | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts | occurs_check _ _ _ = (); - fun assign i (T as Param (i', _)) S tye_idx = - if i = i' then tye_idx - else tye_idx |> meet (T, S) |>> Inttab.update_new (i, T) - | assign i T S (tye_idx as (tye, _)) = - (occurs_check tye i T; tye_idx |> meet (T, S) |>> Inttab.update_new (i, T)); + fun assign xi (T as TVar (xi', _)) S env = + if xi = xi' then env + else env |> meet (T, S) |>> Vartab.update_new (xi, T) + | assign xi T S (env as (tye, _)) = + (occurs_check tye xi T; env |> meet (T, S) |>> Vartab.update_new (xi, T)); (* unification *) @@ -299,16 +257,16 @@ fun show_tycon (a, Ts) = quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT))); - fun unif (T1, T2) (tye_idx as (tye, idx)) = - (case (deref tye T1, deref tye T2) of - (Param (i, S), T) => assign i T S tye_idx - | (T, Param (i, S)) => assign i T S tye_idx - | (PType (a, Ts), PType (b, Us)) => + fun unif (T1, T2) (env as (tye, _)) = + (case pairself (`is_paramT o deref tye) (T1, T2) of + ((true, TVar (xi, S)), (_, T)) => assign xi T S env + | ((_, T), (true, TVar (xi, S))) => assign xi T S env + | ((_, Type (a, Ts)), (_, Type (b, Us))) => if a <> b then raise NO_UNIFIER ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye) - else fold unif (Ts ~~ Us) tye_idx - | (T, U) => if T = U then tye_idx else raise NO_UNIFIER ("", tye)); + else fold unif (Ts ~~ Us) env + | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye)); in unif end; @@ -327,7 +285,7 @@ fun prep_output tye bs ts Ts = let - val (Ts_bTs', ts') = typs_terms_of ctxt tye (Ts @ map snd bs, ts); + val (Ts_bTs', ts') = finish ctxt tye (Ts @ map snd bs, ts); val (Ts', Ts'') = chop (length Ts) Ts_bTs'; fun prep t = let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts''))) @@ -346,21 +304,20 @@ (* main *) - fun inf _ (PConst (_, T)) tye_idx = (T, tye_idx) - | inf _ (PFree (_, T)) tye_idx = (T, tye_idx) - | inf _ (PVar (_, T)) tye_idx = (T, tye_idx) - | inf bs (PBound i) tye_idx = + fun inf _ (Const (_, T)) tye_idx = (T, tye_idx) + | inf _ (Free (_, T)) tye_idx = (T, tye_idx) + | inf _ (Var (_, T)) tye_idx = (T, tye_idx) + | inf bs (Bound i) tye_idx = (snd (nth bs i handle Subscript => err_loose i), tye_idx) - | inf bs (PAbs (x, T, t)) tye_idx = + | inf bs (Abs (x, T, t)) tye_idx = let val (U, tye_idx') = inf ((x, T) :: bs) t tye_idx - in (PType ("fun", [T, U]), tye_idx') end - | inf bs (PAppl (t, u)) tye_idx = + in (T --> U, tye_idx') end + | inf bs (t $ u) tye_idx = let val (T, tye_idx') = inf bs t tye_idx; val (U, (tye, idx)) = inf bs u tye_idx'; - val V = Param (idx, []); - val U_to_V = PType ("fun", [U, V]); - val tye_idx'' = unify ctxt pp (U_to_V, T) (tye, idx + 1) + val V = mk_param idx []; + val tye_idx'' = unify ctxt pp (U --> V, T) (tye, idx + 1) handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U; in (V, tye_idx'') end; @@ -381,11 +338,11 @@ (*convert to preterms*) val ts = burrow_types (Syntax.check_typs ctxt) raw_ts; val (ts', (_, _, idx)) = - fold_map (preterm_of const_type o constrain_vars) ts + fold_map (prepare_term const_type o constrain_vars) ts (Vartab.empty, Vartab.empty, 0); (*do type inference*) - val (tye, _) = fold (snd oo infer ctxt) ts' (Inttab.empty, idx); - in #2 (typs_terms_of ctxt tye ([], ts')) end; + val (tye, _) = fold (snd oo infer ctxt) ts' (Vartab.empty, idx); + in #2 (finish ctxt tye ([], ts')) end; end;