src/Pure/type_infer_context.ML
author wenzelm
Mon Dec 04 22:54:31 2017 +0100 (20 months ago)
changeset 67131 85d10959c2e4
parent 64556 851ae0e7b09c
child 69575 f77cc54f6d47
permissions -rw-r--r--
tuned signature;
     1 (*  Title:      Pure/type_infer_context.ML
     2     Author:     Stefan Berghofer and Markus Wenzel, TU Muenchen
     3 
     4 Type-inference preparation and standard type inference.
     5 *)
     6 
     7 signature TYPE_INFER_CONTEXT =
     8 sig
     9   val const_sorts: bool Config.T
    10   val const_type: Proof.context -> string -> typ option
    11   val prepare_positions: Proof.context -> term list -> term list * (Position.T * typ) list
    12   val prepare: Proof.context -> term list -> int * term list
    13   val infer_types: Proof.context -> term list -> term list
    14 end;
    15 
    16 structure Type_Infer_Context: TYPE_INFER_CONTEXT =
    17 struct
    18 
    19 (** prepare types/terms: create inference parameters **)
    20 
    21 (* constraints *)
    22 
    23 val const_sorts =
    24   Config.bool (Config.declare ("const_sorts", \<^here>) (K (Config.Bool true)));
    25 
    26 fun const_type ctxt =
    27   try ((not (Config.get ctxt const_sorts) ? Type.strip_sorts) o
    28     Consts.the_constraint (Proof_Context.consts_of ctxt));
    29 
    30 fun var_type ctxt = the_default dummyT o Proof_Context.def_type ctxt;
    31 
    32 
    33 (* prepare_typ *)
    34 
    35 fun prepare_typ typ params_idx =
    36   let
    37     val (params', idx) = fold_atyps
    38       (fn TVar (xi, S) =>
    39           (fn ps_idx as (ps, idx) =>
    40             if Type_Infer.is_param xi andalso not (Vartab.defined ps xi)
    41             then (Vartab.update (xi, Type_Infer.mk_param idx S) ps, idx + 1) else ps_idx)
    42         | _ => I) typ params_idx;
    43 
    44     fun prepare (T as Type (a, Ts)) idx =
    45           if T = dummyT then (Type_Infer.mk_param idx [], idx + 1)
    46           else
    47             let val (Ts', idx') = fold_map prepare Ts idx
    48             in (Type (a, Ts'), idx') end
    49       | prepare (T as TVar (xi, _)) idx =
    50           (case Vartab.lookup params' xi of
    51             NONE => T
    52           | SOME p => p, idx)
    53       | prepare (TFree ("'_dummy_", S)) idx = (Type_Infer.mk_param idx S, idx + 1)
    54       | prepare (T as TFree _) idx = (T, idx);
    55 
    56     val (typ', idx') = prepare typ idx;
    57   in (typ', (params', idx')) end;
    58 
    59 
    60 (* prepare_term *)
    61 
    62 fun prepare_term ctxt tm (vparams, params, idx) =
    63   let
    64     fun add_vparm xi (ps_idx as (ps, idx)) =
    65       if not (Vartab.defined ps xi) then
    66         (Vartab.update (xi, Type_Infer.mk_param idx []) ps, idx + 1)
    67       else ps_idx;
    68 
    69     val (vparams', idx') = fold_aterms
    70       (fn Var (_, Type ("_polymorphic_", _)) => I
    71         | Var (xi, _) => add_vparm xi
    72         | Free (x, _) => add_vparm (x, ~1)
    73         | _ => I)
    74       tm (vparams, idx);
    75     fun var_param xi = the (Vartab.lookup vparams' xi);
    76 
    77     fun polyT_of T idx =
    78       apsnd snd (prepare_typ (Type_Infer.paramify_vars T) (Vartab.empty, idx));
    79 
    80     fun constraint T t ps =
    81       if T = dummyT then (t, ps)
    82       else
    83         let val (T', ps') = prepare_typ T ps
    84         in (Type.constraint T' t, ps') end;
    85 
    86     fun prepare (Const ("_type_constraint_", T) $ t) ps_idx =
    87           let
    88             val A = Type.constraint_type ctxt T;
    89             val (A', ps_idx') = prepare_typ A ps_idx;
    90             val (t', ps_idx'') = prepare t ps_idx';
    91           in (Const ("_type_constraint_", A' --> A') $ t', ps_idx'') end
    92       | prepare (Const (c, T)) (ps, idx) =
    93           (case const_type ctxt c of
    94             SOME U =>
    95               let val (U', idx') = polyT_of U idx
    96               in constraint T (Const (c, U')) (ps, idx') end
    97           | NONE => error ("Undeclared constant: " ^ quote c))
    98       | prepare (Var (xi, Type ("_polymorphic_", [T]))) (ps, idx) =
    99           let val (T', idx') = polyT_of T idx
   100           in (Var (xi, T'), (ps, idx')) end
   101       | prepare (Var (xi, T)) ps_idx = constraint T (Var (xi, var_param xi)) ps_idx
   102       | prepare (Free (x, T)) ps_idx = constraint T (Free (x, var_param (x, ~1))) ps_idx
   103       | prepare (Bound i) ps_idx = (Bound i, ps_idx)
   104       | prepare (Abs (x, T, t)) ps_idx =
   105           let
   106             val (T', ps_idx') = prepare_typ T ps_idx;
   107             val (t', ps_idx'') = prepare t ps_idx';
   108           in (Abs (x, T', t'), ps_idx'') end
   109       | prepare (t $ u) ps_idx =
   110           let
   111             val (t', ps_idx') = prepare t ps_idx;
   112             val (u', ps_idx'') = prepare u ps_idx';
   113           in (t' $ u', ps_idx'') end;
   114 
   115     val (tm', (params', idx'')) = prepare tm (params, idx');
   116   in (tm', (vparams', params', idx'')) end;
   117 
   118 
   119 (* prepare_positions *)
   120 
   121 fun prepare_positions ctxt tms =
   122   let
   123     fun prepareT (Type (a, Ts)) ps_idx =
   124           let val (Ts', ps_idx') = fold_map prepareT Ts ps_idx
   125           in (Type (a, Ts'), ps_idx') end
   126       | prepareT T (ps, idx) =
   127           (case Term_Position.decode_positionT T of
   128             SOME pos =>
   129               let val U = Type_Infer.mk_param idx []
   130               in (U, ((pos, U) :: ps, idx + 1)) end
   131           | NONE => (T, (ps, idx)));
   132 
   133     fun prepare (Const ("_type_constraint_", T)) ps_idx =
   134           let
   135             val A = Type.constraint_type ctxt T;
   136             val (A', ps_idx') = prepareT A ps_idx;
   137           in (Const ("_type_constraint_", A' --> A'), ps_idx') end
   138       | prepare (Const (c, T)) ps_idx =
   139           let val (T', ps_idx') = prepareT T ps_idx
   140           in (Const (c, T'), ps_idx') end
   141       | prepare (Free (x, T)) ps_idx =
   142           let val (T', ps_idx') = prepareT T ps_idx
   143           in (Free (x, T'), ps_idx') end
   144       | prepare (Var (xi, T)) ps_idx =
   145           let val (T', ps_idx') = prepareT T ps_idx
   146           in (Var (xi, T'), ps_idx') end
   147       | prepare (t as Bound _) ps_idx = (t, ps_idx)
   148       | prepare (Abs (x, T, t)) ps_idx =
   149           let
   150             val (T', ps_idx') = prepareT T ps_idx;
   151             val (t', ps_idx'') = prepare t ps_idx';
   152           in (Abs (x, T', t'), ps_idx'') end
   153       | prepare (t $ u) ps_idx =
   154           let
   155             val (t', ps_idx') = prepare t ps_idx;
   156             val (u', ps_idx'') = prepare u ps_idx';
   157           in (t' $ u', ps_idx'') end;
   158 
   159     val idx = Type_Infer.param_maxidx_of tms + 1;
   160     val (tms', (ps, _)) = fold_map prepare tms ([], idx);
   161   in (tms', ps) end;
   162 
   163 
   164 
   165 (** order-sorted unification of types **)
   166 
   167 exception NO_UNIFIER of string * typ Vartab.table;
   168 
   169 fun unify ctxt =
   170   let
   171     val thy = Proof_Context.theory_of ctxt;
   172     val arity_sorts = Proof_Context.arity_sorts ctxt;
   173 
   174 
   175     (* adjust sorts of parameters *)
   176 
   177     fun not_of_sort x S' S =
   178       "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
   179         Syntax.string_of_sort ctxt S;
   180 
   181     fun meet (_, []) tye_idx = tye_idx
   182       | meet (Type (a, Ts), S) (tye_idx as (tye, _)) =
   183           meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
   184       | meet (TFree (x, S'), S) (tye_idx as (tye, _)) =
   185           if Sign.subsort thy (S', S) then tye_idx
   186           else raise NO_UNIFIER (not_of_sort x S' S, tye)
   187       | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
   188           if Sign.subsort thy (S', S) then tye_idx
   189           else if Type_Infer.is_param xi then
   190             (Vartab.update_new
   191               (xi, Type_Infer.mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
   192           else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
   193     and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
   194           meets (Ts, Ss) (meet (Type_Infer.deref tye T, S) tye_idx)
   195       | meets _ tye_idx = tye_idx;
   196 
   197 
   198     (* occurs check and assignment *)
   199 
   200     fun occurs_check tye xi (TVar (xi', _)) =
   201           if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye)
   202           else
   203             (case Vartab.lookup tye xi' of
   204               NONE => ()
   205             | SOME T => occurs_check tye xi T)
   206       | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts
   207       | occurs_check _ _ _ = ();
   208 
   209     fun assign xi (T as TVar (xi', _)) S env =
   210           if xi = xi' then env
   211           else env |> meet (T, S) |>> Vartab.update_new (xi, T)
   212       | assign xi T S (env as (tye, _)) =
   213           (occurs_check tye xi T; env |> meet (T, S) |>> Vartab.update_new (xi, T));
   214 
   215 
   216     (* unification *)
   217 
   218     fun show_tycon (a, Ts) =
   219       quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
   220 
   221     fun unif (T1, T2) (env as (tye, _)) =
   222       (case apply2 (`Type_Infer.is_paramT o Type_Infer.deref tye) (T1, T2) of
   223         ((true, TVar (xi, S)), (_, T)) => assign xi T S env
   224       | ((_, T), (true, TVar (xi, S))) => assign xi T S env
   225       | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
   226           if a <> b then
   227             raise NO_UNIFIER
   228               ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
   229           else fold unif (Ts ~~ Us) env
   230       | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye));
   231 
   232   in unif end;
   233 
   234 
   235 
   236 (** simple type inference **)
   237 
   238 (* infer *)
   239 
   240 fun infer ctxt =
   241   let
   242     (* errors *)
   243 
   244     fun prep_output tye bs ts Ts =
   245       let
   246         val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts);
   247         val (Ts', Ts'') = chop (length Ts) Ts_bTs';
   248         fun prep t =
   249           let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
   250           in Term.subst_bounds (map Syntax_Trans.mark_bound_abs xs, t) end;
   251       in (map prep ts', Ts') end;
   252 
   253     fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i);
   254 
   255     fun unif_failed msg =
   256       "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
   257 
   258     fun err_appl msg tye bs t T u U =
   259       let val ([t', u'], [T', U']) = prep_output tye bs [t, u] [T, U]
   260       in error (unif_failed msg ^ Type.appl_error ctxt t' T' u' U' ^ "\n") end;
   261 
   262 
   263     (* main *)
   264 
   265     fun inf _ (Const (_, T)) tye_idx = (T, tye_idx)
   266       | inf _ (Free (_, T)) tye_idx = (T, tye_idx)
   267       | inf _ (Var (_, T)) tye_idx = (T, tye_idx)
   268       | inf bs (Bound i) tye_idx =
   269           (snd (nth bs i handle General.Subscript => err_loose i), tye_idx)
   270       | inf bs (Abs (x, T, t)) tye_idx =
   271           let val (U, tye_idx') = inf ((x, T) :: bs) t tye_idx
   272           in (T --> U, tye_idx') end
   273       | inf bs (t $ u) tye_idx =
   274           let
   275             val (T, tye_idx') = inf bs t tye_idx;
   276             val (U, (tye, idx)) = inf bs u tye_idx';
   277             val V = Type_Infer.mk_param idx [];
   278             val tye_idx'' = unify ctxt (U --> V, T) (tye, idx + 1)
   279               handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U;
   280           in (V, tye_idx'') end;
   281 
   282   in inf [] end;
   283 
   284 
   285 (* main interfaces *)
   286 
   287 fun prepare ctxt raw_ts =
   288   let
   289     val constrain_vars = Term.map_aterms
   290       (fn Free (x, T) => Type.constraint T (Free (x, var_type ctxt (x, ~1)))
   291         | Var (xi, T) => Type.constraint T (Var (xi, var_type ctxt xi))
   292         | t => t);
   293 
   294     val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
   295     val idx = Type_Infer.param_maxidx_of ts + 1;
   296     val (ts', (_, _, idx')) =
   297       fold_map (prepare_term ctxt o constrain_vars) ts
   298         (Vartab.empty, Vartab.empty, idx);
   299   in (idx', ts') end;
   300 
   301 fun infer_types ctxt raw_ts =
   302   let
   303     val (idx, ts) = prepare ctxt raw_ts;
   304     val (tye, _) = fold (snd oo infer ctxt) ts (Vartab.empty, idx);
   305     val (_, ts') = Type_Infer.finish ctxt tye ([], ts);
   306   in ts' end;
   307 
   308 end;