src/Pure/type_infer.ML
author wenzelm
Sun Sep 12 20:47:47 2010 +0200 (2010-09-12)
changeset 39290 44e4d8dfd6bf
parent 39289 92b50c8bb67b
child 39291 4b632bb847a8
permissions -rw-r--r--
load type_infer.ML later -- proper context for Type_Infer.infer_types;
renamed Type_Infer.polymorphicT to Type.mark_polymorphic;
     1 (*  Title:      Pure/type_infer.ML
     2     Author:     Stefan Berghofer and Markus Wenzel, TU Muenchen
     3 
     4 Simple type inference.
     5 *)
     6 
     7 signature TYPE_INFER =
     8 sig
     9   val anyT: sort -> typ
    10   val is_param: indexname -> bool
    11   val param: int -> string * sort -> typ
    12   val paramify_vars: typ -> typ
    13   val paramify_dummies: typ -> int -> typ * int
    14   val fixate_params: Name.context -> term list -> term list
    15   val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
    16     term list -> term list
    17 end;
    18 
    19 structure Type_Infer: TYPE_INFER =
    20 struct
    21 
    22 
    23 (** type parameters and constraints **)
    24 
    25 fun anyT S = TFree ("'_dummy_", S);
    26 
    27 
    28 (* type inference parameters -- may get instantiated *)
    29 
    30 fun is_param (x, _: int) = String.isPrefix "?" x;
    31 fun param i (x, S) = TVar (("?" ^ x, i), S);
    32 
    33 val paramify_vars =
    34   Same.commit
    35     (Term_Subst.map_atypsT_same
    36       (fn TVar ((x, i), S) => (param i (x, S)) | _ => raise Same.SAME));
    37 
    38 val paramify_dummies =
    39   let
    40     fun dummy S maxidx = (param (maxidx + 1) ("'dummy", S), maxidx + 1);
    41 
    42     fun paramify (TFree ("'_dummy_", S)) maxidx = dummy S maxidx
    43       | paramify (Type ("dummy", _)) maxidx = dummy [] maxidx
    44       | paramify (Type (a, Ts)) maxidx =
    45           let val (Ts', maxidx') = fold_map paramify Ts maxidx
    46           in (Type (a, Ts'), maxidx') end
    47       | paramify T maxidx = (T, maxidx);
    48   in paramify end;
    49 
    50 fun fixate_params name_context ts =
    51   let
    52     fun subst_param (xi, S) (inst, used) =
    53       if is_param xi then
    54         let
    55           val [a] = Name.invents used Name.aT 1;
    56           val used' = Name.declare a used;
    57         in (((xi, S), TFree (a, S)) :: inst, used') end
    58       else (inst, used);
    59     val name_context' = (fold o fold_types) Term.declare_typ_names ts name_context;
    60     val (inst, _) = fold_rev subst_param (fold Term.add_tvars ts []) ([], name_context');
    61   in (map o map_types) (Term_Subst.instantiateT inst) ts end;
    62 
    63 
    64 
    65 (** pretyps and preterms **)
    66 
    67 datatype pretyp =
    68   PType of string * pretyp list |
    69   PTFree of string * sort |
    70   PTVar of indexname * sort |
    71   Param of int * sort;
    72 
    73 datatype preterm =
    74   PConst of string * pretyp |
    75   PFree of string * pretyp |
    76   PVar of indexname * pretyp |
    77   PBound of int |
    78   PAbs of string * pretyp * preterm |
    79   PAppl of preterm * preterm |
    80   Constraint of preterm * pretyp;
    81 
    82 
    83 (* utils *)
    84 
    85 fun deref tye (T as Param (i, S)) =
    86       (case Inttab.lookup tye i of
    87         NONE => T
    88       | SOME U => deref tye U)
    89   | deref tye T = T;
    90 
    91 fun fold_pretyps f (PConst (_, T)) x = f T x
    92   | fold_pretyps f (PFree (_, T)) x = f T x
    93   | fold_pretyps f (PVar (_, T)) x = f T x
    94   | fold_pretyps _ (PBound _) x = x
    95   | fold_pretyps f (PAbs (_, T, t)) x = fold_pretyps f t (f T x)
    96   | fold_pretyps f (PAppl (t, u)) x = fold_pretyps f u (fold_pretyps f t x)
    97   | fold_pretyps f (Constraint (t, T)) x = f T (fold_pretyps f t x);
    98 
    99 
   100 
   101 (** raw typs/terms to pretyps/preterms **)
   102 
   103 (* pretyp_of *)
   104 
   105 fun pretyp_of typ params_idx =
   106   let
   107     val (params', idx) = fold_atyps
   108       (fn TVar (xi as (x, _), S) =>
   109           (fn ps_idx as (ps, idx) =>
   110             if is_param xi andalso not (Vartab.defined ps xi)
   111             then (Vartab.update (xi, Param (idx, S)) ps, idx + 1) else ps_idx)
   112         | _ => I) typ params_idx;
   113 
   114     fun pre_of (TVar (v as (xi, _))) idx =
   115           (case Vartab.lookup params' xi of
   116             NONE => PTVar v
   117           | SOME p => p, idx)
   118       | pre_of (TFree ("'_dummy_", S)) idx = (Param (idx, S), idx + 1)
   119       | pre_of (TFree v) idx = (PTFree v, idx)
   120       | pre_of (T as Type (a, Ts)) idx =
   121           if T = dummyT then (Param (idx, []), idx + 1)
   122           else
   123             let val (Ts', idx') = fold_map pre_of Ts idx
   124             in (PType (a, Ts'), idx') end;
   125 
   126     val (ptyp, idx') = pre_of typ idx;
   127   in (ptyp, (params', idx')) end;
   128 
   129 
   130 (* preterm_of *)
   131 
   132 fun preterm_of const_type tm (vparams, params, idx) =
   133   let
   134     fun add_vparm xi (ps_idx as (ps, idx)) =
   135       if not (Vartab.defined ps xi) then
   136         (Vartab.update (xi, Param (idx, [])) ps, idx + 1)
   137       else ps_idx;
   138 
   139     val (vparams', idx') = fold_aterms
   140       (fn Var (_, Type ("_polymorphic_", _)) => I
   141         | Var (xi, _) => add_vparm xi
   142         | Free (x, _) => add_vparm (x, ~1)
   143         | _ => I)
   144       tm (vparams, idx);
   145     fun var_param xi = the (Vartab.lookup vparams' xi);
   146 
   147     fun polyT_of T idx = apsnd snd (pretyp_of (paramify_vars T) (Vartab.empty, idx));
   148 
   149     fun constraint T t ps =
   150       if T = dummyT then (t, ps)
   151       else
   152         let val (T', ps') = pretyp_of T ps
   153         in (Constraint (t, T'), ps') end;
   154 
   155     fun pre_of (Const (c, T)) (ps, idx) =
   156           (case const_type c of
   157             SOME U =>
   158               let val (pU, idx') = polyT_of U idx
   159               in constraint T (PConst (c, pU)) (ps, idx') end
   160           | NONE => raise TYPE ("Undeclared constant: " ^ quote c, [], []))
   161       | pre_of (Var (xi, Type ("_polymorphic_", [T]))) (ps, idx) =
   162           let val (pT, idx') = polyT_of T idx
   163           in (PVar (xi, pT), (ps, idx')) end
   164       | pre_of (Var (xi, T)) ps_idx = constraint T (PVar (xi, var_param xi)) ps_idx
   165       | pre_of (Free (x, T)) ps_idx = constraint T (PFree (x, var_param (x, ~1))) ps_idx
   166       | pre_of (Const ("_type_constraint_", Type ("fun", [T, _])) $ t) ps_idx =
   167           pre_of t ps_idx |-> constraint T
   168       | pre_of (Bound i) ps_idx = (PBound i, ps_idx)
   169       | pre_of (Abs (x, T, t)) ps_idx =
   170           let
   171             val (T', ps_idx') = pretyp_of T ps_idx;
   172             val (t', ps_idx'') = pre_of t ps_idx';
   173           in (PAbs (x, T', t'), ps_idx'') end
   174       | pre_of (t $ u) ps_idx =
   175           let
   176             val (t', ps_idx') = pre_of t ps_idx;
   177             val (u', ps_idx'') = pre_of u ps_idx';
   178           in (PAppl (t', u'), ps_idx'') end;
   179 
   180     val (tm', (params', idx'')) = pre_of tm (params, idx');
   181   in (tm', (vparams', params', idx'')) end;
   182 
   183 
   184 
   185 (** pretyps/terms to typs/terms **)
   186 
   187 (* add_parms *)
   188 
   189 fun add_parmsT tye T =
   190   (case deref tye T of
   191     PType (_, Ts) => fold (add_parmsT tye) Ts
   192   | Param (i, _) => insert (op =) i
   193   | _ => I);
   194 
   195 fun add_parms tye = fold_pretyps (add_parmsT tye);
   196 
   197 
   198 (* add_names *)
   199 
   200 fun add_namesT tye T =
   201   (case deref tye T of
   202     PType (_, Ts) => fold (add_namesT tye) Ts
   203   | PTFree (x, _) => Name.declare x
   204   | PTVar ((x, _), _) => Name.declare x
   205   | Param _ => I);
   206 
   207 fun add_names tye = fold_pretyps (add_namesT tye);
   208 
   209 
   210 (* simple_typ/term_of *)
   211 
   212 fun simple_typ_of tye f T =
   213   (case deref tye T of
   214     PType (a, Ts) => Type (a, map (simple_typ_of tye f) Ts)
   215   | PTFree v => TFree v
   216   | PTVar v => TVar v
   217   | Param (i, S) => TVar (f i, S));
   218 
   219 (*convert types, drop constraints*)
   220 fun simple_term_of tye f (PConst (c, T)) = Const (c, simple_typ_of tye f T)
   221   | simple_term_of tye f (PFree (x, T)) = Free (x, simple_typ_of tye f T)
   222   | simple_term_of tye f (PVar (xi, T)) = Var (xi, simple_typ_of tye f T)
   223   | simple_term_of tye f (PBound i) = Bound i
   224   | simple_term_of tye f (PAbs (x, T, t)) =
   225       Abs (x, simple_typ_of tye f T, simple_term_of tye f t)
   226   | simple_term_of tye f (PAppl (t, u)) =
   227       simple_term_of tye f t $ simple_term_of tye f u
   228   | simple_term_of tye f (Constraint (t, _)) = simple_term_of tye f t;
   229 
   230 
   231 (* typs_terms_of *)
   232 
   233 fun typs_terms_of ctxt tye (Ts, ts) =
   234   let
   235     val used = fold (add_names tye) ts (fold (add_namesT tye) Ts (Variable.names_of ctxt));
   236     val parms = rev (fold (add_parms tye) ts (fold (add_parmsT tye) Ts []));
   237     val names = Name.invents used ("?" ^ Name.aT) (length parms);
   238     val tab = Inttab.make (parms ~~ names);
   239 
   240     val maxidx = Variable.maxidx_of ctxt;
   241     fun f i = (the (Inttab.lookup tab i), maxidx + 1);
   242   in (map (simple_typ_of tye f) Ts, map (simple_term_of tye f) ts) end;
   243 
   244 
   245 
   246 (** order-sorted unification of types **)
   247 
   248 exception NO_UNIFIER of string * pretyp Inttab.table;
   249 
   250 fun unify ctxt pp =
   251   let
   252     val thy = ProofContext.theory_of ctxt;
   253     val arity_sorts = Type.arity_sorts pp (Sign.tsig_of thy);
   254 
   255 
   256     (* adjust sorts of parameters *)
   257 
   258     fun not_of_sort x S' S =
   259       "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
   260         Syntax.string_of_sort ctxt S;
   261 
   262     fun meet (_, []) tye_idx = tye_idx
   263       | meet (Param (i, S'), S) (tye_idx as (tye, idx)) =
   264           if Sign.subsort thy (S', S) then tye_idx
   265           else (Inttab.update_new (i,
   266             Param (idx, Sign.inter_sort thy (S', S))) tye, idx + 1)
   267       | meet (PType (a, Ts), S) (tye_idx as (tye, _)) =
   268           meets (Ts, arity_sorts a S
   269             handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
   270       | meet (PTFree (x, S'), S) (tye_idx as (tye, _)) =
   271           if Sign.subsort thy (S', S) then tye_idx
   272           else raise NO_UNIFIER (not_of_sort x S' S, tye)
   273       | meet (PTVar (xi, S'), S) (tye_idx as (tye, _)) =
   274           if Sign.subsort thy (S', S) then tye_idx
   275           else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
   276     and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
   277           meets (Ts, Ss) (meet (deref tye T, S) tye_idx)
   278       | meets _ tye_idx = tye_idx;
   279 
   280 
   281     (* occurs check and assignment *)
   282 
   283     fun occurs_check tye i (Param (i', S)) =
   284           if i = i' then raise NO_UNIFIER ("Occurs check!", tye)
   285           else
   286             (case Inttab.lookup tye i' of
   287               NONE => ()
   288             | SOME T => occurs_check tye i T)
   289       | occurs_check tye i (PType (_, Ts)) = List.app (occurs_check tye i) Ts
   290       | occurs_check _ _ _ = ();
   291 
   292     fun assign i (T as Param (i', _)) S tye_idx =
   293           if i = i' then tye_idx
   294           else tye_idx |> meet (T, S) |>> Inttab.update_new (i, T)
   295       | assign i T S (tye_idx as (tye, _)) =
   296           (occurs_check tye i T; tye_idx |> meet (T, S) |>> Inttab.update_new (i, T));
   297 
   298 
   299     (* unification *)
   300 
   301     fun show_tycon (a, Ts) =
   302       quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
   303 
   304     fun unif (T1, T2) (tye_idx as (tye, idx)) =
   305       (case (deref tye T1, deref tye T2) of
   306         (Param (i, S), T) => assign i T S tye_idx
   307       | (T, Param (i, S)) => assign i T S tye_idx
   308       | (PType (a, Ts), PType (b, Us)) =>
   309           if a <> b then
   310             raise NO_UNIFIER
   311               ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
   312           else fold unif (Ts ~~ Us) tye_idx
   313       | (T, U) => if T = U then tye_idx else raise NO_UNIFIER ("", tye));
   314 
   315   in unif end;
   316 
   317 
   318 
   319 (** type inference **)
   320 
   321 (* infer *)
   322 
   323 fun infer ctxt =
   324   let
   325     val pp = Syntax.pp ctxt;
   326 
   327 
   328     (* errors *)
   329 
   330     fun prep_output tye bs ts Ts =
   331       let
   332         val (Ts_bTs', ts') = typs_terms_of ctxt tye (Ts @ map snd bs, ts);
   333         val (Ts', Ts'') = chop (length Ts) Ts_bTs';
   334         fun prep t =
   335           let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
   336           in Term.subst_bounds (map Syntax.mark_boundT xs, t) end;
   337       in (map prep ts', Ts') end;
   338 
   339     fun err_loose i =
   340       raise TYPE ("Loose bound variable: B." ^ string_of_int i, [], []);
   341 
   342     fun unif_failed msg =
   343       "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
   344 
   345     fun err_appl msg tye bs t T u U =
   346       let
   347         val ([t', u'], [T', U']) = prep_output tye bs [t, u] [T, U];
   348         val text = unif_failed msg ^ Type.appl_error pp t' T' u' U' ^ "\n";
   349       in raise TYPE (text, [T', U'], [t', u']) end;
   350 
   351     fun err_constraint msg tye bs t T U =
   352       let
   353         val ([t'], [T', U']) = prep_output tye bs [t] [T, U];
   354         val text =
   355           unif_failed msg ^
   356             Type.appl_error pp (Const ("_type_constraint_", U' --> U')) (U' --> U') t' T' ^ "\n";
   357       in raise TYPE (text, [T', U'], [t']) end;
   358 
   359 
   360     (* main *)
   361 
   362     fun inf _ (PConst (_, T)) tye_idx = (T, tye_idx)
   363       | inf _ (PFree (_, T)) tye_idx = (T, tye_idx)
   364       | inf _ (PVar (_, T)) tye_idx = (T, tye_idx)
   365       | inf bs (PBound i) tye_idx =
   366           (snd (nth bs i handle Subscript => err_loose i), tye_idx)
   367       | inf bs (PAbs (x, T, t)) tye_idx =
   368           let val (U, tye_idx') = inf ((x, T) :: bs) t tye_idx
   369           in (PType ("fun", [T, U]), tye_idx') end
   370       | inf bs (PAppl (t, u)) tye_idx =
   371           let
   372             val (T, tye_idx') = inf bs t tye_idx;
   373             val (U, (tye, idx)) = inf bs u tye_idx';
   374             val V = Param (idx, []);
   375             val U_to_V = PType ("fun", [U, V]);
   376             val tye_idx'' = unify ctxt pp (U_to_V, T) (tye, idx + 1)
   377               handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U;
   378           in (V, tye_idx'') end
   379       | inf bs (Constraint (t, U)) tye_idx =
   380           let val (T, tye_idx') = inf bs t tye_idx in
   381             (T,
   382              unify ctxt pp (T, U) tye_idx'
   383                handle NO_UNIFIER (msg, tye) => err_constraint msg tye bs t T U)
   384           end;
   385 
   386   in inf [] end;
   387 
   388 
   389 (* infer_types *)
   390 
   391 fun infer_types ctxt const_type var_type raw_ts =
   392   let
   393     (*constrain vars*)
   394     val get_type = the_default dummyT o var_type;
   395     val constrain_vars = Term.map_aterms
   396       (fn Free (x, T) => Type.constraint T (Free (x, get_type (x, ~1)))
   397         | Var (xi, T) => Type.constraint T (Var (xi, get_type xi))
   398         | t => t);
   399 
   400     (*convert to preterms*)
   401     val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
   402     val (ts', (_, _, idx)) =
   403       fold_map (preterm_of const_type o constrain_vars) ts
   404       (Vartab.empty, Vartab.empty, 0);
   405 
   406     (*do type inference*)
   407     val (tye, _) = fold (snd oo infer ctxt) ts' (Inttab.empty, idx);
   408   in #2 (typs_terms_of ctxt tye ([], ts')) end;
   409 
   410 end;