src/Pure/type_infer.ML
changeset 14788 9776f0c747c8
parent 14695 9c78044b99c3
child 14828 15d12761ba54
--- a/src/Pure/type_infer.ML	Fri May 21 21:24:22 2004 +0200
+++ b/src/Pure/type_infer.ML	Fri May 21 21:25:34 2004 +0200
@@ -10,12 +10,18 @@
   val anyT: sort -> typ
   val logicT: typ
   val polymorphicT: typ -> typ
-  val infer_types: (term -> Pretty.T) -> (typ -> Pretty.T)
-    -> (string -> typ option) -> Sorts.classrel -> Sorts.arities
-    -> string list -> bool -> (indexname -> bool) -> term list -> typ list
-    -> term list * typ list * (indexname * typ) list
   val appl_error: (term -> Pretty.T) -> (typ -> Pretty.T)
     -> string -> term -> typ -> term -> typ -> string list
+  val constrain: term -> typ -> term
+  val param: int -> string * sort -> typ
+  val paramify_dummies: int * typ -> int * typ
+  val get_sort: Type.tsig -> (indexname -> sort option) -> (sort -> sort)
+    -> (indexname * sort) list -> indexname -> sort
+  val infer_types: (term -> Pretty.T) -> (typ -> Pretty.T)
+    -> Type.tsig -> (string -> typ option) -> (indexname -> typ option)
+    -> (indexname -> sort option) -> (string -> string) -> (typ -> typ)
+    -> (sort -> sort) -> string list -> bool -> typ list -> term list
+    -> term list * (indexname * typ) list
 end;
 
 structure TypeInfer: TYPE_INFER =
@@ -247,7 +253,7 @@
 exception NO_UNIFIER of string;
 
 
-fun unify classrel arities =
+fun unify classes arities =
   let
 
     (* adjust sorts of parameters *)
@@ -260,17 +266,17 @@
 
     fun meet (_, []) = ()
       | meet (Link (r as (ref (Param S'))), S) =
-          if Sorts.sort_le classrel (S', S) then ()
-          else r := mk_param (Sorts.inter_sort classrel (S', S))
+          if Sorts.sort_le classes (S', S) then ()
+          else r := mk_param (Sorts.inter_sort classes (S', S))
       | meet (Link (ref T), S) = meet (T, S)
       | meet (PType (a, Ts), S) =
-          seq2 meet (Ts, Sorts.mg_domain (classrel, arities) a S
+          seq2 meet (Ts, Sorts.mg_domain (classes, arities) a S
             handle Sorts.DOMAIN ac => raise NO_UNIFIER (no_domain ac))
       | meet (PTFree (x, S'), S) =
-          if Sorts.sort_le classrel (S', S) then ()
+          if Sorts.sort_le classes (S', S) then ()
           else raise NO_UNIFIER (not_in_sort x S' S)
       | meet (PTVar (xi, S'), S) =
-          if Sorts.sort_le classrel (S', S) then ()
+          if Sorts.sort_le classes (S', S) then ()
           else raise NO_UNIFIER (not_in_sort (Syntax.string_of_vname xi) S' S)
       | meet (Param _, _) = sys_error "meet";
 
@@ -320,7 +326,7 @@
 
 (* infer *)                                     (*DESTRUCTIVE*)
 
-fun infer prt prT classrel arities =
+fun infer prt prT classes arities =
   let
     (* errors *)
 
@@ -365,7 +371,7 @@
 
     (* main *)
 
-    val unif = unify classrel arities;
+    val unif = unify classes arities;
 
     fun inf _ (PConst (_, T)) = T
       | inf _ (PFree (_, T)) = T
@@ -389,9 +395,9 @@
   in inf [] end;
 
 
-(* infer_types *)
+(* basic_infer_types *)
 
-fun infer_types prt prT const_type classrel arities used freeze is_param ts Ts =
+fun basic_infer_types prt prT const_type classes arities used freeze is_param ts Ts =
   let
     (*convert to preterms/typs*)
     val (Tps, Ts') = pretyps_of (K true) ([], Ts);
@@ -399,7 +405,7 @@
 
     (*run type inference*)
     val tTs' = ListPair.map Constraint (ts', Ts');
-    val _ = seq (fn t => (infer prt prT classrel arities t; ())) tTs';
+    val _ = seq (fn t => (infer prt prT classes arities t; ())) tTs';
 
     (*collect result unifier*)
     fun ch_var (xi, Link (r as ref (Param S))) = (r := PTVar (xi, S); None)
@@ -412,9 +418,114 @@
       else (fn (x, S) => PTVar ((x, 0), S));
     val (final_Ts, final_ts) = typs_terms_of used mk_var "" (Ts', ts');
     val final_env = map (apsnd simple_typ_of) env;
-  in
-    (final_ts, final_Ts, final_env)
-  end;
+  in (final_ts, final_Ts, final_env) end;
+
+
+
+(** type inference **)
+
+(* user constraints *)
+
+fun constrain t T =
+  if T = dummyT then t
+  else Const ("_type_constraint_", T) $ t;
+
+
+(* user parameters *)
+
+fun is_param (x, _) = size x > 0 andalso ord x = ord "?";
+fun param i (x, S) = TVar (("?" ^ x, i), S);
+
+fun paramify_dummies (maxidx, TFree ("'_dummy_", S)) =
+      (maxidx + 1, param (maxidx + 1) ("'dummy", S))
+  | paramify_dummies (maxidx, Type (a, Ts)) =
+      let val (maxidx', Ts') = foldl_map paramify_dummies (maxidx, Ts)
+      in (maxidx', Type (a, Ts')) end
+  | paramify_dummies arg = arg;
+
+
+(* decode sort constraints *)
+
+fun get_sort tsig def_sort map_sort raw_env =
+  let
+    fun eq ((xi, S), (xi', S')) =
+      xi = xi' andalso Type.eq_sort tsig (S, S');
+
+    val env = gen_distinct eq (map (apsnd map_sort) raw_env);
+    val _ =
+      (case gen_duplicates eq_fst env of
+        [] => ()
+      | dups => error ("Inconsistent sort constraints for type variable(s) " ^
+          commas (map (quote o Syntax.string_of_vname' o fst) dups)));
+
+    fun get xi =
+      (case (assoc (env, xi), def_sort xi) of
+        (None, None) => Type.defaultS tsig
+      | (None, Some S) => S
+      | (Some S, None) => S
+      | (Some S, Some S') =>
+          if Type.eq_sort tsig (S, S') then S'
+          else error ("Sort constraint inconsistent with default for type variable " ^
+            quote (Syntax.string_of_vname' xi)));
+  in get end;
 
 
+(* decode_types -- transform parse tree into raw term *)
+
+fun decode_types tsig is_const def_type def_sort map_const map_type map_sort tm =
+  let
+    fun get_type xi = if_none (def_type xi) dummyT;
+    fun is_free x = is_some (def_type (x, ~1));
+    val raw_env = Syntax.raw_term_sorts tm;
+    val sort_of = get_sort tsig def_sort map_sort raw_env;
+
+    val certT = Type.cert_typ tsig o map_type;
+    fun decodeT t = certT (Syntax.typ_of_term sort_of map_sort t);
+
+    fun decode (Const ("_constrain", _) $ t $ typ) =
+          constrain (decode t) (decodeT typ)
+      | decode (Const ("_constrainAbs", _) $ (Abs (x, T, t)) $ typ) =
+          if T = dummyT then Abs (x, decodeT typ, decode t)
+          else constrain (Abs (x, certT T, decode t)) (decodeT typ --> dummyT)
+      | decode (Abs (x, T, t)) = Abs (x, certT T, decode t)
+      | decode (t $ u) = decode t $ decode u
+      | decode (Free (x, T)) =
+          let val c = map_const x in
+            if not (is_free x) andalso (is_const c orelse NameSpace.is_qualified c) then
+              Const (c, certT T)
+            else if T = dummyT then Free (x, get_type (x, ~1))
+            else constrain (Free (x, certT T)) (get_type (x, ~1))
+          end
+      | decode (Var (xi, T)) =
+          if T = dummyT then Var (xi, get_type xi)
+          else constrain (Var (xi, certT T)) (get_type xi)
+      | decode (t as Bound _) = t
+      | decode (Const (c, T)) = Const (map_const c, certT T);
+  in decode tm end;
+
+
+(* infer_types *)
+
+(*Given [T1,...,Tn] and [t1,...,tn], ensure that the type of ti
+  unifies with Ti (for i=1,...,n).
+
+  tsig: type signature
+  const_type: name mapping and signature lookup
+  def_type: partial map from indexnames to types (constrains Frees, Vars)
+  def_sort: partial map from indexnames to sorts (constrains TFrees, TVars)
+  used: list of already used type variables
+  freeze: if true then generated parameters are turned into TFrees, else TVars*)
+
+fun infer_types prt prT tsig const_type def_type def_sort
+    map_const map_type map_sort used freeze pat_Ts raw_ts =
+  let
+    val {classes, arities, ...} = Type.rep_tsig tsig;
+    val pat_Ts' = map (Type.cert_typ tsig) pat_Ts;
+    val is_const = is_some o const_type;
+    val raw_ts' =
+      map (decode_types tsig is_const def_type def_sort map_const map_type map_sort) raw_ts;
+    val (ts, Ts, unifier) = basic_infer_types prt prT const_type
+      classes arities used freeze is_param raw_ts' pat_Ts';
+  in (ts, unifier) end;
+
 end;