improving code generation for multisets; adding exhaustive quickcheck generators for multisets
authorbulwahn
Tue, 10 Jan 2012 10:17:09 +0100
changeset 46168 bef8c811df20
parent 46167 25eba8a5d7d0
child 46169 321abd584588
improving code generation for multisets; adding exhaustive quickcheck generators for multisets
src/HOL/Library/Multiset.thy
--- a/src/HOL/Library/Multiset.thy	Tue Jan 10 10:17:07 2012 +0100
+++ b/src/HOL/Library/Multiset.thy	Tue Jan 10 10:17:09 2012 +0100
@@ -5,7 +5,7 @@
 header {* (Finite) multisets *}
 
 theory Multiset
-imports Main
+imports Main AList
 begin
 
 subsection {* The type of multisets *}
@@ -1041,7 +1041,81 @@
   by (cases "i = j") (simp_all add: multiset_of_update nth_mem_multiset_of)
 
 
-subsubsection {* Association lists -- including rudimentary code generation *}
+subsubsection {* Association lists -- including code generation *}
+
+text {* Preliminaries *}
+
+text {* Raw operations on lists *}
+
+definition join_raw :: "('key \<Rightarrow> 'val \<times> 'val \<Rightarrow> 'val) \<Rightarrow> ('key \<times> 'val) list \<Rightarrow> ('key \<times> 'val) list \<Rightarrow> ('key \<times> 'val) list"
+where
+  "join_raw f xs ys = foldr (\<lambda>(k, v). map_default k v (%v'. f k (v', v))) ys xs"
+
+lemma join_raw_Nil [simp]:
+  "join_raw f xs [] = xs"
+by (simp add: join_raw_def)
+
+lemma join_raw_Cons [simp]:
+  "join_raw f xs ((k, v) # ys) = map_default k v (%v'. f k (v', v)) (join_raw f xs ys)"
+by (simp add: join_raw_def)
+
+lemma map_of_join_raw:
+  assumes "distinct (map fst ys)"
+  shows "map_of (join_raw f xs ys) x = (case map_of xs x of None => map_of ys x | Some v => (case map_of ys x of None => Some v | Some v' => Some (f x (v, v'))))"
+using assms
+apply (induct ys)
+apply (auto simp add: map_of_map_default split: option.split)
+apply (metis map_of_eq_None_iff option.simps(2) weak_map_of_SomeI)
+by (metis Some_eq_map_of_iff map_of_eq_None_iff option.simps(2))
+
+lemma distinct_join_raw:
+  assumes "distinct (map fst xs)"
+  shows "distinct (map fst (join_raw f xs ys))"
+using assms
+proof (induct ys)
+  case (Cons y ys)
+  thus ?case by (cases y) (simp add: distinct_map_default)
+qed auto
+
+definition
+  "subtract_entries_raw xs ys = foldr (%(k, v). AList_Impl.map_entry k (%v'. v' - v)) ys xs"
+
+lemma map_of_subtract_entries_raw:
+  "distinct (map fst ys) ==> map_of (subtract_entries_raw xs ys) x = (case map_of xs x of None => None | Some v => (case map_of ys x of None => Some v | Some v' => Some (v - v')))"
+unfolding subtract_entries_raw_def
+apply (induct ys)
+apply auto
+apply (simp split: option.split)
+apply (simp add: map_of_map_entry)
+apply (auto split: option.split)
+apply (metis map_of_eq_None_iff option.simps(3) option.simps(4))
+by (metis map_of_eq_None_iff option.simps(4) option.simps(5))
+
+lemma distinct_subtract_entries_raw:
+  assumes "distinct (map fst xs)"
+  shows "distinct (map fst (subtract_entries_raw xs ys))"
+using assms
+unfolding subtract_entries_raw_def by (induct ys) (auto simp add: distinct_map_entry)
+
+text {* Operations on alists *}
+
+definition join
+where
+  "join f xs ys = AList.Alist (join_raw f (AList.impl_of xs) (AList.impl_of ys))" 
+
+lemma [code abstract]:
+  "AList.impl_of (join f xs ys) = join_raw f (AList.impl_of xs) (AList.impl_of ys)"
+unfolding join_def by (simp add: Alist_inverse distinct_join_raw)
+
+definition subtract_entries
+where
+  "subtract_entries xs ys = AList.Alist (subtract_entries_raw (AList.impl_of xs) (AList.impl_of ys))"
+
+lemma [code abstract]:
+  "AList.impl_of (subtract_entries xs ys) = subtract_entries_raw (AList.impl_of xs) (AList.impl_of ys)"
+unfolding subtract_entries_def by (simp add: Alist_inverse distinct_subtract_entries_raw)
+
+text {* Implementing multisets by means of association lists *}
 
 definition count_of :: "('a \<times> nat) list \<Rightarrow> 'a \<Rightarrow> nat" where
   "count_of xs x = (case map_of xs x of None \<Rightarrow> 0 | Some n \<Rightarrow> n)"
@@ -1074,32 +1148,55 @@
   by (induct xs) (simp_all add: count_of_def)
 
 lemma count_of_filter:
-  "count_of (filter (P \<circ> fst) xs) x = (if P x then count_of xs x else 0)"
+  "count_of (List.filter (P \<circ> fst) xs) x = (if P x then count_of xs x else 0)"
   by (induct xs) auto
 
-definition Bag :: "('a \<times> nat) list \<Rightarrow> 'a multiset" where
-  "Bag xs = Abs_multiset (count_of xs)"
+lemma count_of_map_default [simp]:
+  "count_of (map_default x b (%x. x + b) xs) y = (if x = y then count_of xs x + b else count_of xs y)"
+unfolding count_of_def by (simp add: map_of_map_default split: option.split)
+
+lemma count_of_join_raw:
+  "distinct (map fst ys) ==> count_of xs x + count_of ys x = count_of (join_raw (%x (x, y). x + y) xs ys) x"
+unfolding count_of_def by (simp add: map_of_join_raw split: option.split)
+
+lemma count_of_subtract_entries_raw:
+  "distinct (map fst ys) ==> count_of xs x - count_of ys x = count_of (subtract_entries_raw xs ys) x"
+unfolding count_of_def by (simp add: map_of_subtract_entries_raw split: option.split)
+
+text {* Code equations for multiset operations *}
+
+definition Bag :: "('a, nat) alist \<Rightarrow> 'a multiset" where
+  "Bag xs = Abs_multiset (count_of (AList.impl_of xs))"
 
 code_datatype Bag
 
 lemma count_Bag [simp, code]:
-  "count (Bag xs) = count_of xs"
+  "count (Bag xs) = count_of (AList.impl_of xs)"
   by (simp add: Bag_def count_of_multiset Abs_multiset_inverse)
 
 lemma Mempty_Bag [code]:
-  "{#} = Bag []"
-  by (simp add: multiset_eq_iff)
+  "{#} = Bag (Alist [])"
+  by (simp add: multiset_eq_iff alist.Alist_inverse)
   
 lemma single_Bag [code]:
-  "{#x#} = Bag [(x, 1)]"
-  by (simp add: multiset_eq_iff)
+  "{#x#} = Bag (Alist [(x, 1)])"
+  by (simp add: multiset_eq_iff alist.Alist_inverse)
+
+lemma union_Bag [code]:
+  "Bag xs + Bag ys = Bag (join (\<lambda>x (n1, n2). n1 + n2) xs ys)"
+by (rule multiset_eqI) (simp add: count_of_join_raw alist.Alist_inverse distinct_join_raw join_def)
+
+lemma minus_Bag [code]:
+  "Bag xs - Bag ys = Bag (subtract_entries xs ys)"
+by (rule multiset_eqI)
+  (simp add: count_of_subtract_entries_raw alist.Alist_inverse distinct_subtract_entries_raw subtract_entries_def)
 
 lemma filter_Bag [code]:
-  "Multiset.filter P (Bag xs) = Bag (filter (P \<circ> fst) xs)"
-  by (rule multiset_eqI) (simp add: count_of_filter)
+  "Multiset.filter P (Bag xs) = Bag (AList.filter (P \<circ> fst) xs)"
+by (rule multiset_eqI) (simp add: count_of_filter impl_of_filter)
 
 lemma mset_less_eq_Bag [code]:
-  "Bag xs \<le> A \<longleftrightarrow> (\<forall>(x, n) \<in> set xs. count_of xs x \<le> count A x)"
+  "Bag xs \<le> A \<longleftrightarrow> (\<forall>(x, n) \<in> set (AList.impl_of xs). count_of (AList.impl_of xs) x \<le> count A x)"
     (is "?lhs \<longleftrightarrow> ?rhs")
 proof
   assume ?lhs then show ?rhs
@@ -1109,8 +1206,8 @@
   show ?lhs
   proof (rule mset_less_eqI)
     fix x
-    from `?rhs` have "count_of xs x \<le> count A x"
-      by (cases "x \<in> fst ` set xs") (auto simp add: count_of_empty)
+    from `?rhs` have "count_of (AList.impl_of xs) x \<le> count A x"
+      by (cases "x \<in> fst ` set (AList.impl_of xs)") (auto simp add: count_of_empty)
     then show "count (Bag xs) x \<le> count A x"
       by (simp add: mset_le_def count_Bag)
   qed
@@ -1127,12 +1224,10 @@
 
 end
 
-lemma [code nbe]:
-  "HOL.equal (A :: 'a::equal multiset) A \<longleftrightarrow> True"
-  by (fact equal_refl)
+text {* Quickcheck generators *}
 
 definition (in term_syntax)
-  bagify :: "('a\<Colon>typerep \<times> nat) list \<times> (unit \<Rightarrow> Code_Evaluation.term)
+  bagify :: "('a\<Colon>typerep, nat) alist \<times> (unit \<Rightarrow> Code_Evaluation.term)
     \<Rightarrow> 'a multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)" where
   [code_unfold]: "bagify xs = Code_Evaluation.valtermify Bag {\<cdot>} xs"
 
@@ -1152,6 +1247,28 @@
 no_notation fcomp (infixl "\<circ>>" 60)
 no_notation scomp (infixl "\<circ>\<rightarrow>" 60)
 
+instantiation multiset :: (exhaustive) exhaustive
+begin
+
+definition exhaustive_multiset :: "('a multiset => (bool * term list) option) => code_numeral => (bool * term list) option"
+where
+  "exhaustive_multiset f i = Quickcheck_Exhaustive.exhaustive (%xs. f (Bag xs)) i"
+
+instance ..
+
+end
+
+instantiation multiset :: (full_exhaustive) full_exhaustive
+begin
+
+definition full_exhaustive_multiset :: "('a multiset * (unit => term) => (bool * term list) option) => code_numeral => (bool * term list) option"
+where
+  "full_exhaustive_multiset f i = Quickcheck_Exhaustive.full_exhaustive (%xs. f (bagify xs)) i"
+
+instance ..
+
+end
+
 hide_const (open) bagify