src/Pure/type_infer.ML
changeset 39290 44e4d8dfd6bf
parent 39289 92b50c8bb67b
child 39291 4b632bb847a8
--- a/src/Pure/type_infer.ML	Sun Sep 12 19:55:45 2010 +0200
+++ b/src/Pure/type_infer.ML	Sun Sep 12 20:47:47 2010 +0200
@@ -7,14 +7,12 @@
 signature TYPE_INFER =
 sig
   val anyT: sort -> typ
-  val polymorphicT: typ -> typ
   val is_param: indexname -> bool
   val param: int -> string * sort -> typ
   val paramify_vars: typ -> typ
   val paramify_dummies: typ -> int -> typ * int
   val fixate_params: Name.context -> term list -> term list
-  val infer_types: Pretty.pp -> Type.tsig -> (typ list -> typ list) ->
-    (string -> typ option) -> (indexname -> typ option) -> Name.context -> int ->
+  val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
     term list -> term list
 end;
 
@@ -26,9 +24,6 @@
 
 fun anyT S = TFree ("'_dummy_", S);
 
-(*indicate polymorphic Vars*)
-fun polymorphicT T = Type ("_polymorphic_", [T]);
-
 
 (* type inference parameters -- may get instantiated *)
 
@@ -235,12 +230,14 @@
 
 (* typs_terms_of *)
 
-fun typs_terms_of tye used maxidx (Ts, ts) =
+fun typs_terms_of ctxt tye (Ts, ts) =
   let
-    val used' = fold (add_names tye) ts (fold (add_namesT tye) Ts used);
+    val used = fold (add_names tye) ts (fold (add_namesT tye) Ts (Variable.names_of ctxt));
     val parms = rev (fold (add_parms tye) ts (fold (add_parmsT tye) Ts []));
-    val names = Name.invents used' ("?" ^ Name.aT) (length parms);
+    val names = Name.invents used ("?" ^ Name.aT) (length parms);
     val tab = Inttab.make (parms ~~ names);
+
+    val maxidx = Variable.maxidx_of ctxt;
     fun f i = (the (Inttab.lookup tab i), maxidx + 1);
   in (map (simple_typ_of tye f) Ts, map (simple_term_of tye f) ts) end;
 
@@ -250,27 +247,31 @@
 
 exception NO_UNIFIER of string * pretyp Inttab.table;
 
-fun unify pp tsig =
+fun unify ctxt pp =
   let
+    val thy = ProofContext.theory_of ctxt;
+    val arity_sorts = Type.arity_sorts pp (Sign.tsig_of thy);
+
+
     (* adjust sorts of parameters *)
 
     fun not_of_sort x S' S =
-      "Variable " ^ x ^ "::" ^ Pretty.string_of_sort pp S' ^ " not of sort " ^
-        Pretty.string_of_sort pp S;
+      "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
+        Syntax.string_of_sort ctxt S;
 
     fun meet (_, []) tye_idx = tye_idx
       | meet (Param (i, S'), S) (tye_idx as (tye, idx)) =
-          if Type.subsort tsig (S', S) then tye_idx
+          if Sign.subsort thy (S', S) then tye_idx
           else (Inttab.update_new (i,
-            Param (idx, Type.inter_sort tsig (S', S))) tye, idx + 1)
+            Param (idx, Sign.inter_sort thy (S', S))) tye, idx + 1)
       | meet (PType (a, Ts), S) (tye_idx as (tye, _)) =
-          meets (Ts, Type.arity_sorts pp tsig a S
+          meets (Ts, arity_sorts a S
             handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
       | meet (PTFree (x, S'), S) (tye_idx as (tye, _)) =
-          if Type.subsort tsig (S', S) then tye_idx
+          if Sign.subsort thy (S', S) then tye_idx
           else raise NO_UNIFIER (not_of_sort x S' S, tye)
       | meet (PTVar (xi, S'), S) (tye_idx as (tye, _)) =
-          if Type.subsort tsig (S', S) then tye_idx
+          if Sign.subsort thy (S', S) then tye_idx
           else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
     and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
           meets (Ts, Ss) (meet (deref tye T, S) tye_idx)
@@ -298,7 +299,7 @@
     (* unification *)
 
     fun show_tycon (a, Ts) =
-      quote (Pretty.string_of_typ pp (Type (a, replicate (length Ts) dummyT)));
+      quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
 
     fun unif (T1, T2) (tye_idx as (tye, idx)) =
       (case (deref tye T1, deref tye T2) of
@@ -319,13 +320,16 @@
 
 (* infer *)
 
-fun infer pp tsig =
+fun infer ctxt =
   let
+    val pp = Syntax.pp ctxt;
+
+
     (* errors *)
 
     fun prep_output tye bs ts Ts =
       let
-        val (Ts_bTs', ts') = typs_terms_of tye Name.context ~1 (Ts @ map snd bs, ts);
+        val (Ts_bTs', ts') = typs_terms_of ctxt tye (Ts @ map snd bs, ts);
         val (Ts', Ts'') = chop (length Ts) Ts_bTs';
         fun prep t =
           let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
@@ -355,8 +359,6 @@
 
     (* main *)
 
-    val unif = unify pp tsig;
-
     fun inf _ (PConst (_, T)) tye_idx = (T, tye_idx)
       | inf _ (PFree (_, T)) tye_idx = (T, tye_idx)
       | inf _ (PVar (_, T)) tye_idx = (T, tye_idx)
@@ -371,13 +373,13 @@
             val (U, (tye, idx)) = inf bs u tye_idx';
             val V = Param (idx, []);
             val U_to_V = PType ("fun", [U, V]);
-            val tye_idx'' = unif (U_to_V, T) (tye, idx + 1)
+            val tye_idx'' = unify ctxt pp (U_to_V, T) (tye, idx + 1)
               handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U;
           in (V, tye_idx'') end
       | inf bs (Constraint (t, U)) tye_idx =
           let val (T, tye_idx') = inf bs t tye_idx in
             (T,
-             unif (T, U) tye_idx'
+             unify ctxt pp (T, U) tye_idx'
                handle NO_UNIFIER (msg, tye) => err_constraint msg tye bs t T U)
           end;
 
@@ -386,7 +388,7 @@
 
 (* infer_types *)
 
-fun infer_types pp tsig check_typs const_type var_type used maxidx raw_ts =
+fun infer_types ctxt const_type var_type raw_ts =
   let
     (*constrain vars*)
     val get_type = the_default dummyT o var_type;
@@ -396,13 +398,13 @@
         | t => t);
 
     (*convert to preterms*)
-    val ts = burrow_types check_typs raw_ts;
+    val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
     val (ts', (_, _, idx)) =
       fold_map (preterm_of const_type o constrain_vars) ts
       (Vartab.empty, Vartab.empty, 0);
 
     (*do type inference*)
-    val (tye, _) = fold (snd oo infer pp tsig) ts' (Inttab.empty, idx);
-  in #2 (typs_terms_of tye used maxidx ([], ts')) end;
+    val (tye, _) = fold (snd oo infer ctxt) ts' (Inttab.empty, idx);
+  in #2 (typs_terms_of ctxt tye ([], ts')) end;
 
 end;