--- a/src/HOL/Nitpick_Examples/minipick.ML Fri Sep 23 14:59:29 2011 +0200
+++ b/src/HOL/Nitpick_Examples/minipick.ML Fri Sep 23 16:50:39 2011 +0200
@@ -7,21 +7,8 @@
signature MINIPICK =
sig
- datatype rep = S_Rep | R_Rep
- type styp = Nitpick_Util.styp
-
- val vars_for_bound_var :
- (typ -> int) -> rep -> typ list -> int -> Kodkod.rel_expr list
- val rel_expr_for_bound_var :
- (typ -> int) -> rep -> typ list -> int -> Kodkod.rel_expr
- val decls_for : rep -> (typ -> int) -> typ list -> typ -> Kodkod.decl list
- val false_atom : Kodkod.rel_expr
- val true_atom : Kodkod.rel_expr
- val formula_from_atom : Kodkod.rel_expr -> Kodkod.formula
- val atom_from_formula : Kodkod.formula -> Kodkod.rel_expr
- val kodkod_problem_from_term :
- Proof.context -> (typ -> int) -> term -> Kodkod.problem
- val solve_any_kodkod_problem : theory -> Kodkod.problem list -> string
+ val minipick : Proof.context -> int -> term -> string
+ val minipick_expect : Proof.context -> string -> int -> term -> unit
end;
structure Minipick : MINIPICK =
@@ -33,28 +20,34 @@
open Nitpick_Peephole
open Nitpick_Kodkod
-datatype rep = S_Rep | R_Rep
+datatype rep =
+ S_Rep |
+ R_Rep of bool
-fun check_type ctxt (Type (@{type_name fun}, Ts)) =
- List.app (check_type ctxt) Ts
- | check_type ctxt (Type (@{type_name prod}, Ts)) =
- List.app (check_type ctxt) Ts
- | check_type _ @{typ bool} = ()
- | check_type _ (TFree (_, @{sort "{}"})) = ()
- | check_type _ (TFree (_, @{sort HOL.type})) = ()
- | check_type ctxt T =
- raise NOT_SUPPORTED ("type " ^ quote (Syntax.string_of_typ ctxt T))
+fun check_type ctxt raw_infinite (Type (@{type_name fun}, Ts)) =
+ List.app (check_type ctxt raw_infinite) Ts
+ | check_type ctxt raw_infinite (Type (@{type_name prod}, Ts)) =
+ List.app (check_type ctxt raw_infinite) Ts
+ | check_type _ _ @{typ bool} = ()
+ | check_type _ _ (TFree (_, @{sort "{}"})) = ()
+ | check_type _ _ (TFree (_, @{sort HOL.type})) = ()
+ | check_type ctxt raw_infinite T =
+ if raw_infinite T then ()
+ else raise NOT_SUPPORTED ("type " ^ quote (Syntax.string_of_typ ctxt T))
fun atom_schema_of S_Rep card (Type (@{type_name fun}, [T1, T2])) =
replicate_list (card T1) (atom_schema_of S_Rep card T2)
- | atom_schema_of R_Rep card (Type (@{type_name fun}, [T1, @{typ bool}])) =
+ | atom_schema_of (R_Rep true) card
+ (Type (@{type_name fun}, [T1, @{typ bool}])) =
atom_schema_of S_Rep card T1
- | atom_schema_of R_Rep card (Type (@{type_name fun}, [T1, T2])) =
- atom_schema_of S_Rep card T1 @ atom_schema_of R_Rep card T2
+ | atom_schema_of (rep as R_Rep _) card (Type (@{type_name fun}, [T1, T2])) =
+ atom_schema_of S_Rep card T1 @ atom_schema_of rep card T2
| atom_schema_of _ card (Type (@{type_name prod}, Ts)) =
maps (atom_schema_of S_Rep card) Ts
| atom_schema_of _ card T = [card T]
val arity_of = length ooo atom_schema_of
+val atom_seqs_of = map (AtomSeq o rpair 0) ooo atom_schema_of
+val atom_seq_product_of = foldl1 Product ooo atom_seqs_of
fun index_for_bound_var _ [_] 0 = 0
| index_for_bound_var card (_ :: Ts) 0 =
@@ -68,78 +61,121 @@
map2 (curry DeclOne o pair 1)
(index_seq (index_for_bound_var card (T :: Ts) 0)
(arity_of R card (nth (T :: Ts) 0)))
- (map (AtomSeq o rpair 0) (atom_schema_of R card T))
+ (atom_seqs_of R card T)
val atom_product = foldl1 Product o map Atom
-val false_atom = Atom 0
-val true_atom = Atom 1
+val false_atom_num = 0
+val true_atom_num = 1
+val false_atom = Atom false_atom_num
+val true_atom = Atom true_atom_num
-fun formula_from_atom r = RelEq (r, true_atom)
-fun atom_from_formula f = RelIf (f, true_atom, false_atom)
-
-fun kodkod_formula_from_term ctxt card frees =
+fun kodkod_formula_from_term ctxt total card complete concrete frees =
let
- fun R_rep_from_S_rep (Type (@{type_name fun}, [T1, @{typ bool}])) r =
- let
- val jss = atom_schema_of S_Rep card T1 |> map (rpair 0)
- |> all_combinations
- in
- map2 (fn i => fn js =>
- RelIf (formula_from_atom (Project (r, [Num i])),
- atom_product js, empty_n_ary_rel (length js)))
- (index_seq 0 (length jss)) jss
- |> foldl1 Union
- end
- | R_rep_from_S_rep (Type (@{type_name fun}, [T1, T2])) r =
- let
- val jss = atom_schema_of S_Rep card T1 |> map (rpair 0)
- |> all_combinations
- val arity2 = arity_of S_Rep card T2
- in
- map2 (fn i => fn js =>
- Product (atom_product js,
- Project (r, num_seq (i * arity2) arity2)
- |> R_rep_from_S_rep T2))
- (index_seq 0 (length jss)) jss
- |> foldl1 Union
- end
+ fun F_from_S_rep (SOME false) r = Not (RelEq (r, false_atom))
+ | F_from_S_rep _ r = RelEq (r, true_atom)
+ fun S_rep_from_F NONE f = RelIf (f, true_atom, false_atom)
+ | S_rep_from_F (SOME true) f = RelIf (f, true_atom, None)
+ | S_rep_from_F (SOME false) f = RelIf (Not f, false_atom, None)
+ fun R_rep_from_S_rep (Type (@{type_name fun}, [T1, T2])) r =
+ if total andalso T2 = bool_T then
+ let
+ val jss = atom_schema_of S_Rep card T1 |> map (rpair 0)
+ |> all_combinations
+ in
+ map2 (fn i => fn js =>
+(*
+ RelIf (F_from_S_rep NONE (Project (r, [Num i])),
+ atom_product js, empty_n_ary_rel (length js))
+*)
+ Join (Project (r, [Num i]),
+ atom_product (false_atom_num :: js))
+ ) (index_seq 0 (length jss)) jss
+ |> foldl1 Union
+ end
+ else
+ let
+ val jss = atom_schema_of S_Rep card T1 |> map (rpair 0)
+ |> all_combinations
+ val arity2 = arity_of S_Rep card T2
+ in
+ map2 (fn i => fn js =>
+ Product (atom_product js,
+ Project (r, num_seq (i * arity2) arity2)
+ |> R_rep_from_S_rep T2))
+ (index_seq 0 (length jss)) jss
+ |> foldl1 Union
+ end
| R_rep_from_S_rep _ r = r
fun S_rep_from_R_rep Ts (T as Type (@{type_name fun}, _)) r =
Comprehension (decls_for S_Rep card Ts T,
RelEq (R_rep_from_S_rep T
(rel_expr_for_bound_var card S_Rep (T :: Ts) 0), r))
| S_rep_from_R_rep _ _ r = r
- fun to_F Ts t =
+ fun partial_eq pos Ts (Type (@{type_name fun}, [T1, T2])) t1 t2 =
+ HOLogic.mk_all ("x", T1,
+ HOLogic.eq_const T2 $ (incr_boundvars 1 t1 $ Bound 0)
+ $ (incr_boundvars 1 t2 $ Bound 0))
+ |> to_F (SOME pos) Ts
+ | partial_eq pos Ts T t1 t2 =
+ if pos andalso not (concrete T) then
+ False
+ else
+ (t1, t2) |> pairself (to_R_rep Ts)
+ |> (if pos then Some o Intersect else Lone o Union)
+ and to_F pos Ts t =
(case t of
- @{const Not} $ t1 => Not (to_F Ts t1)
+ @{const Not} $ t1 => Not (to_F (Option.map not pos) Ts t1)
| @{const False} => False
| @{const True} => True
| Const (@{const_name All}, _) $ Abs (_, T, t') =>
- All (decls_for S_Rep card Ts T, to_F (T :: Ts) t')
+ if pos = SOME true andalso not (complete T) then False
+ else All (decls_for S_Rep card Ts T, to_F pos (T :: Ts) t')
| (t0 as Const (@{const_name All}, _)) $ t1 =>
- to_F Ts (t0 $ eta_expand Ts t1 1)
+ to_F pos Ts (t0 $ eta_expand Ts t1 1)
| Const (@{const_name Ex}, _) $ Abs (_, T, t') =>
- Exist (decls_for S_Rep card Ts T, to_F (T :: Ts) t')
+ if pos = SOME false andalso not (complete T) then True
+ else Exist (decls_for S_Rep card Ts T, to_F pos (T :: Ts) t')
| (t0 as Const (@{const_name Ex}, _)) $ t1 =>
- to_F Ts (t0 $ eta_expand Ts t1 1)
- | Const (@{const_name HOL.eq}, _) $ t1 $ t2 =>
- RelEq (to_R_rep Ts t1, to_R_rep Ts t2)
+ to_F pos Ts (t0 $ eta_expand Ts t1 1)
+ | Const (@{const_name HOL.eq}, Type (_, [T, _])) $ t1 $ t2 =>
+ (case pos of
+ NONE => RelEq (to_R_rep Ts t1, to_R_rep Ts t2)
+ | SOME pos => partial_eq pos Ts T t1 t2)
| Const (@{const_name ord_class.less_eq},
Type (@{type_name fun},
- [Type (@{type_name fun}, [_, @{typ bool}]), _]))
+ [Type (@{type_name fun}, [T', @{typ bool}]), _]))
$ t1 $ t2 =>
- Subset (to_R_rep Ts t1, to_R_rep Ts t2)
- | @{const HOL.conj} $ t1 $ t2 => And (to_F Ts t1, to_F Ts t2)
- | @{const HOL.disj} $ t1 $ t2 => Or (to_F Ts t1, to_F Ts t2)
- | @{const HOL.implies} $ t1 $ t2 => Implies (to_F Ts t1, to_F Ts t2)
- | t1 $ t2 => Subset (to_S_rep Ts t2, to_R_rep Ts t1)
- | Free _ => raise SAME ()
- | Term.Var _ => raise SAME ()
- | Bound _ => raise SAME ()
- | Const (s, _) => raise NOT_SUPPORTED ("constant " ^ quote s)
- | _ => raise TERM ("Minipick.kodkod_formula_from_term.to_F", [t]))
- handle SAME () => formula_from_atom (to_R_rep Ts t)
+ (case pos of
+ NONE => Subset (to_R_rep Ts t1, to_R_rep Ts t2)
+ | SOME true =>
+ Subset (Difference (atom_seq_product_of S_Rep card T',
+ Join (to_R_rep Ts t1, false_atom)),
+ Join (to_R_rep Ts t2, true_atom))
+ | SOME false =>
+ Subset (Join (to_R_rep Ts t1, true_atom),
+ Difference (atom_seq_product_of S_Rep card T',
+ Join (to_R_rep Ts t2, false_atom))))
+ | @{const HOL.conj} $ t1 $ t2 => And (to_F pos Ts t1, to_F pos Ts t2)
+ | @{const HOL.disj} $ t1 $ t2 => Or (to_F pos Ts t1, to_F pos Ts t2)
+ | @{const HOL.implies} $ t1 $ t2 =>
+ Implies (to_F (Option.map not pos) Ts t1, to_F pos Ts t2)
+ | t1 $ t2 =>
+ (case pos of
+ NONE => Subset (to_S_rep Ts t2, to_R_rep Ts t1)
+ | SOME pos =>
+ let
+ val kt1 = to_R_rep Ts t1
+ val kt2 = to_S_rep Ts t2
+ val kT = atom_seq_product_of S_Rep card (fastype_of1 (Ts, t2))
+ in
+ if pos then
+ Not (Subset (kt2, Difference (kT, Join (kt1, true_atom))))
+ else
+ Subset (kt2, Difference (kT, Join (kt1, false_atom)))
+ end)
+ | _ => raise SAME ())
+ handle SAME () => F_from_S_rep pos (to_R_rep Ts t)
and to_S_rep Ts t =
case t of
Const (@{const_name Pair}, _) $ t1 $ t2 =>
@@ -160,6 +196,16 @@
| Const (@{const_name snd}, _) => to_S_rep Ts (eta_expand Ts t 1)
| Bound j => rel_expr_for_bound_var card S_Rep Ts j
| _ => S_rep_from_R_rep Ts (fastype_of1 (Ts, t)) (to_R_rep Ts t)
+ and partial_set_op swap1 swap2 op1 op2 Ts t1 t2 =
+ let
+ val kt1 = to_R_rep Ts t1
+ val kt2 = to_R_rep Ts t2
+ val (a11, a21) = (false_atom, true_atom) |> swap1 ? swap
+ val (a12, a22) = (false_atom, true_atom) |> swap2 ? swap
+ in
+ Union (Product (op1 (Join (kt1, a11), Join (kt2, a12)), true_atom),
+ Product (op2 (Join (kt1, a21), Join (kt2, a22)), false_atom))
+ end
and to_R_rep Ts t =
(case t of
@{const Not} => to_R_rep Ts (eta_expand Ts t 1)
@@ -180,15 +226,51 @@
| @{const HOL.implies} $ _ => to_R_rep Ts (eta_expand Ts t 1)
| @{const HOL.implies} => to_R_rep Ts (eta_expand Ts t 2)
| Const (@{const_name bot_class.bot},
- T as Type (@{type_name fun}, [_, @{typ bool}])) =>
- empty_n_ary_rel (arity_of R_Rep card T)
- | Const (@{const_name insert}, _) $ t1 $ t2 =>
- Union (to_S_rep Ts t1, to_R_rep Ts t2)
+ T as Type (@{type_name fun}, [T', @{typ bool}])) =>
+ if total then empty_n_ary_rel (arity_of (R_Rep total) card T)
+ else Product (atom_seq_product_of (R_Rep total) card T', false_atom)
+ | Const (@{const_name top_class.top},
+ T as Type (@{type_name fun}, [T', @{typ bool}])) =>
+ if total then atom_seq_product_of (R_Rep total) card T
+ else Product (atom_seq_product_of (R_Rep total) card T', true_atom)
+ | Const (@{const_name insert}, Type (_, [T, _])) $ t1 $ t2 =>
+ if total then
+ Union (to_S_rep Ts t1, to_R_rep Ts t2)
+ else
+ let
+ val kt1 = to_S_rep Ts t1
+ val kt2 = to_R_rep Ts t2
+ in
+ RelIf (Some kt1,
+ if arity_of S_Rep card T = 1 then
+ Override (kt2, Product (kt1, true_atom))
+ else
+ Union (Difference (kt2, Product (kt1, false_atom)),
+ Product (kt1, true_atom)),
+ Difference (kt2, Product (atom_seq_product_of S_Rep card T,
+ false_atom)))
+ end
| Const (@{const_name insert}, _) $ _ => to_R_rep Ts (eta_expand Ts t 1)
| Const (@{const_name insert}, _) => to_R_rep Ts (eta_expand Ts t 2)
- | Const (@{const_name trancl}, _) $ t1 =>
- if arity_of R_Rep card (fastype_of1 (Ts, t1)) = 2 then
- Closure (to_R_rep Ts t1)
+ | Const (@{const_name trancl},
+ Type (_, [Type (_, [Type (_, [T', _]), _]), _])) $ t1 =>
+ if arity_of S_Rep card T' = 1 then
+ if total then
+ Closure (to_R_rep Ts t1)
+ else
+ let
+ val kt1 = to_R_rep Ts t1
+ val true_core_kt = Closure (Join (kt1, true_atom))
+ val kTx =
+ atom_seq_product_of S_Rep card (HOLogic.mk_prodT (`I T'))
+ val false_mantle_kt =
+ Difference (kTx,
+ Closure (Difference (kTx, Join (kt1, false_atom))))
+ in
+ Union (Product (Difference (false_mantle_kt, true_core_kt),
+ false_atom),
+ Product (true_core_kt, true_atom))
+ end
else
raise NOT_SUPPORTED "transitive closure for function or pair type"
| Const (@{const_name trancl}, _) => to_R_rep Ts (eta_expand Ts t 1)
@@ -196,7 +278,8 @@
Type (@{type_name fun},
[Type (@{type_name fun}, [_, @{typ bool}]), _]))
$ t1 $ t2 =>
- Intersect (to_R_rep Ts t1, to_R_rep Ts t2)
+ if total then Intersect (to_R_rep Ts t1, to_R_rep Ts t2)
+ else partial_set_op true true Intersect Union Ts t1 t2
| Const (@{const_name inf_class.inf}, _) $ _ =>
to_R_rep Ts (eta_expand Ts t 1)
| Const (@{const_name inf_class.inf}, _) =>
@@ -205,7 +288,8 @@
Type (@{type_name fun},
[Type (@{type_name fun}, [_, @{typ bool}]), _]))
$ t1 $ t2 =>
- Union (to_R_rep Ts t1, to_R_rep Ts t2)
+ if total then Union (to_R_rep Ts t1, to_R_rep Ts t2)
+ else partial_set_op true true Union Intersect Ts t1 t2
| Const (@{const_name sup_class.sup}, _) $ _ =>
to_R_rep Ts (eta_expand Ts t 1)
| Const (@{const_name sup_class.sup}, _) =>
@@ -214,7 +298,8 @@
Type (@{type_name fun},
[Type (@{type_name fun}, [_, @{typ bool}]), _]))
$ t1 $ t2 =>
- Difference (to_R_rep Ts t1, to_R_rep Ts t2)
+ if total then Difference (to_R_rep Ts t1, to_R_rep Ts t2)
+ else partial_set_op true false Intersect Union Ts t1 t2
| Const (@{const_name minus_class.minus},
Type (@{type_name fun},
[Type (@{type_name fun}, [_, @{typ bool}]), _])) $ _ =>
@@ -223,40 +308,47 @@
Type (@{type_name fun},
[Type (@{type_name fun}, [_, @{typ bool}]), _])) =>
to_R_rep Ts (eta_expand Ts t 2)
- | Const (@{const_name Pair}, _) $ _ $ _ => raise SAME ()
- | Const (@{const_name Pair}, _) $ _ => raise SAME ()
- | Const (@{const_name Pair}, _) => raise SAME ()
+ | Const (@{const_name Pair}, _) $ _ $ _ => to_S_rep Ts t
+ | Const (@{const_name Pair}, _) $ _ => to_S_rep Ts t
+ | Const (@{const_name Pair}, _) => to_S_rep Ts t
| Const (@{const_name fst}, _) $ _ => raise SAME ()
| Const (@{const_name fst}, _) => raise SAME ()
| Const (@{const_name snd}, _) $ _ => raise SAME ()
| Const (@{const_name snd}, _) => raise SAME ()
- | Const (_, @{typ bool}) => atom_from_formula (to_F Ts t)
+ | @{const False} => false_atom
+ | @{const True} => true_atom
| Free (x as (_, T)) =>
- Rel (arity_of R_Rep card T, find_index (curry (op =) x) frees)
+ Rel (arity_of (R_Rep total) card T, find_index (curry (op =) x) frees)
| Term.Var _ => raise NOT_SUPPORTED "schematic variables"
| Bound _ => raise SAME ()
| Abs (_, T, t') =>
- (case fastype_of1 (T :: Ts, t') of
- @{typ bool} => Comprehension (decls_for S_Rep card Ts T,
- to_F (T :: Ts) t')
- | T' => Comprehension (decls_for S_Rep card Ts T @
- decls_for R_Rep card (T :: Ts) T',
- Subset (rel_expr_for_bound_var card R_Rep
- (T' :: T :: Ts) 0,
- to_R_rep (T :: Ts) t')))
+ (case (total, fastype_of1 (T :: Ts, t')) of
+ (true, @{typ bool}) =>
+ Comprehension (decls_for S_Rep card Ts T, to_F NONE (T :: Ts) t')
+ | (_, T') =>
+ Comprehension (decls_for S_Rep card Ts T @
+ decls_for (R_Rep total) card (T :: Ts) T',
+ Subset (rel_expr_for_bound_var card (R_Rep total)
+ (T' :: T :: Ts) 0,
+ to_R_rep (T :: Ts) t')))
| t1 $ t2 =>
(case fastype_of1 (Ts, t) of
- @{typ bool} => atom_from_formula (to_F Ts t)
+ @{typ bool} =>
+ if total then
+ S_rep_from_F NONE (to_F NONE Ts t)
+ else
+ RelIf (to_F (SOME true) Ts t, true_atom,
+ RelIf (Not (to_F (SOME false) Ts t), false_atom,
+ None))
| T =>
let val T2 = fastype_of1 (Ts, t2) in
case arity_of S_Rep card T2 of
1 => Join (to_S_rep Ts t2, to_R_rep Ts t1)
| arity2 =>
- let val res_arity = arity_of R_Rep card T in
+ let val res_arity = arity_of (R_Rep total) card T in
Project (Intersect
(Product (to_S_rep Ts t2,
- atom_schema_of R_Rep card T
- |> map (AtomSeq o rpair 0) |> foldl1 Product),
+ atom_seq_product_of (R_Rep total) card T),
to_R_rep Ts t1),
num_seq arity2 res_arity)
end
@@ -264,28 +356,30 @@
| _ => raise NOT_SUPPORTED ("term " ^
quote (Syntax.string_of_term ctxt t)))
handle SAME () => R_rep_from_S_rep (fastype_of1 (Ts, t)) (to_S_rep Ts t)
- in to_F [] end
+ in to_F (if total then NONE else SOME true) [] end
-fun bound_for_free card i (s, T) =
- let val js = atom_schema_of R_Rep card T in
+fun bound_for_free total card i (s, T) =
+ let val js = atom_schema_of (R_Rep total) card T in
([((length js, i), s)],
- [TupleSet [], atom_schema_of R_Rep card T |> map (rpair 0)
+ [TupleSet [], atom_schema_of (R_Rep total) card T |> map (rpair 0)
|> tuple_set_from_atom_schema])
end
-fun declarative_axiom_for_rel_expr card Ts (Type (@{type_name fun}, [T1, T2]))
- r =
- if body_type T2 = bool_T then
+fun declarative_axiom_for_rel_expr total card Ts
+ (Type (@{type_name fun}, [T1, T2])) r =
+ if total andalso body_type T2 = bool_T then
True
else
All (decls_for S_Rep card Ts T1,
- declarative_axiom_for_rel_expr card (T1 :: Ts) T2
+ declarative_axiom_for_rel_expr total card (T1 :: Ts) T2
(List.foldl Join r (vars_for_bound_var card S_Rep (T1 :: Ts) 0)))
- | declarative_axiom_for_rel_expr _ _ _ r = One r
-fun declarative_axiom_for_free card i (_, T) =
- declarative_axiom_for_rel_expr card [] T (Rel (arity_of R_Rep card T, i))
+ | declarative_axiom_for_rel_expr total _ _ _ r =
+ (if total then One else Lone) r
+fun declarative_axiom_for_free total card i (_, T) =
+ declarative_axiom_for_rel_expr total card [] T
+ (Rel (arity_of (R_Rep total) card T, i))
-fun kodkod_problem_from_term ctxt raw_card t =
+fun kodkod_problem_from_term ctxt total raw_card raw_infinite t =
let
val thy = ProofContext.theory_of ctxt
fun card (Type (@{type_name fun}, [T1, T2])) =
@@ -293,14 +387,25 @@
| card (Type (@{type_name prod}, [T1, T2])) = card T1 * card T2
| card @{typ bool} = 2
| card T = Int.max (1, raw_card T)
+ fun complete (Type (@{type_name fun}, [T1, T2])) =
+ concrete T1 andalso complete T2
+ | complete (Type (@{type_name prod}, Ts)) = forall complete Ts
+ | complete T = not (raw_infinite T)
+ and concrete (Type (@{type_name fun}, [T1, T2])) =
+ complete T1 andalso concrete T2
+ | concrete (Type (@{type_name prod}, Ts)) = forall concrete Ts
+ | concrete _ = true
val neg_t = @{const Not} $ Object_Logic.atomize_term thy t
- val _ = fold_types (K o check_type ctxt) neg_t ()
+ val _ = fold_types (K o check_type ctxt raw_infinite) neg_t ()
val frees = Term.add_frees neg_t []
- val bounds = map2 (bound_for_free card) (index_seq 0 (length frees)) frees
+ val bounds =
+ map2 (bound_for_free total card) (index_seq 0 (length frees)) frees
val declarative_axioms =
- map2 (declarative_axiom_for_free card) (index_seq 0 (length frees)) frees
- val formula = kodkod_formula_from_term ctxt card frees neg_t
- |> fold_rev (curry And) declarative_axioms
+ map2 (declarative_axiom_for_free total card)
+ (index_seq 0 (length frees)) frees
+ val formula =
+ neg_t |> kodkod_formula_from_term ctxt total card complete concrete frees
+ |> fold_rev (curry And) declarative_axioms
val univ_card = univ_card 0 0 0 bounds formula
in
{comment = "", settings = [], univ_card = univ_card, tuple_assigns = [],
@@ -324,4 +429,28 @@
| Error (s, _) => error ("Kodkod error: " ^ s)
end
+val default_raw_infinite = member (op =) [@{typ nat}, @{typ int}]
+
+fun minipick ctxt n t =
+ let
+ val thy = ProofContext.theory_of ctxt
+ val {total_consts, ...} = Nitpick_Isar.default_params thy []
+ val totals =
+ total_consts |> Option.map single |> the_default [true, false]
+ fun problem_for (total, k) =
+ kodkod_problem_from_term ctxt total (K k) default_raw_infinite t
+ in
+ (totals, 1 upto n)
+ |-> map_product pair
+ |> map problem_for
+ |> solve_any_kodkod_problem (Proof_Context.theory_of ctxt)
+ end
+
+fun minipick_expect ctxt expect n t =
+ if getenv "KODKODI" <> "" then
+ if minipick ctxt n t = expect then ()
+ else error ("\"minipick_expect\" expected " ^ quote expect)
+ else
+ ()
+
end;
--- a/src/Tools/subtyping.ML Fri Sep 23 14:59:29 2011 +0200
+++ b/src/Tools/subtyping.ML Fri Sep 23 16:50:39 2011 +0200
@@ -9,6 +9,8 @@
val coercion_enabled: bool Config.T
val add_type_map: term -> Context.generic -> Context.generic
val add_coercion: term -> Context.generic -> Context.generic
+ val print_coercions: Proof.context -> unit
+ val print_coercion_maps: Proof.context -> unit
val setup: theory -> theory
end;
@@ -20,46 +22,52 @@
datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT | INVARIANT_TO of typ;
datatype data = Data of
- {coes: term Symreltab.table, (*coercions table*)
- coes_graph: unit Graph.T, (*coercions graph*)
+ {coes: (term * ((typ list * typ list) * term list)) Symreltab.table, (*coercions table*)
+ (*full coercions graph - only used at coercion declaration/deletion*)
+ full_graph: int Graph.T,
+ (*coercions graph restricted to base types - for efficiency reasons strored in the context*)
+ coes_graph: int Graph.T,
tmaps: (term * variance list) Symtab.table}; (*map functions*)
-fun make_data (coes, coes_graph, tmaps) =
- Data {coes = coes, coes_graph = coes_graph, tmaps = tmaps};
+fun make_data (coes, full_graph, coes_graph, tmaps) =
+ Data {coes = coes, full_graph = full_graph, coes_graph = coes_graph, tmaps = tmaps};
structure Data = Generic_Data
(
type T = data;
- val empty = make_data (Symreltab.empty, Graph.empty, Symtab.empty);
+ val empty = make_data (Symreltab.empty, Graph.empty, Graph.empty, Symtab.empty);
val extend = I;
fun merge
- (Data {coes = coes1, coes_graph = coes_graph1, tmaps = tmaps1},
- Data {coes = coes2, coes_graph = coes_graph2, tmaps = tmaps2}) =
- make_data (Symreltab.merge (op aconv) (coes1, coes2),
+ (Data {coes = coes1, full_graph = full_graph1, coes_graph = coes_graph1, tmaps = tmaps1},
+ Data {coes = coes2, full_graph = full_graph2, coes_graph = coes_graph2, tmaps = tmaps2}) =
+ make_data (Symreltab.merge (eq_pair (op aconv)
+ (eq_pair (eq_pair (eq_list (op =)) (eq_list (op =))) (eq_list (op aconv))))
+ (coes1, coes2),
+ Graph.merge (op =) (full_graph1, full_graph2),
Graph.merge (op =) (coes_graph1, coes_graph2),
Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2));
);
fun map_data f =
- Data.map (fn Data {coes, coes_graph, tmaps} =>
- make_data (f (coes, coes_graph, tmaps)));
+ Data.map (fn Data {coes, full_graph, coes_graph, tmaps} =>
+ make_data (f (coes, full_graph, coes_graph, tmaps)));
fun map_coes f =
- map_data (fn (coes, coes_graph, tmaps) =>
- (f coes, coes_graph, tmaps));
+ map_data (fn (coes, full_graph, coes_graph, tmaps) =>
+ (f coes, full_graph, coes_graph, tmaps));
fun map_coes_graph f =
- map_data (fn (coes, coes_graph, tmaps) =>
- (coes, f coes_graph, tmaps));
+ map_data (fn (coes, full_graph, coes_graph, tmaps) =>
+ (coes, full_graph, f coes_graph, tmaps));
-fun map_coes_and_graph f =
- map_data (fn (coes, coes_graph, tmaps) =>
- let val (coes', coes_graph') = f (coes, coes_graph);
- in (coes', coes_graph', tmaps) end);
+fun map_coes_and_graphs f =
+ map_data (fn (coes, full_graph, coes_graph, tmaps) =>
+ let val (coes', full_graph', coes_graph') = f (coes, full_graph, coes_graph);
+ in (coes', full_graph', coes_graph', tmaps) end);
fun map_tmaps f =
- map_data (fn (coes, coes_graph, tmaps) =>
- (coes, coes_graph, f tmaps));
+ map_data (fn (coes, full_graph, coes_graph, tmaps) =>
+ (coes, full_graph, coes_graph, f tmaps));
val rep_data = (fn Data args => args) o Data.get o Context.Proof;
@@ -71,6 +79,9 @@
(** utils **)
+fun restrict_graph G =
+ Graph.subgraph (fn key => if Graph.get_node G key = 0 then true else false) G;
+
fun nameT (Type (s, [])) = s;
fun t_of s = Type (s, []);
@@ -86,10 +97,23 @@
val is_funtype = fn (Type ("fun", [_, _])) => true | _ => false;
val is_identity = fn (Abs (_, _, Bound 0)) => true | _ => false;
+fun instantiate t Ts = Term.subst_TVars
+ ((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts) t;
+
+exception COERCION_GEN_ERROR of unit -> string;
+
+fun inst_collect tye err T U =
+ (case (T, Type_Infer.deref tye U) of
+ (TVar (xi, S), U) => [(xi, U)]
+ | (Type (a, Ts), Type (b, Us)) =>
+ if a <> b then raise error (err ()) else inst_collects tye err Ts Us
+ | (_, U') => if T <> U' then error (err ()) else [])
+and inst_collects tye err Ts Us =
+ fold2 (fn T => fn U => fn is => inst_collect tye err T U @ is) Ts Us [];
+
(* unification *)
-exception TYPE_INFERENCE_ERROR of unit -> string;
exception NO_UNIFIER of string * typ Vartab.table;
fun unify weak ctxt =
@@ -166,7 +190,6 @@
(* Typ_Graph shortcuts *)
-val add_edge = Typ_Graph.add_edge_acyclic;
fun get_preds G T = Typ_Graph.all_preds G [T];
fun get_succs G T = Typ_Graph.all_succs G [T];
fun maybe_new_typnode T G = perhaps (try (Typ_Graph.new_node (T, ()))) G;
@@ -179,16 +202,20 @@
(* Graph shortcuts *)
-fun maybe_new_node s G = perhaps (try (Graph.new_node (s, ()))) G
+fun maybe_new_node s G = perhaps (try (Graph.new_node s)) G
fun maybe_new_nodes ss G = fold maybe_new_node ss G
(** error messages **)
+infixr ++> (* lazy error msg composition *)
+
+fun err ++> str = err #> suffix str
+
fun gen_msg err msg =
- err () ^ "\nNow trying to infer coercions:\n\nCoercion inference failed" ^
- (if msg = "" then "" else ": " ^ msg) ^ "\n";
+ err () ^ "\nNow trying to infer coercions globally.\n\nCoercion inference failed" ^
+ (if msg = "" then "" else ":\n" ^ msg) ^ "\n";
fun prep_output ctxt tye bs ts Ts =
let
@@ -212,7 +239,7 @@
let
val (_, Ts') = prep_output ctxt tye [] [] Ts;
val text =
- msg ^ "\n" ^ "Cannot unify a list of types that should be the same:" ^ "\n" ^
+ msg ^ "\nCannot unify a list of types that should be the same:\n" ^
Pretty.string_of (Pretty.list "[" "]" (map (Syntax.pretty_typ ctxt) Ts'));
in
error text
@@ -225,13 +252,14 @@
let val (t', T') = prep_output ctxt tye bs [t, u] [U', U]
in (t' :: ts, T' :: Ts) end)
packs ([], []);
- val text = cat_lines ([msg, "Cannot fulfil subtype constraints:"] @
- (map2 (fn [t, u] => fn [T, U] => Pretty.string_of (
+ val text = msg ^ "\n" ^ Pretty.string_of (
+ Pretty.big_list "Cannot fulfil subtype constraints:"
+ (map2 (fn [t, u] => fn [T, U] =>
Pretty.block [
Syntax.pretty_typ ctxt T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2,
Syntax.pretty_typ ctxt U, Pretty.brk 3,
Pretty.str "from function application", Pretty.brk 2,
- Pretty.block [Syntax.pretty_term ctxt (t $ u)]]))
+ Pretty.block [Syntax.pretty_term ctxt (t $ u)]])
ts Ts))
in
error text
@@ -257,10 +285,11 @@
val (U', (tye, idx), cs'') = gen cs' bs u tye_idx';
val U = Type_Infer.mk_param idx [];
val V = Type_Infer.mk_param (idx + 1) [];
- val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2)
+ val tye_idx'' = strong_unify ctxt (U --> V, T) (tye, idx + 2)
handle NO_UNIFIER (msg, _) => error (gen_msg err msg);
val error_pack = (bs, t $ u, U, V, U');
- in (V, tye_idx'', ((U', U), error_pack) :: cs'') end;
+ in (V, (tye_idx''),
+ ((U', U), error_pack) :: cs'') end;
in
gen [] []
end;
@@ -314,8 +343,8 @@
| INVARIANT_TO T => (cs, unify_list [T, fst constraint, snd constraint] tye_idx
handle NO_UNIFIER (msg, _) =>
err_list ctxt (gen_msg err
- "failed to unify invariant arguments w.r.t. to the known map function")
- (fst tye_idx) Ts)
+ "failed to unify invariant arguments w.r.t. to the known map function" ^ msg)
+ (fst tye_idx) (T :: Ts))
| INVARIANT => (cs, strong_unify ctxt constraint tye_idx
handle NO_UNIFIER (msg, _) =>
error (gen_msg err ("failed to unify invariant arguments" ^ msg))));
@@ -361,10 +390,11 @@
and simplify done [] tye_idx = (done, tye_idx)
| simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) =
(case (Type_Infer.deref tye T, Type_Infer.deref tye U) of
- (Type (a, []), Type (b, [])) =>
+ (T1 as Type (a, []), T2 as Type (b, [])) =>
if a = b then simplify done todo tye_idx
else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx
- else error (gen_msg err (a ^ " is not a subtype of " ^ b))
+ else error (gen_msg err (quote (Syntax.string_of_typ ctxt T1) ^
+ " is not a subtype of " ^ quote (Syntax.string_of_typ ctxt T2)))
| (Type (a, Ts), Type (b, Us)) =>
if a <> b then error (gen_msg err "different constructors")
(fst tye_idx) error_pack
@@ -459,7 +489,7 @@
else
let
val G' = maybe_new_typnodes [T, U] G;
- val (G'', tye_idx') = (add_edge (T, U) G', tye_idx)
+ val (G'', tye_idx') = (Typ_Graph.add_edge_acyclic (T, U) G', tye_idx)
handle Typ_Graph.CYCLES cycles =>
let
val (tye, idx) =
@@ -537,7 +567,8 @@
in
if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s)
then apfst (Vartab.update (xi, T)) tye_idx
- else err_bound ctxt (gen_msg err ("assigned simple type " ^ s ^
+ else err_bound ctxt (gen_msg err ("assigned base type " ^
+ quote (Syntax.string_of_typ ctxt T) ^
" clashes with the upper bound of variable " ^
Syntax.string_of_typ ctxt (TVar(xi, S)))) tye (find_error_pack (not lower) key)
end
@@ -586,77 +617,105 @@
(** coercion insertion **)
-fun gen_coercion ctxt tye (T1, T2) =
- (case pairself (Type_Infer.deref tye) (T1, T2) of
- ((Type (a, [])), (Type (b, []))) =>
- if a = b
- then Abs (Name.uu, Type (a, []), Bound 0)
- else
- (case Symreltab.lookup (coes_of ctxt) (a, b) of
- NONE => raise Fail (a ^ " is not a subtype of " ^ b)
- | SOME co => co)
- | ((Type (a, Ts)), (Type (b, Us))) =>
- if a <> b
- then raise Fail ("Different constructors: " ^ a ^ " and " ^ b)
- else
- let
- fun inst t Ts =
- Term.subst_vars
- (((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t;
- fun sub_co (COVARIANT, TU) = SOME (gen_coercion ctxt tye TU)
- | sub_co (CONTRAVARIANT, TU) = SOME (gen_coercion ctxt tye (swap TU))
- | sub_co (INVARIANT_TO T, _) = NONE;
- fun ts_of [] = []
- | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
- in
- (case Symtab.lookup (tmaps_of ctxt) a of
- NONE => raise Fail ("No map function for " ^ a ^ " known")
- | SOME tmap =>
- let
- val used_coes = map_filter sub_co ((snd tmap) ~~ (Ts ~~ Us));
- in
- if null (filter (not o is_identity) used_coes)
- then Abs (Name.uu, Type (a, Ts), Bound 0)
- else Term.list_comb
- (inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
- end)
- end
- | (T, U) =>
- if Type.could_unify (T, U)
- then Abs (Name.uu, T, Bound 0)
- else raise Fail ("Cannot generate coercion from "
- ^ Syntax.string_of_typ ctxt T ^ " to " ^ Syntax.string_of_typ ctxt U));
+fun gen_coercion ctxt err tye TU =
+ let
+ fun gen (T1, T2) = (case pairself (Type_Infer.deref tye) (T1, T2) of
+ (T1 as (Type (a, [])), T2 as (Type (b, []))) =>
+ if a = b
+ then Abs (Name.uu, Type (a, []), Bound 0)
+ else
+ (case Symreltab.lookup (coes_of ctxt) (a, b) of
+ NONE => raise COERCION_GEN_ERROR (err ++> quote (Syntax.string_of_typ ctxt T1) ^
+ " is not a subtype of " ^ quote (Syntax.string_of_typ ctxt T2))
+ | SOME (co, _) => co)
+ | ((Type (a, Ts)), (Type (b, Us))) =>
+ if a <> b
+ then
+ (case Symreltab.lookup (coes_of ctxt) (a, b) of
+ (*immediate error - cannot fix complex coercion with the global algorithm*)
+ NONE => error (err () ^ "No coercion known for type constructors: " ^
+ quote a ^ " and " ^ quote b)
+ | SOME (co, ((Ts', Us'), _)) =>
+ let
+ val co_before = gen (T1, Type (a, Ts'));
+ val coT = range_type (fastype_of co_before);
+ val insts = inst_collect tye (err ++> "Could not insert complex coercion")
+ (domain_type (fastype_of co)) coT;
+ val co' = Term.subst_TVars insts co;
+ val co_after = gen (Type (b, (map (typ_subst_TVars insts) Us')), T2);
+ in
+ Abs (Name.uu, T1, Library.foldr (op $)
+ (filter (not o is_identity) [co_after, co', co_before], Bound 0))
+ end)
+ else
+ let
+ fun sub_co (COVARIANT, TU) = SOME (gen TU)
+ | sub_co (CONTRAVARIANT, TU) = SOME (gen (swap TU))
+ | sub_co (INVARIANT_TO _, _) = NONE;
+ fun ts_of [] = []
+ | ts_of (Type ("fun", [x1, x2]) :: xs) = x1 :: x2 :: (ts_of xs);
+ in
+ (case Symtab.lookup (tmaps_of ctxt) a of
+ NONE => raise COERCION_GEN_ERROR
+ (err ++> "No map function for " ^ quote a ^ " known")
+ | SOME tmap =>
+ let
+ val used_coes = map_filter sub_co ((snd tmap) ~~ (Ts ~~ Us));
+ in
+ if null (filter (not o is_identity) used_coes)
+ then Abs (Name.uu, Type (a, Ts), Bound 0)
+ else Term.list_comb
+ (instantiate (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
+ end)
+ end
+ | (T, U) =>
+ if Type.could_unify (T, U)
+ then Abs (Name.uu, T, Bound 0)
+ else raise COERCION_GEN_ERROR (err ++> "Cannot generate coercion from " ^
+ quote (Syntax.string_of_typ ctxt T) ^ " to " ^
+ quote (Syntax.string_of_typ ctxt U)));
+ in
+ gen TU
+ end;
-fun insert_coercions ctxt tye ts =
+fun function_of ctxt err tye T =
+ (case Type_Infer.deref tye T of
+ Type (C, Ts) =>
+ (case Symreltab.lookup (coes_of ctxt) (C, "fun") of
+ NONE => error (err () ^ "No complex coercion from " ^ quote C ^ " to fun")
+ | SOME (co, ((Ts', _), _)) =>
+ let
+ val co_before = gen_coercion ctxt err tye (Type (C, Ts), Type (C, Ts'));
+ val coT = range_type (fastype_of co_before);
+ val insts = inst_collect tye (err ++> "Could not insert complex coercion")
+ (domain_type (fastype_of co)) coT;
+ val co' = Term.subst_TVars insts co;
+ in
+ Abs (Name.uu, Type (C, Ts), Library.foldr (op $)
+ (filter (not o is_identity) [co', co_before], Bound 0))
+ end)
+ | T' => error (err () ^ "No complex coercion from " ^
+ quote (Syntax.string_of_typ ctxt T') ^ " to fun"));
+
+fun insert_coercions ctxt (tye, idx) ts =
let
- fun insert _ (Const (c, T)) =
- let val T' = T;
- in (Const (c, T'), T') end
- | insert _ (Free (x, T)) =
- let val T' = T;
- in (Free (x, T'), T') end
- | insert _ (Var (xi, T)) =
- let val T' = T;
- in (Var (xi, T'), T') end
+ fun insert _ (Const (c, T)) = (Const (c, T), T)
+ | insert _ (Free (x, T)) = (Free (x, T), T)
+ | insert _ (Var (xi, T)) = (Var (xi, T), T)
| insert bs (Bound i) =
let val T = nth bs i handle General.Subscript => err_loose i;
in (Bound i, T) end
| insert bs (Abs (x, T, t)) =
- let
- val T' = T;
- val (t', T'') = insert (T' :: bs) t;
- in
- (Abs (x, T', t'), T' --> T'')
- end
+ let val (t', T') = insert (T :: bs) t;
+ in (Abs (x, T, t'), T --> T') end
| insert bs (t $ u) =
let
- val (t', Type ("fun", [U, T])) =
- apsnd (Type_Infer.deref tye) (insert bs t);
+ val (t', Type ("fun", [U, T])) = apsnd (Type_Infer.deref tye) (insert bs t);
val (u', U') = insert bs u;
in
if can (fn TU => strong_unify ctxt TU (tye, 0)) (U, U')
then (t' $ u', T)
- else (t' $ (gen_coercion ctxt tye (U', U) $ u'), T)
+ else (t' $ (gen_coercion ctxt (K "") tye (U', U) $ u'), T)
end
in
map (fst o insert []) ts
@@ -685,24 +744,51 @@
val V = Type_Infer.mk_param idx [];
val (tu, tye_idx'') = (t' $ u', strong_unify ctxt (U --> V, T) (tye, idx + 1))
handle NO_UNIFIER (msg, tye') =>
- raise TYPE_INFERENCE_ERROR (err_appl_msg ctxt msg tye' bs t T u U);
+ let
+ val err = err_appl_msg ctxt msg tye' bs t T u U;
+ val W = Type_Infer.mk_param (idx + 1) [];
+ val (t'', (tye', idx')) =
+ (t', strong_unify ctxt (W --> V, T) (tye, idx + 2))
+ handle NO_UNIFIER _ =>
+ let
+ val err' =
+ err ++> "\nLocal coercion insertion on the operator failed:\n";
+ val co = function_of ctxt err' tye T;
+ val (t'', T'', tye_idx'') = inf bs (co $ t') (tye, idx + 2);
+ in
+ (t'', strong_unify ctxt (W --> V, T'') tye_idx''
+ handle NO_UNIFIER (msg, _) => error (err' () ^ msg))
+ end;
+ val err' = err ++> (if t' aconv t'' then ""
+ else "\nSuccessfully coerced the operand to a function of type:\n" ^
+ Syntax.string_of_typ ctxt
+ (the_single (snd (prep_output ctxt tye' bs [] [W --> V]))) ^ "\n") ^
+ "\nLocal coercion insertion on the operand failed:\n";
+ val co = gen_coercion ctxt err' tye' (U, W);
+ val (u'', U', tye_idx') =
+ inf bs (if is_identity co then u else co $ u) (tye', idx');
+ in
+ (t'' $ u'', strong_unify ctxt (U', W) tye_idx'
+ handle NO_UNIFIER (msg, _) => raise COERCION_GEN_ERROR (err' ++> msg))
+ end;
in (tu, V, tye_idx'') end;
fun infer_single t tye_idx =
- let val (t, _, tye_idx') = inf [] t tye_idx;
+ let val (t, _, tye_idx') = inf [] t tye_idx
in (t, tye_idx') end;
val (ts', (tye, _)) = (fold_map infer_single ts (Vartab.empty, idx)
- handle TYPE_INFERENCE_ERROR err =>
+ handle COERCION_GEN_ERROR err =>
let
fun gen_single t (tye_idx, constraints) =
- let val (_, tye_idx', constraints') = generate_constraints ctxt err t tye_idx
+ let val (_, tye_idx', constraints') =
+ generate_constraints ctxt (err ++> "\n") t tye_idx
in (tye_idx', constraints' @ constraints) end;
val (tye_idx, constraints) = fold gen_single ts ((Vartab.empty, idx), []);
- val (tye, idx) = process_constraints ctxt err constraints tye_idx;
+ val (tye, idx) = process_constraints ctxt (err ++> "\n") constraints tye_idx;
in
- (insert_coercions ctxt tye ts, (tye, idx))
+ (insert_coercions ctxt (tye, idx) ts, (tye, idx))
end);
val (_, ts'') = Type_Infer.finish ctxt tye ([], ts');
@@ -728,11 +814,11 @@
val ctxt = Context.proof_of context;
val t = singleton (Variable.polymorphic ctxt) raw_t;
- fun err_str t = "\n\nThe provided function has the type\n" ^
+ fun err_str t = "\n\nThe provided function has the type:\n" ^
Syntax.string_of_typ ctxt (fastype_of t) ^
- "\n\nThe general type signature of a map function is" ^
+ "\n\nThe general type signature of a map function is:" ^
"\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [y1, ..., yn]" ^
- "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi)";
+ "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi).";
val ((fis, T1), T2) = apfst split_last (strip_type (fastype_of t))
handle Empty => error ("Not a proper map function:" ^ err_str t);
@@ -766,56 +852,150 @@
map_tmaps (Symtab.update (snd res, (t, res_av))) context
end;
+fun transitive_coercion ctxt tab G (a, b) =
+ let
+ fun safe_app t (Abs (x, T', u)) =
+ let
+ val t' = map_types Type_Infer.paramify_vars t;
+ in
+ singleton (coercion_infer_types ctxt) (Abs(x, T', (t' $ u)))
+ end;
+ val path = hd (Graph.irreducible_paths G (a, b));
+ val path' = fst (split_last path) ~~ tl path;
+ val coercions = map (fst o the o Symreltab.lookup tab) path';
+ val trans_co = singleton (Variable.polymorphic ctxt)
+ (fold safe_app coercions (Abs (Name.uu, dummyT, Bound 0)));
+ val (Ts, Us) = pairself (snd o Term.dest_Type) (Term.dest_funT (type_of trans_co))
+ in
+ (trans_co, ((Ts, Us), coercions))
+ end;
+
fun add_coercion raw_t context =
let
val ctxt = Context.proof_of context;
val t = singleton (Variable.polymorphic ctxt) raw_t;
- fun err_coercion () = error ("Bad type for coercion " ^
- Syntax.string_of_term ctxt t ^ ":\n" ^
+ fun err_coercion () = error ("Bad type for a coercion:\n" ^
+ Syntax.string_of_term ctxt t ^ " :: " ^
Syntax.string_of_typ ctxt (fastype_of t));
val (T1, T2) = Term.dest_funT (fastype_of t)
handle TYPE _ => err_coercion ();
- val a =
- (case T1 of
- Type (x, []) => x
- | _ => err_coercion ());
+ val (a, Ts) = Term.dest_Type T1
+ handle TYPE _ => err_coercion ();
+
+ val (b, Us) = Term.dest_Type T2
+ handle TYPE _ => err_coercion ();
- val b =
- (case T2 of
- Type (x, []) => x
- | _ => err_coercion ());
-
- fun coercion_data_update (tab, G) =
+ fun coercion_data_update (tab, G, _) =
let
- val G' = maybe_new_nodes [a, b] G
+ val G' = maybe_new_nodes [(a, length Ts), (b, length Us)] G
val G'' = Graph.add_edge_trans_acyclic (a, b) G'
- handle Graph.CYCLES _ => error (a ^ " is already a subtype of " ^ b ^
- "!\n\nCannot add coercion of type: " ^ a ^ " => " ^ b);
+ handle Graph.CYCLES _ => error (
+ Syntax.string_of_typ ctxt T2 ^ " is already a subtype of " ^
+ Syntax.string_of_typ ctxt T1 ^ "!\n\nCannot add coercion of type: " ^
+ Syntax.string_of_typ ctxt (T1 --> T2));
val new_edges =
flat (Graph.dest G'' |> map (fn (x, ys) => ys |> map_filter (fn y =>
if Graph.is_edge G' (x, y) then NONE else SOME (x, y))));
val G_and_new = Graph.add_edge (a, b) G';
- fun complex_coercion tab G (a, b) =
- let
- val path = hd (Graph.irreducible_paths G (a, b))
- val path' = fst (split_last path) ~~ tl path
- in Abs (Name.uu, Type (a, []),
- fold (fn t => fn u => t $ u) (map (the o Symreltab.lookup tab) path') (Bound 0))
- end;
-
val tab' = fold
- (fn pair => fn tab => Symreltab.update (pair, complex_coercion tab G_and_new pair) tab)
+ (fn pair => fn tab =>
+ Symreltab.update (pair, transitive_coercion ctxt tab G_and_new pair) tab)
(filter (fn pair => pair <> (a, b)) new_edges)
- (Symreltab.update ((a, b), t) tab);
+ (Symreltab.update ((a, b), (t, ((Ts, Us), []))) tab);
in
- (tab', G'')
+ (tab', G'', restrict_graph G'')
end;
in
- map_coes_and_graph coercion_data_update context
+ map_coes_and_graphs coercion_data_update context
+ end;
+
+fun delete_coercion raw_t context =
+ let
+ val ctxt = Context.proof_of context;
+ val t = singleton (Variable.polymorphic ctxt) raw_t;
+
+ fun err_coercion the = error ("Not" ^
+ (if the then " the defined " else " a ") ^ "coercion:\n" ^
+ Syntax.string_of_term ctxt t ^ " :: " ^
+ Syntax.string_of_typ ctxt (fastype_of t));
+
+ val (T1, T2) = Term.dest_funT (fastype_of t)
+ handle TYPE _ => err_coercion false;
+
+ val (a, Ts) = dest_Type T1
+ handle TYPE _ => err_coercion false;
+
+ val (b, Us) = dest_Type T2
+ handle TYPE _ => err_coercion false;
+
+ fun delete_and_insert tab G =
+ let
+ val pairs =
+ Symreltab.fold (fn ((a, b), (_, (_, ts))) => fn pairs =>
+ if member (op aconv) ts t then (a, b) :: pairs else pairs) tab [(a, b)];
+ fun delete pair (G, tab) = (Graph.del_edge pair G, Symreltab.delete_safe pair tab);
+ val (G', tab') = fold delete pairs (G, tab);
+ fun reinsert pair (G, xs) = (case (Graph.irreducible_paths G pair) of
+ [] => (G, xs)
+ | _ => (Graph.add_edge pair G, (pair, transitive_coercion ctxt tab' G' pair) :: xs));
+ val (G'', ins) = fold reinsert pairs (G', []);
+ in
+ (fold Symreltab.update ins tab', G'', restrict_graph G'')
+ end
+
+ fun show_term t = Pretty.block [Syntax.pretty_term ctxt t,
+ Pretty.str " :: ", Syntax.pretty_typ ctxt (fastype_of t)]
+
+ fun coercion_data_update (tab, G, _) =
+ (case Symreltab.lookup tab (a, b) of
+ NONE => err_coercion false
+ | SOME (t', (_, [])) => if t aconv t'
+ then delete_and_insert tab G
+ else err_coercion true
+ | SOME (t', (_, ts)) => if t aconv t'
+ then error ("Cannot delete the automatically derived coercion:\n" ^
+ Syntax.string_of_term ctxt t ^ " :: " ^
+ Syntax.string_of_typ ctxt (fastype_of t) ^
+ Pretty.string_of (Pretty.big_list "\n\nDeleting one of the coercions:"
+ (map show_term ts)) ^
+ "\nwill also remove the transitive coercion.")
+ else err_coercion true);
+ in
+ map_coes_and_graphs coercion_data_update context
+ end;
+
+fun print_coercions ctxt =
+ let
+ fun separate _ [] = ([], [])
+ | separate P (x::xs) = (if P x then apfst else apsnd) (cons x) (separate P xs);
+ val (simple, complex) =
+ separate (fn (_, (_, ((Ts, Us), _))) => null Ts andalso null Us)
+ (Symreltab.dest (coes_of ctxt));
+ fun show_coercion ((a, b), (t, ((Ts, Us), _))) = Pretty.block [
+ Syntax.pretty_typ ctxt (Type (a, Ts)),
+ Pretty.brk 1, Pretty.str "<:", Pretty.brk 1,
+ Syntax.pretty_typ ctxt (Type (b, Us)),
+ Pretty.brk 3, Pretty.block [Pretty.str "using", Pretty.brk 1,
+ Pretty.quote (Syntax.pretty_term ctxt t)]];
+ in
+ Pretty.big_list "Coercions:"
+ [Pretty.big_list "between base types:" (map show_coercion simple),
+ Pretty.big_list "other:" (map show_coercion complex)]
+ |> Pretty.writeln
+ end;
+
+fun print_coercion_maps ctxt =
+ let
+ fun show_map (x, (t, _)) = Pretty.block [
+ Pretty.str x, Pretty.str ":", Pretty.brk 1,
+ Pretty.quote (Syntax.pretty_term ctxt t)];
+ in
+ Pretty.big_list "Coercion maps:" (map show_map (Symtab.dest (tmaps_of ctxt)))
+ |> Pretty.writeln
end;
@@ -826,8 +1006,21 @@
Attrib.setup @{binding coercion}
(Args.term >> (fn t => Thm.declaration_attribute (K (add_coercion t))))
"declaration of new coercions" #>
+ Attrib.setup @{binding coercion_delete}
+ (Args.term >> (fn t => Thm.declaration_attribute (K (delete_coercion t))))
+ "deletion of coercions" #>
Attrib.setup @{binding coercion_map}
(Args.term >> (fn t => Thm.declaration_attribute (K (add_type_map t))))
"declaration of new map functions";
+
+(* outer syntax commands *)
+
+val _ =
+ Outer_Syntax.improper_command "print_coercions" "print all coercions" Keyword.diag
+ (Scan.succeed (Toplevel.keep (print_coercions o Toplevel.context_of)))
+val _ =
+ Outer_Syntax.improper_command "print_coercion_maps" "print all coercion maps" Keyword.diag
+ (Scan.succeed (Toplevel.keep (print_coercion_maps o Toplevel.context_of)))
+
end;