--- a/src/HOL/Probability/Probability_Mass_Function.thy Tue May 17 08:40:24 2016 +0200
+++ b/src/HOL/Probability/Probability_Mass_Function.thy Tue May 17 17:05:35 2016 +0200
@@ -167,6 +167,12 @@
lemma pmf_nonneg[simp]: "0 \<le> pmf p x"
by transfer simp
+
+lemma pmf_not_neg [simp]: "\<not>pmf p x < 0"
+ by (simp add: not_less pmf_nonneg)
+
+lemma pmf_pos [simp]: "pmf p x \<noteq> 0 \<Longrightarrow> pmf p x > 0"
+ using pmf_nonneg[of p x] by linarith
lemma pmf_le_1: "pmf p x \<le> 1"
by (simp add: pmf.rep_eq)
@@ -183,6 +189,13 @@
lemma set_pmf_eq: "set_pmf M = {x. pmf M x \<noteq> 0}"
by (auto simp: set_pmf_iff)
+lemma set_pmf_eq': "set_pmf p = {x. pmf p x > 0}"
+proof safe
+ fix x assume "x \<in> set_pmf p"
+ hence "pmf p x \<noteq> 0" by (auto simp: set_pmf_eq)
+ with pmf_nonneg[of p x] show "pmf p x > 0" by simp
+qed (auto simp: set_pmf_eq)
+
lemma emeasure_pmf_single:
fixes M :: "'a pmf"
shows "emeasure M {x} = pmf M x"
@@ -198,6 +211,17 @@
using emeasure_measure_pmf_finite[of S M]
by (simp add: measure_pmf.emeasure_eq_measure measure_nonneg setsum_nonneg pmf_nonneg)
+lemma setsum_pmf_eq_1:
+ assumes "finite A" "set_pmf p \<subseteq> A"
+ shows "(\<Sum>x\<in>A. pmf p x) = 1"
+proof -
+ have "(\<Sum>x\<in>A. pmf p x) = measure_pmf.prob p A"
+ by (simp add: measure_measure_pmf_finite assms)
+ also from assms have "\<dots> = 1"
+ by (subst measure_pmf.prob_eq_1) (auto simp: AE_measure_pmf_iff)
+ finally show ?thesis .
+qed
+
lemma nn_integral_measure_pmf_support:
fixes f :: "'a \<Rightarrow> ennreal"
assumes f: "finite A" and nn: "\<And>x. x \<in> A \<Longrightarrow> 0 \<le> f x" "\<And>x. x \<in> set_pmf M \<Longrightarrow> x \<notin> A \<Longrightarrow> f x = 0"
@@ -339,6 +363,8 @@
done
qed
+adhoc_overloading Monad_Syntax.bind bind_pmf
+
lemma ennreal_pmf_bind: "pmf (bind_pmf N f) i = (\<integral>\<^sup>+x. pmf (f x) i \<partial>measure_pmf N)"
unfolding pmf.rep_eq bind_pmf.rep_eq
by (auto simp: measure_pmf.measure_bind[where N="count_space UNIV"] measure_subprob measure_nonneg
@@ -363,7 +389,7 @@
finally show ?thesis .
qed
-lemma bind_pmf_cong:
+lemma bind_pmf_cong [fundef_cong]:
assumes "p = q"
shows "(\<And>x. x \<in> set_pmf q \<Longrightarrow> f x = g x) \<Longrightarrow> bind_pmf p f = bind_pmf q g"
unfolding \<open>p = q\<close>[symmetric] measure_pmf_inject[symmetric] bind_pmf.rep_eq
@@ -518,6 +544,15 @@
lemma emeasure_return_pmf[simp]: "emeasure (return_pmf x) X = indicator X x"
unfolding return_pmf.rep_eq by (intro emeasure_return) auto
+lemma measure_return_pmf [simp]: "measure_pmf.prob (return_pmf x) A = indicator A x"
+proof -
+ have "ennreal (measure_pmf.prob (return_pmf x) A) =
+ emeasure (measure_pmf (return_pmf x)) A"
+ by (simp add: measure_pmf.emeasure_eq_measure)
+ also have "\<dots> = ennreal (indicator A x)" by (simp add: ennreal_indicator)
+ finally show ?thesis by simp
+qed
+
lemma return_pmf_inj[simp]: "return_pmf x = return_pmf y \<longleftrightarrow> x = y"
by (metis insertI1 set_return_pmf singletonD)
@@ -732,6 +767,27 @@
lemma pmf_eq_iff: "M = N \<longleftrightarrow> (\<forall>i. pmf M i = pmf N i)"
by (auto intro: pmf_eqI)
+lemma pmf_neq_exists_less:
+ assumes "M \<noteq> N"
+ shows "\<exists>x. pmf M x < pmf N x"
+proof (rule ccontr)
+ assume "\<not>(\<exists>x. pmf M x < pmf N x)"
+ hence ge: "pmf M x \<ge> pmf N x" for x by (auto simp: not_less)
+ from assms obtain x where "pmf M x \<noteq> pmf N x" by (auto simp: pmf_eq_iff)
+ with ge[of x] have gt: "pmf M x > pmf N x" by simp
+ have "1 = measure (measure_pmf M) UNIV" by simp
+ also have "\<dots> = measure (measure_pmf N) {x} + measure (measure_pmf N) (UNIV - {x})"
+ by (subst measure_pmf.finite_measure_Union [symmetric]) simp_all
+ also from gt have "measure (measure_pmf N) {x} < measure (measure_pmf M) {x}"
+ by (simp add: measure_pmf_single)
+ also have "measure (measure_pmf N) (UNIV - {x}) \<le> measure (measure_pmf M) (UNIV - {x})"
+ by (subst (1 2) integral_pmf [symmetric])
+ (intro integral_mono integrable_pmf, simp_all add: ge)
+ also have "measure (measure_pmf M) {x} + \<dots> = 1"
+ by (subst measure_pmf.finite_measure_Union [symmetric]) simp_all
+ finally show False by simp_all
+qed
+
lemma bind_commute_pmf: "bind_pmf A (\<lambda>x. bind_pmf B (C x)) = bind_pmf B (\<lambda>y. bind_pmf A (\<lambda>x. C x y))"
unfolding pmf_eq_iff pmf_bind
proof
@@ -904,6 +960,9 @@
end
+lemma measure_pmf_posI: "x \<in> set_pmf p \<Longrightarrow> x \<in> A \<Longrightarrow> measure_pmf.prob p A > 0"
+ using measure_measure_pmf_not_zero[of p A] by (subst zero_less_measure_iff) blast
+
lemma cond_map_pmf:
assumes "set_pmf p \<inter> f -` s \<noteq> {}"
shows "cond_pmf (map_pmf f p) s = map_pmf f (cond_pmf p (f -` s))"
@@ -1568,6 +1627,31 @@
end
+lemma map_pmf_of_set:
+ assumes "finite A" "A \<noteq> {}"
+ shows "map_pmf f (pmf_of_set A) = pmf_of_multiset (image_mset f (mset_set A))"
+ (is "?lhs = ?rhs")
+proof (intro pmf_eqI)
+ fix x
+ from assms have "ennreal (pmf ?lhs x) = ennreal (pmf ?rhs x)"
+ by (subst ennreal_pmf_map)
+ (simp_all add: emeasure_pmf_of_set mset_set_empty_iff count_image_mset Int_commute)
+ thus "pmf ?lhs x = pmf ?rhs x" by simp
+qed
+
+lemma pmf_bind_pmf_of_set:
+ assumes "A \<noteq> {}" "finite A"
+ shows "pmf (bind_pmf (pmf_of_set A) f) x =
+ (\<Sum>xa\<in>A. pmf (f xa) x) / real_of_nat (card A)" (is "?lhs = ?rhs")
+proof -
+ from assms have "card A > 0" by auto
+ with assms have "ennreal ?lhs = ennreal ?rhs"
+ by (subst ennreal_pmf_bind)
+ (simp_all add: nn_integral_pmf_of_set max_def pmf_nonneg divide_ennreal [symmetric]
+ setsum_nonneg ennreal_of_nat_eq_real_of_nat)
+ thus ?thesis by (subst (asm) ennreal_inj) (auto intro!: setsum_nonneg divide_nonneg_nonneg)
+qed
+
lemma pmf_of_set_singleton: "pmf_of_set {x} = return_pmf x"
by(rule pmf_eqI)(simp add: indicator_def)
@@ -1590,6 +1674,38 @@
qed
qed
+text \<open>
+ Choosing an element uniformly at random from the union of a disjoint family
+ of finite non-empty sets with the same size is the same as first choosing a set
+ from the family uniformly at random and then choosing an element from the chosen set
+ uniformly at random.
+\<close>
+lemma pmf_of_set_UN:
+ assumes "finite (UNION A f)" "A \<noteq> {}" "\<And>x. x \<in> A \<Longrightarrow> f x \<noteq> {}"
+ "\<And>x. x \<in> A \<Longrightarrow> card (f x) = n" "disjoint_family_on f A"
+ shows "pmf_of_set (UNION A f) = do {x \<leftarrow> pmf_of_set A; pmf_of_set (f x)}"
+ (is "?lhs = ?rhs")
+proof (intro pmf_eqI)
+ fix x
+ from assms have [simp]: "finite A"
+ using infinite_disjoint_family_imp_infinite_UNION[of A f] by blast
+ from assms have "ereal (pmf (pmf_of_set (UNION A f)) x) =
+ ereal (indicator (\<Union>x\<in>A. f x) x / real (card (\<Union>x\<in>A. f x)))"
+ by (subst pmf_of_set) auto
+ also from assms have "card (\<Union>x\<in>A. f x) = card A * n"
+ by (subst card_UN_disjoint) (auto simp: disjoint_family_on_def)
+ also from assms
+ have "indicator (\<Union>x\<in>A. f x) x / real \<dots> =
+ indicator (\<Union>x\<in>A. f x) x / (n * real (card A))"
+ by (simp add: setsum_divide_distrib [symmetric] mult_ac)
+ also from assms have "indicator (\<Union>x\<in>A. f x) x = (\<Sum>y\<in>A. indicator (f y) x)"
+ by (intro indicator_UN_disjoint) simp_all
+ also from assms have "ereal ((\<Sum>y\<in>A. indicator (f y) x) / (real n * real (card A))) =
+ ereal (pmf ?rhs x)"
+ by (subst pmf_bind_pmf_of_set) (simp_all add: setsum_divide_distrib)
+ finally show "pmf ?lhs x = pmf ?rhs x" by simp
+qed
+
lemma bernoulli_pmf_half_conv_pmf_of_set: "bernoulli_pmf (1 / 2) = pmf_of_set UNIV"
by (rule pmf_eqI) simp_all
@@ -1670,4 +1786,139 @@
end
+
+subsection \<open>PMFs from assiciation lists\<close>
+
+definition pmf_of_list ::" ('a \<times> real) list \<Rightarrow> 'a pmf" where
+ "pmf_of_list xs = embed_pmf (\<lambda>x. listsum (map snd (filter (\<lambda>z. fst z = x) xs)))"
+
+definition pmf_of_list_wf where
+ "pmf_of_list_wf xs \<longleftrightarrow> (\<forall>x\<in>set (map snd xs) . x \<ge> 0) \<and> listsum (map snd xs) = 1"
+
+lemma pmf_of_list_wfI:
+ "(\<And>x. x \<in> set (map snd xs) \<Longrightarrow> x \<ge> 0) \<Longrightarrow> listsum (map snd xs) = 1 \<Longrightarrow> pmf_of_list_wf xs"
+ unfolding pmf_of_list_wf_def by simp
+
+context
+begin
+
+private lemma pmf_of_list_aux:
+ assumes "\<And>x. x \<in> set (map snd xs) \<Longrightarrow> x \<ge> 0"
+ assumes "listsum (map snd xs) = 1"
+ shows "(\<integral>\<^sup>+ x. ennreal (listsum (map snd [z\<leftarrow>xs . fst z = x])) \<partial>count_space UNIV) = 1"
+proof -
+ have "(\<integral>\<^sup>+ x. ennreal (listsum (map snd (filter (\<lambda>z. fst z = x) xs))) \<partial>count_space UNIV) =
+ (\<integral>\<^sup>+ x. ennreal (listsum (map (\<lambda>(x',p). indicator {x'} x * p) xs)) \<partial>count_space UNIV)"
+ by (intro nn_integral_cong ennreal_cong, subst listsum_map_filter) (auto intro: listsum_cong)
+ also have "\<dots> = (\<Sum>(x',p)\<leftarrow>xs. (\<integral>\<^sup>+ x. ennreal (indicator {x'} x * p) \<partial>count_space UNIV))"
+ using assms(1)
+ proof (induction xs)
+ case (Cons x xs)
+ from Cons.prems have "snd x \<ge> 0" by simp
+ moreover have "b \<ge> 0" if "(a,b) \<in> set xs" for a b
+ using Cons.prems[of b] that by force
+ ultimately have "(\<integral>\<^sup>+ y. ennreal (\<Sum>(x', p)\<leftarrow>x # xs. indicator {x'} y * p) \<partial>count_space UNIV) =
+ (\<integral>\<^sup>+ y. ennreal (indicator {fst x} y * snd x) +
+ ennreal (\<Sum>(x', p)\<leftarrow>xs. indicator {x'} y * p) \<partial>count_space UNIV)"
+ by (intro nn_integral_cong, subst ennreal_plus [symmetric])
+ (auto simp: case_prod_unfold indicator_def intro!: listsum_nonneg)
+ also have "\<dots> = (\<integral>\<^sup>+ y. ennreal (indicator {fst x} y * snd x) \<partial>count_space UNIV) +
+ (\<integral>\<^sup>+ y. ennreal (\<Sum>(x', p)\<leftarrow>xs. indicator {x'} y * p) \<partial>count_space UNIV)"
+ by (intro nn_integral_add)
+ (force intro!: listsum_nonneg AE_I2 intro: Cons simp: indicator_def)+
+ also have "(\<integral>\<^sup>+ y. ennreal (\<Sum>(x', p)\<leftarrow>xs. indicator {x'} y * p) \<partial>count_space UNIV) =
+ (\<Sum>(x', p)\<leftarrow>xs. (\<integral>\<^sup>+ y. ennreal (indicator {x'} y * p) \<partial>count_space UNIV))"
+ using Cons(1) by (intro Cons) simp_all
+ finally show ?case by (simp add: case_prod_unfold)
+ qed simp
+ also have "\<dots> = (\<Sum>(x',p)\<leftarrow>xs. ennreal p * (\<integral>\<^sup>+ x. indicator {x'} x \<partial>count_space UNIV))"
+ using assms(1)
+ by (intro listsum_cong, simp only: case_prod_unfold, subst nn_integral_cmult [symmetric])
+ (auto intro!: assms(1) simp: max_def times_ereal.simps [symmetric] mult_ac ereal_indicator
+ simp del: times_ereal.simps)+
+ also from assms have "\<dots> = listsum (map snd xs)" by (simp add: case_prod_unfold listsum_ennreal)
+ also have "\<dots> = 1" using assms(2) by simp
+ finally show ?thesis .
+qed
+
+lemma pmf_pmf_of_list:
+ assumes "pmf_of_list_wf xs"
+ shows "pmf (pmf_of_list xs) x = listsum (map snd (filter (\<lambda>z. fst z = x) xs))"
+ using assms pmf_of_list_aux[of xs] unfolding pmf_of_list_def pmf_of_list_wf_def
+ by (subst pmf_embed_pmf) (auto intro!: listsum_nonneg)
+
end
+
+lemma set_pmf_of_list:
+ assumes "pmf_of_list_wf xs"
+ shows "set_pmf (pmf_of_list xs) \<subseteq> set (map fst xs)"
+proof clarify
+ fix x assume A: "x \<in> set_pmf (pmf_of_list xs)"
+ show "x \<in> set (map fst xs)"
+ proof (rule ccontr)
+ assume "x \<notin> set (map fst xs)"
+ hence "[z\<leftarrow>xs . fst z = x] = []" by (auto simp: filter_empty_conv)
+ with A assms show False by (simp add: pmf_pmf_of_list set_pmf_eq)
+ qed
+qed
+
+lemma finite_set_pmf_of_list:
+ assumes "pmf_of_list_wf xs"
+ shows "finite (set_pmf (pmf_of_list xs))"
+ using assms by (rule finite_subset[OF set_pmf_of_list]) simp_all
+
+lemma emeasure_Int_set_pmf:
+ "emeasure (measure_pmf p) (A \<inter> set_pmf p) = emeasure (measure_pmf p) A"
+ by (rule emeasure_eq_AE) (auto simp: AE_measure_pmf_iff)
+
+lemma measure_Int_set_pmf:
+ "measure (measure_pmf p) (A \<inter> set_pmf p) = measure (measure_pmf p) A"
+ using emeasure_Int_set_pmf[of p A] by (simp add: Sigma_Algebra.measure_def)
+
+lemma emeasure_pmf_of_list:
+ assumes "pmf_of_list_wf xs"
+ shows "emeasure (pmf_of_list xs) A = ennreal (listsum (map snd (filter (\<lambda>x. fst x \<in> A) xs)))"
+proof -
+ have "emeasure (pmf_of_list xs) A = nn_integral (measure_pmf (pmf_of_list xs)) (indicator A)"
+ by simp
+ also from assms
+ have "\<dots> = (\<Sum>x\<in>set_pmf (pmf_of_list xs) \<inter> A. ennreal (listsum (map snd [z\<leftarrow>xs . fst z = x])))"
+ by (subst nn_integral_measure_pmf_finite) (simp_all add: finite_set_pmf_of_list pmf_pmf_of_list)
+ also from assms
+ have "\<dots> = ennreal (\<Sum>x\<in>set_pmf (pmf_of_list xs) \<inter> A. listsum (map snd [z\<leftarrow>xs . fst z = x]))"
+ by (subst setsum_ennreal) (auto simp: pmf_of_list_wf_def intro!: listsum_nonneg)
+ also have "\<dots> = ennreal (\<Sum>x\<in>set_pmf (pmf_of_list xs) \<inter> A.
+ indicator A x * pmf (pmf_of_list xs) x)" (is "_ = ennreal ?S")
+ using assms by (intro ennreal_cong setsum.cong) (auto simp: pmf_pmf_of_list)
+ also have "?S = (\<Sum>x\<in>set_pmf (pmf_of_list xs). indicator A x * pmf (pmf_of_list xs) x)"
+ using assms by (intro setsum.mono_neutral_left set_pmf_of_list finite_set_pmf_of_list) auto
+ also have "\<dots> = (\<Sum>x\<in>set (map fst xs). indicator A x * pmf (pmf_of_list xs) x)"
+ using assms by (intro setsum.mono_neutral_left set_pmf_of_list) (auto simp: set_pmf_eq)
+ also have "\<dots> = (\<Sum>x\<in>set (map fst xs). indicator A x *
+ listsum (map snd (filter (\<lambda>z. fst z = x) xs)))"
+ using assms by (simp add: pmf_pmf_of_list)
+ also have "\<dots> = (\<Sum>x\<in>set (map fst xs). listsum (map snd (filter (\<lambda>z. fst z = x \<and> x \<in> A) xs)))"
+ by (intro setsum.cong) (auto simp: indicator_def)
+ also have "\<dots> = (\<Sum>x\<in>set (map fst xs). (\<Sum>xa = 0..<length xs.
+ if fst (xs ! xa) = x \<and> x \<in> A then snd (xs ! xa) else 0))"
+ by (intro setsum.cong refl, subst listsum_map_filter, subst listsum_setsum_nth) simp
+ also have "\<dots> = (\<Sum>xa = 0..<length xs. (\<Sum>x\<in>set (map fst xs).
+ if fst (xs ! xa) = x \<and> x \<in> A then snd (xs ! xa) else 0))"
+ by (rule setsum.commute)
+ also have "\<dots> = (\<Sum>xa = 0..<length xs. if fst (xs ! xa) \<in> A then
+ (\<Sum>x\<in>set (map fst xs). if x = fst (xs ! xa) then snd (xs ! xa) else 0) else 0)"
+ by (auto intro!: setsum.cong setsum.neutral)
+ also have "\<dots> = (\<Sum>xa = 0..<length xs. if fst (xs ! xa) \<in> A then snd (xs ! xa) else 0)"
+ by (intro setsum.cong refl) (simp_all add: setsum.delta)
+ also have "\<dots> = listsum (map snd (filter (\<lambda>x. fst x \<in> A) xs))"
+ by (subst listsum_map_filter, subst listsum_setsum_nth) simp_all
+ finally show ?thesis .
+qed
+
+lemma measure_pmf_of_list:
+ assumes "pmf_of_list_wf xs"
+ shows "measure (pmf_of_list xs) A = listsum (map snd (filter (\<lambda>x. fst x \<in> A) xs))"
+ using assms unfolding pmf_of_list_wf_def Sigma_Algebra.measure_def
+ by (subst emeasure_pmf_of_list [OF assms], subst enn2real_ennreal) (auto intro!: listsum_nonneg)
+
+end