improving code generation for multisets; adding exhaustive quickcheck generators for multisets
--- a/src/HOL/Library/Multiset.thy Tue Jan 10 10:17:07 2012 +0100
+++ b/src/HOL/Library/Multiset.thy Tue Jan 10 10:17:09 2012 +0100
@@ -5,7 +5,7 @@
header {* (Finite) multisets *}
theory Multiset
-imports Main
+imports Main AList
begin
subsection {* The type of multisets *}
@@ -1041,7 +1041,81 @@
by (cases "i = j") (simp_all add: multiset_of_update nth_mem_multiset_of)
-subsubsection {* Association lists -- including rudimentary code generation *}
+subsubsection {* Association lists -- including code generation *}
+
+text {* Preliminaries *}
+
+text {* Raw operations on lists *}
+
+definition join_raw :: "('key \<Rightarrow> 'val \<times> 'val \<Rightarrow> 'val) \<Rightarrow> ('key \<times> 'val) list \<Rightarrow> ('key \<times> 'val) list \<Rightarrow> ('key \<times> 'val) list"
+where
+ "join_raw f xs ys = foldr (\<lambda>(k, v). map_default k v (%v'. f k (v', v))) ys xs"
+
+lemma join_raw_Nil [simp]:
+ "join_raw f xs [] = xs"
+by (simp add: join_raw_def)
+
+lemma join_raw_Cons [simp]:
+ "join_raw f xs ((k, v) # ys) = map_default k v (%v'. f k (v', v)) (join_raw f xs ys)"
+by (simp add: join_raw_def)
+
+lemma map_of_join_raw:
+ assumes "distinct (map fst ys)"
+ shows "map_of (join_raw f xs ys) x = (case map_of xs x of None => map_of ys x | Some v => (case map_of ys x of None => Some v | Some v' => Some (f x (v, v'))))"
+using assms
+apply (induct ys)
+apply (auto simp add: map_of_map_default split: option.split)
+apply (metis map_of_eq_None_iff option.simps(2) weak_map_of_SomeI)
+by (metis Some_eq_map_of_iff map_of_eq_None_iff option.simps(2))
+
+lemma distinct_join_raw:
+ assumes "distinct (map fst xs)"
+ shows "distinct (map fst (join_raw f xs ys))"
+using assms
+proof (induct ys)
+ case (Cons y ys)
+ thus ?case by (cases y) (simp add: distinct_map_default)
+qed auto
+
+definition
+ "subtract_entries_raw xs ys = foldr (%(k, v). AList_Impl.map_entry k (%v'. v' - v)) ys xs"
+
+lemma map_of_subtract_entries_raw:
+ "distinct (map fst ys) ==> map_of (subtract_entries_raw xs ys) x = (case map_of xs x of None => None | Some v => (case map_of ys x of None => Some v | Some v' => Some (v - v')))"
+unfolding subtract_entries_raw_def
+apply (induct ys)
+apply auto
+apply (simp split: option.split)
+apply (simp add: map_of_map_entry)
+apply (auto split: option.split)
+apply (metis map_of_eq_None_iff option.simps(3) option.simps(4))
+by (metis map_of_eq_None_iff option.simps(4) option.simps(5))
+
+lemma distinct_subtract_entries_raw:
+ assumes "distinct (map fst xs)"
+ shows "distinct (map fst (subtract_entries_raw xs ys))"
+using assms
+unfolding subtract_entries_raw_def by (induct ys) (auto simp add: distinct_map_entry)
+
+text {* Operations on alists *}
+
+definition join
+where
+ "join f xs ys = AList.Alist (join_raw f (AList.impl_of xs) (AList.impl_of ys))"
+
+lemma [code abstract]:
+ "AList.impl_of (join f xs ys) = join_raw f (AList.impl_of xs) (AList.impl_of ys)"
+unfolding join_def by (simp add: Alist_inverse distinct_join_raw)
+
+definition subtract_entries
+where
+ "subtract_entries xs ys = AList.Alist (subtract_entries_raw (AList.impl_of xs) (AList.impl_of ys))"
+
+lemma [code abstract]:
+ "AList.impl_of (subtract_entries xs ys) = subtract_entries_raw (AList.impl_of xs) (AList.impl_of ys)"
+unfolding subtract_entries_def by (simp add: Alist_inverse distinct_subtract_entries_raw)
+
+text {* Implementing multisets by means of association lists *}
definition count_of :: "('a \<times> nat) list \<Rightarrow> 'a \<Rightarrow> nat" where
"count_of xs x = (case map_of xs x of None \<Rightarrow> 0 | Some n \<Rightarrow> n)"
@@ -1074,32 +1148,55 @@
by (induct xs) (simp_all add: count_of_def)
lemma count_of_filter:
- "count_of (filter (P \<circ> fst) xs) x = (if P x then count_of xs x else 0)"
+ "count_of (List.filter (P \<circ> fst) xs) x = (if P x then count_of xs x else 0)"
by (induct xs) auto
-definition Bag :: "('a \<times> nat) list \<Rightarrow> 'a multiset" where
- "Bag xs = Abs_multiset (count_of xs)"
+lemma count_of_map_default [simp]:
+ "count_of (map_default x b (%x. x + b) xs) y = (if x = y then count_of xs x + b else count_of xs y)"
+unfolding count_of_def by (simp add: map_of_map_default split: option.split)
+
+lemma count_of_join_raw:
+ "distinct (map fst ys) ==> count_of xs x + count_of ys x = count_of (join_raw (%x (x, y). x + y) xs ys) x"
+unfolding count_of_def by (simp add: map_of_join_raw split: option.split)
+
+lemma count_of_subtract_entries_raw:
+ "distinct (map fst ys) ==> count_of xs x - count_of ys x = count_of (subtract_entries_raw xs ys) x"
+unfolding count_of_def by (simp add: map_of_subtract_entries_raw split: option.split)
+
+text {* Code equations for multiset operations *}
+
+definition Bag :: "('a, nat) alist \<Rightarrow> 'a multiset" where
+ "Bag xs = Abs_multiset (count_of (AList.impl_of xs))"
code_datatype Bag
lemma count_Bag [simp, code]:
- "count (Bag xs) = count_of xs"
+ "count (Bag xs) = count_of (AList.impl_of xs)"
by (simp add: Bag_def count_of_multiset Abs_multiset_inverse)
lemma Mempty_Bag [code]:
- "{#} = Bag []"
- by (simp add: multiset_eq_iff)
+ "{#} = Bag (Alist [])"
+ by (simp add: multiset_eq_iff alist.Alist_inverse)
lemma single_Bag [code]:
- "{#x#} = Bag [(x, 1)]"
- by (simp add: multiset_eq_iff)
+ "{#x#} = Bag (Alist [(x, 1)])"
+ by (simp add: multiset_eq_iff alist.Alist_inverse)
+
+lemma union_Bag [code]:
+ "Bag xs + Bag ys = Bag (join (\<lambda>x (n1, n2). n1 + n2) xs ys)"
+by (rule multiset_eqI) (simp add: count_of_join_raw alist.Alist_inverse distinct_join_raw join_def)
+
+lemma minus_Bag [code]:
+ "Bag xs - Bag ys = Bag (subtract_entries xs ys)"
+by (rule multiset_eqI)
+ (simp add: count_of_subtract_entries_raw alist.Alist_inverse distinct_subtract_entries_raw subtract_entries_def)
lemma filter_Bag [code]:
- "Multiset.filter P (Bag xs) = Bag (filter (P \<circ> fst) xs)"
- by (rule multiset_eqI) (simp add: count_of_filter)
+ "Multiset.filter P (Bag xs) = Bag (AList.filter (P \<circ> fst) xs)"
+by (rule multiset_eqI) (simp add: count_of_filter impl_of_filter)
lemma mset_less_eq_Bag [code]:
- "Bag xs \<le> A \<longleftrightarrow> (\<forall>(x, n) \<in> set xs. count_of xs x \<le> count A x)"
+ "Bag xs \<le> A \<longleftrightarrow> (\<forall>(x, n) \<in> set (AList.impl_of xs). count_of (AList.impl_of xs) x \<le> count A x)"
(is "?lhs \<longleftrightarrow> ?rhs")
proof
assume ?lhs then show ?rhs
@@ -1109,8 +1206,8 @@
show ?lhs
proof (rule mset_less_eqI)
fix x
- from `?rhs` have "count_of xs x \<le> count A x"
- by (cases "x \<in> fst ` set xs") (auto simp add: count_of_empty)
+ from `?rhs` have "count_of (AList.impl_of xs) x \<le> count A x"
+ by (cases "x \<in> fst ` set (AList.impl_of xs)") (auto simp add: count_of_empty)
then show "count (Bag xs) x \<le> count A x"
by (simp add: mset_le_def count_Bag)
qed
@@ -1127,12 +1224,10 @@
end
-lemma [code nbe]:
- "HOL.equal (A :: 'a::equal multiset) A \<longleftrightarrow> True"
- by (fact equal_refl)
+text {* Quickcheck generators *}
definition (in term_syntax)
- bagify :: "('a\<Colon>typerep \<times> nat) list \<times> (unit \<Rightarrow> Code_Evaluation.term)
+ bagify :: "('a\<Colon>typerep, nat) alist \<times> (unit \<Rightarrow> Code_Evaluation.term)
\<Rightarrow> 'a multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)" where
[code_unfold]: "bagify xs = Code_Evaluation.valtermify Bag {\<cdot>} xs"
@@ -1152,6 +1247,28 @@
no_notation fcomp (infixl "\<circ>>" 60)
no_notation scomp (infixl "\<circ>\<rightarrow>" 60)
+instantiation multiset :: (exhaustive) exhaustive
+begin
+
+definition exhaustive_multiset :: "('a multiset => (bool * term list) option) => code_numeral => (bool * term list) option"
+where
+ "exhaustive_multiset f i = Quickcheck_Exhaustive.exhaustive (%xs. f (Bag xs)) i"
+
+instance ..
+
+end
+
+instantiation multiset :: (full_exhaustive) full_exhaustive
+begin
+
+definition full_exhaustive_multiset :: "('a multiset * (unit => term) => (bool * term list) option) => code_numeral => (bool * term list) option"
+where
+ "full_exhaustive_multiset f i = Quickcheck_Exhaustive.full_exhaustive (%xs. f (bagify xs)) i"
+
+instance ..
+
+end
+
hide_const (open) bagify