moved get_sort to sign.ML;
authorwenzelm
Sun Apr 15 14:31:57 2007 +0200 (2007-04-15)
changeset 226987e6412e8d64b
parent 22697 92f8e9a8df78
child 22699 938c1011ac94
moved get_sort to sign.ML;
moved decode_types to Syntax/type_ext.ML;
moved mixfixT to Syntax/mixfix.ML;
proper infer_types, without decode/name lookup;
tuned;
src/Pure/type_infer.ML
     1.1 --- a/src/Pure/type_infer.ML	Sun Apr 15 14:31:56 2007 +0200
     1.2 +++ b/src/Pure/type_infer.ML	Sun Apr 15 14:31:57 2007 +0200
     1.3 @@ -2,62 +2,56 @@
     1.4      ID:         $Id$
     1.5      Author:     Stefan Berghofer and Markus Wenzel, TU Muenchen
     1.6  
     1.7 -Type inference.
     1.8 +Simple type inference.
     1.9  *)
    1.10  
    1.11  signature TYPE_INFER =
    1.12  sig
    1.13    val anyT: sort -> typ
    1.14    val logicT: typ
    1.15 -  val mixfixT: Syntax.mixfix -> typ
    1.16    val polymorphicT: typ -> typ
    1.17 -  val appl_error: Pretty.pp -> string -> term -> typ -> term -> typ -> string list
    1.18    val constrain: term -> typ -> term
    1.19    val param: int -> string * sort -> typ
    1.20    val paramify_dummies: typ -> int -> typ * int
    1.21 -  val get_sort: Type.tsig -> (indexname -> sort option) -> (sort -> sort)
    1.22 -    -> (indexname * sort) list -> indexname -> sort
    1.23 -  val infer_types: Pretty.pp
    1.24 -    -> Type.tsig -> (string -> typ option) -> (indexname -> typ option)
    1.25 -    -> (indexname -> sort option) -> (string -> string) -> (typ -> typ)
    1.26 -    -> (sort -> sort) -> Name.context -> bool -> typ list -> term list
    1.27 -    -> term list * (indexname * typ) list
    1.28 +  val appl_error: Pretty.pp -> string -> term -> typ -> term -> typ -> string list
    1.29 +  val infer_types: Pretty.pp -> Type.tsig ->
    1.30 +    (string -> typ option) -> (indexname -> typ option) ->
    1.31 +    Name.context -> bool -> (term * typ) list -> (term * typ) list * (indexname * typ) list
    1.32  end;
    1.33  
    1.34  structure TypeInfer: TYPE_INFER =
    1.35  struct
    1.36  
    1.37  
    1.38 -(** term encodings **)
    1.39 +(** type parameters and constraints **)
    1.40  
    1.41 -(*
    1.42 -  Flavours of term encodings:
    1.43 +fun anyT S = TFree ("'_dummy_", S);
    1.44 +val logicT = anyT [];
    1.45  
    1.46 -    parse trees (type term):
    1.47 -      A very complicated structure produced by the syntax module's
    1.48 -      read functions.  Encodes types and sorts as terms; may contain
    1.49 -      explicit constraints and partial typing information (where
    1.50 -      dummies serve as wildcards).
    1.51 +(*indicate polymorphic Vars*)
    1.52 +fun polymorphicT T = Type ("_polymorphic_", [T]);
    1.53  
    1.54 -      Parse trees are INTERNAL! Users should never encounter them,
    1.55 -      except in parse / print translation functions.
    1.56 +fun constrain t T =
    1.57 +  if T = dummyT then t
    1.58 +  else Const ("_type_constraint_", T --> T) $ t;
    1.59 +
    1.60 +
    1.61 +(* user parameters *)
    1.62  
    1.63 -    raw terms (type term):
    1.64 -      Provide the user interface to type inferences.  They may contain
    1.65 -      partial type information (dummies are wildcards) or explicit
    1.66 -      type constraints (introduced via constrain: term -> typ ->
    1.67 -      term).
    1.68 +fun is_param_default (x, _) = size x > 0 andalso ord x = ord "?";
    1.69 +fun param i (x, S) = TVar (("?" ^ x, i), S);
    1.70 +
    1.71 +val paramify_dummies =
    1.72 +  let
    1.73 +    fun dummy S maxidx = (param (maxidx + 1) ("'dummy", S), maxidx + 1);
    1.74  
    1.75 -      The type inference function also lets users specify a certain
    1.76 -      subset of TVars to be treated as non-rigid inference parameters.
    1.77 -
    1.78 -    preterms (type preterm):
    1.79 -      The internal representation for type inference.
    1.80 -
    1.81 -    well-typed term (type term):
    1.82 -      Fully typed lambda terms to be accepted by appropriate
    1.83 -      certification functions.
    1.84 -*)
    1.85 +    fun paramify (TFree ("'_dummy_", S)) maxidx = dummy S maxidx
    1.86 +      | paramify (Type ("dummy", _)) maxidx = dummy [] maxidx
    1.87 +      | paramify (Type (a, Ts)) maxidx =
    1.88 +          let val (Ts', maxidx') = fold_map paramify Ts maxidx
    1.89 +          in (Type (a, Ts'), maxidx') end
    1.90 +      | paramify T maxidx = (T, maxidx);
    1.91 +  in paramify end;
    1.92  
    1.93  
    1.94  
    1.95 @@ -103,16 +97,6 @@
    1.96  
    1.97  (* pretyp_of *)
    1.98  
    1.99 -fun anyT S = TFree ("'_dummy_", S);
   1.100 -val logicT = anyT [];
   1.101 -
   1.102 -fun mixfixT (Binder _) = (logicT --> logicT) --> logicT
   1.103 -  | mixfixT mx = replicate (Syntax.mixfix_args mx) logicT ---> logicT;
   1.104 -
   1.105 -
   1.106 -(*indicate polymorphic Vars*)
   1.107 -fun polymorphicT T = Type ("_polymorphic_", [T]);
   1.108 -
   1.109  fun pretyp_of is_param typ params =
   1.110    let
   1.111      val params' = fold_atyps
   1.112 @@ -151,11 +135,10 @@
   1.113        tm vparams;
   1.114      fun var_param xi = the (Vartab.lookup vparams' xi);
   1.115  
   1.116 -
   1.117      val preT_of = pretyp_of is_param;
   1.118      fun polyT_of T = fst (pretyp_of (K true) T Vartab.empty);
   1.119  
   1.120 -    fun constrain T t ps =
   1.121 +    fun constraint T t ps =
   1.122        if T = dummyT then (t, ps)
   1.123        else
   1.124          let val (T', ps') = preT_of T ps
   1.125 @@ -163,13 +146,13 @@
   1.126  
   1.127      fun pre_of (Const (c, T)) ps =
   1.128            (case const_type c of
   1.129 -            SOME U => constrain T (PConst (c, polyT_of U)) ps
   1.130 +            SOME U => constraint T (PConst (c, polyT_of U)) ps
   1.131            | NONE => raise TYPE ("No such constant: " ^ quote c, [], []))
   1.132        | pre_of (Var (xi, Type ("_polymorphic_", [T]))) ps = (PVar (xi, polyT_of T), ps)
   1.133 -      | pre_of (Var (xi, T)) ps = constrain T (PVar (xi, var_param xi)) ps
   1.134 -      | pre_of (Free (x, T)) ps = constrain T (PFree (x, var_param (x, ~1))) ps
   1.135 +      | pre_of (Var (xi, T)) ps = constraint T (PVar (xi, var_param xi)) ps
   1.136 +      | pre_of (Free (x, T)) ps = constraint T (PFree (x, var_param (x, ~1))) ps
   1.137        | pre_of (Const ("_type_constraint_", Type ("fun", [T, _])) $ t) ps =
   1.138 -          uncurry (constrain T) (pre_of t ps)
   1.139 +          pre_of t ps |-> constraint T
   1.140        | pre_of (Bound i) ps = (PBound i, ps)
   1.141        | pre_of (Abs (x, T, t)) ps =
   1.142            let
   1.143 @@ -250,7 +233,6 @@
   1.144  
   1.145  exception NO_UNIFIER of string;
   1.146  
   1.147 -
   1.148  fun unify pp tsig =
   1.149    let
   1.150  
   1.151 @@ -310,6 +292,8 @@
   1.152  
   1.153  (** type inference **)
   1.154  
   1.155 +(* appl_error *)
   1.156 +
   1.157  fun appl_error pp why t T u U =
   1.158   ["Type error in application: " ^ why,
   1.159    "",
   1.160 @@ -392,13 +376,27 @@
   1.161    in inf [] end;
   1.162  
   1.163  
   1.164 -(* basic_infer_types *)
   1.165 +(* infer_types *)
   1.166  
   1.167 -fun basic_infer_types pp tsig const_type used freeze is_param ts Ts =
   1.168 +fun infer_types pp tsig const_type var_type used freeze args =
   1.169    let
   1.170 +    (*certify types*)
   1.171 +    val certT = Type.cert_typ tsig;
   1.172 +    val (raw_ts, raw_Ts) = split_list args;
   1.173 +    val ts = map (Term.map_types certT) raw_ts;
   1.174 +    val Ts = map certT raw_Ts;
   1.175 +
   1.176 +    (*constrain vars*)
   1.177 +    val get_type = the_default dummyT o var_type;
   1.178 +    val constrain_vars = Term.map_aterms
   1.179 +      (fn Free (x, T) => constrain (Free (x, get_type (x, ~1))) T
   1.180 +        | Var (xi, T) => constrain (Var (xi, get_type xi)) T
   1.181 +        | t => t);
   1.182 +
   1.183      (*convert to preterms/typs*)
   1.184      val (Ts', Tps) = fold_map (pretyp_of (K true)) Ts Vartab.empty;
   1.185 -    val (ts', (vps, ps)) = fold_map (preterm_of const_type is_param) ts (Vartab.empty, Tps);
   1.186 +    val (ts', (vps, ps)) =
   1.187 +      fold_map (preterm_of const_type is_param_default o constrain_vars) ts (Vartab.empty, Tps);
   1.188  
   1.189      (*run type inference*)
   1.190      val tTs' = ListPair.map Constraint (ts', Ts');
   1.191 @@ -415,118 +413,6 @@
   1.192        else (fn (x, S) => PTVar ((x, 0), S));
   1.193      val (final_Ts, final_ts) = typs_terms_of used mk_var "" (Ts', ts');
   1.194      val final_env = map (apsnd simple_typ_of) env;
   1.195 -  in (final_ts, final_Ts, final_env) end;
   1.196 -
   1.197 -
   1.198 -
   1.199 -(** type inference **)
   1.200 -
   1.201 -(* user constraints *)
   1.202 -
   1.203 -fun constrain t T =
   1.204 -  if T = dummyT then t
   1.205 -  else Const ("_type_constraint_", T --> T) $ t;
   1.206 -
   1.207 -
   1.208 -(* user parameters *)
   1.209 -
   1.210 -fun is_param (x, _) = size x > 0 andalso ord x = ord "?";
   1.211 -fun param i (x, S) = TVar (("?" ^ x, i), S);
   1.212 -
   1.213 -val paramify_dummies =
   1.214 -  let
   1.215 -    fun dummy S maxidx = (param (maxidx + 1) ("'dummy", S), maxidx + 1);
   1.216 -
   1.217 -    fun paramify (TFree ("'_dummy_", S)) maxidx = dummy S maxidx
   1.218 -      | paramify (Type ("dummy", _)) maxidx = dummy [] maxidx
   1.219 -      | paramify (Type (a, Ts)) maxidx =
   1.220 -          let val (Ts', maxidx') = fold_map paramify Ts maxidx
   1.221 -          in (Type (a, Ts'), maxidx') end
   1.222 -      | paramify T maxidx = (T, maxidx);
   1.223 -  in paramify end;
   1.224 -
   1.225 -
   1.226 -(* get sort constraints *)
   1.227 -
   1.228 -fun get_sort tsig def_sort map_sort raw_env =
   1.229 -  let
   1.230 -    fun eq ((xi, S), (xi', S')) =
   1.231 -      Term.eq_ix (xi, xi') andalso Type.eq_sort tsig (S, S');
   1.232 -
   1.233 -    val env = distinct eq (map (apsnd map_sort) raw_env);
   1.234 -    val _ = (case duplicates (eq_fst (op =)) env of [] => ()
   1.235 -      | dups => error ("Inconsistent sort constraints for type variable(s) "
   1.236 -          ^ commas_quote (map (Term.string_of_vname' o fst) dups)));
   1.237 -
   1.238 -    fun get xi =
   1.239 -      (case (AList.lookup (op =) env xi, def_sort xi) of
   1.240 -        (NONE, NONE) => Type.defaultS tsig
   1.241 -      | (NONE, SOME S) => S
   1.242 -      | (SOME S, NONE) => S
   1.243 -      | (SOME S, SOME S') =>
   1.244 -          if Type.eq_sort tsig (S, S') then S'
   1.245 -          else error ("Sort constraint inconsistent with default for type variable " ^
   1.246 -            quote (Term.string_of_vname' xi)));
   1.247 -  in get end;
   1.248 -
   1.249 -
   1.250 -(* decode_types -- transform parse tree into raw term *)
   1.251 -
   1.252 -fun decode_types tsig is_const def_type def_sort map_const map_type map_sort tm =
   1.253 -  let
   1.254 -    fun get_type xi = the_default dummyT (def_type xi);
   1.255 -    fun is_free x = is_some (def_type (x, ~1));
   1.256 -    val raw_env = Syntax.raw_term_sorts tm;
   1.257 -    val sort_of = get_sort tsig def_sort map_sort raw_env;
   1.258 -
   1.259 -    val certT = Type.cert_typ tsig o map_type;
   1.260 -    fun decodeT t = certT (Syntax.typ_of_term sort_of map_sort t);
   1.261 -
   1.262 -    fun decode (Const ("_constrain", _) $ t $ typ) =
   1.263 -          constrain (decode t) (decodeT typ)
   1.264 -      | decode (Const ("_constrainAbs", _) $ (Abs (x, T, t)) $ typ) =
   1.265 -          if T = dummyT then Abs (x, decodeT typ, decode t)
   1.266 -          else constrain (Abs (x, certT T, decode t)) (decodeT typ --> dummyT)
   1.267 -      | decode (Abs (x, T, t)) = Abs (x, certT T, decode t)
   1.268 -      | decode (t $ u) = decode t $ decode u
   1.269 -      | decode (Const (x, T)) =
   1.270 -          let val c = (case try (unprefix Syntax.constN) x of SOME c => c | NONE => map_const x)
   1.271 -          in Const (c, certT T) end
   1.272 -      | decode (Free (x, T)) =
   1.273 -          let val c = map_const x in
   1.274 -            if not (is_free x) andalso (is_const c orelse NameSpace.is_qualified c) then
   1.275 -              Const (c, certT T)
   1.276 -            else if T = dummyT then Free (x, get_type (x, ~1))
   1.277 -            else constrain (Free (x, certT T)) (get_type (x, ~1))
   1.278 -          end
   1.279 -      | decode (Var (xi, T)) =
   1.280 -          if T = dummyT then Var (xi, get_type xi)
   1.281 -          else constrain (Var (xi, certT T)) (get_type xi)
   1.282 -      | decode (t as Bound _) = t;
   1.283 -  in decode tm end;
   1.284 -
   1.285 -
   1.286 -(* infer_types *)
   1.287 -
   1.288 -(*Given [T1,...,Tn] and [t1,...,tn], ensure that the type of ti
   1.289 -  unifies with Ti (for i=1,...,n).
   1.290 -
   1.291 -  tsig: type signature
   1.292 -  const_type: name mapping and signature lookup
   1.293 -  def_type: partial map from indexnames to types (constrains Frees and Vars)
   1.294 -  def_sort: partial map from indexnames to sorts (constrains TFrees and TVars)
   1.295 -  used: context of already used type variables
   1.296 -  freeze: if true then generated parameters are turned into TFrees, else TVars*)
   1.297 -
   1.298 -fun infer_types pp tsig const_type def_type def_sort
   1.299 -    map_const map_type map_sort used freeze pat_Ts raw_ts =
   1.300 -  let
   1.301 -    val pat_Ts' = map (Type.cert_typ tsig) pat_Ts;
   1.302 -    val is_const = is_some o const_type;
   1.303 -    val raw_ts' =
   1.304 -      map (decode_types tsig is_const def_type def_sort map_const map_type map_sort) raw_ts;
   1.305 -    val (ts, Ts, unifier) =
   1.306 -      basic_infer_types pp tsig const_type used freeze is_param raw_ts' pat_Ts';
   1.307 -  in (ts, unifier) end;
   1.308 +  in (final_ts ~~ final_Ts, final_env) end;
   1.309  
   1.310  end;