--- a/src/Pure/type_infer.ML Wed Jun 18 22:32:01 2008 +0200
+++ b/src/Pure/type_infer.ML Wed Jun 18 22:32:02 2008 +0200
@@ -17,8 +17,8 @@
val fixate_params: Name.context -> term list -> term list
val appl_error: Pretty.pp -> string -> term -> typ -> term -> typ -> string list
val infer_types: Pretty.pp -> Type.tsig -> (typ list -> typ list) ->
- (string -> typ option) -> (indexname -> typ option) -> Name.context -> int -> bool option ->
- (term * typ) list -> (term * typ) list * (indexname * typ) list
+ (string -> typ option) -> (indexname -> typ option) -> Name.context -> int ->
+ term list -> term list
end;
structure TypeInfer: TYPE_INFER =
@@ -32,9 +32,7 @@
(*indicate polymorphic Vars*)
fun polymorphicT T = Type ("_polymorphic_", [T]);
-fun constrain T t =
- if T = dummyT then t
- else Const ("_type_constraint_", T --> T) $ t;
+val constrain = Syntax.type_constraint;
(* user parameters *)
@@ -230,18 +228,16 @@
(* typs_terms_of *) (*DESTRUCTIVE*)
-fun typs_terms_of used mk_var prfx (Ts, ts) =
+fun typs_terms_of used maxidx (Ts, ts) =
let
- fun elim (r as ref (Param S), x) = r := mk_var (x, S)
+ fun elim (r as ref (Param S), x) = r := PTVar ((x, maxidx + 1), S)
| elim _ = ();
val used' = fold add_names ts (fold add_namesT Ts used);
val parms = rev (fold add_parms ts (fold add_parmsT Ts []));
- val names = Name.invents used' (prfx ^ Name.aT) (length parms);
- in
- ListPair.app elim (parms, names);
- (map simple_typ_of Ts, map simple_term_of ts)
- end;
+ val names = Name.invents used' ("?" ^ Name.aT) (length parms);
+ val _ = ListPair.app elim (parms, names);
+ in (map simple_typ_of Ts, map simple_term_of ts) end;
@@ -333,11 +329,12 @@
fun prep_output bs ts Ts =
let
- val (Ts_bTs', ts') = typs_terms_of Name.context PTFree "??" (Ts @ map snd bs, ts);
+ val (Ts_bTs', ts') = typs_terms_of Name.context ~1 (Ts @ map snd bs, ts);
val (Ts', Ts'') = chop (length Ts) Ts_bTs';
- val xs = map Free (map fst bs ~~ Ts'');
- val ts'' = map (fn t => subst_bounds (xs, t)) ts';
- in (ts'', Ts') end;
+ fun prep t =
+ let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
+ in Term.subst_bounds (map Syntax.mark_boundT xs, t) end;
+ in (map prep ts', Ts') end;
fun err_loose i =
raise TYPE ("Loose bound variable: B." ^ string_of_int i, [], []);
@@ -394,13 +391,8 @@
(* infer_types *)
-fun infer_types pp tsig check_typs const_type var_type used maxidx freeze_mode args =
+fun infer_types pp tsig check_typs const_type var_type used maxidx raw_ts =
let
- (*check types*)
- val (raw_ts, raw_Ts) = split_list args;
- val ts = burrow_types check_typs raw_ts;
- val Ts = check_typs raw_Ts;
-
(*constrain vars*)
val get_type = the_default dummyT o var_type;
val constrain_vars = Term.map_aterms
@@ -408,26 +400,13 @@
| Var (xi, T) => constrain T (Var (xi, get_type xi))
| t => t);
- (*convert to preterms/typs*)
- val (Ts', Tps) = fold_map (pretyp_of (K true)) Ts Vartab.empty;
+ (*convert to preterms*)
+ val ts = burrow_types check_typs raw_ts;
val (ts', (vps, ps)) =
- fold_map (preterm_of const_type is_param o constrain_vars) ts (Vartab.empty, Tps);
-
- (*run type inference*)
- val tTs' = ListPair.map Constraint (ts', Ts');
- val _ = List.app (fn t => (infer pp tsig t; ())) tTs';
+ fold_map (preterm_of const_type is_param o constrain_vars) ts (Vartab.empty, Vartab.empty);
- (*convert back to terms/typs*)
- val mk_var =
- if the_default false freeze_mode then PTFree
- else (fn (x, S) => PTVar ((x, maxidx + 1), S));
- val prfx = if is_some freeze_mode then "" else "?";
- val (final_Ts, final_ts) = typs_terms_of used mk_var prfx (Ts', ts');
-
- (*collect result unifier*)
- val redundant = fn (xi, TVar (yi, _)) => xi = yi | _ => false;
- val env = filter_out redundant (map (apsnd simple_typ_of) (Vartab.dest Tps));
-
- in (final_ts ~~ final_Ts, env) end;
+ (*do type inference*)
+ val _ = List.app (ignore o infer pp tsig) ts';
+ in #2 (typs_terms_of used maxidx ([], ts')) end;
end;