src/Pure/type_infer.ML
author wenzelm
Sat Apr 16 15:47:52 2011 +0200 (2011-04-16)
changeset 42360 da8817d01e7c
parent 42284 326f57825e1a
child 42383 0ae4ad40d7b5
permissions -rw-r--r--
modernized structure Proof_Context;
     1 (*  Title:      Pure/type_infer.ML
     2     Author:     Stefan Berghofer and Markus Wenzel, TU Muenchen
     3 
     4 Representation of type-inference problems.  Simple type inference.
     5 *)
     6 
     7 signature TYPE_INFER =
     8 sig
     9   val is_param: indexname -> bool
    10   val is_paramT: typ -> bool
    11   val param: int -> string * sort -> typ
    12   val mk_param: int -> sort -> typ
    13   val anyT: sort -> typ
    14   val paramify_vars: typ -> typ
    15   val paramify_dummies: typ -> int -> typ * int
    16   val deref: typ Vartab.table -> typ -> typ
    17   val finish: Proof.context -> typ Vartab.table -> typ list * term list -> typ list * term list
    18   val fixate: Proof.context -> term list -> term list
    19   val prepare: Proof.context -> (string -> typ option) -> (string * int -> typ option) ->
    20     term list -> int * term list
    21   val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
    22     term list -> term list
    23 end;
    24 
    25 structure Type_Infer: TYPE_INFER =
    26 struct
    27 
    28 (** type parameters and constraints **)
    29 
    30 (* type inference parameters -- may get instantiated *)
    31 
    32 fun is_param (x, _: int) = String.isPrefix "?" x;
    33 
    34 fun is_paramT (TVar (xi, _)) = is_param xi
    35   | is_paramT _ = false;
    36 
    37 fun param i (x, S) = TVar (("?" ^ x, i), S);
    38 
    39 fun mk_param i S = TVar (("?'a", i), S);
    40 
    41 
    42 (* pre-stage parameters *)
    43 
    44 fun anyT S = TFree ("'_dummy_", S);
    45 
    46 val paramify_vars =
    47   Same.commit
    48     (Term_Subst.map_atypsT_same
    49       (fn TVar ((x, i), S) => (param i (x, S)) | _ => raise Same.SAME));
    50 
    51 val paramify_dummies =
    52   let
    53     fun dummy S maxidx = (param (maxidx + 1) ("'dummy", S), maxidx + 1);
    54 
    55     fun paramify (TFree ("'_dummy_", S)) maxidx = dummy S maxidx
    56       | paramify (Type ("dummy", _)) maxidx = dummy [] maxidx
    57       | paramify (Type (a, Ts)) maxidx =
    58           let val (Ts', maxidx') = fold_map paramify Ts maxidx
    59           in (Type (a, Ts'), maxidx') end
    60       | paramify T maxidx = (T, maxidx);
    61   in paramify end;
    62 
    63 
    64 
    65 (** prepare types/terms: create inference parameters **)
    66 
    67 (* prepare_typ *)
    68 
    69 fun prepare_typ typ params_idx =
    70   let
    71     val (params', idx) = fold_atyps
    72       (fn TVar (xi as (x, _), S) =>
    73           (fn ps_idx as (ps, idx) =>
    74             if is_param xi andalso not (Vartab.defined ps xi)
    75             then (Vartab.update (xi, mk_param idx S) ps, idx + 1) else ps_idx)
    76         | _ => I) typ params_idx;
    77 
    78     fun prepare (T as Type (a, Ts)) idx =
    79           if T = dummyT then (mk_param idx [], idx + 1)
    80           else
    81             let val (Ts', idx') = fold_map prepare Ts idx
    82             in (Type (a, Ts'), idx') end
    83       | prepare (T as TVar (xi, _)) idx =
    84           (case Vartab.lookup params' xi of
    85             NONE => T
    86           | SOME p => p, idx)
    87       | prepare (TFree ("'_dummy_", S)) idx = (mk_param idx S, idx + 1)
    88       | prepare (T as TFree _) idx = (T, idx);
    89 
    90     val (typ', idx') = prepare typ idx;
    91   in (typ', (params', idx')) end;
    92 
    93 
    94 (* prepare_term *)
    95 
    96 fun prepare_term ctxt const_type tm (vparams, params, idx) =
    97   let
    98     fun add_vparm xi (ps_idx as (ps, idx)) =
    99       if not (Vartab.defined ps xi) then
   100         (Vartab.update (xi, mk_param idx []) ps, idx + 1)
   101       else ps_idx;
   102 
   103     val (vparams', idx') = fold_aterms
   104       (fn Var (_, Type ("_polymorphic_", _)) => I
   105         | Var (xi, _) => add_vparm xi
   106         | Free (x, _) => add_vparm (x, ~1)
   107         | _ => I)
   108       tm (vparams, idx);
   109     fun var_param xi = the (Vartab.lookup vparams' xi);
   110 
   111     fun polyT_of T idx = apsnd snd (prepare_typ (paramify_vars T) (Vartab.empty, idx));
   112 
   113     fun constraint T t ps =
   114       if T = dummyT then (t, ps)
   115       else
   116         let val (T', ps') = prepare_typ T ps
   117         in (Type.constraint T' t, ps') end;
   118 
   119     fun prepare (Const ("_type_constraint_", T) $ t) ps_idx =
   120           let
   121             fun err () =
   122               error ("Malformed internal type constraint: " ^ Syntax.string_of_typ ctxt T);
   123             val A = (case T of Type ("fun", [A, B]) => if A = B then A else err () | _ => err ());
   124             val (A', ps_idx') = prepare_typ A ps_idx;
   125             val (t', ps_idx'') = prepare t ps_idx';
   126           in (Const ("_type_constraint_", A' --> A') $ t', ps_idx'') end
   127       | prepare (Const (c, T)) (ps, idx) =
   128           (case const_type c of
   129             SOME U =>
   130               let val (U', idx') = polyT_of U idx
   131               in constraint T (Const (c, U')) (ps, idx') end
   132           | NONE => error ("Undeclared constant: " ^ quote c))
   133       | prepare (Var (xi, Type ("_polymorphic_", [T]))) (ps, idx) =
   134           let val (T', idx') = polyT_of T idx
   135           in (Var (xi, T'), (ps, idx')) end
   136       | prepare (Var (xi, T)) ps_idx = constraint T (Var (xi, var_param xi)) ps_idx
   137       | prepare (Free (x, T)) ps_idx = constraint T (Free (x, var_param (x, ~1))) ps_idx
   138       | prepare (Bound i) ps_idx = (Bound i, ps_idx)
   139       | prepare (Abs (x, T, t)) ps_idx =
   140           let
   141             val (T', ps_idx') = prepare_typ T ps_idx;
   142             val (t', ps_idx'') = prepare t ps_idx';
   143           in (Abs (x, T', t'), ps_idx'') end
   144       | prepare (t $ u) ps_idx =
   145           let
   146             val (t', ps_idx') = prepare t ps_idx;
   147             val (u', ps_idx'') = prepare u ps_idx';
   148           in (t' $ u', ps_idx'') end;
   149 
   150     val (tm', (params', idx'')) = prepare tm (params, idx');
   151   in (tm', (vparams', params', idx'')) end;
   152 
   153 
   154 
   155 (** results **)
   156 
   157 (* dereferenced views *)
   158 
   159 fun deref tye (T as TVar (xi, _)) =
   160       (case Vartab.lookup tye xi of
   161         NONE => T
   162       | SOME U => deref tye U)
   163   | deref tye T = T;
   164 
   165 fun add_parms tye T =
   166   (case deref tye T of
   167     Type (_, Ts) => fold (add_parms tye) Ts
   168   | TVar (xi, _) => if is_param xi then insert (op =) xi else I
   169   | _ => I);
   170 
   171 fun add_names tye T =
   172   (case deref tye T of
   173     Type (_, Ts) => fold (add_names tye) Ts
   174   | TFree (x, _) => Name.declare x
   175   | TVar ((x, i), _) => if is_param (x, i) then I else Name.declare x);
   176 
   177 
   178 (* finish -- standardize remaining parameters *)
   179 
   180 fun finish ctxt tye (Ts, ts) =
   181   let
   182     val used =
   183       (fold o fold_types) (add_names tye) ts (fold (add_names tye) Ts (Variable.names_of ctxt));
   184     val parms = rev ((fold o fold_types) (add_parms tye) ts (fold (add_parms tye) Ts []));
   185     val names = Name.invents used ("?" ^ Name.aT) (length parms);
   186     val tab = Vartab.make (parms ~~ names);
   187 
   188     fun finish_typ T =
   189       (case deref tye T of
   190         Type (a, Ts) => Type (a, map finish_typ Ts)
   191       | U as TFree _ => U
   192       | U as TVar (xi, S) =>
   193           (case Vartab.lookup tab xi of
   194             NONE => U
   195           | SOME a => TVar ((a, 0), S)));
   196   in (map finish_typ Ts, map (Type.strip_constraints o Term.map_types finish_typ) ts) end;
   197 
   198 
   199 (* fixate -- introduce fresh type variables *)
   200 
   201 fun fixate ctxt ts =
   202   let
   203     fun subst_param (xi, S) (inst, used) =
   204       if is_param xi then
   205         let
   206           val [a] = Name.invents used Name.aT 1;
   207           val used' = Name.declare a used;
   208         in (((xi, S), TFree (a, S)) :: inst, used') end
   209       else (inst, used);
   210     val used = (fold o fold_types) Term.declare_typ_names ts (Variable.names_of ctxt);
   211     val (inst, _) = fold_rev subst_param (fold Term.add_tvars ts []) ([], used);
   212   in (map o map_types) (Term_Subst.instantiateT inst) ts end;
   213 
   214 
   215 
   216 (** order-sorted unification of types **)
   217 
   218 exception NO_UNIFIER of string * typ Vartab.table;
   219 
   220 fun unify ctxt pp =
   221   let
   222     val thy = Proof_Context.theory_of ctxt;
   223     val arity_sorts = Type.arity_sorts pp (Sign.tsig_of thy);
   224 
   225 
   226     (* adjust sorts of parameters *)
   227 
   228     fun not_of_sort x S' S =
   229       "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
   230         Syntax.string_of_sort ctxt S;
   231 
   232     fun meet (_, []) tye_idx = tye_idx
   233       | meet (Type (a, Ts), S) (tye_idx as (tye, _)) =
   234           meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
   235       | meet (TFree (x, S'), S) (tye_idx as (tye, _)) =
   236           if Sign.subsort thy (S', S) then tye_idx
   237           else raise NO_UNIFIER (not_of_sort x S' S, tye)
   238       | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
   239           if Sign.subsort thy (S', S) then tye_idx
   240           else if is_param xi then
   241             (Vartab.update_new (xi, mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
   242           else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
   243     and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
   244           meets (Ts, Ss) (meet (deref tye T, S) tye_idx)
   245       | meets _ tye_idx = tye_idx;
   246 
   247 
   248     (* occurs check and assignment *)
   249 
   250     fun occurs_check tye xi (TVar (xi', S)) =
   251           if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye)
   252           else
   253             (case Vartab.lookup tye xi' of
   254               NONE => ()
   255             | SOME T => occurs_check tye xi T)
   256       | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts
   257       | occurs_check _ _ _ = ();
   258 
   259     fun assign xi (T as TVar (xi', _)) S env =
   260           if xi = xi' then env
   261           else env |> meet (T, S) |>> Vartab.update_new (xi, T)
   262       | assign xi T S (env as (tye, _)) =
   263           (occurs_check tye xi T; env |> meet (T, S) |>> Vartab.update_new (xi, T));
   264 
   265 
   266     (* unification *)
   267 
   268     fun show_tycon (a, Ts) =
   269       quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
   270 
   271     fun unif (T1, T2) (env as (tye, _)) =
   272       (case pairself (`is_paramT o deref tye) (T1, T2) of
   273         ((true, TVar (xi, S)), (_, T)) => assign xi T S env
   274       | ((_, T), (true, TVar (xi, S))) => assign xi T S env
   275       | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
   276           if a <> b then
   277             raise NO_UNIFIER
   278               ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
   279           else fold unif (Ts ~~ Us) env
   280       | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye));
   281 
   282   in unif end;
   283 
   284 
   285 
   286 (** simple type inference **)
   287 
   288 (* infer *)
   289 
   290 fun infer ctxt =
   291   let
   292     val pp = Syntax.pp ctxt;
   293 
   294 
   295     (* errors *)
   296 
   297     fun prep_output tye bs ts Ts =
   298       let
   299         val (Ts_bTs', ts') = finish ctxt tye (Ts @ map snd bs, ts);
   300         val (Ts', Ts'') = chop (length Ts) Ts_bTs';
   301         fun prep t =
   302           let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
   303           in Term.subst_bounds (map Syntax_Trans.mark_boundT xs, t) end;
   304       in (map prep ts', Ts') end;
   305 
   306     fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i);
   307 
   308     fun unif_failed msg =
   309       "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
   310 
   311     fun err_appl msg tye bs t T u U =
   312       let val ([t', u'], [T', U']) = prep_output tye bs [t, u] [T, U]
   313       in error (unif_failed msg ^ Type.appl_error pp t' T' u' U' ^ "\n") end;
   314 
   315 
   316     (* main *)
   317 
   318     fun inf _ (Const (_, T)) tye_idx = (T, tye_idx)
   319       | inf _ (Free (_, T)) tye_idx = (T, tye_idx)
   320       | inf _ (Var (_, T)) tye_idx = (T, tye_idx)
   321       | inf bs (Bound i) tye_idx =
   322           (snd (nth bs i handle Subscript => err_loose i), tye_idx)
   323       | inf bs (Abs (x, T, t)) tye_idx =
   324           let val (U, tye_idx') = inf ((x, T) :: bs) t tye_idx
   325           in (T --> U, tye_idx') end
   326       | inf bs (t $ u) tye_idx =
   327           let
   328             val (T, tye_idx') = inf bs t tye_idx;
   329             val (U, (tye, idx)) = inf bs u tye_idx';
   330             val V = mk_param idx [];
   331             val tye_idx'' = unify ctxt pp (U --> V, T) (tye, idx + 1)
   332               handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U;
   333           in (V, tye_idx'') end;
   334 
   335   in inf [] end;
   336 
   337 
   338 (* main interfaces *)
   339 
   340 fun prepare ctxt const_type var_type raw_ts =
   341   let
   342     val get_type = the_default dummyT o var_type;
   343     val constrain_vars = Term.map_aterms
   344       (fn Free (x, T) => Type.constraint T (Free (x, get_type (x, ~1)))
   345         | Var (xi, T) => Type.constraint T (Var (xi, get_type xi))
   346         | t => t);
   347 
   348     val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
   349     val (ts', (_, _, idx)) =
   350       fold_map (prepare_term ctxt const_type o constrain_vars) ts
   351         (Vartab.empty, Vartab.empty, 0);
   352   in (idx, ts') end;
   353 
   354 fun infer_types ctxt const_type var_type raw_ts =
   355   let
   356     val (idx, ts) = prepare ctxt const_type var_type raw_ts;
   357     val (tye, _) = fold (snd oo infer ctxt) ts (Vartab.empty, idx);
   358     val (_, ts') = finish ctxt tye ([], ts);
   359   in ts' end;
   360 
   361 end;