src/HOL/Probability/SPMF.thy
author hoelzl
Fri Sep 16 13:56:51 2016 +0200 (2016-09-16)
changeset 63886 685fb01256af
parent 63626 44ce6b524ff3
child 64240 eabf80376aab
permissions -rw-r--r--
move Henstock-Kurzweil integration after Lebesgue_Measure; replace content by abbreviation measure lborel
     1 (* Author: Andreas Lochbihler, ETH Zurich *)
     2 
     3 section \<open>Discrete subprobability distribution\<close>
     4 
     5 theory SPMF imports
     6   Probability_Mass_Function
     7   "~~/src/HOL/Library/Complete_Partial_Order2"
     8   "~~/src/HOL/Library/Rewrite"
     9 begin
    10 
    11 subsection \<open>Auxiliary material\<close>
    12 
    13 lemma cSUP_singleton [simp]: "(SUP x:{x}. f x :: _ :: conditionally_complete_lattice) = f x"
    14 by (metis cSup_singleton image_empty image_insert)
    15 
    16 subsubsection \<open>More about extended reals\<close>
    17 
    18 lemma [simp]:
    19   shows ennreal_max_0: "ennreal (max 0 x) = ennreal x"
    20   and ennreal_max_0': "ennreal (max x 0) = ennreal x"
    21 by(simp_all add: max_def ennreal_eq_0_iff)
    22 
    23 lemma ennreal_enn2real_if: "ennreal (enn2real r) = (if r = \<top> then 0 else r)"
    24 by(auto intro!: ennreal_enn2real simp add: less_top)
    25 
    26 lemma e2ennreal_0 [simp]: "e2ennreal 0 = 0"
    27 by(simp add: zero_ennreal_def)
    28 
    29 lemma enn2real_bot [simp]: "enn2real \<bottom> = 0"
    30 by(simp add: bot_ennreal_def)
    31 
    32 lemma continuous_at_ennreal[continuous_intros]: "continuous F f \<Longrightarrow> continuous F (\<lambda>x. ennreal (f x))"
    33   unfolding continuous_def by auto
    34 
    35 lemma ennreal_Sup:
    36   assumes *: "(SUP a:A. ennreal a) \<noteq> \<top>"
    37   and "A \<noteq> {}"
    38   shows "ennreal (Sup A) = (SUP a:A. ennreal a)"
    39 proof (rule continuous_at_Sup_mono)
    40   obtain r where r: "ennreal r = (SUP a:A. ennreal a)" "r \<ge> 0"
    41     using * by(cases "(SUP a:A. ennreal a)") simp_all
    42   then show "bdd_above A"
    43     by(auto intro!: SUP_upper bdd_aboveI[of _ r] simp add: ennreal_le_iff[symmetric])
    44 qed (auto simp: mono_def continuous_at_imp_continuous_at_within continuous_at_ennreal ennreal_leI assms)
    45 
    46 lemma ennreal_SUP:
    47   "\<lbrakk> (SUP a:A. ennreal (f a)) \<noteq> \<top>; A \<noteq> {} \<rbrakk> \<Longrightarrow> ennreal (SUP a:A. f a) = (SUP a:A. ennreal (f a))"
    48 using ennreal_Sup[of "f ` A"] by auto
    49 
    50 lemma ennreal_lt_0: "x < 0 \<Longrightarrow> ennreal x = 0"
    51 by(simp add: ennreal_eq_0_iff)
    52 
    53 subsubsection \<open>More about @{typ "'a option"}\<close>
    54 
    55 lemma None_in_map_option_image [simp]: "None \<in> map_option f ` A \<longleftrightarrow> None \<in> A"
    56 by auto
    57 
    58 lemma Some_in_map_option_image [simp]: "Some x \<in> map_option f ` A \<longleftrightarrow> (\<exists>y. x = f y \<and> Some y \<in> A)"
    59 by(auto intro: rev_image_eqI dest: sym)
    60 
    61 lemma case_option_collapse: "case_option x (\<lambda>_. x) = (\<lambda>_. x)"
    62 by(simp add: fun_eq_iff split: option.split)
    63 
    64 lemma case_option_id: "case_option None Some = id"
    65 by(rule ext)(simp split: option.split)
    66 
    67 inductive ord_option :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> 'a option \<Rightarrow> 'b option \<Rightarrow> bool"
    68   for ord :: "'a \<Rightarrow> 'b \<Rightarrow> bool"
    69 where
    70   None: "ord_option ord None x"
    71 | Some: "ord x y \<Longrightarrow> ord_option ord (Some x) (Some y)"
    72 
    73 inductive_simps ord_option_simps [simp]:
    74   "ord_option ord None x"
    75   "ord_option ord x None"
    76   "ord_option ord (Some x) (Some y)"
    77   "ord_option ord (Some x) None"
    78 
    79 inductive_simps ord_option_eq_simps [simp]:
    80   "ord_option op = None y"
    81   "ord_option op = (Some x) y"
    82 
    83 lemma ord_option_reflI: "(\<And>y. y \<in> set_option x \<Longrightarrow> ord y y) \<Longrightarrow> ord_option ord x x"
    84 by(cases x) simp_all
    85 
    86 lemma reflp_ord_option: "reflp ord \<Longrightarrow> reflp (ord_option ord)"
    87 by(simp add: reflp_def ord_option_reflI)
    88 
    89 lemma ord_option_trans:
    90   "\<lbrakk> ord_option ord x y; ord_option ord y z;
    91     \<And>a b c. \<lbrakk> a \<in> set_option x; b \<in> set_option y; c \<in> set_option z; ord a b; ord b c \<rbrakk> \<Longrightarrow> ord a c \<rbrakk>
    92   \<Longrightarrow> ord_option ord x z"
    93 by(auto elim!: ord_option.cases)
    94 
    95 lemma transp_ord_option: "transp ord \<Longrightarrow> transp (ord_option ord)"
    96 unfolding transp_def by(blast intro: ord_option_trans)
    97 
    98 lemma antisymP_ord_option: "antisymP ord \<Longrightarrow> antisymP (ord_option ord)"
    99 by(auto intro!: antisymI elim!: ord_option.cases dest: antisymD)
   100 
   101 lemma ord_option_chainD:
   102   "Complete_Partial_Order.chain (ord_option ord) Y
   103   \<Longrightarrow> Complete_Partial_Order.chain ord {x. Some x \<in> Y}"
   104 by(rule chainI)(auto dest: chainD)
   105 
   106 definition lub_option :: "('a set \<Rightarrow> 'b) \<Rightarrow> 'a option set \<Rightarrow> 'b option"
   107 where "lub_option lub Y = (if Y \<subseteq> {None} then None else Some (lub {x. Some x \<in> Y}))"
   108 
   109 lemma map_lub_option: "map_option f (lub_option lub Y) = lub_option (f \<circ> lub) Y"
   110 by(simp add: lub_option_def)
   111 
   112 lemma lub_option_upper:
   113   assumes "Complete_Partial_Order.chain (ord_option ord) Y" "x \<in> Y"
   114   and lub_upper: "\<And>Y x. \<lbrakk> Complete_Partial_Order.chain ord Y; x \<in> Y \<rbrakk> \<Longrightarrow> ord x (lub Y)"
   115   shows "ord_option ord x (lub_option lub Y)"
   116 using assms(1-2)
   117 by(cases x)(auto simp add: lub_option_def intro: lub_upper[OF ord_option_chainD])
   118 
   119 lemma lub_option_least:
   120   assumes Y: "Complete_Partial_Order.chain (ord_option ord) Y"
   121   and upper: "\<And>x. x \<in> Y \<Longrightarrow> ord_option ord x y"
   122   assumes lub_least: "\<And>Y y. \<lbrakk> Complete_Partial_Order.chain ord Y; \<And>x. x \<in> Y \<Longrightarrow> ord x y \<rbrakk> \<Longrightarrow> ord (lub Y) y"
   123   shows "ord_option ord (lub_option lub Y) y"
   124 using Y
   125 by(cases y)(auto 4 3 simp add: lub_option_def intro: lub_least[OF ord_option_chainD] dest: upper)
   126 
   127 lemma lub_map_option: "lub_option lub (map_option f ` Y) = lub_option (lub \<circ> op ` f) Y"
   128 apply(auto simp add: lub_option_def)
   129 apply(erule notE)
   130 apply(rule arg_cong[where f=lub])
   131 apply(auto intro: rev_image_eqI dest: sym)
   132 done
   133 
   134 lemma ord_option_mono: "\<lbrakk> ord_option A x y; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> ord_option B x y"
   135 by(auto elim: ord_option.cases)
   136 
   137 lemma ord_option_mono' [mono]:
   138   "(\<And>x y. A x y \<longrightarrow> B x y) \<Longrightarrow> ord_option A x y \<longrightarrow> ord_option B x y"
   139 by(blast intro: ord_option_mono)
   140 
   141 lemma ord_option_compp: "ord_option (A OO B) = ord_option A OO ord_option B"
   142 by(auto simp add: fun_eq_iff elim!: ord_option.cases intro: ord_option.intros)
   143 
   144 lemma ord_option_inf: "inf (ord_option A) (ord_option B) = ord_option (inf A B)" (is "?lhs = ?rhs")
   145 proof(rule antisym)
   146   show "?lhs \<le> ?rhs" by(auto elim!: ord_option.cases)
   147 qed(auto elim: ord_option_mono)
   148 
   149 lemma ord_option_map2: "ord_option ord x (map_option f y) = ord_option (\<lambda>x y. ord x (f y)) x y"
   150 by(auto elim: ord_option.cases)
   151 
   152 lemma ord_option_map1: "ord_option ord (map_option f x) y = ord_option (\<lambda>x y. ord (f x) y) x y"
   153 by(auto elim: ord_option.cases)
   154 
   155 lemma option_ord_Some1_iff: "option_ord (Some x) y \<longleftrightarrow> y = Some x"
   156 by(auto simp add: flat_ord_def)
   157 
   158 subsubsection \<open>A relator for sets that treats sets like predicates\<close>
   159 
   160 context includes lifting_syntax
   161 begin
   162 
   163 definition rel_pred :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> 'a set \<Rightarrow> 'b set \<Rightarrow> bool"
   164 where "rel_pred R A B = (R ===> op =) (\<lambda>x. x \<in> A) (\<lambda>y. y \<in> B)"
   165 
   166 lemma rel_predI: "(R ===> op =) (\<lambda>x. x \<in> A) (\<lambda>y. y \<in> B) \<Longrightarrow> rel_pred R A B"
   167 by(simp add: rel_pred_def)
   168 
   169 lemma rel_predD: "\<lbrakk> rel_pred R A B; R x y \<rbrakk> \<Longrightarrow> x \<in> A \<longleftrightarrow> y \<in> B"
   170 by(simp add: rel_pred_def rel_fun_def)
   171 
   172 lemma Collect_parametric: "((A ===> op =) ===> rel_pred A) Collect Collect"
   173   \<comment> \<open>Declare this rule as @{attribute transfer_rule} only locally
   174       because it blows up the search space for @{method transfer}
   175       (in combination with @{thm [source] Collect_transfer})\<close>
   176 by(simp add: rel_funI rel_predI)
   177 
   178 end
   179 
   180 subsubsection \<open>Monotonicity rules\<close>
   181 
   182 lemma monotone_gfp_eadd1: "monotone op \<ge> op \<ge> (\<lambda>x. x + y :: enat)"
   183 by(auto intro!: monotoneI)
   184 
   185 lemma monotone_gfp_eadd2: "monotone op \<ge> op \<ge> (\<lambda>y. x + y :: enat)"
   186 by(auto intro!: monotoneI)
   187 
   188 lemma mono2mono_gfp_eadd[THEN gfp.mono2mono2, cont_intro, simp]:
   189   shows monotone_eadd: "monotone (rel_prod op \<ge> op \<ge>) op \<ge> (\<lambda>(x, y). x + y :: enat)"
   190 by(simp add: monotone_gfp_eadd1 monotone_gfp_eadd2)
   191 
   192 lemma eadd_gfp_partial_function_mono [partial_function_mono]:
   193   "\<lbrakk> monotone (fun_ord op \<ge>) op \<ge> f; monotone (fun_ord op \<ge>) op \<ge> g \<rbrakk>
   194   \<Longrightarrow> monotone (fun_ord op \<ge>) op \<ge> (\<lambda>x. f x + g x :: enat)"
   195 by(rule mono2mono_gfp_eadd)
   196 
   197 lemma mono2mono_ereal[THEN lfp.mono2mono]:
   198   shows monotone_ereal: "monotone op \<le> op \<le> ereal"
   199 by(rule monotoneI) simp
   200 
   201 lemma mono2mono_ennreal[THEN lfp.mono2mono]:
   202   shows monotone_ennreal: "monotone op \<le> op \<le> ennreal"
   203 by(rule monotoneI)(simp add: ennreal_leI)
   204 
   205 subsubsection \<open>Bijections\<close>
   206 
   207 lemma bi_unique_rel_set_bij_betw:
   208   assumes unique: "bi_unique R"
   209   and rel: "rel_set R A B"
   210   shows "\<exists>f. bij_betw f A B \<and> (\<forall>x\<in>A. R x (f x))"
   211 proof -
   212   from assms obtain f where f: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)" and B: "\<And>x. x \<in> A \<Longrightarrow> f x \<in> B"
   213     apply(atomize_elim)
   214     apply(fold all_conj_distrib)
   215     apply(subst choice_iff[symmetric])
   216     apply(auto dest: rel_setD1)
   217     done
   218   have "inj_on f A" by(rule inj_onI)(auto dest!: f dest: bi_uniqueDl[OF unique])
   219   moreover have "f ` A = B" using rel
   220     by(auto 4 3 intro: B dest: rel_setD2 f bi_uniqueDr[OF unique])
   221   ultimately have "bij_betw f A B" unfolding bij_betw_def ..
   222   thus ?thesis using f by blast
   223 qed
   224 
   225 lemma bij_betw_rel_setD: "bij_betw f A B \<Longrightarrow> rel_set (\<lambda>x y. y = f x) A B"
   226 by(rule rel_setI)(auto dest: bij_betwE bij_betw_imp_surj_on[symmetric])
   227 
   228 subsection \<open>Subprobability mass function\<close>
   229 
   230 type_synonym 'a spmf = "'a option pmf"
   231 translations (type) "'a spmf" \<leftharpoondown> (type) "'a option pmf"
   232 
   233 definition measure_spmf :: "'a spmf \<Rightarrow> 'a measure"
   234 where "measure_spmf p = distr (restrict_space (measure_pmf p) (range Some)) (count_space UNIV) the"
   235 
   236 abbreviation spmf :: "'a spmf \<Rightarrow> 'a \<Rightarrow> real"
   237 where "spmf p x \<equiv> pmf p (Some x)"
   238 
   239 lemma space_measure_spmf: "space (measure_spmf p) = UNIV"
   240 by(simp add: measure_spmf_def)
   241 
   242 lemma sets_measure_spmf [simp, measurable_cong]: "sets (measure_spmf p) = sets (count_space UNIV)"
   243 by(simp add: measure_spmf_def)
   244 
   245 lemma measure_spmf_not_bot [simp]: "measure_spmf p \<noteq> \<bottom>"
   246 proof
   247   assume "measure_spmf p = \<bottom>"
   248   hence "space (measure_spmf p) = space \<bottom>" by simp
   249   thus False by(simp add: space_measure_spmf)
   250 qed
   251 
   252 lemma measurable_the_measure_pmf_Some [measurable, simp]:
   253   "the \<in> measurable (restrict_space (measure_pmf p) (range Some)) (count_space UNIV)"
   254 by(auto simp add: measurable_def sets_restrict_space space_restrict_space integral_restrict_space)
   255 
   256 lemma measurable_spmf_measure1[simp]: "measurable (measure_spmf M) N = UNIV \<rightarrow> space N"
   257 by(auto simp: measurable_def space_measure_spmf)
   258 
   259 lemma measurable_spmf_measure2[simp]: "measurable N (measure_spmf M) = measurable N (count_space UNIV)"
   260 by(intro measurable_cong_sets) simp_all
   261 
   262 lemma subprob_space_measure_spmf [simp, intro!]: "subprob_space (measure_spmf p)"
   263 proof
   264   show "emeasure (measure_spmf p) (space (measure_spmf p)) \<le> 1"
   265     by(simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space measure_pmf.measure_le_1)
   266 qed(simp add: space_measure_spmf)
   267 
   268 interpretation measure_spmf: subprob_space "measure_spmf p" for p
   269 by(rule subprob_space_measure_spmf)
   270 
   271 lemma finite_measure_spmf [simp]: "finite_measure (measure_spmf p)"
   272 by unfold_locales
   273 
   274 lemma spmf_conv_measure_spmf: "spmf p x = measure (measure_spmf p) {x}"
   275 by(auto simp add: measure_spmf_def measure_distr measure_restrict_space pmf.rep_eq space_restrict_space intro: arg_cong2[where f=measure])
   276 
   277 lemma emeasure_measure_spmf_conv_measure_pmf:
   278   "emeasure (measure_spmf p) A = emeasure (measure_pmf p) (Some ` A)"
   279 by(auto simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])
   280 
   281 lemma measure_measure_spmf_conv_measure_pmf:
   282   "measure (measure_spmf p) A = measure (measure_pmf p) (Some ` A)"
   283 using emeasure_measure_spmf_conv_measure_pmf[of p A]
   284 by(simp add: measure_spmf.emeasure_eq_measure measure_pmf.emeasure_eq_measure)
   285 
   286 lemma emeasure_spmf_map_pmf_Some [simp]:
   287   "emeasure (measure_spmf (map_pmf Some p)) A = emeasure (measure_pmf p) A"
   288 by(auto simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])
   289 
   290 lemma measure_spmf_map_pmf_Some [simp]:
   291   "measure (measure_spmf (map_pmf Some p)) A = measure (measure_pmf p) A"
   292 using emeasure_spmf_map_pmf_Some[of p A] by(simp add: measure_spmf.emeasure_eq_measure measure_pmf.emeasure_eq_measure)
   293 
   294 lemma nn_integral_measure_spmf: "(\<integral>\<^sup>+ x. f x \<partial>measure_spmf p) = \<integral>\<^sup>+ x. ennreal (spmf p x) * f x \<partial>count_space UNIV"
   295   (is "?lhs = ?rhs")
   296 proof -
   297   have "?lhs = \<integral>\<^sup>+ x. pmf p x * f (the x) \<partial>count_space (range Some)"
   298     by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space nn_integral_measure_pmf nn_integral_count_space_indicator ac_simps times_ereal.simps(1)[symmetric] del: times_ereal.simps(1))
   299   also have "\<dots> = \<integral>\<^sup>+ x. ennreal (spmf p (the x)) * f (the x) \<partial>count_space (range Some)"
   300     by(rule nn_integral_cong) auto
   301   also have "\<dots> = \<integral>\<^sup>+ x. spmf p (the (Some x)) * f (the (Some x)) \<partial>count_space UNIV"
   302     by(rule nn_integral_bij_count_space[symmetric])(simp add: bij_betw_def)
   303   also have "\<dots> = ?rhs" by simp
   304   finally show ?thesis .
   305 qed
   306 
   307 lemma integral_measure_spmf:
   308   assumes "integrable (measure_spmf p) f"
   309   shows "(\<integral> x. f x \<partial>measure_spmf p) = \<integral> x. spmf p x * f x \<partial>count_space UNIV"
   310 proof -
   311   have "integrable (count_space UNIV) (\<lambda>x. spmf p x * f x)"
   312     using assms by(simp add: integrable_iff_bounded nn_integral_measure_spmf abs_mult ennreal_mult'')
   313   then show ?thesis using assms
   314     by(simp add: real_lebesgue_integral_def nn_integral_measure_spmf ennreal_mult'[symmetric])
   315 qed
   316 
   317 lemma emeasure_spmf_single: "emeasure (measure_spmf p) {x} = spmf p x"
   318 by(simp add: measure_spmf.emeasure_eq_measure spmf_conv_measure_spmf)
   319 
   320 lemma measurable_measure_spmf[measurable]:
   321   "(\<lambda>x. measure_spmf (M x)) \<in> measurable (count_space UNIV) (subprob_algebra (count_space UNIV))"
   322 by (auto simp: space_subprob_algebra)
   323 
   324 lemma nn_integral_measure_spmf_conv_measure_pmf:
   325   assumes [measurable]: "f \<in> borel_measurable (count_space UNIV)"
   326   shows "nn_integral (measure_spmf p) f = nn_integral (restrict_space (measure_pmf p) (range Some)) (f \<circ> the)"
   327 by(simp add: measure_spmf_def nn_integral_distr o_def)
   328 
   329 lemma measure_spmf_in_space_subprob_algebra [simp]:
   330   "measure_spmf p \<in> space (subprob_algebra (count_space UNIV))"
   331 by(simp add: space_subprob_algebra)
   332 
   333 lemma nn_integral_spmf_neq_top: "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV) \<noteq> \<top>"
   334 using nn_integral_measure_spmf[where f="\<lambda>_. 1", of p, symmetric] by simp
   335 
   336 lemma SUP_spmf_neq_top': "(SUP p:Y. ennreal (spmf p x)) \<noteq> \<top>"
   337 proof(rule neq_top_trans)
   338   show "(SUP p:Y. ennreal (spmf p x)) \<le> 1" by(rule SUP_least)(simp add: pmf_le_1)
   339 qed simp
   340 
   341 lemma SUP_spmf_neq_top: "(SUP i. ennreal (spmf (Y i) x)) \<noteq> \<top>"
   342 proof(rule neq_top_trans)
   343   show "(SUP i. ennreal (spmf (Y i) x)) \<le> 1" by(rule SUP_least)(simp add: pmf_le_1)
   344 qed simp
   345 
   346 lemma SUP_emeasure_spmf_neq_top: "(SUP p:Y. emeasure (measure_spmf p) A) \<noteq> \<top>"
   347 proof(rule neq_top_trans)
   348   show "(SUP p:Y. emeasure (measure_spmf p) A) \<le> 1"
   349     by(rule SUP_least)(simp add: measure_spmf.subprob_emeasure_le_1)
   350 qed simp
   351 
   352 subsection \<open>Support\<close>
   353 
   354 definition set_spmf :: "'a spmf \<Rightarrow> 'a set"
   355 where "set_spmf p = set_pmf p \<bind> set_option"
   356 
   357 lemma set_spmf_rep_eq: "set_spmf p = {x. measure (measure_spmf p) {x} \<noteq> 0}"
   358 proof -
   359   have "\<And>x :: 'a. the -` {x} \<inter> range Some = {Some x}" by auto
   360   then show ?thesis
   361     by(auto simp add: set_spmf_def set_pmf.rep_eq measure_spmf_def measure_distr measure_restrict_space space_restrict_space intro: rev_image_eqI)
   362 qed
   363 
   364 lemma in_set_spmf: "x \<in> set_spmf p \<longleftrightarrow> Some x \<in> set_pmf p"
   365 by(simp add: set_spmf_def)
   366 
   367 lemma AE_measure_spmf_iff [simp]: "(AE x in measure_spmf p. P x) \<longleftrightarrow> (\<forall>x\<in>set_spmf p. P x)"
   368 by(auto 4 3 simp add: measure_spmf_def AE_distr_iff AE_restrict_space_iff AE_measure_pmf_iff set_spmf_def cong del: AE_cong)
   369 
   370 lemma spmf_eq_0_set_spmf: "spmf p x = 0 \<longleftrightarrow> x \<notin> set_spmf p"
   371 by(auto simp add: pmf_eq_0_set_pmf set_spmf_def intro: rev_image_eqI)
   372 
   373 lemma in_set_spmf_iff_spmf: "x \<in> set_spmf p \<longleftrightarrow> spmf p x \<noteq> 0"
   374 by(auto simp add: set_spmf_def set_pmf_iff intro: rev_image_eqI)
   375 
   376 lemma set_spmf_return_pmf_None [simp]: "set_spmf (return_pmf None) = {}"
   377 by(auto simp add: set_spmf_def)
   378 
   379 lemma countable_set_spmf [simp]: "countable (set_spmf p)"
   380 by(simp add: set_spmf_def bind_UNION)
   381 
   382 lemma spmf_eqI:
   383   assumes "\<And>i. spmf p i = spmf q i"
   384   shows "p = q"
   385 proof(rule pmf_eqI)
   386   fix i
   387   show "pmf p i = pmf q i"
   388   proof(cases i)
   389     case (Some i')
   390     thus ?thesis by(simp add: assms)
   391   next
   392     case None
   393     have "ennreal (pmf p i) = measure (measure_pmf p) {i}" by(simp add: pmf_def)
   394     also have "{i} = space (measure_pmf p) - range Some"
   395       by(auto simp add: None intro: ccontr)
   396     also have "measure (measure_pmf p) \<dots> = ennreal 1 - measure (measure_pmf p) (range Some)"
   397       by(simp add: measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf)
   398     also have "range Some = (\<Union>x\<in>set_spmf p. {Some x}) \<union> Some ` (- set_spmf p)"
   399       by auto
   400     also have "measure (measure_pmf p) \<dots> = measure (measure_pmf p) (\<Union>x\<in>set_spmf p. {Some x})"
   401       by(rule measure_pmf.measure_zero_union)(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff)
   402     also have "ennreal \<dots> = \<integral>\<^sup>+ x. measure (measure_pmf p) {Some x} \<partial>count_space (set_spmf p)"
   403       unfolding measure_pmf.emeasure_eq_measure[symmetric]
   404       by(simp_all add: emeasure_UN_countable disjoint_family_on_def)
   405     also have "\<dots> = \<integral>\<^sup>+ x. spmf p x \<partial>count_space (set_spmf p)" by(simp add: pmf_def)
   406     also have "\<dots> = \<integral>\<^sup>+ x. spmf q x \<partial>count_space (set_spmf p)" by(simp add: assms)
   407     also have "set_spmf p = set_spmf q" by(auto simp add: in_set_spmf_iff_spmf assms)
   408     also have "(\<integral>\<^sup>+ x. spmf q x \<partial>count_space (set_spmf q)) = \<integral>\<^sup>+ x. measure (measure_pmf q) {Some x} \<partial>count_space (set_spmf q)"
   409       by(simp add: pmf_def)
   410     also have "\<dots> = measure (measure_pmf q) (\<Union>x\<in>set_spmf q. {Some x})"
   411       unfolding measure_pmf.emeasure_eq_measure[symmetric]
   412       by(simp_all add: emeasure_UN_countable disjoint_family_on_def)
   413     also have "\<dots> = measure (measure_pmf q) ((\<Union>x\<in>set_spmf q. {Some x}) \<union> Some ` (- set_spmf q))"
   414       by(rule ennreal_cong measure_pmf.measure_zero_union[symmetric])+(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff)
   415     also have "((\<Union>x\<in>set_spmf q. {Some x}) \<union> Some ` (- set_spmf q)) = range Some" by auto
   416     also have "ennreal 1 - measure (measure_pmf q) \<dots> = measure (measure_pmf q) (space (measure_pmf q) - range Some)"
   417       by(simp add: one_ereal_def measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf)
   418     also have "space (measure_pmf q) - range Some = {i}"
   419       by(auto simp add: None intro: ccontr)
   420     also have "measure (measure_pmf q) \<dots> = pmf q i" by(simp add: pmf_def)
   421     finally show ?thesis by simp
   422   qed
   423 qed
   424 
   425 lemma integral_measure_spmf_restrict:
   426   fixes f ::  "'a \<Rightarrow> 'b :: {banach, second_countable_topology}" shows
   427   "(\<integral> x. f x \<partial>measure_spmf M) = (\<integral> x. f x \<partial>restrict_space (measure_spmf M) (set_spmf M))"
   428 by(auto intro!: integral_cong_AE simp add: integral_restrict_space)
   429 
   430 lemma nn_integral_measure_spmf':
   431   "(\<integral>\<^sup>+ x. f x \<partial>measure_spmf p) = \<integral>\<^sup>+ x. ennreal (spmf p x) * f x \<partial>count_space (set_spmf p)"
   432 by(auto simp add: nn_integral_measure_spmf nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator)
   433 
   434 subsection \<open>Functorial structure\<close>
   435 
   436 abbreviation map_spmf :: "('a \<Rightarrow> 'b) \<Rightarrow> 'a spmf \<Rightarrow> 'b spmf"
   437 where "map_spmf f \<equiv> map_pmf (map_option f)"
   438 
   439 context begin
   440 local_setup \<open>Local_Theory.map_background_naming (Name_Space.mandatory_path "spmf")\<close>
   441 
   442 lemma map_comp: "map_spmf f (map_spmf g p) = map_spmf (f \<circ> g) p"
   443 by(simp add: pmf.map_comp o_def option.map_comp)
   444 
   445 lemma map_id0: "map_spmf id = id"
   446 by(simp add: pmf.map_id option.map_id0)
   447 
   448 lemma map_id [simp]: "map_spmf id p = p"
   449 by(simp add: map_id0)
   450 
   451 lemma map_ident [simp]: "map_spmf (\<lambda>x. x) p = p"
   452 by(simp add: id_def[symmetric])
   453 
   454 end
   455 
   456 lemma set_map_spmf [simp]: "set_spmf (map_spmf f p) = f ` set_spmf p"
   457 by(simp add: set_spmf_def image_bind bind_image o_def Option.option.set_map)
   458 
   459 lemma map_spmf_cong:
   460   "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q \<Longrightarrow> f x = g x \<rbrakk>
   461   \<Longrightarrow> map_spmf f p = map_spmf g q"
   462 by(auto intro: pmf.map_cong option.map_cong simp add: in_set_spmf)
   463 
   464 lemma map_spmf_cong_simp:
   465   "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q =simp=> f x = g x \<rbrakk>
   466   \<Longrightarrow> map_spmf f p = map_spmf g q"
   467 unfolding simp_implies_def by(rule map_spmf_cong)
   468 
   469 lemma map_spmf_idI: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> f x = x) \<Longrightarrow> map_spmf f p = p"
   470 by(rule map_pmf_idI map_option_idI)+(simp add: in_set_spmf)
   471 
   472 lemma emeasure_map_spmf:
   473   "emeasure (measure_spmf (map_spmf f p)) A = emeasure (measure_spmf p) (f -` A)"
   474 by(auto simp add: measure_spmf_def emeasure_distr measurable_restrict_space1 space_restrict_space emeasure_restrict_space intro: arg_cong2[where f=emeasure])
   475 
   476 lemma measure_map_spmf: "measure (measure_spmf (map_spmf f p)) A = measure (measure_spmf p) (f -` A)"
   477 using emeasure_map_spmf[of f p A] by(simp add: measure_spmf.emeasure_eq_measure)
   478 
   479 lemma measure_map_spmf_conv_distr:
   480   "measure_spmf (map_spmf f p) = distr (measure_spmf p) (count_space UNIV) f"
   481 by(rule measure_eqI)(simp_all add: emeasure_map_spmf emeasure_distr)
   482 
   483 lemma spmf_map_pmf_Some [simp]: "spmf (map_pmf Some p) i = pmf p i"
   484 by(simp add: pmf_map_inj')
   485 
   486 lemma spmf_map_inj: "\<lbrakk> inj_on f (set_spmf M); x \<in> set_spmf M \<rbrakk> \<Longrightarrow> spmf (map_spmf f M) (f x) = spmf M x"
   487 by(subst option.map(2)[symmetric, where f=f])(rule pmf_map_inj, auto simp add: in_set_spmf inj_on_def elim!: option.inj_map_strong[rotated])
   488 
   489 lemma spmf_map_inj': "inj f \<Longrightarrow> spmf (map_spmf f M) (f x) = spmf M x"
   490 by(subst option.map(2)[symmetric, where f=f])(rule pmf_map_inj'[OF option.inj_map])
   491 
   492 lemma spmf_map_outside: "x \<notin> f ` set_spmf M \<Longrightarrow> spmf (map_spmf f M) x = 0"
   493 unfolding spmf_eq_0_set_spmf by simp
   494 
   495 lemma ennreal_spmf_map: "ennreal (spmf (map_spmf f p) x) = emeasure (measure_spmf p) (f -` {x})"
   496 by(auto simp add: ennreal_pmf_map measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])
   497 
   498 lemma spmf_map: "spmf (map_spmf f p) x = measure (measure_spmf p) (f -` {x})"
   499 using ennreal_spmf_map[of f p x] by(simp add: measure_spmf.emeasure_eq_measure)
   500 
   501 lemma ennreal_spmf_map_conv_nn_integral:
   502   "ennreal (spmf (map_spmf f p) x) = integral\<^sup>N (measure_spmf p) (indicator (f -` {x}))"
   503 by(auto simp add: ennreal_pmf_map measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space intro: arg_cong2[where f=emeasure])
   504 
   505 subsection \<open>Monad operations\<close>
   506 
   507 subsubsection \<open>Return\<close>
   508 
   509 abbreviation return_spmf :: "'a \<Rightarrow> 'a spmf"
   510 where "return_spmf x \<equiv> return_pmf (Some x)"
   511 
   512 lemma pmf_return_spmf: "pmf (return_spmf x) y = indicator {y} (Some x)"
   513 by(fact pmf_return)
   514 
   515 lemma measure_spmf_return_spmf: "measure_spmf (return_spmf x) = Giry_Monad.return (count_space UNIV) x"
   516 by(rule measure_eqI)(simp_all add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space indicator_def)
   517 
   518 lemma measure_spmf_return_pmf_None [simp]: "measure_spmf (return_pmf None) = null_measure (count_space UNIV)"
   519 by(rule measure_eqI)(auto simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space indicator_eq_0_iff)
   520 
   521 lemma set_return_spmf [simp]: "set_spmf (return_spmf x) = {x}"
   522 by(auto simp add: set_spmf_def)
   523 
   524 subsubsection \<open>Bind\<close>
   525 
   526 definition bind_spmf :: "'a spmf \<Rightarrow> ('a \<Rightarrow> 'b spmf) \<Rightarrow> 'b spmf"
   527 where "bind_spmf x f = bind_pmf x (\<lambda>a. case a of None \<Rightarrow> return_pmf None | Some a' \<Rightarrow> f a')"
   528 
   529 adhoc_overloading Monad_Syntax.bind bind_spmf
   530 
   531 lemma return_None_bind_spmf [simp]: "return_pmf None \<bind> (f :: 'a \<Rightarrow> _) = return_pmf None"
   532 by(simp add: bind_spmf_def bind_return_pmf)
   533 
   534 lemma return_bind_spmf [simp]: "return_spmf x \<bind> f = f x"
   535 by(simp add: bind_spmf_def bind_return_pmf)
   536 
   537 lemma bind_return_spmf [simp]: "x \<bind> return_spmf = x"
   538 proof -
   539   have "\<And>a :: 'a option. (case a of None \<Rightarrow> return_pmf None | Some a' \<Rightarrow> return_spmf a') = return_pmf a"
   540     by(simp split: option.split)
   541   then show ?thesis
   542     by(simp add: bind_spmf_def bind_return_pmf')
   543 qed
   544 
   545 lemma bind_spmf_assoc [simp]:
   546   fixes x :: "'a spmf" and f :: "'a \<Rightarrow> 'b spmf" and g :: "'b \<Rightarrow> 'c spmf"
   547   shows "(x \<bind> f) \<bind> g = x \<bind> (\<lambda>y. f y \<bind> g)"
   548 by(auto simp add: bind_spmf_def bind_assoc_pmf fun_eq_iff bind_return_pmf split: option.split intro: arg_cong[where f="bind_pmf x"])
   549 
   550 lemma pmf_bind_spmf_None: "pmf (p \<bind> f) None = pmf p None + \<integral> x. pmf (f x) None \<partial>measure_spmf p"
   551   (is "?lhs = ?rhs")
   552 proof -
   553   let ?f = "\<lambda>x. pmf (case x of None \<Rightarrow> return_pmf None | Some x \<Rightarrow> f x) None"
   554   have "?lhs = \<integral> x. ?f x \<partial>measure_pmf p"
   555     by(simp add: bind_spmf_def pmf_bind)
   556   also have "\<dots> = \<integral> x. ?f None * indicator {None} x + ?f x * indicator (range Some) x \<partial>measure_pmf p"
   557     by(rule Bochner_Integration.integral_cong)(auto simp add: indicator_def)
   558   also have "\<dots> = (\<integral> x. ?f None * indicator {None} x \<partial>measure_pmf p) + (\<integral> x. ?f x * indicator (range Some) x \<partial>measure_pmf p)"
   559     by(rule Bochner_Integration.integral_add)(auto 4 3 intro: integrable_real_mult_indicator measure_pmf.integrable_const_bound[where B=1] simp add: AE_measure_pmf_iff pmf_le_1)
   560   also have "\<dots> = pmf p None + \<integral> x. indicator (range Some) x * pmf (f (the x)) None \<partial>measure_pmf p"
   561     by(auto simp add: measure_measure_pmf_finite indicator_eq_0_iff intro!: Bochner_Integration.integral_cong)
   562   also have "\<dots> = ?rhs" unfolding measure_spmf_def
   563     by(subst integral_distr)(auto simp add: integral_restrict_space)
   564   finally show ?thesis .
   565 qed
   566 
   567 lemma spmf_bind: "spmf (p \<bind> f) y = \<integral> x. spmf (f x) y \<partial>measure_spmf p"
   568 unfolding measure_spmf_def
   569 by(subst integral_distr)(auto simp add: bind_spmf_def pmf_bind integral_restrict_space indicator_eq_0_iff intro!: Bochner_Integration.integral_cong split: option.split)
   570 
   571 lemma ennreal_spmf_bind: "ennreal (spmf (p \<bind> f) x) = \<integral>\<^sup>+ y. spmf (f y) x \<partial>measure_spmf p"
   572 by(auto simp add: bind_spmf_def ennreal_pmf_bind nn_integral_measure_spmf_conv_measure_pmf nn_integral_restrict_space intro: nn_integral_cong split: split_indicator option.split)
   573 
   574 lemma measure_spmf_bind_pmf: "measure_spmf (p \<bind> f) = measure_pmf p \<bind> measure_spmf \<circ> f"
   575   (is "?lhs = ?rhs")
   576 proof(rule measure_eqI)
   577   show "sets ?lhs = sets ?rhs"
   578     by(simp add: sets_bind[where N="count_space UNIV"] space_measure_spmf)
   579 next
   580   fix A :: "'a set"
   581   have "emeasure ?lhs A = \<integral>\<^sup>+ x. emeasure (measure_spmf (f x)) A \<partial>measure_pmf p"
   582     by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def)
   583   also have "\<dots> = emeasure ?rhs A"
   584     by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra)
   585   finally show "emeasure ?lhs A = emeasure ?rhs A" .
   586 qed
   587 
   588 lemma measure_spmf_bind: "measure_spmf (p \<bind> f) = measure_spmf p \<bind> measure_spmf \<circ> f"
   589   (is "?lhs = ?rhs")
   590 proof(rule measure_eqI)
   591   show "sets ?lhs = sets ?rhs"
   592     by(simp add: sets_bind[where N="count_space UNIV"] space_measure_spmf)
   593 next
   594   fix A :: "'a set"
   595   let ?A = "the -` A \<inter> range Some"
   596   have "emeasure ?lhs A = \<integral>\<^sup>+ x. emeasure (measure_pmf (case x of None \<Rightarrow> return_pmf None | Some x \<Rightarrow> f x)) ?A \<partial>measure_pmf p"
   597     by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def)
   598   also have "\<dots> =  \<integral>\<^sup>+ x. emeasure (measure_pmf (f (the x))) ?A * indicator (range Some) x \<partial>measure_pmf p"
   599     by(rule nn_integral_cong)(auto split: option.split simp add: indicator_def)
   600   also have "\<dots> = \<integral>\<^sup>+ x. emeasure (measure_spmf (f x)) A \<partial>measure_spmf p"
   601     by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space emeasure_distr space_restrict_space emeasure_restrict_space)
   602   also have "\<dots> = emeasure ?rhs A"
   603     by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra)
   604   finally show "emeasure ?lhs A = emeasure ?rhs A" .
   605 qed
   606 
   607 lemma map_spmf_bind_spmf: "map_spmf f (bind_spmf p g) = bind_spmf p (map_spmf f \<circ> g)"
   608 by(auto simp add: bind_spmf_def map_bind_pmf fun_eq_iff split: option.split intro: arg_cong2[where f=bind_pmf])
   609 
   610 lemma bind_map_spmf: "map_spmf f p \<bind> g = p \<bind> g \<circ> f"
   611 by(simp add: bind_spmf_def bind_map_pmf o_def cong del: option.case_cong_weak)
   612 
   613 lemma spmf_bind_leI:
   614   assumes "\<And>y. y \<in> set_spmf p \<Longrightarrow> spmf (f y) x \<le> r"
   615   and "0 \<le> r"
   616   shows "spmf (bind_spmf p f) x \<le> r"
   617 proof -
   618   have "ennreal (spmf (bind_spmf p f) x) = \<integral>\<^sup>+ y. spmf (f y) x \<partial>measure_spmf p" by(rule ennreal_spmf_bind)
   619   also have "\<dots> \<le> \<integral>\<^sup>+ y. r \<partial>measure_spmf p" by(rule nn_integral_mono_AE)(simp add: assms)
   620   also have "\<dots> \<le> r" using assms measure_spmf.emeasure_space_le_1
   621     by(auto simp add: measure_spmf.emeasure_eq_measure intro!: mult_left_le)
   622   finally show ?thesis using assms(2) by(simp)
   623 qed
   624 
   625 lemma map_spmf_conv_bind_spmf: "map_spmf f p = (p \<bind> (\<lambda>x. return_spmf (f x)))"
   626 by(simp add: map_pmf_def bind_spmf_def)(rule bind_pmf_cong, simp_all split: option.split)
   627 
   628 lemma bind_spmf_cong:
   629   "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q \<Longrightarrow> f x = g x \<rbrakk>
   630   \<Longrightarrow> bind_spmf p f = bind_spmf q g"
   631 by(auto simp add: bind_spmf_def in_set_spmf intro: bind_pmf_cong option.case_cong)
   632 
   633 lemma bind_spmf_cong_simp:
   634   "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q =simp=> f x = g x \<rbrakk>
   635   \<Longrightarrow> bind_spmf p f = bind_spmf q g"
   636 by(simp add: simp_implies_def cong: bind_spmf_cong)
   637 
   638 lemma set_bind_spmf: "set_spmf (M \<bind> f) = set_spmf M \<bind> (set_spmf \<circ> f)"
   639 by(auto simp add: set_spmf_def bind_spmf_def bind_UNION split: option.splits)
   640 
   641 lemma bind_spmf_const_return_None [simp]: "bind_spmf p (\<lambda>_. return_pmf None) = return_pmf None"
   642 by(simp add: bind_spmf_def case_option_collapse)
   643 
   644 lemma bind_commute_spmf:
   645   "bind_spmf p (\<lambda>x. bind_spmf q (f x)) = bind_spmf q (\<lambda>y. bind_spmf p (\<lambda>x. f x y))"
   646   (is "?lhs = ?rhs")
   647 proof -
   648   let ?f = "\<lambda>x y. case x of None \<Rightarrow> return_pmf None | Some a \<Rightarrow> (case y of None \<Rightarrow> return_pmf None | Some b \<Rightarrow> f a b)"
   649   have "?lhs = p \<bind> (\<lambda>x. q \<bind> ?f x)"
   650     unfolding bind_spmf_def by(rule bind_pmf_cong[OF refl])(simp split: option.split)
   651   also have "\<dots> = q \<bind> (\<lambda>y. p \<bind> (\<lambda>x. ?f x y))" by(rule bind_commute_pmf)
   652   also have "\<dots> = ?rhs" unfolding bind_spmf_def
   653     by(rule bind_pmf_cong[OF refl])(auto split: option.split, metis bind_spmf_const_return_None bind_spmf_def)
   654   finally show ?thesis .
   655 qed
   656 
   657 subsection \<open>Relator\<close>
   658 
   659 abbreviation rel_spmf :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> 'a spmf \<Rightarrow> 'b spmf \<Rightarrow> bool"
   660 where "rel_spmf R \<equiv> rel_pmf (rel_option R)"
   661 
   662 lemma rel_pmf_mono:
   663   "\<lbrakk>rel_pmf A f g; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> rel_pmf B f g"
   664 using pmf.rel_mono[of A B] by(simp add: le_fun_def)
   665 
   666 lemma rel_spmf_mono:
   667   "\<lbrakk>rel_spmf A f g; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> rel_spmf B f g"
   668 apply(erule rel_pmf_mono)
   669 using option.rel_mono[of A B] by(simp add: le_fun_def)
   670 
   671 lemma rel_spmf_mono_strong:
   672   "\<lbrakk> rel_spmf A f g; \<And>x y. \<lbrakk> A x y; x \<in> set_spmf f; y \<in> set_spmf g \<rbrakk> \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> rel_spmf B f g"
   673 apply(erule pmf.rel_mono_strong)
   674 apply(erule option.rel_mono_strong)
   675 apply(auto simp add: in_set_spmf)
   676 done
   677 
   678 lemma rel_spmf_reflI: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> P x x) \<Longrightarrow> rel_spmf P p p"
   679 by(rule rel_pmf_reflI)(auto simp add: set_spmf_def intro: rel_option_reflI)
   680 
   681 lemma rel_spmfI [intro?]:
   682   "\<lbrakk> \<And>x y. (x, y) \<in> set_spmf pq \<Longrightarrow> P x y; map_spmf fst pq = p; map_spmf snd pq = q \<rbrakk>
   683   \<Longrightarrow> rel_spmf P p q"
   684 by(rule rel_pmf.intros[where pq="map_pmf (\<lambda>x. case x of None \<Rightarrow> (None, None) | Some (a, b) \<Rightarrow> (Some a, Some b)) pq"])
   685   (auto simp add: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong)
   686 
   687 lemma rel_spmfE [elim?, consumes 1, case_names rel_spmf]:
   688   assumes "rel_spmf P p q"
   689   obtains pq where
   690     "\<And>x y. (x, y) \<in> set_spmf pq \<Longrightarrow> P x y"
   691     "p = map_spmf fst pq"
   692     "q = map_spmf snd pq"
   693 using assms
   694 proof(cases rule: rel_pmf.cases[consumes 1, case_names rel_pmf])
   695   case (rel_pmf pq)
   696   let ?pq = "map_pmf (\<lambda>(a, b). case (a, b) of (Some x, Some y) \<Rightarrow> Some (x, y) | _ \<Rightarrow> None) pq"
   697   have "\<And>x y. (x, y) \<in> set_spmf ?pq \<Longrightarrow> P x y"
   698     by(auto simp add: in_set_spmf split: option.split_asm dest: rel_pmf(1))
   699   moreover
   700   have "\<And>x. (x, None) \<in> set_pmf pq \<Longrightarrow> x = None" by(auto dest!: rel_pmf(1))
   701   then have "p = map_spmf fst ?pq" using rel_pmf(2)
   702     by(auto simp add: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split)
   703   moreover
   704   have "\<And>y. (None, y) \<in> set_pmf pq \<Longrightarrow> y = None" by(auto dest!: rel_pmf(1))
   705   then have "q = map_spmf snd ?pq" using rel_pmf(3)
   706     by(auto simp add: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split)
   707   ultimately show thesis ..
   708 qed
   709 
   710 lemma rel_spmf_simps:
   711   "rel_spmf R p q \<longleftrightarrow> (\<exists>pq. (\<forall>(x, y)\<in>set_spmf pq. R x y) \<and> map_spmf fst pq = p \<and> map_spmf snd pq = q)"
   712 by(auto intro: rel_spmfI elim!: rel_spmfE)
   713 
   714 lemma spmf_rel_map:
   715   shows spmf_rel_map1: "\<And>R f x. rel_spmf R (map_spmf f x) = rel_spmf (\<lambda>x. R (f x)) x"
   716   and spmf_rel_map2: "\<And>R x g y. rel_spmf R x (map_spmf g y) = rel_spmf (\<lambda>x y. R x (g y)) x y"
   717 by(simp_all add: fun_eq_iff pmf.rel_map option.rel_map[abs_def])
   718 
   719 lemma spmf_rel_conversep: "rel_spmf R\<inverse>\<inverse> = (rel_spmf R)\<inverse>\<inverse>"
   720 by(simp add: option.rel_conversep pmf.rel_conversep)
   721 
   722 lemma spmf_rel_eq: "rel_spmf op = = op ="
   723 by(simp add: pmf.rel_eq option.rel_eq)
   724 
   725 context includes lifting_syntax
   726 begin
   727 
   728 lemma bind_spmf_parametric [transfer_rule]:
   729   "(rel_spmf A ===> (A ===> rel_spmf B) ===> rel_spmf B) bind_spmf bind_spmf"
   730 unfolding bind_spmf_def[abs_def] by transfer_prover
   731 
   732 lemma return_spmf_parametric: "(A ===> rel_spmf A) return_spmf return_spmf"
   733 by transfer_prover
   734 
   735 lemma map_spmf_parametric: "((A ===> B) ===> rel_spmf A ===> rel_spmf B) map_spmf map_spmf"
   736 by transfer_prover
   737 
   738 lemma rel_spmf_parametric:
   739   "((A ===> B ===> op =) ===> rel_spmf A ===> rel_spmf B ===> op =) rel_spmf rel_spmf"
   740 by transfer_prover
   741 
   742 lemma set_spmf_parametric [transfer_rule]:
   743   "(rel_spmf A ===> rel_set A) set_spmf set_spmf"
   744 unfolding set_spmf_def[abs_def] by transfer_prover
   745 
   746 lemma return_spmf_None_parametric:
   747   "(rel_spmf A) (return_pmf None) (return_pmf None)"
   748 by simp
   749 
   750 end
   751 
   752 lemma rel_spmf_bindI:
   753   "\<lbrakk> rel_spmf R p q; \<And>x y. R x y \<Longrightarrow> rel_spmf P (f x) (g y) \<rbrakk>
   754   \<Longrightarrow> rel_spmf P (p \<bind> f) (q \<bind> g)"
   755 by(fact bind_spmf_parametric[THEN rel_funD, THEN rel_funD, OF _ rel_funI])
   756 
   757 lemma rel_spmf_bind_reflI:
   758   "(\<And>x. x \<in> set_spmf p \<Longrightarrow> rel_spmf P (f x) (g x)) \<Longrightarrow> rel_spmf P (p \<bind> f) (p \<bind> g)"
   759 by(rule rel_spmf_bindI[where R="\<lambda>x y. x = y \<and> x \<in> set_spmf p"])(auto intro: rel_spmf_reflI)
   760 
   761 lemma rel_pmf_return_pmfI: "P x y \<Longrightarrow> rel_pmf P (return_pmf x) (return_pmf y)"
   762 by(rule rel_pmf.intros[where pq="return_pmf (x, y)"])(simp_all)
   763 
   764 context includes lifting_syntax
   765 begin
   766 
   767 text \<open>We do not yet have a relator for @{typ "'a measure"}, so we combine @{const measure} and @{const measure_pmf}\<close>
   768 lemma measure_pmf_parametric:
   769   "(rel_pmf A ===> rel_pred A ===> op =) (\<lambda>p. measure (measure_pmf p)) (\<lambda>q. measure (measure_pmf q))"
   770 proof(rule rel_funI)+
   771   fix p q X Y
   772   assume "rel_pmf A p q" and "rel_pred A X Y"
   773   from this(1) obtain pq where A: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> A x y"
   774     and p: "p = map_pmf fst pq" and q: "q = map_pmf snd pq" by cases auto
   775   show "measure p X = measure q Y" unfolding p q measure_map_pmf
   776     by(rule measure_pmf.finite_measure_eq_AE)(auto simp add: AE_measure_pmf_iff dest!: A rel_predD[OF \<open>rel_pred _ _ _\<close>])
   777 qed
   778 
   779 lemma measure_spmf_parametric:
   780   "(rel_spmf A ===> rel_pred A ===> op =) (\<lambda>p. measure (measure_spmf p)) (\<lambda>q. measure (measure_spmf q))"
   781 unfolding measure_measure_spmf_conv_measure_pmf[abs_def]
   782 apply(rule rel_funI)+
   783 apply(erule measure_pmf_parametric[THEN rel_funD, THEN rel_funD])
   784 apply(auto simp add: rel_pred_def rel_fun_def elim: option.rel_cases)
   785 done
   786 
   787 end
   788 
   789 subsection \<open>From @{typ "'a pmf"} to @{typ "'a spmf"}\<close>
   790 
   791 definition spmf_of_pmf :: "'a pmf \<Rightarrow> 'a spmf"
   792 where "spmf_of_pmf = map_pmf Some"
   793 
   794 lemma set_spmf_spmf_of_pmf [simp]: "set_spmf (spmf_of_pmf p) = set_pmf p"
   795 by(auto simp add: spmf_of_pmf_def set_spmf_def bind_image o_def)
   796 
   797 lemma spmf_spmf_of_pmf [simp]: "spmf (spmf_of_pmf p) x = pmf p x"
   798 by(simp add: spmf_of_pmf_def)
   799 
   800 lemma pmf_spmf_of_pmf_None [simp]: "pmf (spmf_of_pmf p) None = 0"
   801 using ennreal_pmf_map[of Some p None] by(simp add: spmf_of_pmf_def)
   802 
   803 lemma emeasure_spmf_of_pmf [simp]: "emeasure (measure_spmf (spmf_of_pmf p)) A = emeasure (measure_pmf p) A"
   804 by(simp add: emeasure_measure_spmf_conv_measure_pmf spmf_of_pmf_def inj_vimage_image_eq)
   805 
   806 lemma measure_spmf_spmf_of_pmf [simp]: "measure_spmf (spmf_of_pmf p) = measure_pmf p"
   807 by(rule measure_eqI) simp_all
   808 
   809 lemma map_spmf_of_pmf [simp]: "map_spmf f (spmf_of_pmf p) = spmf_of_pmf (map_pmf f p)"
   810 by(simp add: spmf_of_pmf_def pmf.map_comp o_def)
   811 
   812 lemma rel_spmf_spmf_of_pmf [simp]: "rel_spmf R (spmf_of_pmf p) (spmf_of_pmf q) = rel_pmf R p q"
   813 by(simp add: spmf_of_pmf_def pmf.rel_map)
   814 
   815 lemma spmf_of_pmf_return_pmf [simp]: "spmf_of_pmf (return_pmf x) = return_spmf x"
   816 by(simp add: spmf_of_pmf_def)
   817 
   818 lemma bind_spmf_of_pmf [simp]: "bind_spmf (spmf_of_pmf p) f = bind_pmf p f"
   819 by(simp add: spmf_of_pmf_def bind_spmf_def bind_map_pmf)
   820 
   821 lemma set_spmf_bind_pmf: "set_spmf (bind_pmf p f) = Set.bind (set_pmf p) (set_spmf \<circ> f)"
   822 unfolding bind_spmf_of_pmf[symmetric] by(subst set_bind_spmf) simp
   823 
   824 lemma spmf_of_pmf_bind: "spmf_of_pmf (bind_pmf p f) = bind_pmf p (\<lambda>x. spmf_of_pmf (f x))"
   825 by(simp add: spmf_of_pmf_def map_bind_pmf)
   826 
   827 lemma bind_pmf_return_spmf: "p \<bind> (\<lambda>x. return_spmf (f x)) = spmf_of_pmf (map_pmf f p)"
   828 by(simp add: map_pmf_def spmf_of_pmf_bind)
   829 
   830 subsection \<open>Weight of a subprobability\<close>
   831 
   832 abbreviation weight_spmf :: "'a spmf \<Rightarrow> real"
   833 where "weight_spmf p \<equiv> measure (measure_spmf p) (space (measure_spmf p))"
   834 
   835 lemma weight_spmf_def: "weight_spmf p = measure (measure_spmf p) UNIV"
   836 by(simp add: space_measure_spmf)
   837 
   838 lemma weight_spmf_le_1: "weight_spmf p \<le> 1"
   839 by(simp add: measure_spmf.subprob_measure_le_1)
   840 
   841 lemma weight_return_spmf [simp]: "weight_spmf (return_spmf x) = 1"
   842 by(simp add: measure_spmf_return_spmf measure_return)
   843 
   844 lemma weight_return_pmf_None [simp]: "weight_spmf (return_pmf None) = 0"
   845 by(simp)
   846 
   847 lemma weight_map_spmf [simp]: "weight_spmf (map_spmf f p) = weight_spmf p"
   848 by(simp add: weight_spmf_def measure_map_spmf)
   849 
   850 lemma weight_spmf_of_pmf [simp]: "weight_spmf (spmf_of_pmf p) = 1"
   851 using measure_pmf.prob_space[of p] by(simp add: spmf_of_pmf_def weight_spmf_def)
   852 
   853 lemma weight_spmf_nonneg: "weight_spmf p \<ge> 0"
   854 by(fact measure_nonneg)
   855 
   856 lemma (in finite_measure) integrable_weight_spmf [simp]:
   857   "(\<lambda>x. weight_spmf (f x)) \<in> borel_measurable M \<Longrightarrow> integrable M (\<lambda>x. weight_spmf (f x))"
   858 by(rule integrable_const_bound[where B=1])(simp_all add: weight_spmf_nonneg weight_spmf_le_1)
   859 
   860 lemma weight_spmf_eq_nn_integral_spmf: "weight_spmf p = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
   861 by(simp add: measure_measure_spmf_conv_measure_pmf space_measure_spmf measure_pmf.emeasure_eq_measure[symmetric] nn_integral_pmf[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1)
   862 
   863 lemma weight_spmf_eq_nn_integral_support:
   864   "weight_spmf p = \<integral>\<^sup>+ x. spmf p x \<partial>count_space (set_spmf p)"
   865 unfolding weight_spmf_eq_nn_integral_spmf
   866 by(auto simp add: nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator)
   867 
   868 lemma pmf_None_eq_weight_spmf: "pmf p None = 1 - weight_spmf p"
   869 proof -
   870   have "weight_spmf p = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV" by(rule weight_spmf_eq_nn_integral_spmf)
   871   also have "\<dots> = \<integral>\<^sup>+ x. ennreal (pmf p x) * indicator (range Some) x \<partial>count_space UNIV"
   872     by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1)
   873   also have "\<dots> + pmf p None = \<integral>\<^sup>+ x. ennreal (pmf p x) * indicator (range Some) x + ennreal (pmf p None) * indicator {None} x \<partial>count_space UNIV"
   874     by(subst nn_integral_add)(simp_all add: max_def)
   875   also have "\<dots> = \<integral>\<^sup>+ x. pmf p x \<partial>count_space UNIV"
   876     by(rule nn_integral_cong)(auto split: split_indicator)
   877   also have "\<dots> = 1" by (simp add: nn_integral_pmf)
   878   finally show ?thesis by(simp add: ennreal_plus[symmetric] del: ennreal_plus)
   879 qed
   880 
   881 lemma weight_spmf_conv_pmf_None: "weight_spmf p = 1 - pmf p None"
   882 by(simp add: pmf_None_eq_weight_spmf)
   883 
   884 lemma weight_spmf_le_0: "weight_spmf p \<le> 0 \<longleftrightarrow> weight_spmf p = 0"
   885 by(rule measure_le_0_iff)
   886 
   887 lemma weight_spmf_lt_0: "\<not> weight_spmf p < 0"
   888 by(simp add: not_less weight_spmf_nonneg)
   889 
   890 lemma spmf_le_weight: "spmf p x \<le> weight_spmf p"
   891 proof -
   892   have "ennreal (spmf p x) \<le> weight_spmf p"
   893     unfolding weight_spmf_eq_nn_integral_spmf by(rule nn_integral_ge_point) simp
   894   then show ?thesis by simp
   895 qed
   896 
   897 lemma weight_spmf_eq_0: "weight_spmf p = 0 \<longleftrightarrow> p = return_pmf None"
   898 by(auto intro!: pmf_eqI simp add: pmf_None_eq_weight_spmf split: split_indicator)(metis not_Some_eq pmf_le_0_iff spmf_le_weight)
   899 
   900 lemma weight_bind_spmf: "weight_spmf (x \<bind> f) = lebesgue_integral (measure_spmf x) (weight_spmf \<circ> f)"
   901 unfolding weight_spmf_def
   902 by(simp add: measure_spmf_bind o_def measure_spmf.measure_bind[where N="count_space UNIV"])
   903 
   904 lemma rel_spmf_weightD: "rel_spmf A p q \<Longrightarrow> weight_spmf p = weight_spmf q"
   905 by(erule rel_spmfE) simp
   906 
   907 lemma rel_spmf_bij_betw:
   908   assumes f: "bij_betw f (set_spmf p) (set_spmf q)"
   909   and eq: "\<And>x. x \<in> set_spmf p \<Longrightarrow> spmf p x = spmf q (f x)"
   910   shows "rel_spmf (\<lambda>x y. f x = y) p q"
   911 proof -
   912   let ?f = "map_option f"
   913 
   914   have weq: "ennreal (weight_spmf p) = ennreal (weight_spmf q)"
   915     unfolding weight_spmf_eq_nn_integral_support
   916     by(subst nn_integral_bij_count_space[OF f, symmetric])(rule nn_integral_cong_AE, simp add: eq AE_count_space)
   917   then have "None \<in> set_pmf p \<longleftrightarrow> None \<in> set_pmf q"
   918     by(simp add: pmf_None_eq_weight_spmf set_pmf_iff)
   919   with f have "bij_betw (map_option f) (set_pmf p) (set_pmf q)"
   920     apply(auto simp add: bij_betw_def in_set_spmf inj_on_def intro: option.expand)
   921     apply(rename_tac [!] x)
   922     apply(case_tac [!] x)
   923     apply(auto iff: in_set_spmf)
   924     done
   925   then have "rel_pmf (\<lambda>x y. ?f x = y) p q"
   926     by(rule rel_pmf_bij_betw)(case_tac x, simp_all add: weq[simplified] eq in_set_spmf pmf_None_eq_weight_spmf)
   927   thus ?thesis by(rule pmf.rel_mono_strong)(auto intro!: rel_optionI simp add: Option.is_none_def)
   928 qed
   929 
   930 subsection \<open>From density to spmfs\<close>
   931 
   932 context fixes f :: "'a \<Rightarrow> real" begin
   933 
   934 definition embed_spmf :: "'a spmf"
   935 where "embed_spmf = embed_pmf (\<lambda>x. case x of None \<Rightarrow> 1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV) | Some x' \<Rightarrow> max 0 (f x'))"
   936 
   937 context
   938   assumes prob: "(\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV) \<le> 1"
   939 begin
   940 
   941 lemma nn_integral_embed_spmf_eq_1:
   942   "(\<integral>\<^sup>+ x. ennreal (case x of None \<Rightarrow> 1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV) | Some x' \<Rightarrow> max 0 (f x')) \<partial>count_space UNIV) = 1"
   943   (is "?lhs = _" is "(\<integral>\<^sup>+ x. ?f x \<partial>?M) = _")
   944 proof -
   945   have "?lhs = \<integral>\<^sup>+ x. ?f x * indicator {None} x + ?f x * indicator (range Some) x \<partial>?M"
   946     by(rule nn_integral_cong)(auto split: split_indicator)
   947   also have "\<dots> = (1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV)) + \<integral>\<^sup>+ x. ?f x * indicator (range Some) x \<partial>?M"
   948     (is "_ = ?None + ?Some")
   949     by(subst nn_integral_add)(simp_all add: AE_count_space max_def le_diff_eq real_le_ereal_iff one_ereal_def[symmetric] prob split: option.split)
   950   also have "?Some = \<integral>\<^sup>+ x. ?f x \<partial>count_space (range Some)"
   951     by(simp add: nn_integral_count_space_indicator)
   952   also have "count_space (range Some) = embed_measure (count_space UNIV) Some"
   953     by(simp add: embed_measure_count_space)
   954   also have "(\<integral>\<^sup>+ x. ?f x \<partial>\<dots>) = \<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV"
   955     by(subst nn_integral_embed_measure)(simp_all add: measurable_embed_measure1)
   956   also have "?None + \<dots> = 1" using prob
   957     by(auto simp add: ennreal_minus[symmetric] ennreal_1[symmetric] ennreal_enn2real_if top_unique simp del: ennreal_1)(simp add: diff_add_self_ennreal)
   958   finally show ?thesis .
   959 qed
   960 
   961 lemma pmf_embed_spmf_None: "pmf embed_spmf None = 1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV)"
   962 unfolding embed_spmf_def
   963 apply(subst pmf_embed_pmf)
   964   subgoal using prob by(simp add: field_simps enn2real_leI split: option.split)
   965  apply(rule nn_integral_embed_spmf_eq_1)
   966 apply simp
   967 done
   968 
   969 lemma spmf_embed_spmf [simp]: "spmf embed_spmf x = max 0 (f x)"
   970 unfolding embed_spmf_def
   971 apply(subst pmf_embed_pmf)
   972   subgoal using prob by(simp add: field_simps enn2real_leI split: option.split)
   973  apply(rule nn_integral_embed_spmf_eq_1)
   974 apply simp
   975 done
   976 
   977 end
   978 
   979 end
   980 
   981 lemma embed_spmf_K_0[simp]: "embed_spmf (\<lambda>_. 0) = return_pmf None" (is "?lhs = ?rhs")
   982 by(rule spmf_eqI)(simp add: zero_ereal_def[symmetric])
   983 
   984 subsection \<open>Ordering on spmfs\<close>
   985 
   986 text \<open>
   987   @{const rel_pmf} does not preserve a ccpo structure. Counterexample by Saheb-Djahromi:
   988   Take prefix order over \<open>bool llist\<close> and
   989   the set \<open>range (\<lambda>n :: nat. uniform (llist_n n))\<close> where \<open>llist_n\<close> is the set
   990   of all \<open>llist\<close>s of length \<open>n\<close> and \<open>uniform\<close> returns a uniform distribution over
   991   the given set. The set forms a chain in \<open>ord_pmf lprefix\<close>, but it has not an upper bound.
   992   Any upper bound may contain only infinite lists in its support because otherwise it is not greater
   993   than the \<open>n+1\<close>-st element in the chain where \<open>n\<close> is the length of the finite list.
   994   Moreover its support must contain all infinite lists, because otherwise there is a finite list
   995   all of whose finite extensions are not in the support - a contradiction to the upper bound property.
   996   Hence, the support is uncountable, but pmf's only have countable support.
   997 
   998   However, if all chains in the ccpo are finite, then it should preserve the ccpo structure.
   999 \<close>
  1000 
  1001 abbreviation ord_spmf :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> 'a spmf \<Rightarrow> 'a spmf \<Rightarrow> bool"
  1002 where "ord_spmf ord \<equiv> rel_pmf (ord_option ord)"
  1003 
  1004 locale ord_spmf_syntax begin
  1005 notation ord_spmf (infix "\<sqsubseteq>\<index>" 60)
  1006 end
  1007 
  1008 lemma ord_spmf_map_spmf1: "ord_spmf R (map_spmf f p) = ord_spmf (\<lambda>x. R (f x)) p"
  1009 by(simp add: pmf.rel_map[abs_def] ord_option_map1[abs_def])
  1010 
  1011 lemma ord_spmf_map_spmf2: "ord_spmf R p (map_spmf f q) = ord_spmf (\<lambda>x y. R x (f y)) p q"
  1012 by(simp add: pmf.rel_map ord_option_map2)
  1013 
  1014 lemma ord_spmf_map_spmf12: "ord_spmf R (map_spmf f p) (map_spmf f q) = ord_spmf (\<lambda>x y. R (f x) (f y)) p q"
  1015 by(simp add: pmf.rel_map ord_option_map1[abs_def] ord_option_map2)
  1016 
  1017 lemmas ord_spmf_map_spmf = ord_spmf_map_spmf1 ord_spmf_map_spmf2 ord_spmf_map_spmf12
  1018 
  1019 context fixes ord :: "'a \<Rightarrow> 'a \<Rightarrow> bool" (structure) begin
  1020 interpretation ord_spmf_syntax .
  1021 
  1022 lemma ord_spmfI:
  1023   "\<lbrakk> \<And>x y. (x, y) \<in> set_spmf pq \<Longrightarrow> ord x y; map_spmf fst pq = p; map_spmf snd pq = q \<rbrakk>
  1024   \<Longrightarrow> p \<sqsubseteq> q"
  1025 by(rule rel_pmf.intros[where pq="map_pmf (\<lambda>x. case x of None \<Rightarrow> (None, None) | Some (a, b) \<Rightarrow> (Some a, Some b)) pq"])
  1026   (auto simp add: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong)
  1027 
  1028 lemma ord_spmf_None [simp]: "return_pmf None \<sqsubseteq> x"
  1029 by(rule rel_pmf.intros[where pq="map_pmf (Pair None) x"])(auto simp add: pmf.map_comp o_def)
  1030 
  1031 lemma ord_spmf_reflI: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> ord x x) \<Longrightarrow> p \<sqsubseteq> p"
  1032 by(rule rel_pmf_reflI ord_option_reflI)+(auto simp add: in_set_spmf)
  1033 
  1034 lemma rel_spmf_inf:
  1035   assumes "p \<sqsubseteq> q"
  1036   and "q \<sqsubseteq> p"
  1037   and refl: "reflp ord"
  1038   and trans: "transp ord"
  1039   shows "rel_spmf (inf ord ord\<inverse>\<inverse>) p q"
  1040 proof -
  1041   from \<open>p \<sqsubseteq> q\<close> \<open>q \<sqsubseteq> p\<close>
  1042   have "rel_pmf (inf (ord_option ord) (ord_option ord)\<inverse>\<inverse>) p q"
  1043     by(rule rel_pmf_inf)(blast intro: reflp_ord_option transp_ord_option refl trans)+
  1044   also have "inf (ord_option ord) (ord_option ord)\<inverse>\<inverse> = rel_option (inf ord ord\<inverse>\<inverse>)"
  1045     by(auto simp add: fun_eq_iff elim: ord_option.cases option.rel_cases)
  1046   finally show ?thesis .
  1047 qed
  1048 
  1049 end
  1050 
  1051 lemma ord_spmf_return_spmf2: "ord_spmf R p (return_spmf y) \<longleftrightarrow> (\<forall>x\<in>set_spmf p. R x y)"
  1052 by(auto simp add: rel_pmf_return_pmf2 in_set_spmf ord_option.simps intro: ccontr)
  1053 
  1054 lemma ord_spmf_mono: "\<lbrakk> ord_spmf A p q; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> ord_spmf B p q"
  1055 by(erule rel_pmf_mono)(erule ord_option_mono)
  1056 
  1057 lemma ord_spmf_compp: "ord_spmf (A OO B) = ord_spmf A OO ord_spmf B"
  1058 by(simp add: ord_option_compp pmf.rel_compp)
  1059 
  1060 lemma ord_spmf_bindI:
  1061   assumes pq: "ord_spmf R p q"
  1062   and fg: "\<And>x y. R x y \<Longrightarrow> ord_spmf P (f x) (g y)"
  1063   shows "ord_spmf P (p \<bind> f) (q \<bind> g)"
  1064 unfolding bind_spmf_def using pq
  1065 by(rule rel_pmf_bindI)(auto split: option.split intro: fg)
  1066 
  1067 lemma ord_spmf_bind_reflI:
  1068   "(\<And>x. x \<in> set_spmf p \<Longrightarrow> ord_spmf R (f x) (g x))
  1069   \<Longrightarrow> ord_spmf R (p \<bind> f) (p \<bind> g)"
  1070 by(rule ord_spmf_bindI[where R="\<lambda>x y. x = y \<and> x \<in> set_spmf p"])(auto intro: ord_spmf_reflI)
  1071 
  1072 lemma ord_pmf_increaseI:
  1073   assumes le: "\<And>x. spmf p x \<le> spmf q x"
  1074   and refl: "\<And>x. x \<in> set_spmf p \<Longrightarrow> R x x"
  1075   shows "ord_spmf R p q"
  1076 proof(rule rel_pmf.intros)
  1077   define pq where "pq = embed_pmf
  1078     (\<lambda>(x, y). case x of Some x' \<Rightarrow> (case y of Some y' \<Rightarrow> if x' = y' then spmf p x' else 0 | None \<Rightarrow> 0)
  1079       | None \<Rightarrow> (case y of None \<Rightarrow> pmf q None | Some y' \<Rightarrow> spmf q y' - spmf p y'))"
  1080      (is "_ = embed_pmf ?f")
  1081   have nonneg: "\<And>xy. ?f xy \<ge> 0"
  1082     by(clarsimp simp add: le field_simps split: option.split)
  1083   have integral: "(\<integral>\<^sup>+ xy. ?f xy \<partial>count_space UNIV) = 1" (is "nn_integral ?M _ = _")
  1084   proof -
  1085     have "(\<integral>\<^sup>+ xy. ?f xy \<partial>count_space UNIV) =
  1086       \<integral>\<^sup>+ xy. ennreal (?f xy) * indicator {(None, None)} xy +
  1087              ennreal (?f xy) * indicator (range (\<lambda>x. (None, Some x))) xy +
  1088              ennreal (?f xy) * indicator (range (\<lambda>x. (Some x, Some x))) xy \<partial>?M"
  1089       by(rule nn_integral_cong)(auto split: split_indicator option.splits if_split_asm)
  1090     also have "\<dots> = (\<integral>\<^sup>+ xy. ?f xy * indicator {(None, None)} xy \<partial>?M) +
  1091         (\<integral>\<^sup>+ xy. ennreal (?f xy) * indicator (range (\<lambda>x. (None, Some x))) xy \<partial>?M) +
  1092         (\<integral>\<^sup>+ xy. ennreal (?f xy) * indicator (range (\<lambda>x. (Some x, Some x))) xy \<partial>?M)"
  1093       (is "_ = ?None + ?Some2 + ?Some")
  1094       by(subst nn_integral_add)(simp_all add: nn_integral_add AE_count_space le_diff_eq le split: option.split)
  1095     also have "?None = pmf q None" by simp
  1096     also have "?Some2 = \<integral>\<^sup>+ x. ennreal (spmf q x) - spmf p x \<partial>count_space UNIV"
  1097       by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1 ennreal_minus)
  1098     also have "\<dots> = (\<integral>\<^sup>+ x. spmf q x \<partial>count_space UNIV) - (\<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV)"
  1099       (is "_ = ?Some2' - ?Some2''")
  1100       by(subst nn_integral_diff)(simp_all add: le nn_integral_spmf_neq_top)
  1101     also have "?Some = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
  1102       by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1)
  1103     also have "pmf q None + (?Some2' - ?Some2'') + \<dots> = pmf q None + ?Some2'"
  1104       by(auto simp add: diff_add_self_ennreal le intro!: nn_integral_mono)
  1105     also have "\<dots> = \<integral>\<^sup>+ x. ennreal (pmf q x) * indicator {None} x + ennreal (pmf q x) * indicator (range Some) x \<partial>count_space UNIV"
  1106       by(subst nn_integral_add)(simp_all add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1)
  1107     also have "\<dots> = \<integral>\<^sup>+ x. pmf q x \<partial>count_space UNIV"
  1108       by(rule nn_integral_cong)(auto split: split_indicator)
  1109     also have "\<dots> = 1" by(simp add: nn_integral_pmf)
  1110     finally show ?thesis .
  1111   qed
  1112   note f = nonneg integral
  1113 
  1114   { fix x y
  1115     assume "(x, y) \<in> set_pmf pq"
  1116     hence "?f (x, y) \<noteq> 0" unfolding pq_def by(simp add: set_embed_pmf[OF f])
  1117     then show "ord_option R x y"
  1118       by(simp add: spmf_eq_0_set_spmf refl split: option.split_asm if_split_asm) }
  1119 
  1120   have weight_le: "weight_spmf p \<le> weight_spmf q"
  1121     by(subst ennreal_le_iff[symmetric])(auto simp add: weight_spmf_eq_nn_integral_spmf intro!: nn_integral_mono le)
  1122 
  1123   show "map_pmf fst pq = p"
  1124   proof(rule pmf_eqI)
  1125     fix i
  1126     have "ennreal (pmf (map_pmf fst pq) i) = (\<integral>\<^sup>+ y. pmf pq (i, y) \<partial>count_space UNIV)"
  1127       unfolding pq_def ennreal_pmf_map
  1128       apply(simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density nn_integral_count_space_indicator[symmetric])
  1129       apply(subst pmf_embed_pmf[OF f])
  1130       apply(rule nn_integral_bij_count_space[symmetric])
  1131       apply(auto simp add: bij_betw_def inj_on_def)
  1132       done
  1133     also have "\<dots> = pmf p i"
  1134     proof(cases i)
  1135       case (Some x)
  1136       have "(\<integral>\<^sup>+ y. pmf pq (Some x, y) \<partial>count_space UNIV) = \<integral>\<^sup>+ y. pmf p (Some x) * indicator {Some x} y \<partial>count_space UNIV"
  1137         by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split)
  1138       then show ?thesis using Some by simp
  1139     next
  1140       case None
  1141       have "(\<integral>\<^sup>+ y. pmf pq (None, y) \<partial>count_space UNIV) =
  1142             (\<integral>\<^sup>+ y. ennreal (pmf pq (None, Some (the y))) * indicator (range Some) y +
  1143                    ennreal (pmf pq (None, None)) * indicator {None} y \<partial>count_space UNIV)"
  1144         by(rule nn_integral_cong)(auto split: split_indicator)
  1145       also have "\<dots> = (\<integral>\<^sup>+ y. ennreal (pmf pq (None, Some (the y))) \<partial>count_space (range Some)) + pmf pq (None, None)"
  1146         by(subst nn_integral_add)(simp_all add: nn_integral_count_space_indicator)
  1147       also have "\<dots> = (\<integral>\<^sup>+ y. ennreal (spmf q y) - ennreal (spmf p y) \<partial>count_space UNIV) + pmf q None"
  1148         by(simp add: pq_def pmf_embed_pmf[OF f] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1 ennreal_minus)
  1149       also have "(\<integral>\<^sup>+ y. ennreal (spmf q y) - ennreal (spmf p y) \<partial>count_space UNIV) =
  1150                  (\<integral>\<^sup>+ y. spmf q y \<partial>count_space UNIV) - (\<integral>\<^sup>+ y. spmf p y \<partial>count_space UNIV)"
  1151         by(subst nn_integral_diff)(simp_all add: AE_count_space le nn_integral_spmf_neq_top split: split_indicator)
  1152       also have "\<dots> = pmf p None - pmf q None"
  1153         by(simp add: pmf_None_eq_weight_spmf weight_spmf_eq_nn_integral_spmf[symmetric] ennreal_minus)
  1154       also have "\<dots> = ennreal (pmf p None) - ennreal (pmf q None)" by(simp add: ennreal_minus)
  1155       finally show ?thesis using None weight_le
  1156         by(auto simp add: diff_add_self_ennreal pmf_None_eq_weight_spmf intro: ennreal_leI)
  1157     qed
  1158     finally show "pmf (map_pmf fst pq) i = pmf p i" by simp
  1159   qed
  1160 
  1161   show "map_pmf snd pq = q"
  1162   proof(rule pmf_eqI)
  1163     fix i
  1164     have "ennreal (pmf (map_pmf snd pq) i) = (\<integral>\<^sup>+ x. pmf pq (x, i) \<partial>count_space UNIV)"
  1165       unfolding pq_def ennreal_pmf_map
  1166       apply(simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density nn_integral_count_space_indicator[symmetric])
  1167       apply(subst pmf_embed_pmf[OF f])
  1168       apply(rule nn_integral_bij_count_space[symmetric])
  1169       apply(auto simp add: bij_betw_def inj_on_def)
  1170       done
  1171     also have "\<dots> = ennreal (pmf q i)"
  1172     proof(cases i)
  1173       case None
  1174       have "(\<integral>\<^sup>+ x. pmf pq (x, None) \<partial>count_space UNIV) = \<integral>\<^sup>+ x. pmf q None * indicator {None :: 'a option} x \<partial>count_space UNIV"
  1175         by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split)
  1176       then show ?thesis using None by simp
  1177     next
  1178       case (Some y)
  1179       have "(\<integral>\<^sup>+ x. pmf pq (x, Some y) \<partial>count_space UNIV) =
  1180         (\<integral>\<^sup>+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x +
  1181                ennreal (pmf pq (None, Some y)) * indicator {None} x \<partial>count_space UNIV)"
  1182         by(rule nn_integral_cong)(auto split: split_indicator)
  1183       also have "\<dots> = (\<integral>\<^sup>+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x \<partial>count_space UNIV) + pmf pq (None, Some y)"
  1184         by(subst nn_integral_add)(simp_all)
  1185       also have "\<dots> = (\<integral>\<^sup>+ x. ennreal (spmf p y) * indicator {Some y} x \<partial>count_space UNIV) + (spmf q y - spmf p y)"
  1186         by(auto simp add: pq_def pmf_embed_pmf[OF f] one_ereal_def[symmetric] simp del: nn_integral_indicator_singleton intro!: arg_cong2[where f="op +"] nn_integral_cong split: option.split)
  1187       also have "\<dots> = spmf q y" by(simp add: ennreal_minus[symmetric] le)
  1188       finally show ?thesis using Some by simp
  1189     qed
  1190     finally show "pmf (map_pmf snd pq) i = pmf q i" by simp
  1191   qed
  1192 qed
  1193 
  1194 lemma ord_spmf_eq_leD:
  1195   assumes "ord_spmf op = p q"
  1196   shows "spmf p x \<le> spmf q x"
  1197 proof(cases "x \<in> set_spmf p")
  1198   case False
  1199   thus ?thesis by(simp add: in_set_spmf_iff_spmf)
  1200 next
  1201   case True
  1202   from assms obtain pq
  1203     where pq: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> ord_option op = x y"
  1204     and p: "p = map_pmf fst pq"
  1205     and q: "q = map_pmf snd pq" by cases auto
  1206   have "ennreal (spmf p x) = integral\<^sup>N pq (indicator (fst -` {Some x}))"
  1207     using p by(simp add: ennreal_pmf_map)
  1208   also have "\<dots> = integral\<^sup>N pq (indicator {(Some x, Some x)})"
  1209     by(rule nn_integral_cong_AE)(auto simp add: AE_measure_pmf_iff split: split_indicator dest: pq)
  1210   also have "\<dots> \<le> integral\<^sup>N pq (indicator (snd -` {Some x}))"
  1211     by(rule nn_integral_mono) simp
  1212   also have "\<dots> = ennreal (spmf q x)" using q by(simp add: ennreal_pmf_map)
  1213   finally show ?thesis by simp
  1214 qed
  1215 
  1216 lemma ord_spmf_eqD_set_spmf: "ord_spmf op = p q \<Longrightarrow> set_spmf p \<subseteq> set_spmf q"
  1217 by(rule subsetI)(drule_tac x=x in ord_spmf_eq_leD, auto simp add: in_set_spmf_iff_spmf)
  1218 
  1219 lemma ord_spmf_eqD_emeasure:
  1220   "ord_spmf op = p q \<Longrightarrow> emeasure (measure_spmf p) A \<le> emeasure (measure_spmf q) A"
  1221 by(auto intro!: nn_integral_mono split: split_indicator dest: ord_spmf_eq_leD simp add: nn_integral_measure_spmf nn_integral_indicator[symmetric])
  1222 
  1223 lemma ord_spmf_eqD_measure_spmf: "ord_spmf op = p q \<Longrightarrow> measure_spmf p \<le> measure_spmf q"
  1224   by (subst le_measure) (auto simp: ord_spmf_eqD_emeasure)
  1225 
  1226 subsection \<open>CCPO structure for the flat ccpo @{term "ord_option op ="}\<close>
  1227 
  1228 context fixes Y :: "'a spmf set" begin
  1229 
  1230 definition lub_spmf :: "'a spmf"
  1231 where "lub_spmf = embed_spmf (\<lambda>x. enn2real (SUP p : Y. ennreal (spmf p x)))"
  1232   \<comment> \<open>We go through @{typ ennreal} to have a sensible definition even if @{term Y} is empty.\<close>
  1233 
  1234 lemma lub_spmf_empty [simp]: "SPMF.lub_spmf {} = return_pmf None"
  1235 by(simp add: SPMF.lub_spmf_def bot_ereal_def)
  1236 
  1237 context assumes chain: "Complete_Partial_Order.chain (ord_spmf op =) Y" begin
  1238 
  1239 lemma chain_ord_spmf_eqD: "Complete_Partial_Order.chain (op \<le>) ((\<lambda>p x. ennreal (spmf p x)) ` Y)"
  1240   (is "Complete_Partial_Order.chain _ (?f ` _)")
  1241 proof(rule chainI)
  1242   fix f g
  1243   assume "f \<in> ?f ` Y" "g \<in> ?f ` Y"
  1244   then obtain p q where f: "f = ?f p" "p \<in> Y" and g: "g = ?f q" "q \<in> Y" by blast
  1245   from chain \<open>p \<in> Y\<close> \<open>q \<in> Y\<close> have "ord_spmf op = p q \<or> ord_spmf op = q p" by(rule chainD)
  1246   thus "f \<le> g \<or> g \<le> f"
  1247   proof
  1248     assume "ord_spmf op = p q"
  1249     hence "\<And>x. spmf p x \<le> spmf q x" by(rule ord_spmf_eq_leD)
  1250     hence "f \<le> g" unfolding f g by(auto intro: le_funI)
  1251     thus ?thesis ..
  1252   next
  1253     assume "ord_spmf op = q p"
  1254     hence "\<And>x. spmf q x \<le> spmf p x" by(rule ord_spmf_eq_leD)
  1255     hence "g \<le> f" unfolding f g by(auto intro: le_funI)
  1256     thus ?thesis ..
  1257   qed
  1258 qed
  1259 
  1260 lemma ord_spmf_eq_pmf_None_eq:
  1261   assumes le: "ord_spmf op = p q"
  1262   and None: "pmf p None = pmf q None"
  1263   shows "p = q"
  1264 proof(rule spmf_eqI)
  1265   fix i
  1266   from le have le': "\<And>x. spmf p x \<le> spmf q x" by(rule ord_spmf_eq_leD)
  1267   have "(\<integral>\<^sup>+ x. ennreal (spmf q x) - spmf p x \<partial>count_space UNIV) =
  1268      (\<integral>\<^sup>+ x. spmf q x \<partial>count_space UNIV) - (\<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV)"
  1269     by(subst nn_integral_diff)(simp_all add: AE_count_space le' nn_integral_spmf_neq_top)
  1270   also have "\<dots> = (1 - pmf q None) - (1 - pmf p None)" unfolding pmf_None_eq_weight_spmf
  1271     by(simp add: weight_spmf_eq_nn_integral_spmf[symmetric] ennreal_minus)
  1272   also have "\<dots> = 0" using None by simp
  1273   finally have "\<And>x. spmf q x \<le> spmf p x"
  1274     by(simp add: nn_integral_0_iff_AE AE_count_space ennreal_minus ennreal_eq_0_iff)
  1275   with le' show "spmf p i = spmf q i" by(rule antisym)
  1276 qed
  1277 
  1278 lemma ord_spmf_eqD_pmf_None:
  1279   assumes "ord_spmf op = x y"
  1280   shows "pmf x None \<ge> pmf y None"
  1281 using assms
  1282 apply cases
  1283 apply(clarsimp simp only: ennreal_le_iff[symmetric, OF pmf_nonneg] ennreal_pmf_map)
  1284 apply(fastforce simp add: AE_measure_pmf_iff intro!: nn_integral_mono_AE)
  1285 done
  1286 
  1287 text \<open>
  1288   Chains on @{typ "'a spmf"} maintain countable support.
  1289   Thanks to Johannes Hölzl for the proof idea.
  1290 \<close>
  1291 lemma spmf_chain_countable: "countable (\<Union>p\<in>Y. set_spmf p)"
  1292 proof(cases "Y = {}")
  1293   case Y: False
  1294   show ?thesis
  1295   proof(cases "\<exists>x\<in>Y. \<forall>y\<in>Y. ord_spmf op = y x")
  1296     case True
  1297     then obtain x where x: "x \<in> Y" and upper: "\<And>y. y \<in> Y \<Longrightarrow> ord_spmf op = y x" by blast
  1298     hence "(\<Union>x\<in>Y. set_spmf x) \<subseteq> set_spmf x" by(auto dest: ord_spmf_eqD_set_spmf)
  1299     thus ?thesis by(rule countable_subset) simp
  1300   next
  1301     case False
  1302     define N :: "'a option pmf \<Rightarrow> real" where "N p = pmf p None" for p
  1303 
  1304     have N_less_imp_le_spmf: "\<lbrakk> x \<in> Y; y \<in> Y; N y < N x \<rbrakk> \<Longrightarrow> ord_spmf op = x y" for x y
  1305       using chainD[OF chain, of x y] ord_spmf_eqD_pmf_None[of x y] ord_spmf_eqD_pmf_None[of y x]
  1306       by (auto simp: N_def)
  1307     have N_eq_imp_eq: "\<lbrakk> x \<in> Y; y \<in> Y; N y = N x \<rbrakk> \<Longrightarrow> x = y" for x y
  1308       using chainD[OF chain, of x y] by(auto simp add: N_def dest: ord_spmf_eq_pmf_None_eq)
  1309 
  1310     have NC: "N ` Y \<noteq> {}" "bdd_below (N ` Y)"
  1311       using \<open>Y \<noteq> {}\<close> by(auto intro!: bdd_belowI[of _ 0] simp: N_def)
  1312     have NC_less: "Inf (N ` Y) < N x" if "x \<in> Y" for x unfolding cInf_less_iff[OF NC]
  1313     proof(rule ccontr)
  1314       assume **: "\<not> (\<exists>y\<in>N ` Y. y < N x)"
  1315       { fix y
  1316         assume "y \<in> Y"
  1317         with ** consider "N x < N y" | "N x = N y" by(auto simp add: not_less le_less)
  1318         hence "ord_spmf op = y x" using \<open>y \<in> Y\<close> \<open>x \<in> Y\<close>
  1319           by cases(auto dest: N_less_imp_le_spmf N_eq_imp_eq intro: ord_spmf_reflI) }
  1320       with False \<open>x \<in> Y\<close> show False by blast
  1321     qed
  1322 
  1323     from NC have "Inf (N ` Y) \<in> closure (N ` Y)" by (intro closure_contains_Inf)
  1324     then obtain X' where "\<And>n. X' n \<in> N ` Y" and X': "X' \<longlonglongrightarrow> Inf (N ` Y)"
  1325       unfolding closure_sequential by auto
  1326     then obtain X where X: "\<And>n. X n \<in> Y" and "X' = (\<lambda>n. N (X n))" unfolding image_iff Bex_def by metis
  1327 
  1328     with X' have seq: "(\<lambda>n. N (X n)) \<longlonglongrightarrow> Inf (N ` Y)" by simp
  1329     have "(\<Union>x \<in> Y. set_spmf x) \<subseteq> (\<Union>n. set_spmf (X n))"
  1330     proof(rule UN_least)
  1331       fix x
  1332       assume "x \<in> Y"
  1333       from order_tendstoD(2)[OF seq NC_less[OF \<open>x \<in> Y\<close>]]
  1334       obtain i where "N (X i) < N x" by (auto simp: eventually_sequentially)
  1335       thus "set_spmf x \<subseteq> (\<Union>n. set_spmf (X n))" using X \<open>x \<in> Y\<close>
  1336         by(blast dest: N_less_imp_le_spmf ord_spmf_eqD_set_spmf)
  1337     qed
  1338     thus ?thesis by(rule countable_subset) simp
  1339   qed
  1340 qed simp
  1341 
  1342 lemma lub_spmf_subprob: "(\<integral>\<^sup>+ x. (SUP p : Y. ennreal (spmf p x)) \<partial>count_space UNIV) \<le> 1"
  1343 proof(cases "Y = {}")
  1344   case True
  1345   thus ?thesis by(simp add: bot_ennreal)
  1346 next
  1347   case False
  1348   let ?B = "\<Union>p\<in>Y. set_spmf p"
  1349   have countable: "countable ?B" by(rule spmf_chain_countable)
  1350 
  1351   have "(\<integral>\<^sup>+ x. (SUP p:Y. ennreal (spmf p x)) \<partial>count_space UNIV) =
  1352         (\<integral>\<^sup>+ x. (SUP p:Y. ennreal (spmf p x) * indicator ?B x) \<partial>count_space UNIV)"
  1353     by(intro nn_integral_cong SUP_cong)(auto split: split_indicator simp add: spmf_eq_0_set_spmf)
  1354   also have "\<dots> = (\<integral>\<^sup>+ x. (SUP p:Y. ennreal (spmf p x)) \<partial>count_space ?B)"
  1355     unfolding ennreal_indicator[symmetric] using False
  1356     by(subst SUP_mult_right_ennreal[symmetric])(simp add: ennreal_indicator nn_integral_count_space_indicator)
  1357   also have "\<dots> = (SUP p:Y. \<integral>\<^sup>+ x. spmf p x \<partial>count_space ?B)" using False _ countable
  1358     by(rule nn_integral_monotone_convergence_SUP_countable)(rule chain_ord_spmf_eqD)
  1359   also have "\<dots> \<le> 1"
  1360   proof(rule SUP_least)
  1361     fix p
  1362     assume "p \<in> Y"
  1363     have "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space ?B) = \<integral>\<^sup>+ x. ennreal (spmf p x) * indicator ?B x \<partial>count_space UNIV"
  1364       by(simp add: nn_integral_count_space_indicator)
  1365     also have "\<dots> = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
  1366       by(rule nn_integral_cong)(auto split: split_indicator simp add: spmf_eq_0_set_spmf \<open>p \<in> Y\<close>)
  1367     also have "\<dots> \<le> 1"
  1368       by(simp add: weight_spmf_eq_nn_integral_spmf[symmetric] weight_spmf_le_1)
  1369     finally show "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space ?B) \<le> 1" .
  1370   qed
  1371   finally show ?thesis .
  1372 qed
  1373 
  1374 lemma spmf_lub_spmf:
  1375   assumes "Y \<noteq> {}"
  1376   shows "spmf lub_spmf x = (SUP p : Y. spmf p x)"
  1377 proof -
  1378   from assms obtain p where "p \<in> Y" by auto
  1379   have "spmf lub_spmf x = max 0 (enn2real (SUP p:Y. ennreal (spmf p x)))" unfolding lub_spmf_def
  1380     by(rule spmf_embed_spmf)(simp del: SUP_eq_top_iff Sup_eq_top_iff add: ennreal_enn2real_if SUP_spmf_neq_top' lub_spmf_subprob)
  1381   also have "\<dots> = enn2real (SUP p:Y. ennreal (spmf p x))"
  1382     by(rule max_absorb2)(simp)
  1383   also have "\<dots> = enn2real (ennreal (SUP p : Y. spmf p x))" using assms
  1384     by(subst ennreal_SUP[symmetric])(simp_all add: SUP_spmf_neq_top' del: SUP_eq_top_iff Sup_eq_top_iff)
  1385   also have "0 \<le> (\<Squnion>p\<in>Y. spmf p x)" using assms
  1386     by(auto intro!: cSUP_upper2 bdd_aboveI[where M=1] simp add: pmf_le_1)
  1387   then have "enn2real (ennreal (SUP p : Y. spmf p x)) = (SUP p : Y. spmf p x)"
  1388     by(rule enn2real_ennreal)
  1389   finally show ?thesis .
  1390 qed
  1391 
  1392 lemma ennreal_spmf_lub_spmf: "Y \<noteq> {} \<Longrightarrow> ennreal (spmf lub_spmf x) = (SUP p:Y. ennreal (spmf p x))"
  1393 unfolding spmf_lub_spmf by(subst ennreal_SUP)(simp_all add: SUP_spmf_neq_top' del: SUP_eq_top_iff Sup_eq_top_iff)
  1394 
  1395 lemma lub_spmf_upper:
  1396   assumes p: "p \<in> Y"
  1397   shows "ord_spmf op = p lub_spmf"
  1398 proof(rule ord_pmf_increaseI)
  1399   fix x
  1400   from p have [simp]: "Y \<noteq> {}" by auto
  1401   from p have "ennreal (spmf p x) \<le> (SUP p:Y. ennreal (spmf p x))" by(rule SUP_upper)
  1402   also have "\<dots> = ennreal (spmf lub_spmf x)" using p
  1403     by(subst spmf_lub_spmf)(auto simp add: ennreal_SUP SUP_spmf_neq_top' simp del: SUP_eq_top_iff Sup_eq_top_iff)
  1404   finally show "spmf p x \<le> spmf lub_spmf x" by simp
  1405 qed simp
  1406 
  1407 lemma lub_spmf_least:
  1408   assumes z: "\<And>x. x \<in> Y \<Longrightarrow> ord_spmf op = x z"
  1409   shows "ord_spmf op = lub_spmf z"
  1410 proof(cases "Y = {}")
  1411   case nonempty: False
  1412   show ?thesis
  1413   proof(rule ord_pmf_increaseI)
  1414     fix x
  1415     from nonempty obtain p where p: "p \<in> Y" by auto
  1416     have "ennreal (spmf lub_spmf x) = (SUP p:Y. ennreal (spmf p x))"
  1417       by(subst spmf_lub_spmf)(auto simp add: ennreal_SUP SUP_spmf_neq_top' nonempty simp del: SUP_eq_top_iff Sup_eq_top_iff)
  1418     also have "\<dots> \<le> ennreal (spmf z x)" by(rule SUP_least)(simp add: ord_spmf_eq_leD z)
  1419     finally show "spmf lub_spmf x \<le> spmf z x" by simp
  1420   qed simp
  1421 qed simp
  1422 
  1423 lemma set_lub_spmf: "set_spmf lub_spmf = (\<Union>p\<in>Y. set_spmf p)" (is "?lhs = ?rhs")
  1424 proof(cases "Y = {}")
  1425   case [simp]: False
  1426   show ?thesis
  1427   proof(rule set_eqI)
  1428     fix x
  1429     have "x \<in> ?lhs \<longleftrightarrow> ennreal (spmf lub_spmf x) > 0"
  1430       by(simp_all add: in_set_spmf_iff_spmf less_le)
  1431     also have "\<dots> \<longleftrightarrow> (\<exists>p\<in>Y. ennreal (spmf p x) > 0)"
  1432       by(simp add: ennreal_spmf_lub_spmf less_SUP_iff)
  1433     also have "\<dots> \<longleftrightarrow> x \<in> ?rhs"
  1434       by(auto simp add: in_set_spmf_iff_spmf less_le)
  1435     finally show "x \<in> ?lhs \<longleftrightarrow> x \<in> ?rhs" .
  1436   qed
  1437 qed simp
  1438 
  1439 lemma emeasure_lub_spmf:
  1440   assumes Y: "Y \<noteq> {}"
  1441   shows "emeasure (measure_spmf lub_spmf) A = (SUP y:Y. emeasure (measure_spmf y) A)"
  1442   (is "?lhs = ?rhs")
  1443 proof -
  1444   let ?M = "count_space (set_spmf lub_spmf)"
  1445   have "?lhs = \<integral>\<^sup>+ x. ennreal (spmf lub_spmf x) * indicator A x \<partial>?M"
  1446     by(auto simp add: nn_integral_indicator[symmetric] nn_integral_measure_spmf')
  1447   also have "\<dots> = \<integral>\<^sup>+ x. (SUP y:Y. ennreal (spmf y x) * indicator A x) \<partial>?M"
  1448     unfolding ennreal_indicator[symmetric]
  1449     by(simp add: spmf_lub_spmf assms ennreal_SUP[OF SUP_spmf_neq_top'] SUP_mult_right_ennreal)
  1450   also from assms have "\<dots> = (SUP y:Y. \<integral>\<^sup>+ x. ennreal (spmf y x) * indicator A x \<partial>?M)"
  1451   proof(rule nn_integral_monotone_convergence_SUP_countable)
  1452     have "(\<lambda>i x. ennreal (spmf i x) * indicator A x) ` Y = (\<lambda>f x. f x * indicator A x) ` (\<lambda>p x. ennreal (spmf p x)) ` Y"
  1453       by(simp add: image_image)
  1454     also have "Complete_Partial_Order.chain op \<le> \<dots>" using chain_ord_spmf_eqD
  1455       by(rule chain_imageI)(auto simp add: le_fun_def split: split_indicator)
  1456     finally show "Complete_Partial_Order.chain op \<le> ((\<lambda>i x. ennreal (spmf i x) * indicator A x) ` Y)" .
  1457   qed simp
  1458   also have "\<dots> = (SUP y:Y. \<integral>\<^sup>+ x. ennreal (spmf y x) * indicator A x \<partial>count_space UNIV)"
  1459     by(auto simp add: nn_integral_count_space_indicator set_lub_spmf spmf_eq_0_set_spmf split: split_indicator intro!: SUP_cong nn_integral_cong)
  1460   also have "\<dots> = ?rhs"
  1461     by(auto simp add: nn_integral_indicator[symmetric] nn_integral_measure_spmf)
  1462   finally show ?thesis .
  1463 qed
  1464 
  1465 lemma measure_lub_spmf:
  1466   assumes Y: "Y \<noteq> {}"
  1467   shows "measure (measure_spmf lub_spmf) A = (SUP y:Y. measure (measure_spmf y) A)" (is "?lhs = ?rhs")
  1468 proof -
  1469   have "ennreal ?lhs = ennreal ?rhs"
  1470     using emeasure_lub_spmf[OF assms] SUP_emeasure_spmf_neq_top[of A Y] Y
  1471     unfolding measure_spmf.emeasure_eq_measure by(subst ennreal_SUP)
  1472   moreover have "0 \<le> ?rhs" using Y
  1473     by(auto intro!: cSUP_upper2 bdd_aboveI[where M=1] measure_spmf.subprob_measure_le_1)
  1474   ultimately show ?thesis by(simp)
  1475 qed
  1476 
  1477 lemma weight_lub_spmf:
  1478   assumes Y: "Y \<noteq> {}"
  1479   shows "weight_spmf lub_spmf = (SUP y:Y. weight_spmf y)"
  1480 unfolding weight_spmf_def by(rule measure_lub_spmf) fact
  1481 
  1482 lemma measure_spmf_lub_spmf:
  1483   assumes Y: "Y \<noteq> {}"
  1484   shows "measure_spmf lub_spmf = (SUP p:Y. measure_spmf p)" (is "?lhs = ?rhs")
  1485 proof(rule measure_eqI)
  1486   from assms obtain p where p: "p \<in> Y" by auto
  1487   from chain have chain': "Complete_Partial_Order.chain op \<le> (measure_spmf ` Y)"
  1488     by(rule chain_imageI)(rule ord_spmf_eqD_measure_spmf)
  1489   show "sets ?lhs = sets ?rhs"
  1490     using Y by (subst sets_SUP) auto
  1491   show "emeasure ?lhs A = emeasure ?rhs A" for A
  1492     using chain' Y p by (subst emeasure_SUP_chain) (auto simp:  emeasure_lub_spmf)
  1493 qed
  1494 
  1495 end
  1496 
  1497 end
  1498 
  1499 lemma partial_function_definitions_spmf: "partial_function_definitions (ord_spmf op =) lub_spmf"
  1500   (is "partial_function_definitions ?R _")
  1501 proof
  1502   fix x show "?R x x" by(simp add: ord_spmf_reflI)
  1503 next
  1504   fix x y z
  1505   assume "?R x y" "?R y z"
  1506   with transp_ord_option[OF transp_equality] show "?R x z" by(rule transp_rel_pmf[THEN transpD])
  1507 next
  1508   fix x y
  1509   assume "?R x y" "?R y x"
  1510   thus "x = y"
  1511     by(rule rel_pmf_antisym)(simp_all add: reflp_ord_option transp_ord_option antisymP_ord_option)
  1512 next
  1513   fix Y x
  1514   assume "Complete_Partial_Order.chain ?R Y" "x \<in> Y"
  1515   then show "?R x (lub_spmf Y)"
  1516     by(rule lub_spmf_upper)
  1517 next
  1518   fix Y z
  1519   assume "Complete_Partial_Order.chain ?R Y" "\<And>x. x \<in> Y \<Longrightarrow> ?R x z"
  1520   then show "?R (lub_spmf Y) z"
  1521     by(cases "Y = {}")(simp_all add: lub_spmf_least)
  1522 qed
  1523 
  1524 lemma ccpo_spmf: "class.ccpo lub_spmf (ord_spmf op =) (mk_less (ord_spmf op =))"
  1525 by(rule ccpo partial_function_definitions_spmf)+
  1526 
  1527 interpretation spmf: partial_function_definitions "ord_spmf op =" "lub_spmf"
  1528   rewrites "lub_spmf {} \<equiv> return_pmf None"
  1529 by(rule partial_function_definitions_spmf) simp
  1530 
  1531 declaration \<open>Partial_Function.init "spmf" @{term spmf.fixp_fun}
  1532   @{term spmf.mono_body} @{thm spmf.fixp_rule_uc} @{thm spmf.fixp_induct_uc}
  1533   NONE\<close>
  1534 
  1535 declare spmf.leq_refl[simp]
  1536 declare admissible_leI[OF ccpo_spmf, cont_intro]
  1537 
  1538 abbreviation "mono_spmf \<equiv> monotone (fun_ord (ord_spmf op =)) (ord_spmf op =)"
  1539 
  1540 lemma lub_spmf_const [simp]: "lub_spmf {p} = p"
  1541 by(rule spmf_eqI)(simp add: spmf_lub_spmf[OF ccpo.chain_singleton[OF ccpo_spmf]])
  1542 
  1543 lemma bind_spmf_mono':
  1544   assumes fg: "ord_spmf op = f g"
  1545   and hk: "\<And>x :: 'a. ord_spmf op = (h x) (k x)"
  1546   shows "ord_spmf op = (f \<bind> h) (g \<bind> k)"
  1547 unfolding bind_spmf_def using assms(1)
  1548 by(rule rel_pmf_bindI)(auto split: option.split simp add: hk)
  1549 
  1550 lemma bind_spmf_mono [partial_function_mono]:
  1551   assumes mf: "mono_spmf B" and mg: "\<And>y. mono_spmf (\<lambda>f. C y f)"
  1552   shows "mono_spmf (\<lambda>f. bind_spmf (B f) (\<lambda>y. C y f))"
  1553 proof (rule monotoneI)
  1554   fix f g :: "'a \<Rightarrow> 'b spmf"
  1555   assume fg: "fun_ord (ord_spmf op =) f g"
  1556   with mf have "ord_spmf op = (B f) (B g)" by (rule monotoneD[of _ _ _ f g])
  1557   moreover from mg have "\<And>y'. ord_spmf op = (C y' f) (C y' g)"
  1558     by (rule monotoneD) (rule fg)
  1559   ultimately show "ord_spmf op = (bind_spmf (B f) (\<lambda>y. C y f)) (bind_spmf (B g) (\<lambda>y'. C y' g))"
  1560     by(rule bind_spmf_mono')
  1561 qed
  1562 
  1563 lemma monotone_bind_spmf1: "monotone (ord_spmf op =) (ord_spmf op =) (\<lambda>y. bind_spmf y g)"
  1564 by(rule monotoneI)(simp add: bind_spmf_mono' ord_spmf_reflI)
  1565 
  1566 lemma monotone_bind_spmf2:
  1567   assumes g: "\<And>x. monotone ord (ord_spmf op =) (\<lambda>y. g y x)"
  1568   shows "monotone ord (ord_spmf op =) (\<lambda>y. bind_spmf p (g y))"
  1569 by(rule monotoneI)(auto intro: bind_spmf_mono' monotoneD[OF g] ord_spmf_reflI)
  1570 
  1571 lemma bind_lub_spmf:
  1572   assumes chain: "Complete_Partial_Order.chain (ord_spmf op =) Y"
  1573   shows "bind_spmf (lub_spmf Y) f = lub_spmf ((\<lambda>p. bind_spmf p f) ` Y)" (is "?lhs = ?rhs")
  1574 proof(cases "Y = {}")
  1575   case Y: False
  1576   show ?thesis
  1577   proof(rule spmf_eqI)
  1578     fix i
  1579     have chain': "Complete_Partial_Order.chain op \<le> ((\<lambda>p x. ennreal (spmf p x * spmf (f x) i)) ` Y)"
  1580       using chain by(rule chain_imageI)(auto simp add: le_fun_def dest: ord_spmf_eq_leD intro: mult_right_mono)
  1581     have chain'': "Complete_Partial_Order.chain (ord_spmf op =) ((\<lambda>p. p \<bind> f) ` Y)"
  1582       using chain by(rule chain_imageI)(auto intro!: monotoneI bind_spmf_mono' ord_spmf_reflI)
  1583     let ?M = "count_space (set_spmf (lub_spmf Y))"
  1584     have "ennreal (spmf ?lhs i) = \<integral>\<^sup>+ x. ennreal (spmf (lub_spmf Y) x) * ennreal (spmf (f x) i) \<partial>?M"
  1585       by(auto simp add: ennreal_spmf_lub_spmf ennreal_spmf_bind nn_integral_measure_spmf')
  1586     also have "\<dots> = \<integral>\<^sup>+ x. (SUP p:Y. ennreal (spmf p x * spmf (f x) i)) \<partial>?M"
  1587       by(subst ennreal_spmf_lub_spmf[OF chain Y])(subst SUP_mult_right_ennreal, simp_all add: ennreal_mult Y)
  1588     also have "\<dots> = (SUP p:Y. \<integral>\<^sup>+ x. ennreal (spmf p x * spmf (f x) i) \<partial>?M)"
  1589       using Y chain' by(rule nn_integral_monotone_convergence_SUP_countable) simp
  1590     also have "\<dots> = (SUP p:Y. ennreal (spmf (bind_spmf p f) i))"
  1591       by(auto simp add: ennreal_spmf_bind nn_integral_measure_spmf nn_integral_count_space_indicator set_lub_spmf[OF chain] in_set_spmf_iff_spmf ennreal_mult intro!: SUP_cong nn_integral_cong split: split_indicator)
  1592     also have "\<dots> = ennreal (spmf ?rhs i)" using chain'' by(simp add: ennreal_spmf_lub_spmf Y)
  1593     finally show "spmf ?lhs i = spmf ?rhs i" by simp
  1594   qed
  1595 qed simp
  1596 
  1597 lemma map_lub_spmf:
  1598   "Complete_Partial_Order.chain (ord_spmf op =) Y
  1599   \<Longrightarrow> map_spmf f (lub_spmf Y) = lub_spmf (map_spmf f ` Y)"
  1600 unfolding map_spmf_conv_bind_spmf[abs_def] by(simp add: bind_lub_spmf o_def)
  1601 
  1602 lemma mcont_bind_spmf1: "mcont lub_spmf (ord_spmf op =) lub_spmf (ord_spmf op =) (\<lambda>y. bind_spmf y f)"
  1603 using monotone_bind_spmf1 by(rule mcontI)(rule contI, simp add: bind_lub_spmf)
  1604 
  1605 lemma bind_lub_spmf2:
  1606   assumes chain: "Complete_Partial_Order.chain ord Y"
  1607   and g: "\<And>y. monotone ord (ord_spmf op =) (g y)"
  1608   shows "bind_spmf x (\<lambda>y. lub_spmf (g y ` Y)) = lub_spmf ((\<lambda>p. bind_spmf x (\<lambda>y. g y p)) ` Y)"
  1609   (is "?lhs = ?rhs")
  1610 proof(cases "Y = {}")
  1611   case Y: False
  1612   show ?thesis
  1613   proof(rule spmf_eqI)
  1614     fix i
  1615     have chain': "\<And>y. Complete_Partial_Order.chain (ord_spmf op =) (g y ` Y)"
  1616       using chain g[THEN monotoneD] by(rule chain_imageI)
  1617     have chain'': "Complete_Partial_Order.chain op \<le> ((\<lambda>p y. ennreal (spmf x y * spmf (g y p) i)) ` Y)"
  1618       using chain by(rule chain_imageI)(auto simp add: le_fun_def dest: ord_spmf_eq_leD monotoneD[OF g] intro!: mult_left_mono)
  1619     have chain''': "Complete_Partial_Order.chain (ord_spmf op =) ((\<lambda>p. bind_spmf x (\<lambda>y. g y p)) ` Y)"
  1620       using chain by(rule chain_imageI)(rule monotone_bind_spmf2[OF g, THEN monotoneD])
  1621 
  1622     have "ennreal (spmf ?lhs i) = \<integral>\<^sup>+ y. (SUP p:Y. ennreal (spmf x y * spmf (g y p) i)) \<partial>count_space (set_spmf x)"
  1623       by(simp add: ennreal_spmf_bind ennreal_spmf_lub_spmf[OF chain'] Y nn_integral_measure_spmf' SUP_mult_left_ennreal ennreal_mult)
  1624     also have "\<dots> = (SUP p:Y. \<integral>\<^sup>+ y. ennreal (spmf x y * spmf (g y p) i) \<partial>count_space (set_spmf x))"
  1625       unfolding nn_integral_measure_spmf' using Y chain''
  1626       by(rule nn_integral_monotone_convergence_SUP_countable) simp
  1627     also have "\<dots> = (SUP p:Y. ennreal (spmf (bind_spmf x (\<lambda>y. g y p)) i))"
  1628       by(simp add: ennreal_spmf_bind nn_integral_measure_spmf' ennreal_mult)
  1629     also have "\<dots> = ennreal (spmf ?rhs i)" using chain'''
  1630       by(auto simp add: ennreal_spmf_lub_spmf Y)
  1631     finally show "spmf ?lhs i = spmf ?rhs i" by simp
  1632   qed
  1633 qed simp
  1634 
  1635 lemma mcont_bind_spmf [cont_intro]:
  1636   assumes f: "mcont luba orda lub_spmf (ord_spmf op =) f"
  1637   and g: "\<And>y. mcont luba orda lub_spmf (ord_spmf op =) (g y)"
  1638   shows "mcont luba orda lub_spmf (ord_spmf op =) (\<lambda>x. bind_spmf (f x) (\<lambda>y. g y x))"
  1639 proof(rule spmf.mcont2mcont'[OF _ _ f])
  1640   fix z
  1641   show "mcont lub_spmf (ord_spmf op =) lub_spmf (ord_spmf op =) (\<lambda>x. bind_spmf x (\<lambda>y. g y z))"
  1642     by(rule mcont_bind_spmf1)
  1643 next
  1644   fix x
  1645   let ?f = "\<lambda>z. bind_spmf x (\<lambda>y. g y z)"
  1646   have "monotone orda (ord_spmf op =) ?f" using mcont_mono[OF g] by(rule monotone_bind_spmf2)
  1647   moreover have "cont luba orda lub_spmf (ord_spmf op =) ?f"
  1648   proof(rule contI)
  1649     fix Y
  1650     assume chain: "Complete_Partial_Order.chain orda Y" and Y: "Y \<noteq> {}"
  1651     have "bind_spmf x (\<lambda>y. g y (luba Y)) = bind_spmf x (\<lambda>y. lub_spmf (g y ` Y))"
  1652       by(rule bind_spmf_cong)(simp_all add: mcont_contD[OF g chain Y])
  1653     also have "\<dots> = lub_spmf ((\<lambda>p. x \<bind> (\<lambda>y. g y p)) ` Y)" using chain
  1654       by(rule bind_lub_spmf2)(rule mcont_mono[OF g])
  1655     finally show "bind_spmf x (\<lambda>y. g y (luba Y)) = \<dots>" .
  1656   qed
  1657   ultimately show "mcont luba orda lub_spmf (ord_spmf op =) ?f" by(rule mcontI)
  1658 qed
  1659 
  1660 lemma bind_pmf_mono [partial_function_mono]:
  1661   "(\<And>y. mono_spmf (\<lambda>f. C y f)) \<Longrightarrow> mono_spmf (\<lambda>f. bind_pmf p (\<lambda>x. C x f))"
  1662 using bind_spmf_mono[of "\<lambda>_. spmf_of_pmf p" C] by simp
  1663 
  1664 lemma map_spmf_mono [partial_function_mono]: "mono_spmf B \<Longrightarrow> mono_spmf (\<lambda>g. map_spmf f (B g))"
  1665 unfolding map_spmf_conv_bind_spmf by(rule bind_spmf_mono) simp_all
  1666 
  1667 lemma mcont_map_spmf [cont_intro]:
  1668   "mcont luba orda lub_spmf (ord_spmf op =) g
  1669   \<Longrightarrow> mcont luba orda lub_spmf (ord_spmf op =) (\<lambda>x. map_spmf f (g x))"
  1670 unfolding map_spmf_conv_bind_spmf by(rule mcont_bind_spmf) simp_all
  1671 
  1672 lemma monotone_set_spmf: "monotone (ord_spmf op =) op \<subseteq> set_spmf"
  1673 by(rule monotoneI)(rule ord_spmf_eqD_set_spmf)
  1674 
  1675 lemma cont_set_spmf: "cont lub_spmf (ord_spmf op =) Union op \<subseteq> set_spmf"
  1676 by(rule contI)(subst set_lub_spmf; simp)
  1677 
  1678 lemma mcont2mcont_set_spmf[THEN mcont2mcont, cont_intro]:
  1679   shows mcont_set_spmf: "mcont lub_spmf (ord_spmf op =) Union op \<subseteq> set_spmf"
  1680 by(rule mcontI monotone_set_spmf cont_set_spmf)+
  1681 
  1682 lemma monotone_spmf: "monotone (ord_spmf op =) op \<le> (\<lambda>p. spmf p x)"
  1683 by(rule monotoneI)(simp add: ord_spmf_eq_leD)
  1684 
  1685 lemma cont_spmf: "cont lub_spmf (ord_spmf op =) Sup op \<le> (\<lambda>p. spmf p x)"
  1686 by(rule contI)(simp add: spmf_lub_spmf)
  1687 
  1688 lemma mcont_spmf: "mcont lub_spmf (ord_spmf op =) Sup op \<le> (\<lambda>p. spmf p x)"
  1689 by(rule mcontI monotone_spmf cont_spmf)+
  1690 
  1691 lemma cont_ennreal_spmf: "cont lub_spmf (ord_spmf op =) Sup op \<le> (\<lambda>p. ennreal (spmf p x))"
  1692 by(rule contI)(simp add: ennreal_spmf_lub_spmf)
  1693 
  1694 lemma mcont2mcont_ennreal_spmf [THEN mcont2mcont, cont_intro]:
  1695   shows mcont_ennreal_spmf: "mcont lub_spmf (ord_spmf op =) Sup op \<le> (\<lambda>p. ennreal (spmf p x))"
  1696 by(rule mcontI mono2mono_ennreal monotone_spmf cont_ennreal_spmf)+
  1697 
  1698 lemma nn_integral_map_spmf [simp]: "nn_integral (measure_spmf (map_spmf f p)) g = nn_integral (measure_spmf p) (g \<circ> f)"
  1699 by(auto 4 3 simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space intro: nn_integral_cong split: split_indicator)
  1700 
  1701 subsubsection \<open>Admissibility of @{term rel_spmf}\<close>
  1702 
  1703 lemma rel_spmf_measureD:
  1704   assumes "rel_spmf R p q"
  1705   shows "measure (measure_spmf p) A \<le> measure (measure_spmf q) {y. \<exists>x\<in>A. R x y}" (is "?lhs \<le> ?rhs")
  1706 proof -
  1707   have "?lhs = measure (measure_pmf p) (Some ` A)" by(simp add: measure_measure_spmf_conv_measure_pmf)
  1708   also have "\<dots> \<le> measure (measure_pmf q) {y. \<exists>x\<in>Some ` A. rel_option R x y}"
  1709     using assms by(rule rel_pmf_measureD)
  1710   also have "\<dots> = ?rhs" unfolding measure_measure_spmf_conv_measure_pmf
  1711     by(rule arg_cong2[where f=measure])(auto simp add: option_rel_Some1)
  1712   finally show ?thesis .
  1713 qed
  1714 
  1715 locale rel_spmf_characterisation =
  1716   assumes rel_pmf_measureI:
  1717     "\<And>(R :: 'a option \<Rightarrow> 'b option \<Rightarrow> bool) p q.
  1718     (\<And>A. measure (measure_pmf p) A \<le> measure (measure_pmf q) {y. \<exists>x\<in>A. R x y})
  1719     \<Longrightarrow> rel_pmf R p q"
  1720   \<comment> \<open>This assumption is shown to hold in general in the AFP entry \<open>MFMC_Countable\<close>.\<close>
  1721 begin
  1722 
  1723 context fixes R :: "'a \<Rightarrow> 'b \<Rightarrow> bool" begin
  1724 
  1725 lemma rel_spmf_measureI:
  1726   assumes eq1: "\<And>A. measure (measure_spmf p) A \<le> measure (measure_spmf q) {y. \<exists>x\<in>A. R x y}"
  1727   assumes eq2: "weight_spmf q \<le> weight_spmf p"
  1728   shows "rel_spmf R p q"
  1729 proof(rule rel_pmf_measureI)
  1730   fix A :: "'a option set"
  1731   define A' where "A' = the ` (A \<inter> range Some)"
  1732   define A'' where "A'' = A \<inter> {None}"
  1733   have A: "A = Some ` A' \<union> A''" "Some ` A' \<inter> A'' = {}"
  1734     unfolding A'_def A''_def by(auto 4 3 intro: rev_image_eqI)
  1735   have "measure (measure_pmf p) A = measure (measure_pmf p) (Some ` A') + measure (measure_pmf p) A''"
  1736     by(simp add: A measure_pmf.finite_measure_Union)
  1737   also have "measure (measure_pmf p) (Some ` A') = measure (measure_spmf p) A'"
  1738     by(simp add: measure_measure_spmf_conv_measure_pmf)
  1739   also have "\<dots> \<le> measure (measure_spmf q) {y. \<exists>x\<in>A'. R x y}" by(rule eq1)
  1740   also (ord_eq_le_trans[OF _ add_right_mono])
  1741   have "\<dots> = measure (measure_pmf q) {y. \<exists>x\<in>A'. rel_option R (Some x) y}"
  1742     unfolding measure_measure_spmf_conv_measure_pmf
  1743     by(rule arg_cong2[where f=measure])(auto simp add: A'_def option_rel_Some1)
  1744   also
  1745   { have "weight_spmf p \<le> measure (measure_spmf q) {y. \<exists>x. R x y}"
  1746       using eq1[of UNIV] unfolding weight_spmf_def by simp
  1747     also have "\<dots> \<le> weight_spmf q" unfolding weight_spmf_def
  1748       by(rule measure_spmf.finite_measure_mono) simp_all
  1749     finally have "weight_spmf p = weight_spmf q" using eq2 by simp }
  1750   then have "measure (measure_pmf p) A'' = measure (measure_pmf q) (if None \<in> A then {None} else {})"
  1751     unfolding A''_def by(simp add: pmf_None_eq_weight_spmf measure_pmf_single)
  1752   also have "measure (measure_pmf q) {y. \<exists>x\<in>A'. rel_option R (Some x) y} + \<dots> = measure (measure_pmf q) {y. \<exists>x\<in>A. rel_option R x y}"
  1753     by(subst measure_pmf.finite_measure_Union[symmetric])
  1754       (auto 4 3 intro!: arg_cong2[where f=measure] simp add: option_rel_Some1 option_rel_Some2 A'_def intro: rev_bexI elim: option.rel_cases)
  1755   finally show "measure (measure_pmf p) A \<le> \<dots>" .
  1756 qed
  1757 
  1758 lemma admissible_rel_spmf:
  1759   "ccpo.admissible (prod_lub lub_spmf lub_spmf) (rel_prod (ord_spmf op =) (ord_spmf op =)) (case_prod (rel_spmf R))"
  1760   (is "ccpo.admissible ?lub ?ord ?P")
  1761 proof(rule ccpo.admissibleI)
  1762   fix Y
  1763   assume chain: "Complete_Partial_Order.chain ?ord Y"
  1764     and Y: "Y \<noteq> {}"
  1765     and R: "\<forall>(p, q) \<in> Y. rel_spmf R p q"
  1766   from R have R: "\<And>p q. (p, q) \<in> Y \<Longrightarrow> rel_spmf R p q" by auto
  1767   have chain1: "Complete_Partial_Order.chain (ord_spmf op =) (fst ` Y)"
  1768     and chain2: "Complete_Partial_Order.chain (ord_spmf op =) (snd ` Y)"
  1769     using chain by(rule chain_imageI; clarsimp)+
  1770   from Y have Y1: "fst ` Y \<noteq> {}" and Y2: "snd ` Y \<noteq> {}" by auto
  1771 
  1772   have "rel_spmf R (lub_spmf (fst ` Y)) (lub_spmf (snd ` Y))"
  1773   proof(rule rel_spmf_measureI)
  1774     show "weight_spmf (lub_spmf (snd ` Y)) \<le> weight_spmf (lub_spmf (fst ` Y))"
  1775       by(auto simp add: weight_lub_spmf chain1 chain2 Y rel_spmf_weightD[OF R, symmetric] intro!: cSUP_least intro: cSUP_upper2[OF bdd_aboveI2[OF weight_spmf_le_1]])
  1776 
  1777     fix A
  1778     have "measure (measure_spmf (lub_spmf (fst ` Y))) A = (SUP y:fst ` Y. measure (measure_spmf y) A)"
  1779       using chain1 Y1 by(rule measure_lub_spmf)
  1780     also have "\<dots> \<le> (SUP y:snd ` Y. measure (measure_spmf y) {y. \<exists>x\<in>A. R x y})" using Y1
  1781       by(rule cSUP_least)(auto intro!: cSUP_upper2[OF bdd_aboveI2[OF measure_spmf.subprob_measure_le_1]] rel_spmf_measureD R)
  1782     also have "\<dots> = measure (measure_spmf (lub_spmf (snd ` Y))) {y. \<exists>x\<in>A. R x y}"
  1783       using chain2 Y2 by(rule measure_lub_spmf[symmetric])
  1784     finally show "measure (measure_spmf (lub_spmf (fst ` Y))) A \<le> \<dots>" .
  1785   qed
  1786   then show "?P (?lub Y)" by(simp add: prod_lub_def)
  1787 qed
  1788 
  1789 lemma admissible_rel_spmf_mcont [cont_intro]:
  1790   "\<lbrakk> mcont lub ord lub_spmf (ord_spmf op =) f; mcont lub ord lub_spmf (ord_spmf op =) g \<rbrakk>
  1791   \<Longrightarrow> ccpo.admissible lub ord (\<lambda>x. rel_spmf R (f x) (g x))"
  1792 by(rule admissible_subst[OF admissible_rel_spmf, where f="\<lambda>x. (f x, g x)", simplified])(rule mcont_Pair)
  1793 
  1794 context includes lifting_syntax
  1795 begin
  1796 
  1797 lemma fixp_spmf_parametric':
  1798   assumes f: "\<And>x. monotone (ord_spmf op =) (ord_spmf op =) F"
  1799   and g: "\<And>x. monotone (ord_spmf op =) (ord_spmf op =) G"
  1800   and param: "(rel_spmf R ===> rel_spmf R) F G"
  1801   shows "(rel_spmf R) (ccpo.fixp lub_spmf (ord_spmf op =) F) (ccpo.fixp lub_spmf (ord_spmf op =) G)"
  1802 by(rule parallel_fixp_induct[OF ccpo_spmf ccpo_spmf _ f g])(auto intro: param[THEN rel_funD])
  1803 
  1804 lemma fixp_spmf_parametric:
  1805   assumes f: "\<And>x. mono_spmf (\<lambda>f. F f x)"
  1806   and g: "\<And>x. mono_spmf (\<lambda>f. G f x)"
  1807   and param: "((A ===> rel_spmf R) ===> A ===> rel_spmf R) F G"
  1808   shows "(A ===> rel_spmf R) (spmf.fixp_fun F) (spmf.fixp_fun G)"
  1809 using f g
  1810 proof(rule parallel_fixp_induct_1_1[OF partial_function_definitions_spmf partial_function_definitions_spmf _ _ reflexive reflexive, where P="(A ===> rel_spmf R)"])
  1811   show "ccpo.admissible (prod_lub (fun_lub lub_spmf) (fun_lub lub_spmf)) (rel_prod (fun_ord (ord_spmf op =)) (fun_ord (ord_spmf op =))) (\<lambda>x. (A ===> rel_spmf R) (fst x) (snd x))"
  1812     unfolding rel_fun_def
  1813     apply(rule admissible_all admissible_imp admissible_rel_spmf_mcont)+
  1814     apply(rule spmf.mcont2mcont[OF mcont_call])
  1815      apply(rule mcont_fst)
  1816     apply(rule spmf.mcont2mcont[OF mcont_call])
  1817      apply(rule mcont_snd)
  1818     done
  1819   show "(A ===> rel_spmf R) (\<lambda>_. lub_spmf {}) (\<lambda>_. lub_spmf {})" by auto
  1820   show "(A ===> rel_spmf R) (F f) (G g)" if "(A ===> rel_spmf R) f g" for f g
  1821     using that by(rule rel_funD[OF param])
  1822 qed
  1823 
  1824 end
  1825 
  1826 end
  1827 
  1828 end
  1829 
  1830 subsection \<open>Restrictions on spmfs\<close>
  1831 
  1832 definition restrict_spmf :: "'a spmf \<Rightarrow> 'a set \<Rightarrow> 'a spmf" (infixl "\<upharpoonleft>" 110)
  1833 where "p \<upharpoonleft> A = map_pmf (\<lambda>x. x \<bind> (\<lambda>y. if y \<in> A then Some y else None)) p"
  1834 
  1835 lemma set_restrict_spmf [simp]: "set_spmf (p \<upharpoonleft> A) = set_spmf p \<inter> A"
  1836 by(fastforce simp add: restrict_spmf_def set_spmf_def split: bind_splits if_split_asm)
  1837 
  1838 lemma restrict_map_spmf: "map_spmf f p \<upharpoonleft> A = map_spmf f (p \<upharpoonleft> (f -` A))"
  1839 by(simp add: restrict_spmf_def pmf.map_comp o_def map_option_bind bind_map_option if_distrib cong del: if_weak_cong)
  1840 
  1841 lemma restrict_restrict_spmf [simp]: "p \<upharpoonleft> A \<upharpoonleft> B = p \<upharpoonleft> (A \<inter> B)"
  1842 by(auto simp add: restrict_spmf_def pmf.map_comp o_def intro!: pmf.map_cong bind_option_cong)
  1843 
  1844 lemma restrict_spmf_empty [simp]: "p \<upharpoonleft> {} = return_pmf None"
  1845 by(simp add: restrict_spmf_def)
  1846 
  1847 lemma restrict_spmf_UNIV [simp]: "p \<upharpoonleft> UNIV = p"
  1848 by(simp add: restrict_spmf_def)
  1849 
  1850 lemma spmf_restrict_spmf_outside [simp]: "x \<notin> A \<Longrightarrow> spmf (p \<upharpoonleft> A) x = 0"
  1851 by(simp add: spmf_eq_0_set_spmf)
  1852 
  1853 lemma emeasure_restrict_spmf [simp]:
  1854   "emeasure (measure_spmf (p \<upharpoonleft> A)) X = emeasure (measure_spmf p) (X \<inter> A)"
  1855 by(auto simp add: restrict_spmf_def measure_spmf_def emeasure_distr measurable_restrict_space1 emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure] split: bind_splits if_split_asm)
  1856 
  1857 lemma measure_restrict_spmf [simp]:
  1858   "measure (measure_spmf (p \<upharpoonleft> A)) X = measure (measure_spmf p) (X \<inter> A)"
  1859 using emeasure_restrict_spmf[of p A X]
  1860 by(simp only: measure_spmf.emeasure_eq_measure ennreal_inj measure_nonneg)
  1861 
  1862 lemma spmf_restrict_spmf: "spmf (p \<upharpoonleft> A) x = (if x \<in> A then spmf p x else 0)"
  1863 by(simp add: spmf_conv_measure_spmf)
  1864 
  1865 lemma spmf_restrict_spmf_inside [simp]: "x \<in> A \<Longrightarrow> spmf (p \<upharpoonleft> A) x = spmf p x"
  1866 by(simp add: spmf_restrict_spmf)
  1867 
  1868 lemma pmf_restrict_spmf_None: "pmf (p \<upharpoonleft> A) None = pmf p None + measure (measure_spmf p) (- A)"
  1869 proof -
  1870   have [simp]: "None \<notin> Some ` (- A)" by auto
  1871   have "(\<lambda>x. x \<bind> (\<lambda>y. if y \<in> A then Some y else None)) -` {None} = {None} \<union> (Some ` (- A))"
  1872     by(auto split: bind_splits if_split_asm)
  1873   then show ?thesis unfolding ereal.inject[symmetric]
  1874     by(simp add: restrict_spmf_def ennreal_pmf_map emeasure_pmf_single del: ereal.inject)
  1875       (simp add: pmf.rep_eq measure_pmf.finite_measure_Union[symmetric] measure_measure_spmf_conv_measure_pmf measure_pmf.emeasure_eq_measure)
  1876 qed
  1877 
  1878 lemma restrict_spmf_trivial: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> x \<in> A) \<Longrightarrow> p \<upharpoonleft> A = p"
  1879 by(rule spmf_eqI)(auto simp add: spmf_restrict_spmf spmf_eq_0_set_spmf)
  1880 
  1881 lemma restrict_spmf_trivial': "set_spmf p \<subseteq> A \<Longrightarrow> p \<upharpoonleft> A = p"
  1882 by(rule restrict_spmf_trivial) blast
  1883 
  1884 lemma restrict_return_spmf: "return_spmf x \<upharpoonleft> A = (if x \<in> A then return_spmf x else return_pmf None)"
  1885 by(simp add: restrict_spmf_def)
  1886 
  1887 lemma restrict_return_spmf_inside [simp]: "x \<in> A \<Longrightarrow> return_spmf x \<upharpoonleft> A = return_spmf x"
  1888 by(simp add: restrict_return_spmf)
  1889 
  1890 lemma restrict_return_spmf_outside [simp]: "x \<notin> A \<Longrightarrow> return_spmf x \<upharpoonleft> A = return_pmf None"
  1891 by(simp add: restrict_return_spmf)
  1892 
  1893 lemma restrict_spmf_return_pmf_None [simp]: "return_pmf None \<upharpoonleft> A = return_pmf None"
  1894 by(simp add: restrict_spmf_def)
  1895 
  1896 lemma restrict_bind_pmf: "bind_pmf p g \<upharpoonleft> A = p \<bind> (\<lambda>x. g x \<upharpoonleft> A)"
  1897 by(simp add: restrict_spmf_def map_bind_pmf o_def)
  1898 
  1899 lemma restrict_bind_spmf: "bind_spmf p g \<upharpoonleft> A = p \<bind> (\<lambda>x. g x \<upharpoonleft> A)"
  1900 by(auto simp add: bind_spmf_def restrict_bind_pmf cong del: option.case_cong_weak cong: option.case_cong intro!: bind_pmf_cong split: option.split)
  1901 
  1902 lemma bind_restrict_pmf: "bind_pmf (p \<upharpoonleft> A) g = p \<bind> (\<lambda>x. if x \<in> Some ` A then g x else g None)"
  1903 by(auto simp add: restrict_spmf_def bind_map_pmf fun_eq_iff split: bind_split intro: arg_cong2[where f=bind_pmf])
  1904 
  1905 lemma bind_restrict_spmf: "bind_spmf (p \<upharpoonleft> A) g = p \<bind> (\<lambda>x. if x \<in> A then g x else return_pmf None)"
  1906 by(auto simp add: bind_spmf_def bind_restrict_pmf fun_eq_iff intro: arg_cong2[where f=bind_pmf] split: option.split)
  1907 
  1908 lemma spmf_map_restrict: "spmf (map_spmf fst (p \<upharpoonleft> (snd -` {y}))) x = spmf p (x, y)"
  1909 by(subst spmf_map)(auto intro: arg_cong2[where f=measure] simp add: spmf_conv_measure_spmf)
  1910 
  1911 lemma measure_eqI_restrict_spmf:
  1912   assumes "rel_spmf R (restrict_spmf p A) (restrict_spmf q B)"
  1913   shows "measure (measure_spmf p) A = measure (measure_spmf q) B"
  1914 proof -
  1915   from assms have "weight_spmf (restrict_spmf p A) = weight_spmf (restrict_spmf q B)" by(rule rel_spmf_weightD)
  1916   thus ?thesis by(simp add: weight_spmf_def)
  1917 qed
  1918 
  1919 subsection \<open>Subprobability distributions of sets\<close>
  1920 
  1921 definition spmf_of_set :: "'a set \<Rightarrow> 'a spmf"
  1922 where
  1923   "spmf_of_set A = (if finite A \<and> A \<noteq> {} then spmf_of_pmf (pmf_of_set A) else return_pmf None)"
  1924 
  1925 lemma spmf_of_set: "spmf (spmf_of_set A) x = indicator A x / card A"
  1926 by(auto simp add: spmf_of_set_def)
  1927 
  1928 lemma pmf_spmf_of_set_None [simp]: "pmf (spmf_of_set A) None = indicator {A. infinite A \<or> A = {}} A"
  1929 by(simp add: spmf_of_set_def)
  1930 
  1931 lemma set_spmf_of_set: "set_spmf (spmf_of_set A) = (if finite A then A else {})"
  1932 by(simp add: spmf_of_set_def)
  1933 
  1934 lemma set_spmf_of_set_finite [simp]: "finite A \<Longrightarrow> set_spmf (spmf_of_set A) = A"
  1935 by(simp add: set_spmf_of_set)
  1936 
  1937 lemma spmf_of_set_singleton: "spmf_of_set {x} = return_spmf x"
  1938 by(simp add: spmf_of_set_def pmf_of_set_singleton)
  1939 
  1940 lemma map_spmf_of_set_inj_on [simp]:
  1941   "inj_on f A \<Longrightarrow> map_spmf f (spmf_of_set A) = spmf_of_set (f ` A)"
  1942 by(auto simp add: spmf_of_set_def map_pmf_of_set_inj dest: finite_imageD)
  1943 
  1944 lemma spmf_of_pmf_pmf_of_set [simp]:
  1945   "\<lbrakk> finite A; A \<noteq> {} \<rbrakk> \<Longrightarrow> spmf_of_pmf (pmf_of_set A) = spmf_of_set A"
  1946 by(simp add: spmf_of_set_def)
  1947 
  1948 lemma weight_spmf_of_set:
  1949   "weight_spmf (spmf_of_set A) = (if finite A \<and> A \<noteq> {} then 1 else 0)"
  1950 by(auto simp only: spmf_of_set_def weight_spmf_of_pmf weight_return_pmf_None split: if_split)
  1951 
  1952 lemma weight_spmf_of_set_finite [simp]: "\<lbrakk> finite A; A \<noteq> {} \<rbrakk> \<Longrightarrow> weight_spmf (spmf_of_set A) = 1"
  1953 by(simp add: weight_spmf_of_set)
  1954 
  1955 lemma weight_spmf_of_set_infinite [simp]: "infinite A \<Longrightarrow> weight_spmf (spmf_of_set A) = 0"
  1956 by(simp add: weight_spmf_of_set)
  1957 
  1958 lemma measure_spmf_spmf_of_set:
  1959   "measure_spmf (spmf_of_set A) = (if finite A \<and> A \<noteq> {} then measure_pmf (pmf_of_set A) else null_measure (count_space UNIV))"
  1960 by(simp add: spmf_of_set_def del: spmf_of_pmf_pmf_of_set)
  1961 
  1962 lemma emeasure_spmf_of_set:
  1963   "emeasure (measure_spmf (spmf_of_set S)) A = card (S \<inter> A) / card S"
  1964 by(auto simp add: measure_spmf_spmf_of_set emeasure_pmf_of_set)
  1965 
  1966 lemma measure_spmf_of_set:
  1967   "measure (measure_spmf (spmf_of_set S)) A = card (S \<inter> A) / card S"
  1968 by(auto simp add: measure_spmf_spmf_of_set measure_pmf_of_set)
  1969 
  1970 lemma nn_integral_spmf_of_set: "nn_integral (measure_spmf (spmf_of_set A)) f = setsum f A / card A"
  1971 by(cases "finite A")(auto simp add: spmf_of_set_def nn_integral_pmf_of_set card_gt_0_iff simp del: spmf_of_pmf_pmf_of_set)
  1972 
  1973 lemma integral_spmf_of_set: "integral\<^sup>L (measure_spmf (spmf_of_set A)) f = setsum f A / card A"
  1974 by(clarsimp simp add: spmf_of_set_def integral_pmf_of_set card_gt_0_iff simp del: spmf_of_pmf_pmf_of_set)
  1975 
  1976 notepad begin \<comment> \<open>@{const pmf_of_set} is not fully parametric.\<close>
  1977   define R :: "nat \<Rightarrow> nat \<Rightarrow> bool" where "R x y \<longleftrightarrow> (x \<noteq> 0 \<longrightarrow> y = 0)" for x y
  1978   define A :: "nat set" where "A = {0, 1}"
  1979   define B :: "nat set" where "B = {0, 1, 2}"
  1980   have "rel_set R A B" unfolding R_def[abs_def] A_def B_def rel_set_def by auto
  1981   have "\<not> rel_pmf R (pmf_of_set A) (pmf_of_set B)"
  1982   proof
  1983     assume "rel_pmf R (pmf_of_set A) (pmf_of_set B)"
  1984     then obtain pq where pq: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> R x y"
  1985       and 1: "map_pmf fst pq = pmf_of_set A"
  1986       and 2: "map_pmf snd pq = pmf_of_set B"
  1987       by cases auto
  1988     have "pmf (pmf_of_set B) 1 = 1 / 3" by(simp add: B_def)
  1989     have "pmf (pmf_of_set B) 2 = 1 / 3" by(simp add: B_def)
  1990 
  1991     have "2 / 3 = pmf (pmf_of_set B) 1 + pmf (pmf_of_set B) 2" by(simp add: B_def)
  1992     also have "\<dots> = measure (measure_pmf (pmf_of_set B)) ({1} \<union> {2})"
  1993       by(subst measure_pmf.finite_measure_Union)(simp_all add: measure_pmf_single)
  1994     also have "\<dots> = emeasure (measure_pmf pq) (snd -` {2, 1})"
  1995       unfolding 2[symmetric] measure_pmf.emeasure_eq_measure[symmetric] by(simp)
  1996     also have "\<dots> = emeasure (measure_pmf pq) {(0, 2), (0, 1)}"
  1997       by(rule emeasure_eq_AE)(auto simp add: AE_measure_pmf_iff R_def dest!: pq)
  1998     also have "\<dots> \<le> emeasure (measure_pmf pq) (fst -` {0})"
  1999       by(rule emeasure_mono) auto
  2000     also have "\<dots> = emeasure (measure_pmf (pmf_of_set A)) {0}"
  2001       unfolding 1[symmetric] by simp
  2002     also have "\<dots> = pmf (pmf_of_set A) 0"
  2003       by(simp add: measure_pmf_single measure_pmf.emeasure_eq_measure)
  2004     also have "pmf (pmf_of_set A) 0 = 1 / 2" by(simp add: A_def)
  2005     finally show False by(subst (asm) ennreal_le_iff; simp)
  2006   qed
  2007 end
  2008 
  2009 lemma rel_pmf_of_set_bij:
  2010   assumes f: "bij_betw f A B"
  2011   and A: "A \<noteq> {}" "finite A"
  2012   and B: "B \<noteq> {}" "finite B"
  2013   and R: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)"
  2014   shows "rel_pmf R (pmf_of_set A) (pmf_of_set B)"
  2015 proof(rule pmf.rel_mono_strong)
  2016   define AB where "AB = (\<lambda>x. (x, f x)) ` A"
  2017   define R' where "R' x y \<longleftrightarrow> (x, y) \<in> AB" for x y
  2018   have "(x, y) \<in> AB" if "(x, y) \<in> set_pmf (pmf_of_set AB)" for x y
  2019     using that by(auto simp add: AB_def A)
  2020   moreover have "map_pmf fst (pmf_of_set AB) = pmf_of_set A"
  2021     by(simp add: AB_def map_pmf_of_set_inj[symmetric] inj_on_def A pmf.map_comp o_def)
  2022   moreover
  2023   from f have [simp]: "inj_on f A" by(rule bij_betw_imp_inj_on)
  2024   from f have [simp]: "f ` A = B" by(rule bij_betw_imp_surj_on)
  2025   have "map_pmf snd (pmf_of_set AB) = pmf_of_set B"
  2026     by(simp add: AB_def map_pmf_of_set_inj[symmetric] inj_on_def A pmf.map_comp o_def)
  2027       (simp add: map_pmf_of_set_inj A)
  2028   ultimately show "rel_pmf (\<lambda>x y. (x, y) \<in> AB) (pmf_of_set A) (pmf_of_set B)" ..
  2029 qed(auto intro: R)
  2030 
  2031 lemma rel_spmf_of_set_bij:
  2032   assumes f: "bij_betw f A B"
  2033   and R: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)"
  2034   shows "rel_spmf R (spmf_of_set A) (spmf_of_set B)"
  2035 proof -
  2036   have "finite A \<longleftrightarrow> finite B" using f by(rule bij_betw_finite)
  2037   moreover have "A = {} \<longleftrightarrow> B = {}" using f by(auto dest: bij_betw_empty2 bij_betw_empty1)
  2038   ultimately show ?thesis using assms
  2039     by(auto simp add: spmf_of_set_def simp del: spmf_of_pmf_pmf_of_set intro: rel_pmf_of_set_bij)
  2040 qed
  2041 
  2042 context includes lifting_syntax
  2043 begin
  2044 
  2045 lemma rel_spmf_of_set:
  2046   assumes "bi_unique R"
  2047   shows "(rel_set R ===> rel_spmf R) spmf_of_set spmf_of_set"
  2048 proof
  2049   fix A B
  2050   assume R: "rel_set R A B"
  2051   with assms obtain f where "bij_betw f A B" and f: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)"
  2052     by(auto dest: bi_unique_rel_set_bij_betw)
  2053   then show "rel_spmf R (spmf_of_set A) (spmf_of_set B)" by(rule rel_spmf_of_set_bij)
  2054 qed
  2055 
  2056 end
  2057 
  2058 lemma map_mem_spmf_of_set:
  2059   assumes "finite B" "B \<noteq> {}"
  2060   shows "map_spmf (\<lambda>x. x \<in> A) (spmf_of_set B) = spmf_of_pmf (bernoulli_pmf (card (A \<inter> B) / card B))"
  2061   (is "?lhs = ?rhs")
  2062 proof(rule spmf_eqI)
  2063   fix i
  2064   have "ennreal (spmf ?lhs i) = card (B \<inter> (\<lambda>x. x \<in> A) -` {i}) / (card B)"
  2065     by(subst ennreal_spmf_map)(simp add: measure_spmf_spmf_of_set assms emeasure_pmf_of_set)
  2066   also have "\<dots> = (if i then card (B \<inter> A) / card B else card (B - A) / card B)"
  2067     by(auto intro: arg_cong[where f=card])
  2068   also have "\<dots> = (if i then card (B \<inter> A) / card B else (card B - card (B \<inter> A)) / card B)"
  2069     by(auto simp add: card_Diff_subset_Int assms)
  2070   also have "\<dots> = ennreal (spmf ?rhs i)"
  2071     by(simp add: assms card_gt_0_iff field_simps card_mono Int_commute of_nat_diff)
  2072   finally show "spmf ?lhs i = spmf ?rhs i" by simp
  2073 qed
  2074 
  2075 abbreviation coin_spmf :: "bool spmf"
  2076 where "coin_spmf \<equiv> spmf_of_set UNIV"
  2077 
  2078 lemma map_eq_const_coin_spmf: "map_spmf (op = c) coin_spmf = coin_spmf"
  2079 proof -
  2080   have "inj (op \<longleftrightarrow> c)" "range (op \<longleftrightarrow> c) = UNIV" by(auto intro: inj_onI)
  2081   then show ?thesis by simp
  2082 qed
  2083 
  2084 lemma bind_coin_spmf_eq_const: "coin_spmf \<bind> (\<lambda>x :: bool. return_spmf (b = x)) = coin_spmf"
  2085 using map_eq_const_coin_spmf unfolding map_spmf_conv_bind_spmf by simp
  2086 
  2087 lemma bind_coin_spmf_eq_const': "coin_spmf \<bind> (\<lambda>x :: bool. return_spmf (x = b)) = coin_spmf"
  2088 by(rewrite in "_ = \<hole>" bind_coin_spmf_eq_const[symmetric, of b])(auto intro: bind_spmf_cong)
  2089 
  2090 subsection \<open>Losslessness\<close>
  2091 
  2092 definition lossless_spmf :: "'a spmf \<Rightarrow> bool"
  2093 where "lossless_spmf p \<longleftrightarrow> weight_spmf p = 1"
  2094 
  2095 lemma lossless_iff_pmf_None: "lossless_spmf p \<longleftrightarrow> pmf p None = 0"
  2096 by(simp add: lossless_spmf_def pmf_None_eq_weight_spmf)
  2097 
  2098 lemma lossless_return_spmf [iff]: "lossless_spmf (return_spmf x)"
  2099 by(simp add: lossless_iff_pmf_None)
  2100 
  2101 lemma lossless_return_pmf_None [iff]: "\<not> lossless_spmf (return_pmf None)"
  2102 by(simp add: lossless_iff_pmf_None)
  2103 
  2104 lemma lossless_map_spmf [simp]: "lossless_spmf (map_spmf f p) \<longleftrightarrow> lossless_spmf p"
  2105 by(auto simp add: lossless_iff_pmf_None pmf_eq_0_set_pmf)
  2106 
  2107 lemma lossless_bind_spmf [simp]:
  2108   "lossless_spmf (p \<bind> f) \<longleftrightarrow> lossless_spmf p \<and> (\<forall>x\<in>set_spmf p. lossless_spmf (f x))"
  2109 by(simp add: lossless_iff_pmf_None pmf_bind_spmf_None add_nonneg_eq_0_iff integral_nonneg_AE integral_nonneg_eq_0_iff_AE measure_spmf.integrable_const_bound[where B=1] pmf_le_1)
  2110 
  2111 lemma lossless_weight_spmfD: "lossless_spmf p \<Longrightarrow> weight_spmf p = 1"
  2112 by(simp add: lossless_spmf_def)
  2113 
  2114 lemma lossless_iff_set_pmf_None:
  2115   "lossless_spmf p \<longleftrightarrow> None \<notin> set_pmf p"
  2116 by (simp add: lossless_iff_pmf_None pmf_eq_0_set_pmf)
  2117 
  2118 lemma lossless_spmf_of_set [simp]: "lossless_spmf (spmf_of_set A) \<longleftrightarrow> finite A \<and> A \<noteq> {}"
  2119 by(auto simp add: lossless_spmf_def weight_spmf_of_set)
  2120 
  2121 lemma lossless_spmf_spmf_of_spmf [simp]: "lossless_spmf (spmf_of_pmf p)"
  2122 by(simp add: lossless_spmf_def)
  2123 
  2124 lemma lossless_spmf_bind_pmf [simp]:
  2125   "lossless_spmf (bind_pmf p f) \<longleftrightarrow> (\<forall>x\<in>set_pmf p. lossless_spmf (f x))"
  2126 by(simp add: lossless_iff_pmf_None pmf_bind integral_nonneg_AE integral_nonneg_eq_0_iff_AE measure_pmf.integrable_const_bound[where B=1] AE_measure_pmf_iff pmf_le_1)
  2127 
  2128 lemma lossless_spmf_conv_spmf_of_pmf: "lossless_spmf p \<longleftrightarrow> (\<exists>p'. p = spmf_of_pmf p')"
  2129 proof
  2130   assume "lossless_spmf p"
  2131   hence *: "\<And>y. y \<in> set_pmf p \<Longrightarrow> \<exists>x. y = Some x"
  2132     by(case_tac y)(simp_all add: lossless_iff_set_pmf_None)
  2133 
  2134   let ?p = "map_pmf the p"
  2135   have "p = spmf_of_pmf ?p"
  2136   proof(rule spmf_eqI)
  2137     fix i
  2138     have "ennreal (pmf (map_pmf the p) i) = \<integral>\<^sup>+ x. indicator (the -` {i}) x \<partial>p" by(simp add: ennreal_pmf_map)
  2139     also have "\<dots> = \<integral>\<^sup>+ x. indicator {i} x \<partial>measure_spmf p" unfolding measure_spmf_def
  2140       by(subst nn_integral_distr)(auto simp add: nn_integral_restrict_space AE_measure_pmf_iff simp del: nn_integral_indicator intro!: nn_integral_cong_AE split: split_indicator dest!: * )
  2141     also have "\<dots> = spmf p i" by(simp add: emeasure_spmf_single)
  2142     finally show "spmf p i = spmf (spmf_of_pmf ?p) i" by simp
  2143   qed
  2144   thus "\<exists>p'. p = spmf_of_pmf p'" ..
  2145 qed auto
  2146 
  2147 lemma spmf_False_conv_True: "lossless_spmf p \<Longrightarrow> spmf p False = 1 - spmf p True"
  2148 by(clarsimp simp add: lossless_spmf_conv_spmf_of_pmf pmf_False_conv_True)
  2149 
  2150 lemma spmf_True_conv_False: "lossless_spmf p \<Longrightarrow> spmf p True = 1 - spmf p False"
  2151 by(simp add: spmf_False_conv_True)
  2152 
  2153 lemma bind_eq_return_spmf:
  2154   "bind_spmf p f = return_spmf x \<longleftrightarrow> (\<forall>y\<in>set_spmf p. f y = return_spmf x) \<and> lossless_spmf p"
  2155 by(auto simp add: bind_spmf_def bind_eq_return_pmf in_set_spmf lossless_iff_pmf_None pmf_eq_0_set_pmf iff del: not_None_eq split: option.split)
  2156 
  2157 lemma rel_spmf_return_spmf2:
  2158   "rel_spmf R p (return_spmf x) \<longleftrightarrow> lossless_spmf p \<and> (\<forall>a\<in>set_spmf p. R a x)"
  2159 by(auto simp add: lossless_iff_set_pmf_None rel_pmf_return_pmf2 option_rel_Some2 in_set_spmf, metis in_set_spmf not_None_eq)
  2160 
  2161 lemma rel_spmf_return_spmf1:
  2162   "rel_spmf R (return_spmf x) p \<longleftrightarrow> lossless_spmf p \<and> (\<forall>a\<in>set_spmf p. R x a)"
  2163 using rel_spmf_return_spmf2[of "R\<inverse>\<inverse>"] by(simp add: spmf_rel_conversep)
  2164 
  2165 lemma rel_spmf_bindI1:
  2166   assumes f: "\<And>x. x \<in> set_spmf p \<Longrightarrow> rel_spmf R (f x) q"
  2167   and p: "lossless_spmf p"
  2168   shows "rel_spmf R (bind_spmf p f) q"
  2169 proof -
  2170   fix x :: 'a
  2171   have "rel_spmf R (bind_spmf p f) (bind_spmf (return_spmf x) (\<lambda>_. q))"
  2172     by(rule rel_spmf_bindI[where R="\<lambda>x _. x \<in> set_spmf p"])(simp_all add: rel_spmf_return_spmf2 p f)
  2173   then show ?thesis by simp
  2174 qed
  2175 
  2176 lemma rel_spmf_bindI2:
  2177   "\<lbrakk> \<And>x. x \<in> set_spmf q \<Longrightarrow> rel_spmf R p (f x); lossless_spmf q \<rbrakk>
  2178   \<Longrightarrow> rel_spmf R p (bind_spmf q f)"
  2179 using rel_spmf_bindI1[of q "conversep R" f p] by(simp add: spmf_rel_conversep)
  2180 
  2181 subsection \<open>Scaling\<close>
  2182 
  2183 definition scale_spmf :: "real \<Rightarrow> 'a spmf \<Rightarrow> 'a spmf"
  2184 where
  2185   "scale_spmf r p = embed_spmf (\<lambda>x. min (inverse (weight_spmf p)) (max 0 r) * spmf p x)"
  2186 
  2187 lemma scale_spmf_le_1:
  2188   "(\<integral>\<^sup>+ x. min (inverse (weight_spmf p)) (max 0 r) * spmf p x \<partial>count_space UNIV) \<le> 1" (is "?lhs \<le> _")
  2189 proof -
  2190   have "?lhs = min (inverse (weight_spmf p)) (max 0 r) * \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
  2191     by(subst nn_integral_cmult[symmetric])(simp_all add: weight_spmf_nonneg max_def min_def ennreal_mult)
  2192   also have "\<dots> \<le> 1" unfolding weight_spmf_eq_nn_integral_spmf[symmetric]
  2193     by(simp add: min_def max_def weight_spmf_nonneg order.strict_iff_order field_simps ennreal_mult[symmetric])
  2194   finally show ?thesis .
  2195 qed
  2196 
  2197 lemma spmf_scale_spmf: "spmf (scale_spmf r p) x = max 0 (min (inverse (weight_spmf p)) r) * spmf p x" (is "?lhs = ?rhs")
  2198 unfolding scale_spmf_def
  2199 apply(subst spmf_embed_spmf[OF scale_spmf_le_1])
  2200 apply(simp add: max_def min_def weight_spmf_le_0 field_simps weight_spmf_nonneg not_le order.strict_iff_order)
  2201 apply(metis antisym_conv order_trans weight_spmf_nonneg zero_le_mult_iff zero_le_one)
  2202 done
  2203 
  2204 lemma real_inverse_le_1_iff: fixes x :: real
  2205   shows "\<lbrakk> 0 \<le> x; x \<le> 1 \<rbrakk> \<Longrightarrow> 1 / x \<le> 1 \<longleftrightarrow> x = 1 \<or> x = 0"
  2206 by auto
  2207 
  2208 lemma spmf_scale_spmf': "r \<le> 1 \<Longrightarrow> spmf (scale_spmf r p) x = max 0 r * spmf p x"
  2209 using real_inverse_le_1_iff[OF weight_spmf_nonneg weight_spmf_le_1, of p]
  2210 by(auto simp add: spmf_scale_spmf max_def min_def field_simps)(metis pmf_le_0_iff spmf_le_weight)
  2211 
  2212 lemma scale_spmf_neg: "r \<le> 0 \<Longrightarrow> scale_spmf r p = return_pmf None"
  2213 by(rule spmf_eqI)(simp add: spmf_scale_spmf' max_def)
  2214 
  2215 lemma scale_spmf_return_None [simp]: "scale_spmf r (return_pmf None) = return_pmf None"
  2216 by(rule spmf_eqI)(simp add: spmf_scale_spmf)
  2217 
  2218 lemma scale_spmf_conv_bind_bernoulli:
  2219   assumes "r \<le> 1"
  2220   shows "scale_spmf r p = bind_pmf (bernoulli_pmf r) (\<lambda>b. if b then p else return_pmf None)" (is "?lhs = ?rhs")
  2221 proof(rule spmf_eqI)
  2222   fix x
  2223   have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using assms
  2224     unfolding spmf_scale_spmf ennreal_pmf_bind nn_integral_measure_pmf UNIV_bool bernoulli_pmf.rep_eq
  2225     apply(auto simp add: nn_integral_count_space_finite max_def min_def field_simps real_inverse_le_1_iff[OF weight_spmf_nonneg weight_spmf_le_1] weight_spmf_lt_0 not_le ennreal_mult[symmetric])
  2226     apply (metis pmf_le_0_iff spmf_le_weight)
  2227     apply (metis pmf_le_0_iff spmf_le_weight)
  2228     apply (meson le_divide_eq_1_pos measure_spmf.subprob_measure_le_1 not_less order_trans weight_spmf_le_0)
  2229     by (meson divide_le_0_1_iff less_imp_le order_trans weight_spmf_le_0)
  2230   thus "spmf ?lhs x = spmf ?rhs x" by simp
  2231 qed
  2232 
  2233 lemma nn_integral_spmf: "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space A) = emeasure (measure_spmf p) A"
  2234 apply(simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space nn_integral_pmf[symmetric])
  2235 apply(rule nn_integral_bij_count_space[where g=Some])
  2236 apply(auto simp add: bij_betw_def)
  2237 done
  2238 
  2239 lemma measure_spmf_scale_spmf: "measure_spmf (scale_spmf r p) = scale_measure (min (inverse (weight_spmf p)) r) (measure_spmf p)"
  2240 apply(rule measure_eqI)
  2241  apply simp
  2242 apply(simp add: nn_integral_spmf[symmetric] spmf_scale_spmf)
  2243 apply(subst nn_integral_cmult[symmetric])
  2244 apply(auto simp add: max_def min_def ennreal_mult[symmetric] not_le ennreal_lt_0)
  2245 done
  2246 
  2247 lemma measure_spmf_scale_spmf':
  2248   "r \<le> 1 \<Longrightarrow> measure_spmf (scale_spmf r p) = scale_measure r (measure_spmf p)"
  2249 unfolding measure_spmf_scale_spmf
  2250 apply(cases "weight_spmf p > 0")
  2251  apply(simp add: min.absorb2 field_simps weight_spmf_le_1 mult_le_one)
  2252 apply(clarsimp simp add: weight_spmf_le_0 min_def scale_spmf_neg weight_spmf_eq_0 not_less)
  2253 done
  2254 
  2255 lemma scale_spmf_1 [simp]: "scale_spmf 1 p = p"
  2256 apply(rule spmf_eqI)
  2257 apply(simp add: spmf_scale_spmf max_def min_def order.strict_iff_order field_simps weight_spmf_nonneg)
  2258 apply(metis antisym_conv divide_le_eq_1 less_imp_le pmf_nonneg spmf_le_weight weight_spmf_nonneg weight_spmf_le_1)
  2259 done
  2260 
  2261 lemma scale_spmf_0 [simp]: "scale_spmf 0 p = return_pmf None"
  2262 by(rule spmf_eqI)(simp add: spmf_scale_spmf min_def max_def weight_spmf_le_0)
  2263 
  2264 lemma bind_scale_spmf:
  2265   assumes r: "r \<le> 1"
  2266   shows "bind_spmf (scale_spmf r p) f = bind_spmf p (\<lambda>x. scale_spmf r (f x))"
  2267   (is "?lhs = ?rhs")
  2268 proof(rule spmf_eqI)
  2269   fix x
  2270   have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using r
  2271     by(simp add: ennreal_spmf_bind measure_spmf_scale_spmf' nn_integral_scale_measure spmf_scale_spmf')
  2272       (simp add: ennreal_mult ennreal_lt_0 nn_integral_cmult max_def min_def)
  2273   thus "spmf ?lhs x = spmf ?rhs x" by simp
  2274 qed
  2275 
  2276 lemma scale_bind_spmf:
  2277   assumes "r \<le> 1"
  2278   shows "scale_spmf r (bind_spmf p f) = bind_spmf p (\<lambda>x. scale_spmf r (f x))"
  2279   (is "?lhs = ?rhs")
  2280 proof(rule spmf_eqI)
  2281   fix x
  2282   have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using assms
  2283     unfolding spmf_scale_spmf'[OF assms]
  2284     by(simp add: ennreal_mult ennreal_spmf_bind spmf_scale_spmf' nn_integral_cmult max_def min_def)
  2285   thus "spmf ?lhs x = spmf ?rhs x" by simp
  2286 qed
  2287 
  2288 lemma bind_spmf_const: "bind_spmf p (\<lambda>x. q) = scale_spmf (weight_spmf p) q" (is "?lhs = ?rhs")
  2289 proof(rule spmf_eqI)
  2290   fix x
  2291   have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)"
  2292     using measure_spmf.subprob_measure_le_1[of p "space (measure_spmf p)"]
  2293     by(subst ennreal_spmf_bind)(simp add: spmf_scale_spmf' weight_spmf_le_1 ennreal_mult mult.commute max_def min_def measure_spmf.emeasure_eq_measure)
  2294   thus "spmf ?lhs x = spmf ?rhs x" by simp
  2295 qed
  2296 
  2297 lemma map_scale_spmf: "map_spmf f (scale_spmf r p) = scale_spmf r (map_spmf f p)" (is "?lhs = ?rhs")
  2298 proof(rule spmf_eqI)
  2299   fix i
  2300   show "spmf ?lhs i = spmf ?rhs i" unfolding spmf_scale_spmf
  2301     by(subst (1 2) spmf_map)(auto simp add: measure_spmf_scale_spmf max_def min_def ennreal_lt_0)
  2302 qed
  2303 
  2304 lemma set_scale_spmf: "set_spmf (scale_spmf r p) = (if r > 0 then set_spmf p else {})"
  2305 apply(auto simp add: in_set_spmf_iff_spmf spmf_scale_spmf)
  2306 apply(simp add: max_def min_def not_le weight_spmf_lt_0 weight_spmf_eq_0 split: if_split_asm)
  2307 done
  2308 
  2309 lemma set_scale_spmf' [simp]: "0 < r \<Longrightarrow> set_spmf (scale_spmf r p) = set_spmf p"
  2310 by(simp add: set_scale_spmf)
  2311 
  2312 lemma rel_spmf_scaleI:
  2313   assumes "r > 0 \<Longrightarrow> rel_spmf A p q"
  2314   shows "rel_spmf A (scale_spmf r p) (scale_spmf r q)"
  2315 proof(cases "r > 0")
  2316   case True
  2317   from assms[OF this] show ?thesis
  2318     by(rule rel_spmfE)(auto simp add: map_scale_spmf[symmetric] spmf_rel_map True intro: rel_spmf_reflI)
  2319 qed(simp add: not_less scale_spmf_neg)
  2320 
  2321 lemma weight_scale_spmf: "weight_spmf (scale_spmf r p) = min 1 (max 0 r * weight_spmf p)"
  2322 proof -
  2323   have "ennreal (weight_spmf (scale_spmf r p)) = min 1 (max 0 r * ennreal (weight_spmf p))"
  2324     unfolding weight_spmf_eq_nn_integral_spmf
  2325     apply(simp add: spmf_scale_spmf ennreal_mult zero_ereal_def[symmetric] nn_integral_cmult)
  2326     apply(auto simp add: weight_spmf_eq_nn_integral_spmf[symmetric] field_simps min_def max_def not_le weight_spmf_lt_0 ennreal_mult[symmetric])
  2327     subgoal by(subst (asm) ennreal_mult[symmetric], meson divide_less_0_1_iff le_less_trans not_le weight_spmf_lt_0, simp+, meson not_le pos_divide_le_eq weight_spmf_le_0)
  2328     subgoal by(cases "r \<ge> 0")(simp_all add: ennreal_mult[symmetric] weight_spmf_nonneg ennreal_lt_0, meson le_less_trans not_le pos_divide_le_eq zero_less_divide_1_iff)
  2329     done
  2330   thus ?thesis by(auto simp add: min_def max_def ennreal_mult[symmetric] split: if_split_asm)
  2331 qed
  2332 
  2333 lemma weight_scale_spmf' [simp]:
  2334   "\<lbrakk> 0 \<le> r; r \<le> 1 \<rbrakk> \<Longrightarrow> weight_spmf (scale_spmf r p) = r * weight_spmf p"
  2335 by(simp add: weight_scale_spmf max_def min_def)(metis antisym_conv mult_left_le order_trans weight_spmf_le_1)
  2336 
  2337 lemma pmf_scale_spmf_None:
  2338   "pmf (scale_spmf k p) None = 1 - min 1 (max 0 k * (1 - pmf p None))"
  2339 unfolding pmf_None_eq_weight_spmf by(simp add: weight_scale_spmf)
  2340 
  2341 lemma scale_scale_spmf:
  2342   "scale_spmf r (scale_spmf r' p) = scale_spmf (r * max 0 (min (inverse (weight_spmf p)) r')) p"
  2343   (is "?lhs = ?rhs")
  2344 proof(rule spmf_eqI)
  2345   fix i
  2346   have "max 0 (min (1 / weight_spmf p) r') *
  2347     max 0 (min (1 / min 1 (weight_spmf p * max 0 r')) r) =
  2348     max 0 (min (1 / weight_spmf p) (r * max 0 (min (1 / weight_spmf p) r')))"
  2349   proof(cases "weight_spmf p > 0")
  2350     case False
  2351     thus ?thesis by(simp add: not_less weight_spmf_le_0)
  2352   next
  2353     case True
  2354     thus ?thesis by(simp add: field_simps max_def min.absorb_iff2[symmetric])(auto simp add: min_def field_simps zero_le_mult_iff)
  2355   qed
  2356   then show "spmf ?lhs i = spmf ?rhs i"
  2357     by(simp add: spmf_scale_spmf field_simps weight_scale_spmf)
  2358 qed
  2359 
  2360 lemma scale_scale_spmf' [simp]:
  2361   "\<lbrakk> 0 \<le> r; r \<le> 1; 0 \<le> r'; r' \<le> 1 \<rbrakk>
  2362   \<Longrightarrow> scale_spmf r (scale_spmf r' p) = scale_spmf (r * r') p"
  2363 apply(cases "weight_spmf p > 0")
  2364 apply(auto simp add: scale_scale_spmf min_def max_def field_simps not_le weight_spmf_lt_0 weight_spmf_eq_0 not_less weight_spmf_le_0)
  2365 apply(subgoal_tac "1 = r'")
  2366  apply (metis (no_types) divide_1 eq_iff measure_spmf.subprob_measure_le_1 mult.commute mult_cancel_right1)
  2367 apply(meson eq_iff le_divide_eq_1_pos measure_spmf.subprob_measure_le_1 mult_imp_div_pos_le order.trans)
  2368 done
  2369 
  2370 lemma scale_spmf_eq_same: "scale_spmf r p = p \<longleftrightarrow> weight_spmf p = 0 \<or> r = 1 \<or> r \<ge> 1 \<and> weight_spmf p = 1"
  2371   (is "?lhs \<longleftrightarrow> ?rhs")
  2372 proof
  2373   assume ?lhs
  2374   hence "weight_spmf (scale_spmf r p) = weight_spmf p" by simp
  2375   hence *: "min 1 (max 0 r * weight_spmf p) = weight_spmf p" by(simp add: weight_scale_spmf)
  2376   hence **: "weight_spmf p = 0 \<or> r \<ge> 1" by(auto simp add: min_def max_def split: if_split_asm)
  2377   show ?rhs
  2378   proof(cases "weight_spmf p = 0")
  2379     case False
  2380     with ** have "r \<ge> 1" by simp
  2381     with * False have "r = 1 \<or> weight_spmf p = 1" by(simp add: max_def min_def not_le split: if_split_asm)
  2382     with \<open>r \<ge> 1\<close> show ?thesis by simp
  2383   qed simp
  2384 qed(auto intro!: spmf_eqI simp add: spmf_scale_spmf, metis pmf_le_0_iff spmf_le_weight)
  2385 
  2386 lemma map_const_spmf_of_set:
  2387   "\<lbrakk> finite A; A \<noteq> {} \<rbrakk> \<Longrightarrow> map_spmf (\<lambda>_. c) (spmf_of_set A) = return_spmf c"
  2388 by(simp add: map_spmf_conv_bind_spmf bind_spmf_const)
  2389 
  2390 subsection \<open>Conditional spmfs\<close>
  2391 
  2392 lemma set_pmf_Int_Some: "set_pmf p \<inter> Some ` A = {} \<longleftrightarrow> set_spmf p \<inter> A = {}"
  2393 by(auto simp add: in_set_spmf)
  2394 
  2395 lemma measure_spmf_zero_iff: "measure (measure_spmf p) A = 0 \<longleftrightarrow> set_spmf p \<inter> A = {}"
  2396 unfolding measure_measure_spmf_conv_measure_pmf by(simp add: measure_pmf_zero_iff set_pmf_Int_Some)
  2397 
  2398 definition cond_spmf :: "'a spmf \<Rightarrow> 'a set \<Rightarrow> 'a spmf"
  2399 where "cond_spmf p A = (if set_spmf p \<inter> A = {} then return_pmf None else cond_pmf p (Some ` A))"
  2400 
  2401 lemma set_cond_spmf [simp]: "set_spmf (cond_spmf p A) = set_spmf p \<inter> A"
  2402 by(auto 4 4 simp add: cond_spmf_def in_set_spmf iff: set_cond_pmf[THEN set_eq_iff[THEN iffD1], THEN spec, rotated])
  2403 
  2404 lemma cond_map_spmf [simp]: "cond_spmf (map_spmf f p) A = map_spmf f (cond_spmf p (f -` A))"
  2405 proof -
  2406   have "map_option f -` Some ` A = Some ` f -` A" by auto
  2407   moreover have "set_pmf p \<inter> map_option f -` Some ` A \<noteq> {}" if "Some x \<in> set_pmf p" "f x \<in> A" for x
  2408     using that by auto
  2409   ultimately show ?thesis by(auto simp add: cond_spmf_def in_set_spmf cond_map_pmf)
  2410 qed
  2411 
  2412 lemma spmf_cond_spmf [simp]:
  2413   "spmf (cond_spmf p A) x = (if x \<in> A then spmf p x / measure (measure_spmf p) A else 0)"
  2414 by(auto simp add: cond_spmf_def pmf_cond set_pmf_Int_Some[symmetric] measure_measure_spmf_conv_measure_pmf measure_pmf_zero_iff)
  2415 
  2416 lemma bind_eq_return_pmf_None:
  2417   "bind_spmf p f = return_pmf None \<longleftrightarrow> (\<forall>x\<in>set_spmf p. f x = return_pmf None)"
  2418 by(auto simp add: bind_spmf_def bind_eq_return_pmf in_set_spmf split: option.splits)
  2419 
  2420 lemma return_pmf_None_eq_bind:
  2421   "return_pmf None = bind_spmf p f \<longleftrightarrow> (\<forall>x\<in>set_spmf p. f x = return_pmf None)"
  2422 using bind_eq_return_pmf_None[of p f] by auto
  2423 
  2424 (* Conditional probabilities do not seem to interact nicely with bind. *)
  2425 
  2426 subsection \<open>Product spmf\<close>
  2427 
  2428 definition pair_spmf :: "'a spmf \<Rightarrow> 'b spmf \<Rightarrow> ('a \<times> 'b) spmf"
  2429 where "pair_spmf p q = bind_pmf (pair_pmf p q) (\<lambda>xy. case xy of (Some x, Some y) \<Rightarrow> return_spmf (x, y) | _ \<Rightarrow> return_pmf None)"
  2430 
  2431 lemma map_fst_pair_spmf [simp]: "map_spmf fst (pair_spmf p q) = scale_spmf (weight_spmf q) p"
  2432 unfolding bind_spmf_const[symmetric]
  2433 apply(simp add: pair_spmf_def map_bind_pmf pair_pmf_def bind_assoc_pmf option.case_distrib)
  2434 apply(subst bind_commute_pmf)
  2435 apply(auto intro!: bind_pmf_cong[OF refl] simp add: bind_return_pmf bind_spmf_def bind_return_pmf' case_option_collapse option.case_distrib[where h="map_spmf _"] option.case_distrib[symmetric] case_option_id split: option.split cong del: option.case_cong_weak)
  2436 done
  2437 
  2438 lemma map_snd_pair_spmf [simp]: "map_spmf snd (pair_spmf p q) = scale_spmf (weight_spmf p) q"
  2439 unfolding bind_spmf_const[symmetric]
  2440   apply(simp add: pair_spmf_def map_bind_pmf pair_pmf_def bind_assoc_pmf option.case_distrib
  2441     cong del: option.case_cong_weak)
  2442 apply(auto intro!: bind_pmf_cong[OF refl] simp add: bind_return_pmf bind_spmf_def bind_return_pmf' case_option_collapse option.case_distrib[where h="map_spmf _"] option.case_distrib[symmetric] case_option_id split: option.split cong del: option.case_cong_weak)
  2443 done
  2444 
  2445 lemma set_pair_spmf [simp]: "set_spmf (pair_spmf p q) = set_spmf p \<times> set_spmf q"
  2446 by(auto 4 3 simp add: pair_spmf_def set_spmf_bind_pmf bind_UNION in_set_spmf intro: rev_bexI split: option.splits)
  2447 
  2448 lemma spmf_pair [simp]: "spmf (pair_spmf p q) (x, y) = spmf p x * spmf q y" (is "?lhs = ?rhs")
  2449 proof -
  2450   have "ennreal ?lhs = \<integral>\<^sup>+ a. \<integral>\<^sup>+ b. indicator {(x, y)} (a, b) \<partial>measure_spmf q \<partial>measure_spmf p"
  2451     unfolding measure_spmf_def pair_spmf_def ennreal_pmf_bind nn_integral_pair_pmf'
  2452     by(auto simp add: zero_ereal_def[symmetric] nn_integral_distr nn_integral_restrict_space nn_integral_multc[symmetric] intro!: nn_integral_cong split: option.split split_indicator)
  2453   also have "\<dots> = \<integral>\<^sup>+ a. (\<integral>\<^sup>+ b. indicator {y} b \<partial>measure_spmf q) * indicator {x} a \<partial>measure_spmf p"
  2454     by(subst nn_integral_multc[symmetric])(auto intro!: nn_integral_cong split: split_indicator)
  2455   also have "\<dots> = ennreal ?rhs" by(simp add: emeasure_spmf_single max_def ennreal_mult mult.commute)
  2456   finally show ?thesis by simp
  2457 qed
  2458 
  2459 lemma pair_map_spmf2: "pair_spmf p (map_spmf f q) = map_spmf (apsnd f) (pair_spmf p q)"
  2460 by(auto simp add: pair_spmf_def pair_map_pmf2 bind_map_pmf map_bind_pmf intro: bind_pmf_cong split: option.split)
  2461 
  2462 lemma pair_map_spmf1: "pair_spmf (map_spmf f p) q = map_spmf (apfst f) (pair_spmf p q)"
  2463 by(auto simp add: pair_spmf_def pair_map_pmf1 bind_map_pmf map_bind_pmf intro: bind_pmf_cong split: option.split)
  2464 
  2465 lemma pair_map_spmf: "pair_spmf (map_spmf f p) (map_spmf g q) = map_spmf (map_prod f g) (pair_spmf p q)"
  2466 unfolding pair_map_spmf2 pair_map_spmf1 spmf.map_comp by(simp add: apfst_def apsnd_def o_def prod.map_comp)
  2467 
  2468 lemma pair_spmf_alt_def: "pair_spmf p q = bind_spmf p (\<lambda>x. bind_spmf q (\<lambda>y. return_spmf (x, y)))"
  2469 by(auto simp add: pair_spmf_def pair_pmf_def bind_spmf_def bind_assoc_pmf bind_return_pmf split: option.split intro: bind_pmf_cong)
  2470 
  2471 lemma weight_pair_spmf [simp]: "weight_spmf (pair_spmf p q) = weight_spmf p * weight_spmf q"
  2472 unfolding pair_spmf_alt_def by(simp add: weight_bind_spmf o_def)
  2473 
  2474 lemma pair_scale_spmf1: (* FIXME: generalise to arbitrary r *)
  2475   "r \<le> 1 \<Longrightarrow> pair_spmf (scale_spmf r p) q = scale_spmf r (pair_spmf p q)"
  2476 by(simp add: pair_spmf_alt_def scale_bind_spmf bind_scale_spmf)
  2477 
  2478 lemma pair_scale_spmf2: (* FIXME: generalise to arbitrary r *)
  2479   "r \<le> 1 \<Longrightarrow> pair_spmf p (scale_spmf r q) = scale_spmf r (pair_spmf p q)"
  2480 by(simp add: pair_spmf_alt_def scale_bind_spmf bind_scale_spmf)
  2481 
  2482 lemma pair_spmf_return_None1 [simp]: "pair_spmf (return_pmf None) p = return_pmf None"
  2483 by(rule spmf_eqI)(clarsimp)
  2484 
  2485 lemma pair_spmf_return_None2 [simp]: "pair_spmf p (return_pmf None) = return_pmf None"
  2486 by(rule spmf_eqI)(clarsimp)
  2487 
  2488 lemma pair_spmf_return_spmf1: "pair_spmf (return_spmf x) q = map_spmf (Pair x) q"
  2489 by(rule spmf_eqI)(auto split: split_indicator simp add: spmf_map_inj' inj_on_def intro: spmf_map_outside)
  2490 
  2491 lemma pair_spmf_return_spmf2: "pair_spmf p (return_spmf y) = map_spmf (\<lambda>x. (x, y)) p"
  2492 by(rule spmf_eqI)(auto split: split_indicator simp add: inj_on_def intro!: spmf_map_outside spmf_map_inj'[symmetric])
  2493 
  2494 lemma pair_spmf_return_spmf [simp]: "pair_spmf (return_spmf x) (return_spmf y) = return_spmf (x, y)"
  2495 by(simp add: pair_spmf_return_spmf1)
  2496 
  2497 lemma rel_pair_spmf_prod:
  2498   "rel_spmf (rel_prod A B) (pair_spmf p q) (pair_spmf p' q') \<longleftrightarrow>
  2499    rel_spmf A (scale_spmf (weight_spmf q) p) (scale_spmf (weight_spmf q') p') \<and>
  2500    rel_spmf B (scale_spmf (weight_spmf p) q) (scale_spmf (weight_spmf p') q')"
  2501   (is "?lhs \<longleftrightarrow> ?rhs" is "_ \<longleftrightarrow> ?A \<and> ?B" is "_ \<longleftrightarrow> rel_spmf _ ?p ?p' \<and> rel_spmf _ ?q ?q'")
  2502 proof(intro iffI conjI)
  2503   assume ?rhs
  2504   then obtain pq pq' where p: "map_spmf fst pq = ?p" and p': "map_spmf snd pq = ?p'"
  2505     and q: "map_spmf fst pq' = ?q" and q': "map_spmf snd pq' = ?q'"
  2506     and *: "\<And>x x'. (x, x') \<in> set_spmf pq \<Longrightarrow> A x x'"
  2507     and **: "\<And>y y'. (y, y') \<in> set_spmf pq' \<Longrightarrow> B y y'" by(auto elim!: rel_spmfE)
  2508   let ?f = "\<lambda>((x, x'), (y, y')). ((x, y), (x', y'))"
  2509   let ?r = "1 / (weight_spmf p * weight_spmf q)"
  2510   let ?pq = "scale_spmf ?r (map_spmf ?f (pair_spmf pq pq'))"
  2511 
  2512   { fix p :: "'x spmf" and q :: "'y spmf"
  2513     assume "weight_spmf q \<noteq> 0"
  2514       and "weight_spmf p \<noteq> 0"
  2515       and "1 / (weight_spmf p * weight_spmf q) \<le> weight_spmf p * weight_spmf q"
  2516     hence "1 \<le> (weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q)"
  2517       by(simp add: pos_divide_le_eq order.strict_iff_order weight_spmf_nonneg)
  2518     moreover have "(weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q) \<le> (1 * 1) * (1 * 1)"
  2519       by(intro mult_mono)(simp_all add: weight_spmf_nonneg weight_spmf_le_1)
  2520     ultimately have "(weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q) = 1" by simp
  2521     hence *: "weight_spmf p * weight_spmf q = 1"
  2522       by(metis antisym_conv less_le mult_less_cancel_left1 weight_pair_spmf weight_spmf_le_1 weight_spmf_nonneg)
  2523     hence **: "weight_spmf p = 1" by(metis antisym_conv mult_left_le weight_spmf_le_1 weight_spmf_nonneg)
  2524     moreover from * ** have "weight_spmf q = 1" by simp
  2525     moreover note calculation }
  2526   note full = this
  2527 
  2528   show ?lhs
  2529   proof
  2530     have [simp]: "fst \<circ> ?f = map_prod fst fst" by(simp add: fun_eq_iff)
  2531     have "map_spmf fst ?pq = scale_spmf ?r (pair_spmf ?p ?q)"
  2532       by(simp add: pair_map_spmf[symmetric] p q map_scale_spmf spmf.map_comp)
  2533     also have "\<dots> = pair_spmf p q" using full[of p q]
  2534       by(simp add: pair_scale_spmf1 pair_scale_spmf2 weight_spmf_le_1 weight_spmf_nonneg)
  2535         (auto simp add: scale_scale_spmf max_def min_def field_simps weight_spmf_nonneg weight_spmf_eq_0)
  2536     finally show "map_spmf fst ?pq = \<dots>" .
  2537 
  2538     have [simp]: "snd \<circ> ?f = map_prod snd snd" by(simp add: fun_eq_iff)
  2539     from \<open>?rhs\<close> have eq: "weight_spmf p * weight_spmf q = weight_spmf p' * weight_spmf q'"
  2540       by(auto dest!: rel_spmf_weightD simp add: weight_spmf_le_1 weight_spmf_nonneg)
  2541 
  2542     have "map_spmf snd ?pq = scale_spmf ?r (pair_spmf ?p' ?q')"
  2543       by(simp add: pair_map_spmf[symmetric] p' q' map_scale_spmf spmf.map_comp)
  2544     also have "\<dots> = pair_spmf p' q'" using full[of p' q'] eq
  2545       by(simp add: pair_scale_spmf1 pair_scale_spmf2 weight_spmf_le_1 weight_spmf_nonneg)
  2546         (auto simp add: scale_scale_spmf max_def min_def field_simps weight_spmf_nonneg weight_spmf_eq_0)
  2547     finally show "map_spmf snd ?pq = \<dots>" .
  2548   qed(auto simp add: set_scale_spmf split: if_split_asm dest: * ** )
  2549 next
  2550   assume ?lhs
  2551   then obtain pq where pq: "map_spmf fst pq = pair_spmf p q"
  2552     and pq': "map_spmf snd pq = pair_spmf p' q'"
  2553     and *: "\<And>x y x' y'. ((x, y), (x', y')) \<in> set_spmf pq \<Longrightarrow> A x x' \<and> B y y'"
  2554     by(auto elim: rel_spmfE)
  2555 
  2556   show ?A
  2557   proof
  2558     let ?f = "(\<lambda>((x, y), (x', y')). (x, x'))"
  2559     let ?pq = "map_spmf ?f pq"
  2560     have [simp]: "fst \<circ> ?f = fst \<circ> fst" by(simp add: split_def o_def)
  2561     show "map_spmf fst ?pq = scale_spmf (weight_spmf q) p" using pq
  2562       by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
  2563 
  2564     have [simp]: "snd \<circ> ?f = fst \<circ> snd" by(simp add: split_def o_def)
  2565     show "map_spmf snd ?pq = scale_spmf (weight_spmf q') p'" using pq'
  2566       by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
  2567   qed(auto dest: * )
  2568 
  2569   show ?B
  2570   proof
  2571     let ?f = "(\<lambda>((x, y), (x', y')). (y, y'))"
  2572     let ?pq = "map_spmf ?f pq"
  2573     have [simp]: "fst \<circ> ?f = snd \<circ> fst" by(simp add: split_def o_def)
  2574     show "map_spmf fst ?pq = scale_spmf (weight_spmf p) q" using pq
  2575       by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
  2576 
  2577     have [simp]: "snd \<circ> ?f = snd \<circ> snd" by(simp add: split_def o_def)
  2578     show "map_spmf snd ?pq = scale_spmf (weight_spmf p') q'" using pq'
  2579       by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
  2580   qed(auto dest: * )
  2581 qed
  2582 
  2583 lemma pair_pair_spmf:
  2584   "pair_spmf (pair_spmf p q) r = map_spmf (\<lambda>(x, (y, z)). ((x, y), z)) (pair_spmf p (pair_spmf q r))"
  2585 by(simp add: pair_spmf_alt_def map_spmf_conv_bind_spmf)
  2586 
  2587 lemma pair_commute_spmf:
  2588   "pair_spmf p q = map_spmf (\<lambda>(y, x). (x, y)) (pair_spmf q p)"
  2589 unfolding pair_spmf_alt_def by(subst bind_commute_spmf)(simp add: map_spmf_conv_bind_spmf)
  2590 
  2591 subsection \<open>Assertions\<close>
  2592 
  2593 definition assert_spmf :: "bool \<Rightarrow> unit spmf"
  2594 where "assert_spmf b = (if b then return_spmf () else return_pmf None)"
  2595 
  2596 lemma assert_spmf_simps [simp]:
  2597   "assert_spmf True = return_spmf ()"
  2598   "assert_spmf False = return_pmf None"
  2599 by(simp_all add: assert_spmf_def)
  2600 
  2601 lemma in_set_assert_spmf [simp]: "x \<in> set_spmf (assert_spmf p) \<longleftrightarrow> p"
  2602 by(cases p) simp_all
  2603 
  2604 lemma set_spmf_assert_spmf_eq_empty [simp]: "set_spmf (assert_spmf b) = {} \<longleftrightarrow> \<not> b"
  2605 by(cases b) simp_all
  2606 
  2607 lemma lossless_assert_spmf [iff]: "lossless_spmf (assert_spmf b) \<longleftrightarrow> b"
  2608 by(cases b) simp_all
  2609 
  2610 subsection \<open>Try\<close>
  2611 
  2612 definition try_spmf :: "'a spmf \<Rightarrow> 'a spmf \<Rightarrow> 'a spmf" ("TRY _ ELSE _" [0,60] 59)
  2613 where "try_spmf p q = bind_pmf p (\<lambda>x. case x of None \<Rightarrow> q | Some y \<Rightarrow> return_spmf y)"
  2614 
  2615 lemma try_spmf_lossless [simp]:
  2616   assumes "lossless_spmf p"
  2617   shows "TRY p ELSE q = p"
  2618 proof -
  2619   have "TRY p ELSE q = bind_pmf p return_pmf" unfolding try_spmf_def using assms
  2620     by(auto simp add: lossless_iff_set_pmf_None split: option.split intro: bind_pmf_cong)
  2621   thus ?thesis by(simp add: bind_return_pmf')
  2622 qed
  2623 
  2624 lemma try_spmf_return_spmf1: "TRY return_spmf x ELSE q = return_spmf x"
  2625 by(simp add: try_spmf_def bind_return_pmf)
  2626 
  2627 lemma try_spmf_return_None [simp]: "TRY return_pmf None ELSE q = q"
  2628 by(simp add: try_spmf_def bind_return_pmf)
  2629 
  2630 lemma try_spmf_return_pmf_None2 [simp]: "TRY p ELSE return_pmf None = p"
  2631 by(simp add: try_spmf_def option.case_distrib[symmetric] bind_return_pmf' case_option_id)
  2632 
  2633 lemma map_try_spmf: "map_spmf f (try_spmf p q) = try_spmf (map_spmf f p) (map_spmf f q)"
  2634 by(simp add: try_spmf_def map_bind_pmf bind_map_pmf option.case_distrib[where h="map_spmf f"] o_def cong del: option.case_cong_weak)
  2635 
  2636 lemma try_spmf_bind_pmf: "TRY (bind_pmf p f) ELSE q = bind_pmf p (\<lambda>x. TRY (f x) ELSE q)"
  2637 by(simp add: try_spmf_def bind_assoc_pmf)
  2638 
  2639 lemma try_spmf_bind_spmf_lossless:
  2640   "lossless_spmf p \<Longrightarrow> TRY (bind_spmf p f) ELSE q = bind_spmf p (\<lambda>x. TRY (f x) ELSE q)"
  2641 by(auto simp add: try_spmf_def bind_spmf_def bind_assoc_pmf bind_return_pmf lossless_iff_set_pmf_None intro!: bind_pmf_cong split: option.split)
  2642 
  2643 lemma try_spmf_bind_out:
  2644   "lossless_spmf p \<Longrightarrow> bind_spmf p (\<lambda>x. TRY (f x) ELSE q) = TRY (bind_spmf p f) ELSE q"
  2645 by(simp add: try_spmf_bind_spmf_lossless)
  2646 
  2647 lemma lossless_try_spmf [simp]:
  2648   "lossless_spmf (TRY p ELSE q) \<longleftrightarrow> lossless_spmf p \<or> lossless_spmf q"
  2649 by(auto simp add: try_spmf_def in_set_spmf lossless_iff_set_pmf_None split: option.splits)
  2650 
  2651 context includes lifting_syntax
  2652 begin
  2653 
  2654 lemma try_spmf_parametric [transfer_rule]:
  2655   "(rel_spmf A ===> rel_spmf A ===> rel_spmf A) try_spmf try_spmf"
  2656 unfolding try_spmf_def[abs_def] by transfer_prover
  2657 
  2658 end
  2659 
  2660 lemma try_spmf_cong:
  2661   "\<lbrakk> p = p'; \<not> lossless_spmf p' \<Longrightarrow> q = q' \<rbrakk> \<Longrightarrow> TRY p ELSE q = TRY p' ELSE q'"
  2662 unfolding try_spmf_def
  2663 by(rule bind_pmf_cong)(auto split: option.split simp add: lossless_iff_set_pmf_None)
  2664 
  2665 lemma rel_spmf_try_spmf:
  2666   "\<lbrakk> rel_spmf R p p'; \<not> lossless_spmf p' \<Longrightarrow> rel_spmf R q q' \<rbrakk>
  2667   \<Longrightarrow> rel_spmf R (TRY p ELSE q) (TRY p' ELSE q')"
  2668 unfolding try_spmf_def
  2669 apply(rule rel_pmf_bindI[where R="\<lambda>x y. rel_option R x y \<and> x \<in> set_pmf p \<and> y \<in> set_pmf p'"])
  2670  apply(erule pmf.rel_mono_strong; simp)
  2671 apply(auto split: option.split simp add: lossless_iff_set_pmf_None)
  2672 done
  2673 
  2674 lemma spmf_try_spmf:
  2675   "spmf (TRY p ELSE q) x = spmf p x + pmf p None * spmf q x"
  2676 proof -
  2677   have "ennreal (spmf (TRY p ELSE q) x) = \<integral>\<^sup>+ y. ennreal (spmf q x) * indicator {None} y + indicator {Some x} y \<partial>measure_pmf p"
  2678     unfolding try_spmf_def ennreal_pmf_bind by(rule nn_integral_cong)(simp split: option.split split_indicator)
  2679   also have "\<dots> = (\<integral>\<^sup>+ y. ennreal (spmf q x) * indicator {None} y \<partial>measure_pmf p) + \<integral>\<^sup>+ y. indicator {Some x} y \<partial>measure_pmf p"
  2680     by(simp add: nn_integral_add)
  2681   also have "\<dots> = ennreal (spmf q x) * pmf p None + spmf p x" by(simp add: emeasure_pmf_single)
  2682   finally show ?thesis by(simp add: ennreal_mult[symmetric] ennreal_plus[symmetric] del: ennreal_plus)
  2683 qed
  2684 
  2685 lemma try_scale_spmf_same [simp]: "lossless_spmf p \<Longrightarrow> TRY scale_spmf k p ELSE p = p"
  2686 by(rule spmf_eqI)(auto simp add: spmf_try_spmf spmf_scale_spmf pmf_scale_spmf_None lossless_iff_pmf_None weight_spmf_conv_pmf_None min_def max_def field_simps)
  2687 
  2688 lemma pmf_try_spmf_None [simp]: "pmf (TRY p ELSE q) None = pmf p None * pmf q None" (is "?lhs = ?rhs")
  2689 proof -
  2690   have "?lhs = \<integral> x. pmf q None * indicator {None} x \<partial>measure_pmf p"
  2691     unfolding try_spmf_def pmf_bind by(rule Bochner_Integration.integral_cong)(simp_all split: option.split)
  2692   also have "\<dots> = ?rhs" by(simp add: measure_pmf_single)
  2693   finally show ?thesis .
  2694 qed
  2695 
  2696 lemma try_bind_spmf_lossless2:
  2697   "lossless_spmf q \<Longrightarrow> TRY (bind_spmf p f) ELSE q = TRY (p \<bind> (\<lambda>x. TRY (f x) ELSE q)) ELSE q"
  2698 by(rule spmf_eqI)(simp add: spmf_try_spmf pmf_bind_spmf_None spmf_bind field_simps measure_spmf.integrable_const_bound[where B=1] pmf_le_1 lossless_iff_pmf_None)
  2699 
  2700 lemma try_bind_spmf_lossless2':
  2701   fixes f :: "'a \<Rightarrow> 'b spmf" shows
  2702   "\<lbrakk> NO_MATCH (\<lambda>x :: 'a. try_spmf (g x :: 'b spmf) (h x)) f; lossless_spmf q \<rbrakk>
  2703   \<Longrightarrow> TRY (bind_spmf p f) ELSE q = TRY (p \<bind> (\<lambda>x :: 'a. TRY (f x) ELSE q)) ELSE q"
  2704 by(rule try_bind_spmf_lossless2)
  2705 
  2706 lemma try_bind_assert_spmf:
  2707   "TRY (assert_spmf b \<bind> f) ELSE q = (if b then TRY (f ()) ELSE q else q)"
  2708 by simp
  2709 
  2710 subsection \<open>Miscellaneous\<close>
  2711 
  2712 lemma assumes "rel_spmf (\<lambda>x y. bad1 x = bad2 y \<and> (\<not> bad2 y \<longrightarrow> A x \<longleftrightarrow> B y)) p q" (is "rel_spmf ?A _ _")
  2713   shows fundamental_lemma_bad: "measure (measure_spmf p) {x. bad1 x} = measure (measure_spmf q) {y. bad2 y}" (is "?bad")
  2714   and fundamental_lemma: "\<bar>measure (measure_spmf p) {x. A x} - measure (measure_spmf q) {y. B y}\<bar> \<le>
  2715     measure (measure_spmf p) {x. bad1 x}" (is ?fundamental)
  2716 proof -
  2717   have good: "rel_fun ?A op = (\<lambda>x. A x \<and> \<not> bad1 x) (\<lambda>y. B y \<and> \<not> bad2 y)" by(auto simp add: rel_fun_def)
  2718   from assms have 1: "measure (measure_spmf p) {x. A x \<and> \<not> bad1 x} = measure (measure_spmf q) {y. B y \<and> \<not> bad2 y}"
  2719     by(rule measure_spmf_parametric[THEN rel_funD, THEN rel_funD])(rule Collect_parametric[THEN rel_funD, OF good])
  2720 
  2721   have bad: "rel_fun ?A op = bad1 bad2" by(simp add: rel_fun_def)
  2722   show 2: ?bad using assms
  2723     by(rule measure_spmf_parametric[THEN rel_funD, THEN rel_funD])(rule Collect_parametric[THEN rel_funD, OF bad])
  2724 
  2725   let ?\<mu>p = "measure (measure_spmf p)" and ?\<mu>q = "measure (measure_spmf q)"
  2726   have "{x. A x \<and> bad1 x} \<union> {x. A x \<and> \<not> bad1 x} = {x. A x}"
  2727     and "{y. B y \<and> bad2 y} \<union> {y. B y \<and> \<not> bad2 y} = {y. B y}" by auto
  2728   then have "\<bar>?\<mu>p {x. A x} - ?\<mu>q {x. B x}\<bar> = \<bar>?\<mu>p ({x. A x \<and> bad1 x} \<union> {x. A x \<and> \<not> bad1 x}) - ?\<mu>q ({y. B y \<and> bad2 y} \<union> {y. B y \<and> \<not> bad2 y})\<bar>"
  2729     by simp
  2730   also have "\<dots> = \<bar>?\<mu>p {x. A x \<and> bad1 x} + ?\<mu>p {x. A x \<and> \<not> bad1 x} - ?\<mu>q {y. B y \<and> bad2 y} - ?\<mu>q {y. B y \<and> \<not> bad2 y}\<bar>"
  2731     by(subst (1 2) measure_Union)(auto)
  2732   also have "\<dots> = \<bar>?\<mu>p {x. A x \<and> bad1 x} - ?\<mu>q {y. B y \<and> bad2 y}\<bar>" using 1 by simp
  2733   also have "\<dots> \<le> max (?\<mu>p {x. A x \<and> bad1 x}) (?\<mu>q {y. B y \<and> bad2 y})"
  2734     by(rule abs_leI)(auto simp add: max_def not_le, simp_all only: add_increasing measure_nonneg mult_2)
  2735   also have "\<dots> \<le> max (?\<mu>p {x. bad1 x}) (?\<mu>q {y. bad2 y})"
  2736     by(rule max.mono; rule measure_spmf.finite_measure_mono; auto)
  2737   also note 2[symmetric]
  2738   finally show ?fundamental by simp
  2739 qed
  2740 
  2741 end