--- a/src/Pure/type_infer.ML Fri May 21 21:24:22 2004 +0200
+++ b/src/Pure/type_infer.ML Fri May 21 21:25:34 2004 +0200
@@ -10,12 +10,18 @@
val anyT: sort -> typ
val logicT: typ
val polymorphicT: typ -> typ
- val infer_types: (term -> Pretty.T) -> (typ -> Pretty.T)
- -> (string -> typ option) -> Sorts.classrel -> Sorts.arities
- -> string list -> bool -> (indexname -> bool) -> term list -> typ list
- -> term list * typ list * (indexname * typ) list
val appl_error: (term -> Pretty.T) -> (typ -> Pretty.T)
-> string -> term -> typ -> term -> typ -> string list
+ val constrain: term -> typ -> term
+ val param: int -> string * sort -> typ
+ val paramify_dummies: int * typ -> int * typ
+ val get_sort: Type.tsig -> (indexname -> sort option) -> (sort -> sort)
+ -> (indexname * sort) list -> indexname -> sort
+ val infer_types: (term -> Pretty.T) -> (typ -> Pretty.T)
+ -> Type.tsig -> (string -> typ option) -> (indexname -> typ option)
+ -> (indexname -> sort option) -> (string -> string) -> (typ -> typ)
+ -> (sort -> sort) -> string list -> bool -> typ list -> term list
+ -> term list * (indexname * typ) list
end;
structure TypeInfer: TYPE_INFER =
@@ -247,7 +253,7 @@
exception NO_UNIFIER of string;
-fun unify classrel arities =
+fun unify classes arities =
let
(* adjust sorts of parameters *)
@@ -260,17 +266,17 @@
fun meet (_, []) = ()
| meet (Link (r as (ref (Param S'))), S) =
- if Sorts.sort_le classrel (S', S) then ()
- else r := mk_param (Sorts.inter_sort classrel (S', S))
+ if Sorts.sort_le classes (S', S) then ()
+ else r := mk_param (Sorts.inter_sort classes (S', S))
| meet (Link (ref T), S) = meet (T, S)
| meet (PType (a, Ts), S) =
- seq2 meet (Ts, Sorts.mg_domain (classrel, arities) a S
+ seq2 meet (Ts, Sorts.mg_domain (classes, arities) a S
handle Sorts.DOMAIN ac => raise NO_UNIFIER (no_domain ac))
| meet (PTFree (x, S'), S) =
- if Sorts.sort_le classrel (S', S) then ()
+ if Sorts.sort_le classes (S', S) then ()
else raise NO_UNIFIER (not_in_sort x S' S)
| meet (PTVar (xi, S'), S) =
- if Sorts.sort_le classrel (S', S) then ()
+ if Sorts.sort_le classes (S', S) then ()
else raise NO_UNIFIER (not_in_sort (Syntax.string_of_vname xi) S' S)
| meet (Param _, _) = sys_error "meet";
@@ -320,7 +326,7 @@
(* infer *) (*DESTRUCTIVE*)
-fun infer prt prT classrel arities =
+fun infer prt prT classes arities =
let
(* errors *)
@@ -365,7 +371,7 @@
(* main *)
- val unif = unify classrel arities;
+ val unif = unify classes arities;
fun inf _ (PConst (_, T)) = T
| inf _ (PFree (_, T)) = T
@@ -389,9 +395,9 @@
in inf [] end;
-(* infer_types *)
+(* basic_infer_types *)
-fun infer_types prt prT const_type classrel arities used freeze is_param ts Ts =
+fun basic_infer_types prt prT const_type classes arities used freeze is_param ts Ts =
let
(*convert to preterms/typs*)
val (Tps, Ts') = pretyps_of (K true) ([], Ts);
@@ -399,7 +405,7 @@
(*run type inference*)
val tTs' = ListPair.map Constraint (ts', Ts');
- val _ = seq (fn t => (infer prt prT classrel arities t; ())) tTs';
+ val _ = seq (fn t => (infer prt prT classes arities t; ())) tTs';
(*collect result unifier*)
fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); None)
@@ -412,9 +418,114 @@
else (fn (x, S) => PTVar ((x, 0), S));
val (final_Ts, final_ts) = typs_terms_of used mk_var "" (Ts', ts');
val final_env = map (apsnd simple_typ_of) env;
- in
- (final_ts, final_Ts, final_env)
- end;
+ in (final_ts, final_Ts, final_env) end;
+
+
+
+(** type inference **)
+
+(* user constraints *)
+
+fun constrain t T =
+ if T = dummyT then t
+ else Const ("_type_constraint_", T) $ t;
+
+
+(* user parameters *)
+
+fun is_param (x, _) = size x > 0 andalso ord x = ord "?";
+fun param i (x, S) = TVar (("?" ^ x, i), S);
+
+fun paramify_dummies (maxidx, TFree ("'_dummy_", S)) =
+ (maxidx + 1, param (maxidx + 1) ("'dummy", S))
+ | paramify_dummies (maxidx, Type (a, Ts)) =
+ let val (maxidx', Ts') = foldl_map paramify_dummies (maxidx, Ts)
+ in (maxidx', Type (a, Ts')) end
+ | paramify_dummies arg = arg;
+
+
+(* decode sort constraints *)
+
+fun get_sort tsig def_sort map_sort raw_env =
+ let
+ fun eq ((xi, S), (xi', S')) =
+ xi = xi' andalso Type.eq_sort tsig (S, S');
+
+ val env = gen_distinct eq (map (apsnd map_sort) raw_env);
+ val _ =
+ (case gen_duplicates eq_fst env of
+ [] => ()
+ | dups => error ("Inconsistent sort constraints for type variable(s) " ^
+ commas (map (quote o Syntax.string_of_vname' o fst) dups)));
+
+ fun get xi =
+ (case (assoc (env, xi), def_sort xi) of
+ (None, None) => Type.defaultS tsig
+ | (None, Some S) => S
+ | (Some S, None) => S
+ | (Some S, Some S') =>
+ if Type.eq_sort tsig (S, S') then S'
+ else error ("Sort constraint inconsistent with default for type variable " ^
+ quote (Syntax.string_of_vname' xi)));
+ in get end;
+(* decode_types -- transform parse tree into raw term *)
+
+fun decode_types tsig is_const def_type def_sort map_const map_type map_sort tm =
+ let
+ fun get_type xi = if_none (def_type xi) dummyT;
+ fun is_free x = is_some (def_type (x, ~1));
+ val raw_env = Syntax.raw_term_sorts tm;
+ val sort_of = get_sort tsig def_sort map_sort raw_env;
+
+ val certT = Type.cert_typ tsig o map_type;
+ fun decodeT t = certT (Syntax.typ_of_term sort_of map_sort t);
+
+ fun decode (Const ("_constrain", _) $ t $ typ) =
+ constrain (decode t) (decodeT typ)
+ | decode (Const ("_constrainAbs", _) $ (Abs (x, T, t)) $ typ) =
+ if T = dummyT then Abs (x, decodeT typ, decode t)
+ else constrain (Abs (x, certT T, decode t)) (decodeT typ --> dummyT)
+ | decode (Abs (x, T, t)) = Abs (x, certT T, decode t)
+ | decode (t $ u) = decode t $ decode u
+ | decode (Free (x, T)) =
+ let val c = map_const x in
+ if not (is_free x) andalso (is_const c orelse NameSpace.is_qualified c) then
+ Const (c, certT T)
+ else if T = dummyT then Free (x, get_type (x, ~1))
+ else constrain (Free (x, certT T)) (get_type (x, ~1))
+ end
+ | decode (Var (xi, T)) =
+ if T = dummyT then Var (xi, get_type xi)
+ else constrain (Var (xi, certT T)) (get_type xi)
+ | decode (t as Bound _) = t
+ | decode (Const (c, T)) = Const (map_const c, certT T);
+ in decode tm end;
+
+
+(* infer_types *)
+
+(*Given [T1,...,Tn] and [t1,...,tn], ensure that the type of ti
+ unifies with Ti (for i=1,...,n).
+
+ tsig: type signature
+ const_type: name mapping and signature lookup
+ def_type: partial map from indexnames to types (constrains Frees, Vars)
+ def_sort: partial map from indexnames to sorts (constrains TFrees, TVars)
+ used: list of already used type variables
+ freeze: if true then generated parameters are turned into TFrees, else TVars*)
+
+fun infer_types prt prT tsig const_type def_type def_sort
+ map_const map_type map_sort used freeze pat_Ts raw_ts =
+ let
+ val {classes, arities, ...} = Type.rep_tsig tsig;
+ val pat_Ts' = map (Type.cert_typ tsig) pat_Ts;
+ val is_const = is_some o const_type;
+ val raw_ts' =
+ map (decode_types tsig is_const def_type def_sort map_const map_type map_sort) raw_ts;
+ val (ts, Ts, unifier) = basic_infer_types prt prT const_type
+ classes arities used freeze is_param raw_ts' pat_Ts';
+ in (ts, unifier) end;
+
end;