# HG changeset patch # User wenzelm # Date 1176640317 -7200 # Node ID 7e6412e8d64bad7c874fa7963c10690b8548c8a6 # Parent 92f8e9a8df78696110482bbf751fcf0b1625e292 moved get_sort to sign.ML; moved decode_types to Syntax/type_ext.ML; moved mixfixT to Syntax/mixfix.ML; proper infer_types, without decode/name lookup; tuned; diff -r 92f8e9a8df78 -r 7e6412e8d64b src/Pure/type_infer.ML --- a/src/Pure/type_infer.ML Sun Apr 15 14:31:56 2007 +0200 +++ b/src/Pure/type_infer.ML Sun Apr 15 14:31:57 2007 +0200 @@ -2,62 +2,56 @@ ID: $Id$ Author: Stefan Berghofer and Markus Wenzel, TU Muenchen -Type inference. +Simple type inference. *) signature TYPE_INFER = sig val anyT: sort -> typ val logicT: typ - val mixfixT: Syntax.mixfix -> typ val polymorphicT: typ -> typ - val appl_error: Pretty.pp -> string -> term -> typ -> term -> typ -> string list val constrain: term -> typ -> term val param: int -> string * sort -> typ val paramify_dummies: typ -> int -> typ * int - val get_sort: Type.tsig -> (indexname -> sort option) -> (sort -> sort) - -> (indexname * sort) list -> indexname -> sort - val infer_types: Pretty.pp - -> Type.tsig -> (string -> typ option) -> (indexname -> typ option) - -> (indexname -> sort option) -> (string -> string) -> (typ -> typ) - -> (sort -> sort) -> Name.context -> bool -> typ list -> term list - -> term list * (indexname * typ) list + val appl_error: Pretty.pp -> string -> term -> typ -> term -> typ -> string list + val infer_types: Pretty.pp -> Type.tsig -> + (string -> typ option) -> (indexname -> typ option) -> + Name.context -> bool -> (term * typ) list -> (term * typ) list * (indexname * typ) list end; structure TypeInfer: TYPE_INFER = struct -(** term encodings **) +(** type parameters and constraints **) -(* - Flavours of term encodings: +fun anyT S = TFree ("'_dummy_", S); +val logicT = anyT []; - parse trees (type term): - A very complicated structure produced by the syntax module's - read functions. Encodes types and sorts as terms; may contain - explicit constraints and partial typing information (where - dummies serve as wildcards). +(*indicate polymorphic Vars*) +fun polymorphicT T = Type ("_polymorphic_", [T]); - Parse trees are INTERNAL! Users should never encounter them, - except in parse / print translation functions. +fun constrain t T = + if T = dummyT then t + else Const ("_type_constraint_", T --> T) $ t; + + +(* user parameters *) - raw terms (type term): - Provide the user interface to type inferences. They may contain - partial type information (dummies are wildcards) or explicit - type constraints (introduced via constrain: term -> typ -> - term). +fun is_param_default (x, _) = size x > 0 andalso ord x = ord "?"; +fun param i (x, S) = TVar (("?" ^ x, i), S); + +val paramify_dummies = + let + fun dummy S maxidx = (param (maxidx + 1) ("'dummy", S), maxidx + 1); - The type inference function also lets users specify a certain - subset of TVars to be treated as non-rigid inference parameters. - - preterms (type preterm): - The internal representation for type inference. - - well-typed term (type term): - Fully typed lambda terms to be accepted by appropriate - certification functions. -*) + fun paramify (TFree ("'_dummy_", S)) maxidx = dummy S maxidx + | paramify (Type ("dummy", _)) maxidx = dummy [] maxidx + | paramify (Type (a, Ts)) maxidx = + let val (Ts', maxidx') = fold_map paramify Ts maxidx + in (Type (a, Ts'), maxidx') end + | paramify T maxidx = (T, maxidx); + in paramify end; @@ -103,16 +97,6 @@ (* pretyp_of *) -fun anyT S = TFree ("'_dummy_", S); -val logicT = anyT []; - -fun mixfixT (Binder _) = (logicT --> logicT) --> logicT - | mixfixT mx = replicate (Syntax.mixfix_args mx) logicT ---> logicT; - - -(*indicate polymorphic Vars*) -fun polymorphicT T = Type ("_polymorphic_", [T]); - fun pretyp_of is_param typ params = let val params' = fold_atyps @@ -151,11 +135,10 @@ tm vparams; fun var_param xi = the (Vartab.lookup vparams' xi); - val preT_of = pretyp_of is_param; fun polyT_of T = fst (pretyp_of (K true) T Vartab.empty); - fun constrain T t ps = + fun constraint T t ps = if T = dummyT then (t, ps) else let val (T', ps') = preT_of T ps @@ -163,13 +146,13 @@ fun pre_of (Const (c, T)) ps = (case const_type c of - SOME U => constrain T (PConst (c, polyT_of U)) ps + SOME U => constraint T (PConst (c, polyT_of U)) ps | NONE => raise TYPE ("No such constant: " ^ quote c, [], [])) | pre_of (Var (xi, Type ("_polymorphic_", [T]))) ps = (PVar (xi, polyT_of T), ps) - | pre_of (Var (xi, T)) ps = constrain T (PVar (xi, var_param xi)) ps - | pre_of (Free (x, T)) ps = constrain T (PFree (x, var_param (x, ~1))) ps + | pre_of (Var (xi, T)) ps = constraint T (PVar (xi, var_param xi)) ps + | pre_of (Free (x, T)) ps = constraint T (PFree (x, var_param (x, ~1))) ps | pre_of (Const ("_type_constraint_", Type ("fun", [T, _])) $ t) ps = - uncurry (constrain T) (pre_of t ps) + pre_of t ps |-> constraint T | pre_of (Bound i) ps = (PBound i, ps) | pre_of (Abs (x, T, t)) ps = let @@ -250,7 +233,6 @@ exception NO_UNIFIER of string; - fun unify pp tsig = let @@ -310,6 +292,8 @@ (** type inference **) +(* appl_error *) + fun appl_error pp why t T u U = ["Type error in application: " ^ why, "", @@ -392,13 +376,27 @@ in inf [] end; -(* basic_infer_types *) +(* infer_types *) -fun basic_infer_types pp tsig const_type used freeze is_param ts Ts = +fun infer_types pp tsig const_type var_type used freeze args = let + (*certify types*) + val certT = Type.cert_typ tsig; + val (raw_ts, raw_Ts) = split_list args; + val ts = map (Term.map_types certT) raw_ts; + val Ts = map certT raw_Ts; + + (*constrain vars*) + val get_type = the_default dummyT o var_type; + val constrain_vars = Term.map_aterms + (fn Free (x, T) => constrain (Free (x, get_type (x, ~1))) T + | Var (xi, T) => constrain (Var (xi, get_type xi)) T + | t => t); + (*convert to preterms/typs*) val (Ts', Tps) = fold_map (pretyp_of (K true)) Ts Vartab.empty; - val (ts', (vps, ps)) = fold_map (preterm_of const_type is_param) ts (Vartab.empty, Tps); + val (ts', (vps, ps)) = + fold_map (preterm_of const_type is_param_default o constrain_vars) ts (Vartab.empty, Tps); (*run type inference*) val tTs' = ListPair.map Constraint (ts', Ts'); @@ -415,118 +413,6 @@ else (fn (x, S) => PTVar ((x, 0), S)); val (final_Ts, final_ts) = typs_terms_of used mk_var "" (Ts', ts'); val final_env = map (apsnd simple_typ_of) env; - in (final_ts, final_Ts, final_env) end; - - - -(** type inference **) - -(* user constraints *) - -fun constrain t T = - if T = dummyT then t - else Const ("_type_constraint_", T --> T) $ t; - - -(* user parameters *) - -fun is_param (x, _) = size x > 0 andalso ord x = ord "?"; -fun param i (x, S) = TVar (("?" ^ x, i), S); - -val paramify_dummies = - let - fun dummy S maxidx = (param (maxidx + 1) ("'dummy", S), maxidx + 1); - - fun paramify (TFree ("'_dummy_", S)) maxidx = dummy S maxidx - | paramify (Type ("dummy", _)) maxidx = dummy [] maxidx - | paramify (Type (a, Ts)) maxidx = - let val (Ts', maxidx') = fold_map paramify Ts maxidx - in (Type (a, Ts'), maxidx') end - | paramify T maxidx = (T, maxidx); - in paramify end; - - -(* get sort constraints *) - -fun get_sort tsig def_sort map_sort raw_env = - let - fun eq ((xi, S), (xi', S')) = - Term.eq_ix (xi, xi') andalso Type.eq_sort tsig (S, S'); - - val env = distinct eq (map (apsnd map_sort) raw_env); - val _ = (case duplicates (eq_fst (op =)) env of [] => () - | dups => error ("Inconsistent sort constraints for type variable(s) " - ^ commas_quote (map (Term.string_of_vname' o fst) dups))); - - fun get xi = - (case (AList.lookup (op =) env xi, def_sort xi) of - (NONE, NONE) => Type.defaultS tsig - | (NONE, SOME S) => S - | (SOME S, NONE) => S - | (SOME S, SOME S') => - if Type.eq_sort tsig (S, S') then S' - else error ("Sort constraint inconsistent with default for type variable " ^ - quote (Term.string_of_vname' xi))); - in get end; - - -(* decode_types -- transform parse tree into raw term *) - -fun decode_types tsig is_const def_type def_sort map_const map_type map_sort tm = - let - fun get_type xi = the_default dummyT (def_type xi); - fun is_free x = is_some (def_type (x, ~1)); - val raw_env = Syntax.raw_term_sorts tm; - val sort_of = get_sort tsig def_sort map_sort raw_env; - - val certT = Type.cert_typ tsig o map_type; - fun decodeT t = certT (Syntax.typ_of_term sort_of map_sort t); - - fun decode (Const ("_constrain", _) $ t $ typ) = - constrain (decode t) (decodeT typ) - | decode (Const ("_constrainAbs", _) $ (Abs (x, T, t)) $ typ) = - if T = dummyT then Abs (x, decodeT typ, decode t) - else constrain (Abs (x, certT T, decode t)) (decodeT typ --> dummyT) - | decode (Abs (x, T, t)) = Abs (x, certT T, decode t) - | decode (t $ u) = decode t $ decode u - | decode (Const (x, T)) = - let val c = (case try (unprefix Syntax.constN) x of SOME c => c | NONE => map_const x) - in Const (c, certT T) end - | decode (Free (x, T)) = - let val c = map_const x in - if not (is_free x) andalso (is_const c orelse NameSpace.is_qualified c) then - Const (c, certT T) - else if T = dummyT then Free (x, get_type (x, ~1)) - else constrain (Free (x, certT T)) (get_type (x, ~1)) - end - | decode (Var (xi, T)) = - if T = dummyT then Var (xi, get_type xi) - else constrain (Var (xi, certT T)) (get_type xi) - | decode (t as Bound _) = t; - in decode tm end; - - -(* infer_types *) - -(*Given [T1,...,Tn] and [t1,...,tn], ensure that the type of ti - unifies with Ti (for i=1,...,n). - - tsig: type signature - const_type: name mapping and signature lookup - def_type: partial map from indexnames to types (constrains Frees and Vars) - def_sort: partial map from indexnames to sorts (constrains TFrees and TVars) - used: context of already used type variables - freeze: if true then generated parameters are turned into TFrees, else TVars*) - -fun infer_types pp tsig const_type def_type def_sort - map_const map_type map_sort used freeze pat_Ts raw_ts = - let - val pat_Ts' = map (Type.cert_typ tsig) pat_Ts; - val is_const = is_some o const_type; - val raw_ts' = - map (decode_types tsig is_const def_type def_sort map_const map_type map_sort) raw_ts; - val (ts, Ts, unifier) = - basic_infer_types pp tsig const_type used freeze is_param raw_ts' pat_Ts'; - in (ts, unifier) end; + in (final_ts ~~ final_Ts, final_env) end; end;