--- a/src/Tools/subtyping.ML Tue Dec 21 11:54:35 2010 +0100
+++ b/src/Tools/subtyping.ML Tue Dec 21 01:12:14 2010 +0100
@@ -6,7 +6,7 @@
signature SUBTYPING =
sig
- datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT
+ datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ;
val coercion_enabled: bool Config.T
val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
term list -> term list
@@ -21,7 +21,7 @@
(** coercions data **)
-datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT
+datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ;
datatype data = Data of
{coes: term Symreltab.table, (*coercions table*)
@@ -83,9 +83,11 @@
| sort_of _ = NONE;
val is_typeT = fn (Type _) => true | _ => false;
+val is_stypeT = fn (Type (_, [])) => true | _ => false;
val is_compT = fn (Type (_, _ :: _)) => true | _ => false;
val is_freeT = fn (TFree _) => true | _ => false;
val is_fixedvarT = fn (TVar (xi, _)) => not (Type_Infer.is_param xi) | _ => false;
+val is_funtype = fn (Type ("fun", [_, _])) => true | _ => false;
(* unification *)
@@ -205,10 +207,6 @@
fun unif_failed msg =
"Type unification failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
-
-fun subtyping_err_appl_msg ctxt msg tye bs t T u U () =
- let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U]
- in msg ^ Type.appl_error (Syntax.pp ctxt) t' T' u' U' ^ "\n" end;
fun err_appl_msg ctxt msg tye bs t T u U () =
let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U]
@@ -264,7 +262,7 @@
val U = Type_Infer.mk_param idx [];
val V = Type_Infer.mk_param (idx + 1) [];
val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2)
- handle NO_UNIFIER (msg, tye') => error (gen_msg err msg);
+ handle NO_UNIFIER (msg, _) => error (gen_msg err msg);
val error_pack = (bs, t $ u, U, V, U');
in (V, tye_idx'', ((U', U), error_pack) :: cs'') end;
in
@@ -291,12 +289,15 @@
(case pairself f (fst c) of
(false, false) => apsnd (cons c) (split_cs f cs)
| _ => apfst (cons c) (split_cs f cs));
+
+ fun unify_list (T :: Ts) tye_idx =
+ fold (fn U => fn tye_idx' => strong_unify ctxt (T, U) tye_idx') Ts tye_idx;
(* check whether constraint simplification will terminate using weak unification *)
- val _ = fold (fn (TU, error_pack) => fn tye_idx =>
- weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, tye) =>
+ val _ = fold (fn (TU, _) => fn tye_idx =>
+ weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, _) =>
error (gen_msg err ("weak unification of subtype constraints fails\n" ^ msg))) cs tye_idx;
@@ -315,9 +316,14 @@
(case variance of
COVARIANT => (constraint :: cs, tye_idx)
| CONTRAVARIANT => (swap constraint :: cs, tye_idx)
+ | INVARIANT_TO T => (cs, unify_list [T, fst constraint, snd constraint] tye_idx
+ handle NO_UNIFIER (msg, _) =>
+ err_list ctxt (gen_msg err
+ "failed to unify invariant arguments w.r.t. to the known map function")
+ (fst tye_idx) Ts)
| INVARIANT => (cs, strong_unify ctxt constraint tye_idx
- handle NO_UNIFIER (msg, tye) =>
- error (gen_msg err ("failed to unify invariant arguments\n" ^ msg))));
+ handle NO_UNIFIER (msg, _) =>
+ error (gen_msg err ("failed to unify invariant arguments" ^ msg))));
val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack))
(fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx)));
val test_update = is_compT orf is_freeT orf is_fixedvarT;
@@ -343,7 +349,7 @@
simplify done' ((new, error_pack) :: todo') (tye', idx + n)
end
(*TU is a pair of a parameter and a free/fixed variable*)
- and eliminate TU error_pack done todo tye idx =
+ and eliminate TU done todo tye idx =
let
val [TVar (xi, S)] = filter Type_Infer.is_paramT TU;
val [T] = filter_out Type_Infer.is_paramT TU;
@@ -376,7 +382,7 @@
if T = U then simplify done todo tye_idx
else if exists (is_freeT orf is_fixedvarT) [T, U] andalso
exists Type_Infer.is_paramT [T, U]
- then eliminate [T, U] error_pack done todo tye idx
+ then eliminate [T, U] done todo tye idx
else if exists (is_freeT orf is_fixedvarT) [T, U]
then error (gen_msg err "not eliminated free/fixed variables")
else simplify (((T, U), error_pack) :: done) todo tye_idx);
@@ -402,9 +408,6 @@
cs'
end;
- fun unify_list (T :: Ts) tye_idx =
- fold (fn U => fn tye_idx' => strong_unify ctxt (T, U) tye_idx') Ts tye_idx;
-
(*styps stands either for supertypes or for subtypes of a type T
in terms of the subtype-relation (excluding T itself)*)
fun styps super T =
@@ -467,7 +470,7 @@
val (tye, idx) =
fold
(fn cycle => fn tye_idx' => (unify_list cycle tye_idx'
- handle NO_UNIFIER (msg, tye) =>
+ handle NO_UNIFIER (msg, _) =>
err_bound ctxt
(gen_msg err ("constraint cycle not unifiable" ^ msg)) (fst tye_idx)
(find_cycle_packs cycle)))
@@ -572,7 +575,7 @@
in
fold
(fn Ts => fn tye_idx' => unify_list Ts tye_idx'
- handle NO_UNIFIER (msg, tye) => err_list ctxt (gen_msg err msg) (fst tye_idx) Ts)
+ handle NO_UNIFIER (msg, _) => err_list ctxt (gen_msg err msg) (fst tye_idx) Ts)
to_unify tye_idx
end;
@@ -605,8 +608,9 @@
fun inst t Ts =
Term.subst_vars
(((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t;
- fun sub_co (COVARIANT, TU) = gen_coercion ctxt tye TU
- | sub_co (CONTRAVARIANT, TU) = gen_coercion ctxt tye (swap TU);
+ fun sub_co (COVARIANT, TU) = SOME (gen_coercion ctxt tye TU)
+ | sub_co (CONTRAVARIANT, TU) = SOME (gen_coercion ctxt tye (swap TU))
+ | sub_co (INVARIANT_TO T, _) = NONE;
fun ts_of [] = []
| ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
in
@@ -614,7 +618,7 @@
NONE => raise Fail ("No map function for " ^ a ^ " known")
| SOME tmap =>
let
- val used_coes = map sub_co ((snd tmap) ~~ (Ts ~~ Us));
+ val used_coes = map_filter sub_co ((snd tmap) ~~ (Ts ~~ Us));
in
Term.list_comb
(inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
@@ -735,36 +739,39 @@
val ctxt = Context.proof_of context;
val t = singleton (Variable.polymorphic ctxt) raw_t;
- fun err_str () = "\n\nthe general type signature for a map function is" ^
- "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [x1, ..., xn]" ^
+ fun err_str t = "\n\nThe provided function has the type\n" ^
+ Syntax.string_of_typ ctxt (fastype_of t) ^
+ "\n\nThe general type signature of a map function is" ^
+ "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [y1, ..., yn]" ^
"\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi)";
-
+
+ val ((fis, T1), T2) = apfst split_last (strip_type (fastype_of t))
+ handle Empty => error ("Not a proper map function:" ^ err_str t);
+
fun gen_arg_var ([], []) = []
| gen_arg_var ((T, T') :: Ts, (U, U') :: Us) =
- if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us)
+ if U = U' then
+ if is_stypeT U then INVARIANT_TO U :: gen_arg_var ((T, T') :: Ts, Us)
+ else error ("Invariant xi and yi should be base types:" ^ err_str t)
+ else if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us)
else if T = U' andalso T' = U then CONTRAVARIANT :: gen_arg_var (Ts, Us)
- else error ("Functions do not apply to arguments correctly:" ^ err_str ())
- | gen_arg_var (_, _) =
- error ("Different numbers of functions and arguments\n" ^ err_str ());
+ else error ("Functions do not apply to arguments correctly:" ^ err_str t)
+ | gen_arg_var (_, Ts) =
+ if forall (op = andf is_stypeT o fst) Ts
+ then map (INVARIANT_TO o fst) Ts
+ else error ("Different numbers of functions and variant arguments\n" ^ err_str t);
- (* TODO: This function is only needed to introde the fun type map
- function: "% f g h . g o h o f". There must be a better solution. *)
- fun balanced (Type (_, [])) (Type (_, [])) = true
- | balanced (Type (a, Ts)) (Type (b, Us)) =
- a = b andalso forall I (map2 balanced Ts Us)
- | balanced (TFree _) (TFree _) = true
- | balanced (TVar _) (TVar _) = true
- | balanced _ _ = false;
+ (*retry flag needed to adjust the type lists, when given a map over type constructor fun*)
+ fun check_map_fun fis (Type (C1, Ts)) (Type (C2, Us)) retry =
+ if C1 = C2 andalso not (null fis) andalso forall is_funtype fis
+ then ((map dest_funT fis, Ts ~~ Us), C1)
+ else error ("Not a proper map function:" ^ err_str t)
+ | check_map_fun fis T1 T2 true =
+ let val (fis', T') = split_last fis
+ in check_map_fun fis' T' (T1 --> T2) false end
+ | check_map_fun _ _ _ _ = error ("Not a proper map function:" ^ err_str t);
- fun check_map_fun (pairs, []) (Type ("fun", [T as Type (C, Ts), U as Type (_, Us)])) =
- if balanced T U
- then ((pairs, Ts ~~ Us), C)
- else if C = "fun"
- then check_map_fun (pairs @ [(hd Ts, hd (tl Ts))], []) U
- else error ("Not a proper map function:" ^ err_str ())
- | check_map_fun _ _ = error ("Not a proper map function:" ^ err_str ());
-
- val res = check_map_fun ([], []) (fastype_of t);
+ val res = check_map_fun fis T1 T2 true;
val res_av = gen_arg_var (fst res);
in
map_tmaps (Symtab.update (snd res, (t, res_av))) context