--- a/src/Pure/type_infer.ML Sun Sep 12 19:55:45 2010 +0200
+++ b/src/Pure/type_infer.ML Sun Sep 12 20:47:47 2010 +0200
@@ -7,14 +7,12 @@
signature TYPE_INFER =
sig
val anyT: sort -> typ
- val polymorphicT: typ -> typ
val is_param: indexname -> bool
val param: int -> string * sort -> typ
val paramify_vars: typ -> typ
val paramify_dummies: typ -> int -> typ * int
val fixate_params: Name.context -> term list -> term list
- val infer_types: Pretty.pp -> Type.tsig -> (typ list -> typ list) ->
- (string -> typ option) -> (indexname -> typ option) -> Name.context -> int ->
+ val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
term list -> term list
end;
@@ -26,9 +24,6 @@
fun anyT S = TFree ("'_dummy_", S);
-(*indicate polymorphic Vars*)
-fun polymorphicT T = Type ("_polymorphic_", [T]);
-
(* type inference parameters -- may get instantiated *)
@@ -235,12 +230,14 @@
(* typs_terms_of *)
-fun typs_terms_of tye used maxidx (Ts, ts) =
+fun typs_terms_of ctxt tye (Ts, ts) =
let
- val used' = fold (add_names tye) ts (fold (add_namesT tye) Ts used);
+ val used = fold (add_names tye) ts (fold (add_namesT tye) Ts (Variable.names_of ctxt));
val parms = rev (fold (add_parms tye) ts (fold (add_parmsT tye) Ts []));
- val names = Name.invents used' ("?" ^ Name.aT) (length parms);
+ val names = Name.invents used ("?" ^ Name.aT) (length parms);
val tab = Inttab.make (parms ~~ names);
+
+ val maxidx = Variable.maxidx_of ctxt;
fun f i = (the (Inttab.lookup tab i), maxidx + 1);
in (map (simple_typ_of tye f) Ts, map (simple_term_of tye f) ts) end;
@@ -250,27 +247,31 @@
exception NO_UNIFIER of string * pretyp Inttab.table;
-fun unify pp tsig =
+fun unify ctxt pp =
let
+ val thy = ProofContext.theory_of ctxt;
+ val arity_sorts = Type.arity_sorts pp (Sign.tsig_of thy);
+
+
(* adjust sorts of parameters *)
fun not_of_sort x S' S =
- "Variable " ^ x ^ "::" ^ Pretty.string_of_sort pp S' ^ " not of sort " ^
- Pretty.string_of_sort pp S;
+ "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
+ Syntax.string_of_sort ctxt S;
fun meet (_, []) tye_idx = tye_idx
| meet (Param (i, S'), S) (tye_idx as (tye, idx)) =
- if Type.subsort tsig (S', S) then tye_idx
+ if Sign.subsort thy (S', S) then tye_idx
else (Inttab.update_new (i,
- Param (idx, Type.inter_sort tsig (S', S))) tye, idx + 1)
+ Param (idx, Sign.inter_sort thy (S', S))) tye, idx + 1)
| meet (PType (a, Ts), S) (tye_idx as (tye, _)) =
- meets (Ts, Type.arity_sorts pp tsig a S
+ meets (Ts, arity_sorts a S
handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
| meet (PTFree (x, S'), S) (tye_idx as (tye, _)) =
- if Type.subsort tsig (S', S) then tye_idx
+ if Sign.subsort thy (S', S) then tye_idx
else raise NO_UNIFIER (not_of_sort x S' S, tye)
| meet (PTVar (xi, S'), S) (tye_idx as (tye, _)) =
- if Type.subsort tsig (S', S) then tye_idx
+ if Sign.subsort thy (S', S) then tye_idx
else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
meets (Ts, Ss) (meet (deref tye T, S) tye_idx)
@@ -298,7 +299,7 @@
(* unification *)
fun show_tycon (a, Ts) =
- quote (Pretty.string_of_typ pp (Type (a, replicate (length Ts) dummyT)));
+ quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
fun unif (T1, T2) (tye_idx as (tye, idx)) =
(case (deref tye T1, deref tye T2) of
@@ -319,13 +320,16 @@
(* infer *)
-fun infer pp tsig =
+fun infer ctxt =
let
+ val pp = Syntax.pp ctxt;
+
+
(* errors *)
fun prep_output tye bs ts Ts =
let
- val (Ts_bTs', ts') = typs_terms_of tye Name.context ~1 (Ts @ map snd bs, ts);
+ val (Ts_bTs', ts') = typs_terms_of ctxt tye (Ts @ map snd bs, ts);
val (Ts', Ts'') = chop (length Ts) Ts_bTs';
fun prep t =
let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
@@ -355,8 +359,6 @@
(* main *)
- val unif = unify pp tsig;
-
fun inf _ (PConst (_, T)) tye_idx = (T, tye_idx)
| inf _ (PFree (_, T)) tye_idx = (T, tye_idx)
| inf _ (PVar (_, T)) tye_idx = (T, tye_idx)
@@ -371,13 +373,13 @@
val (U, (tye, idx)) = inf bs u tye_idx';
val V = Param (idx, []);
val U_to_V = PType ("fun", [U, V]);
- val tye_idx'' = unif (U_to_V, T) (tye, idx + 1)
+ val tye_idx'' = unify ctxt pp (U_to_V, T) (tye, idx + 1)
handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U;
in (V, tye_idx'') end
| inf bs (Constraint (t, U)) tye_idx =
let val (T, tye_idx') = inf bs t tye_idx in
(T,
- unif (T, U) tye_idx'
+ unify ctxt pp (T, U) tye_idx'
handle NO_UNIFIER (msg, tye) => err_constraint msg tye bs t T U)
end;
@@ -386,7 +388,7 @@
(* infer_types *)
-fun infer_types pp tsig check_typs const_type var_type used maxidx raw_ts =
+fun infer_types ctxt const_type var_type raw_ts =
let
(*constrain vars*)
val get_type = the_default dummyT o var_type;
@@ -396,13 +398,13 @@
| t => t);
(*convert to preterms*)
- val ts = burrow_types check_typs raw_ts;
+ val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
val (ts', (_, _, idx)) =
fold_map (preterm_of const_type o constrain_vars) ts
(Vartab.empty, Vartab.empty, 0);
(*do type inference*)
- val (tye, _) = fold (snd oo infer pp tsig) ts' (Inttab.empty, idx);
- in #2 (typs_terms_of tye used maxidx ([], ts')) end;
+ val (tye, _) = fold (snd oo infer ctxt) ts' (Inttab.empty, idx);
+ in #2 (typs_terms_of ctxt tye ([], ts')) end;
end;