src/Pure/type_infer.ML
changeset 39296 e275d581a218
parent 39295 6e8b0672c6a2
child 40286 b928e3960446
equal deleted inserted replaced
39295:6e8b0672c6a2 39296:e275d581a218
     1 (*  Title:      Pure/type_infer.ML
     1 (*  Title:      Pure/type_infer.ML
     2     Author:     Stefan Berghofer and Markus Wenzel, TU Muenchen
     2     Author:     Stefan Berghofer and Markus Wenzel, TU Muenchen
     3 
     3 
     4 Simple type inference.
     4 Representation of type-inference problems.  Simple type inference.
     5 *)
     5 *)
     6 
     6 
     7 signature TYPE_INFER =
     7 signature TYPE_INFER =
     8 sig
     8 sig
     9   val anyT: sort -> typ
       
    10   val is_param: indexname -> bool
     9   val is_param: indexname -> bool
    11   val is_paramT: typ -> bool
    10   val is_paramT: typ -> bool
    12   val param: int -> string * sort -> typ
    11   val param: int -> string * sort -> typ
       
    12   val anyT: sort -> typ
    13   val paramify_vars: typ -> typ
    13   val paramify_vars: typ -> typ
    14   val paramify_dummies: typ -> int -> typ * int
    14   val paramify_dummies: typ -> int -> typ * int
    15   val fixate_params: Proof.context -> term list -> term list
    15   val deref: typ Vartab.table -> typ -> typ
       
    16   val finish: Proof.context -> typ Vartab.table -> typ list * term list -> typ list * term list
       
    17   val fixate: Proof.context -> term list -> term list
       
    18   val prepare: Proof.context -> (string -> typ option) -> (string * int -> typ option) ->
       
    19     term list -> int * term list
    16   val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
    20   val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
    17     term list -> term list
    21     term list -> term list
    18 end;
    22 end;
    19 
    23 
    20 structure Type_Infer: TYPE_INFER =
    24 structure Type_Infer: TYPE_INFER =
    21 struct
    25 struct
    22 
    26 
    23 
       
    24 (** type parameters and constraints **)
    27 (** type parameters and constraints **)
    25 
       
    26 fun anyT S = TFree ("'_dummy_", S);
       
    27 
       
    28 
    28 
    29 (* type inference parameters -- may get instantiated *)
    29 (* type inference parameters -- may get instantiated *)
    30 
    30 
    31 fun is_param (x, _: int) = String.isPrefix "?" x;
    31 fun is_param (x, _: int) = String.isPrefix "?" x;
    32 
    32 
    34   | is_paramT _ = false;
    34   | is_paramT _ = false;
    35 
    35 
    36 fun param i (x, S) = TVar (("?" ^ x, i), S);
    36 fun param i (x, S) = TVar (("?" ^ x, i), S);
    37 
    37 
    38 fun mk_param i S = TVar (("?'a", i), S);
    38 fun mk_param i S = TVar (("?'a", i), S);
       
    39 
       
    40 
       
    41 (* pre-stage parameters *)
       
    42 
       
    43 fun anyT S = TFree ("'_dummy_", S);
    39 
    44 
    40 val paramify_vars =
    45 val paramify_vars =
    41   Same.commit
    46   Same.commit
    42     (Term_Subst.map_atypsT_same
    47     (Term_Subst.map_atypsT_same
    43       (fn TVar ((x, i), S) => (param i (x, S)) | _ => raise Same.SAME));
    48       (fn TVar ((x, i), S) => (param i (x, S)) | _ => raise Same.SAME));
    51       | paramify (Type (a, Ts)) maxidx =
    56       | paramify (Type (a, Ts)) maxidx =
    52           let val (Ts', maxidx') = fold_map paramify Ts maxidx
    57           let val (Ts', maxidx') = fold_map paramify Ts maxidx
    53           in (Type (a, Ts'), maxidx') end
    58           in (Type (a, Ts'), maxidx') end
    54       | paramify T maxidx = (T, maxidx);
    59       | paramify T maxidx = (T, maxidx);
    55   in paramify end;
    60   in paramify end;
    56 
       
    57 fun fixate_params ctxt ts =
       
    58   let
       
    59     fun subst_param (xi, S) (inst, used) =
       
    60       if is_param xi then
       
    61         let
       
    62           val [a] = Name.invents used Name.aT 1;
       
    63           val used' = Name.declare a used;
       
    64         in (((xi, S), TFree (a, S)) :: inst, used') end
       
    65       else (inst, used);
       
    66     val used = (fold o fold_types) Term.declare_typ_names ts (Variable.names_of ctxt);
       
    67     val (inst, _) = fold_rev subst_param (fold Term.add_tvars ts []) ([], used);
       
    68   in (map o map_types) (Term_Subst.instantiateT inst) ts end;
       
    69 
    61 
    70 
    62 
    71 
    63 
    72 (** prepare types/terms: create inference parameters **)
    64 (** prepare types/terms: create inference parameters **)
    73 
    65 
   154     val (tm', (params', idx'')) = prepare tm (params, idx');
   146     val (tm', (params', idx'')) = prepare tm (params, idx');
   155   in (tm', (vparams', params', idx'')) end;
   147   in (tm', (vparams', params', idx'')) end;
   156 
   148 
   157 
   149 
   158 
   150 
   159 (** finish types/terms: standardize remaining parameters **)
   151 (** results **)
   160 
   152 
   161 (* dereferenced views *)
   153 (* dereferenced views *)
   162 
   154 
   163 fun deref tye (T as TVar (xi, _)) =
   155 fun deref tye (T as TVar (xi, _)) =
   164       (case Vartab.lookup tye xi of
   156       (case Vartab.lookup tye xi of
   177     Type (_, Ts) => fold (add_names tye) Ts
   169     Type (_, Ts) => fold (add_names tye) Ts
   178   | TFree (x, _) => Name.declare x
   170   | TFree (x, _) => Name.declare x
   179   | TVar ((x, i), _) => if is_param (x, i) then I else Name.declare x);
   171   | TVar ((x, i), _) => if is_param (x, i) then I else Name.declare x);
   180 
   172 
   181 
   173 
   182 (* finish *)
   174 (* finish -- standardize remaining parameters *)
   183 
   175 
   184 fun finish ctxt tye (Ts, ts) =
   176 fun finish ctxt tye (Ts, ts) =
   185   let
   177   let
   186     val used =
   178     val used =
   187       (fold o fold_types) (add_names tye) ts (fold (add_names tye) Ts (Variable.names_of ctxt));
   179       (fold o fold_types) (add_names tye) ts (fold (add_names tye) Ts (Variable.names_of ctxt));
   196       | U as TVar (xi, S) =>
   188       | U as TVar (xi, S) =>
   197           (case Vartab.lookup tab xi of
   189           (case Vartab.lookup tab xi of
   198             NONE => U
   190             NONE => U
   199           | SOME a => TVar ((a, 0), S)));
   191           | SOME a => TVar ((a, 0), S)));
   200   in (map finish_typ Ts, map (Type.strip_constraints o Term.map_types finish_typ) ts) end;
   192   in (map finish_typ Ts, map (Type.strip_constraints o Term.map_types finish_typ) ts) end;
       
   193 
       
   194 
       
   195 (* fixate -- introduce fresh type variables *)
       
   196 
       
   197 fun fixate ctxt ts =
       
   198   let
       
   199     fun subst_param (xi, S) (inst, used) =
       
   200       if is_param xi then
       
   201         let
       
   202           val [a] = Name.invents used Name.aT 1;
       
   203           val used' = Name.declare a used;
       
   204         in (((xi, S), TFree (a, S)) :: inst, used') end
       
   205       else (inst, used);
       
   206     val used = (fold o fold_types) Term.declare_typ_names ts (Variable.names_of ctxt);
       
   207     val (inst, _) = fold_rev subst_param (fold Term.add_tvars ts []) ([], used);
       
   208   in (map o map_types) (Term_Subst.instantiateT inst) ts end;
   201 
   209 
   202 
   210 
   203 
   211 
   204 (** order-sorted unification of types **)
   212 (** order-sorted unification of types **)
   205 
   213 
   269 
   277 
   270   in unif end;
   278   in unif end;
   271 
   279 
   272 
   280 
   273 
   281 
   274 (** type inference **)
   282 (** simple type inference **)
   275 
   283 
   276 (* infer *)
   284 (* infer *)
   277 
   285 
   278 fun infer ctxt =
   286 fun infer ctxt =
   279   let
   287   let
   321           in (V, tye_idx'') end;
   329           in (V, tye_idx'') end;
   322 
   330 
   323   in inf [] end;
   331   in inf [] end;
   324 
   332 
   325 
   333 
   326 (* infer_types *)
   334 (* main interfaces *)
   327 
   335 
   328 fun infer_types ctxt const_type var_type raw_ts =
   336 fun prepare ctxt const_type var_type raw_ts =
   329   let
   337   let
   330     (*constrain vars*)
       
   331     val get_type = the_default dummyT o var_type;
   338     val get_type = the_default dummyT o var_type;
   332     val constrain_vars = Term.map_aterms
   339     val constrain_vars = Term.map_aterms
   333       (fn Free (x, T) => Type.constraint T (Free (x, get_type (x, ~1)))
   340       (fn Free (x, T) => Type.constraint T (Free (x, get_type (x, ~1)))
   334         | Var (xi, T) => Type.constraint T (Var (xi, get_type xi))
   341         | Var (xi, T) => Type.constraint T (Var (xi, get_type xi))
   335         | t => t);
   342         | t => t);
   336 
   343 
   337     (*convert to preterms*)
       
   338     val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
   344     val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
   339     val (ts', (_, _, idx)) =
   345     val (ts', (_, _, idx)) =
   340       fold_map (prepare_term const_type o constrain_vars) ts
   346       fold_map (prepare_term const_type o constrain_vars) ts
   341       (Vartab.empty, Vartab.empty, 0);
   347       (Vartab.empty, Vartab.empty, 0);
   342 
   348   in (idx, ts') end;
   343     (*do type inference*)
   349 
   344     val (tye, _) = fold (snd oo infer ctxt) ts' (Vartab.empty, idx);
   350 fun infer_types ctxt const_type var_type raw_ts =
   345   in #2 (finish ctxt tye ([], ts')) end;
   351   let
       
   352     val (idx, ts) = prepare ctxt const_type var_type raw_ts;
       
   353     val (tye, _) = fold (snd oo infer ctxt) ts (Vartab.empty, idx);
       
   354     val (_, ts') = finish ctxt tye ([], ts);
       
   355   in ts' end;
   346 
   356 
   347 end;
   357 end;