src/Pure/type_infer.ML
author wenzelm
Sun Sep 12 22:28:59 2010 +0200 (2010-09-12)
changeset 39292 6f085332c7d3
parent 39291 4b632bb847a8
child 39294 27fae73fe769
permissions -rw-r--r--
Type_Infer.preterm: eliminated separate Constraint;
     1 (*  Title:      Pure/type_infer.ML
     2     Author:     Stefan Berghofer and Markus Wenzel, TU Muenchen
     3 
     4 Simple type inference.
     5 *)
     6 
     7 signature TYPE_INFER =
     8 sig
     9   val anyT: sort -> typ
    10   val is_param: indexname -> bool
    11   val param: int -> string * sort -> typ
    12   val paramify_vars: typ -> typ
    13   val paramify_dummies: typ -> int -> typ * int
    14   val fixate_params: Name.context -> term list -> term list
    15   val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
    16     term list -> term list
    17 end;
    18 
    19 structure Type_Infer: TYPE_INFER =
    20 struct
    21 
    22 
    23 (** type parameters and constraints **)
    24 
    25 fun anyT S = TFree ("'_dummy_", S);
    26 
    27 
    28 (* type inference parameters -- may get instantiated *)
    29 
    30 fun is_param (x, _: int) = String.isPrefix "?" x;
    31 fun param i (x, S) = TVar (("?" ^ x, i), S);
    32 
    33 val paramify_vars =
    34   Same.commit
    35     (Term_Subst.map_atypsT_same
    36       (fn TVar ((x, i), S) => (param i (x, S)) | _ => raise Same.SAME));
    37 
    38 val paramify_dummies =
    39   let
    40     fun dummy S maxidx = (param (maxidx + 1) ("'dummy", S), maxidx + 1);
    41 
    42     fun paramify (TFree ("'_dummy_", S)) maxidx = dummy S maxidx
    43       | paramify (Type ("dummy", _)) maxidx = dummy [] maxidx
    44       | paramify (Type (a, Ts)) maxidx =
    45           let val (Ts', maxidx') = fold_map paramify Ts maxidx
    46           in (Type (a, Ts'), maxidx') end
    47       | paramify T maxidx = (T, maxidx);
    48   in paramify end;
    49 
    50 fun fixate_params name_context ts =
    51   let
    52     fun subst_param (xi, S) (inst, used) =
    53       if is_param xi then
    54         let
    55           val [a] = Name.invents used Name.aT 1;
    56           val used' = Name.declare a used;
    57         in (((xi, S), TFree (a, S)) :: inst, used') end
    58       else (inst, used);
    59     val name_context' = (fold o fold_types) Term.declare_typ_names ts name_context;
    60     val (inst, _) = fold_rev subst_param (fold Term.add_tvars ts []) ([], name_context');
    61   in (map o map_types) (Term_Subst.instantiateT inst) ts end;
    62 
    63 
    64 
    65 (** pretyps and preterms **)
    66 
    67 datatype pretyp =
    68   PType of string * pretyp list |
    69   PTFree of string * sort |
    70   PTVar of indexname * sort |
    71   Param of int * sort;
    72 
    73 datatype preterm =
    74   PConst of string * pretyp |
    75   PFree of string * pretyp |
    76   PVar of indexname * pretyp |
    77   PBound of int |
    78   PAbs of string * pretyp * preterm |
    79   PAppl of preterm * preterm;
    80 
    81 
    82 (* utils *)
    83 
    84 fun deref tye (T as Param (i, S)) =
    85       (case Inttab.lookup tye i of
    86         NONE => T
    87       | SOME U => deref tye U)
    88   | deref tye T = T;
    89 
    90 fun fold_pretyps f (PConst (_, T)) x = f T x
    91   | fold_pretyps f (PFree (_, T)) x = f T x
    92   | fold_pretyps f (PVar (_, T)) x = f T x
    93   | fold_pretyps _ (PBound _) x = x
    94   | fold_pretyps f (PAbs (_, T, t)) x = fold_pretyps f t (f T x)
    95   | fold_pretyps f (PAppl (t, u)) x = fold_pretyps f u (fold_pretyps f t x);
    96 
    97 
    98 
    99 (** raw typs/terms to pretyps/preterms **)
   100 
   101 (* pretyp_of *)
   102 
   103 fun pretyp_of typ params_idx =
   104   let
   105     val (params', idx) = fold_atyps
   106       (fn TVar (xi as (x, _), S) =>
   107           (fn ps_idx as (ps, idx) =>
   108             if is_param xi andalso not (Vartab.defined ps xi)
   109             then (Vartab.update (xi, Param (idx, S)) ps, idx + 1) else ps_idx)
   110         | _ => I) typ params_idx;
   111 
   112     fun pre_of (TVar (v as (xi, _))) idx =
   113           (case Vartab.lookup params' xi of
   114             NONE => PTVar v
   115           | SOME p => p, idx)
   116       | pre_of (TFree ("'_dummy_", S)) idx = (Param (idx, S), idx + 1)
   117       | pre_of (TFree v) idx = (PTFree v, idx)
   118       | pre_of (T as Type (a, Ts)) idx =
   119           if T = dummyT then (Param (idx, []), idx + 1)
   120           else
   121             let val (Ts', idx') = fold_map pre_of Ts idx
   122             in (PType (a, Ts'), idx') end;
   123 
   124     val (ptyp, idx') = pre_of typ idx;
   125   in (ptyp, (params', idx')) end;
   126 
   127 
   128 (* preterm_of *)
   129 
   130 fun preterm_of const_type tm (vparams, params, idx) =
   131   let
   132     fun add_vparm xi (ps_idx as (ps, idx)) =
   133       if not (Vartab.defined ps xi) then
   134         (Vartab.update (xi, Param (idx, [])) ps, idx + 1)
   135       else ps_idx;
   136 
   137     val (vparams', idx') = fold_aterms
   138       (fn Var (_, Type ("_polymorphic_", _)) => I
   139         | Var (xi, _) => add_vparm xi
   140         | Free (x, _) => add_vparm (x, ~1)
   141         | _ => I)
   142       tm (vparams, idx);
   143     fun var_param xi = the (Vartab.lookup vparams' xi);
   144 
   145     fun polyT_of T idx = apsnd snd (pretyp_of (paramify_vars T) (Vartab.empty, idx));
   146 
   147     fun constraint T t ps =
   148       if T = dummyT then (t, ps)
   149       else
   150         let val (T', ps') = pretyp_of T ps
   151         in (PAppl (PConst ("_type_constraint_", PType ("fun", [T', T'])), t), ps') end;
   152 
   153     fun pre_of (Const (c, T)) (ps, idx) =
   154           (case const_type c of
   155             SOME U =>
   156               let val (pU, idx') = polyT_of U idx
   157               in constraint T (PConst (c, pU)) (ps, idx') end
   158           | NONE => error ("Undeclared constant: " ^ quote c))
   159       | pre_of (Const ("_type_constraint_", T) $ t) ps_idx =
   160           let
   161             val (T', ps_idx') = pretyp_of T ps_idx;
   162             val (t', ps_idx'') = pre_of t ps_idx';
   163           in (PAppl (PConst ("_type_constraint_", T'), t'), ps_idx'') end
   164       | pre_of (Var (xi, Type ("_polymorphic_", [T]))) (ps, idx) =
   165           let val (pT, idx') = polyT_of T idx
   166           in (PVar (xi, pT), (ps, idx')) end
   167       | pre_of (Var (xi, T)) ps_idx = constraint T (PVar (xi, var_param xi)) ps_idx
   168       | pre_of (Free (x, T)) ps_idx = constraint T (PFree (x, var_param (x, ~1))) ps_idx
   169       | pre_of (Bound i) ps_idx = (PBound i, ps_idx)
   170       | pre_of (Abs (x, T, t)) ps_idx =
   171           let
   172             val (T', ps_idx') = pretyp_of T ps_idx;
   173             val (t', ps_idx'') = pre_of t ps_idx';
   174           in (PAbs (x, T', t'), ps_idx'') end
   175       | pre_of (t $ u) ps_idx =
   176           let
   177             val (t', ps_idx') = pre_of t ps_idx;
   178             val (u', ps_idx'') = pre_of u ps_idx';
   179           in (PAppl (t', u'), ps_idx'') end;
   180 
   181     val (tm', (params', idx'')) = pre_of tm (params, idx');
   182   in (tm', (vparams', params', idx'')) end;
   183 
   184 
   185 
   186 (** pretyps/terms to typs/terms **)
   187 
   188 (* add_parms *)
   189 
   190 fun add_parmsT tye T =
   191   (case deref tye T of
   192     PType (_, Ts) => fold (add_parmsT tye) Ts
   193   | Param (i, _) => insert (op =) i
   194   | _ => I);
   195 
   196 fun add_parms tye = fold_pretyps (add_parmsT tye);
   197 
   198 
   199 (* add_names *)
   200 
   201 fun add_namesT tye T =
   202   (case deref tye T of
   203     PType (_, Ts) => fold (add_namesT tye) Ts
   204   | PTFree (x, _) => Name.declare x
   205   | PTVar ((x, _), _) => Name.declare x
   206   | Param _ => I);
   207 
   208 fun add_names tye = fold_pretyps (add_namesT tye);
   209 
   210 
   211 (* simple_typ/term_of *)
   212 
   213 fun simple_typ_of tye f T =
   214   (case deref tye T of
   215     PType (a, Ts) => Type (a, map (simple_typ_of tye f) Ts)
   216   | PTFree v => TFree v
   217   | PTVar v => TVar v
   218   | Param (i, S) => TVar (f i, S));
   219 
   220 fun simple_term_of tye f (PConst (c, T)) = Const (c, simple_typ_of tye f T)
   221   | simple_term_of tye f (PFree (x, T)) = Free (x, simple_typ_of tye f T)
   222   | simple_term_of tye f (PVar (xi, T)) = Var (xi, simple_typ_of tye f T)
   223   | simple_term_of tye f (PBound i) = Bound i
   224   | simple_term_of tye f (PAbs (x, T, t)) =
   225       Abs (x, simple_typ_of tye f T, simple_term_of tye f t)
   226   | simple_term_of tye f (PAppl (t, u)) =
   227       simple_term_of tye f t $ simple_term_of tye f u;
   228 
   229 
   230 (* typs_terms_of *)
   231 
   232 fun typs_terms_of ctxt tye (Ts, ts) =
   233   let
   234     val used = fold (add_names tye) ts (fold (add_namesT tye) Ts (Variable.names_of ctxt));
   235     val parms = rev (fold (add_parms tye) ts (fold (add_parmsT tye) Ts []));
   236     val names = Name.invents used ("?" ^ Name.aT) (length parms);
   237     val tab = Inttab.make (parms ~~ names);
   238 
   239     val maxidx = Variable.maxidx_of ctxt;
   240     fun f i = (the (Inttab.lookup tab i), maxidx + 1);
   241   in (map (simple_typ_of tye f) Ts, map (Type.strip_constraints o simple_term_of tye f) ts) end;
   242 
   243 
   244 
   245 (** order-sorted unification of types **)
   246 
   247 exception NO_UNIFIER of string * pretyp Inttab.table;
   248 
   249 fun unify ctxt pp =
   250   let
   251     val thy = ProofContext.theory_of ctxt;
   252     val arity_sorts = Type.arity_sorts pp (Sign.tsig_of thy);
   253 
   254 
   255     (* adjust sorts of parameters *)
   256 
   257     fun not_of_sort x S' S =
   258       "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
   259         Syntax.string_of_sort ctxt S;
   260 
   261     fun meet (_, []) tye_idx = tye_idx
   262       | meet (Param (i, S'), S) (tye_idx as (tye, idx)) =
   263           if Sign.subsort thy (S', S) then tye_idx
   264           else (Inttab.update_new (i,
   265             Param (idx, Sign.inter_sort thy (S', S))) tye, idx + 1)
   266       | meet (PType (a, Ts), S) (tye_idx as (tye, _)) =
   267           meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
   268       | meet (PTFree (x, S'), S) (tye_idx as (tye, _)) =
   269           if Sign.subsort thy (S', S) then tye_idx
   270           else raise NO_UNIFIER (not_of_sort x S' S, tye)
   271       | meet (PTVar (xi, S'), S) (tye_idx as (tye, _)) =
   272           if Sign.subsort thy (S', S) then tye_idx
   273           else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
   274     and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
   275           meets (Ts, Ss) (meet (deref tye T, S) tye_idx)
   276       | meets _ tye_idx = tye_idx;
   277 
   278 
   279     (* occurs check and assignment *)
   280 
   281     fun occurs_check tye i (Param (i', S)) =
   282           if i = i' then raise NO_UNIFIER ("Occurs check!", tye)
   283           else
   284             (case Inttab.lookup tye i' of
   285               NONE => ()
   286             | SOME T => occurs_check tye i T)
   287       | occurs_check tye i (PType (_, Ts)) = List.app (occurs_check tye i) Ts
   288       | occurs_check _ _ _ = ();
   289 
   290     fun assign i (T as Param (i', _)) S tye_idx =
   291           if i = i' then tye_idx
   292           else tye_idx |> meet (T, S) |>> Inttab.update_new (i, T)
   293       | assign i T S (tye_idx as (tye, _)) =
   294           (occurs_check tye i T; tye_idx |> meet (T, S) |>> Inttab.update_new (i, T));
   295 
   296 
   297     (* unification *)
   298 
   299     fun show_tycon (a, Ts) =
   300       quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
   301 
   302     fun unif (T1, T2) (tye_idx as (tye, idx)) =
   303       (case (deref tye T1, deref tye T2) of
   304         (Param (i, S), T) => assign i T S tye_idx
   305       | (T, Param (i, S)) => assign i T S tye_idx
   306       | (PType (a, Ts), PType (b, Us)) =>
   307           if a <> b then
   308             raise NO_UNIFIER
   309               ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
   310           else fold unif (Ts ~~ Us) tye_idx
   311       | (T, U) => if T = U then tye_idx else raise NO_UNIFIER ("", tye));
   312 
   313   in unif end;
   314 
   315 
   316 
   317 (** type inference **)
   318 
   319 (* infer *)
   320 
   321 fun infer ctxt =
   322   let
   323     val pp = Syntax.pp ctxt;
   324 
   325 
   326     (* errors *)
   327 
   328     fun prep_output tye bs ts Ts =
   329       let
   330         val (Ts_bTs', ts') = typs_terms_of ctxt tye (Ts @ map snd bs, ts);
   331         val (Ts', Ts'') = chop (length Ts) Ts_bTs';
   332         fun prep t =
   333           let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
   334           in Term.subst_bounds (map Syntax.mark_boundT xs, t) end;
   335       in (map prep ts', Ts') end;
   336 
   337     fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i);
   338 
   339     fun unif_failed msg =
   340       "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
   341 
   342     fun err_appl msg tye bs t T u U =
   343       let val ([t', u'], [T', U']) = prep_output tye bs [t, u] [T, U]
   344       in error (unif_failed msg ^ Type.appl_error pp t' T' u' U' ^ "\n") end;
   345 
   346 
   347     (* main *)
   348 
   349     fun inf _ (PConst (_, T)) tye_idx = (T, tye_idx)
   350       | inf _ (PFree (_, T)) tye_idx = (T, tye_idx)
   351       | inf _ (PVar (_, T)) tye_idx = (T, tye_idx)
   352       | inf bs (PBound i) tye_idx =
   353           (snd (nth bs i handle Subscript => err_loose i), tye_idx)
   354       | inf bs (PAbs (x, T, t)) tye_idx =
   355           let val (U, tye_idx') = inf ((x, T) :: bs) t tye_idx
   356           in (PType ("fun", [T, U]), tye_idx') end
   357       | inf bs (PAppl (t, u)) tye_idx =
   358           let
   359             val (T, tye_idx') = inf bs t tye_idx;
   360             val (U, (tye, idx)) = inf bs u tye_idx';
   361             val V = Param (idx, []);
   362             val U_to_V = PType ("fun", [U, V]);
   363             val tye_idx'' = unify ctxt pp (U_to_V, T) (tye, idx + 1)
   364               handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U;
   365           in (V, tye_idx'') end;
   366 
   367   in inf [] end;
   368 
   369 
   370 (* infer_types *)
   371 
   372 fun infer_types ctxt const_type var_type raw_ts =
   373   let
   374     (*constrain vars*)
   375     val get_type = the_default dummyT o var_type;
   376     val constrain_vars = Term.map_aterms
   377       (fn Free (x, T) => Type.constraint T (Free (x, get_type (x, ~1)))
   378         | Var (xi, T) => Type.constraint T (Var (xi, get_type xi))
   379         | t => t);
   380 
   381     (*convert to preterms*)
   382     val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
   383     val (ts', (_, _, idx)) =
   384       fold_map (preterm_of const_type o constrain_vars) ts
   385       (Vartab.empty, Vartab.empty, 0);
   386 
   387     (*do type inference*)
   388     val (tye, _) = fold (snd oo infer ctxt) ts' (Inttab.empty, idx);
   389   in #2 (typs_terms_of ctxt tye ([], ts')) end;
   390 
   391 end;