src/HOL/Probability/Giry_Monad.thy
changeset 78517 28c1f4f5335f
parent 75607 3c544d64c218
child 81995 d67dadd69d07
--- a/src/HOL/Probability/Giry_Monad.thy	Sat Aug 12 10:09:29 2023 +0100
+++ b/src/HOL/Probability/Giry_Monad.thy	Mon Aug 21 18:38:25 2023 +0100
@@ -31,7 +31,7 @@
 qed
 
 lemma (in subprob_space) emeasure_subprob_space_less_top: "emeasure M A \<noteq> top"
-  using emeasure_finite[of A] .
+  by simp
 
 lemma prob_space_imp_subprob_space:
   "prob_space M \<Longrightarrow> subprob_space M"
@@ -44,10 +44,10 @@
   by (rule subprob_spaceI) (simp_all add: emeasure_space_1 not_empty)
 
 lemma subprob_space_sigma [simp]: "\<Omega> \<noteq> {} \<Longrightarrow> subprob_space (sigma \<Omega> X)"
-by(rule subprob_spaceI)(simp_all add: emeasure_sigma space_measure_of_conv)
+  by(rule subprob_spaceI)(simp_all add: emeasure_sigma space_measure_of_conv)
 
 lemma subprob_space_null_measure: "space M \<noteq> {} \<Longrightarrow> subprob_space (null_measure M)"
-by(simp add: null_measure_def)
+  by(simp add: null_measure_def)
 
 lemma (in subprob_space) subprob_space_distr:
   assumes f: "f \<in> measurable M M'" and "space M' \<noteq> {}" shows "subprob_space (distr M M' f)"
@@ -343,7 +343,7 @@
   assumes [measurable]: "f \<in> measurable M N"
   shows "(\<lambda>M'. distr M' N f) \<in> measurable (subprob_algebra M) (subprob_algebra N)"
 proof (cases "space N = {}")
-  assume not_empty: "space N \<noteq> {}"
+  case False
   show ?thesis
   proof (rule measurable_subprob_algebra)
     fix A assume A: "A \<in> sets N"
@@ -355,8 +355,8 @@
     also have "\<dots>"
       using A by (intro measurable_emeasure_subprob_algebra) simp
     finally show "(\<lambda>M'. emeasure (distr M' N f) A) \<in> borel_measurable (subprob_algebra M)" .
-  qed (auto intro!: subprob_space.subprob_space_distr simp: space_subprob_algebra not_empty cong: measurable_cong_sets)
-qed (insert assms, auto simp: measurable_empty_iff space_subprob_algebra_empty_iff)
+  qed (auto intro!: subprob_space.subprob_space_distr simp: space_subprob_algebra False cong: measurable_cong_sets)
+qed (use assms in \<open>auto simp: measurable_empty_iff space_subprob_algebra_empty_iff\<close>)
 
 lemma emeasure_space_subprob_algebra[measurable]:
   "(\<lambda>a. emeasure a (space a)) \<in> borel_measurable (subprob_algebra N)"
@@ -565,10 +565,7 @@
 
 lemma subprob_space_return_ne:
   assumes "space M \<noteq> {}" shows "subprob_space (return M x)"
-proof
-  show "emeasure (return M x) (space (return M x)) \<le> 1"
-    by (subst emeasure_return) (auto split: split_indicator)
-qed (simp, fact)
+  by (metis assms emeasure_return indicator_simps(2) sets.top space_return subprob_spaceI subprob_space_return zero_le)
 
 lemma measure_return: assumes X: "X \<in> sets M" shows "measure (return M x) X = indicator X x"
   unfolding measure_def emeasure_return[OF X, of x] by (simp split: split_indicator)
@@ -654,8 +651,7 @@
         by (auto simp add: emeasure_distr f_M' cong: measurable_cong_sets)
       also have "\<dots> = (\<integral>\<^sup>+M''. emeasure (g x) (f M'' -` A \<inter> space M) \<partial>?R)"
         apply (subst emeasure_pair_measure_alt)
-        apply (rule measurable_sets[OF _ A])
-        apply (auto simp add: f_M' cong: measurable_cong_sets)
+        apply (force simp add: f_M' cong: measurable_cong_sets intro!: measurable_sets[OF _ A])
         apply (intro nn_integral_cong arg_cong[where f="emeasure (g x)"])
         apply (auto simp: space_subprob_algebra space_pair_measure)
         done
@@ -668,10 +664,10 @@
     qed
   qed
   also have "\<dots>"
-    apply (intro measurable_compose[OF measurable_pair_measure measurable_distr])
-    apply (rule return_measurable)
-    apply measurable
-    done
+  proof (intro measurable_compose[OF measurable_pair_measure measurable_distr])
+    show "return L \<in> L \<rightarrow>\<^sub>M subprob_algebra L"
+      by (rule return_measurable)
+  qed measurable
   finally show ?thesis .
 qed
 
@@ -703,19 +699,15 @@
 
   have *: "\<And>x. fst x \<in> space M \<Longrightarrow> snd x \<in> A (fst x) \<longleftrightarrow> x \<in> (SIGMA x:space M. A x)"
     by (auto simp: fun_eq_iff)
-  have "(\<lambda>(x, y). indicator (A x) y::ennreal) \<in> borel_measurable (M \<Otimes>\<^sub>M N)"
+  have MN: "Measurable.pred (M \<Otimes>\<^sub>M N) (\<lambda>w. w \<in> Sigma (space M) A)"
+    by auto
+  then have "(\<lambda>(x, y). indicator (A x) y::ennreal) \<in> borel_measurable (M \<Otimes>\<^sub>M N)"
     apply measurable
-    apply (subst measurable_cong)
-    apply (rule *)
-    apply (auto simp: space_pair_measure)
-    done
+    by (smt (verit, best) MN measurable_cong mem_Sigma_iff prod.collapse space_pair_measure)
   then have "(\<lambda>x. integral\<^sup>N (L x) (indicator (A x))) \<in> borel_measurable M"
     by (intro nn_integral_measurable_subprob_algebra2[where N=N] L)
   then show "(\<lambda>x. emeasure (L x) (A x)) \<in> borel_measurable M"
-    apply (rule measurable_cong[THEN iffD1, rotated])
-    apply (rule nn_integral_indicator)
-    apply (simp add: subprob_measurableD[OF L] **)
-    done
+    by (smt (verit) "**" L measurable_cong_simp nn_integral_indicator sets_kernel)
 qed
 
 lemma measure_measurable_subprob_algebra2:
@@ -751,7 +743,7 @@
   next
     assume "space (subprob_algebra N) \<noteq> {}"
     with eq show ?thesis
-      by (fastforce simp add: space_subprob_algebra)
+      by (smt (verit) equals0I mem_Collect_eq space_subprob_algebra)
   qed
 qed
 
@@ -807,10 +799,11 @@
   assume [simp]: "space N \<noteq> {}"
   fix M assume M: "M \<in> space (subprob_algebra (subprob_algebra N))"
   then have "(\<integral>\<^sup>+M'. emeasure M' (space N) \<partial>M) \<le> (\<integral>\<^sup>+M'. 1 \<partial>M)"
-    apply (intro nn_integral_mono)
-    apply (auto simp: space_subprob_algebra
-                 dest!: sets_eq_imp_space_eq subprob_space.emeasure_space_le_1)
-    done
+  proof (intro nn_integral_mono)
+    show "\<And>x. \<lbrakk>M \<in> space (subprob_algebra (subprob_algebra N)); x \<in> space M\<rbrakk>
+         \<Longrightarrow> emeasure x (space N) \<le> 1"
+      by (smt (verit) mem_Collect_eq sets_eq_imp_space_eq space_subprob_algebra subprob_space.subprob_emeasure_le_1)
+  qed
   with M show "subprob_space (join M)"
     by (intro subprob_spaceI)
        (auto simp: emeasure_join space_subprob_algebra M dest: subprob_space.emeasure_space_le_1)
@@ -888,7 +881,7 @@
 lemma measurable_join1:
   "\<lbrakk> f \<in> measurable N K; sets M = sets (subprob_algebra N) \<rbrakk>
   \<Longrightarrow> f \<in> measurable (join M) K"
-by(simp add: measurable_def)
+  by(simp add: measurable_def)
 
 lemma
   fixes f :: "_ \<Rightarrow> real"
@@ -1051,12 +1044,15 @@
 lemma join_return':
   assumes "sets N = sets M"
   shows "join (distr M (subprob_algebra N) (return N)) = M"
-apply (rule measure_eqI)
-apply (simp add: assms)
-apply (subgoal_tac "return N \<in> measurable M (subprob_algebra N)")
-apply (simp add: emeasure_join nn_integral_distr measurable_emeasure_subprob_algebra assms)
-apply (subst measurable_cong_sets, rule assms[symmetric], rule refl, rule return_measurable)
-done
+proof (rule measure_eqI)
+  fix A
+  have "return N \<in> measurable M (subprob_algebra N)"
+    using assms by auto
+  moreover
+  assume "A \<in> sets (join (distr M (subprob_algebra N) (return N)))"
+  ultimately show "emeasure (join (distr M (subprob_algebra N) (return N))) A = emeasure M A"
+    by (simp add: emeasure_join nn_integral_distr measurable_emeasure_subprob_algebra assms)
+qed (simp add: assms)
 
 lemma join_distr_distr:
   fixes f :: "'a \<Rightarrow> 'b" and M :: "'a measure measure" and N :: "'b measure"
@@ -1107,7 +1103,7 @@
   by (simp add: bind_def)
 
 lemma sets_bind_empty: "sets M = {} \<Longrightarrow> sets (bind M f) = {{}}"
-  by (auto simp: bind_def)
+  by auto
 
 lemma space_bind_empty: "space M = {} \<Longrightarrow> space (bind M f) = {}"
   by (simp add: bind_def)
@@ -1139,11 +1135,12 @@
 lemma bind_nonempty':
   assumes "f \<in> measurable M (subprob_algebra N)" "x \<in> space M"
   shows "bind M f = join (distr M (subprob_algebra N) f)"
-  using assms
-  apply (subst bind_nonempty, blast)
-  apply (subst subprob_algebra_cong[OF sets_kernel[OF assms(1) someI_ex]], blast)
-  apply (simp add: subprob_algebra_cong[OF sets_kernel[OF assms]])
-  done
+proof -
+  have "join (distr M (subprob_algebra (f (SOME x. x \<in> space M))) f) = join (distr M (subprob_algebra N) f)"
+    by (metis assms someI_ex subprob_algebra_cong subprob_measurableD(2))
+  with assms show ?thesis
+    by (metis bind_nonempty empty_iff)
+qed
 
 lemma bind_nonempty'':
   assumes "f \<in> measurable M (subprob_algebra N)" "space M \<noteq> {}"
@@ -1182,14 +1179,15 @@
   have "(AE x in M \<bind> N. P x) \<longleftrightarrow> (\<integral>\<^sup>+ x. integral\<^sup>N (N x) (indicator {x \<in> space B. \<not> P x}) \<partial>M) = 0"
     by (simp add: AE_iff_nn_integral sets_bind[OF _ M] space_bind[OF _ M] * nn_integral_bind[where B=B]
              del: nn_integral_indicator)
-  also have "\<dots> = (AE x in M. AE y in N x. P y)"
-    apply (subst nn_integral_0_iff_AE)
+  also have "... = (AE x in M. integral\<^sup>N (N x) (indicator {x \<in> space B. \<not> P x}) = 0)"
+  proof (rule nn_integral_0_iff_AE)
+    show "(\<lambda>x. integral\<^sup>N (N x) (indicator {x \<in> space B. \<not> P x})) \<in> borel_measurable M"
     apply (rule measurable_compose[OF N nn_integral_measurable_subprob_algebra])
-    apply measurable
+      by measurable
+  qed
+  also have "\<dots> = (AE x in M. AE y in N x. P y)"
     apply (intro eventually_subst AE_I2)
-    apply (auto simp add: subprob_measurableD(1)[OF N]
-                intro!: AE_iff_measurable[symmetric])
-    done
+    by (auto simp add: subprob_measurableD(1)[OF N] intro!: AE_iff_measurable[symmetric])
   finally show ?thesis .
 qed
 
@@ -1351,13 +1349,14 @@
   assumes N: "N \<in> measurable M (subprob_algebra K)" "space M \<noteq> {}"
   assumes f: "f \<in> measurable K R"
   shows "distr (M \<bind> N) R f = (M \<bind> (\<lambda>x. distr (N x) R f))"
-  unfolding bind_nonempty''[OF N]
-  apply (subst bind_nonempty''[OF measurable_compose[OF N(1) measurable_distr] N(2)])
-  apply (rule f)
-  apply (simp add: join_distr_distr[OF _ f, symmetric])
-  apply (subst distr_distr[OF measurable_distr, OF f N(1)])
-  apply (simp add: comp_def)
-  done
+proof -
+  have "distr (join (distr M (subprob_algebra K) N)) R f =
+       join (distr M (subprob_algebra R) (\<lambda>x. distr (N x) R f))"
+    by (simp add: assms distr_distr[OF measurable_distr] comp_def flip: join_distr_distr)
+  with assms show ?thesis
+    unfolding bind_nonempty''[OF N]
+    by (smt (verit) bind_nonempty sets_distr subprob_algebra_cong)
+qed
 
 lemma bind_distr:
   assumes f[measurable]: "f \<in> measurable M X"
@@ -1393,16 +1392,20 @@
   show "sets (restrict_space (bind M N) X) = sets (bind M (\<lambda>x. restrict_space (N x) X))"
     by (simp add: sets_restrict_space assms(2) sets_bind[OF sets_kernel[OF restrict_space_measurable[OF assms(4,3,1)]]])
   fix A assume "A \<in> sets (restrict_space (M \<bind> N) X)"
-  with X have "A \<in> sets K" "A \<subseteq> X"
+  with X have A: "A \<in> sets K" "A \<subseteq> X"
     by (auto simp: sets_restrict_space)
-  then show "emeasure (restrict_space (M \<bind> N) X) A = emeasure (M \<bind> (\<lambda>x. restrict_space (N x) X)) A"
-    using assms
-    apply (subst emeasure_restrict_space)
-    apply (simp_all add: emeasure_bind[OF assms(2,1)])
-    apply (subst emeasure_bind[OF _ restrict_space_measurable[OF _ _ N]])
-    apply (auto simp: sets_restrict_space emeasure_restrict_space space_subprob_algebra
-                intro!: nn_integral_cong dest!: measurable_space)
+  then have "emeasure (restrict_space (M \<bind> N) X) A = emeasure (M \<bind> N) A"
+    by (simp add: emeasure_restrict_space)
+  also have "\<dots> = \<integral>\<^sup>+ x. emeasure (N x) A \<partial>M"
+    by (metis \<open>A \<in> sets K\<close> N \<open>space M \<noteq> {}\<close> emeasure_bind)
+  also have "... = \<integral>\<^sup>+ x. emeasure (restrict_space (N x) X) A \<partial>M"
+    using A assms by (smt (verit, best) emeasure_restrict_space nn_integral_cong sets.Int_space_eq2 subprob_measurableD(2))
+  also have "\<dots> = emeasure (M \<bind> (\<lambda>x. restrict_space (N x) X)) A"
+    using A assms
+    apply (subst emeasure_bind[OF _ restrict_space_measurable])
+    apply (auto simp: sets_restrict_space)
     done
+  finally show "emeasure (restrict_space (M \<bind> N) X) A = emeasure (M \<bind> (\<lambda>x. restrict_space (N x) X)) A" .
 qed
 
 lemma bind_restrict_space:
@@ -1442,13 +1445,18 @@
      (simp_all add: space_subprob_algebra prob_space.not_empty emeasure_bind_const_prob_space)
 
 lemma bind_return_distr:
-    "space M \<noteq> {} \<Longrightarrow> f \<in> measurable M N \<Longrightarrow> bind M (return N \<circ> f) = distr M N f"
-  apply (simp add: bind_nonempty)
-  apply (subst subprob_algebra_cong)
-  apply (rule sets_return)
-  apply (subst distr_distr[symmetric])
-  apply (auto intro!: return_measurable simp: distr_distr[symmetric] join_return')
-  done
+  assumes "space M \<noteq> {}" "f \<in> measurable M N"
+  shows "bind M (return N \<circ> f) = distr M N f"
+proof -
+  have "bind M (return N \<circ> f)
+      = join (distr M (subprob_algebra (return N (f (SOME x. x \<in> space M)))) (return N \<circ> f))"
+    by (simp add: Giry_Monad.bind_def assms)
+  also have "\<dots> = join (distr M (subprob_algebra N) (return N \<circ> f))"
+    by (metis sets_return subprob_algebra_cong)
+  also have "\<dots> = distr M N f"
+    by (metis assms(2) distr_distr join_return' return_measurable sets_distr)
+  finally show ?thesis .
+qed
 
 lemma bind_return_distr':
   "space M \<noteq> {} \<Longrightarrow> f \<in> measurable M N \<Longrightarrow> bind M (\<lambda>x. return N (f x)) = distr M N f"
@@ -1469,6 +1477,9 @@
                          sets_kernel[OF M2 someI_ex[OF ex_in[OF \<open>space N \<noteq> {}\<close>]]]
   note space_some[simp] = sets_eq_imp_space_eq[OF this(1)] sets_eq_imp_space_eq[OF this(2)]
 
+
+  have *: "(\<lambda>x. distr x (subprob_algebra R) g) \<circ> f \<in> M \<rightarrow>\<^sub>M subprob_algebra (subprob_algebra R)"
+    using M1 M2 measurable_comp measurable_distr by blast
   have "bind M (\<lambda>x. bind (f x) g) =
         join (distr M (subprob_algebra R) (join \<circ> (\<lambda>x. (distr x (subprob_algebra R) g)) \<circ> f))"
     by (simp add: sets_eq_imp_space_eq[OF sets_fx] bind_nonempty o_def
@@ -1478,10 +1489,7 @@
                           (subprob_algebra (subprob_algebra R))
                           (\<lambda>x. distr x (subprob_algebra R) g))
                    (subprob_algebra R) join"
-      apply (subst distr_distr,
-             (blast intro: measurable_comp measurable_distr measurable_join M1 M2)+)+
-      apply (simp add: o_assoc)
-      done
+    by (simp add: distr_distr M1 M2 measurable_distr measurable_join fun.map_comp *)
   also have "join ... = bind (bind M f) g"
       by (simp add: join_assoc join_distr_distr M2 bind_nonempty cong: subprob_algebra_cong)
   finally show ?thesis ..
@@ -1637,7 +1645,7 @@
       using measurable_space[OF g]
     by (auto simp: measurable_restrict_space2_iff prob_algebra_def space_pair_measure Pi_iff
                 intro!: prob_space.prob_space_bind[where S=R] AE_I2)
-qed (insert g, simp)
+qed (use g in simp)
 
 
 lemma measurable_prob_algebra_generated:
@@ -1659,7 +1667,7 @@
       by (intro measurable_cong) auto
     then show "(\<lambda>a. emeasure (K a) \<Omega>) \<in> borel_measurable M" by simp
   qed
-qed (insert subsp, auto)
+qed (use subsp in auto)
 
 lemma in_space_prob_algebra:
   "x \<in> space (prob_algebra M) \<Longrightarrow> emeasure x (space M) = 1"
@@ -1668,13 +1676,7 @@
 
 lemma prob_space_pair:
   assumes "prob_space M" "prob_space N" shows "prob_space (M \<Otimes>\<^sub>M N)"
-proof -
-  interpret M: prob_space M by fact
-  interpret N: prob_space N by fact
-  interpret P: pair_prob_space M N proof qed
-  show ?thesis
-    by unfold_locales
-qed
+  by (metis assms measurable_fst prob_space.distr_pair_fst prob_space_distrD)
 
 lemma measurable_pair_prob[measurable]:
   "f \<in> M \<rightarrow>\<^sub>M prob_algebra N \<Longrightarrow> g \<in> M \<rightarrow>\<^sub>M prob_algebra L \<Longrightarrow> (\<lambda>x. f x \<Otimes>\<^sub>M g x) \<in> M \<rightarrow>\<^sub>M prob_algebra (N \<Otimes>\<^sub>M L)"
@@ -1738,7 +1740,7 @@
   also from assms(3) x have "... = emeasure (distr (density M f') (count_space A) g) {x}"
     by (subst emeasure_distr) simp_all
   finally show "f x = emeasure (distr (density M f') (count_space A) g) {x}" .
-qed (insert assms, auto)
+qed (use assms in auto)
 
 lemma bind_cong_AE:
   assumes "M = N"
@@ -1796,7 +1798,6 @@
     by eventually_elim auto
   thus "y \<in> space M"
     by simp
-
   show "M = return M y"
   proof (rule measure_eqI)
     fix X assume X: "X \<in> sets M"