# HG changeset patch # User wenzelm # Date 1085167534 -7200 # Node ID 9776f0c747c865a66d4ae7820fc1a1514d53c3bf # Parent 724ce6e574e34440e15a62cf09a9543b2a1d7ad6 incorporate type inference interface from type.ML; diff -r 724ce6e574e3 -r 9776f0c747c8 src/Pure/type_infer.ML --- a/src/Pure/type_infer.ML Fri May 21 21:24:22 2004 +0200 +++ b/src/Pure/type_infer.ML Fri May 21 21:25:34 2004 +0200 @@ -10,12 +10,18 @@ val anyT: sort -> typ val logicT: typ val polymorphicT: typ -> typ - val infer_types: (term -> Pretty.T) -> (typ -> Pretty.T) - -> (string -> typ option) -> Sorts.classrel -> Sorts.arities - -> string list -> bool -> (indexname -> bool) -> term list -> typ list - -> term list * typ list * (indexname * typ) list val appl_error: (term -> Pretty.T) -> (typ -> Pretty.T) -> string -> term -> typ -> term -> typ -> string list + val constrain: term -> typ -> term + val param: int -> string * sort -> typ + val paramify_dummies: int * typ -> int * typ + val get_sort: Type.tsig -> (indexname -> sort option) -> (sort -> sort) + -> (indexname * sort) list -> indexname -> sort + val infer_types: (term -> Pretty.T) -> (typ -> Pretty.T) + -> Type.tsig -> (string -> typ option) -> (indexname -> typ option) + -> (indexname -> sort option) -> (string -> string) -> (typ -> typ) + -> (sort -> sort) -> string list -> bool -> typ list -> term list + -> term list * (indexname * typ) list end; structure TypeInfer: TYPE_INFER = @@ -247,7 +253,7 @@ exception NO_UNIFIER of string; -fun unify classrel arities = +fun unify classes arities = let (* adjust sorts of parameters *) @@ -260,17 +266,17 @@ fun meet (_, []) = () | meet (Link (r as (ref (Param S'))), S) = - if Sorts.sort_le classrel (S', S) then () - else r := mk_param (Sorts.inter_sort classrel (S', S)) + if Sorts.sort_le classes (S', S) then () + else r := mk_param (Sorts.inter_sort classes (S', S)) | meet (Link (ref T), S) = meet (T, S) | meet (PType (a, Ts), S) = - seq2 meet (Ts, Sorts.mg_domain (classrel, arities) a S + seq2 meet (Ts, Sorts.mg_domain (classes, arities) a S handle Sorts.DOMAIN ac => raise NO_UNIFIER (no_domain ac)) | meet (PTFree (x, S'), S) = - if Sorts.sort_le classrel (S', S) then () + if Sorts.sort_le classes (S', S) then () else raise NO_UNIFIER (not_in_sort x S' S) | meet (PTVar (xi, S'), S) = - if Sorts.sort_le classrel (S', S) then () + if Sorts.sort_le classes (S', S) then () else raise NO_UNIFIER (not_in_sort (Syntax.string_of_vname xi) S' S) | meet (Param _, _) = sys_error "meet"; @@ -320,7 +326,7 @@ (* infer *) (*DESTRUCTIVE*) -fun infer prt prT classrel arities = +fun infer prt prT classes arities = let (* errors *) @@ -365,7 +371,7 @@ (* main *) - val unif = unify classrel arities; + val unif = unify classes arities; fun inf _ (PConst (_, T)) = T | inf _ (PFree (_, T)) = T @@ -389,9 +395,9 @@ in inf [] end; -(* infer_types *) +(* basic_infer_types *) -fun infer_types prt prT const_type classrel arities used freeze is_param ts Ts = +fun basic_infer_types prt prT const_type classes arities used freeze is_param ts Ts = let (*convert to preterms/typs*) val (Tps, Ts') = pretyps_of (K true) ([], Ts); @@ -399,7 +405,7 @@ (*run type inference*) val tTs' = ListPair.map Constraint (ts', Ts'); - val _ = seq (fn t => (infer prt prT classrel arities t; ())) tTs'; + val _ = seq (fn t => (infer prt prT classes arities t; ())) tTs'; (*collect result unifier*) fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); None) @@ -412,9 +418,114 @@ else (fn (x, S) => PTVar ((x, 0), S)); val (final_Ts, final_ts) = typs_terms_of used mk_var "" (Ts', ts'); val final_env = map (apsnd simple_typ_of) env; - in - (final_ts, final_Ts, final_env) - end; + in (final_ts, final_Ts, final_env) end; + + + +(** type inference **) + +(* user constraints *) + +fun constrain t T = + if T = dummyT then t + else Const ("_type_constraint_", T) $ t; + + +(* user parameters *) + +fun is_param (x, _) = size x > 0 andalso ord x = ord "?"; +fun param i (x, S) = TVar (("?" ^ x, i), S); + +fun paramify_dummies (maxidx, TFree ("'_dummy_", S)) = + (maxidx + 1, param (maxidx + 1) ("'dummy", S)) + | paramify_dummies (maxidx, Type (a, Ts)) = + let val (maxidx', Ts') = foldl_map paramify_dummies (maxidx, Ts) + in (maxidx', Type (a, Ts')) end + | paramify_dummies arg = arg; + + +(* decode sort constraints *) + +fun get_sort tsig def_sort map_sort raw_env = + let + fun eq ((xi, S), (xi', S')) = + xi = xi' andalso Type.eq_sort tsig (S, S'); + + val env = gen_distinct eq (map (apsnd map_sort) raw_env); + val _ = + (case gen_duplicates eq_fst env of + [] => () + | dups => error ("Inconsistent sort constraints for type variable(s) " ^ + commas (map (quote o Syntax.string_of_vname' o fst) dups))); + + fun get xi = + (case (assoc (env, xi), def_sort xi) of + (None, None) => Type.defaultS tsig + | (None, Some S) => S + | (Some S, None) => S + | (Some S, Some S') => + if Type.eq_sort tsig (S, S') then S' + else error ("Sort constraint inconsistent with default for type variable " ^ + quote (Syntax.string_of_vname' xi))); + in get end; +(* decode_types -- transform parse tree into raw term *) + +fun decode_types tsig is_const def_type def_sort map_const map_type map_sort tm = + let + fun get_type xi = if_none (def_type xi) dummyT; + fun is_free x = is_some (def_type (x, ~1)); + val raw_env = Syntax.raw_term_sorts tm; + val sort_of = get_sort tsig def_sort map_sort raw_env; + + val certT = Type.cert_typ tsig o map_type; + fun decodeT t = certT (Syntax.typ_of_term sort_of map_sort t); + + fun decode (Const ("_constrain", _) $ t $ typ) = + constrain (decode t) (decodeT typ) + | decode (Const ("_constrainAbs", _) $ (Abs (x, T, t)) $ typ) = + if T = dummyT then Abs (x, decodeT typ, decode t) + else constrain (Abs (x, certT T, decode t)) (decodeT typ --> dummyT) + | decode (Abs (x, T, t)) = Abs (x, certT T, decode t) + | decode (t $ u) = decode t $ decode u + | decode (Free (x, T)) = + let val c = map_const x in + if not (is_free x) andalso (is_const c orelse NameSpace.is_qualified c) then + Const (c, certT T) + else if T = dummyT then Free (x, get_type (x, ~1)) + else constrain (Free (x, certT T)) (get_type (x, ~1)) + end + | decode (Var (xi, T)) = + if T = dummyT then Var (xi, get_type xi) + else constrain (Var (xi, certT T)) (get_type xi) + | decode (t as Bound _) = t + | decode (Const (c, T)) = Const (map_const c, certT T); + in decode tm end; + + +(* infer_types *) + +(*Given [T1,...,Tn] and [t1,...,tn], ensure that the type of ti + unifies with Ti (for i=1,...,n). + + tsig: type signature + const_type: name mapping and signature lookup + def_type: partial map from indexnames to types (constrains Frees, Vars) + def_sort: partial map from indexnames to sorts (constrains TFrees, TVars) + used: list of already used type variables + freeze: if true then generated parameters are turned into TFrees, else TVars*) + +fun infer_types prt prT tsig const_type def_type def_sort + map_const map_type map_sort used freeze pat_Ts raw_ts = + let + val {classes, arities, ...} = Type.rep_tsig tsig; + val pat_Ts' = map (Type.cert_typ tsig) pat_Ts; + val is_const = is_some o const_type; + val raw_ts' = + map (decode_types tsig is_const def_type def_sort map_const map_type map_sort) raw_ts; + val (ts, Ts, unifier) = basic_infer_types prt prT const_type + classes arities used freeze is_param raw_ts' pat_Ts'; + in (ts, unifier) end; + end;