src/HOL/Probability/Probability_Mass_Function.thy
changeset 62975 1d066f6ab25d
parent 62390 842917225d56
child 63040 eb4ddd18d635
--- a/src/HOL/Probability/Probability_Mass_Function.thy	Thu Apr 14 12:17:44 2016 +0200
+++ b/src/HOL/Probability/Probability_Mass_Function.thy	Thu Apr 14 15:48:11 2016 +0200
@@ -22,16 +22,13 @@
     with x_M[THEN sets.sets_into_space] N have "emeasure M {x} \<le> emeasure M N"
       by (intro emeasure_mono) auto
     with x N have False
-      by (auto simp: emeasure_le_0_iff) }
+      by (auto simp:) }
   then show "P x" by auto
 qed
 
 lemma AE_measure_singleton: "measure M {x} \<noteq> 0 \<Longrightarrow> AE x in M. P x \<Longrightarrow> P x"
   by (metis AE_emeasure_singleton measure_def emeasure_empty measure_empty)
 
-lemma ereal_divide': "b \<noteq> 0 \<Longrightarrow> ereal (a / b) = ereal a / ereal b"
-  using ereal_divide[of a b] by simp
-
 lemma (in finite_measure) AE_support_countable:
   assumes [simp]: "sets M = UNIV"
   shows "(AE x in M. measure M {x} \<noteq> 0) \<longleftrightarrow> (\<exists>S. countable S \<and> (AE x in M. x \<in> S))"
@@ -54,7 +51,7 @@
     by (simp add: emeasure_single_in_space cong: rev_conj_cong)
   with finite_measure_compl[of "{x \<in> space M. x\<in>S \<and> emeasure M {x} \<noteq> 0}"]
   have "AE x in M. x \<in> S \<and> emeasure M {x} \<noteq> 0"
-    by (intro AE_I[OF order_refl]) (auto simp: emeasure_eq_measure set_diff_eq cong: conj_cong)
+    by (intro AE_I[OF order_refl]) (auto simp: emeasure_eq_measure measure_nonneg set_diff_eq cong: conj_cong)
   then show "AE x in M. measure M {x} \<noteq> 0"
     by (auto simp: emeasure_eq_measure)
 qed (auto intro!: exI[of _ "{x. measure M {x} \<noteq> 0}"] countable_support)
@@ -153,7 +150,7 @@
 lemma emeasure_pmf_single_eq_zero_iff:
   fixes M :: "'a pmf"
   shows "emeasure M {y} = 0 \<longleftrightarrow> y \<notin> M"
-  by transfer (simp add: finite_measure.emeasure_eq_measure[OF prob_space.finite_measure])
+  unfolding set_pmf.rep_eq by (simp add: measure_pmf.emeasure_eq_measure)
 
 lemma AE_measure_pmf_iff: "(AE x in measure_pmf M. P x) \<longleftrightarrow> (\<forall>y\<in>M. P y)"
   using AE_measure_singleton[of M] AE_measure_pmf[of M]
@@ -166,10 +163,10 @@
   by transfer (metis prob_space.finite_measure finite_measure.countable_support)
 
 lemma pmf_positive: "x \<in> set_pmf p \<Longrightarrow> 0 < pmf p x"
-  by transfer (simp add: less_le measure_nonneg)
+  by transfer (simp add: less_le)
 
-lemma pmf_nonneg: "0 \<le> pmf p x"
-  by transfer (simp add: measure_nonneg)
+lemma pmf_nonneg[simp]: "0 \<le> pmf p x"
+  by transfer simp
 
 lemma pmf_le_1: "pmf p x \<le> 1"
   by (simp add: pmf.rep_eq)
@@ -180,6 +177,9 @@
 lemma set_pmf_iff: "x \<in> set_pmf M \<longleftrightarrow> pmf M x \<noteq> 0"
   by transfer simp
 
+lemma pmf_positive_iff: "0 < pmf p x \<longleftrightarrow> x \<in> set_pmf p"
+  unfolding less_le by (simp add: set_pmf_iff)
+
 lemma set_pmf_eq: "set_pmf M = {x. pmf M x \<noteq> 0}"
   by (auto simp: set_pmf_iff)
 
@@ -189,16 +189,17 @@
   by transfer (simp add: finite_measure.emeasure_eq_measure[OF prob_space.finite_measure])
 
 lemma measure_pmf_single: "measure (measure_pmf M) {x} = pmf M x"
-using emeasure_pmf_single[of M x] by(simp add: measure_pmf.emeasure_eq_measure)
+  using emeasure_pmf_single[of M x] by(simp add: measure_pmf.emeasure_eq_measure pmf_nonneg measure_nonneg)
 
 lemma emeasure_measure_pmf_finite: "finite S \<Longrightarrow> emeasure (measure_pmf M) S = (\<Sum>s\<in>S. pmf M s)"
-  by (subst emeasure_eq_setsum_singleton) (auto simp: emeasure_pmf_single)
+  by (subst emeasure_eq_setsum_singleton) (auto simp: emeasure_pmf_single pmf_nonneg)
 
 lemma measure_measure_pmf_finite: "finite S \<Longrightarrow> measure (measure_pmf M) S = setsum (pmf M) S"
-  using emeasure_measure_pmf_finite[of S M] by(simp add: measure_pmf.emeasure_eq_measure)
+  using emeasure_measure_pmf_finite[of S M]
+  by (simp add: measure_pmf.emeasure_eq_measure measure_nonneg setsum_nonneg pmf_nonneg)
 
 lemma nn_integral_measure_pmf_support:
-  fixes f :: "'a \<Rightarrow> ereal"
+  fixes f :: "'a \<Rightarrow> ennreal"
   assumes f: "finite A" and nn: "\<And>x. x \<in> A \<Longrightarrow> 0 \<le> f x" "\<And>x. x \<in> set_pmf M \<Longrightarrow> x \<notin> A \<Longrightarrow> f x = 0"
   shows "(\<integral>\<^sup>+x. f x \<partial>measure_pmf M) = (\<Sum>x\<in>A. f x * pmf M x)"
 proof -
@@ -211,14 +212,15 @@
 qed
 
 lemma nn_integral_measure_pmf_finite:
-  fixes f :: "'a \<Rightarrow> ereal"
+  fixes f :: "'a \<Rightarrow> ennreal"
   assumes f: "finite (set_pmf M)" and nn: "\<And>x. x \<in> set_pmf M \<Longrightarrow> 0 \<le> f x"
   shows "(\<integral>\<^sup>+x. f x \<partial>measure_pmf M) = (\<Sum>x\<in>set_pmf M. f x * pmf M x)"
   using assms by (intro nn_integral_measure_pmf_support) auto
+
 lemma integrable_measure_pmf_finite:
   fixes f :: "'a \<Rightarrow> 'b::{banach, second_countable_topology}"
   shows "finite (set_pmf M) \<Longrightarrow> integrable M f"
-  by (auto intro!: integrableI_bounded simp: nn_integral_measure_pmf_finite)
+  by (auto intro!: integrableI_bounded simp: nn_integral_measure_pmf_finite ennreal_mult_less_top)
 
 lemma integral_measure_pmf:
   assumes [simp]: "finite A" and "\<And>a. a \<in> set_pmf M \<Longrightarrow> f a \<noteq> 0 \<Longrightarrow> a \<in> A"
@@ -227,7 +229,8 @@
   have "(\<integral>x. f x \<partial>measure_pmf M) = (\<integral>x. f x * indicator A x \<partial>measure_pmf M)"
     using assms(2) by (intro integral_cong_AE) (auto split: split_indicator simp: AE_measure_pmf_iff)
   also have "\<dots> = (\<Sum>a\<in>A. f a * pmf M a)"
-    by (subst integral_indicator_finite_real) (auto simp: measure_def emeasure_measure_pmf_finite)
+    by (subst integral_indicator_finite_real)
+       (auto simp: measure_def emeasure_measure_pmf_finite pmf_nonneg)
   finally show ?thesis .
 qed
 
@@ -254,7 +257,7 @@
   also have "\<dots> = emeasure M X"
     by (auto intro!: emeasure_eq_AE simp: AE_measure_pmf_iff)
   finally show ?thesis
-    by (simp add: measure_pmf.emeasure_eq_measure)
+    by (simp add: measure_pmf.emeasure_eq_measure measure_nonneg integral_nonneg pmf_nonneg)
 qed
 
 lemma integral_pmf_restrict:
@@ -336,23 +339,29 @@
     done
 qed
 
-lemma ereal_pmf_bind: "pmf (bind_pmf N f) i = (\<integral>\<^sup>+x. pmf (f x) i \<partial>measure_pmf N)"
+lemma ennreal_pmf_bind: "pmf (bind_pmf N f) i = (\<integral>\<^sup>+x. pmf (f x) i \<partial>measure_pmf N)"
   unfolding pmf.rep_eq bind_pmf.rep_eq
   by (auto simp: measure_pmf.measure_bind[where N="count_space UNIV"] measure_subprob measure_nonneg
            intro!: nn_integral_eq_integral[symmetric] measure_pmf.integrable_const_bound[where B=1])
 
 lemma pmf_bind: "pmf (bind_pmf N f) i = (\<integral>x. pmf (f x) i \<partial>measure_pmf N)"
-  using ereal_pmf_bind[of N f i]
+  using ennreal_pmf_bind[of N f i]
   by (subst (asm) nn_integral_eq_integral)
-     (auto simp: pmf_nonneg pmf_le_1
+     (auto simp: pmf_nonneg pmf_le_1 pmf_nonneg integral_nonneg
            intro!: nn_integral_eq_integral[symmetric] measure_pmf.integrable_const_bound[where B=1])
 
 lemma bind_pmf_const[simp]: "bind_pmf M (\<lambda>x. c) = c"
   by transfer (simp add: bind_const' prob_space_imp_subprob_space)
 
 lemma set_bind_pmf[simp]: "set_pmf (bind_pmf M N) = (\<Union>M\<in>set_pmf M. set_pmf (N M))"
-  unfolding set_pmf_eq ereal_eq_0(1)[symmetric] ereal_pmf_bind
-  by (auto simp add: nn_integral_0_iff_AE AE_measure_pmf_iff set_pmf_eq not_le less_le pmf_nonneg)
+proof -
+  have "set_pmf (bind_pmf M N) = {x. ennreal (pmf (bind_pmf M N) x) \<noteq> 0}"
+    by (simp add: set_pmf_eq pmf_nonneg)
+  also have "\<dots> = (\<Union>M\<in>set_pmf M. set_pmf (N M))"
+    unfolding ennreal_pmf_bind
+    by (subst nn_integral_0_iff_AE) (auto simp: AE_measure_pmf_iff pmf_nonneg set_pmf_eq)
+  finally show ?thesis .
+qed
 
 lemma bind_pmf_cong:
   assumes "p = q"
@@ -372,7 +381,6 @@
 lemma nn_integral_bind_pmf[simp]: "(\<integral>\<^sup>+x. f x \<partial>bind_pmf M N) = (\<integral>\<^sup>+x. \<integral>\<^sup>+y. f y \<partial>N x \<partial>M)"
   using measurable_measure_pmf[of N]
   unfolding measure_pmf_bind
-  apply (subst (1 3) nn_integral_max_0[symmetric])
   apply (intro nn_integral_bind[where B="count_space UNIV"])
   apply auto
   done
@@ -452,18 +460,28 @@
   unfolding map_pmf_rep_eq by (subst emeasure_distr) auto
 
 lemma measure_map_pmf[simp]: "measure (map_pmf f M) X = measure M (f -` X)"
-using emeasure_map_pmf[of f M X] by(simp add: measure_pmf.emeasure_eq_measure)
+using emeasure_map_pmf[of f M X] by(simp add: measure_pmf.emeasure_eq_measure measure_nonneg)
 
 lemma nn_integral_map_pmf[simp]: "(\<integral>\<^sup>+x. f x \<partial>map_pmf g M) = (\<integral>\<^sup>+x. f (g x) \<partial>M)"
   unfolding map_pmf_rep_eq by (intro nn_integral_distr) auto
 
-lemma ereal_pmf_map: "pmf (map_pmf f p) x = (\<integral>\<^sup>+ y. indicator (f -` {x}) y \<partial>measure_pmf p)"
+lemma ennreal_pmf_map: "pmf (map_pmf f p) x = (\<integral>\<^sup>+ y. indicator (f -` {x}) y \<partial>measure_pmf p)"
 proof (transfer fixing: f x)
   fix p :: "'b measure"
   presume "prob_space p"
   then interpret prob_space p .
   presume "sets p = UNIV"
-  then show "ereal (measure (distr p (count_space UNIV) f) {x}) = integral\<^sup>N p (indicator (f -` {x}))"
+  then show "ennreal (measure (distr p (count_space UNIV) f) {x}) = integral\<^sup>N p (indicator (f -` {x}))"
+    by(simp add: measure_distr measurable_def emeasure_eq_measure)
+qed simp_all
+
+lemma pmf_map: "pmf (map_pmf f p) x = measure p (f -` {x})"
+proof (transfer fixing: f x)
+  fix p :: "'b measure"
+  presume "prob_space p"
+  then interpret prob_space p .
+  presume "sets p = UNIV"
+  then show "measure (distr p (count_space UNIV) f) {x} = measure p (f -` {x})"
     by(simp add: measure_distr measurable_def emeasure_eq_measure)
 qed simp_all
 
@@ -480,6 +498,11 @@
   finally show ?thesis .
 qed
 
+lemma integral_map_pmf[simp]:
+  fixes f :: "'a \<Rightarrow> 'b::{banach, second_countable_topology}"
+  shows "integral\<^sup>L (map_pmf g p) f = integral\<^sup>L p (\<lambda>x. f (g x))"
+  by (simp add: integral_distr map_pmf_rep_eq)
+
 lemma map_return_pmf [simp]: "map_pmf f (return_pmf x) = return_pmf (f x)"
   by transfer (simp add: distr_return)
 
@@ -529,15 +552,13 @@
 
 lemma nn_integral_pair_pmf': "(\<integral>\<^sup>+x. f x \<partial>pair_pmf A B) = (\<integral>\<^sup>+a. \<integral>\<^sup>+b. f (a, b) \<partial>B \<partial>A)"
 proof -
-  have "(\<integral>\<^sup>+x. f x \<partial>pair_pmf A B) = (\<integral>\<^sup>+x. max 0 (f x) * indicator (A \<times> B) x \<partial>pair_pmf A B)"
-    by (subst nn_integral_max_0[symmetric])
-       (auto simp: AE_measure_pmf_iff intro!: nn_integral_cong_AE)
-  also have "\<dots> = (\<integral>\<^sup>+a. \<integral>\<^sup>+b. max 0 (f (a, b)) * indicator (A \<times> B) (a, b) \<partial>B \<partial>A)"
+  have "(\<integral>\<^sup>+x. f x \<partial>pair_pmf A B) = (\<integral>\<^sup>+x. f x * indicator (A \<times> B) x \<partial>pair_pmf A B)"
+    by (auto simp: AE_measure_pmf_iff intro!: nn_integral_cong_AE)
+  also have "\<dots> = (\<integral>\<^sup>+a. \<integral>\<^sup>+b. f (a, b) * indicator (A \<times> B) (a, b) \<partial>B \<partial>A)"
     by (simp add: pair_pmf_def)
-  also have "\<dots> = (\<integral>\<^sup>+a. \<integral>\<^sup>+b. max 0 (f (a, b)) \<partial>B \<partial>A)"
+  also have "\<dots> = (\<integral>\<^sup>+a. \<integral>\<^sup>+b. f (a, b) \<partial>B \<partial>A)"
     by (auto intro!: nn_integral_cong_AE simp: AE_measure_pmf_iff)
-  finally show ?thesis
-    unfolding nn_integral_max_0 .
+  finally show ?thesis .
 qed
 
 lemma bind_pair_pmf:
@@ -562,7 +583,7 @@
   then show "emeasure ?L X = emeasure ?R X"
     apply (simp add: emeasure_bind[OF _ M' X])
     apply (simp add: nn_integral_bind[where B="count_space UNIV"] pair_pmf_def measure_pmf_bind[of A]
-                     nn_integral_measure_pmf_finite emeasure_nonneg one_ereal_def[symmetric])
+                     nn_integral_measure_pmf_finite)
     apply (subst emeasure_bind[OF _ _ X])
     apply measurable
     apply (subst emeasure_bind[OF _ _ X])
@@ -582,10 +603,10 @@
      (auto simp: bij_betw_def nn_integral_pmf)
 
 lemma pmf_le_0_iff[simp]: "pmf M p \<le> 0 \<longleftrightarrow> pmf M p = 0"
-  using pmf_nonneg[of M p] by simp
+  using pmf_nonneg[of M p] by arith
 
 lemma min_pmf_0[simp]: "min (pmf M p) 0 = 0" "min 0 (pmf M p) = 0"
-  using pmf_nonneg[of M p] by simp_all
+  using pmf_nonneg[of M p] by arith+
 
 lemma pmf_eq_0_set_pmf: "pmf M p = 0 \<longleftrightarrow> p \<notin> set_pmf M"
   unfolding set_pmf_iff by simp
@@ -612,22 +633,22 @@
   assumes prob: "(\<integral>\<^sup>+x. f x \<partial>count_space UNIV) = 1"
 begin
 
-lift_definition embed_pmf :: "'a pmf" is "density (count_space UNIV) (ereal \<circ> f)"
+lift_definition embed_pmf :: "'a pmf" is "density (count_space UNIV) (ennreal \<circ> f)"
 proof (intro conjI)
-  have *[simp]: "\<And>x y. ereal (f y) * indicator {x} y = ereal (f x) * indicator {x} y"
+  have *[simp]: "\<And>x y. ennreal (f y) * indicator {x} y = ennreal (f x) * indicator {x} y"
     by (simp split: split_indicator)
-  show "AE x in density (count_space UNIV) (ereal \<circ> f).
-    measure (density (count_space UNIV) (ereal \<circ> f)) {x} \<noteq> 0"
+  show "AE x in density (count_space UNIV) (ennreal \<circ> f).
+    measure (density (count_space UNIV) (ennreal \<circ> f)) {x} \<noteq> 0"
     by (simp add: AE_density nonneg measure_def emeasure_density max_def)
-  show "prob_space (density (count_space UNIV) (ereal \<circ> f))"
+  show "prob_space (density (count_space UNIV) (ennreal \<circ> f))"
     by standard (simp add: emeasure_density prob)
 qed simp
 
 lemma pmf_embed_pmf: "pmf embed_pmf x = f x"
 proof transfer
-  have *[simp]: "\<And>x y. ereal (f y) * indicator {x} y = ereal (f x) * indicator {x} y"
+  have *[simp]: "\<And>x y. ennreal (f y) * indicator {x} y = ennreal (f x) * indicator {x} y"
     by (simp split: split_indicator)
-  fix x show "measure (density (count_space UNIV) (ereal \<circ> f)) {x} = f x"
+  fix x show "measure (density (count_space UNIV) (ennreal \<circ> f)) {x} = f x"
     by transfer (simp add: measure_def emeasure_density nonneg max_def)
 qed
 
@@ -637,17 +658,17 @@
 end
 
 lemma embed_pmf_transfer:
-  "rel_fun (eq_onp (\<lambda>f. (\<forall>x. 0 \<le> f x) \<and> (\<integral>\<^sup>+x. ereal (f x) \<partial>count_space UNIV) = 1)) pmf_as_measure.cr_pmf (\<lambda>f. density (count_space UNIV) (ereal \<circ> f)) embed_pmf"
+  "rel_fun (eq_onp (\<lambda>f. (\<forall>x. 0 \<le> f x) \<and> (\<integral>\<^sup>+x. ennreal (f x) \<partial>count_space UNIV) = 1)) pmf_as_measure.cr_pmf (\<lambda>f. density (count_space UNIV) (ennreal \<circ> f)) embed_pmf"
   by (auto simp: rel_fun_def eq_onp_def embed_pmf.transfer)
 
 lemma measure_pmf_eq_density: "measure_pmf p = density (count_space UNIV) (pmf p)"
 proof (transfer, elim conjE)
   fix M :: "'a measure" assume [simp]: "sets M = UNIV" and ae: "AE x in M. measure M {x} \<noteq> 0"
   assume "prob_space M" then interpret prob_space M .
-  show "M = density (count_space UNIV) (\<lambda>x. ereal (measure M {x}))"
+  show "M = density (count_space UNIV) (\<lambda>x. ennreal (measure M {x}))"
   proof (rule measure_eqI)
     fix A :: "'a set"
-    have "(\<integral>\<^sup>+ x. ereal (measure M {x}) * indicator A x \<partial>count_space UNIV) =
+    have "(\<integral>\<^sup>+ x. ennreal (measure M {x}) * indicator A x \<partial>count_space UNIV) =
       (\<integral>\<^sup>+ x. emeasure M {x} * indicator (A \<inter> {x. measure M {x} \<noteq> 0}) x \<partial>count_space UNIV)"
       by (auto intro!: nn_integral_cong simp: emeasure_eq_measure split: split_indicator)
     also have "\<dots> = (\<integral>\<^sup>+ x. emeasure M {x} \<partial>count_space (A \<inter> {x. measure M {x} \<noteq> 0}))"
@@ -657,19 +678,19 @@
          (auto simp: disjoint_family_on_def)
     also have "\<dots> = emeasure M A"
       using ae by (intro emeasure_eq_AE) auto
-    finally show " emeasure M A = emeasure (density (count_space UNIV) (\<lambda>x. ereal (measure M {x}))) A"
+    finally show " emeasure M A = emeasure (density (count_space UNIV) (\<lambda>x. ennreal (measure M {x}))) A"
       using emeasure_space_1 by (simp add: emeasure_density)
   qed simp
 qed
 
 lemma td_pmf_embed_pmf:
-  "type_definition pmf embed_pmf {f::'a \<Rightarrow> real. (\<forall>x. 0 \<le> f x) \<and> (\<integral>\<^sup>+x. ereal (f x) \<partial>count_space UNIV) = 1}"
+  "type_definition pmf embed_pmf {f::'a \<Rightarrow> real. (\<forall>x. 0 \<le> f x) \<and> (\<integral>\<^sup>+x. ennreal (f x) \<partial>count_space UNIV) = 1}"
   unfolding type_definition_def
 proof safe
   fix p :: "'a pmf"
   have "(\<integral>\<^sup>+ x. 1 \<partial>measure_pmf p) = 1"
     using measure_pmf.emeasure_space_1[of p] by simp
-  then show *: "(\<integral>\<^sup>+ x. ereal (pmf p x) \<partial>count_space UNIV) = 1"
+  then show *: "(\<integral>\<^sup>+ x. ennreal (pmf p x) \<partial>count_space UNIV) = 1"
     by (simp add: measure_pmf_eq_density nn_integral_density pmf_nonneg del: nn_integral_const)
 
   show "embed_pmf (pmf p) = p"
@@ -683,7 +704,7 @@
 
 end
 
-lemma nn_integral_measure_pmf: "(\<integral>\<^sup>+ x. f x \<partial>measure_pmf p) = \<integral>\<^sup>+ x. ereal (pmf p x) * f x \<partial>count_space UNIV"
+lemma nn_integral_measure_pmf: "(\<integral>\<^sup>+ x. f x \<partial>measure_pmf p) = \<integral>\<^sup>+ x. ennreal (pmf p x) * f x \<partial>count_space UNIV"
 by(simp add: measure_pmf_eq_density nn_integral_density pmf_nonneg)
 
 locale pmf_as_function
@@ -745,31 +766,31 @@
 lemma pair_map_pmf1: "pair_pmf (map_pmf f A) B = map_pmf (apfst f) (pair_pmf A B)"
 proof (safe intro!: pmf_eqI)
   fix a :: "'a" and b :: "'b"
-  have [simp]: "\<And>c d. indicator (apfst f -` {(a, b)}) (c, d) = indicator (f -` {a}) c * (indicator {b} d::ereal)"
+  have [simp]: "\<And>c d. indicator (apfst f -` {(a, b)}) (c, d) = indicator (f -` {a}) c * (indicator {b} d::ennreal)"
     by (auto split: split_indicator)
 
-  have "ereal (pmf (pair_pmf (map_pmf f A) B) (a, b)) =
-         ereal (pmf (map_pmf (apfst f) (pair_pmf A B)) (a, b))"
-    unfolding pmf_pair ereal_pmf_map
+  have "ennreal (pmf (pair_pmf (map_pmf f A) B) (a, b)) =
+         ennreal (pmf (map_pmf (apfst f) (pair_pmf A B)) (a, b))"
+    unfolding pmf_pair ennreal_pmf_map
     by (simp add: nn_integral_pair_pmf' max_def emeasure_pmf_single nn_integral_multc pmf_nonneg
-                  emeasure_map_pmf[symmetric] del: emeasure_map_pmf)
+                  emeasure_map_pmf[symmetric] ennreal_mult del: emeasure_map_pmf)
   then show "pmf (pair_pmf (map_pmf f A) B) (a, b) = pmf (map_pmf (apfst f) (pair_pmf A B)) (a, b)"
-    by simp
+    by (simp add: pmf_nonneg)
 qed
 
 lemma pair_map_pmf2: "pair_pmf A (map_pmf f B) = map_pmf (apsnd f) (pair_pmf A B)"
 proof (safe intro!: pmf_eqI)
   fix a :: "'a" and b :: "'b"
-  have [simp]: "\<And>c d. indicator (apsnd f -` {(a, b)}) (c, d) = indicator {a} c * (indicator (f -` {b}) d::ereal)"
+  have [simp]: "\<And>c d. indicator (apsnd f -` {(a, b)}) (c, d) = indicator {a} c * (indicator (f -` {b}) d::ennreal)"
     by (auto split: split_indicator)
 
-  have "ereal (pmf (pair_pmf A (map_pmf f B)) (a, b)) =
-         ereal (pmf (map_pmf (apsnd f) (pair_pmf A B)) (a, b))"
-    unfolding pmf_pair ereal_pmf_map
+  have "ennreal (pmf (pair_pmf A (map_pmf f B)) (a, b)) =
+         ennreal (pmf (map_pmf (apsnd f) (pair_pmf A B)) (a, b))"
+    unfolding pmf_pair ennreal_pmf_map
     by (simp add: nn_integral_pair_pmf' max_def emeasure_pmf_single nn_integral_cmult nn_integral_multc pmf_nonneg
-                  emeasure_map_pmf[symmetric] del: emeasure_map_pmf)
+                  emeasure_map_pmf[symmetric] ennreal_mult del: emeasure_map_pmf)
   then show "pmf (pair_pmf A (map_pmf f B)) (a, b) = pmf (map_pmf (apsnd f) (pair_pmf A B)) (a, b)"
-    by simp
+    by (simp add: pmf_nonneg)
 qed
 
 lemma map_pair: "map_pmf (\<lambda>(a, b). (f a, g b)) (pair_pmf A B) = pair_pmf (map_pmf f A) (map_pmf g B)"
@@ -794,7 +815,7 @@
   fix i
   assume x: "set_pmf p \<subseteq> {x}"
   hence *: "set_pmf p = {x}" using set_pmf_not_empty[of p] by auto
-  have "ereal (pmf p x) = \<integral>\<^sup>+ i. indicator {x} i \<partial>p" by(simp add: emeasure_pmf_single)
+  have "ennreal (pmf p x) = \<integral>\<^sup>+ i. indicator {x} i \<partial>p" by(simp add: emeasure_pmf_single)
   also have "\<dots> = \<integral>\<^sup>+ i. 1 \<partial>p" by(rule nn_integral_cong_AE)(simp add: AE_measure_pmf_iff * )
   also have "\<dots> = 1" by simp
   finally show "pmf p i = pmf (return_pmf x) i" using x
@@ -817,11 +838,14 @@
   show ?lhs
   proof(rule pmf_eqI)
     fix i
-    have "ereal (pmf (bind_pmf p f) i) = \<integral>\<^sup>+ y. ereal (pmf (f y) i) \<partial>p" by(simp add: ereal_pmf_bind)
-    also have "\<dots> = \<integral>\<^sup>+ y. ereal (pmf (return_pmf x) i) \<partial>p"
+    have "ennreal (pmf (bind_pmf p f) i) = \<integral>\<^sup>+ y. ennreal (pmf (f y) i) \<partial>p"
+      by (simp add: ennreal_pmf_bind)
+    also have "\<dots> = \<integral>\<^sup>+ y. ennreal (pmf (return_pmf x) i) \<partial>p"
       by(rule nn_integral_cong_AE)(simp add: AE_measure_pmf_iff * )
-    also have "\<dots> = ereal (pmf (return_pmf x) i)" by simp
-    finally show "pmf (bind_pmf p f) i = pmf (return_pmf x) i" by simp
+    also have "\<dots> = ennreal (pmf (return_pmf x) i)"
+      by simp
+    finally show "pmf (bind_pmf p f) i = pmf (return_pmf x) i"
+      by (simp add: pmf_nonneg)
   qed
 qed
 
@@ -860,7 +884,7 @@
 qed
 
 lemma measure_measure_pmf_not_zero: "measure (measure_pmf p) s \<noteq> 0"
-  using emeasure_measure_pmf_not_zero unfolding measure_pmf.emeasure_eq_measure by simp
+  using emeasure_measure_pmf_not_zero by (simp add: measure_pmf.emeasure_eq_measure measure_nonneg)
 
 lift_definition cond_pmf :: "'a pmf" is
   "uniform_measure (measure_pmf p) s"
@@ -869,7 +893,7 @@
     by (intro prob_space_uniform_measure) (auto simp: emeasure_measure_pmf_not_zero)
   show "AE x in uniform_measure (measure_pmf p) s. measure (uniform_measure (measure_pmf p) s) {x} \<noteq> 0"
     by (simp add: emeasure_measure_pmf_not_zero measure_measure_pmf_not_zero AE_uniform_measure
-                  AE_measure_pmf_iff set_pmf.rep_eq)
+                  AE_measure_pmf_iff set_pmf.rep_eq less_top[symmetric])
 qed simp
 
 lemma pmf_cond: "pmf cond_pmf x = (if x \<in> s then pmf p x / measure p s else 0)"
@@ -887,20 +911,20 @@
   have *: "set_pmf (map_pmf f p) \<inter> s \<noteq> {}"
     using assms by auto
   { fix x
-    have "ereal (pmf (map_pmf f (cond_pmf p (f -` s))) x) =
+    have "ennreal (pmf (map_pmf f (cond_pmf p (f -` s))) x) =
       emeasure p (f -` s \<inter> f -` {x}) / emeasure p (f -` s)"
-      unfolding ereal_pmf_map cond_pmf.rep_eq[OF assms] by (simp add: nn_integral_uniform_measure)
+      unfolding ennreal_pmf_map cond_pmf.rep_eq[OF assms] by (simp add: nn_integral_uniform_measure)
     also have "f -` s \<inter> f -` {x} = (if x \<in> s then f -` {x} else {})"
       by auto
     also have "emeasure p (if x \<in> s then f -` {x} else {}) / emeasure p (f -` s) =
-      ereal (pmf (cond_pmf (map_pmf f p) s) x)"
+      ennreal (pmf (cond_pmf (map_pmf f p) s) x)"
       using measure_measure_pmf_not_zero[OF *]
-      by (simp add: pmf_cond[OF *] ereal_divide' ereal_pmf_map measure_pmf.emeasure_eq_measure[symmetric]
-               del: ereal_divide)
-    finally have "ereal (pmf (cond_pmf (map_pmf f p) s) x) = ereal (pmf (map_pmf f (cond_pmf p (f -` s))) x)"
+      by (simp add: pmf_cond[OF *] ennreal_pmf_map measure_pmf.emeasure_eq_measure
+                    divide_ennreal pmf_nonneg measure_nonneg zero_less_measure_iff pmf_map)
+    finally have "ennreal (pmf (cond_pmf (map_pmf f p) s) x) = ennreal (pmf (map_pmf f (cond_pmf p (f -` s))) x)"
       by simp }
   then show ?thesis
-    by (intro pmf_eqI) simp
+    by (intro pmf_eqI) (simp add: pmf_nonneg)
 qed
 
 lemma bind_cond_pmf_cancel:
@@ -910,16 +934,18 @@
   shows "bind_pmf p (\<lambda>x. cond_pmf q {y. R x y}) = q"
 proof (rule pmf_eqI)
   fix i
-  have "ereal (pmf (bind_pmf p (\<lambda>x. cond_pmf q {y. R x y})) i) =
-    (\<integral>\<^sup>+x. ereal (pmf q i / measure p {x. R x i}) * ereal (indicator {x. R x i} x) \<partial>p)"
-    by (auto simp add: ereal_pmf_bind AE_measure_pmf_iff pmf_cond pmf_eq_0_set_pmf intro!: nn_integral_cong_AE)
+  have "ennreal (pmf (bind_pmf p (\<lambda>x. cond_pmf q {y. R x y})) i) =
+    (\<integral>\<^sup>+x. ennreal (pmf q i / measure p {x. R x i}) * ennreal (indicator {x. R x i} x) \<partial>p)"
+    by (auto simp add: ennreal_pmf_bind AE_measure_pmf_iff pmf_cond pmf_eq_0_set_pmf pmf_nonneg measure_nonneg
+             intro!: nn_integral_cong_AE)
   also have "\<dots> = (pmf q i * measure p {x. R x i}) / measure p {x. R x i}"
-    by (simp add: pmf_nonneg measure_nonneg zero_ereal_def[symmetric] ereal_indicator
-                  nn_integral_cmult measure_pmf.emeasure_eq_measure)
+    by (simp add: pmf_nonneg measure_nonneg zero_ennreal_def[symmetric] ennreal_indicator
+                  nn_integral_cmult measure_pmf.emeasure_eq_measure ennreal_mult[symmetric])
   also have "\<dots> = pmf q i"
-    by (cases "pmf q i = 0") (simp_all add: pmf_eq_0_set_pmf measure_measure_pmf_not_zero)
+    by (cases "pmf q i = 0")
+       (simp_all add: pmf_eq_0_set_pmf measure_measure_pmf_not_zero pmf_nonneg)
   finally show "pmf (bind_pmf p (\<lambda>x. cond_pmf q {y. R x y})) i = pmf q i"
-    by simp
+    by (simp add: pmf_nonneg)
 qed
 
 subsection \<open> Relator \<close>
@@ -1277,8 +1303,8 @@
 lemma pmf_join: "pmf (join_pmf N) i = (\<integral>M. pmf M i \<partial>measure_pmf N)"
   unfolding join_pmf_def pmf_bind ..
 
-lemma ereal_pmf_join: "ereal (pmf (join_pmf N) i) = (\<integral>\<^sup>+M. pmf M i \<partial>measure_pmf N)"
-  unfolding join_pmf_def ereal_pmf_bind ..
+lemma ennreal_pmf_join: "ennreal (pmf (join_pmf N) i) = (\<integral>\<^sup>+M. pmf M i \<partial>measure_pmf N)"
+  unfolding join_pmf_def ennreal_pmf_bind ..
 
 lemma set_pmf_join_pmf[simp]: "set_pmf (join_pmf f) = (\<Union>p\<in>set_pmf f. set_pmf p)"
   by (simp add: join_pmf_def)
@@ -1430,7 +1456,7 @@
 lemma pmf_bernoulli_False[simp]: "0 \<le> p \<Longrightarrow> p \<le> 1 \<Longrightarrow> pmf (bernoulli_pmf p) False = 1 - p"
   by transfer simp
 
-lemma set_pmf_bernoulli: "0 < p \<Longrightarrow> p < 1 \<Longrightarrow> set_pmf (bernoulli_pmf p) = UNIV"
+lemma set_pmf_bernoulli[simp]: "0 < p \<Longrightarrow> p < 1 \<Longrightarrow> set_pmf (bernoulli_pmf p) = UNIV"
   by (auto simp add: set_pmf_iff UNIV_bool)
 
 lemma nn_integral_bernoulli_pmf[simp]:
@@ -1448,7 +1474,10 @@
 by(cases x) simp_all
 
 lemma measure_pmf_bernoulli_half: "measure_pmf (bernoulli_pmf (1 / 2)) = uniform_count_measure UNIV"
-by(rule measure_eqI)(simp_all add: nn_integral_pmf[symmetric] emeasure_uniform_count_measure nn_integral_count_space_finite sets_uniform_count_measure)
+  by (rule measure_eqI)
+     (simp_all add: nn_integral_pmf[symmetric] emeasure_uniform_count_measure ennreal_divide_numeral[symmetric]
+                    nn_integral_count_space_finite sets_uniform_count_measure divide_ennreal_def mult_ac
+                    ennreal_of_nat_eq_real_of_nat)
 
 subsubsection \<open> Geometric Distribution \<close>
 
@@ -1458,9 +1487,9 @@
 
 lift_definition geometric_pmf :: "nat pmf" is "\<lambda>n. (1 - p)^n * p"
 proof
-  have "(\<Sum>i. ereal (p * (1 - p) ^ i)) = ereal (p * (1 / (1 - (1 - p))))"
-    by (intro sums_suminf_ereal sums_mult geometric_sums) auto
-  then show "(\<integral>\<^sup>+ x. ereal ((1 - p)^x * p) \<partial>count_space UNIV) = 1"
+  have "(\<Sum>i. ennreal (p * (1 - p) ^ i)) = ennreal (p * (1 / (1 - (1 - p))))"
+    by (intro suminf_ennreal_eq sums_mult geometric_sums) auto
+  then show "(\<integral>\<^sup>+ x. ennreal ((1 - p)^x * p) \<partial>count_space UNIV) = 1"
     by (simp add: nn_integral_count_space_nat field_simps)
 qed simp
 
@@ -1480,7 +1509,7 @@
 
 lift_definition pmf_of_multiset :: "'a pmf" is "\<lambda>x. count M x / size M"
 proof
-  show "(\<integral>\<^sup>+ x. ereal (real (count M x) / real (size M)) \<partial>count_space UNIV) = 1"
+  show "(\<integral>\<^sup>+ x. ennreal (real (count M x) / real (size M)) \<partial>count_space UNIV) = 1"
     using M_not_empty
     by (simp add: zero_less_divide_iff nn_integral_count_space nonempty_has_size
                   setsum_divide_distrib[symmetric])
@@ -1503,8 +1532,10 @@
 
 lift_definition pmf_of_set :: "'a pmf" is "\<lambda>x. indicator S x / card S"
 proof
-  show "(\<integral>\<^sup>+ x. ereal (indicator S x / real (card S)) \<partial>count_space UNIV) = 1"
-    using S_not_empty S_finite by (subst nn_integral_count_space'[of S]) auto
+  show "(\<integral>\<^sup>+ x. ennreal (indicator S x / real (card S)) \<partial>count_space UNIV) = 1"
+    using S_not_empty S_finite
+    by (subst nn_integral_count_space'[of S])
+       (auto simp: ennreal_of_nat_eq_real_of_nat ennreal_mult[symmetric])
 qed simp
 
 lemma pmf_of_set[simp]: "pmf pmf_of_set x = indicator S x / card S"
@@ -1516,38 +1547,22 @@
 lemma emeasure_pmf_of_set_space[simp]: "emeasure pmf_of_set S = 1"
   by (rule measure_pmf.emeasure_eq_1_AE) (auto simp: AE_measure_pmf_iff)
 
-lemma nn_integral_pmf_of_set':
-  "(\<And>x. x \<in> S \<Longrightarrow> f x \<ge> 0) \<Longrightarrow> nn_integral (measure_pmf pmf_of_set) f = setsum f S / card S"
-apply(subst nn_integral_measure_pmf_finite, simp_all add: S_finite)
-apply(simp add: setsum_ereal_left_distrib[symmetric])
-apply(subst ereal_divide', simp add: S_not_empty S_finite)
-apply(simp add: ereal_times_divide_eq one_ereal_def[symmetric])
-done
+lemma nn_integral_pmf_of_set: "nn_integral (measure_pmf pmf_of_set) f = setsum f S / card S"
+  by (subst nn_integral_measure_pmf_finite)
+     (simp_all add: setsum_left_distrib[symmetric] card_gt_0_iff S_not_empty S_finite divide_ennreal_def
+                divide_ennreal[symmetric] ennreal_of_nat_eq_real_of_nat[symmetric] ennreal_times_divide)
 
-lemma nn_integral_pmf_of_set:
-  "nn_integral (measure_pmf pmf_of_set) f = setsum (max 0 \<circ> f) S / card S"
-apply(subst nn_integral_max_0[symmetric])
-apply(subst nn_integral_pmf_of_set')
-apply simp_all
-done
+lemma integral_pmf_of_set: "integral\<^sup>L (measure_pmf pmf_of_set) f = setsum f S / card S"
+  by (subst integral_measure_pmf[of S]) (auto simp: S_finite setsum_divide_distrib)
 
-lemma integral_pmf_of_set:
-  "integral\<^sup>L (measure_pmf pmf_of_set) f = setsum f S / card S"
-apply(simp add: real_lebesgue_integral_def integrable_measure_pmf_finite nn_integral_pmf_of_set S_finite)
-apply(subst real_of_ereal_minus')
- apply(simp add: ereal_max_0 S_finite del: ereal_max)
-apply(simp add: ereal_max_0 S_finite S_not_empty del: ereal_max)
-apply(simp add: field_simps S_finite S_not_empty)
-apply(subst setsum.distrib[symmetric])
-apply(rule setsum.cong; simp_all)
-done
+lemma emeasure_pmf_of_set: "emeasure (measure_pmf pmf_of_set) A = card (S \<inter> A) / card S"
+  by (subst nn_integral_indicator[symmetric], simp)
+     (simp add: S_finite S_not_empty card_gt_0_iff indicator_def setsum.If_cases divide_ennreal
+                ennreal_of_nat_eq_real_of_nat nn_integral_pmf_of_set)
 
-lemma emeasure_pmf_of_set:
-  "emeasure (measure_pmf pmf_of_set) A = card (S \<inter> A) / card S"
-apply(subst nn_integral_indicator[symmetric], simp)
-apply(subst nn_integral_pmf_of_set)
-apply(simp add: o_def max_def ereal_indicator[symmetric] S_not_empty S_finite real_of_nat_indicator[symmetric] of_nat_setsum[symmetric] setsum_indicator_eq_card del: of_nat_setsum)
-done
+lemma measure_pmf_of_set: "measure (measure_pmf pmf_of_set) A = card (S \<inter> A) / card S"
+  using emeasure_pmf_of_set[OF assms, of A]
+  by (simp add: measure_nonneg measure_pmf.emeasure_eq_measure)
 
 end
 
@@ -1574,15 +1589,7 @@
 qed
 
 lemma bernoulli_pmf_half_conv_pmf_of_set: "bernoulli_pmf (1 / 2) = pmf_of_set UNIV"
-by(rule pmf_eqI) simp_all
-
-
-
-lemma measure_pmf_of_set:
-  assumes "S \<noteq> {}" "finite S"
-  shows "measure (measure_pmf (pmf_of_set S)) A = card (S \<inter> A) / card S"
-using emeasure_pmf_of_set[OF assms, of A]
-unfolding measure_pmf.emeasure_eq_measure by simp
+  by (rule pmf_eqI) simp_all
 
 subsubsection \<open> Poisson Distribution \<close>
 
@@ -1596,14 +1603,14 @@
     by (simp add: field_simps divide_inverse [symmetric])
   have "(\<integral>\<^sup>+(x::nat). rate ^ x / fact x * exp (-rate) \<partial>count_space UNIV) =
           exp (-rate) * (\<integral>\<^sup>+(x::nat). rate ^ x / fact x \<partial>count_space UNIV)"
-    by (simp add: field_simps nn_integral_cmult[symmetric])
+    by (simp add: field_simps nn_integral_cmult[symmetric] ennreal_mult'[symmetric])
   also from rate_pos have "(\<integral>\<^sup>+(x::nat). rate ^ x / fact x \<partial>count_space UNIV) = (\<Sum>x. rate ^ x / fact x)"
-    by (simp_all add: nn_integral_count_space_nat suminf_ereal summable suminf_ereal_finite)
+    by (simp_all add: nn_integral_count_space_nat suminf_ennreal summable ennreal_suminf_neq_top)
   also have "... = exp rate" unfolding exp_def
     by (simp add: field_simps divide_inverse [symmetric])
-  also have "ereal (exp (-rate)) * ereal (exp rate) = 1"
-    by (simp add: mult_exp_exp)
-  finally show "(\<integral>\<^sup>+ x. ereal (rate ^ x / (fact x) * exp (- rate)) \<partial>count_space UNIV) = 1" .
+  also have "ennreal (exp (-rate)) * ennreal (exp rate) = 1"
+    by (simp add: mult_exp_exp ennreal_mult[symmetric])
+  finally show "(\<integral>\<^sup>+ x. ennreal (rate ^ x / (fact x) * exp (- rate)) \<partial>count_space UNIV) = 1" .
 qed (simp add: rate_pos[THEN less_imp_le])
 
 lemma pmf_poisson[simp]: "pmf poisson_pmf k = rate ^ k / fact k * exp (-rate)"
@@ -1622,12 +1629,12 @@
 
 lift_definition binomial_pmf :: "nat pmf" is "\<lambda>k. (n choose k) * p^k * (1 - p)^(n - k)"
 proof
-  have "(\<integral>\<^sup>+k. ereal (real (n choose k) * p ^ k * (1 - p) ^ (n - k)) \<partial>count_space UNIV) =
-    ereal (\<Sum>k\<le>n. real (n choose k) * p ^ k * (1 - p) ^ (n - k))"
+  have "(\<integral>\<^sup>+k. ennreal (real (n choose k) * p ^ k * (1 - p) ^ (n - k)) \<partial>count_space UNIV) =
+    ennreal (\<Sum>k\<le>n. real (n choose k) * p ^ k * (1 - p) ^ (n - k))"
     using p_le_1 p_nonneg by (subst nn_integral_count_space') auto
   also have "(\<Sum>k\<le>n. real (n choose k) * p ^ k * (1 - p) ^ (n - k)) = (p + (1 - p)) ^ n"
     by (subst binomial_ring) (simp add: atLeast0AtMost)
-  finally show "(\<integral>\<^sup>+ x. ereal (real (n choose x) * p ^ x * (1 - p) ^ (n - x)) \<partial>count_space UNIV) = 1"
+  finally show "(\<integral>\<^sup>+ x. ennreal (real (n choose x) * p ^ x * (1 - p) ^ (n - x)) \<partial>count_space UNIV) = 1"
     by simp
 qed (insert p_nonneg p_le_1, simp)