--- a/src/HOL/Library/AList_Mapping.thy Tue May 31 12:24:43 2016 +0200
+++ b/src/HOL/Library/AList_Mapping.thy Tue May 31 13:02:44 2016 +0200
@@ -63,9 +63,43 @@
by (auto intro!: map_of_eqI) (auto dest!: map_of_eq_dom intro: aux)
qed
+lemma map_values_Mapping [code]:
+ fixes f :: "'a \<Rightarrow> 'b" and xs :: "('c \<times> 'a) list"
+ shows "Mapping.map_values f (Mapping xs) = Mapping (map (\<lambda>(x,y). (x, f y)) xs)"
+proof (transfer, rule ext, goal_cases)
+ case (1 f xs x)
+ thus ?case by (induction xs) auto
+qed
+
+lemma combine_code [code]:
+ "Mapping.combine f (Mapping xs) (Mapping ys) =
+ Mapping.tabulate (remdups (map fst xs @ map fst ys))
+ (\<lambda>x. the (combine_options f (map_of xs x) (map_of ys x)))"
+proof (transfer, rule ext, rule sym, goal_cases)
+ case (1 f xs ys x)
+ show ?case
+ by (cases "map_of xs x"; cases "map_of ys x"; simp)
+ (force simp: map_of_eq_None_iff combine_options_altdef option.the_def o_def image_iff
+ dest: map_of_SomeD split: option.splits)+
+qed
+
+(* TODO: Move? *)
+lemma map_of_filter_distinct:
+ assumes "distinct (map fst xs)"
+ shows "map_of (filter P xs) x =
+ (case map_of xs x of None \<Rightarrow> None | Some y \<Rightarrow> if P (x,y) then Some y else None)"
+ using assms
+ by (auto simp: map_of_eq_None_iff filter_map distinct_map_filter dest: map_of_SomeD
+ simp del: map_of_eq_Some_iff intro!: map_of_is_SomeI split: option.splits)
+(* END TODO *)
+
+lemma filter_Mapping [code]:
+ "Mapping.filter P (Mapping xs) = Mapping (filter (\<lambda>(k,v). P k v) (AList.clearjunk xs))"
+ by (transfer, rule ext)
+ (subst map_of_filter_distinct, simp_all add: map_of_clearjunk split: option.split)
+
lemma [code nbe]:
"HOL.equal (x :: ('a, 'b) mapping) x \<longleftrightarrow> True"
by (fact equal_refl)
end
-
--- a/src/HOL/Library/Mapping.thy Tue May 31 12:24:43 2016 +0200
+++ b/src/HOL/Library/Mapping.thy Tue May 31 13:02:44 2016 +0200
@@ -88,6 +88,18 @@
"((A ===> B) ===> (C ===> D) ===> (B ===> rel_option C) ===> A ===> rel_option D)
(\<lambda>f g m. (map_option g \<circ> m \<circ> f)) (\<lambda>f g m. (map_option g \<circ> m \<circ> f))"
by transfer_prover
+
+lemma combine_with_key_parametric:
+ shows "((A ===> B ===> B ===> B) ===> (A ===> rel_option B) ===> (A ===> rel_option B) ===>
+ (A ===> rel_option B)) (\<lambda>f m1 m2 x. combine_options (f x) (m1 x) (m2 x))
+ (\<lambda>f m1 m2 x. combine_options (f x) (m1 x) (m2 x))"
+ unfolding combine_options_def by transfer_prover
+
+lemma combine_parametric:
+ shows "((B ===> B ===> B) ===> (A ===> rel_option B) ===> (A ===> rel_option B) ===>
+ (A ===> rel_option B)) (\<lambda>f m1 m2 x. combine_options f (m1 x) (m2 x))
+ (\<lambda>f m1 m2 x. combine_options f (m1 x) (m2 x))"
+ unfolding combine_options_def by transfer_prover
end
@@ -106,6 +118,8 @@
lift_definition lookup :: "('a, 'b) mapping \<Rightarrow> 'a \<Rightarrow> 'b option"
is "\<lambda>m k. m k" parametric lookup_parametric .
+definition "lookup_default d m k = (case Mapping.lookup m k of None \<Rightarrow> d | Some v \<Rightarrow> v)"
+
declare [[code drop: Mapping.lookup]]
setup \<open>Code.add_default_eqn @{thm Mapping.lookup.abs_eq}\<close> \<comment> \<open>FIXME lifting\<close>
@@ -115,6 +129,9 @@
lift_definition delete :: "'a \<Rightarrow> ('a, 'b) mapping \<Rightarrow> ('a, 'b) mapping"
is "\<lambda>k m. m(k := None)" parametric delete_parametric .
+lift_definition filter :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> ('a, 'b) mapping \<Rightarrow> ('a, 'b) mapping"
+ is "\<lambda>P m k. case m k of None \<Rightarrow> None | Some v \<Rightarrow> if P k v then Some v else None" .
+
lift_definition keys :: "('a, 'b) mapping \<Rightarrow> 'a set"
is dom parametric dom_parametric .
@@ -126,6 +143,20 @@
lift_definition map :: "('c \<Rightarrow> 'a) \<Rightarrow> ('b \<Rightarrow> 'd) \<Rightarrow> ('a, 'b) mapping \<Rightarrow> ('c, 'd) mapping"
is "\<lambda>f g m. (map_option g \<circ> m \<circ> f)" parametric map_parametric .
+
+lift_definition map_values :: "('c \<Rightarrow> 'a \<Rightarrow> 'b) \<Rightarrow> ('c, 'a) mapping \<Rightarrow> ('c, 'b) mapping"
+ is "\<lambda>f m x. map_option (f x) (m x)" .
+
+lift_definition combine_with_key ::
+ "('a \<Rightarrow> 'b \<Rightarrow> 'b \<Rightarrow> 'b) \<Rightarrow> ('a,'b) mapping \<Rightarrow> ('a,'b) mapping \<Rightarrow> ('a,'b) mapping"
+ is "\<lambda>f m1 m2 x. combine_options (f x) (m1 x) (m2 x)" parametric combine_with_key_parametric .
+
+lift_definition combine ::
+ "('b \<Rightarrow> 'b \<Rightarrow> 'b) \<Rightarrow> ('a,'b) mapping \<Rightarrow> ('a,'b) mapping \<Rightarrow> ('a,'b) mapping"
+ is "\<lambda>f m1 m2 x. combine_options f (m1 x) (m2 x)" parametric combine_parametric .
+
+definition All_mapping where
+ "All_mapping m P \<longleftrightarrow> (\<forall>x. case Mapping.lookup m x of None \<Rightarrow> True | Some y \<Rightarrow> P x y)"
declare [[code drop: map]]
@@ -217,10 +248,80 @@
"k \<noteq> k' \<Longrightarrow> lookup (update k v m) k' = lookup m k'"
by transfer simp
+lemma lookup_update':
+ "Mapping.lookup (update k v m) k' = (if k = k' then Some v else lookup m k')"
+ by (auto simp: lookup_update lookup_update_neq)
+
lemma lookup_empty:
"lookup empty k = None"
by transfer simp
+lemma lookup_filter:
+ "lookup (filter P m) k =
+ (case lookup m k of None \<Rightarrow> None | Some v \<Rightarrow> if P k v then Some v else None)"
+ by transfer simp_all
+
+lemma lookup_map_values:
+ "lookup (map_values f m) k = map_option (f k) (lookup m k)"
+ by transfer simp_all
+
+lemma lookup_default_empty: "lookup_default d empty k = d"
+ by (simp add: lookup_default_def lookup_empty)
+
+lemma lookup_default_update:
+ "lookup_default d (update k v m) k = v"
+ by (simp add: lookup_default_def lookup_update)
+
+lemma lookup_default_update_neq:
+ "k \<noteq> k' \<Longrightarrow> lookup_default d (update k v m) k' = lookup_default d m k'"
+ by (simp add: lookup_default_def lookup_update_neq)
+
+lemma lookup_default_update':
+ "lookup_default d (update k v m) k' = (if k = k' then v else lookup_default d m k')"
+ by (auto simp: lookup_default_update lookup_default_update_neq)
+
+lemma lookup_default_filter:
+ "lookup_default d (filter P m) k =
+ (if P k (lookup_default d m k) then lookup_default d m k else d)"
+ by (simp add: lookup_default_def lookup_filter split: option.splits)
+
+lemma lookup_default_map_values:
+ "lookup_default (f k d) (map_values f m) k = f k (lookup_default d m k)"
+ by (simp add: lookup_default_def lookup_map_values split: option.splits)
+
+lemma lookup_combine_with_key:
+ "Mapping.lookup (combine_with_key f m1 m2) x =
+ combine_options (f x) (Mapping.lookup m1 x) (Mapping.lookup m2 x)"
+ by transfer (auto split: option.splits)
+
+lemma combine_altdef: "combine f m1 m2 = combine_with_key (\<lambda>_. f) m1 m2"
+ by transfer' (rule refl)
+
+lemma lookup_combine:
+ "Mapping.lookup (combine f m1 m2) x =
+ combine_options f (Mapping.lookup m1 x) (Mapping.lookup m2 x)"
+ by transfer (auto split: option.splits)
+
+lemma lookup_default_neutral_combine_with_key:
+ assumes "\<And>x. f k d x = x" "\<And>x. f k x d = x"
+ shows "Mapping.lookup_default d (combine_with_key f m1 m2) k =
+ f k (Mapping.lookup_default d m1 k) (Mapping.lookup_default d m2 k)"
+ by (auto simp: lookup_default_def lookup_combine_with_key assms split: option.splits)
+
+lemma lookup_default_neutral_combine:
+ assumes "\<And>x. f d x = x" "\<And>x. f x d = x"
+ shows "Mapping.lookup_default d (combine f m1 m2) x =
+ f (Mapping.lookup_default d m1 x) (Mapping.lookup_default d m2 x)"
+ by (auto simp: lookup_default_def lookup_combine assms split: option.splits)
+
+lemma lookup_tabulate:
+ assumes "distinct xs"
+ shows "Mapping.lookup (Mapping.tabulate xs f) x = (if x \<in> set xs then Some (f x) else None)"
+ using assms by transfer (auto simp: map_of_eq_None_iff o_def dest!: map_of_SomeD)
+
+lemma lookup_of_alist: "Mapping.lookup (Mapping.of_alist xs) k = map_of xs k"
+ by transfer simp_all
+
lemma keys_is_none_rep [code_unfold]:
"k \<in> keys m \<longleftrightarrow> \<not> (Option.is_none (lookup m k))"
by transfer (auto simp add: Option.is_none_def)
@@ -247,6 +348,13 @@
"k \<notin> keys m \<Longrightarrow> replace k v m = m"
"k \<in> keys m \<Longrightarrow> replace k v m = update k v m"
by (transfer, auto simp add: replace_def fun_upd_twist)+
+
+lemma map_values_update: "map_values f (update k v m) = update k (f k v) (map_values f m)"
+ by transfer (simp_all add: fun_eq_iff)
+
+lemma size_mono:
+ "finite (keys m') \<Longrightarrow> keys m \<subseteq> keys m' \<Longrightarrow> size m \<le> size m'"
+ unfolding size_def by (auto intro: card_mono)
lemma size_empty [simp]:
"size empty = 0"
@@ -265,6 +373,13 @@
"size (tabulate ks f) = length (remdups ks)"
unfolding size_def by transfer (auto simp add: map_of_map_restrict card_set comp_def)
+lemma keys_filter: "keys (filter P m) \<subseteq> keys m"
+ by transfer (auto split: option.splits)
+
+lemma size_filter: "finite (keys m) \<Longrightarrow> size (filter P m) \<le> size m"
+ by (intro size_mono keys_filter)
+
+
lemma bulkload_tabulate:
"bulkload xs = tabulate [0..<length xs] (nth xs)"
by transfer (auto simp add: map_of_map_restrict)
@@ -293,6 +408,10 @@
"is_empty (map_entry k f m) \<longleftrightarrow> is_empty m"
unfolding is_empty_def by transfer (auto split: option.split)
+lemma is_empty_map_values [simp]:
+ "is_empty (map_values f m) \<longleftrightarrow> is_empty m"
+ unfolding is_empty_def by transfer (auto simp: fun_eq_iff)
+
lemma is_empty_map_default [simp]:
"\<not> is_empty (map_default k v f m)"
by (simp add: map_default_def)
@@ -329,10 +448,24 @@
"keys (map_default k v f m) = insert k (keys m)"
by (simp add: map_default_def)
+lemma keys_map_values [simp]:
+ "keys (map_values f m) = keys m"
+ by transfer (simp_all add: dom_def)
+
+lemma keys_combine_with_key [simp]:
+ "Mapping.keys (combine_with_key f m1 m2) = Mapping.keys m1 \<union> Mapping.keys m2"
+ by transfer (auto simp: dom_def combine_options_def split: option.splits)
+
+lemma keys_combine [simp]: "Mapping.keys (combine f m1 m2) = Mapping.keys m1 \<union> Mapping.keys m2"
+ by (simp add: combine_altdef)
+
lemma keys_tabulate [simp]:
"keys (tabulate ks f) = set ks"
by transfer (simp add: map_of_map_restrict o_def)
+lemma keys_of_alist [simp]: "keys (of_alist xs) = set (List.map fst xs)"
+ by transfer (simp_all add: dom_map_of_conv_image_fst)
+
lemma keys_bulkload [simp]:
"keys (bulkload xs) = {0..<length xs}"
by (simp add: bulkload_tabulate)
@@ -407,11 +540,91 @@
by simp
qed
+lemma All_mapping_mono:
+ "(\<And>k v. k \<in> keys m \<Longrightarrow> P k v \<Longrightarrow> Q k v) \<Longrightarrow> All_mapping m P \<Longrightarrow> All_mapping m Q"
+ unfolding All_mapping_def by transfer (auto simp: All_mapping_def dom_def split: option.splits)
-subsection \<open>Code generator setup\<close>
+lemma All_mapping_empty [simp]: "All_mapping Mapping.empty P"
+ by (auto simp: All_mapping_def lookup_empty)
+
+lemma All_mapping_update_iff:
+ "All_mapping (Mapping.update k v m) P \<longleftrightarrow> P k v \<and> All_mapping m (\<lambda>k' v'. k = k' \<or> P k' v')"
+ unfolding All_mapping_def
+proof safe
+ assume "\<forall>x. case Mapping.lookup (Mapping.update k v m) x of None \<Rightarrow> True | Some y \<Rightarrow> P x y"
+ hence A: "case Mapping.lookup (Mapping.update k v m) x of None \<Rightarrow> True | Some y \<Rightarrow> P x y" for x
+ by blast
+ from A[of k] show "P k v" by (simp add: lookup_update)
+ show "case Mapping.lookup m x of None \<Rightarrow> True | Some v' \<Rightarrow> k = x \<or> P x v'" for x
+ using A[of x] by (auto simp add: lookup_update' split: if_splits option.splits)
+next
+ assume "P k v"
+ assume "\<forall>x. case Mapping.lookup m x of None \<Rightarrow> True | Some v' \<Rightarrow> k = x \<or> P x v'"
+ hence A: "case Mapping.lookup m x of None \<Rightarrow> True | Some v' \<Rightarrow> k = x \<or> P x v'" for x by blast
+ show "case Mapping.lookup (Mapping.update k v m) x of None \<Rightarrow> True | Some xa \<Rightarrow> P x xa" for x
+ using \<open>P k v\<close> A[of x] by (auto simp: lookup_update' split: option.splits)
+qed
+
+lemma All_mapping_update:
+ "P k v \<Longrightarrow> All_mapping m (\<lambda>k' v'. k = k' \<or> P k' v') \<Longrightarrow> All_mapping (Mapping.update k v m) P"
+ by (simp add: All_mapping_update_iff)
+
+lemma All_mapping_filter_iff:
+ "All_mapping (filter P m) Q \<longleftrightarrow> All_mapping m (\<lambda>k v. P k v \<longrightarrow> Q k v)"
+ by (auto simp: All_mapping_def lookup_filter split: option.splits)
+
+lemma All_mapping_filter:
+ "All_mapping m Q \<Longrightarrow> All_mapping (filter P m) Q"
+ by (auto simp: All_mapping_filter_iff intro: All_mapping_mono)
-hide_const (open) empty is_empty rep lookup update delete ordered_keys keys size
- replace default map_entry map_default tabulate bulkload map of_alist
+lemma All_mapping_map_values:
+ "All_mapping (map_values f m) P \<longleftrightarrow> All_mapping m (\<lambda>k v. P k (f k v))"
+ by (auto simp: All_mapping_def lookup_map_values split: option.splits)
+
+lemma All_mapping_tabulate:
+ "(\<forall>x\<in>set xs. P x (f x)) \<Longrightarrow> All_mapping (Mapping.tabulate xs f) P"
+ unfolding All_mapping_def
+ by (intro allI, transfer) (auto split: option.split dest!: map_of_SomeD)
+
+lemma All_mapping_alist:
+ "(\<And>k v. (k, v) \<in> set xs \<Longrightarrow> P k v) \<Longrightarrow> All_mapping (Mapping.of_alist xs) P"
+ by (auto simp: All_mapping_def lookup_of_alist dest!: map_of_SomeD split: option.splits)
+
+
+lemma combine_empty [simp]:
+ "combine f Mapping.empty y = y" "combine f y Mapping.empty = y"
+ by (transfer, force)+
+
+lemma (in abel_semigroup) comm_monoid_set_combine: "comm_monoid_set (combine f) Mapping.empty"
+ by standard (transfer fixing: f, simp add: combine_options_ac[of f] ac_simps)+
+
+locale combine_mapping_abel_semigroup = abel_semigroup
+begin
+
+sublocale combine: comm_monoid_set "combine f" Mapping.empty
+ by (rule comm_monoid_set_combine)
+
+lemma fold_combine_code:
+ "combine.F g (set xs) = foldr (\<lambda>x. combine f (g x)) (remdups xs) Mapping.empty"
+proof -
+ have "combine.F g (set xs) = foldr (\<lambda>x. combine f (g x)) xs Mapping.empty"
+ if "distinct xs" for xs
+ using that by (induction xs) simp_all
+ from this[of "remdups xs"] show ?thesis by simp
+qed
+
+lemma keys_fold_combine:
+ assumes "finite A"
+ shows "Mapping.keys (combine.F g A) = (\<Union>x\<in>A. Mapping.keys (g x))"
+ using assms by (induction A rule: finite_induct) simp_all
end
+
+subsection \<open>Code generator setup\<close>
+
+hide_const (open) empty is_empty rep lookup lookup_default filter update delete ordered_keys
+ keys size replace default map_entry map_default tabulate bulkload map map_values combine of_alist
+
+end
+
--- a/src/HOL/Library/RBT.thy Tue May 31 12:24:43 2016 +0200
+++ b/src/HOL/Library/RBT.thy Tue May 31 13:02:44 2016 +0200
@@ -67,12 +67,22 @@
lift_definition foldi :: "('c \<Rightarrow> bool) \<Rightarrow> ('a \<Rightarrow> 'b \<Rightarrow> 'c \<Rightarrow> 'c) \<Rightarrow> ('a :: linorder, 'b) rbt \<Rightarrow> 'c \<Rightarrow> 'c"
is RBT_Impl.foldi .
+
+lift_definition combine_with_key :: "('a \<Rightarrow> 'b \<Rightarrow> 'b \<Rightarrow> 'b) \<Rightarrow> ('a::linorder, 'b) rbt \<Rightarrow> ('a, 'b) rbt \<Rightarrow> ('a, 'b) rbt"
+ is RBT_Impl.rbt_union_with_key by (rule is_rbt_rbt_unionwk)
+
+lift_definition combine :: "('b \<Rightarrow> 'b \<Rightarrow> 'b) \<Rightarrow> ('a::linorder, 'b) rbt \<Rightarrow> ('a, 'b) rbt \<Rightarrow> ('a, 'b) rbt"
+ is RBT_Impl.rbt_union_with by (rule rbt_unionw_is_rbt)
subsection \<open>Derived operations\<close>
definition is_empty :: "('a::linorder, 'b) rbt \<Rightarrow> bool" where
[code]: "is_empty t = (case impl_of t of RBT_Impl.Empty \<Rightarrow> True | _ \<Rightarrow> False)"
+(* TODO: Is deleting more efficient than re-building the tree?
+ (Probably more difficult to prove though *)
+definition filter :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> ('a::linorder, 'b) rbt \<Rightarrow> ('a, 'b) rbt" where
+ [code]: "filter P t = fold (\<lambda>k v t. if P k v then insert k v t else t) t empty"
subsection \<open>Abstract lookup properties\<close>
@@ -128,6 +138,17 @@
"lookup (map f t) k = map_option (f k) (lookup t k)"
by transfer (rule rbt_lookup_map)
+lemma lookup_combine_with_key [simp]:
+ "lookup (combine_with_key f t1 t2) k = combine_options (f k) (lookup t1 k) (lookup t2 k)"
+ by transfer (simp_all add: combine_options_def rbt_lookup_rbt_unionwk)
+
+lemma combine_altdef: "combine f t1 t2 = combine_with_key (\<lambda>_. f) t1 t2"
+ by transfer (simp add: rbt_union_with_def)
+
+lemma lookup_combine [simp]:
+ "lookup (combine f t1 t2) k = combine_options f (lookup t1 k) (lookup t2 k)"
+ by (simp add: combine_altdef)
+
lemma fold_fold:
"fold f t = List.fold (case_prod f) (entries t)"
by transfer (rule RBT_Impl.fold_def)
@@ -182,6 +203,26 @@
"keys t = List.map fst (entries t)"
by transfer (simp add: RBT_Impl.keys_def)
+context
+begin
+
+private lemma lookup_filter_aux:
+ assumes "distinct (List.map fst xs)"
+ shows "lookup (List.fold (\<lambda>(k, v) t. if P k v then insert k v t else t) xs t) k =
+ (case map_of xs k of
+ None \<Rightarrow> lookup t k
+ | Some v \<Rightarrow> if P k v then Some v else lookup t k)"
+ using assms by (induction xs arbitrary: t) (force split: option.splits)+
+
+lemma lookup_filter:
+ "lookup (filter P t) k =
+ (case lookup t k of None \<Rightarrow> None | Some v \<Rightarrow> if P k v then Some v else None)"
+ unfolding filter_def using lookup_filter_aux[of "entries t" P empty k]
+ by (simp add: fold_fold distinct_entries split: option.splits)
+
+end
+
+
subsection \<open>Quickcheck generators\<close>
quickcheck_generator rbt predicate: is_rbt constructors: empty, insert
--- a/src/HOL/Library/RBT_Mapping.thy Tue May 31 12:24:43 2016 +0200
+++ b/src/HOL/Library/RBT_Mapping.thy Tue May 31 13:02:44 2016 +0200
@@ -77,6 +77,24 @@
end
+lemma map_values_Mapping [code]:
+ "Mapping.map_values f (Mapping t) = Mapping (RBT.map f t)"
+ by (transfer fixing: t) (auto simp: fun_eq_iff)
+
+lemma filter_Mapping [code]:
+ "Mapping.filter P (Mapping t) = Mapping (RBT.filter P t)"
+ by (transfer' fixing: P t) (simp add: RBT.lookup_filter fun_eq_iff)
+
+lemma combine_with_key_Mapping [code]:
+ "Mapping.combine_with_key f (Mapping t1) (Mapping t2) =
+ Mapping (RBT.combine_with_key f t1 t2)"
+ by (transfer fixing: f t1 t2) (simp_all add: fun_eq_iff)
+
+lemma combine_Mapping [code]:
+ "Mapping.combine f (Mapping t1) (Mapping t2) =
+ Mapping (RBT.combine f t1 t2)"
+ by (transfer fixing: f t1 t2) (simp_all add: fun_eq_iff)
+
lemma equal_Mapping [code]:
"HOL.equal (Mapping t1) (Mapping t2) \<longleftrightarrow> RBT.entries t1 = RBT.entries t2"
by (transfer fixing: t1 t2) (simp add: entries_lookup)
--- a/src/HOL/Library/RBT_Set.thy Tue May 31 12:24:43 2016 +0200
+++ b/src/HOL/Library/RBT_Set.thy Tue May 31 13:02:44 2016 +0200
@@ -838,10 +838,10 @@
lemma Bleast_code [code]:
"Bleast (Set t) P =
- (case filter P (RBT.keys t) of
+ (case List.filter P (RBT.keys t) of
x # xs \<Rightarrow> x
| [] \<Rightarrow> abort_Bleast (Set t) P)"
-proof (cases "filter P (RBT.keys t)")
+proof (cases "List.filter P (RBT.keys t)")
case Nil
thus ?thesis by (simp add: Bleast_def abort_Bleast_def)
next
--- a/src/HOL/Option.thy Tue May 31 12:24:43 2016 +0200
+++ b/src/HOL/Option.thy Tue May 31 13:02:44 2016 +0200
@@ -136,6 +136,43 @@
| _ \<Rightarrow> False)"
by (auto split: prod.split option.split)
+
+definition combine_options :: "('a \<Rightarrow> 'a \<Rightarrow> 'a) \<Rightarrow> 'a option \<Rightarrow> 'a option \<Rightarrow> 'a option"
+ where "combine_options f x y =
+ (case x of None \<Rightarrow> y | Some x \<Rightarrow> (case y of None \<Rightarrow> Some x | Some y \<Rightarrow> Some (f x y)))"
+
+lemma combine_options_simps [simp]:
+ "combine_options f None y = y"
+ "combine_options f x None = x"
+ "combine_options f (Some a) (Some b) = Some (f a b)"
+ by (simp_all add: combine_options_def split: option.splits)
+
+lemma combine_options_cases [case_names None1 None2 Some]:
+ "(x = None \<Longrightarrow> P x y) \<Longrightarrow> (y = None \<Longrightarrow> P x y) \<Longrightarrow>
+ (\<And>a b. x = Some a \<Longrightarrow> y = Some b \<Longrightarrow> P x y) \<Longrightarrow> P x y"
+ by (cases x; cases y) simp_all
+
+lemma combine_options_commute:
+ "(\<And>x y. f x y = f y x) \<Longrightarrow> combine_options f x y = combine_options f y x"
+ using combine_options_cases[of x ]
+ by (induction x y rule: combine_options_cases) simp_all
+
+lemma combine_options_assoc:
+ "(\<And>x y z. f (f x y) z = f x (f y z)) \<Longrightarrow>
+ combine_options f (combine_options f x y) z =
+ combine_options f x (combine_options f y z)"
+ by (auto simp: combine_options_def split: option.splits)
+
+lemma combine_options_left_commute:
+ "(\<And>x y. f x y = f y x) \<Longrightarrow> (\<And>x y z. f (f x y) z = f x (f y z)) \<Longrightarrow>
+ combine_options f y (combine_options f x z) =
+ combine_options f x (combine_options f y z)"
+ by (auto simp: combine_options_def split: option.splits)
+
+lemmas combine_options_ac =
+ combine_options_commute combine_options_assoc combine_options_left_commute
+
+
context
begin
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Probability/PMF_Impl.thy Tue May 31 13:02:44 2016 +0200
@@ -0,0 +1,478 @@
+(* Title: HOL/Probability/PMF_Impl.thy
+ Author: Manuel Eberl, TU München
+
+ An implementation of PMFs using Mappings, which are implemented with association lists
+ by default. Also includes Quickcheck setup for PMFs.
+*)
+
+theory PMF_Impl
+imports Probability_Mass_Function "~~/src/HOL/Library/AList_Mapping"
+begin
+
+definition pmf_of_mapping :: "('a, real) mapping \<Rightarrow> 'a pmf" where
+ "pmf_of_mapping m = embed_pmf (Mapping.lookup_default 0 m)"
+
+lemma nn_integral_lookup_default:
+ fixes m :: "('a, real) mapping"
+ assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ x. x \<ge> 0)"
+ shows "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) =
+ ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)"
+proof -
+ have "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) =
+ (\<Sum>x\<in>Mapping.keys m. ennreal (Mapping.lookup_default 0 m x))" using assms
+ by (subst nn_integral_count_space'[of "Mapping.keys m"])
+ (auto simp: Mapping.lookup_default_def keys_is_none_rep Option.is_none_def)
+ also from assms have "\<dots> = ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)"
+ by (intro setsum_ennreal)
+ (auto simp: Mapping.lookup_default_def All_mapping_def split: option.splits)
+ finally show ?thesis .
+qed
+
+lemma pmf_of_mapping:
+ assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ p. p \<ge> 0)"
+ assumes "(\<Sum>x\<in>Mapping.keys m. Mapping.lookup_default 0 m x) = 1"
+ shows "pmf (pmf_of_mapping m) x = Mapping.lookup_default 0 m x"
+ unfolding pmf_of_mapping_def
+proof (intro pmf_embed_pmf)
+ from assms show "(\<integral>\<^sup>+x. ennreal (Mapping.lookup_default 0 m x) \<partial>count_space UNIV) = 1"
+ by (subst nn_integral_lookup_default) (simp_all)
+qed (insert assms, simp add: All_mapping_def Mapping.lookup_default_def split: option.splits)
+
+lemma pmf_of_set_pmf_of_mapping:
+ assumes "A \<noteq> {}" "set xs = A" "distinct xs"
+ shows "pmf_of_set A = pmf_of_mapping (Mapping.tabulate xs (\<lambda>_. 1 / real (length xs)))"
+ (is "?lhs = ?rhs")
+ by (rule pmf_eqI, subst pmf_of_mapping)
+ (insert assms, auto intro!: All_mapping_tabulate
+ simp: Mapping.lookup_default_def lookup_tabulate distinct_card)
+
+lift_definition mapping_of_pmf :: "'a pmf \<Rightarrow> ('a, real) mapping" is
+ "\<lambda>p x. if pmf p x = 0 then None else Some (pmf p x)" .
+
+lemma lookup_default_mapping_of_pmf:
+ "Mapping.lookup_default 0 (mapping_of_pmf p) x = pmf p x"
+ by (simp add: mapping_of_pmf.abs_eq lookup_default_def Mapping.lookup.abs_eq)
+
+context
+begin
+
+interpretation pmf_as_function .
+
+lemma nn_integral_pmf_eq_1: "(\<integral>\<^sup>+ x. ennreal (pmf p x) \<partial>count_space UNIV) = 1"
+ by transfer simp_all
+end
+
+lemma pmf_of_mapping_mapping_of_pmf [code abstype]:
+ "pmf_of_mapping (mapping_of_pmf p) = p"
+ unfolding pmf_of_mapping_def
+ by (rule pmf_eqI, subst pmf_embed_pmf)
+ (insert nn_integral_pmf_eq_1[of p],
+ auto simp: lookup_default_mapping_of_pmf split: option.splits)
+
+lemma mapping_of_pmfI:
+ assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup m x = Some (pmf p x)"
+ assumes "Mapping.keys m = set_pmf p"
+ shows "mapping_of_pmf p = m"
+ using assms by transfer (rule ext, auto simp: set_pmf_eq)
+
+lemma mapping_of_pmfI':
+ assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup_default 0 m x = pmf p x"
+ assumes "Mapping.keys m = set_pmf p"
+ shows "mapping_of_pmf p = m"
+ using assms unfolding Mapping.lookup_default_def
+ by transfer (rule ext, force simp: set_pmf_eq)
+
+lemma return_pmf_code [code abstract]:
+ "mapping_of_pmf (return_pmf x) = Mapping.update x 1 Mapping.empty"
+ by (intro mapping_of_pmfI) (auto simp: lookup_update')
+
+lemma pmf_of_set_code_aux:
+ assumes "A \<noteq> {}" "set xs = A" "distinct xs"
+ shows "mapping_of_pmf (pmf_of_set A) = Mapping.tabulate xs (\<lambda>_. 1 / real (length xs))"
+ using assms
+ by (intro mapping_of_pmfI, subst pmf_of_set)
+ (auto simp: lookup_tabulate distinct_card)
+
+definition pmf_of_set_impl where
+ "pmf_of_set_impl A = mapping_of_pmf (pmf_of_set A)"
+
+lemma pmf_of_set_impl_code [code]:
+ "pmf_of_set_impl (set xs) =
+ (if xs = [] then
+ Code.abort (STR ''pmf_of_set of empty set'') (\<lambda>_. mapping_of_pmf (pmf_of_set (set xs)))
+ else let xs' = remdups xs; p = 1 / real (length xs') in
+ Mapping.tabulate xs' (\<lambda>_. p))"
+ unfolding pmf_of_set_impl_def
+ using pmf_of_set_code_aux[of "set xs" "remdups xs"] by (simp add: Let_def)
+
+lemma pmf_of_set_code [code abstract]:
+ "mapping_of_pmf (pmf_of_set A) = pmf_of_set_impl A"
+ by (simp add: pmf_of_set_impl_def)
+
+
+lemma pmf_of_multiset_pmf_of_mapping:
+ assumes "A \<noteq> {#}" "set xs = set_mset A" "distinct xs"
+ shows "mapping_of_pmf (pmf_of_multiset A) = Mapping.tabulate xs (\<lambda>x. count A x / real (size A))"
+ using assms by (intro mapping_of_pmfI) (auto simp: lookup_tabulate)
+
+definition pmf_of_multiset_impl where
+ "pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)"
+
+lemma pmf_of_multiset_impl_code [code]:
+ "pmf_of_multiset_impl (mset xs) =
+ (if xs = [] then
+ Code.abort (STR ''pmf_of_multiset of empty multiset'')
+ (\<lambda>_. mapping_of_pmf (pmf_of_multiset (mset xs)))
+ else let xs' = remdups xs; p = 1 / real (length xs) in
+ Mapping.tabulate xs' (\<lambda>x. real (count (mset xs) x) * p))"
+ using pmf_of_multiset_pmf_of_mapping[of "mset xs" "remdups xs"]
+ by (simp add: pmf_of_multiset_impl_def)
+
+lemma pmf_of_multiset_code [code abstract]:
+ "mapping_of_pmf (pmf_of_multiset A) = pmf_of_multiset_impl A"
+ by (simp add: pmf_of_multiset_impl_def)
+
+lemma bernoulli_pmf_code [code abstract]:
+ "mapping_of_pmf (bernoulli_pmf p) =
+ (if p \<le> 0 then Mapping.update False 1 Mapping.empty
+ else if p \<ge> 1 then Mapping.update True 1 Mapping.empty
+ else Mapping.update False (1 - p) (Mapping.update True p Mapping.empty))"
+ by (intro mapping_of_pmfI) (auto simp: bernoulli_pmf.rep_eq lookup_update' set_pmf_eq)
+
+
+
+
+lemma pmf_code [code]: "pmf p x = Mapping.lookup_default 0 (mapping_of_pmf p) x"
+ unfolding mapping_of_pmf_def Mapping.lookup_default_def
+ by (auto split: option.splits simp: id_def Mapping.lookup.abs_eq)
+
+lemma set_pmf_code [code]: "set_pmf p = Mapping.keys (mapping_of_pmf p)"
+ by transfer (auto simp: dom_def set_pmf_eq)
+
+lemma keys_mapping_of_pmf [simp]: "Mapping.keys (mapping_of_pmf p) = set_pmf p"
+ by transfer (auto simp: dom_def set_pmf_eq)
+
+
+
+(* This is necessary since we want something the guarantees finiteness, but simply using
+ "finite" restricts the code equations to types where finiteness of the universe can
+ be decided. This simply fails when finiteness is not clear *)
+definition is_list_set where "is_list_set A = finite A"
+
+lemma is_list_set_code [code]: "is_list_set (set xs) = True"
+ by (simp add: is_list_set_def)
+
+definition fold_combine_plus where
+ "fold_combine_plus = comm_monoid_set.F (Mapping.combine (op + :: real \<Rightarrow> _)) Mapping.empty"
+
+context
+begin
+
+interpretation fold_combine_plus: combine_mapping_abel_semigroup "op + :: real \<Rightarrow> _"
+ by unfold_locales (simp_all add: add_ac)
+
+qualified lemma lookup_default_fold_combine_plus:
+ fixes A :: "'b set" and f :: "'b \<Rightarrow> ('a, real) mapping"
+ assumes "finite A"
+ shows "Mapping.lookup_default 0 (fold_combine_plus f A) x =
+ (\<Sum>y\<in>A. Mapping.lookup_default 0 (f y) x)"
+ unfolding fold_combine_plus_def using assms
+ by (induction A rule: finite_induct)
+ (simp_all add: lookup_default_empty lookup_default_neutral_combine)
+
+qualified lemma keys_fold_combine_plus:
+ "finite A \<Longrightarrow> Mapping.keys (fold_combine_plus f A) = (\<Union>x\<in>A. Mapping.keys (f x))"
+ by (simp add: fold_combine_plus_def fold_combine_plus.keys_fold_combine)
+
+qualified lemma fold_combine_plus_code [code]:
+ "fold_combine_plus g (set xs) = foldr (\<lambda>x. Mapping.combine op+ (g x)) (remdups xs) Mapping.empty"
+ by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code)
+
+private lemma lookup_default_0_map_values:
+ assumes "f 0 = 0"
+ shows "Mapping.lookup_default 0 (Mapping.map_values f m) x = f (Mapping.lookup_default 0 m x)"
+ unfolding Mapping.lookup_default_def
+ using assms by transfer (auto split: option.splits)
+
+qualified lemma mapping_of_bind_pmf:
+ assumes "finite (set_pmf p)"
+ shows "mapping_of_pmf (bind_pmf p f) =
+ fold_combine_plus (\<lambda>x. Mapping.map_values (op * (pmf p x))
+ (mapping_of_pmf (f x))) (set_pmf p)"
+ using assms
+ by (intro mapping_of_pmfI')
+ (auto simp: keys_fold_combine_plus lookup_default_fold_combine_plus
+ pmf_bind integral_measure_pmf lookup_default_0_map_values
+ lookup_default_mapping_of_pmf mult_ac)
+
+lemma bind_pmf_code [code abstract]:
+ "mapping_of_pmf (bind_pmf p f) =
+ (let A = set_pmf p in if is_list_set A then
+ fold_combine_plus (\<lambda>x. Mapping.map_values (op * (pmf p x)) (mapping_of_pmf (f x))) A
+ else
+ Code.abort (STR ''bind_pmf with infinite support.'') (\<lambda>_. mapping_of_pmf (bind_pmf p f)))"
+ using mapping_of_bind_pmf[of p f] by (auto simp: Let_def is_list_set_def)
+
+end
+
+hide_const (open) is_list_set fold_combine_plus
+
+
+lift_definition cond_pmf_impl :: "'a pmf \<Rightarrow> 'a set \<Rightarrow> ('a, real) mapping option" is
+ "\<lambda>p A. if A \<inter> set_pmf p = {} then None else
+ Some (\<lambda>x. if x \<in> A \<inter> set_pmf p then Some (pmf p x / measure_pmf.prob p A) else None)" .
+
+lemma cond_pmf_impl_code [code]:
+ "cond_pmf_impl p (set xs) = (
+ let B = set_pmf p;
+ xs' = remdups (filter (\<lambda>x. x \<in> B) xs);
+ prob = listsum (map (pmf p) xs')
+ in if prob = 0 then
+ None
+ else
+ Some (Mapping.map_values (\<lambda>y. y / prob)
+ (Mapping.filter (\<lambda>k _. k \<in> set xs') (mapping_of_pmf p))))"
+proof -
+ define xs' where "xs' = remdups (filter (\<lambda>x. x \<in> set_pmf p) xs)"
+ have xs': "set xs' = set xs \<inter> set_pmf p" "distinct xs'" by (auto simp: xs'_def)
+ define prob where "prob = listsum (map (pmf p) xs')"
+ have "prob = (\<Sum>x\<in>set xs'. pmf p x)"
+ unfolding prob_def by (rule listsum_distinct_conv_setsum_set) (simp_all add: xs'_def)
+ also note xs'(1)
+ also have "(\<Sum>x\<in>set xs \<inter> set_pmf p. pmf p x) = (\<Sum>x\<in>set xs. pmf p x)"
+ by (intro setsum.mono_neutral_left) (auto simp: set_pmf_eq)
+ finally have prob1: "prob = (\<Sum>x\<in>set xs. pmf p x)" .
+ hence prob2: "prob = measure_pmf.prob p (set xs)"
+ by (subst measure_measure_pmf_finite) simp_all
+ have prob3: "prob = 0 \<longleftrightarrow> set xs \<inter> set_pmf p = {}"
+ by (subst prob1, subst setsum_nonneg_eq_0_iff) (auto simp: set_pmf_eq)
+
+ show ?thesis
+ proof (cases "prob = 0")
+ case True
+ hence "set xs \<inter> set_pmf p = {}" by (subst (asm) prob3)
+ with True show ?thesis by (simp add: Let_def prob_def xs'_def cond_pmf_impl.abs_eq)
+ next
+ case False
+ hence A: "set xs' \<noteq> {}" unfolding xs' by (subst (asm) prob3) auto
+ with xs' prob3 have prob_nz: "prob \<noteq> 0" by auto
+ fix x
+ have "cond_pmf_impl p (set xs) =
+ Some (mapping.Mapping (\<lambda>x. if x \<in> set xs' then
+ Some (pmf p x / measure_pmf.prob p (set xs)) else None))"
+ (is "_ = Some ?m")
+ using A unfolding xs'_def by transfer auto
+ also have "?m = Mapping.map_values (\<lambda>y. y / prob)
+ (Mapping.filter (\<lambda>k _. k \<in> set xs') (mapping_of_pmf p))"
+ unfolding prob2 [symmetric] xs' using xs' prob_nz
+ by transfer (rule ext, simp add: set_pmf_eq)
+ finally show ?thesis using False by (simp add: Let_def prob_def xs'_def)
+ qed
+qed
+
+lemma cond_pmf_code [code abstract]:
+ "mapping_of_pmf (cond_pmf p A) =
+ (case cond_pmf_impl p A of
+ None \<Rightarrow> Code.abort (STR ''cond_pmf with set of probability 0'')
+ (\<lambda>_. mapping_of_pmf (cond_pmf p A))
+ | Some m \<Rightarrow> m)"
+proof (cases "cond_pmf_impl p A")
+ case (Some m)
+ hence A: "set_pmf p \<inter> A \<noteq> {}" by transfer (auto split: if_splits)
+ from Some have B: "Mapping.keys m = set_pmf (cond_pmf p A)"
+ by (subst set_cond_pmf[OF A], transfer) (auto split: if_splits)
+ with Some A have "mapping_of_pmf (cond_pmf p A) = m"
+ by (intro mapping_of_pmfI[OF _ B], transfer) (auto split: if_splits simp: pmf_cond)
+ with Some show ?thesis by simp
+qed simp_all
+
+
+lemma binomial_pmf_code [code abstract]:
+ "mapping_of_pmf (binomial_pmf n p) = (
+ if p < 0 \<or> p > 1 then
+ Code.abort (STR ''binomial_pmf with invalid probability'') (\<lambda>_. mapping_of_pmf (binomial_pmf n p))
+ else if p = 0 then Mapping.update 0 1 Mapping.empty
+ else if p = 1 then Mapping.update n 1 Mapping.empty
+ else Mapping.tabulate [0..<Suc n] (\<lambda>k. real (n choose k) * p ^ k * (1 - p) ^ (n - k)))"
+ by (cases "p < 0 \<or> p > 1")
+ (simp, intro mapping_of_pmfI,
+ auto simp: lookup_update' lookup_empty set_pmf_binomial_eq lookup_tabulate split: if_splits)
+
+lemma pred_pmf_code [code]:
+ "pred_pmf P p = (\<forall>x\<in>set_pmf p. P x)"
+ by (auto simp: pred_pmf_def)
+
+
+definition pmf_integral where
+ "pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \<Rightarrow> real)"
+
+definition pmf_set_integral where
+ "pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\<lambda>x. indicator A x * f x :: real)"
+
+definition pmf_prob where
+ "pmf_prob p A = measure_pmf.prob p A"
+
+lemma pmf_integral_pmf_set_integral [code]:
+ "pmf_integral p f = pmf_set_integral p f (set_pmf p)"
+ unfolding pmf_integral_def pmf_set_integral_def
+ by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff)
+
+lemma pmf_set_integral_code [code]:
+ "pmf_set_integral p f (set xs) = listsum (map (\<lambda>x. pmf p x * f x) (remdups xs))"
+proof -
+ have "listsum (map (\<lambda>x. pmf p x * f x) (remdups xs)) = (\<Sum>x\<in>set xs. pmf p x * f x)"
+ by (subst listsum_distinct_conv_setsum_set) simp_all
+ also have "\<dots> = pmf_set_integral p f (set xs)" unfolding pmf_set_integral_def
+ by (subst integral_measure_pmf[of "set xs"])
+ (auto simp: indicator_def mult_ac split: if_splits)
+ finally show ?thesis ..
+qed
+
+lemma pmf_prob_code [code]:
+ "pmf_prob p (set xs) = listsum (map (pmf p) (remdups xs))"
+proof -
+ have "pmf_prob p (set xs) = pmf_set_integral p (\<lambda>_. 1) (set xs)"
+ unfolding pmf_prob_def pmf_set_integral_def by simp
+ also have "\<dots> = listsum (map (pmf p) (remdups xs))"
+ unfolding pmf_set_integral_code by simp
+ finally show ?thesis .
+qed
+
+lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p"
+ by (intro ext) (simp add: pmf_prob_def)
+
+(* Why does this not work without parameters? *)
+lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p"
+ by (intro ext) (simp add: pmf_integral_def)
+
+lemma mapping_of_pmf_pmf_of_list:
+ assumes "\<And>x. x \<in> snd ` set xs \<Longrightarrow> x > 0" "listsum (map snd xs) = 1"
+ shows "mapping_of_pmf (pmf_of_list xs) =
+ Mapping.tabulate (remdups (map fst xs))
+ (\<lambda>x. listsum (map snd (filter (\<lambda>z. fst z = x) xs)))"
+proof -
+ from assms have wf: "pmf_of_list_wf xs" by (intro pmf_of_list_wfI) force
+ moreover from this assms have "set_pmf (pmf_of_list xs) = fst ` set xs"
+ by (intro set_pmf_of_list_eq) auto
+ ultimately show ?thesis
+ by (intro mapping_of_pmfI) (auto simp: lookup_tabulate pmf_pmf_of_list)
+qed
+
+lemma mapping_of_pmf_pmf_of_list':
+ assumes "pmf_of_list_wf xs"
+ defines "xs' \<equiv> filter (\<lambda>z. snd z \<noteq> 0) xs"
+ shows "mapping_of_pmf (pmf_of_list xs) =
+ Mapping.tabulate (remdups (map fst xs'))
+ (\<lambda>x. listsum (map snd (filter (\<lambda>z. fst z = x) xs')))" (is "_ = ?rhs")
+proof -
+ have wf: "pmf_of_list_wf xs'" unfolding xs'_def by (rule pmf_of_list_remove_zeros) fact
+ have pos: "\<forall>x\<in>snd`set xs'. x > 0" using assms(1) unfolding xs'_def
+ by (force simp: pmf_of_list_wf_def)
+ from assms have "pmf_of_list xs = pmf_of_list xs'"
+ unfolding xs'_def by (subst pmf_of_list_remove_zeros) simp_all
+ also from wf pos have "mapping_of_pmf \<dots> = ?rhs"
+ by (intro mapping_of_pmf_pmf_of_list) (auto simp: pmf_of_list_wf_def)
+ finally show ?thesis .
+qed
+
+lemma pmf_of_list_wf_code [code]:
+ "pmf_of_list_wf xs \<longleftrightarrow> list_all (\<lambda>z. snd z \<ge> 0) xs \<and> listsum (map snd xs) = 1"
+ by (auto simp add: pmf_of_list_wf_def list_all_def)
+
+lemma pmf_of_list_code [code abstract]:
+ "mapping_of_pmf (pmf_of_list xs) = (
+ if pmf_of_list_wf xs then
+ let xs' = filter (\<lambda>z. snd z \<noteq> 0) xs
+ in Mapping.tabulate (remdups (map fst xs'))
+ (\<lambda>x. listsum (map snd (filter (\<lambda>z. fst z = x) xs')))
+ else
+ Code.abort (STR ''Invalid list for pmf_of_list'') (\<lambda>_. mapping_of_pmf (pmf_of_list xs)))"
+ using mapping_of_pmf_pmf_of_list'[of xs] by (simp add: Let_def)
+
+
+lemma mapping_of_pmf_eq_iff [simp]:
+ "mapping_of_pmf p = mapping_of_pmf q \<longleftrightarrow> p = (q :: 'a pmf)"
+proof (transfer, intro iffI pmf_eqI)
+ fix p q :: "'a pmf" and x :: 'a
+ assume "(\<lambda>x. if pmf p x = 0 then None else Some (pmf p x)) =
+ (\<lambda>x. if pmf q x = 0 then None else Some (pmf q x))"
+ hence "(if pmf p x = 0 then None else Some (pmf p x)) =
+ (if pmf q x = 0 then None else Some (pmf q x))" for x
+ by (simp add: fun_eq_iff)
+ from this[of x] show "pmf p x = pmf q x" by (auto split: if_splits)
+qed (simp_all cong: if_cong)
+
+definition "pmf_of_alist xs = embed_pmf (\<lambda>x. case map_of xs x of Some p \<Rightarrow> p | None \<Rightarrow> 0)"
+
+lemma pmf_of_mapping_Mapping [code_post]:
+ "pmf_of_mapping (Mapping xs) = pmf_of_alist xs"
+ unfolding pmf_of_mapping_def Mapping.lookup_default_def [abs_def] pmf_of_alist_def
+ by transfer simp_all
+
+
+instantiation pmf :: (equal) equal
+begin
+
+definition "equal_pmf p q = (mapping_of_pmf p = mapping_of_pmf (q :: 'a pmf))"
+
+instance by standard (simp add: equal_pmf_def)
+end
+
+
+definition (in term_syntax)
+ pmfify :: "('a::typerep multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)) \<Rightarrow>
+ 'a \<times> (unit \<Rightarrow> Code_Evaluation.term) \<Rightarrow>
+ 'a pmf \<times> (unit \<Rightarrow> Code_Evaluation.term)" where
+ [code_unfold]: "pmfify A x =
+ Code_Evaluation.valtermify pmf_of_multiset {\<cdot>}
+ (Code_Evaluation.valtermify (op +) {\<cdot>} A {\<cdot>}
+ (Code_Evaluation.valtermify single {\<cdot>} x))"
+
+
+notation fcomp (infixl "\<circ>>" 60)
+notation scomp (infixl "\<circ>\<rightarrow>" 60)
+
+instantiation pmf :: (random) random
+begin
+
+definition
+ "Quickcheck_Random.random i =
+ Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>A.
+ Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>x. Pair (pmfify A x)))"
+
+instance ..
+
+end
+
+no_notation fcomp (infixl "\<circ>>" 60)
+no_notation scomp (infixl "\<circ>\<rightarrow>" 60)
+
+(*
+instantiation pmf :: (exhaustive) exhaustive
+begin
+
+definition exhaustive_pmf :: "('a pmf \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option"
+where
+ "exhaustive_pmf f i =
+ Quickcheck_Exhaustive.exhaustive (\<lambda>A.
+ Quickcheck_Exhaustive.exhaustive (\<lambda>x. f (pmf_of_multiset (A + {#x#}))) i) i"
+
+instance ..
+
+end
+*)
+
+instantiation pmf :: (full_exhaustive) full_exhaustive
+begin
+
+definition full_exhaustive_pmf :: "('a pmf \<times> (unit \<Rightarrow> term) \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option"
+where
+ "full_exhaustive_pmf f i =
+ Quickcheck_Exhaustive.full_exhaustive (\<lambda>A.
+ Quickcheck_Exhaustive.full_exhaustive (\<lambda>x. f (pmfify A x)) i) i"
+
+instance ..
+
+end
+
+end
\ No newline at end of file
--- a/src/HOL/Probability/Probability.thy Tue May 31 12:24:43 2016 +0200
+++ b/src/HOL/Probability/Probability.thy Tue May 31 13:02:44 2016 +0200
@@ -8,6 +8,7 @@
Complete_Measure
Projective_Limit
Probability_Mass_Function
+ PMF_Impl
Stream_Space
Random_Permutations
Embed_Measure
--- a/src/HOL/Probability/Probability_Mass_Function.thy Tue May 31 12:24:43 2016 +0200
+++ b/src/HOL/Probability/Probability_Mass_Function.thy Tue May 31 13:02:44 2016 +0200
@@ -1787,6 +1787,58 @@
end
+primrec replicate_pmf :: "nat \<Rightarrow> 'a pmf \<Rightarrow> 'a list pmf" where
+ "replicate_pmf 0 _ = return_pmf []"
+| "replicate_pmf (Suc n) p = do {x \<leftarrow> p; xs \<leftarrow> replicate_pmf n p; return_pmf (x#xs)}"
+
+lemma replicate_pmf_1: "replicate_pmf 1 p = map_pmf (\<lambda>x. [x]) p"
+ by (simp add: map_pmf_def bind_return_pmf)
+
+lemma set_replicate_pmf:
+ "set_pmf (replicate_pmf n p) = {xs\<in>lists (set_pmf p). length xs = n}"
+ by (induction n) (auto simp: length_Suc_conv)
+
+lemma replicate_pmf_distrib:
+ "replicate_pmf (m + n) p =
+ do {xs \<leftarrow> replicate_pmf m p; ys \<leftarrow> replicate_pmf n p; return_pmf (xs @ ys)}"
+ by (induction m) (simp_all add: bind_return_pmf bind_return_pmf' bind_assoc_pmf)
+
+lemma power_diff':
+ assumes "b \<le> a"
+ shows "x ^ (a - b) = (if x = 0 \<and> a = b then 1 else x ^ a / (x::'a::field) ^ b)"
+proof (cases "x = 0")
+ case True
+ with assms show ?thesis by (cases "a - b") simp_all
+qed (insert assms, simp_all add: power_diff)
+
+
+lemma binomial_pmf_Suc:
+ assumes "p \<in> {0..1}"
+ shows "binomial_pmf (Suc n) p =
+ do {b \<leftarrow> bernoulli_pmf p;
+ k \<leftarrow> binomial_pmf n p;
+ return_pmf ((if b then 1 else 0) + k)}" (is "_ = ?rhs")
+proof (intro pmf_eqI)
+ fix k
+ have A: "indicator {Suc a} (Suc b) = indicator {a} b" for a b
+ by (simp add: indicator_def)
+ show "pmf (binomial_pmf (Suc n) p) k = pmf ?rhs k"
+ by (cases k; cases "k > n")
+ (insert assms, auto simp: pmf_bind measure_pmf_single A divide_simps algebra_simps
+ not_less less_eq_Suc_le [symmetric] power_diff')
+qed
+
+lemma binomial_pmf_0: "p \<in> {0..1} \<Longrightarrow> binomial_pmf 0 p = return_pmf 0"
+ by (rule pmf_eqI) (simp_all add: indicator_def)
+
+lemma binomial_pmf_altdef:
+ assumes "p \<in> {0..1}"
+ shows "binomial_pmf n p = map_pmf (length \<circ> filter id) (replicate_pmf n (bernoulli_pmf p))"
+ by (induction n)
+ (insert assms, auto simp: binomial_pmf_Suc map_pmf_def bind_return_pmf bind_assoc_pmf
+ bind_return_pmf' binomial_pmf_0 intro!: bind_pmf_cong)
+
+
subsection \<open>PMFs from assiciation lists\<close>
definition pmf_of_list ::" ('a \<times> real) list \<Rightarrow> 'a pmf" where
@@ -1921,4 +1973,52 @@
using assms unfolding pmf_of_list_wf_def Sigma_Algebra.measure_def
by (subst emeasure_pmf_of_list [OF assms], subst enn2real_ennreal) (auto intro!: listsum_nonneg)
+(* TODO Move? *)
+lemma listsum_nonneg_eq_zero_iff:
+ fixes xs :: "'a :: linordered_ab_group_add list"
+ shows "(\<And>x. x \<in> set xs \<Longrightarrow> x \<ge> 0) \<Longrightarrow> listsum xs = 0 \<longleftrightarrow> set xs \<subseteq> {0}"
+proof (induction xs)
+ case (Cons x xs)
+ from Cons.prems have "listsum (x#xs) = 0 \<longleftrightarrow> x = 0 \<and> listsum xs = 0"
+ unfolding listsum_simps by (subst add_nonneg_eq_0_iff) (auto intro: listsum_nonneg)
+ with Cons.IH Cons.prems show ?case by simp
+qed simp_all
+
+lemma listsum_filter_nonzero:
+ "listsum (filter (\<lambda>x. x \<noteq> 0) xs) = listsum xs"
+ by (induction xs) simp_all
+(* END MOVE *)
+
+lemma set_pmf_of_list_eq:
+ assumes "pmf_of_list_wf xs" "\<And>x. x \<in> snd ` set xs \<Longrightarrow> x > 0"
+ shows "set_pmf (pmf_of_list xs) = fst ` set xs"
+proof
+ {
+ fix x assume A: "x \<in> fst ` set xs" and B: "x \<notin> set_pmf (pmf_of_list xs)"
+ then obtain y where y: "(x, y) \<in> set xs" by auto
+ from B have "listsum (map snd [z\<leftarrow>xs. fst z = x]) = 0"
+ by (simp add: pmf_pmf_of_list[OF assms(1)] set_pmf_eq)
+ moreover from y have "y \<in> snd ` {xa \<in> set xs. fst xa = x}" by force
+ ultimately have "y = 0" using assms(1)
+ by (subst (asm) listsum_nonneg_eq_zero_iff) (auto simp: pmf_of_list_wf_def)
+ with assms(2) y have False by force
+ }
+ thus "fst ` set xs \<subseteq> set_pmf (pmf_of_list xs)" by blast
+qed (insert set_pmf_of_list[OF assms(1)], simp_all)
+
+lemma pmf_of_list_remove_zeros:
+ assumes "pmf_of_list_wf xs"
+ defines "xs' \<equiv> filter (\<lambda>z. snd z \<noteq> 0) xs"
+ shows "pmf_of_list_wf xs'" "pmf_of_list xs' = pmf_of_list xs"
+proof -
+ have "map snd [z\<leftarrow>xs . snd z \<noteq> 0] = filter (\<lambda>x. x \<noteq> 0) (map snd xs)"
+ by (induction xs) simp_all
+ with assms(1) show wf: "pmf_of_list_wf xs'"
+ by (auto simp: pmf_of_list_wf_def xs'_def listsum_filter_nonzero)
+ have "listsum (map snd [z\<leftarrow>xs' . fst z = i]) = listsum (map snd [z\<leftarrow>xs . fst z = i])" for i
+ unfolding xs'_def by (induction xs) simp_all
+ with assms(1) wf show "pmf_of_list xs' = pmf_of_list xs"
+ by (intro pmf_eqI) (simp_all add: pmf_pmf_of_list)
+qed
+
end
--- a/src/HOL/Probability/Random_Permutations.thy Tue May 31 12:24:43 2016 +0200
+++ b/src/HOL/Probability/Random_Permutations.thy Tue May 31 13:02:44 2016 +0200
@@ -102,7 +102,11 @@
map_pmf (\<lambda>xs. fold f xs x) (pmf_of_set (permutations_of_set A))"
by (subst fold_random_permutation_foldl [OF assms], intro map_pmf_cong)
(simp_all add: foldl_conv_fold)
-
+
+lemma fold_random_permutation_code [code]:
+ "fold_random_permutation f x (set xs) =
+ map_pmf (foldl (\<lambda>x y. f y x) x) (pmf_of_set (permutations_of_set (set xs)))"
+ by (simp add: fold_random_permutation_foldl)
text \<open>
We now introduce a slightly generalised version of the above fold
@@ -134,7 +138,7 @@
We now show that the recursive definition is equivalent to
a random fold followed by a monadic bind.
\<close>
-lemma fold_bind_random_permutation_altdef:
+lemma fold_bind_random_permutation_altdef [code]:
"fold_bind_random_permutation f g x A = fold_random_permutation f x A \<bind> g"
proof (induction f x A rule: fold_random_permutation.induct [case_names empty infinite remove])
case (remove A f x)