--- a/src/Pure/type_infer.ML Mon Sep 13 12:42:08 2010 +0200
+++ b/src/Pure/type_infer.ML Mon Sep 13 13:20:18 2010 +0200
@@ -1,18 +1,22 @@
(* Title: Pure/type_infer.ML
Author: Stefan Berghofer and Markus Wenzel, TU Muenchen
-Simple type inference.
+Representation of type-inference problems. Simple type inference.
*)
signature TYPE_INFER =
sig
- val anyT: sort -> typ
val is_param: indexname -> bool
val is_paramT: typ -> bool
val param: int -> string * sort -> typ
+ val anyT: sort -> typ
val paramify_vars: typ -> typ
val paramify_dummies: typ -> int -> typ * int
- val fixate_params: Proof.context -> term list -> term list
+ val deref: typ Vartab.table -> typ -> typ
+ val finish: Proof.context -> typ Vartab.table -> typ list * term list -> typ list * term list
+ val fixate: Proof.context -> term list -> term list
+ val prepare: Proof.context -> (string -> typ option) -> (string * int -> typ option) ->
+ term list -> int * term list
val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
term list -> term list
end;
@@ -20,12 +24,8 @@
structure Type_Infer: TYPE_INFER =
struct
-
(** type parameters and constraints **)
-fun anyT S = TFree ("'_dummy_", S);
-
-
(* type inference parameters -- may get instantiated *)
fun is_param (x, _: int) = String.isPrefix "?" x;
@@ -37,6 +37,11 @@
fun mk_param i S = TVar (("?'a", i), S);
+
+(* pre-stage parameters *)
+
+fun anyT S = TFree ("'_dummy_", S);
+
val paramify_vars =
Same.commit
(Term_Subst.map_atypsT_same
@@ -54,19 +59,6 @@
| paramify T maxidx = (T, maxidx);
in paramify end;
-fun fixate_params ctxt ts =
- let
- fun subst_param (xi, S) (inst, used) =
- if is_param xi then
- let
- val [a] = Name.invents used Name.aT 1;
- val used' = Name.declare a used;
- in (((xi, S), TFree (a, S)) :: inst, used') end
- else (inst, used);
- val used = (fold o fold_types) Term.declare_typ_names ts (Variable.names_of ctxt);
- val (inst, _) = fold_rev subst_param (fold Term.add_tvars ts []) ([], used);
- in (map o map_types) (Term_Subst.instantiateT inst) ts end;
-
(** prepare types/terms: create inference parameters **)
@@ -156,7 +148,7 @@
-(** finish types/terms: standardize remaining parameters **)
+(** results **)
(* dereferenced views *)
@@ -179,7 +171,7 @@
| TVar ((x, i), _) => if is_param (x, i) then I else Name.declare x);
-(* finish *)
+(* finish -- standardize remaining parameters *)
fun finish ctxt tye (Ts, ts) =
let
@@ -200,6 +192,22 @@
in (map finish_typ Ts, map (Type.strip_constraints o Term.map_types finish_typ) ts) end;
+(* fixate -- introduce fresh type variables *)
+
+fun fixate ctxt ts =
+ let
+ fun subst_param (xi, S) (inst, used) =
+ if is_param xi then
+ let
+ val [a] = Name.invents used Name.aT 1;
+ val used' = Name.declare a used;
+ in (((xi, S), TFree (a, S)) :: inst, used') end
+ else (inst, used);
+ val used = (fold o fold_types) Term.declare_typ_names ts (Variable.names_of ctxt);
+ val (inst, _) = fold_rev subst_param (fold Term.add_tvars ts []) ([], used);
+ in (map o map_types) (Term_Subst.instantiateT inst) ts end;
+
+
(** order-sorted unification of types **)
@@ -271,7 +279,7 @@
-(** type inference **)
+(** simple type inference **)
(* infer *)
@@ -323,25 +331,27 @@
in inf [] end;
-(* infer_types *)
+(* main interfaces *)
-fun infer_types ctxt const_type var_type raw_ts =
+fun prepare ctxt const_type var_type raw_ts =
let
- (*constrain vars*)
val get_type = the_default dummyT o var_type;
val constrain_vars = Term.map_aterms
(fn Free (x, T) => Type.constraint T (Free (x, get_type (x, ~1)))
| Var (xi, T) => Type.constraint T (Var (xi, get_type xi))
| t => t);
- (*convert to preterms*)
val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
val (ts', (_, _, idx)) =
fold_map (prepare_term const_type o constrain_vars) ts
(Vartab.empty, Vartab.empty, 0);
+ in (idx, ts') end;
- (*do type inference*)
- val (tye, _) = fold (snd oo infer ctxt) ts' (Vartab.empty, idx);
- in #2 (finish ctxt tye ([], ts')) end;
+fun infer_types ctxt const_type var_type raw_ts =
+ let
+ val (idx, ts) = prepare ctxt const_type var_type raw_ts;
+ val (tye, _) = fold (snd oo infer ctxt) ts (Vartab.empty, idx);
+ val (_, ts') = finish ctxt tye ([], ts);
+ in ts' end;
end;