src/Pure/type_infer_context.ML
changeset 42405 13ecdb3057d8
child 43278 1fbdcebb364b
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/Pure/type_infer_context.ML	Tue Apr 19 20:47:02 2011 +0200
     1.3 @@ -0,0 +1,267 @@
     1.4 +(*  Title:      Pure/type_infer_context.ML
     1.5 +    Author:     Stefan Berghofer and Markus Wenzel, TU Muenchen
     1.6 +
     1.7 +Type-inference preparation and standard type inference.
     1.8 +*)
     1.9 +
    1.10 +signature TYPE_INFER_CONTEXT =
    1.11 +sig
    1.12 +  val const_sorts: bool Config.T
    1.13 +  val prepare: Proof.context -> term list -> int * term list
    1.14 +  val infer_types: Proof.context -> term list -> term list
    1.15 +end;
    1.16 +
    1.17 +structure Type_Infer_Context: TYPE_INFER_CONTEXT =
    1.18 +struct
    1.19 +
    1.20 +(** prepare types/terms: create inference parameters **)
    1.21 +
    1.22 +(* constraints *)
    1.23 +
    1.24 +val const_sorts = Config.bool (Config.declare "const_sorts" (K (Config.Bool true)));
    1.25 +
    1.26 +fun const_type ctxt =
    1.27 +  try ((not (Config.get ctxt const_sorts) ? Type.strip_sorts) o
    1.28 +    Consts.the_constraint (Proof_Context.consts_of ctxt));
    1.29 +
    1.30 +fun var_type ctxt = the_default dummyT o Proof_Context.def_type ctxt;
    1.31 +
    1.32 +
    1.33 +(* prepare_typ *)
    1.34 +
    1.35 +fun prepare_typ typ params_idx =
    1.36 +  let
    1.37 +    val (params', idx) = fold_atyps
    1.38 +      (fn TVar (xi, S) =>
    1.39 +          (fn ps_idx as (ps, idx) =>
    1.40 +            if Type_Infer.is_param xi andalso not (Vartab.defined ps xi)
    1.41 +            then (Vartab.update (xi, Type_Infer.mk_param idx S) ps, idx + 1) else ps_idx)
    1.42 +        | _ => I) typ params_idx;
    1.43 +
    1.44 +    fun prepare (T as Type (a, Ts)) idx =
    1.45 +          if T = dummyT then (Type_Infer.mk_param idx [], idx + 1)
    1.46 +          else
    1.47 +            let val (Ts', idx') = fold_map prepare Ts idx
    1.48 +            in (Type (a, Ts'), idx') end
    1.49 +      | prepare (T as TVar (xi, _)) idx =
    1.50 +          (case Vartab.lookup params' xi of
    1.51 +            NONE => T
    1.52 +          | SOME p => p, idx)
    1.53 +      | prepare (TFree ("'_dummy_", S)) idx = (Type_Infer.mk_param idx S, idx + 1)
    1.54 +      | prepare (T as TFree _) idx = (T, idx);
    1.55 +
    1.56 +    val (typ', idx') = prepare typ idx;
    1.57 +  in (typ', (params', idx')) end;
    1.58 +
    1.59 +
    1.60 +(* prepare_term *)
    1.61 +
    1.62 +fun prepare_term ctxt tm (vparams, params, idx) =
    1.63 +  let
    1.64 +    fun add_vparm xi (ps_idx as (ps, idx)) =
    1.65 +      if not (Vartab.defined ps xi) then
    1.66 +        (Vartab.update (xi, Type_Infer.mk_param idx []) ps, idx + 1)
    1.67 +      else ps_idx;
    1.68 +
    1.69 +    val (vparams', idx') = fold_aterms
    1.70 +      (fn Var (_, Type ("_polymorphic_", _)) => I
    1.71 +        | Var (xi, _) => add_vparm xi
    1.72 +        | Free (x, _) => add_vparm (x, ~1)
    1.73 +        | _ => I)
    1.74 +      tm (vparams, idx);
    1.75 +    fun var_param xi = the (Vartab.lookup vparams' xi);
    1.76 +
    1.77 +    fun polyT_of T idx =
    1.78 +      apsnd snd (prepare_typ (Type_Infer.paramify_vars T) (Vartab.empty, idx));
    1.79 +
    1.80 +    fun constraint T t ps =
    1.81 +      if T = dummyT then (t, ps)
    1.82 +      else
    1.83 +        let val (T', ps') = prepare_typ T ps
    1.84 +        in (Type.constraint T' t, ps') end;
    1.85 +
    1.86 +    fun prepare (Const ("_type_constraint_", T) $ t) ps_idx =
    1.87 +          let
    1.88 +            fun err () =
    1.89 +              error ("Malformed internal type constraint: " ^ Syntax.string_of_typ ctxt T);
    1.90 +            val A = (case T of Type ("fun", [A, B]) => if A = B then A else err () | _ => err ());
    1.91 +            val (A', ps_idx') = prepare_typ A ps_idx;
    1.92 +            val (t', ps_idx'') = prepare t ps_idx';
    1.93 +          in (Const ("_type_constraint_", A' --> A') $ t', ps_idx'') end
    1.94 +      | prepare (Const (c, T)) (ps, idx) =
    1.95 +          (case const_type ctxt c of
    1.96 +            SOME U =>
    1.97 +              let val (U', idx') = polyT_of U idx
    1.98 +              in constraint T (Const (c, U')) (ps, idx') end
    1.99 +          | NONE => error ("Undeclared constant: " ^ quote c))
   1.100 +      | prepare (Var (xi, Type ("_polymorphic_", [T]))) (ps, idx) =
   1.101 +          let val (T', idx') = polyT_of T idx
   1.102 +          in (Var (xi, T'), (ps, idx')) end
   1.103 +      | prepare (Var (xi, T)) ps_idx = constraint T (Var (xi, var_param xi)) ps_idx
   1.104 +      | prepare (Free (x, T)) ps_idx = constraint T (Free (x, var_param (x, ~1))) ps_idx
   1.105 +      | prepare (Bound i) ps_idx = (Bound i, ps_idx)
   1.106 +      | prepare (Abs (x, T, t)) ps_idx =
   1.107 +          let
   1.108 +            val (T', ps_idx') = prepare_typ T ps_idx;
   1.109 +            val (t', ps_idx'') = prepare t ps_idx';
   1.110 +          in (Abs (x, T', t'), ps_idx'') end
   1.111 +      | prepare (t $ u) ps_idx =
   1.112 +          let
   1.113 +            val (t', ps_idx') = prepare t ps_idx;
   1.114 +            val (u', ps_idx'') = prepare u ps_idx';
   1.115 +          in (t' $ u', ps_idx'') end;
   1.116 +
   1.117 +    val (tm', (params', idx'')) = prepare tm (params, idx');
   1.118 +  in (tm', (vparams', params', idx'')) end;
   1.119 +
   1.120 +
   1.121 +
   1.122 +(** order-sorted unification of types **)
   1.123 +
   1.124 +exception NO_UNIFIER of string * typ Vartab.table;
   1.125 +
   1.126 +fun unify ctxt =
   1.127 +  let
   1.128 +    val thy = Proof_Context.theory_of ctxt;
   1.129 +    val arity_sorts = Type.arity_sorts (Context.pretty ctxt) (Sign.tsig_of thy);
   1.130 +
   1.131 +
   1.132 +    (* adjust sorts of parameters *)
   1.133 +
   1.134 +    fun not_of_sort x S' S =
   1.135 +      "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
   1.136 +        Syntax.string_of_sort ctxt S;
   1.137 +
   1.138 +    fun meet (_, []) tye_idx = tye_idx
   1.139 +      | meet (Type (a, Ts), S) (tye_idx as (tye, _)) =
   1.140 +          meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
   1.141 +      | meet (TFree (x, S'), S) (tye_idx as (tye, _)) =
   1.142 +          if Sign.subsort thy (S', S) then tye_idx
   1.143 +          else raise NO_UNIFIER (not_of_sort x S' S, tye)
   1.144 +      | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
   1.145 +          if Sign.subsort thy (S', S) then tye_idx
   1.146 +          else if Type_Infer.is_param xi then
   1.147 +            (Vartab.update_new
   1.148 +              (xi, Type_Infer.mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
   1.149 +          else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
   1.150 +    and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
   1.151 +          meets (Ts, Ss) (meet (Type_Infer.deref tye T, S) tye_idx)
   1.152 +      | meets _ tye_idx = tye_idx;
   1.153 +
   1.154 +
   1.155 +    (* occurs check and assignment *)
   1.156 +
   1.157 +    fun occurs_check tye xi (TVar (xi', _)) =
   1.158 +          if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye)
   1.159 +          else
   1.160 +            (case Vartab.lookup tye xi' of
   1.161 +              NONE => ()
   1.162 +            | SOME T => occurs_check tye xi T)
   1.163 +      | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts
   1.164 +      | occurs_check _ _ _ = ();
   1.165 +
   1.166 +    fun assign xi (T as TVar (xi', _)) S env =
   1.167 +          if xi = xi' then env
   1.168 +          else env |> meet (T, S) |>> Vartab.update_new (xi, T)
   1.169 +      | assign xi T S (env as (tye, _)) =
   1.170 +          (occurs_check tye xi T; env |> meet (T, S) |>> Vartab.update_new (xi, T));
   1.171 +
   1.172 +
   1.173 +    (* unification *)
   1.174 +
   1.175 +    fun show_tycon (a, Ts) =
   1.176 +      quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
   1.177 +
   1.178 +    fun unif (T1, T2) (env as (tye, _)) =
   1.179 +      (case pairself (`Type_Infer.is_paramT o Type_Infer.deref tye) (T1, T2) of
   1.180 +        ((true, TVar (xi, S)), (_, T)) => assign xi T S env
   1.181 +      | ((_, T), (true, TVar (xi, S))) => assign xi T S env
   1.182 +      | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
   1.183 +          if a <> b then
   1.184 +            raise NO_UNIFIER
   1.185 +              ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
   1.186 +          else fold unif (Ts ~~ Us) env
   1.187 +      | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye));
   1.188 +
   1.189 +  in unif end;
   1.190 +
   1.191 +
   1.192 +
   1.193 +(** simple type inference **)
   1.194 +
   1.195 +(* infer *)
   1.196 +
   1.197 +fun infer ctxt =
   1.198 +  let
   1.199 +    (* errors *)
   1.200 +
   1.201 +    fun prep_output tye bs ts Ts =
   1.202 +      let
   1.203 +        val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts);
   1.204 +        val (Ts', Ts'') = chop (length Ts) Ts_bTs';
   1.205 +        fun prep t =
   1.206 +          let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
   1.207 +          in Term.subst_bounds (map Syntax_Trans.mark_boundT xs, t) end;
   1.208 +      in (map prep ts', Ts') end;
   1.209 +
   1.210 +    fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i);
   1.211 +
   1.212 +    fun unif_failed msg =
   1.213 +      "Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
   1.214 +
   1.215 +    fun err_appl msg tye bs t T u U =
   1.216 +      let val ([t', u'], [T', U']) = prep_output tye bs [t, u] [T, U]
   1.217 +      in error (unif_failed msg ^ Type.appl_error ctxt t' T' u' U' ^ "\n") end;
   1.218 +
   1.219 +
   1.220 +    (* main *)
   1.221 +
   1.222 +    fun inf _ (Const (_, T)) tye_idx = (T, tye_idx)
   1.223 +      | inf _ (Free (_, T)) tye_idx = (T, tye_idx)
   1.224 +      | inf _ (Var (_, T)) tye_idx = (T, tye_idx)
   1.225 +      | inf bs (Bound i) tye_idx =
   1.226 +          (snd (nth bs i handle Subscript => err_loose i), tye_idx)
   1.227 +      | inf bs (Abs (x, T, t)) tye_idx =
   1.228 +          let val (U, tye_idx') = inf ((x, T) :: bs) t tye_idx
   1.229 +          in (T --> U, tye_idx') end
   1.230 +      | inf bs (t $ u) tye_idx =
   1.231 +          let
   1.232 +            val (T, tye_idx') = inf bs t tye_idx;
   1.233 +            val (U, (tye, idx)) = inf bs u tye_idx';
   1.234 +            val V = Type_Infer.mk_param idx [];
   1.235 +            val tye_idx'' = unify ctxt (U --> V, T) (tye, idx + 1)
   1.236 +              handle NO_UNIFIER (msg, tye') => err_appl msg tye' bs t T u U;
   1.237 +          in (V, tye_idx'') end;
   1.238 +
   1.239 +  in inf [] end;
   1.240 +
   1.241 +
   1.242 +(* main interfaces *)
   1.243 +
   1.244 +fun prepare ctxt raw_ts =
   1.245 +  let
   1.246 +    val constrain_vars = Term.map_aterms
   1.247 +      (fn Free (x, T) => Type.constraint T (Free (x, var_type ctxt (x, ~1)))
   1.248 +        | Var (xi, T) => Type.constraint T (Var (xi, var_type ctxt xi))
   1.249 +        | t => t);
   1.250 +
   1.251 +    val ts = burrow_types (Syntax.check_typs ctxt) raw_ts;
   1.252 +    val idx = Type_Infer.param_maxidx_of ts + 1;
   1.253 +    val (ts', (_, _, idx')) =
   1.254 +      fold_map (prepare_term ctxt o constrain_vars) ts
   1.255 +        (Vartab.empty, Vartab.empty, idx);
   1.256 +  in (idx', ts') end;
   1.257 +
   1.258 +fun infer_types ctxt raw_ts =
   1.259 +  let
   1.260 +    val (idx, ts) = prepare ctxt raw_ts;
   1.261 +    val (tye, _) = fold (snd oo infer ctxt) ts (Vartab.empty, idx);
   1.262 +    val (_, ts') = Type_Infer.finish ctxt tye ([], ts);
   1.263 +  in ts' end;
   1.264 +
   1.265 +val _ =
   1.266 +  Context.>>
   1.267 +    (Syntax.add_term_check 0 "standard"
   1.268 +      (fn ctxt => infer_types ctxt #> map (Proof_Context.expand_abbrevs ctxt)));
   1.269 +
   1.270 +end;