diff -r 6e8b0672c6a2 -r e275d581a218 src/Pure/type_infer.ML --- a/src/Pure/type_infer.ML Mon Sep 13 12:42:08 2010 +0200 +++ b/src/Pure/type_infer.ML Mon Sep 13 13:20:18 2010 +0200 @@ -1,18 +1,22 @@ (* Title: Pure/type_infer.ML Author: Stefan Berghofer and Markus Wenzel, TU Muenchen -Simple type inference. +Representation of type-inference problems. Simple type inference. *) signature TYPE_INFER = sig - val anyT: sort -> typ val is_param: indexname -> bool val is_paramT: typ -> bool val param: int -> string * sort -> typ + val anyT: sort -> typ val paramify_vars: typ -> typ val paramify_dummies: typ -> int -> typ * int - val fixate_params: Proof.context -> term list -> term list + val deref: typ Vartab.table -> typ -> typ + val finish: Proof.context -> typ Vartab.table -> typ list * term list -> typ list * term list + val fixate: Proof.context -> term list -> term list + val prepare: Proof.context -> (string -> typ option) -> (string * int -> typ option) -> + term list -> int * term list val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) -> term list -> term list end; @@ -20,12 +24,8 @@ structure Type_Infer: TYPE_INFER = struct - (** type parameters and constraints **) -fun anyT S = TFree ("'_dummy_", S); - - (* type inference parameters -- may get instantiated *) fun is_param (x, _: int) = String.isPrefix "?" x; @@ -37,6 +37,11 @@ fun mk_param i S = TVar (("?'a", i), S); + +(* pre-stage parameters *) + +fun anyT S = TFree ("'_dummy_", S); + val paramify_vars = Same.commit (Term_Subst.map_atypsT_same @@ -54,19 +59,6 @@ | paramify T maxidx = (T, maxidx); in paramify end; -fun fixate_params ctxt ts = - let - fun subst_param (xi, S) (inst, used) = - if is_param xi then - let - val [a] = Name.invents used Name.aT 1; - val used' = Name.declare a used; - in (((xi, S), TFree (a, S)) :: inst, used') end - else (inst, used); - val used = (fold o fold_types) Term.declare_typ_names ts (Variable.names_of ctxt); - val (inst, _) = fold_rev subst_param (fold Term.add_tvars ts []) ([], used); - in (map o map_types) (Term_Subst.instantiateT inst) ts end; - (** prepare types/terms: create inference parameters **) @@ -156,7 +148,7 @@ -(** finish types/terms: standardize remaining parameters **) +(** results **) (* dereferenced views *) @@ -179,7 +171,7 @@ | TVar ((x, i), _) => if is_param (x, i) then I else Name.declare x); -(* finish *) +(* finish -- standardize remaining parameters *) fun finish ctxt tye (Ts, ts) = let @@ -200,6 +192,22 @@ in (map finish_typ Ts, map (Type.strip_constraints o Term.map_types finish_typ) ts) end; +(* fixate -- introduce fresh type variables *) + +fun fixate ctxt ts = + let + fun subst_param (xi, S) (inst, used) = + if is_param xi then + let + val [a] = Name.invents used Name.aT 1; + val used' = Name.declare a used; + in (((xi, S), TFree (a, S)) :: inst, used') end + else (inst, used); + val used = (fold o fold_types) Term.declare_typ_names ts (Variable.names_of ctxt); + val (inst, _) = fold_rev subst_param (fold Term.add_tvars ts []) ([], used); + in (map o map_types) (Term_Subst.instantiateT inst) ts end; + + (** order-sorted unification of types **) @@ -271,7 +279,7 @@ -(** type inference **) +(** simple type inference **) (* infer *) @@ -323,25 +331,27 @@ in inf [] end; -(* infer_types *) +(* main interfaces *) -fun infer_types ctxt const_type var_type raw_ts = +fun prepare ctxt const_type var_type raw_ts = let - (*constrain vars*) val get_type = the_default dummyT o var_type; val constrain_vars = Term.map_aterms (fn Free (x, T) => Type.constraint T (Free (x, get_type (x, ~1))) | Var (xi, T) => Type.constraint T (Var (xi, get_type xi)) | t => t); - (*convert to preterms*) val ts = burrow_types (Syntax.check_typs ctxt) raw_ts; val (ts', (_, _, idx)) = fold_map (prepare_term const_type o constrain_vars) ts (Vartab.empty, Vartab.empty, 0); + in (idx, ts') end; - (*do type inference*) - val (tye, _) = fold (snd oo infer ctxt) ts' (Vartab.empty, idx); - in #2 (finish ctxt tye ([], ts')) end; +fun infer_types ctxt const_type var_type raw_ts = + let + val (idx, ts) = prepare ctxt const_type var_type raw_ts; + val (tye, _) = fold (snd oo infer ctxt) ts (Vartab.empty, idx); + val (_, ts') = finish ctxt tye ([], ts); + in ts' end; end;