src/HOL/Probability/Probability_Mass_Function.thy
changeset 59664 224741ede5ae
parent 59557 ebd8ecacfba6
child 59665 37adca7fd48f
     1.1 --- a/src/HOL/Probability/Probability_Mass_Function.thy	Tue Mar 10 09:49:17 2015 +0100
     1.2 +++ b/src/HOL/Probability/Probability_Mass_Function.thy	Tue Mar 10 10:53:48 2015 +0100
     1.3 @@ -12,6 +12,24 @@
     1.4    "~~/src/HOL/Library/Multiset"
     1.5  begin
     1.6  
     1.7 +lemma AE_emeasure_singleton:
     1.8 +  assumes x: "emeasure M {x} \<noteq> 0" and ae: "AE x in M. P x" shows "P x"
     1.9 +proof -
    1.10 +  from x have x_M: "{x} \<in> sets M"
    1.11 +    by (auto intro: emeasure_notin_sets)
    1.12 +  from ae obtain N where N: "{x\<in>space M. \<not> P x} \<subseteq> N" "emeasure M N = 0" "N \<in> sets M"
    1.13 +    by (auto elim: AE_E)
    1.14 +  { assume "\<not> P x"
    1.15 +    with x_M[THEN sets.sets_into_space] N have "emeasure M {x} \<le> emeasure M N"
    1.16 +      by (intro emeasure_mono) auto
    1.17 +    with x N have False
    1.18 +      by (auto simp: emeasure_le_0_iff) }
    1.19 +  then show "P x" by auto
    1.20 +qed
    1.21 +
    1.22 +lemma AE_measure_singleton: "measure M {x} \<noteq> 0 \<Longrightarrow> AE x in M. P x \<Longrightarrow> P x"
    1.23 +  by (metis AE_emeasure_singleton measure_def emeasure_empty measure_empty)
    1.24 +
    1.25  lemma ereal_divide': "b \<noteq> 0 \<Longrightarrow> ereal (a / b) = ereal a / ereal b"
    1.26    using ereal_divide[of a b] by simp
    1.27  
    1.28 @@ -84,7 +102,7 @@
    1.29      by (auto simp: emeasure_eq_measure)
    1.30  qed (auto intro!: exI[of _ "{x. measure M {x} \<noteq> 0}"] countable_support)
    1.31  
    1.32 -subsection {* PMF as measure *}
    1.33 +subsection \<open> PMF as measure \<close>
    1.34  
    1.35  typedef 'a pmf = "{M :: 'a measure. prob_space M \<and> sets M = UNIV \<and> (AE x in M. measure M {x} \<noteq> 0)}"
    1.36    morphisms measure_pmf Abs_pmf
    1.37 @@ -117,36 +135,8 @@
    1.38  
    1.39  interpretation pmf_as_measure .
    1.40  
    1.41 -lift_definition pmf :: "'a pmf \<Rightarrow> 'a \<Rightarrow> real" is "\<lambda>M x. measure M {x}" .
    1.42 -
    1.43 -lift_definition set_pmf :: "'a pmf \<Rightarrow> 'a set" is "\<lambda>M. {x. measure M {x} \<noteq> 0}" .
    1.44 -
    1.45 -lift_definition map_pmf :: "('a \<Rightarrow> 'b) \<Rightarrow> 'a pmf \<Rightarrow> 'b pmf" is
    1.46 -  "\<lambda>f M. distr M (count_space UNIV) f"
    1.47 -proof safe
    1.48 -  fix M and f :: "'a \<Rightarrow> 'b"
    1.49 -  let ?D = "distr M (count_space UNIV) f"
    1.50 -  assume "prob_space M" and [simp]: "sets M = UNIV" and ae: "AE x in M. measure M {x} \<noteq> 0"
    1.51 -  interpret prob_space M by fact
    1.52 -  from ae have "AE x in M. measure M (f -` {f x}) \<noteq> 0"
    1.53 -  proof eventually_elim
    1.54 -    fix x
    1.55 -    have "measure M {x} \<le> measure M (f -` {f x})"
    1.56 -      by (intro finite_measure_mono) auto
    1.57 -    then show "measure M {x} \<noteq> 0 \<Longrightarrow> measure M (f -` {f x}) \<noteq> 0"
    1.58 -      using measure_nonneg[of M "{x}"] by auto
    1.59 -  qed
    1.60 -  then show "AE x in ?D. measure ?D {x} \<noteq> 0"
    1.61 -    by (simp add: AE_distr_iff measure_distr measurable_def)
    1.62 -qed (auto simp: measurable_def prob_space.prob_space_distr)
    1.63 -
    1.64 -declare [[coercion set_pmf]]
    1.65 -
    1.66 -lemma countable_set_pmf [simp]: "countable (set_pmf p)"
    1.67 -  by transfer (metis prob_space.finite_measure finite_measure.countable_support)
    1.68 -
    1.69  lemma sets_measure_pmf[simp]: "sets (measure_pmf p) = UNIV"
    1.70 -  by transfer metis
    1.71 +  by transfer blast 
    1.72  
    1.73  lemma sets_measure_pmf_count_space[measurable_cong]:
    1.74    "sets (measure_pmf M) = sets (count_space UNIV)"
    1.75 @@ -164,19 +154,38 @@
    1.76  lemma measurable_pmf_measure2[simp]: "measurable N (M :: 'a pmf) = measurable N (count_space UNIV)"
    1.77    by (intro measurable_cong_sets) simp_all
    1.78  
    1.79 -lemma pmf_positive: "x \<in> set_pmf p \<Longrightarrow> 0 < pmf p x"
    1.80 -  by transfer (simp add: less_le measure_nonneg)
    1.81 +lemma measurable_pair_restrict_pmf2:
    1.82 +  assumes "countable A"
    1.83 +  assumes [measurable]: "\<And>y. y \<in> A \<Longrightarrow> (\<lambda>x. f (x, y)) \<in> measurable M L"
    1.84 +  shows "f \<in> measurable (M \<Otimes>\<^sub>M restrict_space (measure_pmf N) A) L" (is "f \<in> measurable ?M _")
    1.85 +proof -
    1.86 +  have [measurable_cong]: "sets (restrict_space (count_space UNIV) A) = sets (count_space A)"
    1.87 +    by (simp add: restrict_count_space)
    1.88  
    1.89 -lemma pmf_nonneg: "0 \<le> pmf p x"
    1.90 -  by transfer (simp add: measure_nonneg)
    1.91 +  show ?thesis
    1.92 +    by (intro measurable_compose_countable'[where f="\<lambda>a b. f (fst b, a)" and g=snd and I=A,
    1.93 +                                            unfolded pair_collapse] assms)
    1.94 +        measurable
    1.95 +qed
    1.96  
    1.97 -lemma pmf_le_1: "pmf p x \<le> 1"
    1.98 -  by (simp add: pmf.rep_eq)
    1.99 +lemma measurable_pair_restrict_pmf1:
   1.100 +  assumes "countable A"
   1.101 +  assumes [measurable]: "\<And>x. x \<in> A \<Longrightarrow> (\<lambda>y. f (x, y)) \<in> measurable N L"
   1.102 +  shows "f \<in> measurable (restrict_space (measure_pmf M) A \<Otimes>\<^sub>M N) L"
   1.103 +proof -
   1.104 +  have [measurable_cong]: "sets (restrict_space (count_space UNIV) A) = sets (count_space A)"
   1.105 +    by (simp add: restrict_count_space)
   1.106  
   1.107 -lemma emeasure_pmf_single:
   1.108 -  fixes M :: "'a pmf"
   1.109 -  shows "emeasure M {x} = pmf M x"
   1.110 -  by transfer (simp add: finite_measure.emeasure_eq_measure[OF prob_space.finite_measure])
   1.111 +  show ?thesis
   1.112 +    by (intro measurable_compose_countable'[where f="\<lambda>a b. f (a, snd b)" and g=fst and I=A,
   1.113 +                                            unfolded pair_collapse] assms)
   1.114 +        measurable
   1.115 +qed
   1.116 +
   1.117 +lift_definition pmf :: "'a pmf \<Rightarrow> 'a \<Rightarrow> real" is "\<lambda>M x. measure M {x}" .
   1.118 +
   1.119 +lift_definition set_pmf :: "'a pmf \<Rightarrow> 'a set" is "\<lambda>M. {x. measure M {x} \<noteq> 0}" .
   1.120 +declare [[coercion set_pmf]]
   1.121  
   1.122  lemma AE_measure_pmf: "AE x in (M::'a pmf). x \<in> M"
   1.123    by transfer simp
   1.124 @@ -187,15 +196,20 @@
   1.125    by transfer (simp add: finite_measure.emeasure_eq_measure[OF prob_space.finite_measure])
   1.126  
   1.127  lemma AE_measure_pmf_iff: "(AE x in measure_pmf M. P x) \<longleftrightarrow> (\<forall>y\<in>M. P y)"
   1.128 -proof -
   1.129 -  { fix y assume y: "y \<in> M" and P: "AE x in M. P x" "\<not> P y"
   1.130 -    with P have "AE x in M. x \<noteq> y"
   1.131 -      by auto
   1.132 -    with y have False
   1.133 -      by (simp add: emeasure_pmf_single_eq_zero_iff AE_iff_measurable[OF _ refl]) }
   1.134 -  then show ?thesis
   1.135 -    using AE_measure_pmf[of M] by auto
   1.136 -qed
   1.137 +  using AE_measure_singleton[of M] AE_measure_pmf[of M]
   1.138 +  by (auto simp: set_pmf.rep_eq)
   1.139 +
   1.140 +lemma countable_set_pmf [simp]: "countable (set_pmf p)"
   1.141 +  by transfer (metis prob_space.finite_measure finite_measure.countable_support)
   1.142 +
   1.143 +lemma pmf_positive: "x \<in> set_pmf p \<Longrightarrow> 0 < pmf p x"
   1.144 +  by transfer (simp add: less_le measure_nonneg)
   1.145 +
   1.146 +lemma pmf_nonneg: "0 \<le> pmf p x"
   1.147 +  by transfer (simp add: measure_nonneg)
   1.148 +
   1.149 +lemma pmf_le_1: "pmf p x \<le> 1"
   1.150 +  by (simp add: pmf.rep_eq)
   1.151  
   1.152  lemma set_pmf_not_empty: "set_pmf M \<noteq> {}"
   1.153    using AE_measure_pmf[of M] by (intro notI) simp
   1.154 @@ -203,6 +217,14 @@
   1.155  lemma set_pmf_iff: "x \<in> set_pmf M \<longleftrightarrow> pmf M x \<noteq> 0"
   1.156    by transfer simp
   1.157  
   1.158 +lemma set_pmf_eq: "set_pmf M = {x. pmf M x \<noteq> 0}"
   1.159 +  by (auto simp: set_pmf_iff)
   1.160 +
   1.161 +lemma emeasure_pmf_single:
   1.162 +  fixes M :: "'a pmf"
   1.163 +  shows "emeasure M {x} = pmf M x"
   1.164 +  by transfer (simp add: finite_measure.emeasure_eq_measure[OF prob_space.finite_measure])
   1.165 +
   1.166  lemma emeasure_measure_pmf_finite: "finite S \<Longrightarrow> emeasure (measure_pmf M) S = (\<Sum>s\<in>S. pmf M s)"
   1.167    by (subst emeasure_eq_setsum_singleton) (auto simp: emeasure_pmf_single)
   1.168  
   1.169 @@ -290,6 +312,155 @@
   1.170  using emeasure_eq_0_AE[where ?P="\<lambda>x. x \<in> A" and M="measure_pmf p"]
   1.171  by(auto simp add: null_sets_def AE_measure_pmf_iff)
   1.172  
   1.173 +lemma measure_subprob: "measure_pmf M \<in> space (subprob_algebra (count_space UNIV))"
   1.174 +  by (simp add: space_subprob_algebra subprob_space_measure_pmf)
   1.175 +
   1.176 +subsection \<open> Monad Interpretation \<close>
   1.177 +
   1.178 +lemma measurable_measure_pmf[measurable]:
   1.179 +  "(\<lambda>x. measure_pmf (M x)) \<in> measurable (count_space UNIV) (subprob_algebra (count_space UNIV))"
   1.180 +  by (auto simp: space_subprob_algebra intro!: prob_space_imp_subprob_space) unfold_locales
   1.181 +
   1.182 +lemma bind_measure_pmf_cong:
   1.183 +  assumes "\<And>x. A x \<in> space (subprob_algebra N)" "\<And>x. B x \<in> space (subprob_algebra N)"
   1.184 +  assumes "\<And>i. i \<in> set_pmf x \<Longrightarrow> A i = B i"
   1.185 +  shows "bind (measure_pmf x) A = bind (measure_pmf x) B"
   1.186 +proof (rule measure_eqI)
   1.187 +  show "sets (measure_pmf x \<guillemotright>= A) = sets (measure_pmf x \<guillemotright>= B)"
   1.188 +    using assms by (subst (1 2) sets_bind) (auto simp: space_subprob_algebra)
   1.189 +next
   1.190 +  fix X assume "X \<in> sets (measure_pmf x \<guillemotright>= A)"
   1.191 +  then have X: "X \<in> sets N"
   1.192 +    using assms by (subst (asm) sets_bind) (auto simp: space_subprob_algebra)
   1.193 +  show "emeasure (measure_pmf x \<guillemotright>= A) X = emeasure (measure_pmf x \<guillemotright>= B) X"
   1.194 +    using assms
   1.195 +    by (subst (1 2) emeasure_bind[where N=N, OF _ _ X])
   1.196 +       (auto intro!: nn_integral_cong_AE simp: AE_measure_pmf_iff)
   1.197 +qed
   1.198 +
   1.199 +lift_definition bind_pmf :: "'a pmf \<Rightarrow> ('a \<Rightarrow> 'b pmf ) \<Rightarrow> 'b pmf" is bind
   1.200 +proof (clarify, intro conjI)
   1.201 +  fix f :: "'a measure" and g :: "'a \<Rightarrow> 'b measure"
   1.202 +  assume "prob_space f"
   1.203 +  then interpret f: prob_space f .
   1.204 +  assume "sets f = UNIV" and ae_f: "AE x in f. measure f {x} \<noteq> 0"
   1.205 +  then have s_f[simp]: "sets f = sets (count_space UNIV)"
   1.206 +    by simp
   1.207 +  assume g: "\<And>x. prob_space (g x) \<and> sets (g x) = UNIV \<and> (AE y in g x. measure (g x) {y} \<noteq> 0)"
   1.208 +  then have g: "\<And>x. prob_space (g x)" and s_g[simp]: "\<And>x. sets (g x) = sets (count_space UNIV)"
   1.209 +    and ae_g: "\<And>x. AE y in g x. measure (g x) {y} \<noteq> 0"
   1.210 +    by auto
   1.211 +
   1.212 +  have [measurable]: "g \<in> measurable f (subprob_algebra (count_space UNIV))"
   1.213 +    by (auto simp: measurable_def space_subprob_algebra prob_space_imp_subprob_space g)
   1.214 +    
   1.215 +  show "prob_space (f \<guillemotright>= g)"
   1.216 +    using g by (intro f.prob_space_bind[where S="count_space UNIV"]) auto
   1.217 +  then interpret fg: prob_space "f \<guillemotright>= g" . 
   1.218 +  show [simp]: "sets (f \<guillemotright>= g) = UNIV"
   1.219 +    using sets_eq_imp_space_eq[OF s_f]
   1.220 +    by (subst sets_bind[where N="count_space UNIV"]) auto
   1.221 +  show "AE x in f \<guillemotright>= g. measure (f \<guillemotright>= g) {x} \<noteq> 0"
   1.222 +    apply (simp add: fg.prob_eq_0 AE_bind[where B="count_space UNIV"])
   1.223 +    using ae_f
   1.224 +    apply eventually_elim
   1.225 +    using ae_g
   1.226 +    apply eventually_elim
   1.227 +    apply (auto dest: AE_measure_singleton)
   1.228 +    done
   1.229 +qed
   1.230 +
   1.231 +lemma ereal_pmf_bind: "pmf (bind_pmf N f) i = (\<integral>\<^sup>+x. pmf (f x) i \<partial>measure_pmf N)"
   1.232 +  unfolding pmf.rep_eq bind_pmf.rep_eq
   1.233 +  by (auto simp: measure_pmf.measure_bind[where N="count_space UNIV"] measure_subprob measure_nonneg
   1.234 +           intro!: nn_integral_eq_integral[symmetric] measure_pmf.integrable_const_bound[where B=1])
   1.235 +
   1.236 +lemma pmf_bind: "pmf (bind_pmf N f) i = (\<integral>x. pmf (f x) i \<partial>measure_pmf N)"
   1.237 +  using ereal_pmf_bind[of N f i]
   1.238 +  by (subst (asm) nn_integral_eq_integral)
   1.239 +     (auto simp: pmf_nonneg pmf_le_1
   1.240 +           intro!: nn_integral_eq_integral[symmetric] measure_pmf.integrable_const_bound[where B=1])
   1.241 +
   1.242 +lemma bind_pmf_const[simp]: "bind_pmf M (\<lambda>x. c) = c"
   1.243 +  by transfer (simp add: bind_const' prob_space_imp_subprob_space)
   1.244 +
   1.245 +lemma set_bind_pmf: "set_pmf (bind_pmf M N) = (\<Union>M\<in>set_pmf M. set_pmf (N M))"
   1.246 +  unfolding set_pmf_eq ereal_eq_0(1)[symmetric] ereal_pmf_bind  
   1.247 +  by (auto simp add: nn_integral_0_iff_AE AE_measure_pmf_iff set_pmf_eq not_le less_le pmf_nonneg)
   1.248 +
   1.249 +lemma bind_pmf_cong:
   1.250 +  assumes "p = q"
   1.251 +  shows "(\<And>x. x \<in> set_pmf q \<Longrightarrow> f x = g x) \<Longrightarrow> bind_pmf p f = bind_pmf q g"
   1.252 +  unfolding `p = q`[symmetric] measure_pmf_inject[symmetric] bind_pmf.rep_eq
   1.253 +  by (auto simp: AE_measure_pmf_iff Pi_iff space_subprob_algebra subprob_space_measure_pmf
   1.254 +                 sets_bind[where N="count_space UNIV"] emeasure_bind[where N="count_space UNIV"]
   1.255 +           intro!: nn_integral_cong_AE measure_eqI)
   1.256 +
   1.257 +lemma bind_pmf_cong_simp:
   1.258 +  "p = q \<Longrightarrow> (\<And>x. x \<in> set_pmf q =simp=> f x = g x) \<Longrightarrow> bind_pmf p f = bind_pmf q g"
   1.259 +  by (simp add: simp_implies_def cong: bind_pmf_cong)
   1.260 +
   1.261 +lemma measure_pmf_bind: "measure_pmf (bind_pmf M f) = (measure_pmf M \<guillemotright>= (\<lambda>x. measure_pmf (f x)))"
   1.262 +  by transfer simp
   1.263 +
   1.264 +lemma nn_integral_bind_pmf[simp]: "(\<integral>\<^sup>+x. f x \<partial>bind_pmf M N) = (\<integral>\<^sup>+x. \<integral>\<^sup>+y. f y \<partial>N x \<partial>M)"
   1.265 +  using measurable_measure_pmf[of N]
   1.266 +  unfolding measure_pmf_bind
   1.267 +  apply (subst (1 3) nn_integral_max_0[symmetric])
   1.268 +  apply (intro nn_integral_bind[where B="count_space UNIV"])
   1.269 +  apply auto
   1.270 +  done
   1.271 +
   1.272 +lemma emeasure_bind_pmf[simp]: "emeasure (bind_pmf M N) X = (\<integral>\<^sup>+x. emeasure (N x) X \<partial>M)"
   1.273 +  using measurable_measure_pmf[of N]
   1.274 +  unfolding measure_pmf_bind
   1.275 +  by (subst emeasure_bind[where N="count_space UNIV"]) auto
   1.276 +                                
   1.277 +lift_definition return_pmf :: "'a \<Rightarrow> 'a pmf" is "return (count_space UNIV)"
   1.278 +  by (auto intro!: prob_space_return simp: AE_return measure_return)
   1.279 +
   1.280 +lemma bind_return_pmf: "bind_pmf (return_pmf x) f = f x"
   1.281 +  by transfer
   1.282 +     (auto intro!: prob_space_imp_subprob_space bind_return[where N="count_space UNIV"]
   1.283 +           simp: space_subprob_algebra)
   1.284 +
   1.285 +lemma set_return_pmf: "set_pmf (return_pmf x) = {x}"
   1.286 +  by transfer (auto simp add: measure_return split: split_indicator)
   1.287 +
   1.288 +lemma bind_return_pmf': "bind_pmf N return_pmf = N"
   1.289 +proof (transfer, clarify)
   1.290 +  fix N :: "'a measure" assume "sets N = UNIV" then show "N \<guillemotright>= return (count_space UNIV) = N"
   1.291 +    by (subst return_sets_cong[where N=N]) (simp_all add: bind_return')
   1.292 +qed
   1.293 +
   1.294 +lemma bind_assoc_pmf: "bind_pmf (bind_pmf A B) C = bind_pmf A (\<lambda>x. bind_pmf (B x) C)"
   1.295 +  by transfer
   1.296 +     (auto intro!: bind_assoc[where N="count_space UNIV" and R="count_space UNIV"]
   1.297 +           simp: measurable_def space_subprob_algebra prob_space_imp_subprob_space)
   1.298 +
   1.299 +definition "map_pmf f M = bind_pmf M (\<lambda>x. return_pmf (f x))"
   1.300 +
   1.301 +lemma map_bind_pmf: "map_pmf f (bind_pmf M g) = bind_pmf M (\<lambda>x. map_pmf f (g x))"
   1.302 +  by (simp add: map_pmf_def bind_assoc_pmf)
   1.303 +
   1.304 +lemma bind_map_pmf: "bind_pmf (map_pmf f M) g = bind_pmf M (\<lambda>x. g (f x))"
   1.305 +  by (simp add: map_pmf_def bind_assoc_pmf bind_return_pmf)
   1.306 +
   1.307 +lemma map_pmf_transfer[transfer_rule]:
   1.308 +  "rel_fun op = (rel_fun cr_pmf cr_pmf) (\<lambda>f M. distr M (count_space UNIV) f) map_pmf"
   1.309 +proof -
   1.310 +  have "rel_fun op = (rel_fun pmf_as_measure.cr_pmf pmf_as_measure.cr_pmf)
   1.311 +     (\<lambda>f M. M \<guillemotright>= (return (count_space UNIV) o f)) map_pmf"
   1.312 +    unfolding map_pmf_def[abs_def] comp_def by transfer_prover 
   1.313 +  then show ?thesis
   1.314 +    by (force simp: rel_fun_def cr_pmf_def bind_return_distr)
   1.315 +qed
   1.316 +
   1.317 +lemma map_pmf_rep_eq:
   1.318 +  "measure_pmf (map_pmf f M) = distr (measure_pmf M) (count_space UNIV) f"
   1.319 +  unfolding map_pmf_def bind_pmf.rep_eq comp_def return_pmf.rep_eq
   1.320 +  using bind_return_distr[of M f "count_space UNIV"] by (simp add: comp_def)
   1.321 +
   1.322  lemma map_pmf_id[simp]: "map_pmf id = id"
   1.323    by (rule, transfer) (auto simp: emeasure_distr measurable_def intro!: measure_eqI)
   1.324  
   1.325 @@ -302,20 +473,23 @@
   1.326  lemma map_pmf_comp: "map_pmf f (map_pmf g M) = map_pmf (\<lambda>x. f (g x)) M"
   1.327    using map_pmf_compose[of f g] by (simp add: comp_def)
   1.328  
   1.329 -lemma map_pmf_cong:
   1.330 -  assumes "p = q"
   1.331 -  shows "(\<And>x. x \<in> set_pmf q \<Longrightarrow> f x = g x) \<Longrightarrow> map_pmf f p = map_pmf g q"
   1.332 -  unfolding `p = q`[symmetric] measure_pmf_inject[symmetric] map_pmf.rep_eq
   1.333 -  by (auto simp add: emeasure_distr AE_measure_pmf_iff intro!: emeasure_eq_AE measure_eqI)
   1.334 +lemma map_pmf_cong: "p = q \<Longrightarrow> (\<And>x. x \<in> set_pmf q \<Longrightarrow> f x = g x) \<Longrightarrow> map_pmf f p = map_pmf g q"
   1.335 +  unfolding map_pmf_def by (rule bind_pmf_cong) auto
   1.336 +
   1.337 +lemma pmf_set_map: "set_pmf \<circ> map_pmf f = op ` f \<circ> set_pmf"
   1.338 +  by (auto simp add: comp_def fun_eq_iff map_pmf_def set_bind_pmf set_return_pmf)
   1.339 +
   1.340 +lemma set_map_pmf: "set_pmf (map_pmf f M) = f`set_pmf M"
   1.341 +  using pmf_set_map[of f] by (auto simp: comp_def fun_eq_iff)
   1.342  
   1.343  lemma emeasure_map_pmf[simp]: "emeasure (map_pmf f M) X = emeasure M (f -` X)"
   1.344 -  unfolding map_pmf.rep_eq by (subst emeasure_distr) auto
   1.345 +  unfolding map_pmf_rep_eq by (subst emeasure_distr) auto
   1.346  
   1.347  lemma nn_integral_map_pmf[simp]: "(\<integral>\<^sup>+x. f x \<partial>map_pmf g M) = (\<integral>\<^sup>+x. f (g x) \<partial>M)"
   1.348 -  unfolding map_pmf.rep_eq by (intro nn_integral_distr) auto
   1.349 +  unfolding map_pmf_rep_eq by (intro nn_integral_distr) auto
   1.350  
   1.351  lemma ereal_pmf_map: "pmf (map_pmf f p) x = (\<integral>\<^sup>+ y. indicator (f -` {x}) y \<partial>measure_pmf p)"
   1.352 -proof(transfer fixing: f x)
   1.353 +proof (transfer fixing: f x)
   1.354    fix p :: "'b measure"
   1.355    presume "prob_space p"
   1.356    then interpret prob_space p .
   1.357 @@ -324,36 +498,6 @@
   1.358      by(simp add: measure_distr measurable_def emeasure_eq_measure)
   1.359  qed simp_all
   1.360  
   1.361 -lemma pmf_set_map: 
   1.362 -  fixes f :: "'a \<Rightarrow> 'b"
   1.363 -  shows "set_pmf \<circ> map_pmf f = op ` f \<circ> set_pmf"
   1.364 -proof (rule, transfer, clarsimp simp add: measure_distr measurable_def)
   1.365 -  fix f :: "'a \<Rightarrow> 'b" and M :: "'a measure"
   1.366 -  assume "prob_space M" and ae: "AE x in M. measure M {x} \<noteq> 0" and [simp]: "sets M = UNIV"
   1.367 -  interpret prob_space M by fact
   1.368 -  show "{x. measure M (f -` {x}) \<noteq> 0} = f ` {x. measure M {x} \<noteq> 0}"
   1.369 -  proof safe
   1.370 -    fix x assume "measure M (f -` {x}) \<noteq> 0"
   1.371 -    moreover have "measure M (f -` {x}) = measure M {y. f y = x \<and> measure M {y} \<noteq> 0}"
   1.372 -      using ae by (intro finite_measure_eq_AE) auto
   1.373 -    ultimately have "{y. f y = x \<and> measure M {y} \<noteq> 0} \<noteq> {}"
   1.374 -      by (metis measure_empty)
   1.375 -    then show "x \<in> f ` {x. measure M {x} \<noteq> 0}"
   1.376 -      by auto
   1.377 -  next
   1.378 -    fix x assume "measure M {x} \<noteq> 0"
   1.379 -    then have "0 < measure M {x}"
   1.380 -      using measure_nonneg[of M "{x}"] by auto
   1.381 -    also have "measure M {x} \<le> measure M (f -` {f x})"
   1.382 -      by (intro finite_measure_mono) auto
   1.383 -    finally show "measure M (f -` {f x}) = 0 \<Longrightarrow> False"
   1.384 -      by simp
   1.385 -  qed
   1.386 -qed
   1.387 -
   1.388 -lemma set_map_pmf: "set_pmf (map_pmf f M) = f`set_pmf M"
   1.389 -  using pmf_set_map[of f] by (auto simp: comp_def fun_eq_iff)
   1.390 -
   1.391  lemma nn_integral_pmf: "(\<integral>\<^sup>+ x. pmf p x \<partial>count_space A) = emeasure (measure_pmf p) A"
   1.392  proof -
   1.393    have "(\<integral>\<^sup>+ x. pmf p x \<partial>count_space A) = (\<integral>\<^sup>+ x. pmf p x \<partial>count_space (A \<inter> set_pmf p))"
   1.394 @@ -367,7 +511,109 @@
   1.395    finally show ?thesis .
   1.396  qed
   1.397  
   1.398 -subsection {* PMFs as function *}
   1.399 +lemma map_return_pmf: "map_pmf f (return_pmf x) = return_pmf (f x)"
   1.400 +  by transfer (simp add: distr_return)
   1.401 +
   1.402 +lemma map_pmf_const[simp]: "map_pmf (\<lambda>_. c) M = return_pmf c"
   1.403 +  by transfer (auto simp: prob_space.distr_const)
   1.404 +
   1.405 +lemma pmf_return: "pmf (return_pmf x) y = indicator {y} x"
   1.406 +  by transfer (simp add: measure_return)
   1.407 +
   1.408 +lemma nn_integral_return_pmf[simp]: "0 \<le> f x \<Longrightarrow> (\<integral>\<^sup>+x. f x \<partial>return_pmf x) = f x"
   1.409 +  unfolding return_pmf.rep_eq by (intro nn_integral_return) auto
   1.410 +
   1.411 +lemma emeasure_return_pmf[simp]: "emeasure (return_pmf x) X = indicator X x"
   1.412 +  unfolding return_pmf.rep_eq by (intro emeasure_return) auto
   1.413 +
   1.414 +lemma return_pmf_inj[simp]: "return_pmf x = return_pmf y \<longleftrightarrow> x = y"
   1.415 +  by (metis insertI1 set_return_pmf singletonD)
   1.416 +
   1.417 +definition "pair_pmf A B = bind_pmf A (\<lambda>x. bind_pmf B (\<lambda>y. return_pmf (x, y)))"
   1.418 +
   1.419 +lemma pmf_pair: "pmf (pair_pmf M N) (a, b) = pmf M a * pmf N b"
   1.420 +  unfolding pair_pmf_def pmf_bind pmf_return
   1.421 +  apply (subst integral_measure_pmf[where A="{b}"])
   1.422 +  apply (auto simp: indicator_eq_0_iff)
   1.423 +  apply (subst integral_measure_pmf[where A="{a}"])
   1.424 +  apply (auto simp: indicator_eq_0_iff setsum_nonneg_eq_0_iff pmf_nonneg)
   1.425 +  done
   1.426 +
   1.427 +lemma set_pair_pmf: "set_pmf (pair_pmf A B) = set_pmf A \<times> set_pmf B"
   1.428 +  unfolding pair_pmf_def set_bind_pmf set_return_pmf by auto
   1.429 +
   1.430 +lemma measure_pmf_in_subprob_space[measurable (raw)]:
   1.431 +  "measure_pmf M \<in> space (subprob_algebra (count_space UNIV))"
   1.432 +  by (simp add: space_subprob_algebra) intro_locales
   1.433 +
   1.434 +lemma nn_integral_pair_pmf': "(\<integral>\<^sup>+x. f x \<partial>pair_pmf A B) = (\<integral>\<^sup>+a. \<integral>\<^sup>+b. f (a, b) \<partial>B \<partial>A)"
   1.435 +proof -
   1.436 +  have "(\<integral>\<^sup>+x. f x \<partial>pair_pmf A B) = (\<integral>\<^sup>+x. max 0 (f x) * indicator (A \<times> B) x \<partial>pair_pmf A B)"
   1.437 +    by (subst nn_integral_max_0[symmetric])
   1.438 +       (auto simp: AE_measure_pmf_iff set_pair_pmf intro!: nn_integral_cong_AE)
   1.439 +  also have "\<dots> = (\<integral>\<^sup>+a. \<integral>\<^sup>+b. max 0 (f (a, b)) * indicator (A \<times> B) (a, b) \<partial>B \<partial>A)"
   1.440 +    by (simp add: pair_pmf_def)
   1.441 +  also have "\<dots> = (\<integral>\<^sup>+a. \<integral>\<^sup>+b. max 0 (f (a, b)) \<partial>B \<partial>A)"
   1.442 +    by (auto intro!: nn_integral_cong_AE simp: AE_measure_pmf_iff)
   1.443 +  finally show ?thesis
   1.444 +    unfolding nn_integral_max_0 .
   1.445 +qed
   1.446 +
   1.447 +lemma bind_pair_pmf:
   1.448 +  assumes M[measurable]: "M \<in> measurable (count_space UNIV \<Otimes>\<^sub>M count_space UNIV) (subprob_algebra N)"
   1.449 +  shows "measure_pmf (pair_pmf A B) \<guillemotright>= M = (measure_pmf A \<guillemotright>= (\<lambda>x. measure_pmf B \<guillemotright>= (\<lambda>y. M (x, y))))"
   1.450 +    (is "?L = ?R")
   1.451 +proof (rule measure_eqI)
   1.452 +  have M'[measurable]: "M \<in> measurable (pair_pmf A B) (subprob_algebra N)"
   1.453 +    using M[THEN measurable_space] by (simp_all add: space_pair_measure)
   1.454 +
   1.455 +  note measurable_bind[where N="count_space UNIV", measurable]
   1.456 +  note measure_pmf_in_subprob_space[simp]
   1.457 +
   1.458 +  have sets_eq_N: "sets ?L = N"
   1.459 +    by (subst sets_bind[OF sets_kernel[OF M']]) auto
   1.460 +  show "sets ?L = sets ?R"
   1.461 +    using measurable_space[OF M]
   1.462 +    by (simp add: sets_eq_N space_pair_measure space_subprob_algebra)
   1.463 +  fix X assume "X \<in> sets ?L"
   1.464 +  then have X[measurable]: "X \<in> sets N"
   1.465 +    unfolding sets_eq_N .
   1.466 +  then show "emeasure ?L X = emeasure ?R X"
   1.467 +    apply (simp add: emeasure_bind[OF _ M' X])
   1.468 +    apply (simp add: nn_integral_bind[where B="count_space UNIV"] pair_pmf_def measure_pmf_bind[of A]
   1.469 +      nn_integral_measure_pmf_finite set_return_pmf emeasure_nonneg pmf_return one_ereal_def[symmetric])
   1.470 +    apply (subst emeasure_bind[OF _ _ X])
   1.471 +    apply measurable
   1.472 +    apply (subst emeasure_bind[OF _ _ X])
   1.473 +    apply measurable
   1.474 +    done
   1.475 +qed
   1.476 +
   1.477 +lemma map_fst_pair_pmf: "map_pmf fst (pair_pmf A B) = A"
   1.478 +  by (simp add: pair_pmf_def map_pmf_def bind_assoc_pmf bind_return_pmf bind_return_pmf')
   1.479 +
   1.480 +lemma map_snd_pair_pmf: "map_pmf snd (pair_pmf A B) = B"
   1.481 +  by (simp add: pair_pmf_def map_pmf_def bind_assoc_pmf bind_return_pmf bind_return_pmf')
   1.482 +
   1.483 +lemma nn_integral_pmf':
   1.484 +  "inj_on f A \<Longrightarrow> (\<integral>\<^sup>+x. pmf p (f x) \<partial>count_space A) = emeasure p (f ` A)"
   1.485 +  by (subst nn_integral_bij_count_space[where g=f and B="f`A"])
   1.486 +     (auto simp: bij_betw_def nn_integral_pmf)
   1.487 +
   1.488 +lemma pmf_le_0_iff[simp]: "pmf M p \<le> 0 \<longleftrightarrow> pmf M p = 0"
   1.489 +  using pmf_nonneg[of M p] by simp
   1.490 +
   1.491 +lemma min_pmf_0[simp]: "min (pmf M p) 0 = 0" "min 0 (pmf M p) = 0"
   1.492 +  using pmf_nonneg[of M p] by simp_all
   1.493 +
   1.494 +lemma pmf_eq_0_set_pmf: "pmf M p = 0 \<longleftrightarrow> p \<notin> set_pmf M"
   1.495 +  unfolding set_pmf_iff by simp
   1.496 +
   1.497 +lemma pmf_map_inj: "inj_on f (set_pmf M) \<Longrightarrow> x \<in> set_pmf M \<Longrightarrow> pmf (map_pmf f M) (f x) = pmf M x"
   1.498 +  by (auto simp: pmf.rep_eq map_pmf_rep_eq measure_distr AE_measure_pmf_iff inj_onD
   1.499 +           intro!: measure_pmf.finite_measure_eq_AE)
   1.500 +
   1.501 +subsection \<open> PMFs as function \<close>
   1.502  
   1.503  context
   1.504    fixes f :: "'a \<Rightarrow> real"
   1.505 @@ -468,8 +714,484 @@
   1.506  lemma pmf_eq_iff: "M = N \<longleftrightarrow> (\<forall>i. pmf M i = pmf N i)"
   1.507    by (auto intro: pmf_eqI)
   1.508  
   1.509 +lemma bind_commute_pmf: "bind_pmf A (\<lambda>x. bind_pmf B (C x)) = bind_pmf B (\<lambda>y. bind_pmf A (\<lambda>x. C x y))"
   1.510 +  unfolding pmf_eq_iff pmf_bind
   1.511 +proof
   1.512 +  fix i
   1.513 +  interpret B: prob_space "restrict_space B B"
   1.514 +    by (intro prob_space_restrict_space measure_pmf.emeasure_eq_1_AE)
   1.515 +       (auto simp: AE_measure_pmf_iff)
   1.516 +  interpret A: prob_space "restrict_space A A"
   1.517 +    by (intro prob_space_restrict_space measure_pmf.emeasure_eq_1_AE)
   1.518 +       (auto simp: AE_measure_pmf_iff)
   1.519 +
   1.520 +  interpret AB: pair_prob_space "restrict_space A A" "restrict_space B B"
   1.521 +    by unfold_locales
   1.522 +
   1.523 +  have "(\<integral> x. \<integral> y. pmf (C x y) i \<partial>B \<partial>A) = (\<integral> x. (\<integral> y. pmf (C x y) i \<partial>restrict_space B B) \<partial>A)"
   1.524 +    by (rule integral_cong) (auto intro!: integral_pmf_restrict)
   1.525 +  also have "\<dots> = (\<integral> x. (\<integral> y. pmf (C x y) i \<partial>restrict_space B B) \<partial>restrict_space A A)"
   1.526 +    by (intro integral_pmf_restrict B.borel_measurable_lebesgue_integral measurable_pair_restrict_pmf2
   1.527 +              countable_set_pmf borel_measurable_count_space)
   1.528 +  also have "\<dots> = (\<integral> y. \<integral> x. pmf (C x y) i \<partial>restrict_space A A \<partial>restrict_space B B)"
   1.529 +    by (rule AB.Fubini_integral[symmetric])
   1.530 +       (auto intro!: AB.integrable_const_bound[where B=1] measurable_pair_restrict_pmf2
   1.531 +             simp: pmf_nonneg pmf_le_1 measurable_restrict_space1)
   1.532 +  also have "\<dots> = (\<integral> y. \<integral> x. pmf (C x y) i \<partial>restrict_space A A \<partial>B)"
   1.533 +    by (intro integral_pmf_restrict[symmetric] A.borel_measurable_lebesgue_integral measurable_pair_restrict_pmf2
   1.534 +              countable_set_pmf borel_measurable_count_space)
   1.535 +  also have "\<dots> = (\<integral> y. \<integral> x. pmf (C x y) i \<partial>A \<partial>B)"
   1.536 +    by (rule integral_cong) (auto intro!: integral_pmf_restrict[symmetric])
   1.537 +  finally show "(\<integral> x. \<integral> y. pmf (C x y) i \<partial>B \<partial>A) = (\<integral> y. \<integral> x. pmf (C x y) i \<partial>A \<partial>B)" .
   1.538 +qed
   1.539 +
   1.540 +lemma pair_map_pmf1: "pair_pmf (map_pmf f A) B = map_pmf (apfst f) (pair_pmf A B)"
   1.541 +proof (safe intro!: pmf_eqI)
   1.542 +  fix a :: "'a" and b :: "'b"
   1.543 +  have [simp]: "\<And>c d. indicator (apfst f -` {(a, b)}) (c, d) = indicator (f -` {a}) c * (indicator {b} d::ereal)"
   1.544 +    by (auto split: split_indicator)
   1.545 +
   1.546 +  have "ereal (pmf (pair_pmf (map_pmf f A) B) (a, b)) =
   1.547 +         ereal (pmf (map_pmf (apfst f) (pair_pmf A B)) (a, b))"
   1.548 +    unfolding pmf_pair ereal_pmf_map
   1.549 +    by (simp add: nn_integral_pair_pmf' max_def emeasure_pmf_single nn_integral_multc pmf_nonneg
   1.550 +                  emeasure_map_pmf[symmetric] del: emeasure_map_pmf)
   1.551 +  then show "pmf (pair_pmf (map_pmf f A) B) (a, b) = pmf (map_pmf (apfst f) (pair_pmf A B)) (a, b)"
   1.552 +    by simp
   1.553 +qed
   1.554 +
   1.555 +lemma pair_map_pmf2: "pair_pmf A (map_pmf f B) = map_pmf (apsnd f) (pair_pmf A B)"
   1.556 +proof (safe intro!: pmf_eqI)
   1.557 +  fix a :: "'a" and b :: "'b"
   1.558 +  have [simp]: "\<And>c d. indicator (apsnd f -` {(a, b)}) (c, d) = indicator {a} c * (indicator (f -` {b}) d::ereal)"
   1.559 +    by (auto split: split_indicator)
   1.560 +
   1.561 +  have "ereal (pmf (pair_pmf A (map_pmf f B)) (a, b)) =
   1.562 +         ereal (pmf (map_pmf (apsnd f) (pair_pmf A B)) (a, b))"
   1.563 +    unfolding pmf_pair ereal_pmf_map
   1.564 +    by (simp add: nn_integral_pair_pmf' max_def emeasure_pmf_single nn_integral_cmult nn_integral_multc pmf_nonneg
   1.565 +                  emeasure_map_pmf[symmetric] del: emeasure_map_pmf)
   1.566 +  then show "pmf (pair_pmf A (map_pmf f B)) (a, b) = pmf (map_pmf (apsnd f) (pair_pmf A B)) (a, b)"
   1.567 +    by simp
   1.568 +qed
   1.569 +
   1.570 +lemma map_pair: "map_pmf (\<lambda>(a, b). (f a, g b)) (pair_pmf A B) = pair_pmf (map_pmf f A) (map_pmf g B)"
   1.571 +  by (simp add: pair_map_pmf2 pair_map_pmf1 map_pmf_comp split_beta')
   1.572 +
   1.573  end
   1.574  
   1.575 +subsection \<open> Conditional Probabilities \<close>
   1.576 +
   1.577 +context
   1.578 +  fixes p :: "'a pmf" and s :: "'a set"
   1.579 +  assumes not_empty: "set_pmf p \<inter> s \<noteq> {}"
   1.580 +begin
   1.581 +
   1.582 +interpretation pmf_as_measure .
   1.583 +
   1.584 +lemma emeasure_measure_pmf_not_zero: "emeasure (measure_pmf p) s \<noteq> 0"
   1.585 +proof
   1.586 +  assume "emeasure (measure_pmf p) s = 0"
   1.587 +  then have "AE x in measure_pmf p. x \<notin> s"
   1.588 +    by (rule AE_I[rotated]) auto
   1.589 +  with not_empty show False
   1.590 +    by (auto simp: AE_measure_pmf_iff)
   1.591 +qed
   1.592 +
   1.593 +lemma measure_measure_pmf_not_zero: "measure (measure_pmf p) s \<noteq> 0"
   1.594 +  using emeasure_measure_pmf_not_zero unfolding measure_pmf.emeasure_eq_measure by simp
   1.595 +
   1.596 +lift_definition cond_pmf :: "'a pmf" is
   1.597 +  "uniform_measure (measure_pmf p) s"
   1.598 +proof (intro conjI)
   1.599 +  show "prob_space (uniform_measure (measure_pmf p) s)"
   1.600 +    by (intro prob_space_uniform_measure) (auto simp: emeasure_measure_pmf_not_zero)
   1.601 +  show "AE x in uniform_measure (measure_pmf p) s. measure (uniform_measure (measure_pmf p) s) {x} \<noteq> 0"
   1.602 +    by (simp add: emeasure_measure_pmf_not_zero measure_measure_pmf_not_zero AE_uniform_measure
   1.603 +                  AE_measure_pmf_iff set_pmf.rep_eq)
   1.604 +qed simp
   1.605 +
   1.606 +lemma pmf_cond: "pmf cond_pmf x = (if x \<in> s then pmf p x / measure p s else 0)"
   1.607 +  by transfer (simp add: emeasure_measure_pmf_not_zero pmf.rep_eq)
   1.608 +
   1.609 +lemma set_cond_pmf: "set_pmf cond_pmf = set_pmf p \<inter> s"
   1.610 +  by (auto simp add: set_pmf_iff pmf_cond measure_measure_pmf_not_zero split: split_if_asm)
   1.611 +
   1.612 +end
   1.613 +
   1.614 +lemma cond_map_pmf:
   1.615 +  assumes "set_pmf p \<inter> f -` s \<noteq> {}"
   1.616 +  shows "cond_pmf (map_pmf f p) s = map_pmf f (cond_pmf p (f -` s))"
   1.617 +proof -
   1.618 +  have *: "set_pmf (map_pmf f p) \<inter> s \<noteq> {}"
   1.619 +    using assms by (simp add: set_map_pmf) auto
   1.620 +  { fix x
   1.621 +    have "ereal (pmf (map_pmf f (cond_pmf p (f -` s))) x) =
   1.622 +      emeasure p (f -` s \<inter> f -` {x}) / emeasure p (f -` s)"
   1.623 +      unfolding ereal_pmf_map cond_pmf.rep_eq[OF assms] by (simp add: nn_integral_uniform_measure)
   1.624 +    also have "f -` s \<inter> f -` {x} = (if x \<in> s then f -` {x} else {})"
   1.625 +      by auto
   1.626 +    also have "emeasure p (if x \<in> s then f -` {x} else {}) / emeasure p (f -` s) =
   1.627 +      ereal (pmf (cond_pmf (map_pmf f p) s) x)"
   1.628 +      using measure_measure_pmf_not_zero[OF *]
   1.629 +      by (simp add: pmf_cond[OF *] ereal_divide' ereal_pmf_map measure_pmf.emeasure_eq_measure[symmetric]
   1.630 +               del: ereal_divide)
   1.631 +    finally have "ereal (pmf (cond_pmf (map_pmf f p) s) x) = ereal (pmf (map_pmf f (cond_pmf p (f -` s))) x)"
   1.632 +      by simp }
   1.633 +  then show ?thesis
   1.634 +    by (intro pmf_eqI) simp
   1.635 +qed
   1.636 +
   1.637 +lemma bind_cond_pmf_cancel:
   1.638 +  assumes in_S: "\<And>x. x \<in> set_pmf p \<Longrightarrow> x \<in> S x" "\<And>x. x \<in> set_pmf q \<Longrightarrow> x \<in> S x"
   1.639 +  assumes S_eq: "\<And>x y. x \<in> S y \<Longrightarrow> S x = S y"
   1.640 +  and same: "\<And>x. measure (measure_pmf p) (S x) = measure (measure_pmf q) (S x)"
   1.641 +  shows "bind_pmf p (\<lambda>x. cond_pmf q (S x)) = q" (is "?lhs = _")
   1.642 +proof (rule pmf_eqI)
   1.643 +  { fix x
   1.644 +    assume "x \<in> set_pmf p"
   1.645 +    hence "set_pmf p \<inter> (S x) \<noteq> {}" using in_S by auto
   1.646 +    hence "measure (measure_pmf p) (S x) \<noteq> 0"
   1.647 +      by(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff)
   1.648 +    with same have "measure (measure_pmf q) (S x) \<noteq> 0" by simp
   1.649 +    hence "set_pmf q \<inter> S x \<noteq> {}"
   1.650 +      by(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff) }
   1.651 +  note [simp] = this
   1.652 +
   1.653 +  fix z
   1.654 +  have pmf_q_z: "z \<notin> S z \<Longrightarrow> pmf q z = 0"
   1.655 +    by(erule contrapos_np)(simp add: pmf_eq_0_set_pmf in_S)
   1.656 +
   1.657 +  have "ereal (pmf ?lhs z) = \<integral>\<^sup>+ x. ereal (pmf (cond_pmf q (S x)) z) \<partial>measure_pmf p"
   1.658 +    by(simp add: ereal_pmf_bind)
   1.659 +  also have "\<dots> = \<integral>\<^sup>+ x. ereal (pmf q z / measure p (S z)) * indicator (S z) x \<partial>measure_pmf p"
   1.660 +    by(rule nn_integral_cong_AE)(auto simp add: AE_measure_pmf_iff pmf_cond same pmf_q_z in_S dest!: S_eq split: split_indicator)
   1.661 +  also have "\<dots> = pmf q z" using pmf_nonneg[of q z]
   1.662 +    by (subst nn_integral_cmult)(auto simp add: measure_nonneg measure_pmf.emeasure_eq_measure same measure_pmf.prob_eq_0 AE_measure_pmf_iff pmf_eq_0_set_pmf in_S)
   1.663 +  finally show "pmf ?lhs z = pmf q z" by simp
   1.664 +qed
   1.665 +
   1.666 +subsection \<open> Relator \<close>
   1.667 +
   1.668 +inductive rel_pmf :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> 'a pmf \<Rightarrow> 'b pmf \<Rightarrow> bool"
   1.669 +for R p q
   1.670 +where
   1.671 +  "\<lbrakk> \<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> R x y; 
   1.672 +     map_pmf fst pq = p; map_pmf snd pq = q \<rbrakk>
   1.673 +  \<Longrightarrow> rel_pmf R p q"
   1.674 +
   1.675 +bnf pmf: "'a pmf" map: map_pmf sets: set_pmf bd : "natLeq" rel: rel_pmf
   1.676 +proof -
   1.677 +  show "map_pmf id = id" by (rule map_pmf_id)
   1.678 +  show "\<And>f g. map_pmf (f \<circ> g) = map_pmf f \<circ> map_pmf g" by (rule map_pmf_compose) 
   1.679 +  show "\<And>f g::'a \<Rightarrow> 'b. \<And>p. (\<And>x. x \<in> set_pmf p \<Longrightarrow> f x = g x) \<Longrightarrow> map_pmf f p = map_pmf g p"
   1.680 +    by (intro map_pmf_cong refl)
   1.681 +
   1.682 +  show "\<And>f::'a \<Rightarrow> 'b. set_pmf \<circ> map_pmf f = op ` f \<circ> set_pmf"
   1.683 +    by (rule pmf_set_map)
   1.684 +
   1.685 +  { fix p :: "'s pmf"
   1.686 +    have "(card_of (set_pmf p), card_of (UNIV :: nat set)) \<in> ordLeq"
   1.687 +      by (rule card_of_ordLeqI[where f="to_nat_on (set_pmf p)"])
   1.688 +         (auto intro: countable_set_pmf)
   1.689 +    also have "(card_of (UNIV :: nat set), natLeq) \<in> ordLeq"
   1.690 +      by (metis Field_natLeq card_of_least natLeq_Well_order)
   1.691 +    finally show "(card_of (set_pmf p), natLeq) \<in> ordLeq" . }
   1.692 +
   1.693 +  show "\<And>R. rel_pmf R =
   1.694 +         (BNF_Def.Grp {x. set_pmf x \<subseteq> {(x, y). R x y}} (map_pmf fst))\<inverse>\<inverse> OO
   1.695 +         BNF_Def.Grp {x. set_pmf x \<subseteq> {(x, y). R x y}} (map_pmf snd)"
   1.696 +     by (auto simp add: fun_eq_iff BNF_Def.Grp_def OO_def rel_pmf.simps)
   1.697 +
   1.698 +  { fix p :: "'a pmf" and f :: "'a \<Rightarrow> 'b" and g x
   1.699 +    assume p: "\<And>z. z \<in> set_pmf p \<Longrightarrow> f z = g z"
   1.700 +      and x: "x \<in> set_pmf p"
   1.701 +    thus "f x = g x" by simp }
   1.702 +
   1.703 +  fix R :: "'a => 'b \<Rightarrow> bool" and S :: "'b \<Rightarrow> 'c \<Rightarrow> bool"
   1.704 +  { fix p q r
   1.705 +    assume pq: "rel_pmf R p q"
   1.706 +      and qr:"rel_pmf S q r"
   1.707 +    from pq obtain pq where pq: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> R x y"
   1.708 +      and p: "p = map_pmf fst pq" and q: "q = map_pmf snd pq" by cases auto
   1.709 +    from qr obtain qr where qr: "\<And>y z. (y, z) \<in> set_pmf qr \<Longrightarrow> S y z"
   1.710 +      and q': "q = map_pmf fst qr" and r: "r = map_pmf snd qr" by cases auto
   1.711 +
   1.712 +    def pr \<equiv> "bind_pmf pq (\<lambda>(x, y). bind_pmf (cond_pmf qr {(y', z). y' = y}) (\<lambda>(y', z). return_pmf (x, z)))"
   1.713 +    have pr_welldefined: "\<And>y. y \<in> q \<Longrightarrow> qr \<inter> {(y', z). y' = y} \<noteq> {}"
   1.714 +      by (force simp: q' set_map_pmf)
   1.715 +
   1.716 +    have "rel_pmf (R OO S) p r"
   1.717 +    proof (rule rel_pmf.intros)
   1.718 +      fix x z assume "(x, z) \<in> pr"
   1.719 +      then have "\<exists>y. (x, y) \<in> pq \<and> (y, z) \<in> qr"
   1.720 +        by (auto simp: q pr_welldefined pr_def set_bind_pmf split_beta set_return_pmf set_cond_pmf set_map_pmf)
   1.721 +      with pq qr show "(R OO S) x z"
   1.722 +        by blast
   1.723 +    next
   1.724 +      have "map_pmf snd pr = map_pmf snd (bind_pmf q (\<lambda>y. cond_pmf qr {(y', z). y' = y}))"
   1.725 +        by (simp add: pr_def q split_beta bind_map_pmf map_pmf_def[symmetric] map_bind_pmf map_return_pmf)
   1.726 +      then show "map_pmf snd pr = r"
   1.727 +        unfolding r q' bind_map_pmf by (subst (asm) bind_cond_pmf_cancel) auto
   1.728 +    qed (simp add: pr_def map_bind_pmf split_beta map_return_pmf map_pmf_def[symmetric] p) }
   1.729 +  then show "rel_pmf R OO rel_pmf S \<le> rel_pmf (R OO S)"
   1.730 +    by(auto simp add: le_fun_def)
   1.731 +qed (fact natLeq_card_order natLeq_cinfinite)+
   1.732 +
   1.733 +lemma rel_pmf_return_pmf1: "rel_pmf R (return_pmf x) M \<longleftrightarrow> (\<forall>a\<in>M. R x a)"
   1.734 +proof safe
   1.735 +  fix a assume "a \<in> M" "rel_pmf R (return_pmf x) M"
   1.736 +  then obtain pq where *: "\<And>a b. (a, b) \<in> set_pmf pq \<Longrightarrow> R a b"
   1.737 +    and eq: "return_pmf x = map_pmf fst pq" "M = map_pmf snd pq"
   1.738 +    by (force elim: rel_pmf.cases)
   1.739 +  moreover have "set_pmf (return_pmf x) = {x}"
   1.740 +    by (simp add: set_return_pmf)
   1.741 +  with `a \<in> M` have "(x, a) \<in> pq"
   1.742 +    by (force simp: eq set_map_pmf)
   1.743 +  with * show "R x a"
   1.744 +    by auto
   1.745 +qed (auto intro!: rel_pmf.intros[where pq="pair_pmf (return_pmf x) M"]
   1.746 +          simp: map_fst_pair_pmf map_snd_pair_pmf set_pair_pmf set_return_pmf)
   1.747 +
   1.748 +lemma rel_pmf_return_pmf2: "rel_pmf R M (return_pmf x) \<longleftrightarrow> (\<forall>a\<in>M. R a x)"
   1.749 +  by (subst pmf.rel_flip[symmetric]) (simp add: rel_pmf_return_pmf1)
   1.750 +
   1.751 +lemma rel_return_pmf[simp]: "rel_pmf R (return_pmf x1) (return_pmf x2) = R x1 x2"
   1.752 +  unfolding rel_pmf_return_pmf2 set_return_pmf by simp
   1.753 +
   1.754 +lemma rel_pmf_False[simp]: "rel_pmf (\<lambda>x y. False) x y = False"
   1.755 +  unfolding pmf.in_rel fun_eq_iff using set_pmf_not_empty by fastforce
   1.756 +
   1.757 +lemma rel_pmf_rel_prod:
   1.758 +  "rel_pmf (rel_prod R S) (pair_pmf A A') (pair_pmf B B') \<longleftrightarrow> rel_pmf R A B \<and> rel_pmf S A' B'"
   1.759 +proof safe
   1.760 +  assume "rel_pmf (rel_prod R S) (pair_pmf A A') (pair_pmf B B')"
   1.761 +  then obtain pq where pq: "\<And>a b c d. ((a, c), (b, d)) \<in> set_pmf pq \<Longrightarrow> R a b \<and> S c d"
   1.762 +    and eq: "map_pmf fst pq = pair_pmf A A'" "map_pmf snd pq = pair_pmf B B'"
   1.763 +    by (force elim: rel_pmf.cases)
   1.764 +  show "rel_pmf R A B"
   1.765 +  proof (rule rel_pmf.intros)
   1.766 +    let ?f = "\<lambda>(a, b). (fst a, fst b)"
   1.767 +    have [simp]: "(\<lambda>x. fst (?f x)) = fst o fst" "(\<lambda>x. snd (?f x)) = fst o snd"
   1.768 +      by auto
   1.769 +
   1.770 +    show "map_pmf fst (map_pmf ?f pq) = A"
   1.771 +      by (simp add: map_pmf_comp pmf.map_comp[symmetric] eq map_fst_pair_pmf)
   1.772 +    show "map_pmf snd (map_pmf ?f pq) = B"
   1.773 +      by (simp add: map_pmf_comp pmf.map_comp[symmetric] eq map_fst_pair_pmf)
   1.774 +
   1.775 +    fix a b assume "(a, b) \<in> set_pmf (map_pmf ?f pq)"
   1.776 +    then obtain c d where "((a, c), (b, d)) \<in> set_pmf pq"
   1.777 +      by (auto simp: set_map_pmf)
   1.778 +    from pq[OF this] show "R a b" ..
   1.779 +  qed
   1.780 +  show "rel_pmf S A' B'"
   1.781 +  proof (rule rel_pmf.intros)
   1.782 +    let ?f = "\<lambda>(a, b). (snd a, snd b)"
   1.783 +    have [simp]: "(\<lambda>x. fst (?f x)) = snd o fst" "(\<lambda>x. snd (?f x)) = snd o snd"
   1.784 +      by auto
   1.785 +
   1.786 +    show "map_pmf fst (map_pmf ?f pq) = A'"
   1.787 +      by (simp add: map_pmf_comp pmf.map_comp[symmetric] eq map_snd_pair_pmf)
   1.788 +    show "map_pmf snd (map_pmf ?f pq) = B'"
   1.789 +      by (simp add: map_pmf_comp pmf.map_comp[symmetric] eq map_snd_pair_pmf)
   1.790 +
   1.791 +    fix c d assume "(c, d) \<in> set_pmf (map_pmf ?f pq)"
   1.792 +    then obtain a b where "((a, c), (b, d)) \<in> set_pmf pq"
   1.793 +      by (auto simp: set_map_pmf)
   1.794 +    from pq[OF this] show "S c d" ..
   1.795 +  qed
   1.796 +next
   1.797 +  assume "rel_pmf R A B" "rel_pmf S A' B'"
   1.798 +  then obtain Rpq Spq
   1.799 +    where Rpq: "\<And>a b. (a, b) \<in> set_pmf Rpq \<Longrightarrow> R a b"
   1.800 +        "map_pmf fst Rpq = A" "map_pmf snd Rpq = B"
   1.801 +      and Spq: "\<And>a b. (a, b) \<in> set_pmf Spq \<Longrightarrow> S a b"
   1.802 +        "map_pmf fst Spq = A'" "map_pmf snd Spq = B'"
   1.803 +    by (force elim: rel_pmf.cases)
   1.804 +
   1.805 +  let ?f = "(\<lambda>((a, c), (b, d)). ((a, b), (c, d)))"
   1.806 +  let ?pq = "map_pmf ?f (pair_pmf Rpq Spq)"
   1.807 +  have [simp]: "(\<lambda>x. fst (?f x)) = (\<lambda>(a, b). (fst a, fst b))" "(\<lambda>x. snd (?f x)) = (\<lambda>(a, b). (snd a, snd b))"
   1.808 +    by auto
   1.809 +
   1.810 +  show "rel_pmf (rel_prod R S) (pair_pmf A A') (pair_pmf B B')"
   1.811 +    by (rule rel_pmf.intros[where pq="?pq"])
   1.812 +       (auto simp: map_snd_pair_pmf map_fst_pair_pmf set_pair_pmf set_map_pmf map_pmf_comp Rpq Spq
   1.813 +                   map_pair)
   1.814 +qed
   1.815 +
   1.816 +lemma rel_pmf_reflI: 
   1.817 +  assumes "\<And>x. x \<in> set_pmf p \<Longrightarrow> P x x"
   1.818 +  shows "rel_pmf P p p"
   1.819 +by(rule rel_pmf.intros[where pq="map_pmf (\<lambda>x. (x, x)) p"])(auto simp add: pmf.map_comp o_def set_map_pmf assms)
   1.820 +
   1.821 +context
   1.822 +begin
   1.823 +
   1.824 +interpretation pmf_as_measure .
   1.825 +
   1.826 +definition "join_pmf M = bind_pmf M (\<lambda>x. x)"
   1.827 +
   1.828 +lemma bind_eq_join_pmf: "bind_pmf M f = join_pmf (map_pmf f M)"
   1.829 +  unfolding join_pmf_def bind_map_pmf ..
   1.830 +
   1.831 +lemma join_eq_bind_pmf: "join_pmf M = bind_pmf M id"
   1.832 +  by (simp add: join_pmf_def id_def)
   1.833 +
   1.834 +lemma pmf_join: "pmf (join_pmf N) i = (\<integral>M. pmf M i \<partial>measure_pmf N)"
   1.835 +  unfolding join_pmf_def pmf_bind ..
   1.836 +
   1.837 +lemma ereal_pmf_join: "ereal (pmf (join_pmf N) i) = (\<integral>\<^sup>+M. pmf M i \<partial>measure_pmf N)"
   1.838 +  unfolding join_pmf_def ereal_pmf_bind ..
   1.839 +
   1.840 +lemma set_pmf_join_pmf: "set_pmf (join_pmf f) = (\<Union>p\<in>set_pmf f. set_pmf p)"
   1.841 +  by (simp add: join_pmf_def set_bind_pmf)
   1.842 +
   1.843 +lemma join_return_pmf: "join_pmf (return_pmf M) = M"
   1.844 +  by (simp add: integral_return pmf_eq_iff pmf_join return_pmf.rep_eq)
   1.845 +
   1.846 +lemma map_join_pmf: "map_pmf f (join_pmf AA) = join_pmf (map_pmf (map_pmf f) AA)"
   1.847 +  by (simp add: join_pmf_def map_pmf_def bind_assoc_pmf bind_return_pmf)
   1.848 +
   1.849 +lemma join_map_return_pmf: "join_pmf (map_pmf return_pmf A) = A"
   1.850 +  by (simp add: join_pmf_def map_pmf_def bind_assoc_pmf bind_return_pmf bind_return_pmf')
   1.851 +
   1.852 +end
   1.853 +
   1.854 +lemma rel_pmf_joinI:
   1.855 +  assumes "rel_pmf (rel_pmf P) p q"
   1.856 +  shows "rel_pmf P (join_pmf p) (join_pmf q)"
   1.857 +proof -
   1.858 +  from assms obtain pq where p: "p = map_pmf fst pq"
   1.859 +    and q: "q = map_pmf snd pq"
   1.860 +    and P: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> rel_pmf P x y"
   1.861 +    by cases auto
   1.862 +  from P obtain PQ 
   1.863 +    where PQ: "\<And>x y a b. \<lbrakk> (x, y) \<in> set_pmf pq; (a, b) \<in> set_pmf (PQ x y) \<rbrakk> \<Longrightarrow> P a b"
   1.864 +    and x: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> map_pmf fst (PQ x y) = x"
   1.865 +    and y: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> map_pmf snd (PQ x y) = y"
   1.866 +    by(metis rel_pmf.simps)
   1.867 +
   1.868 +  let ?r = "bind_pmf pq (\<lambda>(x, y). PQ x y)"
   1.869 +  have "\<And>a b. (a, b) \<in> set_pmf ?r \<Longrightarrow> P a b" by(auto simp add: set_bind_pmf intro: PQ)
   1.870 +  moreover have "map_pmf fst ?r = join_pmf p" "map_pmf snd ?r = join_pmf q"
   1.871 +    by (simp_all add: p q x y join_pmf_def map_bind_pmf bind_map_pmf split_def cong: bind_pmf_cong)
   1.872 +  ultimately show ?thesis ..
   1.873 +qed
   1.874 +
   1.875 +lemma rel_pmf_bindI:
   1.876 +  assumes pq: "rel_pmf R p q"
   1.877 +  and fg: "\<And>x y. R x y \<Longrightarrow> rel_pmf P (f x) (g y)"
   1.878 +  shows "rel_pmf P (bind_pmf p f) (bind_pmf q g)"
   1.879 +  unfolding bind_eq_join_pmf
   1.880 +  by (rule rel_pmf_joinI)
   1.881 +     (auto simp add: pmf.rel_map intro: pmf.rel_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp, OF _ pq] fg)
   1.882 +
   1.883 +text {*
   1.884 +  Proof that @{const rel_pmf} preserves orders.
   1.885 +  Antisymmetry proof follows Thm. 1 in N. Saheb-Djahromi, Cpo's of measures for nondeterminism, 
   1.886 +  Theoretical Computer Science 12(1):19--37, 1980, 
   1.887 +  @{url "http://dx.doi.org/10.1016/0304-3975(80)90003-1"}
   1.888 +*}
   1.889 +
   1.890 +lemma 
   1.891 +  assumes *: "rel_pmf R p q"
   1.892 +  and refl: "reflp R" and trans: "transp R"
   1.893 +  shows measure_Ici: "measure p {y. R x y} \<le> measure q {y. R x y}" (is ?thesis1)
   1.894 +  and measure_Ioi: "measure p {y. R x y \<and> \<not> R y x} \<le> measure q {y. R x y \<and> \<not> R y x}" (is ?thesis2)
   1.895 +proof -
   1.896 +  from * obtain pq
   1.897 +    where pq: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> R x y"
   1.898 +    and p: "p = map_pmf fst pq"
   1.899 +    and q: "q = map_pmf snd pq"
   1.900 +    by cases auto
   1.901 +  show ?thesis1 ?thesis2 unfolding p q map_pmf_rep_eq using refl trans
   1.902 +    by(auto 4 3 simp add: measure_distr reflpD AE_measure_pmf_iff intro!: measure_pmf.finite_measure_mono_AE dest!: pq elim: transpE)
   1.903 +qed
   1.904 +
   1.905 +lemma rel_pmf_inf:
   1.906 +  fixes p q :: "'a pmf"
   1.907 +  assumes 1: "rel_pmf R p q"
   1.908 +  assumes 2: "rel_pmf R q p"
   1.909 +  and refl: "reflp R" and trans: "transp R"
   1.910 +  shows "rel_pmf (inf R R\<inverse>\<inverse>) p q"
   1.911 +proof
   1.912 +  let ?E = "\<lambda>x. {y. R x y \<and> R y x}"
   1.913 +  let ?\<mu>E = "\<lambda>x. measure q (?E x)"
   1.914 +  { fix x
   1.915 +    have "measure p (?E x) = measure p ({y. R x y} - {y. R x y \<and> \<not> R y x})"
   1.916 +      by(auto intro!: arg_cong[where f="measure p"])
   1.917 +    also have "\<dots> = measure p {y. R x y} - measure p {y. R x y \<and> \<not> R y x}"
   1.918 +      by (rule measure_pmf.finite_measure_Diff) auto
   1.919 +    also have "measure p {y. R x y \<and> \<not> R y x} = measure q {y. R x y \<and> \<not> R y x}"
   1.920 +      using 1 2 refl trans by(auto intro!: Orderings.antisym measure_Ioi)
   1.921 +    also have "measure p {y. R x y} = measure q {y. R x y}"
   1.922 +      using 1 2 refl trans by(auto intro!: Orderings.antisym measure_Ici)
   1.923 +    also have "measure q {y. R x y} - measure q {y. R x y \<and> ~ R y x} =
   1.924 +      measure q ({y. R x y} - {y. R x y \<and> \<not> R y x})"
   1.925 +      by(rule measure_pmf.finite_measure_Diff[symmetric]) auto
   1.926 +    also have "\<dots> = ?\<mu>E x"
   1.927 +      by(auto intro!: arg_cong[where f="measure q"])
   1.928 +    also note calculation }
   1.929 +  note eq = this
   1.930 +
   1.931 +  def pq \<equiv> "bind_pmf p (\<lambda>x. bind_pmf (cond_pmf q (?E x)) (\<lambda>y. return_pmf (x, y)))"
   1.932 +
   1.933 +  show "map_pmf fst pq = p"
   1.934 +    by(simp add: pq_def map_bind_pmf map_return_pmf bind_return_pmf')
   1.935 +
   1.936 +  show "map_pmf snd pq = q"
   1.937 +    unfolding pq_def map_bind_pmf map_return_pmf bind_return_pmf' snd_conv
   1.938 +    by(subst bind_cond_pmf_cancel)(auto simp add: reflpD[OF \<open>reflp R\<close>] eq  intro: transpD[OF \<open>transp R\<close>])
   1.939 +
   1.940 +  fix x y
   1.941 +  assume "(x, y) \<in> set_pmf pq"
   1.942 +  moreover
   1.943 +  { assume "x \<in> set_pmf p"
   1.944 +    hence "measure (measure_pmf p) (?E x) \<noteq> 0"
   1.945 +      by(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff intro: reflpD[OF \<open>reflp R\<close>])
   1.946 +    hence "measure (measure_pmf q) (?E x) \<noteq> 0" using eq by simp
   1.947 +    hence "set_pmf q \<inter> {y. R x y \<and> R y x} \<noteq> {}" 
   1.948 +      by(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff) }
   1.949 +  ultimately show "inf R R\<inverse>\<inverse> x y"
   1.950 +    by(auto simp add: pq_def set_bind_pmf set_return_pmf set_cond_pmf)
   1.951 +qed
   1.952 +
   1.953 +lemma rel_pmf_antisym:
   1.954 +  fixes p q :: "'a pmf"
   1.955 +  assumes 1: "rel_pmf R p q"
   1.956 +  assumes 2: "rel_pmf R q p"
   1.957 +  and refl: "reflp R" and trans: "transp R" and antisym: "antisymP R"
   1.958 +  shows "p = q"
   1.959 +proof -
   1.960 +  from 1 2 refl trans have "rel_pmf (inf R R\<inverse>\<inverse>) p q" by(rule rel_pmf_inf)
   1.961 +  also have "inf R R\<inverse>\<inverse> = op ="
   1.962 +    using refl antisym by(auto intro!: ext simp add: reflpD dest: antisymD)
   1.963 +  finally show ?thesis unfolding pmf.rel_eq .
   1.964 +qed
   1.965 +
   1.966 +lemma reflp_rel_pmf: "reflp R \<Longrightarrow> reflp (rel_pmf R)"
   1.967 +by(blast intro: reflpI rel_pmf_reflI reflpD)
   1.968 +
   1.969 +lemma antisymP_rel_pmf:
   1.970 +  "\<lbrakk> reflp R; transp R; antisymP R \<rbrakk>
   1.971 +  \<Longrightarrow> antisymP (rel_pmf R)"
   1.972 +by(rule antisymI)(blast intro: rel_pmf_antisym)
   1.973 +
   1.974 +lemma transp_rel_pmf:
   1.975 +  assumes "transp R"
   1.976 +  shows "transp (rel_pmf R)"
   1.977 +proof (rule transpI)
   1.978 +  fix x y z
   1.979 +  assume "rel_pmf R x y" and "rel_pmf R y z"
   1.980 +  hence "rel_pmf (R OO R) x z" by (simp add: pmf.rel_compp relcompp.relcompI)
   1.981 +  thus "rel_pmf R x z"
   1.982 +    using assms by (metis (no_types) pmf.rel_mono rev_predicate2D transp_relcompp_less_eq)
   1.983 +qed
   1.984 +
   1.985 +subsection \<open> Distributions \<close>
   1.986 +
   1.987  context
   1.988  begin
   1.989  
   1.990 @@ -639,755 +1361,4 @@
   1.991  lemma set_pmf_binomial[simp]: "0 < p \<Longrightarrow> p < 1 \<Longrightarrow> set_pmf (binomial_pmf n p) = {..n}"
   1.992    by (simp add: set_pmf_binomial_eq)
   1.993  
   1.994 -subsection \<open> Monad Interpretation \<close>
   1.995 -
   1.996 -lemma measurable_measure_pmf[measurable]:
   1.997 -  "(\<lambda>x. measure_pmf (M x)) \<in> measurable (count_space UNIV) (subprob_algebra (count_space UNIV))"
   1.998 -  by (auto simp: space_subprob_algebra intro!: prob_space_imp_subprob_space) unfold_locales
   1.999 -
  1.1000 -lemma bind_measure_pmf_cong:
  1.1001 -  assumes "\<And>x. A x \<in> space (subprob_algebra N)" "\<And>x. B x \<in> space (subprob_algebra N)"
  1.1002 -  assumes "\<And>i. i \<in> set_pmf x \<Longrightarrow> A i = B i"
  1.1003 -  shows "bind (measure_pmf x) A = bind (measure_pmf x) B"
  1.1004 -proof (rule measure_eqI)
  1.1005 -  show "sets (measure_pmf x \<guillemotright>= A) = sets (measure_pmf x \<guillemotright>= B)"
  1.1006 -    using assms by (subst (1 2) sets_bind) (auto simp: space_subprob_algebra)
  1.1007 -next
  1.1008 -  fix X assume "X \<in> sets (measure_pmf x \<guillemotright>= A)"
  1.1009 -  then have X: "X \<in> sets N"
  1.1010 -    using assms by (subst (asm) sets_bind) (auto simp: space_subprob_algebra)
  1.1011 -  show "emeasure (measure_pmf x \<guillemotright>= A) X = emeasure (measure_pmf x \<guillemotright>= B) X"
  1.1012 -    using assms
  1.1013 -    by (subst (1 2) emeasure_bind[where N=N, OF _ _ X])
  1.1014 -       (auto intro!: nn_integral_cong_AE simp: AE_measure_pmf_iff)
  1.1015 -qed
  1.1016 -
  1.1017 -context
  1.1018 -begin
  1.1019 -
  1.1020 -interpretation pmf_as_measure .
  1.1021 -
  1.1022 -lift_definition join_pmf :: "'a pmf pmf \<Rightarrow> 'a pmf" is "\<lambda>M. measure_pmf M \<guillemotright>= measure_pmf"
  1.1023 -proof (intro conjI)
  1.1024 -  fix M :: "'a pmf pmf"
  1.1025 -
  1.1026 -  interpret bind: prob_space "measure_pmf M \<guillemotright>= measure_pmf"
  1.1027 -    apply (intro measure_pmf.prob_space_bind[where S="count_space UNIV"] AE_I2)
  1.1028 -    apply (auto intro!: subprob_space_measure_pmf simp: space_subprob_algebra)
  1.1029 -    apply unfold_locales
  1.1030 -    done
  1.1031 -  show "prob_space (measure_pmf M \<guillemotright>= measure_pmf)"
  1.1032 -    by intro_locales
  1.1033 -  show "sets (measure_pmf M \<guillemotright>= measure_pmf) = UNIV"
  1.1034 -    by (subst sets_bind) auto
  1.1035 -  have "AE x in measure_pmf M \<guillemotright>= measure_pmf. emeasure (measure_pmf M \<guillemotright>= measure_pmf) {x} \<noteq> 0"
  1.1036 -    by (auto simp: AE_bind[where B="count_space UNIV"] measure_pmf_in_subprob_algebra
  1.1037 -                   emeasure_bind[where N="count_space UNIV"] AE_measure_pmf_iff nn_integral_0_iff_AE
  1.1038 -                   measure_pmf.emeasure_eq_measure measure_le_0_iff set_pmf_iff pmf.rep_eq)
  1.1039 -  then show "AE x in measure_pmf M \<guillemotright>= measure_pmf. measure (measure_pmf M \<guillemotright>= measure_pmf) {x} \<noteq> 0"
  1.1040 -    unfolding bind.emeasure_eq_measure by simp
  1.1041 -qed
  1.1042 -
  1.1043 -lemma pmf_join: "pmf (join_pmf N) i = (\<integral>M. pmf M i \<partial>measure_pmf N)"
  1.1044 -proof (transfer fixing: N i)
  1.1045 -  have N: "subprob_space (measure_pmf N)"
  1.1046 -    by (rule prob_space_imp_subprob_space) intro_locales
  1.1047 -  show "measure (measure_pmf N \<guillemotright>= measure_pmf) {i} = integral\<^sup>L (measure_pmf N) (\<lambda>M. measure M {i})"
  1.1048 -    using measurable_measure_pmf[of "\<lambda>x. x"]
  1.1049 -    by (intro subprob_space.measure_bind[where N="count_space UNIV", OF N]) auto
  1.1050 -qed (auto simp: Transfer.Rel_def rel_fun_def cr_pmf_def)
  1.1051 -
  1.1052 -lemma ereal_pmf_join: "ereal (pmf (join_pmf N) i) = (\<integral>\<^sup>+M. pmf M i \<partial>measure_pmf N)"
  1.1053 -  unfolding pmf_join
  1.1054 -  by (intro nn_integral_eq_integral[symmetric] measure_pmf.integrable_const_bound[where B=1])
  1.1055 -     (auto simp: pmf_le_1 pmf_nonneg)
  1.1056 -
  1.1057 -lemma set_pmf_join_pmf: "set_pmf (join_pmf f) = (\<Union>p\<in>set_pmf f. set_pmf p)"
  1.1058 -apply(simp add: set_eq_iff set_pmf_iff pmf_join)
  1.1059 -apply(subst integral_nonneg_eq_0_iff_AE)
  1.1060 -apply(auto simp add: pmf_le_1 pmf_nonneg AE_measure_pmf_iff intro!: measure_pmf.integrable_const_bound[where B=1])
  1.1061 -done
  1.1062 -
  1.1063 -lift_definition return_pmf :: "'a \<Rightarrow> 'a pmf" is "return (count_space UNIV)"
  1.1064 -  by (auto intro!: prob_space_return simp: AE_return measure_return)
  1.1065 -
  1.1066 -lemma join_return_pmf: "join_pmf (return_pmf M) = M"
  1.1067 -  by (simp add: integral_return pmf_eq_iff pmf_join return_pmf.rep_eq)
  1.1068 -
  1.1069 -lemma map_return_pmf: "map_pmf f (return_pmf x) = return_pmf (f x)"
  1.1070 -  by transfer (simp add: distr_return)
  1.1071 -
  1.1072 -lemma map_pmf_const[simp]: "map_pmf (\<lambda>_. c) M = return_pmf c"
  1.1073 -  by transfer (auto simp: prob_space.distr_const)
  1.1074 -
  1.1075 -lemma set_return_pmf: "set_pmf (return_pmf x) = {x}"
  1.1076 -  by transfer (auto simp add: measure_return split: split_indicator)
  1.1077 -
  1.1078 -lemma pmf_return: "pmf (return_pmf x) y = indicator {y} x"
  1.1079 -  by transfer (simp add: measure_return)
  1.1080 -
  1.1081 -lemma nn_integral_return_pmf[simp]: "0 \<le> f x \<Longrightarrow> (\<integral>\<^sup>+x. f x \<partial>return_pmf x) = f x"
  1.1082 -  unfolding return_pmf.rep_eq by (intro nn_integral_return) auto
  1.1083 -
  1.1084 -lemma emeasure_return_pmf[simp]: "emeasure (return_pmf x) X = indicator X x"
  1.1085 -  unfolding return_pmf.rep_eq by (intro emeasure_return) auto
  1.1086 -
  1.1087  end
  1.1088 -
  1.1089 -lemma return_pmf_inj[simp]: "return_pmf x = return_pmf y \<longleftrightarrow> x = y"
  1.1090 -  by (metis insertI1 set_return_pmf singletonD)
  1.1091 -
  1.1092 -definition "bind_pmf M f = join_pmf (map_pmf f M)"
  1.1093 -
  1.1094 -lemma (in pmf_as_measure) bind_transfer[transfer_rule]:
  1.1095 -  "rel_fun pmf_as_measure.cr_pmf (rel_fun (rel_fun op = pmf_as_measure.cr_pmf) pmf_as_measure.cr_pmf) op \<guillemotright>= bind_pmf"
  1.1096 -proof (auto simp: pmf_as_measure.cr_pmf_def rel_fun_def bind_pmf_def join_pmf.rep_eq map_pmf.rep_eq)
  1.1097 -  fix M f and g :: "'a \<Rightarrow> 'b pmf" assume "\<forall>x. f x = measure_pmf (g x)"
  1.1098 -  then have f: "f = (\<lambda>x. measure_pmf (g x))"
  1.1099 -    by auto
  1.1100 -  show "measure_pmf M \<guillemotright>= f = distr (measure_pmf M) (count_space UNIV) g \<guillemotright>= measure_pmf"
  1.1101 -    unfolding f by (subst bind_distr[OF _ measurable_measure_pmf]) auto
  1.1102 -qed
  1.1103 -
  1.1104 -lemma ereal_pmf_bind: "pmf (bind_pmf N f) i = (\<integral>\<^sup>+x. pmf (f x) i \<partial>measure_pmf N)"
  1.1105 -  by (auto intro!: nn_integral_distr simp: bind_pmf_def ereal_pmf_join map_pmf.rep_eq)
  1.1106 -
  1.1107 -lemma pmf_bind: "pmf (bind_pmf N f) i = (\<integral>x. pmf (f x) i \<partial>measure_pmf N)"
  1.1108 -  by (auto intro!: integral_distr simp: bind_pmf_def pmf_join map_pmf.rep_eq)
  1.1109 -
  1.1110 -lemma bind_return_pmf: "bind_pmf (return_pmf x) f = f x"
  1.1111 -  unfolding bind_pmf_def map_return_pmf join_return_pmf ..
  1.1112 -
  1.1113 -lemma join_eq_bind_pmf: "join_pmf M = bind_pmf M id"
  1.1114 -  by (simp add: bind_pmf_def)
  1.1115 -
  1.1116 -lemma bind_pmf_const[simp]: "bind_pmf M (\<lambda>x. c) = c"
  1.1117 -  unfolding bind_pmf_def map_pmf_const join_return_pmf ..
  1.1118 -
  1.1119 -lemma set_bind_pmf: "set_pmf (bind_pmf M N) = (\<Union>M\<in>set_pmf M. set_pmf (N M))"
  1.1120 -  apply (simp add: set_eq_iff set_pmf_iff pmf_bind)
  1.1121 -  apply (subst integral_nonneg_eq_0_iff_AE)
  1.1122 -  apply (auto simp: pmf_nonneg pmf_le_1 AE_measure_pmf_iff
  1.1123 -              intro!: measure_pmf.integrable_const_bound[where B=1])
  1.1124 -  done
  1.1125 -
  1.1126 -
  1.1127 -lemma measurable_pair_restrict_pmf2:
  1.1128 -  assumes "countable A"
  1.1129 -  assumes [measurable]: "\<And>y. y \<in> A \<Longrightarrow> (\<lambda>x. f (x, y)) \<in> measurable M L"
  1.1130 -  shows "f \<in> measurable (M \<Otimes>\<^sub>M restrict_space (measure_pmf N) A) L" (is "f \<in> measurable ?M _")
  1.1131 -proof -
  1.1132 -  have [measurable_cong]: "sets (restrict_space (count_space UNIV) A) = sets (count_space A)"
  1.1133 -    by (simp add: restrict_count_space)
  1.1134 -
  1.1135 -  show ?thesis
  1.1136 -    by (intro measurable_compose_countable'[where f="\<lambda>a b. f (fst b, a)" and g=snd and I=A,
  1.1137 -                                            unfolded pair_collapse] assms)
  1.1138 -        measurable
  1.1139 -qed
  1.1140 -
  1.1141 -lemma measurable_pair_restrict_pmf1:
  1.1142 -  assumes "countable A"
  1.1143 -  assumes [measurable]: "\<And>x. x \<in> A \<Longrightarrow> (\<lambda>y. f (x, y)) \<in> measurable N L"
  1.1144 -  shows "f \<in> measurable (restrict_space (measure_pmf M) A \<Otimes>\<^sub>M N) L"
  1.1145 -proof -
  1.1146 -  have [measurable_cong]: "sets (restrict_space (count_space UNIV) A) = sets (count_space A)"
  1.1147 -    by (simp add: restrict_count_space)
  1.1148 -
  1.1149 -  show ?thesis
  1.1150 -    by (intro measurable_compose_countable'[where f="\<lambda>a b. f (a, snd b)" and g=fst and I=A,
  1.1151 -                                            unfolded pair_collapse] assms)
  1.1152 -        measurable
  1.1153 -qed
  1.1154 -                                
  1.1155 -lemma bind_commute_pmf: "bind_pmf A (\<lambda>x. bind_pmf B (C x)) = bind_pmf B (\<lambda>y. bind_pmf A (\<lambda>x. C x y))"
  1.1156 -  unfolding pmf_eq_iff pmf_bind
  1.1157 -proof
  1.1158 -  fix i
  1.1159 -  interpret B: prob_space "restrict_space B B"
  1.1160 -    by (intro prob_space_restrict_space measure_pmf.emeasure_eq_1_AE)
  1.1161 -       (auto simp: AE_measure_pmf_iff)
  1.1162 -  interpret A: prob_space "restrict_space A A"
  1.1163 -    by (intro prob_space_restrict_space measure_pmf.emeasure_eq_1_AE)
  1.1164 -       (auto simp: AE_measure_pmf_iff)
  1.1165 -
  1.1166 -  interpret AB: pair_prob_space "restrict_space A A" "restrict_space B B"
  1.1167 -    by unfold_locales
  1.1168 -
  1.1169 -  have "(\<integral> x. \<integral> y. pmf (C x y) i \<partial>B \<partial>A) = (\<integral> x. (\<integral> y. pmf (C x y) i \<partial>restrict_space B B) \<partial>A)"
  1.1170 -    by (rule integral_cong) (auto intro!: integral_pmf_restrict)
  1.1171 -  also have "\<dots> = (\<integral> x. (\<integral> y. pmf (C x y) i \<partial>restrict_space B B) \<partial>restrict_space A A)"
  1.1172 -    by (intro integral_pmf_restrict B.borel_measurable_lebesgue_integral measurable_pair_restrict_pmf2
  1.1173 -              countable_set_pmf borel_measurable_count_space)
  1.1174 -  also have "\<dots> = (\<integral> y. \<integral> x. pmf (C x y) i \<partial>restrict_space A A \<partial>restrict_space B B)"
  1.1175 -    by (rule AB.Fubini_integral[symmetric])
  1.1176 -       (auto intro!: AB.integrable_const_bound[where B=1] measurable_pair_restrict_pmf2
  1.1177 -             simp: pmf_nonneg pmf_le_1 measurable_restrict_space1)
  1.1178 -  also have "\<dots> = (\<integral> y. \<integral> x. pmf (C x y) i \<partial>restrict_space A A \<partial>B)"
  1.1179 -    by (intro integral_pmf_restrict[symmetric] A.borel_measurable_lebesgue_integral measurable_pair_restrict_pmf2
  1.1180 -              countable_set_pmf borel_measurable_count_space)
  1.1181 -  also have "\<dots> = (\<integral> y. \<integral> x. pmf (C x y) i \<partial>A \<partial>B)"
  1.1182 -    by (rule integral_cong) (auto intro!: integral_pmf_restrict[symmetric])
  1.1183 -  finally show "(\<integral> x. \<integral> y. pmf (C x y) i \<partial>B \<partial>A) = (\<integral> y. \<integral> x. pmf (C x y) i \<partial>A \<partial>B)" .
  1.1184 -qed
  1.1185 -
  1.1186 -
  1.1187 -context
  1.1188 -begin
  1.1189 -
  1.1190 -interpretation pmf_as_measure .
  1.1191 -
  1.1192 -lemma measure_pmf_bind: "measure_pmf (bind_pmf M f) = (measure_pmf M \<guillemotright>= (\<lambda>x. measure_pmf (f x)))"
  1.1193 -  by transfer simp
  1.1194 -
  1.1195 -lemma nn_integral_bind_pmf[simp]: "(\<integral>\<^sup>+x. f x \<partial>bind_pmf M N) = (\<integral>\<^sup>+x. \<integral>\<^sup>+y. f y \<partial>N x \<partial>M)"
  1.1196 -  using measurable_measure_pmf[of N]
  1.1197 -  unfolding measure_pmf_bind
  1.1198 -  apply (subst (1 3) nn_integral_max_0[symmetric])
  1.1199 -  apply (intro nn_integral_bind[where B="count_space UNIV"])
  1.1200 -  apply auto
  1.1201 -  done
  1.1202 -
  1.1203 -lemma emeasure_bind_pmf[simp]: "emeasure (bind_pmf M N) X = (\<integral>\<^sup>+x. emeasure (N x) X \<partial>M)"
  1.1204 -  using measurable_measure_pmf[of N]
  1.1205 -  unfolding measure_pmf_bind
  1.1206 -  by (subst emeasure_bind[where N="count_space UNIV"]) auto
  1.1207 -
  1.1208 -lemma bind_return_pmf': "bind_pmf N return_pmf = N"
  1.1209 -proof (transfer, clarify)
  1.1210 -  fix N :: "'a measure" assume "sets N = UNIV" then show "N \<guillemotright>= return (count_space UNIV) = N"
  1.1211 -    by (subst return_sets_cong[where N=N]) (simp_all add: bind_return')
  1.1212 -qed
  1.1213 -
  1.1214 -lemma bind_return_pmf'': "bind_pmf N (\<lambda>x. return_pmf (f x)) = map_pmf f N"
  1.1215 -proof (transfer, clarify)
  1.1216 -  fix N :: "'b measure" and f :: "'b \<Rightarrow> 'a" assume "prob_space N" "sets N = UNIV"
  1.1217 -  then show "N \<guillemotright>= (\<lambda>x. return (count_space UNIV) (f x)) = distr N (count_space UNIV) f"
  1.1218 -    by (subst bind_return_distr[symmetric])
  1.1219 -       (auto simp: prob_space.not_empty measurable_def comp_def)
  1.1220 -qed
  1.1221 -
  1.1222 -lemma bind_assoc_pmf: "bind_pmf (bind_pmf A B) C = bind_pmf A (\<lambda>x. bind_pmf (B x) C)"
  1.1223 -  by transfer
  1.1224 -     (auto intro!: bind_assoc[where N="count_space UNIV" and R="count_space UNIV"]
  1.1225 -           simp: measurable_def space_subprob_algebra prob_space_imp_subprob_space)
  1.1226 -
  1.1227 -end
  1.1228 -
  1.1229 -lemma map_bind_pmf: "map_pmf f (bind_pmf M g) = bind_pmf M (\<lambda>x. map_pmf f (g x))"
  1.1230 -  unfolding bind_return_pmf''[symmetric] bind_assoc_pmf[of M] ..
  1.1231 -
  1.1232 -lemma bind_map_pmf: "bind_pmf (map_pmf f M) g = bind_pmf M (\<lambda>x. g (f x))"
  1.1233 -  unfolding bind_return_pmf''[symmetric] bind_assoc_pmf bind_return_pmf ..
  1.1234 -
  1.1235 -lemma map_join_pmf: "map_pmf f (join_pmf AA) = join_pmf (map_pmf (map_pmf f) AA)"
  1.1236 -  unfolding bind_pmf_def[symmetric]
  1.1237 -  unfolding bind_return_pmf''[symmetric] join_eq_bind_pmf bind_assoc_pmf
  1.1238 -  by (simp add: bind_return_pmf'')
  1.1239 -
  1.1240 -lemma bind_pmf_cong:
  1.1241 -  "\<lbrakk> p = q; \<And>x. x \<in> set_pmf q \<Longrightarrow> f x = g x \<rbrakk>
  1.1242 -  \<Longrightarrow> bind_pmf p f = bind_pmf q g"
  1.1243 -by(simp add: bind_pmf_def cong: map_pmf_cong)
  1.1244 -
  1.1245 -lemma bind_pmf_cong_simp:
  1.1246 -  "\<lbrakk> p = q; \<And>x. x \<in> set_pmf q =simp=> f x = g x \<rbrakk>
  1.1247 -  \<Longrightarrow> bind_pmf p f = bind_pmf q g"
  1.1248 -by(simp add: simp_implies_def cong: bind_pmf_cong)
  1.1249 -
  1.1250 -definition "pair_pmf A B = bind_pmf A (\<lambda>x. bind_pmf B (\<lambda>y. return_pmf (x, y)))"
  1.1251 -
  1.1252 -lemma pmf_pair: "pmf (pair_pmf M N) (a, b) = pmf M a * pmf N b"
  1.1253 -  unfolding pair_pmf_def pmf_bind pmf_return
  1.1254 -  apply (subst integral_measure_pmf[where A="{b}"])
  1.1255 -  apply (auto simp: indicator_eq_0_iff)
  1.1256 -  apply (subst integral_measure_pmf[where A="{a}"])
  1.1257 -  apply (auto simp: indicator_eq_0_iff setsum_nonneg_eq_0_iff pmf_nonneg)
  1.1258 -  done
  1.1259 -
  1.1260 -lemma set_pair_pmf: "set_pmf (pair_pmf A B) = set_pmf A \<times> set_pmf B"
  1.1261 -  unfolding pair_pmf_def set_bind_pmf set_return_pmf by auto
  1.1262 -
  1.1263 -lemma measure_pmf_in_subprob_space[measurable (raw)]:
  1.1264 -  "measure_pmf M \<in> space (subprob_algebra (count_space UNIV))"
  1.1265 -  by (simp add: space_subprob_algebra) intro_locales
  1.1266 -
  1.1267 -lemma nn_integral_pair_pmf': "(\<integral>\<^sup>+x. f x \<partial>pair_pmf A B) = (\<integral>\<^sup>+a. \<integral>\<^sup>+b. f (a, b) \<partial>B \<partial>A)"
  1.1268 -proof -
  1.1269 -  have "(\<integral>\<^sup>+x. f x \<partial>pair_pmf A B) = (\<integral>\<^sup>+x. max 0 (f x) * indicator (A \<times> B) x \<partial>pair_pmf A B)"
  1.1270 -    by (subst nn_integral_max_0[symmetric])
  1.1271 -       (auto simp: AE_measure_pmf_iff set_pair_pmf intro!: nn_integral_cong_AE)
  1.1272 -  also have "\<dots> = (\<integral>\<^sup>+a. \<integral>\<^sup>+b. max 0 (f (a, b)) * indicator (A \<times> B) (a, b) \<partial>B \<partial>A)"
  1.1273 -    by (simp add: pair_pmf_def)
  1.1274 -  also have "\<dots> = (\<integral>\<^sup>+a. \<integral>\<^sup>+b. max 0 (f (a, b)) \<partial>B \<partial>A)"
  1.1275 -    by (auto intro!: nn_integral_cong_AE simp: AE_measure_pmf_iff)
  1.1276 -  finally show ?thesis
  1.1277 -    unfolding nn_integral_max_0 .
  1.1278 -qed
  1.1279 -
  1.1280 -lemma pair_map_pmf1: "pair_pmf (map_pmf f A) B = map_pmf (apfst f) (pair_pmf A B)"
  1.1281 -proof (safe intro!: pmf_eqI)
  1.1282 -  fix a :: "'a" and b :: "'b"
  1.1283 -  have [simp]: "\<And>c d. indicator (apfst f -` {(a, b)}) (c, d) = indicator (f -` {a}) c * (indicator {b} d::ereal)"
  1.1284 -    by (auto split: split_indicator)
  1.1285 -
  1.1286 -  have "ereal (pmf (pair_pmf (map_pmf f A) B) (a, b)) =
  1.1287 -         ereal (pmf (map_pmf (apfst f) (pair_pmf A B)) (a, b))"
  1.1288 -    unfolding pmf_pair ereal_pmf_map
  1.1289 -    by (simp add: nn_integral_pair_pmf' max_def emeasure_pmf_single nn_integral_multc pmf_nonneg
  1.1290 -                  emeasure_map_pmf[symmetric] del: emeasure_map_pmf)
  1.1291 -  then show "pmf (pair_pmf (map_pmf f A) B) (a, b) = pmf (map_pmf (apfst f) (pair_pmf A B)) (a, b)"
  1.1292 -    by simp
  1.1293 -qed
  1.1294 -
  1.1295 -lemma pair_map_pmf2: "pair_pmf A (map_pmf f B) = map_pmf (apsnd f) (pair_pmf A B)"
  1.1296 -proof (safe intro!: pmf_eqI)
  1.1297 -  fix a :: "'a" and b :: "'b"
  1.1298 -  have [simp]: "\<And>c d. indicator (apsnd f -` {(a, b)}) (c, d) = indicator {a} c * (indicator (f -` {b}) d::ereal)"
  1.1299 -    by (auto split: split_indicator)
  1.1300 -
  1.1301 -  have "ereal (pmf (pair_pmf A (map_pmf f B)) (a, b)) =
  1.1302 -         ereal (pmf (map_pmf (apsnd f) (pair_pmf A B)) (a, b))"
  1.1303 -    unfolding pmf_pair ereal_pmf_map
  1.1304 -    by (simp add: nn_integral_pair_pmf' max_def emeasure_pmf_single nn_integral_cmult nn_integral_multc pmf_nonneg
  1.1305 -                  emeasure_map_pmf[symmetric] del: emeasure_map_pmf)
  1.1306 -  then show "pmf (pair_pmf A (map_pmf f B)) (a, b) = pmf (map_pmf (apsnd f) (pair_pmf A B)) (a, b)"
  1.1307 -    by simp
  1.1308 -qed
  1.1309 -
  1.1310 -lemma map_pair: "map_pmf (\<lambda>(a, b). (f a, g b)) (pair_pmf A B) = pair_pmf (map_pmf f A) (map_pmf g B)"
  1.1311 -  by (simp add: pair_map_pmf2 pair_map_pmf1 map_pmf_comp split_beta')
  1.1312 -
  1.1313 -lemma bind_pair_pmf:
  1.1314 -  assumes M[measurable]: "M \<in> measurable (count_space UNIV \<Otimes>\<^sub>M count_space UNIV) (subprob_algebra N)"
  1.1315 -  shows "measure_pmf (pair_pmf A B) \<guillemotright>= M = (measure_pmf A \<guillemotright>= (\<lambda>x. measure_pmf B \<guillemotright>= (\<lambda>y. M (x, y))))"
  1.1316 -    (is "?L = ?R")
  1.1317 -proof (rule measure_eqI)
  1.1318 -  have M'[measurable]: "M \<in> measurable (pair_pmf A B) (subprob_algebra N)"
  1.1319 -    using M[THEN measurable_space] by (simp_all add: space_pair_measure)
  1.1320 -
  1.1321 -  note measurable_bind[where N="count_space UNIV", measurable]
  1.1322 -  note measure_pmf_in_subprob_space[simp]
  1.1323 -
  1.1324 -  have sets_eq_N: "sets ?L = N"
  1.1325 -    by (subst sets_bind[OF sets_kernel[OF M']]) auto
  1.1326 -  show "sets ?L = sets ?R"
  1.1327 -    using measurable_space[OF M]
  1.1328 -    by (simp add: sets_eq_N space_pair_measure space_subprob_algebra)
  1.1329 -  fix X assume "X \<in> sets ?L"
  1.1330 -  then have X[measurable]: "X \<in> sets N"
  1.1331 -    unfolding sets_eq_N .
  1.1332 -  then show "emeasure ?L X = emeasure ?R X"
  1.1333 -    apply (simp add: emeasure_bind[OF _ M' X])
  1.1334 -    apply (simp add: nn_integral_bind[where B="count_space UNIV"] pair_pmf_def measure_pmf_bind[of A]
  1.1335 -      nn_integral_measure_pmf_finite set_return_pmf emeasure_nonneg pmf_return one_ereal_def[symmetric])
  1.1336 -    apply (subst emeasure_bind[OF _ _ X])
  1.1337 -    apply measurable
  1.1338 -    apply (subst emeasure_bind[OF _ _ X])
  1.1339 -    apply measurable
  1.1340 -    done
  1.1341 -qed
  1.1342 -
  1.1343 -lemma join_map_return_pmf: "join_pmf (map_pmf return_pmf A) = A"
  1.1344 -  unfolding bind_pmf_def[symmetric] bind_return_pmf' ..
  1.1345 -
  1.1346 -lemma map_fst_pair_pmf: "map_pmf fst (pair_pmf A B) = A"
  1.1347 -  by (simp add: pair_pmf_def bind_return_pmf''[symmetric] bind_assoc_pmf bind_return_pmf bind_return_pmf')
  1.1348 -
  1.1349 -lemma map_snd_pair_pmf: "map_pmf snd (pair_pmf A B) = B"
  1.1350 -  by (simp add: pair_pmf_def bind_return_pmf''[symmetric] bind_assoc_pmf bind_return_pmf bind_return_pmf')
  1.1351 -
  1.1352 -lemma nn_integral_pmf':
  1.1353 -  "inj_on f A \<Longrightarrow> (\<integral>\<^sup>+x. pmf p (f x) \<partial>count_space A) = emeasure p (f ` A)"
  1.1354 -  by (subst nn_integral_bij_count_space[where g=f and B="f`A"])
  1.1355 -     (auto simp: bij_betw_def nn_integral_pmf)
  1.1356 -
  1.1357 -lemma pmf_le_0_iff[simp]: "pmf M p \<le> 0 \<longleftrightarrow> pmf M p = 0"
  1.1358 -  using pmf_nonneg[of M p] by simp
  1.1359 -
  1.1360 -lemma min_pmf_0[simp]: "min (pmf M p) 0 = 0" "min 0 (pmf M p) = 0"
  1.1361 -  using pmf_nonneg[of M p] by simp_all
  1.1362 -
  1.1363 -lemma pmf_eq_0_set_pmf: "pmf M p = 0 \<longleftrightarrow> p \<notin> set_pmf M"
  1.1364 -  unfolding set_pmf_iff by simp
  1.1365 -
  1.1366 -lemma pmf_map_inj: "inj_on f (set_pmf M) \<Longrightarrow> x \<in> set_pmf M \<Longrightarrow> pmf (map_pmf f M) (f x) = pmf M x"
  1.1367 -  by (auto simp: pmf.rep_eq map_pmf.rep_eq measure_distr AE_measure_pmf_iff inj_onD
  1.1368 -           intro!: measure_pmf.finite_measure_eq_AE)
  1.1369 -
  1.1370 -subsection \<open> Conditional Probabilities \<close>
  1.1371 -
  1.1372 -context
  1.1373 -  fixes p :: "'a pmf" and s :: "'a set"
  1.1374 -  assumes not_empty: "set_pmf p \<inter> s \<noteq> {}"
  1.1375 -begin
  1.1376 -
  1.1377 -interpretation pmf_as_measure .
  1.1378 -
  1.1379 -lemma emeasure_measure_pmf_not_zero: "emeasure (measure_pmf p) s \<noteq> 0"
  1.1380 -proof
  1.1381 -  assume "emeasure (measure_pmf p) s = 0"
  1.1382 -  then have "AE x in measure_pmf p. x \<notin> s"
  1.1383 -    by (rule AE_I[rotated]) auto
  1.1384 -  with not_empty show False
  1.1385 -    by (auto simp: AE_measure_pmf_iff)
  1.1386 -qed
  1.1387 -
  1.1388 -lemma measure_measure_pmf_not_zero: "measure (measure_pmf p) s \<noteq> 0"
  1.1389 -  using emeasure_measure_pmf_not_zero unfolding measure_pmf.emeasure_eq_measure by simp
  1.1390 -
  1.1391 -lift_definition cond_pmf :: "'a pmf" is
  1.1392 -  "uniform_measure (measure_pmf p) s"
  1.1393 -proof (intro conjI)
  1.1394 -  show "prob_space (uniform_measure (measure_pmf p) s)"
  1.1395 -    by (intro prob_space_uniform_measure) (auto simp: emeasure_measure_pmf_not_zero)
  1.1396 -  show "AE x in uniform_measure (measure_pmf p) s. measure (uniform_measure (measure_pmf p) s) {x} \<noteq> 0"
  1.1397 -    by (simp add: emeasure_measure_pmf_not_zero measure_measure_pmf_not_zero AE_uniform_measure
  1.1398 -                  AE_measure_pmf_iff set_pmf.rep_eq)
  1.1399 -qed simp
  1.1400 -
  1.1401 -lemma pmf_cond: "pmf cond_pmf x = (if x \<in> s then pmf p x / measure p s else 0)"
  1.1402 -  by transfer (simp add: emeasure_measure_pmf_not_zero pmf.rep_eq)
  1.1403 -
  1.1404 -lemma set_cond_pmf: "set_pmf cond_pmf = set_pmf p \<inter> s"
  1.1405 -  by (auto simp add: set_pmf_iff pmf_cond measure_measure_pmf_not_zero split: split_if_asm)
  1.1406 -
  1.1407 -end
  1.1408 -
  1.1409 -lemma cond_map_pmf:
  1.1410 -  assumes "set_pmf p \<inter> f -` s \<noteq> {}"
  1.1411 -  shows "cond_pmf (map_pmf f p) s = map_pmf f (cond_pmf p (f -` s))"
  1.1412 -proof -
  1.1413 -  have *: "set_pmf (map_pmf f p) \<inter> s \<noteq> {}"
  1.1414 -    using assms by (simp add: set_map_pmf) auto
  1.1415 -  { fix x
  1.1416 -    have "ereal (pmf (map_pmf f (cond_pmf p (f -` s))) x) =
  1.1417 -      emeasure p (f -` s \<inter> f -` {x}) / emeasure p (f -` s)"
  1.1418 -      unfolding ereal_pmf_map cond_pmf.rep_eq[OF assms] by (simp add: nn_integral_uniform_measure)
  1.1419 -    also have "f -` s \<inter> f -` {x} = (if x \<in> s then f -` {x} else {})"
  1.1420 -      by auto
  1.1421 -    also have "emeasure p (if x \<in> s then f -` {x} else {}) / emeasure p (f -` s) =
  1.1422 -      ereal (pmf (cond_pmf (map_pmf f p) s) x)"
  1.1423 -      using measure_measure_pmf_not_zero[OF *]
  1.1424 -      by (simp add: pmf_cond[OF *] ereal_divide' ereal_pmf_map measure_pmf.emeasure_eq_measure[symmetric]
  1.1425 -               del: ereal_divide)
  1.1426 -    finally have "ereal (pmf (cond_pmf (map_pmf f p) s) x) = ereal (pmf (map_pmf f (cond_pmf p (f -` s))) x)"
  1.1427 -      by simp }
  1.1428 -  then show ?thesis
  1.1429 -    by (intro pmf_eqI) simp
  1.1430 -qed
  1.1431 -
  1.1432 -lemma bind_cond_pmf_cancel:
  1.1433 -  assumes in_S: "\<And>x. x \<in> set_pmf p \<Longrightarrow> x \<in> S x" "\<And>x. x \<in> set_pmf q \<Longrightarrow> x \<in> S x"
  1.1434 -  assumes S_eq: "\<And>x y. x \<in> S y \<Longrightarrow> S x = S y"
  1.1435 -  and same: "\<And>x. measure (measure_pmf p) (S x) = measure (measure_pmf q) (S x)"
  1.1436 -  shows "bind_pmf p (\<lambda>x. cond_pmf q (S x)) = q" (is "?lhs = _")
  1.1437 -proof (rule pmf_eqI)
  1.1438 -  { fix x
  1.1439 -    assume "x \<in> set_pmf p"
  1.1440 -    hence "set_pmf p \<inter> (S x) \<noteq> {}" using in_S by auto
  1.1441 -    hence "measure (measure_pmf p) (S x) \<noteq> 0"
  1.1442 -      by(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff)
  1.1443 -    with same have "measure (measure_pmf q) (S x) \<noteq> 0" by simp
  1.1444 -    hence "set_pmf q \<inter> S x \<noteq> {}"
  1.1445 -      by(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff) }
  1.1446 -  note [simp] = this
  1.1447 -
  1.1448 -  fix z
  1.1449 -  have pmf_q_z: "z \<notin> S z \<Longrightarrow> pmf q z = 0"
  1.1450 -    by(erule contrapos_np)(simp add: pmf_eq_0_set_pmf in_S)
  1.1451 -
  1.1452 -  have "ereal (pmf ?lhs z) = \<integral>\<^sup>+ x. ereal (pmf (cond_pmf q (S x)) z) \<partial>measure_pmf p"
  1.1453 -    by(simp add: ereal_pmf_bind)
  1.1454 -  also have "\<dots> = \<integral>\<^sup>+ x. ereal (pmf q z / measure p (S z)) * indicator (S z) x \<partial>measure_pmf p"
  1.1455 -    by(rule nn_integral_cong_AE)(auto simp add: AE_measure_pmf_iff pmf_cond same pmf_q_z in_S dest!: S_eq split: split_indicator)
  1.1456 -  also have "\<dots> = pmf q z" using pmf_nonneg[of q z]
  1.1457 -    by (subst nn_integral_cmult)(auto simp add: measure_nonneg measure_pmf.emeasure_eq_measure same measure_pmf.prob_eq_0 AE_measure_pmf_iff pmf_eq_0_set_pmf in_S)
  1.1458 -  finally show "pmf ?lhs z = pmf q z" by simp
  1.1459 -qed
  1.1460 -
  1.1461 -inductive rel_pmf :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> 'a pmf \<Rightarrow> 'b pmf \<Rightarrow> bool"
  1.1462 -for R p q
  1.1463 -where
  1.1464 -  "\<lbrakk> \<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> R x y; 
  1.1465 -     map_pmf fst pq = p; map_pmf snd pq = q \<rbrakk>
  1.1466 -  \<Longrightarrow> rel_pmf R p q"
  1.1467 -
  1.1468 -bnf pmf: "'a pmf" map: map_pmf sets: set_pmf bd : "natLeq" rel: rel_pmf
  1.1469 -proof -
  1.1470 -  show "map_pmf id = id" by (rule map_pmf_id)
  1.1471 -  show "\<And>f g. map_pmf (f \<circ> g) = map_pmf f \<circ> map_pmf g" by (rule map_pmf_compose) 
  1.1472 -  show "\<And>f g::'a \<Rightarrow> 'b. \<And>p. (\<And>x. x \<in> set_pmf p \<Longrightarrow> f x = g x) \<Longrightarrow> map_pmf f p = map_pmf g p"
  1.1473 -    by (intro map_pmf_cong refl)
  1.1474 -
  1.1475 -  show "\<And>f::'a \<Rightarrow> 'b. set_pmf \<circ> map_pmf f = op ` f \<circ> set_pmf"
  1.1476 -    by (rule pmf_set_map)
  1.1477 -
  1.1478 -  { fix p :: "'s pmf"
  1.1479 -    have "(card_of (set_pmf p), card_of (UNIV :: nat set)) \<in> ordLeq"
  1.1480 -      by (rule card_of_ordLeqI[where f="to_nat_on (set_pmf p)"])
  1.1481 -         (auto intro: countable_set_pmf)
  1.1482 -    also have "(card_of (UNIV :: nat set), natLeq) \<in> ordLeq"
  1.1483 -      by (metis Field_natLeq card_of_least natLeq_Well_order)
  1.1484 -    finally show "(card_of (set_pmf p), natLeq) \<in> ordLeq" . }
  1.1485 -
  1.1486 -  show "\<And>R. rel_pmf R =
  1.1487 -         (BNF_Def.Grp {x. set_pmf x \<subseteq> {(x, y). R x y}} (map_pmf fst))\<inverse>\<inverse> OO
  1.1488 -         BNF_Def.Grp {x. set_pmf x \<subseteq> {(x, y). R x y}} (map_pmf snd)"
  1.1489 -     by (auto simp add: fun_eq_iff BNF_Def.Grp_def OO_def rel_pmf.simps)
  1.1490 -
  1.1491 -  { fix p :: "'a pmf" and f :: "'a \<Rightarrow> 'b" and g x
  1.1492 -    assume p: "\<And>z. z \<in> set_pmf p \<Longrightarrow> f z = g z"
  1.1493 -      and x: "x \<in> set_pmf p"
  1.1494 -    thus "f x = g x" by simp }
  1.1495 -
  1.1496 -  fix R :: "'a => 'b \<Rightarrow> bool" and S :: "'b \<Rightarrow> 'c \<Rightarrow> bool"
  1.1497 -  { fix p q r
  1.1498 -    assume pq: "rel_pmf R p q"
  1.1499 -      and qr:"rel_pmf S q r"
  1.1500 -    from pq obtain pq where pq: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> R x y"
  1.1501 -      and p: "p = map_pmf fst pq" and q: "q = map_pmf snd pq" by cases auto
  1.1502 -    from qr obtain qr where qr: "\<And>y z. (y, z) \<in> set_pmf qr \<Longrightarrow> S y z"
  1.1503 -      and q': "q = map_pmf fst qr" and r: "r = map_pmf snd qr" by cases auto
  1.1504 -
  1.1505 -    def pr \<equiv> "bind_pmf pq (\<lambda>(x, y). bind_pmf (cond_pmf qr {(y', z). y' = y}) (\<lambda>(y', z). return_pmf (x, z)))"
  1.1506 -    have pr_welldefined: "\<And>y. y \<in> q \<Longrightarrow> qr \<inter> {(y', z). y' = y} \<noteq> {}"
  1.1507 -      by (force simp: q' set_map_pmf)
  1.1508 -
  1.1509 -    have "rel_pmf (R OO S) p r"
  1.1510 -    proof (rule rel_pmf.intros)
  1.1511 -      fix x z assume "(x, z) \<in> pr"
  1.1512 -      then have "\<exists>y. (x, y) \<in> pq \<and> (y, z) \<in> qr"
  1.1513 -        by (auto simp: q pr_welldefined pr_def set_bind_pmf split_beta set_return_pmf set_cond_pmf set_map_pmf)
  1.1514 -      with pq qr show "(R OO S) x z"
  1.1515 -        by blast
  1.1516 -    next
  1.1517 -      have "map_pmf snd pr = map_pmf snd (bind_pmf q (\<lambda>y. cond_pmf qr {(y', z). y' = y}))"
  1.1518 -        by (simp add: pr_def q split_beta bind_map_pmf bind_return_pmf'' map_bind_pmf map_return_pmf)
  1.1519 -      then show "map_pmf snd pr = r"
  1.1520 -        unfolding r q' bind_map_pmf by (subst (asm) bind_cond_pmf_cancel) auto
  1.1521 -    qed (simp add: pr_def map_bind_pmf split_beta map_return_pmf bind_return_pmf'' p) }
  1.1522 -  then show "rel_pmf R OO rel_pmf S \<le> rel_pmf (R OO S)"
  1.1523 -    by(auto simp add: le_fun_def)
  1.1524 -qed (fact natLeq_card_order natLeq_cinfinite)+
  1.1525 -
  1.1526 -lemma rel_pmf_return_pmf1: "rel_pmf R (return_pmf x) M \<longleftrightarrow> (\<forall>a\<in>M. R x a)"
  1.1527 -proof safe
  1.1528 -  fix a assume "a \<in> M" "rel_pmf R (return_pmf x) M"
  1.1529 -  then obtain pq where *: "\<And>a b. (a, b) \<in> set_pmf pq \<Longrightarrow> R a b"
  1.1530 -    and eq: "return_pmf x = map_pmf fst pq" "M = map_pmf snd pq"
  1.1531 -    by (force elim: rel_pmf.cases)
  1.1532 -  moreover have "set_pmf (return_pmf x) = {x}"
  1.1533 -    by (simp add: set_return_pmf)
  1.1534 -  with `a \<in> M` have "(x, a) \<in> pq"
  1.1535 -    by (force simp: eq set_map_pmf)
  1.1536 -  with * show "R x a"
  1.1537 -    by auto
  1.1538 -qed (auto intro!: rel_pmf.intros[where pq="pair_pmf (return_pmf x) M"]
  1.1539 -          simp: map_fst_pair_pmf map_snd_pair_pmf set_pair_pmf set_return_pmf)
  1.1540 -
  1.1541 -lemma rel_pmf_return_pmf2: "rel_pmf R M (return_pmf x) \<longleftrightarrow> (\<forall>a\<in>M. R a x)"
  1.1542 -  by (subst pmf.rel_flip[symmetric]) (simp add: rel_pmf_return_pmf1)
  1.1543 -
  1.1544 -lemma rel_return_pmf[simp]: "rel_pmf R (return_pmf x1) (return_pmf x2) = R x1 x2"
  1.1545 -  unfolding rel_pmf_return_pmf2 set_return_pmf by simp
  1.1546 -
  1.1547 -lemma rel_pmf_False[simp]: "rel_pmf (\<lambda>x y. False) x y = False"
  1.1548 -  unfolding pmf.in_rel fun_eq_iff using set_pmf_not_empty by fastforce
  1.1549 -
  1.1550 -lemma rel_pmf_rel_prod:
  1.1551 -  "rel_pmf (rel_prod R S) (pair_pmf A A') (pair_pmf B B') \<longleftrightarrow> rel_pmf R A B \<and> rel_pmf S A' B'"
  1.1552 -proof safe
  1.1553 -  assume "rel_pmf (rel_prod R S) (pair_pmf A A') (pair_pmf B B')"
  1.1554 -  then obtain pq where pq: "\<And>a b c d. ((a, c), (b, d)) \<in> set_pmf pq \<Longrightarrow> R a b \<and> S c d"
  1.1555 -    and eq: "map_pmf fst pq = pair_pmf A A'" "map_pmf snd pq = pair_pmf B B'"
  1.1556 -    by (force elim: rel_pmf.cases)
  1.1557 -  show "rel_pmf R A B"
  1.1558 -  proof (rule rel_pmf.intros)
  1.1559 -    let ?f = "\<lambda>(a, b). (fst a, fst b)"
  1.1560 -    have [simp]: "(\<lambda>x. fst (?f x)) = fst o fst" "(\<lambda>x. snd (?f x)) = fst o snd"
  1.1561 -      by auto
  1.1562 -
  1.1563 -    show "map_pmf fst (map_pmf ?f pq) = A"
  1.1564 -      by (simp add: map_pmf_comp pmf.map_comp[symmetric] eq map_fst_pair_pmf)
  1.1565 -    show "map_pmf snd (map_pmf ?f pq) = B"
  1.1566 -      by (simp add: map_pmf_comp pmf.map_comp[symmetric] eq map_fst_pair_pmf)
  1.1567 -
  1.1568 -    fix a b assume "(a, b) \<in> set_pmf (map_pmf ?f pq)"
  1.1569 -    then obtain c d where "((a, c), (b, d)) \<in> set_pmf pq"
  1.1570 -      by (auto simp: set_map_pmf)
  1.1571 -    from pq[OF this] show "R a b" ..
  1.1572 -  qed
  1.1573 -  show "rel_pmf S A' B'"
  1.1574 -  proof (rule rel_pmf.intros)
  1.1575 -    let ?f = "\<lambda>(a, b). (snd a, snd b)"
  1.1576 -    have [simp]: "(\<lambda>x. fst (?f x)) = snd o fst" "(\<lambda>x. snd (?f x)) = snd o snd"
  1.1577 -      by auto
  1.1578 -
  1.1579 -    show "map_pmf fst (map_pmf ?f pq) = A'"
  1.1580 -      by (simp add: map_pmf_comp pmf.map_comp[symmetric] eq map_snd_pair_pmf)
  1.1581 -    show "map_pmf snd (map_pmf ?f pq) = B'"
  1.1582 -      by (simp add: map_pmf_comp pmf.map_comp[symmetric] eq map_snd_pair_pmf)
  1.1583 -
  1.1584 -    fix c d assume "(c, d) \<in> set_pmf (map_pmf ?f pq)"
  1.1585 -    then obtain a b where "((a, c), (b, d)) \<in> set_pmf pq"
  1.1586 -      by (auto simp: set_map_pmf)
  1.1587 -    from pq[OF this] show "S c d" ..
  1.1588 -  qed
  1.1589 -next
  1.1590 -  assume "rel_pmf R A B" "rel_pmf S A' B'"
  1.1591 -  then obtain Rpq Spq
  1.1592 -    where Rpq: "\<And>a b. (a, b) \<in> set_pmf Rpq \<Longrightarrow> R a b"
  1.1593 -        "map_pmf fst Rpq = A" "map_pmf snd Rpq = B"
  1.1594 -      and Spq: "\<And>a b. (a, b) \<in> set_pmf Spq \<Longrightarrow> S a b"
  1.1595 -        "map_pmf fst Spq = A'" "map_pmf snd Spq = B'"
  1.1596 -    by (force elim: rel_pmf.cases)
  1.1597 -
  1.1598 -  let ?f = "(\<lambda>((a, c), (b, d)). ((a, b), (c, d)))"
  1.1599 -  let ?pq = "map_pmf ?f (pair_pmf Rpq Spq)"
  1.1600 -  have [simp]: "(\<lambda>x. fst (?f x)) = (\<lambda>(a, b). (fst a, fst b))" "(\<lambda>x. snd (?f x)) = (\<lambda>(a, b). (snd a, snd b))"
  1.1601 -    by auto
  1.1602 -
  1.1603 -  show "rel_pmf (rel_prod R S) (pair_pmf A A') (pair_pmf B B')"
  1.1604 -    by (rule rel_pmf.intros[where pq="?pq"])
  1.1605 -       (auto simp: map_snd_pair_pmf map_fst_pair_pmf set_pair_pmf set_map_pmf map_pmf_comp Rpq Spq
  1.1606 -                   map_pair)
  1.1607 -qed
  1.1608 -
  1.1609 -lemma rel_pmf_reflI: 
  1.1610 -  assumes "\<And>x. x \<in> set_pmf p \<Longrightarrow> P x x"
  1.1611 -  shows "rel_pmf P p p"
  1.1612 -by(rule rel_pmf.intros[where pq="map_pmf (\<lambda>x. (x, x)) p"])(auto simp add: pmf.map_comp o_def set_map_pmf assms)
  1.1613 -
  1.1614 -lemma rel_pmf_joinI:
  1.1615 -  assumes "rel_pmf (rel_pmf P) p q"
  1.1616 -  shows "rel_pmf P (join_pmf p) (join_pmf q)"
  1.1617 -proof -
  1.1618 -  from assms obtain pq where p: "p = map_pmf fst pq"
  1.1619 -    and q: "q = map_pmf snd pq"
  1.1620 -    and P: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> rel_pmf P x y"
  1.1621 -    by cases auto
  1.1622 -  from P obtain PQ 
  1.1623 -    where PQ: "\<And>x y a b. \<lbrakk> (x, y) \<in> set_pmf pq; (a, b) \<in> set_pmf (PQ x y) \<rbrakk> \<Longrightarrow> P a b"
  1.1624 -    and x: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> map_pmf fst (PQ x y) = x"
  1.1625 -    and y: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> map_pmf snd (PQ x y) = y"
  1.1626 -    by(metis rel_pmf.simps)
  1.1627 -
  1.1628 -  let ?r = "bind_pmf pq (\<lambda>(x, y). PQ x y)"
  1.1629 -  have "\<And>a b. (a, b) \<in> set_pmf ?r \<Longrightarrow> P a b" by(auto simp add: set_bind_pmf intro: PQ)
  1.1630 -  moreover have "map_pmf fst ?r = join_pmf p" "map_pmf snd ?r = join_pmf q"
  1.1631 -    by(simp_all add: bind_pmf_def map_join_pmf pmf.map_comp o_def split_def p q x y cong: pmf.map_cong)
  1.1632 -  ultimately show ?thesis ..
  1.1633 -qed
  1.1634 -
  1.1635 -lemma rel_pmf_bindI:
  1.1636 -  assumes pq: "rel_pmf R p q"
  1.1637 -  and fg: "\<And>x y. R x y \<Longrightarrow> rel_pmf P (f x) (g y)"
  1.1638 -  shows "rel_pmf P (bind_pmf p f) (bind_pmf q g)"
  1.1639 -unfolding bind_pmf_def
  1.1640 -by(rule rel_pmf_joinI)(auto simp add: pmf.rel_map intro: pmf.rel_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp, OF _ pq] fg)
  1.1641 -
  1.1642 -text {*
  1.1643 -  Proof that @{const rel_pmf} preserves orders.
  1.1644 -  Antisymmetry proof follows Thm. 1 in N. Saheb-Djahromi, Cpo's of measures for nondeterminism, 
  1.1645 -  Theoretical Computer Science 12(1):19--37, 1980, 
  1.1646 -  @{url "http://dx.doi.org/10.1016/0304-3975(80)90003-1"}
  1.1647 -*}
  1.1648 -
  1.1649 -lemma 
  1.1650 -  assumes *: "rel_pmf R p q"
  1.1651 -  and refl: "reflp R" and trans: "transp R"
  1.1652 -  shows measure_Ici: "measure p {y. R x y} \<le> measure q {y. R x y}" (is ?thesis1)
  1.1653 -  and measure_Ioi: "measure p {y. R x y \<and> \<not> R y x} \<le> measure q {y. R x y \<and> \<not> R y x}" (is ?thesis2)
  1.1654 -proof -
  1.1655 -  from * obtain pq
  1.1656 -    where pq: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> R x y"
  1.1657 -    and p: "p = map_pmf fst pq"
  1.1658 -    and q: "q = map_pmf snd pq"
  1.1659 -    by cases auto
  1.1660 -  show ?thesis1 ?thesis2 unfolding p q map_pmf.rep_eq using refl trans
  1.1661 -    by(auto 4 3 simp add: measure_distr reflpD AE_measure_pmf_iff intro!: measure_pmf.finite_measure_mono_AE dest!: pq elim: transpE)
  1.1662 -qed
  1.1663 -
  1.1664 -lemma rel_pmf_inf:
  1.1665 -  fixes p q :: "'a pmf"
  1.1666 -  assumes 1: "rel_pmf R p q"
  1.1667 -  assumes 2: "rel_pmf R q p"
  1.1668 -  and refl: "reflp R" and trans: "transp R"
  1.1669 -  shows "rel_pmf (inf R R\<inverse>\<inverse>) p q"
  1.1670 -proof
  1.1671 -  let ?E = "\<lambda>x. {y. R x y \<and> R y x}"
  1.1672 -  let ?\<mu>E = "\<lambda>x. measure q (?E x)"
  1.1673 -  { fix x
  1.1674 -    have "measure p (?E x) = measure p ({y. R x y} - {y. R x y \<and> \<not> R y x})"
  1.1675 -      by(auto intro!: arg_cong[where f="measure p"])
  1.1676 -    also have "\<dots> = measure p {y. R x y} - measure p {y. R x y \<and> \<not> R y x}"
  1.1677 -      by (rule measure_pmf.finite_measure_Diff) auto
  1.1678 -    also have "measure p {y. R x y \<and> \<not> R y x} = measure q {y. R x y \<and> \<not> R y x}"
  1.1679 -      using 1 2 refl trans by(auto intro!: Orderings.antisym measure_Ioi)
  1.1680 -    also have "measure p {y. R x y} = measure q {y. R x y}"
  1.1681 -      using 1 2 refl trans by(auto intro!: Orderings.antisym measure_Ici)
  1.1682 -    also have "measure q {y. R x y} - measure q {y. R x y \<and> ~ R y x} =
  1.1683 -      measure q ({y. R x y} - {y. R x y \<and> \<not> R y x})"
  1.1684 -      by(rule measure_pmf.finite_measure_Diff[symmetric]) auto
  1.1685 -    also have "\<dots> = ?\<mu>E x"
  1.1686 -      by(auto intro!: arg_cong[where f="measure q"])
  1.1687 -    also note calculation }
  1.1688 -  note eq = this
  1.1689 -
  1.1690 -  def pq \<equiv> "bind_pmf p (\<lambda>x. bind_pmf (cond_pmf q (?E x)) (\<lambda>y. return_pmf (x, y)))"
  1.1691 -
  1.1692 -  show "map_pmf fst pq = p"
  1.1693 -    by(simp add: pq_def map_bind_pmf map_return_pmf bind_return_pmf')
  1.1694 -
  1.1695 -  show "map_pmf snd pq = q"
  1.1696 -    unfolding pq_def map_bind_pmf map_return_pmf bind_return_pmf' snd_conv
  1.1697 -    by(subst bind_cond_pmf_cancel)(auto simp add: reflpD[OF \<open>reflp R\<close>] eq  intro: transpD[OF \<open>transp R\<close>])
  1.1698 -
  1.1699 -  fix x y
  1.1700 -  assume "(x, y) \<in> set_pmf pq"
  1.1701 -  moreover
  1.1702 -  { assume "x \<in> set_pmf p"
  1.1703 -    hence "measure (measure_pmf p) (?E x) \<noteq> 0"
  1.1704 -      by(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff intro: reflpD[OF \<open>reflp R\<close>])
  1.1705 -    hence "measure (measure_pmf q) (?E x) \<noteq> 0" using eq by simp
  1.1706 -    hence "set_pmf q \<inter> {y. R x y \<and> R y x} \<noteq> {}" 
  1.1707 -      by(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff) }
  1.1708 -  ultimately show "inf R R\<inverse>\<inverse> x y"
  1.1709 -    by(auto simp add: pq_def set_bind_pmf set_return_pmf set_cond_pmf)
  1.1710 -qed
  1.1711 -
  1.1712 -lemma rel_pmf_antisym:
  1.1713 -  fixes p q :: "'a pmf"
  1.1714 -  assumes 1: "rel_pmf R p q"
  1.1715 -  assumes 2: "rel_pmf R q p"
  1.1716 -  and refl: "reflp R" and trans: "transp R" and antisym: "antisymP R"
  1.1717 -  shows "p = q"
  1.1718 -proof -
  1.1719 -  from 1 2 refl trans have "rel_pmf (inf R R\<inverse>\<inverse>) p q" by(rule rel_pmf_inf)
  1.1720 -  also have "inf R R\<inverse>\<inverse> = op ="
  1.1721 -    using refl antisym by(auto intro!: ext simp add: reflpD dest: antisymD)
  1.1722 -  finally show ?thesis unfolding pmf.rel_eq .
  1.1723 -qed
  1.1724 -
  1.1725 -lemma reflp_rel_pmf: "reflp R \<Longrightarrow> reflp (rel_pmf R)"
  1.1726 -by(blast intro: reflpI rel_pmf_reflI reflpD)
  1.1727 -
  1.1728 -lemma antisymP_rel_pmf:
  1.1729 -  "\<lbrakk> reflp R; transp R; antisymP R \<rbrakk>
  1.1730 -  \<Longrightarrow> antisymP (rel_pmf R)"
  1.1731 -by(rule antisymI)(blast intro: rel_pmf_antisym)
  1.1732 -
  1.1733 -lemma transp_rel_pmf:
  1.1734 -  assumes "transp R"
  1.1735 -  shows "transp (rel_pmf R)"
  1.1736 -proof (rule transpI)
  1.1737 -  fix x y z
  1.1738 -  assume "rel_pmf R x y" and "rel_pmf R y z"
  1.1739 -  hence "rel_pmf (R OO R) x z" by (simp add: pmf.rel_compp relcompp.relcompI)
  1.1740 -  thus "rel_pmf R x z"
  1.1741 -    using assms by (metis (no_types) pmf.rel_mono rev_predicate2D transp_relcompp_less_eq)
  1.1742 -qed
  1.1743 -
  1.1744 -end
  1.1745 -