src/HOL/Probability/SPMF.thy
author wenzelm
Wed Jun 22 10:09:20 2016 +0200 (2016-06-22)
changeset 63343 fb5d8a50c641
parent 63333 158ab2239496
child 63540 f8652d0534fa
permissions -rw-r--r--
bundle lifting_syntax;
     1 (* Author: Andreas Lochbihler, ETH Zurich *)
     2 
     3 section \<open>Discrete subprobability distribution\<close>
     4 
     5 theory SPMF imports
     6   Probability_Mass_Function
     7   Embed_Measure
     8   "~~/src/HOL/Library/Complete_Partial_Order2"
     9   "~~/src/HOL/Library/Rewrite"
    10 begin
    11 
    12 subsection \<open>Auxiliary material\<close>
    13 
    14 lemma cSUP_singleton [simp]: "(SUP x:{x}. f x :: _ :: conditionally_complete_lattice) = f x"
    15 by (metis cSup_singleton image_empty image_insert)
    16 
    17 subsubsection \<open>More about extended reals\<close>
    18 
    19 lemma [simp]:
    20   shows ennreal_max_0: "ennreal (max 0 x) = ennreal x"
    21   and ennreal_max_0': "ennreal (max x 0) = ennreal x"
    22 by(simp_all add: max_def ennreal_eq_0_iff)
    23 
    24 lemma ennreal_enn2real_if: "ennreal (enn2real r) = (if r = \<top> then 0 else r)"
    25 by(auto intro!: ennreal_enn2real simp add: less_top)
    26 
    27 lemma e2ennreal_0 [simp]: "e2ennreal 0 = 0"
    28 by(simp add: zero_ennreal_def)
    29 
    30 lemma enn2real_bot [simp]: "enn2real \<bottom> = 0"
    31 by(simp add: bot_ennreal_def)
    32 
    33 lemma continuous_at_ennreal[continuous_intros]: "continuous F f \<Longrightarrow> continuous F (\<lambda>x. ennreal (f x))"
    34   unfolding continuous_def by auto
    35 
    36 lemma ennreal_Sup:
    37   assumes *: "(SUP a:A. ennreal a) \<noteq> \<top>"
    38   and "A \<noteq> {}"
    39   shows "ennreal (Sup A) = (SUP a:A. ennreal a)"
    40 proof (rule continuous_at_Sup_mono)
    41   obtain r where r: "ennreal r = (SUP a:A. ennreal a)" "r \<ge> 0"
    42     using * by(cases "(SUP a:A. ennreal a)") simp_all
    43   then show "bdd_above A"
    44     by(auto intro!: SUP_upper bdd_aboveI[of _ r] simp add: ennreal_le_iff[symmetric])
    45 qed (auto simp: mono_def continuous_at_imp_continuous_at_within continuous_at_ennreal ennreal_leI assms)
    46 
    47 lemma ennreal_SUP:
    48   "\<lbrakk> (SUP a:A. ennreal (f a)) \<noteq> \<top>; A \<noteq> {} \<rbrakk> \<Longrightarrow> ennreal (SUP a:A. f a) = (SUP a:A. ennreal (f a))"
    49 using ennreal_Sup[of "f ` A"] by auto
    50 
    51 lemma ennreal_lt_0: "x < 0 \<Longrightarrow> ennreal x = 0"
    52 by(simp add: ennreal_eq_0_iff)
    53 
    54 subsubsection \<open>More about @{typ "'a option"}\<close>
    55 
    56 lemma None_in_map_option_image [simp]: "None \<in> map_option f ` A \<longleftrightarrow> None \<in> A"
    57 by auto
    58 
    59 lemma Some_in_map_option_image [simp]: "Some x \<in> map_option f ` A \<longleftrightarrow> (\<exists>y. x = f y \<and> Some y \<in> A)"
    60 by(auto intro: rev_image_eqI dest: sym)
    61 
    62 lemma case_option_collapse: "case_option x (\<lambda>_. x) = (\<lambda>_. x)"
    63 by(simp add: fun_eq_iff split: option.split)
    64 
    65 lemma case_option_id: "case_option None Some = id"
    66 by(rule ext)(simp split: option.split)
    67 
    68 inductive ord_option :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> 'a option \<Rightarrow> 'b option \<Rightarrow> bool"
    69   for ord :: "'a \<Rightarrow> 'b \<Rightarrow> bool"
    70 where
    71   None: "ord_option ord None x"
    72 | Some: "ord x y \<Longrightarrow> ord_option ord (Some x) (Some y)"
    73 
    74 inductive_simps ord_option_simps [simp]:
    75   "ord_option ord None x"
    76   "ord_option ord x None"
    77   "ord_option ord (Some x) (Some y)"
    78   "ord_option ord (Some x) None"
    79 
    80 inductive_simps ord_option_eq_simps [simp]:
    81   "ord_option op = None y"
    82   "ord_option op = (Some x) y"
    83 
    84 lemma ord_option_reflI: "(\<And>y. y \<in> set_option x \<Longrightarrow> ord y y) \<Longrightarrow> ord_option ord x x"
    85 by(cases x) simp_all
    86 
    87 lemma reflp_ord_option: "reflp ord \<Longrightarrow> reflp (ord_option ord)"
    88 by(simp add: reflp_def ord_option_reflI)
    89 
    90 lemma ord_option_trans:
    91   "\<lbrakk> ord_option ord x y; ord_option ord y z;
    92     \<And>a b c. \<lbrakk> a \<in> set_option x; b \<in> set_option y; c \<in> set_option z; ord a b; ord b c \<rbrakk> \<Longrightarrow> ord a c \<rbrakk>
    93   \<Longrightarrow> ord_option ord x z"
    94 by(auto elim!: ord_option.cases)
    95 
    96 lemma transp_ord_option: "transp ord \<Longrightarrow> transp (ord_option ord)"
    97 unfolding transp_def by(blast intro: ord_option_trans)
    98 
    99 lemma antisymP_ord_option: "antisymP ord \<Longrightarrow> antisymP (ord_option ord)"
   100 by(auto intro!: antisymI elim!: ord_option.cases dest: antisymD)
   101 
   102 lemma ord_option_chainD:
   103   "Complete_Partial_Order.chain (ord_option ord) Y
   104   \<Longrightarrow> Complete_Partial_Order.chain ord {x. Some x \<in> Y}"
   105 by(rule chainI)(auto dest: chainD)
   106 
   107 definition lub_option :: "('a set \<Rightarrow> 'b) \<Rightarrow> 'a option set \<Rightarrow> 'b option"
   108 where "lub_option lub Y = (if Y \<subseteq> {None} then None else Some (lub {x. Some x \<in> Y}))"
   109 
   110 lemma map_lub_option: "map_option f (lub_option lub Y) = lub_option (f \<circ> lub) Y"
   111 by(simp add: lub_option_def)
   112 
   113 lemma lub_option_upper:
   114   assumes "Complete_Partial_Order.chain (ord_option ord) Y" "x \<in> Y"
   115   and lub_upper: "\<And>Y x. \<lbrakk> Complete_Partial_Order.chain ord Y; x \<in> Y \<rbrakk> \<Longrightarrow> ord x (lub Y)"
   116   shows "ord_option ord x (lub_option lub Y)"
   117 using assms(1-2)
   118 by(cases x)(auto simp add: lub_option_def intro: lub_upper[OF ord_option_chainD])
   119 
   120 lemma lub_option_least:
   121   assumes Y: "Complete_Partial_Order.chain (ord_option ord) Y"
   122   and upper: "\<And>x. x \<in> Y \<Longrightarrow> ord_option ord x y"
   123   assumes lub_least: "\<And>Y y. \<lbrakk> Complete_Partial_Order.chain ord Y; \<And>x. x \<in> Y \<Longrightarrow> ord x y \<rbrakk> \<Longrightarrow> ord (lub Y) y"
   124   shows "ord_option ord (lub_option lub Y) y"
   125 using Y
   126 by(cases y)(auto 4 3 simp add: lub_option_def intro: lub_least[OF ord_option_chainD] dest: upper)
   127 
   128 lemma lub_map_option: "lub_option lub (map_option f ` Y) = lub_option (lub \<circ> op ` f) Y"
   129 apply(auto simp add: lub_option_def)
   130 apply(erule notE)
   131 apply(rule arg_cong[where f=lub])
   132 apply(auto intro: rev_image_eqI dest: sym)
   133 done
   134 
   135 lemma ord_option_mono: "\<lbrakk> ord_option A x y; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> ord_option B x y"
   136 by(auto elim: ord_option.cases)
   137 
   138 lemma ord_option_mono' [mono]:
   139   "(\<And>x y. A x y \<longrightarrow> B x y) \<Longrightarrow> ord_option A x y \<longrightarrow> ord_option B x y"
   140 by(blast intro: ord_option_mono)
   141 
   142 lemma ord_option_compp: "ord_option (A OO B) = ord_option A OO ord_option B"
   143 by(auto simp add: fun_eq_iff elim!: ord_option.cases intro: ord_option.intros)
   144 
   145 lemma ord_option_inf: "inf (ord_option A) (ord_option B) = ord_option (inf A B)" (is "?lhs = ?rhs")
   146 proof(rule antisym)
   147   show "?lhs \<le> ?rhs" by(auto elim!: ord_option.cases)
   148 qed(auto elim: ord_option_mono)
   149 
   150 lemma ord_option_map2: "ord_option ord x (map_option f y) = ord_option (\<lambda>x y. ord x (f y)) x y"
   151 by(auto elim: ord_option.cases)
   152 
   153 lemma ord_option_map1: "ord_option ord (map_option f x) y = ord_option (\<lambda>x y. ord (f x) y) x y"
   154 by(auto elim: ord_option.cases)
   155 
   156 lemma option_ord_Some1_iff: "option_ord (Some x) y \<longleftrightarrow> y = Some x"
   157 by(auto simp add: flat_ord_def)
   158 
   159 subsubsection \<open>A relator for sets that treats sets like predicates\<close>
   160 
   161 context includes lifting_syntax
   162 begin
   163 
   164 definition rel_pred :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> 'a set \<Rightarrow> 'b set \<Rightarrow> bool"
   165 where "rel_pred R A B = (R ===> op =) (\<lambda>x. x \<in> A) (\<lambda>y. y \<in> B)"
   166 
   167 lemma rel_predI: "(R ===> op =) (\<lambda>x. x \<in> A) (\<lambda>y. y \<in> B) \<Longrightarrow> rel_pred R A B"
   168 by(simp add: rel_pred_def)
   169 
   170 lemma rel_predD: "\<lbrakk> rel_pred R A B; R x y \<rbrakk> \<Longrightarrow> x \<in> A \<longleftrightarrow> y \<in> B"
   171 by(simp add: rel_pred_def rel_fun_def)
   172 
   173 lemma Collect_parametric: "((A ===> op =) ===> rel_pred A) Collect Collect"
   174   \<comment> \<open>Declare this rule as @{attribute transfer_rule} only locally
   175       because it blows up the search space for @{method transfer}
   176       (in combination with @{thm [source] Collect_transfer})\<close>
   177 by(simp add: rel_funI rel_predI)
   178 
   179 end
   180 
   181 subsubsection \<open>Monotonicity rules\<close>
   182 
   183 lemma monotone_gfp_eadd1: "monotone op \<ge> op \<ge> (\<lambda>x. x + y :: enat)"
   184 by(auto intro!: monotoneI)
   185 
   186 lemma monotone_gfp_eadd2: "monotone op \<ge> op \<ge> (\<lambda>y. x + y :: enat)"
   187 by(auto intro!: monotoneI)
   188 
   189 lemma mono2mono_gfp_eadd[THEN gfp.mono2mono2, cont_intro, simp]:
   190   shows monotone_eadd: "monotone (rel_prod op \<ge> op \<ge>) op \<ge> (\<lambda>(x, y). x + y :: enat)"
   191 by(simp add: monotone_gfp_eadd1 monotone_gfp_eadd2)
   192 
   193 lemma eadd_gfp_partial_function_mono [partial_function_mono]:
   194   "\<lbrakk> monotone (fun_ord op \<ge>) op \<ge> f; monotone (fun_ord op \<ge>) op \<ge> g \<rbrakk>
   195   \<Longrightarrow> monotone (fun_ord op \<ge>) op \<ge> (\<lambda>x. f x + g x :: enat)"
   196 by(rule mono2mono_gfp_eadd)
   197 
   198 lemma mono2mono_ereal[THEN lfp.mono2mono]:
   199   shows monotone_ereal: "monotone op \<le> op \<le> ereal"
   200 by(rule monotoneI) simp
   201 
   202 lemma mono2mono_ennreal[THEN lfp.mono2mono]:
   203   shows monotone_ennreal: "monotone op \<le> op \<le> ennreal"
   204 by(rule monotoneI)(simp add: ennreal_leI)
   205 
   206 subsubsection \<open>Bijections\<close>
   207 
   208 lemma bi_unique_rel_set_bij_betw:
   209   assumes unique: "bi_unique R"
   210   and rel: "rel_set R A B"
   211   shows "\<exists>f. bij_betw f A B \<and> (\<forall>x\<in>A. R x (f x))"
   212 proof -
   213   from assms obtain f where f: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)" and B: "\<And>x. x \<in> A \<Longrightarrow> f x \<in> B"
   214     apply(atomize_elim)
   215     apply(fold all_conj_distrib)
   216     apply(subst choice_iff[symmetric])
   217     apply(auto dest: rel_setD1)
   218     done
   219   have "inj_on f A" by(rule inj_onI)(auto dest!: f dest: bi_uniqueDl[OF unique])
   220   moreover have "f ` A = B" using rel
   221     by(auto 4 3 intro: B dest: rel_setD2 f bi_uniqueDr[OF unique])
   222   ultimately have "bij_betw f A B" unfolding bij_betw_def ..
   223   thus ?thesis using f by blast
   224 qed
   225 
   226 lemma bij_betw_rel_setD: "bij_betw f A B \<Longrightarrow> rel_set (\<lambda>x y. y = f x) A B"
   227 by(rule rel_setI)(auto dest: bij_betwE bij_betw_imp_surj_on[symmetric])
   228 
   229 subsection \<open>Subprobability mass function\<close>
   230 
   231 type_synonym 'a spmf = "'a option pmf"
   232 translations (type) "'a spmf" \<leftharpoondown> (type) "'a option pmf"
   233 
   234 definition measure_spmf :: "'a spmf \<Rightarrow> 'a measure"
   235 where "measure_spmf p = distr (restrict_space (measure_pmf p) (range Some)) (count_space UNIV) the"
   236 
   237 abbreviation spmf :: "'a spmf \<Rightarrow> 'a \<Rightarrow> real"
   238 where "spmf p x \<equiv> pmf p (Some x)"
   239 
   240 lemma space_measure_spmf: "space (measure_spmf p) = UNIV"
   241 by(simp add: measure_spmf_def)
   242 
   243 lemma sets_measure_spmf [simp, measurable_cong]: "sets (measure_spmf p) = sets (count_space UNIV)"
   244 by(simp add: measure_spmf_def)
   245 
   246 lemma measure_spmf_not_bot [simp]: "measure_spmf p \<noteq> \<bottom>"
   247 proof
   248   assume "measure_spmf p = \<bottom>"
   249   hence "space (measure_spmf p) = space \<bottom>" by simp
   250   thus False by(simp add: space_measure_spmf)
   251 qed
   252 
   253 lemma measurable_the_measure_pmf_Some [measurable, simp]:
   254   "the \<in> measurable (restrict_space (measure_pmf p) (range Some)) (count_space UNIV)"
   255 by(auto simp add: measurable_def sets_restrict_space space_restrict_space integral_restrict_space)
   256 
   257 lemma measurable_spmf_measure1[simp]: "measurable (measure_spmf M) N = UNIV \<rightarrow> space N"
   258 by(auto simp: measurable_def space_measure_spmf)
   259 
   260 lemma measurable_spmf_measure2[simp]: "measurable N (measure_spmf M) = measurable N (count_space UNIV)"
   261 by(intro measurable_cong_sets) simp_all
   262 
   263 lemma subprob_space_measure_spmf [simp, intro!]: "subprob_space (measure_spmf p)"
   264 proof
   265   show "emeasure (measure_spmf p) (space (measure_spmf p)) \<le> 1"
   266     by(simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space measure_pmf.measure_le_1)
   267 qed(simp add: space_measure_spmf)
   268 
   269 interpretation measure_spmf: subprob_space "measure_spmf p" for p
   270 by(rule subprob_space_measure_spmf)
   271 
   272 lemma finite_measure_spmf [simp]: "finite_measure (measure_spmf p)"
   273 by unfold_locales
   274 
   275 lemma spmf_conv_measure_spmf: "spmf p x = measure (measure_spmf p) {x}"
   276 by(auto simp add: measure_spmf_def measure_distr measure_restrict_space pmf.rep_eq space_restrict_space intro: arg_cong2[where f=measure])
   277 
   278 lemma emeasure_measure_spmf_conv_measure_pmf:
   279   "emeasure (measure_spmf p) A = emeasure (measure_pmf p) (Some ` A)"
   280 by(auto simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])
   281 
   282 lemma measure_measure_spmf_conv_measure_pmf:
   283   "measure (measure_spmf p) A = measure (measure_pmf p) (Some ` A)"
   284 using emeasure_measure_spmf_conv_measure_pmf[of p A]
   285 by(simp add: measure_spmf.emeasure_eq_measure measure_pmf.emeasure_eq_measure)
   286 
   287 lemma emeasure_spmf_map_pmf_Some [simp]:
   288   "emeasure (measure_spmf (map_pmf Some p)) A = emeasure (measure_pmf p) A"
   289 by(auto simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])
   290 
   291 lemma measure_spmf_map_pmf_Some [simp]:
   292   "measure (measure_spmf (map_pmf Some p)) A = measure (measure_pmf p) A"
   293 using emeasure_spmf_map_pmf_Some[of p A] by(simp add: measure_spmf.emeasure_eq_measure measure_pmf.emeasure_eq_measure)
   294 
   295 lemma nn_integral_measure_spmf: "(\<integral>\<^sup>+ x. f x \<partial>measure_spmf p) = \<integral>\<^sup>+ x. ennreal (spmf p x) * f x \<partial>count_space UNIV"
   296   (is "?lhs = ?rhs")
   297 proof -
   298   have "?lhs = \<integral>\<^sup>+ x. pmf p x * f (the x) \<partial>count_space (range Some)"
   299     by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space nn_integral_measure_pmf nn_integral_count_space_indicator ac_simps times_ereal.simps(1)[symmetric] del: times_ereal.simps(1))
   300   also have "\<dots> = \<integral>\<^sup>+ x. ennreal (spmf p (the x)) * f (the x) \<partial>count_space (range Some)"
   301     by(rule nn_integral_cong) auto
   302   also have "\<dots> = \<integral>\<^sup>+ x. spmf p (the (Some x)) * f (the (Some x)) \<partial>count_space UNIV"
   303     by(rule nn_integral_bij_count_space[symmetric])(simp add: bij_betw_def)
   304   also have "\<dots> = ?rhs" by simp
   305   finally show ?thesis .
   306 qed
   307 
   308 lemma integral_measure_spmf:
   309   assumes "integrable (measure_spmf p) f"
   310   shows "(\<integral> x. f x \<partial>measure_spmf p) = \<integral> x. spmf p x * f x \<partial>count_space UNIV"
   311 proof -
   312   have "integrable (count_space UNIV) (\<lambda>x. spmf p x * f x)"
   313     using assms by(simp add: integrable_iff_bounded nn_integral_measure_spmf abs_mult ennreal_mult'')
   314   then show ?thesis using assms
   315     by(simp add: real_lebesgue_integral_def nn_integral_measure_spmf ennreal_mult'[symmetric])
   316 qed
   317 
   318 lemma emeasure_spmf_single: "emeasure (measure_spmf p) {x} = spmf p x"
   319 by(simp add: measure_spmf.emeasure_eq_measure spmf_conv_measure_spmf)
   320 
   321 lemma measurable_measure_spmf[measurable]:
   322   "(\<lambda>x. measure_spmf (M x)) \<in> measurable (count_space UNIV) (subprob_algebra (count_space UNIV))"
   323 by (auto simp: space_subprob_algebra)
   324 
   325 lemma nn_integral_measure_spmf_conv_measure_pmf:
   326   assumes [measurable]: "f \<in> borel_measurable (count_space UNIV)"
   327   shows "nn_integral (measure_spmf p) f = nn_integral (restrict_space (measure_pmf p) (range Some)) (f \<circ> the)"
   328 by(simp add: measure_spmf_def nn_integral_distr o_def)
   329 
   330 lemma measure_spmf_in_space_subprob_algebra [simp]:
   331   "measure_spmf p \<in> space (subprob_algebra (count_space UNIV))"
   332 by(simp add: space_subprob_algebra)
   333 
   334 lemma nn_integral_spmf_neq_top: "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV) \<noteq> \<top>"
   335 using nn_integral_measure_spmf[where f="\<lambda>_. 1", of p, symmetric] by simp
   336 
   337 lemma SUP_spmf_neq_top': "(SUP p:Y. ennreal (spmf p x)) \<noteq> \<top>"
   338 proof(rule neq_top_trans)
   339   show "(SUP p:Y. ennreal (spmf p x)) \<le> 1" by(rule SUP_least)(simp add: pmf_le_1)
   340 qed simp
   341 
   342 lemma SUP_spmf_neq_top: "(SUP i. ennreal (spmf (Y i) x)) \<noteq> \<top>"
   343 proof(rule neq_top_trans)
   344   show "(SUP i. ennreal (spmf (Y i) x)) \<le> 1" by(rule SUP_least)(simp add: pmf_le_1)
   345 qed simp
   346 
   347 lemma SUP_emeasure_spmf_neq_top: "(SUP p:Y. emeasure (measure_spmf p) A) \<noteq> \<top>"
   348 proof(rule neq_top_trans)
   349   show "(SUP p:Y. emeasure (measure_spmf p) A) \<le> 1"
   350     by(rule SUP_least)(simp add: measure_spmf.subprob_emeasure_le_1)
   351 qed simp
   352 
   353 subsection \<open>Support\<close>
   354 
   355 definition set_spmf :: "'a spmf \<Rightarrow> 'a set"
   356 where "set_spmf p = set_pmf p \<bind> set_option"
   357 
   358 lemma set_spmf_rep_eq: "set_spmf p = {x. measure (measure_spmf p) {x} \<noteq> 0}"
   359 proof -
   360   have "\<And>x :: 'a. the -` {x} \<inter> range Some = {Some x}" by auto
   361   then show ?thesis
   362     by(auto simp add: set_spmf_def set_pmf.rep_eq measure_spmf_def measure_distr measure_restrict_space space_restrict_space intro: rev_image_eqI)
   363 qed
   364 
   365 lemma in_set_spmf: "x \<in> set_spmf p \<longleftrightarrow> Some x \<in> set_pmf p"
   366 by(simp add: set_spmf_def)
   367 
   368 lemma AE_measure_spmf_iff [simp]: "(AE x in measure_spmf p. P x) \<longleftrightarrow> (\<forall>x\<in>set_spmf p. P x)"
   369 by(auto 4 3 simp add: measure_spmf_def AE_distr_iff AE_restrict_space_iff AE_measure_pmf_iff set_spmf_def cong del: AE_cong)
   370 
   371 lemma spmf_eq_0_set_spmf: "spmf p x = 0 \<longleftrightarrow> x \<notin> set_spmf p"
   372 by(auto simp add: pmf_eq_0_set_pmf set_spmf_def intro: rev_image_eqI)
   373 
   374 lemma in_set_spmf_iff_spmf: "x \<in> set_spmf p \<longleftrightarrow> spmf p x \<noteq> 0"
   375 by(auto simp add: set_spmf_def set_pmf_iff intro: rev_image_eqI)
   376 
   377 lemma set_spmf_return_pmf_None [simp]: "set_spmf (return_pmf None) = {}"
   378 by(auto simp add: set_spmf_def)
   379 
   380 lemma countable_set_spmf [simp]: "countable (set_spmf p)"
   381 by(simp add: set_spmf_def bind_UNION)
   382 
   383 lemma spmf_eqI:
   384   assumes "\<And>i. spmf p i = spmf q i"
   385   shows "p = q"
   386 proof(rule pmf_eqI)
   387   fix i
   388   show "pmf p i = pmf q i"
   389   proof(cases i)
   390     case (Some i')
   391     thus ?thesis by(simp add: assms)
   392   next
   393     case None
   394     have "ennreal (pmf p i) = measure (measure_pmf p) {i}" by(simp add: pmf_def)
   395     also have "{i} = space (measure_pmf p) - range Some"
   396       by(auto simp add: None intro: ccontr)
   397     also have "measure (measure_pmf p) \<dots> = ennreal 1 - measure (measure_pmf p) (range Some)"
   398       by(simp add: measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf)
   399     also have "range Some = (\<Union>x\<in>set_spmf p. {Some x}) \<union> Some ` (- set_spmf p)"
   400       by auto
   401     also have "measure (measure_pmf p) \<dots> = measure (measure_pmf p) (\<Union>x\<in>set_spmf p. {Some x})"
   402       by(rule measure_pmf.measure_zero_union)(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff)
   403     also have "ennreal \<dots> = \<integral>\<^sup>+ x. measure (measure_pmf p) {Some x} \<partial>count_space (set_spmf p)"
   404       unfolding measure_pmf.emeasure_eq_measure[symmetric]
   405       by(simp_all add: emeasure_UN_countable disjoint_family_on_def)
   406     also have "\<dots> = \<integral>\<^sup>+ x. spmf p x \<partial>count_space (set_spmf p)" by(simp add: pmf_def)
   407     also have "\<dots> = \<integral>\<^sup>+ x. spmf q x \<partial>count_space (set_spmf p)" by(simp add: assms)
   408     also have "set_spmf p = set_spmf q" by(auto simp add: in_set_spmf_iff_spmf assms)
   409     also have "(\<integral>\<^sup>+ x. spmf q x \<partial>count_space (set_spmf q)) = \<integral>\<^sup>+ x. measure (measure_pmf q) {Some x} \<partial>count_space (set_spmf q)"
   410       by(simp add: pmf_def)
   411     also have "\<dots> = measure (measure_pmf q) (\<Union>x\<in>set_spmf q. {Some x})"
   412       unfolding measure_pmf.emeasure_eq_measure[symmetric]
   413       by(simp_all add: emeasure_UN_countable disjoint_family_on_def)
   414     also have "\<dots> = measure (measure_pmf q) ((\<Union>x\<in>set_spmf q. {Some x}) \<union> Some ` (- set_spmf q))"
   415       by(rule ennreal_cong measure_pmf.measure_zero_union[symmetric])+(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff)
   416     also have "((\<Union>x\<in>set_spmf q. {Some x}) \<union> Some ` (- set_spmf q)) = range Some" by auto
   417     also have "ennreal 1 - measure (measure_pmf q) \<dots> = measure (measure_pmf q) (space (measure_pmf q) - range Some)"
   418       by(simp add: one_ereal_def measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf)
   419     also have "space (measure_pmf q) - range Some = {i}"
   420       by(auto simp add: None intro: ccontr)
   421     also have "measure (measure_pmf q) \<dots> = pmf q i" by(simp add: pmf_def)
   422     finally show ?thesis by simp
   423   qed
   424 qed
   425 
   426 lemma integral_measure_spmf_restrict:
   427   fixes f ::  "'a \<Rightarrow> 'b :: {banach, second_countable_topology}" shows
   428   "(\<integral> x. f x \<partial>measure_spmf M) = (\<integral> x. f x \<partial>restrict_space (measure_spmf M) (set_spmf M))"
   429 by(auto intro!: integral_cong_AE simp add: integral_restrict_space)
   430 
   431 lemma nn_integral_measure_spmf':
   432   "(\<integral>\<^sup>+ x. f x \<partial>measure_spmf p) = \<integral>\<^sup>+ x. ennreal (spmf p x) * f x \<partial>count_space (set_spmf p)"
   433 by(auto simp add: nn_integral_measure_spmf nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator)
   434 
   435 subsection \<open>Functorial structure\<close>
   436 
   437 abbreviation map_spmf :: "('a \<Rightarrow> 'b) \<Rightarrow> 'a spmf \<Rightarrow> 'b spmf"
   438 where "map_spmf f \<equiv> map_pmf (map_option f)"
   439 
   440 context begin
   441 local_setup \<open>Local_Theory.map_background_naming (Name_Space.mandatory_path "spmf")\<close>
   442 
   443 lemma map_comp: "map_spmf f (map_spmf g p) = map_spmf (f \<circ> g) p"
   444 by(simp add: pmf.map_comp o_def option.map_comp)
   445 
   446 lemma map_id0: "map_spmf id = id"
   447 by(simp add: pmf.map_id option.map_id0)
   448 
   449 lemma map_id [simp]: "map_spmf id p = p"
   450 by(simp add: map_id0)
   451 
   452 lemma map_ident [simp]: "map_spmf (\<lambda>x. x) p = p"
   453 by(simp add: id_def[symmetric])
   454 
   455 end
   456 
   457 lemma set_map_spmf [simp]: "set_spmf (map_spmf f p) = f ` set_spmf p"
   458 by(simp add: set_spmf_def image_bind bind_image o_def Option.option.set_map)
   459 
   460 lemma map_spmf_cong:
   461   "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q \<Longrightarrow> f x = g x \<rbrakk>
   462   \<Longrightarrow> map_spmf f p = map_spmf g q"
   463 by(auto intro: pmf.map_cong option.map_cong simp add: in_set_spmf)
   464 
   465 lemma map_spmf_cong_simp:
   466   "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q =simp=> f x = g x \<rbrakk>
   467   \<Longrightarrow> map_spmf f p = map_spmf g q"
   468 unfolding simp_implies_def by(rule map_spmf_cong)
   469 
   470 lemma map_spmf_idI: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> f x = x) \<Longrightarrow> map_spmf f p = p"
   471 by(rule map_pmf_idI map_option_idI)+(simp add: in_set_spmf)
   472 
   473 lemma emeasure_map_spmf:
   474   "emeasure (measure_spmf (map_spmf f p)) A = emeasure (measure_spmf p) (f -` A)"
   475 by(auto simp add: measure_spmf_def emeasure_distr measurable_restrict_space1 space_restrict_space emeasure_restrict_space intro: arg_cong2[where f=emeasure])
   476 
   477 lemma measure_map_spmf: "measure (measure_spmf (map_spmf f p)) A = measure (measure_spmf p) (f -` A)"
   478 using emeasure_map_spmf[of f p A] by(simp add: measure_spmf.emeasure_eq_measure)
   479 
   480 lemma measure_map_spmf_conv_distr:
   481   "measure_spmf (map_spmf f p) = distr (measure_spmf p) (count_space UNIV) f"
   482 by(rule measure_eqI)(simp_all add: emeasure_map_spmf emeasure_distr)
   483 
   484 lemma spmf_map_pmf_Some [simp]: "spmf (map_pmf Some p) i = pmf p i"
   485 by(simp add: pmf_map_inj')
   486 
   487 lemma spmf_map_inj: "\<lbrakk> inj_on f (set_spmf M); x \<in> set_spmf M \<rbrakk> \<Longrightarrow> spmf (map_spmf f M) (f x) = spmf M x"
   488 by(subst option.map(2)[symmetric, where f=f])(rule pmf_map_inj, auto simp add: in_set_spmf inj_on_def elim!: option.inj_map_strong[rotated])
   489 
   490 lemma spmf_map_inj': "inj f \<Longrightarrow> spmf (map_spmf f M) (f x) = spmf M x"
   491 by(subst option.map(2)[symmetric, where f=f])(rule pmf_map_inj'[OF option.inj_map])
   492 
   493 lemma spmf_map_outside: "x \<notin> f ` set_spmf M \<Longrightarrow> spmf (map_spmf f M) x = 0"
   494 unfolding spmf_eq_0_set_spmf by simp
   495 
   496 lemma ennreal_spmf_map: "ennreal (spmf (map_spmf f p) x) = emeasure (measure_spmf p) (f -` {x})"
   497 by(auto simp add: ennreal_pmf_map measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])
   498 
   499 lemma spmf_map: "spmf (map_spmf f p) x = measure (measure_spmf p) (f -` {x})"
   500 using ennreal_spmf_map[of f p x] by(simp add: measure_spmf.emeasure_eq_measure)
   501 
   502 lemma ennreal_spmf_map_conv_nn_integral:
   503   "ennreal (spmf (map_spmf f p) x) = integral\<^sup>N (measure_spmf p) (indicator (f -` {x}))"
   504 by(auto simp add: ennreal_pmf_map measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space intro: arg_cong2[where f=emeasure])
   505 
   506 subsection \<open>Monad operations\<close>
   507 
   508 subsubsection \<open>Return\<close>
   509 
   510 abbreviation return_spmf :: "'a \<Rightarrow> 'a spmf"
   511 where "return_spmf x \<equiv> return_pmf (Some x)"
   512 
   513 lemma pmf_return_spmf: "pmf (return_spmf x) y = indicator {y} (Some x)"
   514 by(fact pmf_return)
   515 
   516 lemma measure_spmf_return_spmf: "measure_spmf (return_spmf x) = Giry_Monad.return (count_space UNIV) x"
   517 by(rule measure_eqI)(simp_all add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space indicator_def)
   518 
   519 lemma measure_spmf_return_pmf_None [simp]: "measure_spmf (return_pmf None) = null_measure (count_space UNIV)"
   520 by(rule measure_eqI)(auto simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space indicator_eq_0_iff)
   521 
   522 lemma set_return_spmf [simp]: "set_spmf (return_spmf x) = {x}"
   523 by(auto simp add: set_spmf_def)
   524 
   525 subsubsection \<open>Bind\<close>
   526 
   527 definition bind_spmf :: "'a spmf \<Rightarrow> ('a \<Rightarrow> 'b spmf) \<Rightarrow> 'b spmf"
   528 where "bind_spmf x f = bind_pmf x (\<lambda>a. case a of None \<Rightarrow> return_pmf None | Some a' \<Rightarrow> f a')"
   529 
   530 adhoc_overloading Monad_Syntax.bind bind_spmf
   531 
   532 lemma return_None_bind_spmf [simp]: "return_pmf None \<bind> (f :: 'a \<Rightarrow> _) = return_pmf None"
   533 by(simp add: bind_spmf_def bind_return_pmf)
   534 
   535 lemma return_bind_spmf [simp]: "return_spmf x \<bind> f = f x"
   536 by(simp add: bind_spmf_def bind_return_pmf)
   537 
   538 lemma bind_return_spmf [simp]: "x \<bind> return_spmf = x"
   539 proof -
   540   have "\<And>a :: 'a option. (case a of None \<Rightarrow> return_pmf None | Some a' \<Rightarrow> return_spmf a') = return_pmf a"
   541     by(simp split: option.split)
   542   then show ?thesis
   543     by(simp add: bind_spmf_def bind_return_pmf')
   544 qed
   545 
   546 lemma bind_spmf_assoc [simp]:
   547   fixes x :: "'a spmf" and f :: "'a \<Rightarrow> 'b spmf" and g :: "'b \<Rightarrow> 'c spmf"
   548   shows "(x \<bind> f) \<bind> g = x \<bind> (\<lambda>y. f y \<bind> g)"
   549 by(auto simp add: bind_spmf_def bind_assoc_pmf fun_eq_iff bind_return_pmf split: option.split intro: arg_cong[where f="bind_pmf x"])
   550 
   551 lemma pmf_bind_spmf_None: "pmf (p \<bind> f) None = pmf p None + \<integral> x. pmf (f x) None \<partial>measure_spmf p"
   552   (is "?lhs = ?rhs")
   553 proof -
   554   let ?f = "\<lambda>x. pmf (case x of None \<Rightarrow> return_pmf None | Some x \<Rightarrow> f x) None"
   555   have "?lhs = \<integral> x. ?f x \<partial>measure_pmf p"
   556     by(simp add: bind_spmf_def pmf_bind)
   557   also have "\<dots> = \<integral> x. ?f None * indicator {None} x + ?f x * indicator (range Some) x \<partial>measure_pmf p"
   558     by(rule integral_cong)(auto simp add: indicator_def)
   559   also have "\<dots> = (\<integral> x. ?f None * indicator {None} x \<partial>measure_pmf p) + (\<integral> x. ?f x * indicator (range Some) x \<partial>measure_pmf p)"
   560     by(rule integral_add)(auto 4 3 intro: integrable_real_mult_indicator measure_pmf.integrable_const_bound[where B=1] simp add: AE_measure_pmf_iff pmf_le_1)
   561   also have "\<dots> = pmf p None + \<integral> x. indicator (range Some) x * pmf (f (the x)) None \<partial>measure_pmf p"
   562     by(auto simp add: measure_measure_pmf_finite indicator_eq_0_iff intro!: integral_cong)
   563   also have "\<dots> = ?rhs" unfolding measure_spmf_def
   564     by(subst integral_distr)(auto simp add: integral_restrict_space)
   565   finally show ?thesis .
   566 qed
   567 
   568 lemma spmf_bind: "spmf (p \<bind> f) y = \<integral> x. spmf (f x) y \<partial>measure_spmf p"
   569 unfolding measure_spmf_def
   570 by(subst integral_distr)(auto simp add: bind_spmf_def pmf_bind integral_restrict_space indicator_eq_0_iff intro!: integral_cong split: option.split)
   571 
   572 lemma ennreal_spmf_bind: "ennreal (spmf (p \<bind> f) x) = \<integral>\<^sup>+ y. spmf (f y) x \<partial>measure_spmf p"
   573 by(auto simp add: bind_spmf_def ennreal_pmf_bind nn_integral_measure_spmf_conv_measure_pmf nn_integral_restrict_space intro: nn_integral_cong split: split_indicator option.split)
   574 
   575 lemma measure_spmf_bind_pmf: "measure_spmf (p \<bind> f) = measure_pmf p \<bind> measure_spmf \<circ> f"
   576   (is "?lhs = ?rhs")
   577 proof(rule measure_eqI)
   578   show "sets ?lhs = sets ?rhs"
   579     by(simp add: sets_bind[where N="count_space UNIV"] space_measure_spmf)
   580 next
   581   fix A :: "'a set"
   582   have "emeasure ?lhs A = \<integral>\<^sup>+ x. emeasure (measure_spmf (f x)) A \<partial>measure_pmf p"
   583     by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def)
   584   also have "\<dots> = emeasure ?rhs A"
   585     by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra)
   586   finally show "emeasure ?lhs A = emeasure ?rhs A" .
   587 qed
   588 
   589 lemma measure_spmf_bind: "measure_spmf (p \<bind> f) = measure_spmf p \<bind> measure_spmf \<circ> f"
   590   (is "?lhs = ?rhs")
   591 proof(rule measure_eqI)
   592   show "sets ?lhs = sets ?rhs"
   593     by(simp add: sets_bind[where N="count_space UNIV"] space_measure_spmf)
   594 next
   595   fix A :: "'a set"
   596   let ?A = "the -` A \<inter> range Some"
   597   have "emeasure ?lhs A = \<integral>\<^sup>+ x. emeasure (measure_pmf (case x of None \<Rightarrow> return_pmf None | Some x \<Rightarrow> f x)) ?A \<partial>measure_pmf p"
   598     by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def)
   599   also have "\<dots> =  \<integral>\<^sup>+ x. emeasure (measure_pmf (f (the x))) ?A * indicator (range Some) x \<partial>measure_pmf p"
   600     by(rule nn_integral_cong)(auto split: option.split simp add: indicator_def)
   601   also have "\<dots> = \<integral>\<^sup>+ x. emeasure (measure_spmf (f x)) A \<partial>measure_spmf p"
   602     by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space emeasure_distr space_restrict_space emeasure_restrict_space)
   603   also have "\<dots> = emeasure ?rhs A"
   604     by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra)
   605   finally show "emeasure ?lhs A = emeasure ?rhs A" .
   606 qed
   607 
   608 lemma map_spmf_bind_spmf: "map_spmf f (bind_spmf p g) = bind_spmf p (map_spmf f \<circ> g)"
   609 by(auto simp add: bind_spmf_def map_bind_pmf fun_eq_iff split: option.split intro: arg_cong2[where f=bind_pmf])
   610 
   611 lemma bind_map_spmf: "map_spmf f p \<bind> g = p \<bind> g \<circ> f"
   612 by(simp add: bind_spmf_def bind_map_pmf o_def cong del: option.case_cong_weak)
   613 
   614 lemma spmf_bind_leI:
   615   assumes "\<And>y. y \<in> set_spmf p \<Longrightarrow> spmf (f y) x \<le> r"
   616   and "0 \<le> r"
   617   shows "spmf (bind_spmf p f) x \<le> r"
   618 proof -
   619   have "ennreal (spmf (bind_spmf p f) x) = \<integral>\<^sup>+ y. spmf (f y) x \<partial>measure_spmf p" by(rule ennreal_spmf_bind)
   620   also have "\<dots> \<le> \<integral>\<^sup>+ y. r \<partial>measure_spmf p" by(rule nn_integral_mono_AE)(simp add: assms)
   621   also have "\<dots> \<le> r" using assms measure_spmf.emeasure_space_le_1
   622     by(auto simp add: measure_spmf.emeasure_eq_measure intro!: mult_left_le)
   623   finally show ?thesis using assms(2) by(simp)
   624 qed
   625 
   626 lemma map_spmf_conv_bind_spmf: "map_spmf f p = (p \<bind> (\<lambda>x. return_spmf (f x)))"
   627 by(simp add: map_pmf_def bind_spmf_def)(rule bind_pmf_cong, simp_all split: option.split)
   628 
   629 lemma bind_spmf_cong:
   630   "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q \<Longrightarrow> f x = g x \<rbrakk>
   631   \<Longrightarrow> bind_spmf p f = bind_spmf q g"
   632 by(auto simp add: bind_spmf_def in_set_spmf intro: bind_pmf_cong option.case_cong)
   633 
   634 lemma bind_spmf_cong_simp:
   635   "\<lbrakk> p = q; \<And>x. x \<in> set_spmf q =simp=> f x = g x \<rbrakk>
   636   \<Longrightarrow> bind_spmf p f = bind_spmf q g"
   637 by(simp add: simp_implies_def cong: bind_spmf_cong)
   638 
   639 lemma set_bind_spmf: "set_spmf (M \<bind> f) = set_spmf M \<bind> (set_spmf \<circ> f)"
   640 by(auto simp add: set_spmf_def bind_spmf_def bind_UNION split: option.splits)
   641 
   642 lemma bind_spmf_const_return_None [simp]: "bind_spmf p (\<lambda>_. return_pmf None) = return_pmf None"
   643 by(simp add: bind_spmf_def case_option_collapse)
   644 
   645 lemma bind_commute_spmf:
   646   "bind_spmf p (\<lambda>x. bind_spmf q (f x)) = bind_spmf q (\<lambda>y. bind_spmf p (\<lambda>x. f x y))"
   647   (is "?lhs = ?rhs")
   648 proof -
   649   let ?f = "\<lambda>x y. case x of None \<Rightarrow> return_pmf None | Some a \<Rightarrow> (case y of None \<Rightarrow> return_pmf None | Some b \<Rightarrow> f a b)"
   650   have "?lhs = p \<bind> (\<lambda>x. q \<bind> ?f x)"
   651     unfolding bind_spmf_def by(rule bind_pmf_cong[OF refl])(simp split: option.split)
   652   also have "\<dots> = q \<bind> (\<lambda>y. p \<bind> (\<lambda>x. ?f x y))" by(rule bind_commute_pmf)
   653   also have "\<dots> = ?rhs" unfolding bind_spmf_def
   654     by(rule bind_pmf_cong[OF refl])(auto split: option.split, metis bind_spmf_const_return_None bind_spmf_def)
   655   finally show ?thesis .
   656 qed
   657 
   658 subsection \<open>Relator\<close>
   659 
   660 abbreviation rel_spmf :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> 'a spmf \<Rightarrow> 'b spmf \<Rightarrow> bool"
   661 where "rel_spmf R \<equiv> rel_pmf (rel_option R)"
   662 
   663 lemma rel_pmf_mono:
   664   "\<lbrakk>rel_pmf A f g; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> rel_pmf B f g"
   665 using pmf.rel_mono[of A B] by(simp add: le_fun_def)
   666 
   667 lemma rel_spmf_mono:
   668   "\<lbrakk>rel_spmf A f g; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> rel_spmf B f g"
   669 apply(erule rel_pmf_mono)
   670 using option.rel_mono[of A B] by(simp add: le_fun_def)
   671 
   672 lemma rel_spmf_mono_strong:
   673   "\<lbrakk> rel_spmf A f g; \<And>x y. \<lbrakk> A x y; x \<in> set_spmf f; y \<in> set_spmf g \<rbrakk> \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> rel_spmf B f g"
   674 apply(erule pmf.rel_mono_strong)
   675 apply(erule option.rel_mono_strong)
   676 apply(auto simp add: in_set_spmf)
   677 done
   678 
   679 lemma rel_spmf_reflI: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> P x x) \<Longrightarrow> rel_spmf P p p"
   680 by(rule rel_pmf_reflI)(auto simp add: set_spmf_def intro: rel_option_reflI)
   681 
   682 lemma rel_spmfI [intro?]:
   683   "\<lbrakk> \<And>x y. (x, y) \<in> set_spmf pq \<Longrightarrow> P x y; map_spmf fst pq = p; map_spmf snd pq = q \<rbrakk>
   684   \<Longrightarrow> rel_spmf P p q"
   685 by(rule rel_pmf.intros[where pq="map_pmf (\<lambda>x. case x of None \<Rightarrow> (None, None) | Some (a, b) \<Rightarrow> (Some a, Some b)) pq"])
   686   (auto simp add: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong)
   687 
   688 lemma rel_spmfE [elim?, consumes 1, case_names rel_spmf]:
   689   assumes "rel_spmf P p q"
   690   obtains pq where
   691     "\<And>x y. (x, y) \<in> set_spmf pq \<Longrightarrow> P x y"
   692     "p = map_spmf fst pq"
   693     "q = map_spmf snd pq"
   694 using assms
   695 proof(cases rule: rel_pmf.cases[consumes 1, case_names rel_pmf])
   696   case (rel_pmf pq)
   697   let ?pq = "map_pmf (\<lambda>(a, b). case (a, b) of (Some x, Some y) \<Rightarrow> Some (x, y) | _ \<Rightarrow> None) pq"
   698   have "\<And>x y. (x, y) \<in> set_spmf ?pq \<Longrightarrow> P x y"
   699     by(auto simp add: in_set_spmf split: option.split_asm dest: rel_pmf(1))
   700   moreover
   701   have "\<And>x. (x, None) \<in> set_pmf pq \<Longrightarrow> x = None" by(auto dest!: rel_pmf(1))
   702   then have "p = map_spmf fst ?pq" using rel_pmf(2)
   703     by(auto simp add: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split)
   704   moreover
   705   have "\<And>y. (None, y) \<in> set_pmf pq \<Longrightarrow> y = None" by(auto dest!: rel_pmf(1))
   706   then have "q = map_spmf snd ?pq" using rel_pmf(3)
   707     by(auto simp add: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split)
   708   ultimately show thesis ..
   709 qed
   710 
   711 lemma rel_spmf_simps:
   712   "rel_spmf R p q \<longleftrightarrow> (\<exists>pq. (\<forall>(x, y)\<in>set_spmf pq. R x y) \<and> map_spmf fst pq = p \<and> map_spmf snd pq = q)"
   713 by(auto intro: rel_spmfI elim!: rel_spmfE)
   714 
   715 lemma spmf_rel_map:
   716   shows spmf_rel_map1: "\<And>R f x. rel_spmf R (map_spmf f x) = rel_spmf (\<lambda>x. R (f x)) x"
   717   and spmf_rel_map2: "\<And>R x g y. rel_spmf R x (map_spmf g y) = rel_spmf (\<lambda>x y. R x (g y)) x y"
   718 by(simp_all add: fun_eq_iff pmf.rel_map option.rel_map[abs_def])
   719 
   720 lemma spmf_rel_conversep: "rel_spmf R\<inverse>\<inverse> = (rel_spmf R)\<inverse>\<inverse>"
   721 by(simp add: option.rel_conversep pmf.rel_conversep)
   722 
   723 lemma spmf_rel_eq: "rel_spmf op = = op ="
   724 by(simp add: pmf.rel_eq option.rel_eq)
   725 
   726 context includes lifting_syntax
   727 begin
   728 
   729 lemma bind_spmf_parametric [transfer_rule]:
   730   "(rel_spmf A ===> (A ===> rel_spmf B) ===> rel_spmf B) bind_spmf bind_spmf"
   731 unfolding bind_spmf_def[abs_def] by transfer_prover
   732 
   733 lemma return_spmf_parametric: "(A ===> rel_spmf A) return_spmf return_spmf"
   734 by transfer_prover
   735 
   736 lemma map_spmf_parametric: "((A ===> B) ===> rel_spmf A ===> rel_spmf B) map_spmf map_spmf"
   737 by transfer_prover
   738 
   739 lemma rel_spmf_parametric:
   740   "((A ===> B ===> op =) ===> rel_spmf A ===> rel_spmf B ===> op =) rel_spmf rel_spmf"
   741 by transfer_prover
   742 
   743 lemma set_spmf_parametric [transfer_rule]:
   744   "(rel_spmf A ===> rel_set A) set_spmf set_spmf"
   745 unfolding set_spmf_def[abs_def] by transfer_prover
   746 
   747 lemma return_spmf_None_parametric:
   748   "(rel_spmf A) (return_pmf None) (return_pmf None)"
   749 by simp
   750 
   751 end
   752 
   753 lemma rel_spmf_bindI:
   754   "\<lbrakk> rel_spmf R p q; \<And>x y. R x y \<Longrightarrow> rel_spmf P (f x) (g y) \<rbrakk>
   755   \<Longrightarrow> rel_spmf P (p \<bind> f) (q \<bind> g)"
   756 by(fact bind_spmf_parametric[THEN rel_funD, THEN rel_funD, OF _ rel_funI])
   757 
   758 lemma rel_spmf_bind_reflI:
   759   "(\<And>x. x \<in> set_spmf p \<Longrightarrow> rel_spmf P (f x) (g x)) \<Longrightarrow> rel_spmf P (p \<bind> f) (p \<bind> g)"
   760 by(rule rel_spmf_bindI[where R="\<lambda>x y. x = y \<and> x \<in> set_spmf p"])(auto intro: rel_spmf_reflI)
   761 
   762 lemma rel_pmf_return_pmfI: "P x y \<Longrightarrow> rel_pmf P (return_pmf x) (return_pmf y)"
   763 by(rule rel_pmf.intros[where pq="return_pmf (x, y)"])(simp_all)
   764 
   765 context includes lifting_syntax
   766 begin
   767 
   768 text \<open>We do not yet have a relator for @{typ "'a measure"}, so we combine @{const measure} and @{const measure_pmf}\<close>
   769 lemma measure_pmf_parametric:
   770   "(rel_pmf A ===> rel_pred A ===> op =) (\<lambda>p. measure (measure_pmf p)) (\<lambda>q. measure (measure_pmf q))"
   771 proof(rule rel_funI)+
   772   fix p q X Y
   773   assume "rel_pmf A p q" and "rel_pred A X Y"
   774   from this(1) obtain pq where A: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> A x y"
   775     and p: "p = map_pmf fst pq" and q: "q = map_pmf snd pq" by cases auto
   776   show "measure p X = measure q Y" unfolding p q measure_map_pmf
   777     by(rule measure_pmf.finite_measure_eq_AE)(auto simp add: AE_measure_pmf_iff dest!: A rel_predD[OF \<open>rel_pred _ _ _\<close>])
   778 qed
   779 
   780 lemma measure_spmf_parametric:
   781   "(rel_spmf A ===> rel_pred A ===> op =) (\<lambda>p. measure (measure_spmf p)) (\<lambda>q. measure (measure_spmf q))"
   782 unfolding measure_measure_spmf_conv_measure_pmf[abs_def]
   783 apply(rule rel_funI)+
   784 apply(erule measure_pmf_parametric[THEN rel_funD, THEN rel_funD])
   785 apply(auto simp add: rel_pred_def rel_fun_def elim: option.rel_cases)
   786 done
   787 
   788 end
   789 
   790 subsection \<open>From @{typ "'a pmf"} to @{typ "'a spmf"}\<close>
   791 
   792 definition spmf_of_pmf :: "'a pmf \<Rightarrow> 'a spmf"
   793 where "spmf_of_pmf = map_pmf Some"
   794 
   795 lemma set_spmf_spmf_of_pmf [simp]: "set_spmf (spmf_of_pmf p) = set_pmf p"
   796 by(auto simp add: spmf_of_pmf_def set_spmf_def bind_image o_def)
   797 
   798 lemma spmf_spmf_of_pmf [simp]: "spmf (spmf_of_pmf p) x = pmf p x"
   799 by(simp add: spmf_of_pmf_def)
   800 
   801 lemma pmf_spmf_of_pmf_None [simp]: "pmf (spmf_of_pmf p) None = 0"
   802 using ennreal_pmf_map[of Some p None] by(simp add: spmf_of_pmf_def)
   803 
   804 lemma emeasure_spmf_of_pmf [simp]: "emeasure (measure_spmf (spmf_of_pmf p)) A = emeasure (measure_pmf p) A"
   805 by(simp add: emeasure_measure_spmf_conv_measure_pmf spmf_of_pmf_def inj_vimage_image_eq)
   806 
   807 lemma measure_spmf_spmf_of_pmf [simp]: "measure_spmf (spmf_of_pmf p) = measure_pmf p"
   808 by(rule measure_eqI) simp_all
   809 
   810 lemma map_spmf_of_pmf [simp]: "map_spmf f (spmf_of_pmf p) = spmf_of_pmf (map_pmf f p)"
   811 by(simp add: spmf_of_pmf_def pmf.map_comp o_def)
   812 
   813 lemma rel_spmf_spmf_of_pmf [simp]: "rel_spmf R (spmf_of_pmf p) (spmf_of_pmf q) = rel_pmf R p q"
   814 by(simp add: spmf_of_pmf_def pmf.rel_map)
   815 
   816 lemma spmf_of_pmf_return_pmf [simp]: "spmf_of_pmf (return_pmf x) = return_spmf x"
   817 by(simp add: spmf_of_pmf_def)
   818 
   819 lemma bind_spmf_of_pmf [simp]: "bind_spmf (spmf_of_pmf p) f = bind_pmf p f"
   820 by(simp add: spmf_of_pmf_def bind_spmf_def bind_map_pmf)
   821 
   822 lemma set_spmf_bind_pmf: "set_spmf (bind_pmf p f) = Set.bind (set_pmf p) (set_spmf \<circ> f)"
   823 unfolding bind_spmf_of_pmf[symmetric] by(subst set_bind_spmf) simp
   824 
   825 lemma spmf_of_pmf_bind: "spmf_of_pmf (bind_pmf p f) = bind_pmf p (\<lambda>x. spmf_of_pmf (f x))"
   826 by(simp add: spmf_of_pmf_def map_bind_pmf)
   827 
   828 lemma bind_pmf_return_spmf: "p \<bind> (\<lambda>x. return_spmf (f x)) = spmf_of_pmf (map_pmf f p)"
   829 by(simp add: map_pmf_def spmf_of_pmf_bind)
   830 
   831 subsection \<open>Weight of a subprobability\<close>
   832 
   833 abbreviation weight_spmf :: "'a spmf \<Rightarrow> real"
   834 where "weight_spmf p \<equiv> measure (measure_spmf p) (space (measure_spmf p))"
   835 
   836 lemma weight_spmf_def: "weight_spmf p = measure (measure_spmf p) UNIV"
   837 by(simp add: space_measure_spmf)
   838 
   839 lemma weight_spmf_le_1: "weight_spmf p \<le> 1"
   840 by(simp add: measure_spmf.subprob_measure_le_1)
   841 
   842 lemma weight_return_spmf [simp]: "weight_spmf (return_spmf x) = 1"
   843 by(simp add: measure_spmf_return_spmf measure_return)
   844 
   845 lemma weight_return_pmf_None [simp]: "weight_spmf (return_pmf None) = 0"
   846 by(simp)
   847 
   848 lemma weight_map_spmf [simp]: "weight_spmf (map_spmf f p) = weight_spmf p"
   849 by(simp add: weight_spmf_def measure_map_spmf)
   850 
   851 lemma weight_spmf_of_pmf [simp]: "weight_spmf (spmf_of_pmf p) = 1"
   852 using measure_pmf.prob_space[of p] by(simp add: spmf_of_pmf_def weight_spmf_def)
   853 
   854 lemma weight_spmf_nonneg: "weight_spmf p \<ge> 0"
   855 by(fact measure_nonneg)
   856 
   857 lemma (in finite_measure) integrable_weight_spmf [simp]:
   858   "(\<lambda>x. weight_spmf (f x)) \<in> borel_measurable M \<Longrightarrow> integrable M (\<lambda>x. weight_spmf (f x))"
   859 by(rule integrable_const_bound[where B=1])(simp_all add: weight_spmf_nonneg weight_spmf_le_1)
   860 
   861 lemma weight_spmf_eq_nn_integral_spmf: "weight_spmf p = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
   862 by(simp add: measure_measure_spmf_conv_measure_pmf space_measure_spmf measure_pmf.emeasure_eq_measure[symmetric] nn_integral_pmf[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1)
   863 
   864 lemma weight_spmf_eq_nn_integral_support:
   865   "weight_spmf p = \<integral>\<^sup>+ x. spmf p x \<partial>count_space (set_spmf p)"
   866 unfolding weight_spmf_eq_nn_integral_spmf
   867 by(auto simp add: nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator)
   868 
   869 lemma pmf_None_eq_weight_spmf: "pmf p None = 1 - weight_spmf p"
   870 proof -
   871   have "weight_spmf p = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV" by(rule weight_spmf_eq_nn_integral_spmf)
   872   also have "\<dots> = \<integral>\<^sup>+ x. ennreal (pmf p x) * indicator (range Some) x \<partial>count_space UNIV"
   873     by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1)
   874   also have "\<dots> + pmf p None = \<integral>\<^sup>+ x. ennreal (pmf p x) * indicator (range Some) x + ennreal (pmf p None) * indicator {None} x \<partial>count_space UNIV"
   875     by(subst nn_integral_add)(simp_all add: max_def)
   876   also have "\<dots> = \<integral>\<^sup>+ x. pmf p x \<partial>count_space UNIV"
   877     by(rule nn_integral_cong)(auto split: split_indicator)
   878   also have "\<dots> = 1" by (simp add: nn_integral_pmf)
   879   finally show ?thesis by(simp add: ennreal_plus[symmetric] del: ennreal_plus)
   880 qed
   881 
   882 lemma weight_spmf_conv_pmf_None: "weight_spmf p = 1 - pmf p None"
   883 by(simp add: pmf_None_eq_weight_spmf)
   884 
   885 lemma weight_spmf_le_0: "weight_spmf p \<le> 0 \<longleftrightarrow> weight_spmf p = 0"
   886 by(rule measure_le_0_iff)
   887 
   888 lemma weight_spmf_lt_0: "\<not> weight_spmf p < 0"
   889 by(simp add: not_less weight_spmf_nonneg)
   890 
   891 lemma spmf_le_weight: "spmf p x \<le> weight_spmf p"
   892 proof -
   893   have "ennreal (spmf p x) \<le> weight_spmf p"
   894     unfolding weight_spmf_eq_nn_integral_spmf by(rule nn_integral_ge_point) simp
   895   then show ?thesis by simp
   896 qed
   897 
   898 lemma weight_spmf_eq_0: "weight_spmf p = 0 \<longleftrightarrow> p = return_pmf None"
   899 by(auto intro!: pmf_eqI simp add: pmf_None_eq_weight_spmf split: split_indicator)(metis not_Some_eq pmf_le_0_iff spmf_le_weight)
   900 
   901 lemma weight_bind_spmf: "weight_spmf (x \<bind> f) = lebesgue_integral (measure_spmf x) (weight_spmf \<circ> f)"
   902 unfolding weight_spmf_def
   903 by(simp add: measure_spmf_bind o_def measure_spmf.measure_bind[where N="count_space UNIV"])
   904 
   905 lemma rel_spmf_weightD: "rel_spmf A p q \<Longrightarrow> weight_spmf p = weight_spmf q"
   906 by(erule rel_spmfE) simp
   907 
   908 lemma rel_spmf_bij_betw:
   909   assumes f: "bij_betw f (set_spmf p) (set_spmf q)"
   910   and eq: "\<And>x. x \<in> set_spmf p \<Longrightarrow> spmf p x = spmf q (f x)"
   911   shows "rel_spmf (\<lambda>x y. f x = y) p q"
   912 proof -
   913   let ?f = "map_option f"
   914 
   915   have weq: "ennreal (weight_spmf p) = ennreal (weight_spmf q)"
   916     unfolding weight_spmf_eq_nn_integral_support
   917     by(subst nn_integral_bij_count_space[OF f, symmetric])(rule nn_integral_cong_AE, simp add: eq AE_count_space)
   918   then have "None \<in> set_pmf p \<longleftrightarrow> None \<in> set_pmf q"
   919     by(simp add: pmf_None_eq_weight_spmf set_pmf_iff)
   920   with f have "bij_betw (map_option f) (set_pmf p) (set_pmf q)"
   921     apply(auto simp add: bij_betw_def in_set_spmf inj_on_def intro: option.expand)
   922     apply(rename_tac [!] x)
   923     apply(case_tac [!] x)
   924     apply(auto iff: in_set_spmf)
   925     done
   926   then have "rel_pmf (\<lambda>x y. ?f x = y) p q"
   927     by(rule rel_pmf_bij_betw)(case_tac x, simp_all add: weq[simplified] eq in_set_spmf pmf_None_eq_weight_spmf)
   928   thus ?thesis by(rule pmf.rel_mono_strong)(auto intro!: rel_optionI simp add: Option.is_none_def)
   929 qed
   930 
   931 subsection \<open>From density to spmfs\<close>
   932 
   933 context fixes f :: "'a \<Rightarrow> real" begin
   934 
   935 definition embed_spmf :: "'a spmf"
   936 where "embed_spmf = embed_pmf (\<lambda>x. case x of None \<Rightarrow> 1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV) | Some x' \<Rightarrow> max 0 (f x'))"
   937 
   938 context
   939   assumes prob: "(\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV) \<le> 1"
   940 begin
   941 
   942 lemma nn_integral_embed_spmf_eq_1:
   943   "(\<integral>\<^sup>+ x. ennreal (case x of None \<Rightarrow> 1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV) | Some x' \<Rightarrow> max 0 (f x')) \<partial>count_space UNIV) = 1"
   944   (is "?lhs = _" is "(\<integral>\<^sup>+ x. ?f x \<partial>?M) = _")
   945 proof -
   946   have "?lhs = \<integral>\<^sup>+ x. ?f x * indicator {None} x + ?f x * indicator (range Some) x \<partial>?M"
   947     by(rule nn_integral_cong)(auto split: split_indicator)
   948   also have "\<dots> = (1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV)) + \<integral>\<^sup>+ x. ?f x * indicator (range Some) x \<partial>?M"
   949     (is "_ = ?None + ?Some")
   950     by(subst nn_integral_add)(simp_all add: AE_count_space max_def le_diff_eq real_le_ereal_iff one_ereal_def[symmetric] prob split: option.split)
   951   also have "?Some = \<integral>\<^sup>+ x. ?f x \<partial>count_space (range Some)"
   952     by(simp add: nn_integral_count_space_indicator)
   953   also have "count_space (range Some) = embed_measure (count_space UNIV) Some"
   954     by(simp add: embed_measure_count_space)
   955   also have "(\<integral>\<^sup>+ x. ?f x \<partial>\<dots>) = \<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV"
   956     by(subst nn_integral_embed_measure)(simp_all add: measurable_embed_measure1)
   957   also have "?None + \<dots> = 1" using prob
   958     by(auto simp add: ennreal_minus[symmetric] ennreal_1[symmetric] ennreal_enn2real_if top_unique simp del: ennreal_1)(simp add: diff_add_self_ennreal)
   959   finally show ?thesis .
   960 qed
   961 
   962 lemma pmf_embed_spmf_None: "pmf embed_spmf None = 1 - enn2real (\<integral>\<^sup>+ x. ennreal (f x) \<partial>count_space UNIV)"
   963 unfolding embed_spmf_def
   964 apply(subst pmf_embed_pmf)
   965   subgoal using prob by(simp add: field_simps enn2real_leI split: option.split)
   966  apply(rule nn_integral_embed_spmf_eq_1)
   967 apply simp
   968 done
   969 
   970 lemma spmf_embed_spmf [simp]: "spmf embed_spmf x = max 0 (f x)"
   971 unfolding embed_spmf_def
   972 apply(subst pmf_embed_pmf)
   973   subgoal using prob by(simp add: field_simps enn2real_leI split: option.split)
   974  apply(rule nn_integral_embed_spmf_eq_1)
   975 apply simp
   976 done
   977 
   978 end
   979 
   980 end
   981 
   982 lemma embed_spmf_K_0[simp]: "embed_spmf (\<lambda>_. 0) = return_pmf None" (is "?lhs = ?rhs")
   983 by(rule spmf_eqI)(simp add: zero_ereal_def[symmetric])
   984 
   985 subsection \<open>Ordering on spmfs\<close>
   986 
   987 text \<open>
   988   @{const rel_pmf} does not preserve a ccpo structure. Counterexample by Saheb-Djahromi:
   989   Take prefix order over \<open>bool llist\<close> and
   990   the set \<open>range (\<lambda>n :: nat. uniform (llist_n n))\<close> where \<open>llist_n\<close> is the set
   991   of all \<open>llist\<close>s of length \<open>n\<close> and \<open>uniform\<close> returns a uniform distribution over
   992   the given set. The set forms a chain in \<open>ord_pmf lprefix\<close>, but it has not an upper bound.
   993   Any upper bound may contain only infinite lists in its support because otherwise it is not greater
   994   than the \<open>n+1\<close>-st element in the chain where \<open>n\<close> is the length of the finite list.
   995   Moreover its support must contain all infinite lists, because otherwise there is a finite list
   996   all of whose finite extensions are not in the support - a contradiction to the upper bound property.
   997   Hence, the support is uncountable, but pmf's only have countable support.
   998 
   999   However, if all chains in the ccpo are finite, then it should preserve the ccpo structure.
  1000 \<close>
  1001 
  1002 abbreviation ord_spmf :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> 'a spmf \<Rightarrow> 'a spmf \<Rightarrow> bool"
  1003 where "ord_spmf ord \<equiv> rel_pmf (ord_option ord)"
  1004 
  1005 locale ord_spmf_syntax begin
  1006 notation ord_spmf (infix "\<sqsubseteq>\<index>" 60)
  1007 end
  1008 
  1009 lemma ord_spmf_map_spmf1: "ord_spmf R (map_spmf f p) = ord_spmf (\<lambda>x. R (f x)) p"
  1010 by(simp add: pmf.rel_map[abs_def] ord_option_map1[abs_def])
  1011 
  1012 lemma ord_spmf_map_spmf2: "ord_spmf R p (map_spmf f q) = ord_spmf (\<lambda>x y. R x (f y)) p q"
  1013 by(simp add: pmf.rel_map ord_option_map2)
  1014 
  1015 lemma ord_spmf_map_spmf12: "ord_spmf R (map_spmf f p) (map_spmf f q) = ord_spmf (\<lambda>x y. R (f x) (f y)) p q"
  1016 by(simp add: pmf.rel_map ord_option_map1[abs_def] ord_option_map2)
  1017 
  1018 lemmas ord_spmf_map_spmf = ord_spmf_map_spmf1 ord_spmf_map_spmf2 ord_spmf_map_spmf12
  1019 
  1020 context fixes ord :: "'a \<Rightarrow> 'a \<Rightarrow> bool" (structure) begin
  1021 interpretation ord_spmf_syntax .
  1022 
  1023 lemma ord_spmfI:
  1024   "\<lbrakk> \<And>x y. (x, y) \<in> set_spmf pq \<Longrightarrow> ord x y; map_spmf fst pq = p; map_spmf snd pq = q \<rbrakk>
  1025   \<Longrightarrow> p \<sqsubseteq> q"
  1026 by(rule rel_pmf.intros[where pq="map_pmf (\<lambda>x. case x of None \<Rightarrow> (None, None) | Some (a, b) \<Rightarrow> (Some a, Some b)) pq"])
  1027   (auto simp add: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong)
  1028 
  1029 lemma ord_spmf_None [simp]: "return_pmf None \<sqsubseteq> x"
  1030 by(rule rel_pmf.intros[where pq="map_pmf (Pair None) x"])(auto simp add: pmf.map_comp o_def)
  1031 
  1032 lemma ord_spmf_reflI: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> ord x x) \<Longrightarrow> p \<sqsubseteq> p"
  1033 by(rule rel_pmf_reflI ord_option_reflI)+(auto simp add: in_set_spmf)
  1034 
  1035 lemma rel_spmf_inf:
  1036   assumes "p \<sqsubseteq> q"
  1037   and "q \<sqsubseteq> p"
  1038   and refl: "reflp ord"
  1039   and trans: "transp ord"
  1040   shows "rel_spmf (inf ord ord\<inverse>\<inverse>) p q"
  1041 proof -
  1042   from \<open>p \<sqsubseteq> q\<close> \<open>q \<sqsubseteq> p\<close>
  1043   have "rel_pmf (inf (ord_option ord) (ord_option ord)\<inverse>\<inverse>) p q"
  1044     by(rule rel_pmf_inf)(blast intro: reflp_ord_option transp_ord_option refl trans)+
  1045   also have "inf (ord_option ord) (ord_option ord)\<inverse>\<inverse> = rel_option (inf ord ord\<inverse>\<inverse>)"
  1046     by(auto simp add: fun_eq_iff elim: ord_option.cases option.rel_cases)
  1047   finally show ?thesis .
  1048 qed
  1049 
  1050 end
  1051 
  1052 lemma ord_spmf_return_spmf2: "ord_spmf R p (return_spmf y) \<longleftrightarrow> (\<forall>x\<in>set_spmf p. R x y)"
  1053 by(auto simp add: rel_pmf_return_pmf2 in_set_spmf ord_option.simps intro: ccontr)
  1054 
  1055 lemma ord_spmf_mono: "\<lbrakk> ord_spmf A p q; \<And>x y. A x y \<Longrightarrow> B x y \<rbrakk> \<Longrightarrow> ord_spmf B p q"
  1056 by(erule rel_pmf_mono)(erule ord_option_mono)
  1057 
  1058 lemma ord_spmf_compp: "ord_spmf (A OO B) = ord_spmf A OO ord_spmf B"
  1059 by(simp add: ord_option_compp pmf.rel_compp)
  1060 
  1061 lemma ord_spmf_bindI:
  1062   assumes pq: "ord_spmf R p q"
  1063   and fg: "\<And>x y. R x y \<Longrightarrow> ord_spmf P (f x) (g y)"
  1064   shows "ord_spmf P (p \<bind> f) (q \<bind> g)"
  1065 unfolding bind_spmf_def using pq
  1066 by(rule rel_pmf_bindI)(auto split: option.split intro: fg)
  1067 
  1068 lemma ord_spmf_bind_reflI:
  1069   "(\<And>x. x \<in> set_spmf p \<Longrightarrow> ord_spmf R (f x) (g x))
  1070   \<Longrightarrow> ord_spmf R (p \<bind> f) (p \<bind> g)"
  1071 by(rule ord_spmf_bindI[where R="\<lambda>x y. x = y \<and> x \<in> set_spmf p"])(auto intro: ord_spmf_reflI)
  1072 
  1073 lemma ord_pmf_increaseI:
  1074   assumes le: "\<And>x. spmf p x \<le> spmf q x"
  1075   and refl: "\<And>x. x \<in> set_spmf p \<Longrightarrow> R x x"
  1076   shows "ord_spmf R p q"
  1077 proof(rule rel_pmf.intros)
  1078   define pq where "pq = embed_pmf
  1079     (\<lambda>(x, y). case x of Some x' \<Rightarrow> (case y of Some y' \<Rightarrow> if x' = y' then spmf p x' else 0 | None \<Rightarrow> 0)
  1080       | None \<Rightarrow> (case y of None \<Rightarrow> pmf q None | Some y' \<Rightarrow> spmf q y' - spmf p y'))"
  1081      (is "_ = embed_pmf ?f")
  1082   have nonneg: "\<And>xy. ?f xy \<ge> 0"
  1083     by(clarsimp simp add: le field_simps split: option.split)
  1084   have integral: "(\<integral>\<^sup>+ xy. ?f xy \<partial>count_space UNIV) = 1" (is "nn_integral ?M _ = _")
  1085   proof -
  1086     have "(\<integral>\<^sup>+ xy. ?f xy \<partial>count_space UNIV) =
  1087       \<integral>\<^sup>+ xy. ennreal (?f xy) * indicator {(None, None)} xy +
  1088              ennreal (?f xy) * indicator (range (\<lambda>x. (None, Some x))) xy +
  1089              ennreal (?f xy) * indicator (range (\<lambda>x. (Some x, Some x))) xy \<partial>?M"
  1090       by(rule nn_integral_cong)(auto split: split_indicator option.splits if_split_asm)
  1091     also have "\<dots> = (\<integral>\<^sup>+ xy. ?f xy * indicator {(None, None)} xy \<partial>?M) +
  1092         (\<integral>\<^sup>+ xy. ennreal (?f xy) * indicator (range (\<lambda>x. (None, Some x))) xy \<partial>?M) +
  1093         (\<integral>\<^sup>+ xy. ennreal (?f xy) * indicator (range (\<lambda>x. (Some x, Some x))) xy \<partial>?M)"
  1094       (is "_ = ?None + ?Some2 + ?Some")
  1095       by(subst nn_integral_add)(simp_all add: nn_integral_add AE_count_space le_diff_eq le split: option.split)
  1096     also have "?None = pmf q None" by simp
  1097     also have "?Some2 = \<integral>\<^sup>+ x. ennreal (spmf q x) - spmf p x \<partial>count_space UNIV"
  1098       by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1 ennreal_minus)
  1099     also have "\<dots> = (\<integral>\<^sup>+ x. spmf q x \<partial>count_space UNIV) - (\<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV)"
  1100       (is "_ = ?Some2' - ?Some2''")
  1101       by(subst nn_integral_diff)(simp_all add: le nn_integral_spmf_neq_top)
  1102     also have "?Some = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
  1103       by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1)
  1104     also have "pmf q None + (?Some2' - ?Some2'') + \<dots> = pmf q None + ?Some2'"
  1105       by(auto simp add: diff_add_self_ennreal le intro!: nn_integral_mono)
  1106     also have "\<dots> = \<integral>\<^sup>+ x. ennreal (pmf q x) * indicator {None} x + ennreal (pmf q x) * indicator (range Some) x \<partial>count_space UNIV"
  1107       by(subst nn_integral_add)(simp_all add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1)
  1108     also have "\<dots> = \<integral>\<^sup>+ x. pmf q x \<partial>count_space UNIV"
  1109       by(rule nn_integral_cong)(auto split: split_indicator)
  1110     also have "\<dots> = 1" by(simp add: nn_integral_pmf)
  1111     finally show ?thesis .
  1112   qed
  1113   note f = nonneg integral
  1114 
  1115   { fix x y
  1116     assume "(x, y) \<in> set_pmf pq"
  1117     hence "?f (x, y) \<noteq> 0" unfolding pq_def by(simp add: set_embed_pmf[OF f])
  1118     then show "ord_option R x y"
  1119       by(simp add: spmf_eq_0_set_spmf refl split: option.split_asm if_split_asm) }
  1120 
  1121   have weight_le: "weight_spmf p \<le> weight_spmf q"
  1122     by(subst ennreal_le_iff[symmetric])(auto simp add: weight_spmf_eq_nn_integral_spmf intro!: nn_integral_mono le)
  1123 
  1124   show "map_pmf fst pq = p"
  1125   proof(rule pmf_eqI)
  1126     fix i
  1127     have "ennreal (pmf (map_pmf fst pq) i) = (\<integral>\<^sup>+ y. pmf pq (i, y) \<partial>count_space UNIV)"
  1128       unfolding pq_def ennreal_pmf_map
  1129       apply(simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density nn_integral_count_space_indicator[symmetric])
  1130       apply(subst pmf_embed_pmf[OF f])
  1131       apply(rule nn_integral_bij_count_space[symmetric])
  1132       apply(auto simp add: bij_betw_def inj_on_def)
  1133       done
  1134     also have "\<dots> = pmf p i"
  1135     proof(cases i)
  1136       case (Some x)
  1137       have "(\<integral>\<^sup>+ y. pmf pq (Some x, y) \<partial>count_space UNIV) = \<integral>\<^sup>+ y. pmf p (Some x) * indicator {Some x} y \<partial>count_space UNIV"
  1138         by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split)
  1139       then show ?thesis using Some by simp
  1140     next
  1141       case None
  1142       have "(\<integral>\<^sup>+ y. pmf pq (None, y) \<partial>count_space UNIV) =
  1143             (\<integral>\<^sup>+ y. ennreal (pmf pq (None, Some (the y))) * indicator (range Some) y +
  1144                    ennreal (pmf pq (None, None)) * indicator {None} y \<partial>count_space UNIV)"
  1145         by(rule nn_integral_cong)(auto split: split_indicator)
  1146       also have "\<dots> = (\<integral>\<^sup>+ y. ennreal (pmf pq (None, Some (the y))) \<partial>count_space (range Some)) + pmf pq (None, None)"
  1147         by(subst nn_integral_add)(simp_all add: nn_integral_count_space_indicator)
  1148       also have "\<dots> = (\<integral>\<^sup>+ y. ennreal (spmf q y) - ennreal (spmf p y) \<partial>count_space UNIV) + pmf q None"
  1149         by(simp add: pq_def pmf_embed_pmf[OF f] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1 ennreal_minus)
  1150       also have "(\<integral>\<^sup>+ y. ennreal (spmf q y) - ennreal (spmf p y) \<partial>count_space UNIV) =
  1151                  (\<integral>\<^sup>+ y. spmf q y \<partial>count_space UNIV) - (\<integral>\<^sup>+ y. spmf p y \<partial>count_space UNIV)"
  1152         by(subst nn_integral_diff)(simp_all add: AE_count_space le nn_integral_spmf_neq_top split: split_indicator)
  1153       also have "\<dots> = pmf p None - pmf q None"
  1154         by(simp add: pmf_None_eq_weight_spmf weight_spmf_eq_nn_integral_spmf[symmetric] ennreal_minus)
  1155       also have "\<dots> = ennreal (pmf p None) - ennreal (pmf q None)" by(simp add: ennreal_minus)
  1156       finally show ?thesis using None weight_le
  1157         by(auto simp add: diff_add_self_ennreal pmf_None_eq_weight_spmf intro: ennreal_leI)
  1158     qed
  1159     finally show "pmf (map_pmf fst pq) i = pmf p i" by simp
  1160   qed
  1161 
  1162   show "map_pmf snd pq = q"
  1163   proof(rule pmf_eqI)
  1164     fix i
  1165     have "ennreal (pmf (map_pmf snd pq) i) = (\<integral>\<^sup>+ x. pmf pq (x, i) \<partial>count_space UNIV)"
  1166       unfolding pq_def ennreal_pmf_map
  1167       apply(simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density nn_integral_count_space_indicator[symmetric])
  1168       apply(subst pmf_embed_pmf[OF f])
  1169       apply(rule nn_integral_bij_count_space[symmetric])
  1170       apply(auto simp add: bij_betw_def inj_on_def)
  1171       done
  1172     also have "\<dots> = ennreal (pmf q i)"
  1173     proof(cases i)
  1174       case None
  1175       have "(\<integral>\<^sup>+ x. pmf pq (x, None) \<partial>count_space UNIV) = \<integral>\<^sup>+ x. pmf q None * indicator {None :: 'a option} x \<partial>count_space UNIV"
  1176         by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split)
  1177       then show ?thesis using None by simp
  1178     next
  1179       case (Some y)
  1180       have "(\<integral>\<^sup>+ x. pmf pq (x, Some y) \<partial>count_space UNIV) =
  1181         (\<integral>\<^sup>+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x +
  1182                ennreal (pmf pq (None, Some y)) * indicator {None} x \<partial>count_space UNIV)"
  1183         by(rule nn_integral_cong)(auto split: split_indicator)
  1184       also have "\<dots> = (\<integral>\<^sup>+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x \<partial>count_space UNIV) + pmf pq (None, Some y)"
  1185         by(subst nn_integral_add)(simp_all)
  1186       also have "\<dots> = (\<integral>\<^sup>+ x. ennreal (spmf p y) * indicator {Some y} x \<partial>count_space UNIV) + (spmf q y - spmf p y)"
  1187         by(auto simp add: pq_def pmf_embed_pmf[OF f] one_ereal_def[symmetric] simp del: nn_integral_indicator_singleton intro!: arg_cong2[where f="op +"] nn_integral_cong split: option.split)
  1188       also have "\<dots> = spmf q y" by(simp add: ennreal_minus[symmetric] le)
  1189       finally show ?thesis using Some by simp
  1190     qed
  1191     finally show "pmf (map_pmf snd pq) i = pmf q i" by simp
  1192   qed
  1193 qed
  1194 
  1195 lemma ord_spmf_eq_leD:
  1196   assumes "ord_spmf op = p q"
  1197   shows "spmf p x \<le> spmf q x"
  1198 proof(cases "x \<in> set_spmf p")
  1199   case False
  1200   thus ?thesis by(simp add: in_set_spmf_iff_spmf)
  1201 next
  1202   case True
  1203   from assms obtain pq
  1204     where pq: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> ord_option op = x y"
  1205     and p: "p = map_pmf fst pq"
  1206     and q: "q = map_pmf snd pq" by cases auto
  1207   have "ennreal (spmf p x) = integral\<^sup>N pq (indicator (fst -` {Some x}))"
  1208     using p by(simp add: ennreal_pmf_map)
  1209   also have "\<dots> = integral\<^sup>N pq (indicator {(Some x, Some x)})"
  1210     by(rule nn_integral_cong_AE)(auto simp add: AE_measure_pmf_iff split: split_indicator dest: pq)
  1211   also have "\<dots> \<le> integral\<^sup>N pq (indicator (snd -` {Some x}))"
  1212     by(rule nn_integral_mono) simp
  1213   also have "\<dots> = ennreal (spmf q x)" using q by(simp add: ennreal_pmf_map)
  1214   finally show ?thesis by simp
  1215 qed
  1216 
  1217 lemma ord_spmf_eqD_set_spmf: "ord_spmf op = p q \<Longrightarrow> set_spmf p \<subseteq> set_spmf q"
  1218 by(rule subsetI)(drule_tac x=x in ord_spmf_eq_leD, auto simp add: in_set_spmf_iff_spmf)
  1219 
  1220 lemma ord_spmf_eqD_emeasure:
  1221   "ord_spmf op = p q \<Longrightarrow> emeasure (measure_spmf p) A \<le> emeasure (measure_spmf q) A"
  1222 by(auto intro!: nn_integral_mono split: split_indicator dest: ord_spmf_eq_leD simp add: nn_integral_measure_spmf nn_integral_indicator[symmetric])
  1223 
  1224 lemma ord_spmf_eqD_measure_spmf: "ord_spmf op = p q \<Longrightarrow> measure_spmf p \<le> measure_spmf q"
  1225   by (subst le_measure) (auto simp: ord_spmf_eqD_emeasure)
  1226 
  1227 subsection \<open>CCPO structure for the flat ccpo @{term "ord_option op ="}\<close>
  1228 
  1229 context fixes Y :: "'a spmf set" begin
  1230 
  1231 definition lub_spmf :: "'a spmf"
  1232 where "lub_spmf = embed_spmf (\<lambda>x. enn2real (SUP p : Y. ennreal (spmf p x)))"
  1233   \<comment> \<open>We go through @{typ ennreal} to have a sensible definition even if @{term Y} is empty.\<close>
  1234 
  1235 lemma lub_spmf_empty [simp]: "SPMF.lub_spmf {} = return_pmf None"
  1236 by(simp add: SPMF.lub_spmf_def bot_ereal_def)
  1237 
  1238 context assumes chain: "Complete_Partial_Order.chain (ord_spmf op =) Y" begin
  1239 
  1240 lemma chain_ord_spmf_eqD: "Complete_Partial_Order.chain (op \<le>) ((\<lambda>p x. ennreal (spmf p x)) ` Y)"
  1241   (is "Complete_Partial_Order.chain _ (?f ` _)")
  1242 proof(rule chainI)
  1243   fix f g
  1244   assume "f \<in> ?f ` Y" "g \<in> ?f ` Y"
  1245   then obtain p q where f: "f = ?f p" "p \<in> Y" and g: "g = ?f q" "q \<in> Y" by blast
  1246   from chain \<open>p \<in> Y\<close> \<open>q \<in> Y\<close> have "ord_spmf op = p q \<or> ord_spmf op = q p" by(rule chainD)
  1247   thus "f \<le> g \<or> g \<le> f"
  1248   proof
  1249     assume "ord_spmf op = p q"
  1250     hence "\<And>x. spmf p x \<le> spmf q x" by(rule ord_spmf_eq_leD)
  1251     hence "f \<le> g" unfolding f g by(auto intro: le_funI)
  1252     thus ?thesis ..
  1253   next
  1254     assume "ord_spmf op = q p"
  1255     hence "\<And>x. spmf q x \<le> spmf p x" by(rule ord_spmf_eq_leD)
  1256     hence "g \<le> f" unfolding f g by(auto intro: le_funI)
  1257     thus ?thesis ..
  1258   qed
  1259 qed
  1260 
  1261 lemma ord_spmf_eq_pmf_None_eq:
  1262   assumes le: "ord_spmf op = p q"
  1263   and None: "pmf p None = pmf q None"
  1264   shows "p = q"
  1265 proof(rule spmf_eqI)
  1266   fix i
  1267   from le have le': "\<And>x. spmf p x \<le> spmf q x" by(rule ord_spmf_eq_leD)
  1268   have "(\<integral>\<^sup>+ x. ennreal (spmf q x) - spmf p x \<partial>count_space UNIV) =
  1269      (\<integral>\<^sup>+ x. spmf q x \<partial>count_space UNIV) - (\<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV)"
  1270     by(subst nn_integral_diff)(simp_all add: AE_count_space le' nn_integral_spmf_neq_top)
  1271   also have "\<dots> = (1 - pmf q None) - (1 - pmf p None)" unfolding pmf_None_eq_weight_spmf
  1272     by(simp add: weight_spmf_eq_nn_integral_spmf[symmetric] ennreal_minus)
  1273   also have "\<dots> = 0" using None by simp
  1274   finally have "\<And>x. spmf q x \<le> spmf p x"
  1275     by(simp add: nn_integral_0_iff_AE AE_count_space ennreal_minus ennreal_eq_0_iff)
  1276   with le' show "spmf p i = spmf q i" by(rule antisym)
  1277 qed
  1278 
  1279 lemma ord_spmf_eqD_pmf_None:
  1280   assumes "ord_spmf op = x y"
  1281   shows "pmf x None \<ge> pmf y None"
  1282 using assms
  1283 apply cases
  1284 apply(clarsimp simp only: ennreal_le_iff[symmetric, OF pmf_nonneg] ennreal_pmf_map)
  1285 apply(fastforce simp add: AE_measure_pmf_iff intro!: nn_integral_mono_AE)
  1286 done
  1287 
  1288 text \<open>
  1289   Chains on @{typ "'a spmf"} maintain countable support.
  1290   Thanks to Johannes Hölzl for the proof idea.
  1291 \<close>
  1292 lemma spmf_chain_countable: "countable (\<Union>p\<in>Y. set_spmf p)"
  1293 proof(cases "Y = {}")
  1294   case Y: False
  1295   show ?thesis
  1296   proof(cases "\<exists>x\<in>Y. \<forall>y\<in>Y. ord_spmf op = y x")
  1297     case True
  1298     then obtain x where x: "x \<in> Y" and upper: "\<And>y. y \<in> Y \<Longrightarrow> ord_spmf op = y x" by blast
  1299     hence "(\<Union>x\<in>Y. set_spmf x) \<subseteq> set_spmf x" by(auto dest: ord_spmf_eqD_set_spmf)
  1300     thus ?thesis by(rule countable_subset) simp
  1301   next
  1302     case False
  1303     define N :: "'a option pmf \<Rightarrow> real" where "N p = pmf p None" for p
  1304 
  1305     have N_less_imp_le_spmf: "\<lbrakk> x \<in> Y; y \<in> Y; N y < N x \<rbrakk> \<Longrightarrow> ord_spmf op = x y" for x y
  1306       using chainD[OF chain, of x y] ord_spmf_eqD_pmf_None[of x y] ord_spmf_eqD_pmf_None[of y x]
  1307       by (auto simp: N_def)
  1308     have N_eq_imp_eq: "\<lbrakk> x \<in> Y; y \<in> Y; N y = N x \<rbrakk> \<Longrightarrow> x = y" for x y
  1309       using chainD[OF chain, of x y] by(auto simp add: N_def dest: ord_spmf_eq_pmf_None_eq)
  1310 
  1311     have NC: "N ` Y \<noteq> {}" "bdd_below (N ` Y)"
  1312       using \<open>Y \<noteq> {}\<close> by(auto intro!: bdd_belowI[of _ 0] simp: N_def)
  1313     have NC_less: "Inf (N ` Y) < N x" if "x \<in> Y" for x unfolding cInf_less_iff[OF NC]
  1314     proof(rule ccontr)
  1315       assume **: "\<not> (\<exists>y\<in>N ` Y. y < N x)"
  1316       { fix y
  1317         assume "y \<in> Y"
  1318         with ** consider "N x < N y" | "N x = N y" by(auto simp add: not_less le_less)
  1319         hence "ord_spmf op = y x" using \<open>y \<in> Y\<close> \<open>x \<in> Y\<close>
  1320           by cases(auto dest: N_less_imp_le_spmf N_eq_imp_eq intro: ord_spmf_reflI) }
  1321       with False \<open>x \<in> Y\<close> show False by blast
  1322     qed
  1323 
  1324     from NC have "Inf (N ` Y) \<in> closure (N ` Y)" by (intro closure_contains_Inf)
  1325     then obtain X' where "\<And>n. X' n \<in> N ` Y" and X': "X' \<longlonglongrightarrow> Inf (N ` Y)"
  1326       unfolding closure_sequential by auto
  1327     then obtain X where X: "\<And>n. X n \<in> Y" and "X' = (\<lambda>n. N (X n))" unfolding image_iff Bex_def by metis
  1328 
  1329     with X' have seq: "(\<lambda>n. N (X n)) \<longlonglongrightarrow> Inf (N ` Y)" by simp
  1330     have "(\<Union>x \<in> Y. set_spmf x) \<subseteq> (\<Union>n. set_spmf (X n))"
  1331     proof(rule UN_least)
  1332       fix x
  1333       assume "x \<in> Y"
  1334       from order_tendstoD(2)[OF seq NC_less[OF \<open>x \<in> Y\<close>]]
  1335       obtain i where "N (X i) < N x" by (auto simp: eventually_sequentially)
  1336       thus "set_spmf x \<subseteq> (\<Union>n. set_spmf (X n))" using X \<open>x \<in> Y\<close>
  1337         by(blast dest: N_less_imp_le_spmf ord_spmf_eqD_set_spmf)
  1338     qed
  1339     thus ?thesis by(rule countable_subset) simp
  1340   qed
  1341 qed simp
  1342 
  1343 lemma lub_spmf_subprob: "(\<integral>\<^sup>+ x. (SUP p : Y. ennreal (spmf p x)) \<partial>count_space UNIV) \<le> 1"
  1344 proof(cases "Y = {}")
  1345   case True
  1346   thus ?thesis by(simp add: bot_ennreal)
  1347 next
  1348   case False
  1349   let ?B = "\<Union>p\<in>Y. set_spmf p"
  1350   have countable: "countable ?B" by(rule spmf_chain_countable)
  1351 
  1352   have "(\<integral>\<^sup>+ x. (SUP p:Y. ennreal (spmf p x)) \<partial>count_space UNIV) =
  1353         (\<integral>\<^sup>+ x. (SUP p:Y. ennreal (spmf p x) * indicator ?B x) \<partial>count_space UNIV)"
  1354     by(intro nn_integral_cong SUP_cong)(auto split: split_indicator simp add: spmf_eq_0_set_spmf)
  1355   also have "\<dots> = (\<integral>\<^sup>+ x. (SUP p:Y. ennreal (spmf p x)) \<partial>count_space ?B)"
  1356     unfolding ennreal_indicator[symmetric] using False
  1357     by(subst SUP_mult_right_ennreal[symmetric])(simp add: ennreal_indicator nn_integral_count_space_indicator)
  1358   also have "\<dots> = (SUP p:Y. \<integral>\<^sup>+ x. spmf p x \<partial>count_space ?B)" using False _ countable
  1359     by(rule nn_integral_monotone_convergence_SUP_countable)(rule chain_ord_spmf_eqD)
  1360   also have "\<dots> \<le> 1"
  1361   proof(rule SUP_least)
  1362     fix p
  1363     assume "p \<in> Y"
  1364     have "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space ?B) = \<integral>\<^sup>+ x. ennreal (spmf p x) * indicator ?B x \<partial>count_space UNIV"
  1365       by(simp add: nn_integral_count_space_indicator)
  1366     also have "\<dots> = \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
  1367       by(rule nn_integral_cong)(auto split: split_indicator simp add: spmf_eq_0_set_spmf \<open>p \<in> Y\<close>)
  1368     also have "\<dots> \<le> 1"
  1369       by(simp add: weight_spmf_eq_nn_integral_spmf[symmetric] weight_spmf_le_1)
  1370     finally show "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space ?B) \<le> 1" .
  1371   qed
  1372   finally show ?thesis .
  1373 qed
  1374 
  1375 lemma spmf_lub_spmf:
  1376   assumes "Y \<noteq> {}"
  1377   shows "spmf lub_spmf x = (SUP p : Y. spmf p x)"
  1378 proof -
  1379   from assms obtain p where "p \<in> Y" by auto
  1380   have "spmf lub_spmf x = max 0 (enn2real (SUP p:Y. ennreal (spmf p x)))" unfolding lub_spmf_def
  1381     by(rule spmf_embed_spmf)(simp del: SUP_eq_top_iff Sup_eq_top_iff add: ennreal_enn2real_if SUP_spmf_neq_top' lub_spmf_subprob)
  1382   also have "\<dots> = enn2real (SUP p:Y. ennreal (spmf p x))"
  1383     by(rule max_absorb2)(simp)
  1384   also have "\<dots> = enn2real (ennreal (SUP p : Y. spmf p x))" using assms
  1385     by(subst ennreal_SUP[symmetric])(simp_all add: SUP_spmf_neq_top' del: SUP_eq_top_iff Sup_eq_top_iff)
  1386   also have "0 \<le> (\<Squnion>p\<in>Y. spmf p x)" using assms
  1387     by(auto intro!: cSUP_upper2 bdd_aboveI[where M=1] simp add: pmf_le_1)
  1388   then have "enn2real (ennreal (SUP p : Y. spmf p x)) = (SUP p : Y. spmf p x)"
  1389     by(rule enn2real_ennreal)
  1390   finally show ?thesis .
  1391 qed
  1392 
  1393 lemma ennreal_spmf_lub_spmf: "Y \<noteq> {} \<Longrightarrow> ennreal (spmf lub_spmf x) = (SUP p:Y. ennreal (spmf p x))"
  1394 unfolding spmf_lub_spmf by(subst ennreal_SUP)(simp_all add: SUP_spmf_neq_top' del: SUP_eq_top_iff Sup_eq_top_iff)
  1395 
  1396 lemma lub_spmf_upper:
  1397   assumes p: "p \<in> Y"
  1398   shows "ord_spmf op = p lub_spmf"
  1399 proof(rule ord_pmf_increaseI)
  1400   fix x
  1401   from p have [simp]: "Y \<noteq> {}" by auto
  1402   from p have "ennreal (spmf p x) \<le> (SUP p:Y. ennreal (spmf p x))" by(rule SUP_upper)
  1403   also have "\<dots> = ennreal (spmf lub_spmf x)" using p
  1404     by(subst spmf_lub_spmf)(auto simp add: ennreal_SUP SUP_spmf_neq_top' simp del: SUP_eq_top_iff Sup_eq_top_iff)
  1405   finally show "spmf p x \<le> spmf lub_spmf x" by simp
  1406 qed simp
  1407 
  1408 lemma lub_spmf_least:
  1409   assumes z: "\<And>x. x \<in> Y \<Longrightarrow> ord_spmf op = x z"
  1410   shows "ord_spmf op = lub_spmf z"
  1411 proof(cases "Y = {}")
  1412   case nonempty: False
  1413   show ?thesis
  1414   proof(rule ord_pmf_increaseI)
  1415     fix x
  1416     from nonempty obtain p where p: "p \<in> Y" by auto
  1417     have "ennreal (spmf lub_spmf x) = (SUP p:Y. ennreal (spmf p x))"
  1418       by(subst spmf_lub_spmf)(auto simp add: ennreal_SUP SUP_spmf_neq_top' nonempty simp del: SUP_eq_top_iff Sup_eq_top_iff)
  1419     also have "\<dots> \<le> ennreal (spmf z x)" by(rule SUP_least)(simp add: ord_spmf_eq_leD z)
  1420     finally show "spmf lub_spmf x \<le> spmf z x" by simp
  1421   qed simp
  1422 qed simp
  1423 
  1424 lemma set_lub_spmf: "set_spmf lub_spmf = (\<Union>p\<in>Y. set_spmf p)" (is "?lhs = ?rhs")
  1425 proof(cases "Y = {}")
  1426   case [simp]: False
  1427   show ?thesis
  1428   proof(rule set_eqI)
  1429     fix x
  1430     have "x \<in> ?lhs \<longleftrightarrow> ennreal (spmf lub_spmf x) > 0"
  1431       by(simp_all add: in_set_spmf_iff_spmf less_le)
  1432     also have "\<dots> \<longleftrightarrow> (\<exists>p\<in>Y. ennreal (spmf p x) > 0)"
  1433       by(simp add: ennreal_spmf_lub_spmf less_SUP_iff)
  1434     also have "\<dots> \<longleftrightarrow> x \<in> ?rhs"
  1435       by(auto simp add: in_set_spmf_iff_spmf less_le)
  1436     finally show "x \<in> ?lhs \<longleftrightarrow> x \<in> ?rhs" .
  1437   qed
  1438 qed simp
  1439 
  1440 lemma emeasure_lub_spmf:
  1441   assumes Y: "Y \<noteq> {}"
  1442   shows "emeasure (measure_spmf lub_spmf) A = (SUP y:Y. emeasure (measure_spmf y) A)"
  1443   (is "?lhs = ?rhs")
  1444 proof -
  1445   let ?M = "count_space (set_spmf lub_spmf)"
  1446   have "?lhs = \<integral>\<^sup>+ x. ennreal (spmf lub_spmf x) * indicator A x \<partial>?M"
  1447     by(auto simp add: nn_integral_indicator[symmetric] nn_integral_measure_spmf')
  1448   also have "\<dots> = \<integral>\<^sup>+ x. (SUP y:Y. ennreal (spmf y x) * indicator A x) \<partial>?M"
  1449     unfolding ennreal_indicator[symmetric]
  1450     by(simp add: spmf_lub_spmf assms ennreal_SUP[OF SUP_spmf_neq_top'] SUP_mult_right_ennreal)
  1451   also from assms have "\<dots> = (SUP y:Y. \<integral>\<^sup>+ x. ennreal (spmf y x) * indicator A x \<partial>?M)"
  1452   proof(rule nn_integral_monotone_convergence_SUP_countable)
  1453     have "(\<lambda>i x. ennreal (spmf i x) * indicator A x) ` Y = (\<lambda>f x. f x * indicator A x) ` (\<lambda>p x. ennreal (spmf p x)) ` Y"
  1454       by(simp add: image_image)
  1455     also have "Complete_Partial_Order.chain op \<le> \<dots>" using chain_ord_spmf_eqD
  1456       by(rule chain_imageI)(auto simp add: le_fun_def split: split_indicator)
  1457     finally show "Complete_Partial_Order.chain op \<le> ((\<lambda>i x. ennreal (spmf i x) * indicator A x) ` Y)" .
  1458   qed simp
  1459   also have "\<dots> = (SUP y:Y. \<integral>\<^sup>+ x. ennreal (spmf y x) * indicator A x \<partial>count_space UNIV)"
  1460     by(auto simp add: nn_integral_count_space_indicator set_lub_spmf spmf_eq_0_set_spmf split: split_indicator intro!: SUP_cong nn_integral_cong)
  1461   also have "\<dots> = ?rhs"
  1462     by(auto simp add: nn_integral_indicator[symmetric] nn_integral_measure_spmf)
  1463   finally show ?thesis .
  1464 qed
  1465 
  1466 lemma measure_lub_spmf:
  1467   assumes Y: "Y \<noteq> {}"
  1468   shows "measure (measure_spmf lub_spmf) A = (SUP y:Y. measure (measure_spmf y) A)" (is "?lhs = ?rhs")
  1469 proof -
  1470   have "ennreal ?lhs = ennreal ?rhs"
  1471     using emeasure_lub_spmf[OF assms] SUP_emeasure_spmf_neq_top[of A Y] Y
  1472     unfolding measure_spmf.emeasure_eq_measure by(subst ennreal_SUP)
  1473   moreover have "0 \<le> ?rhs" using Y
  1474     by(auto intro!: cSUP_upper2 bdd_aboveI[where M=1] measure_spmf.subprob_measure_le_1)
  1475   ultimately show ?thesis by(simp)
  1476 qed
  1477 
  1478 lemma weight_lub_spmf:
  1479   assumes Y: "Y \<noteq> {}"
  1480   shows "weight_spmf lub_spmf = (SUP y:Y. weight_spmf y)"
  1481 unfolding weight_spmf_def by(rule measure_lub_spmf) fact
  1482 
  1483 lemma measure_spmf_lub_spmf:
  1484   assumes Y: "Y \<noteq> {}"
  1485   shows "measure_spmf lub_spmf = (SUP p:Y. measure_spmf p)" (is "?lhs = ?rhs")
  1486 proof(rule measure_eqI)
  1487   from assms obtain p where p: "p \<in> Y" by auto
  1488   from chain have chain': "Complete_Partial_Order.chain op \<le> (measure_spmf ` Y)"
  1489     by(rule chain_imageI)(rule ord_spmf_eqD_measure_spmf)
  1490   show "sets ?lhs = sets ?rhs"
  1491     using Y by (subst sets_SUP) auto
  1492   show "emeasure ?lhs A = emeasure ?rhs A" for A
  1493     using chain' Y p by (subst emeasure_SUP_chain) (auto simp:  emeasure_lub_spmf)
  1494 qed
  1495 
  1496 end
  1497 
  1498 end
  1499 
  1500 lemma partial_function_definitions_spmf: "partial_function_definitions (ord_spmf op =) lub_spmf"
  1501   (is "partial_function_definitions ?R _")
  1502 proof
  1503   fix x show "?R x x" by(simp add: ord_spmf_reflI)
  1504 next
  1505   fix x y z
  1506   assume "?R x y" "?R y z"
  1507   with transp_ord_option[OF transp_equality] show "?R x z" by(rule transp_rel_pmf[THEN transpD])
  1508 next
  1509   fix x y
  1510   assume "?R x y" "?R y x"
  1511   thus "x = y"
  1512     by(rule rel_pmf_antisym)(simp_all add: reflp_ord_option transp_ord_option antisymP_ord_option)
  1513 next
  1514   fix Y x
  1515   assume "Complete_Partial_Order.chain ?R Y" "x \<in> Y"
  1516   then show "?R x (lub_spmf Y)"
  1517     by(rule lub_spmf_upper)
  1518 next
  1519   fix Y z
  1520   assume "Complete_Partial_Order.chain ?R Y" "\<And>x. x \<in> Y \<Longrightarrow> ?R x z"
  1521   then show "?R (lub_spmf Y) z"
  1522     by(cases "Y = {}")(simp_all add: lub_spmf_least)
  1523 qed
  1524 
  1525 lemma ccpo_spmf: "class.ccpo lub_spmf (ord_spmf op =) (mk_less (ord_spmf op =))"
  1526 by(rule ccpo partial_function_definitions_spmf)+
  1527 
  1528 interpretation spmf: partial_function_definitions "ord_spmf op =" "lub_spmf"
  1529   rewrites "lub_spmf {} \<equiv> return_pmf None"
  1530 by(rule partial_function_definitions_spmf) simp
  1531 
  1532 declaration \<open>Partial_Function.init "spmf" @{term spmf.fixp_fun}
  1533   @{term spmf.mono_body} @{thm spmf.fixp_rule_uc} @{thm spmf.fixp_induct_uc}
  1534   NONE\<close>
  1535 
  1536 declare spmf.leq_refl[simp]
  1537 declare admissible_leI[OF ccpo_spmf, cont_intro]
  1538 
  1539 abbreviation "mono_spmf \<equiv> monotone (fun_ord (ord_spmf op =)) (ord_spmf op =)"
  1540 
  1541 lemma lub_spmf_const [simp]: "lub_spmf {p} = p"
  1542 by(rule spmf_eqI)(simp add: spmf_lub_spmf[OF ccpo.chain_singleton[OF ccpo_spmf]])
  1543 
  1544 lemma bind_spmf_mono':
  1545   assumes fg: "ord_spmf op = f g"
  1546   and hk: "\<And>x :: 'a. ord_spmf op = (h x) (k x)"
  1547   shows "ord_spmf op = (f \<bind> h) (g \<bind> k)"
  1548 unfolding bind_spmf_def using assms(1)
  1549 by(rule rel_pmf_bindI)(auto split: option.split simp add: hk)
  1550 
  1551 lemma bind_spmf_mono [partial_function_mono]:
  1552   assumes mf: "mono_spmf B" and mg: "\<And>y. mono_spmf (\<lambda>f. C y f)"
  1553   shows "mono_spmf (\<lambda>f. bind_spmf (B f) (\<lambda>y. C y f))"
  1554 proof (rule monotoneI)
  1555   fix f g :: "'a \<Rightarrow> 'b spmf"
  1556   assume fg: "fun_ord (ord_spmf op =) f g"
  1557   with mf have "ord_spmf op = (B f) (B g)" by (rule monotoneD[of _ _ _ f g])
  1558   moreover from mg have "\<And>y'. ord_spmf op = (C y' f) (C y' g)"
  1559     by (rule monotoneD) (rule fg)
  1560   ultimately show "ord_spmf op = (bind_spmf (B f) (\<lambda>y. C y f)) (bind_spmf (B g) (\<lambda>y'. C y' g))"
  1561     by(rule bind_spmf_mono')
  1562 qed
  1563 
  1564 lemma monotone_bind_spmf1: "monotone (ord_spmf op =) (ord_spmf op =) (\<lambda>y. bind_spmf y g)"
  1565 by(rule monotoneI)(simp add: bind_spmf_mono' ord_spmf_reflI)
  1566 
  1567 lemma monotone_bind_spmf2:
  1568   assumes g: "\<And>x. monotone ord (ord_spmf op =) (\<lambda>y. g y x)"
  1569   shows "monotone ord (ord_spmf op =) (\<lambda>y. bind_spmf p (g y))"
  1570 by(rule monotoneI)(auto intro: bind_spmf_mono' monotoneD[OF g] ord_spmf_reflI)
  1571 
  1572 lemma bind_lub_spmf:
  1573   assumes chain: "Complete_Partial_Order.chain (ord_spmf op =) Y"
  1574   shows "bind_spmf (lub_spmf Y) f = lub_spmf ((\<lambda>p. bind_spmf p f) ` Y)" (is "?lhs = ?rhs")
  1575 proof(cases "Y = {}")
  1576   case Y: False
  1577   show ?thesis
  1578   proof(rule spmf_eqI)
  1579     fix i
  1580     have chain': "Complete_Partial_Order.chain op \<le> ((\<lambda>p x. ennreal (spmf p x * spmf (f x) i)) ` Y)"
  1581       using chain by(rule chain_imageI)(auto simp add: le_fun_def dest: ord_spmf_eq_leD intro: mult_right_mono)
  1582     have chain'': "Complete_Partial_Order.chain (ord_spmf op =) ((\<lambda>p. p \<bind> f) ` Y)"
  1583       using chain by(rule chain_imageI)(auto intro!: monotoneI bind_spmf_mono' ord_spmf_reflI)
  1584     let ?M = "count_space (set_spmf (lub_spmf Y))"
  1585     have "ennreal (spmf ?lhs i) = \<integral>\<^sup>+ x. ennreal (spmf (lub_spmf Y) x) * ennreal (spmf (f x) i) \<partial>?M"
  1586       by(auto simp add: ennreal_spmf_lub_spmf ennreal_spmf_bind nn_integral_measure_spmf')
  1587     also have "\<dots> = \<integral>\<^sup>+ x. (SUP p:Y. ennreal (spmf p x * spmf (f x) i)) \<partial>?M"
  1588       by(subst ennreal_spmf_lub_spmf[OF chain Y])(subst SUP_mult_right_ennreal, simp_all add: ennreal_mult Y)
  1589     also have "\<dots> = (SUP p:Y. \<integral>\<^sup>+ x. ennreal (spmf p x * spmf (f x) i) \<partial>?M)"
  1590       using Y chain' by(rule nn_integral_monotone_convergence_SUP_countable) simp
  1591     also have "\<dots> = (SUP p:Y. ennreal (spmf (bind_spmf p f) i))"
  1592       by(auto simp add: ennreal_spmf_bind nn_integral_measure_spmf nn_integral_count_space_indicator set_lub_spmf[OF chain] in_set_spmf_iff_spmf ennreal_mult intro!: SUP_cong nn_integral_cong split: split_indicator)
  1593     also have "\<dots> = ennreal (spmf ?rhs i)" using chain'' by(simp add: ennreal_spmf_lub_spmf Y)
  1594     finally show "spmf ?lhs i = spmf ?rhs i" by simp
  1595   qed
  1596 qed simp
  1597 
  1598 lemma map_lub_spmf:
  1599   "Complete_Partial_Order.chain (ord_spmf op =) Y
  1600   \<Longrightarrow> map_spmf f (lub_spmf Y) = lub_spmf (map_spmf f ` Y)"
  1601 unfolding map_spmf_conv_bind_spmf[abs_def] by(simp add: bind_lub_spmf o_def)
  1602 
  1603 lemma mcont_bind_spmf1: "mcont lub_spmf (ord_spmf op =) lub_spmf (ord_spmf op =) (\<lambda>y. bind_spmf y f)"
  1604 using monotone_bind_spmf1 by(rule mcontI)(rule contI, simp add: bind_lub_spmf)
  1605 
  1606 lemma bind_lub_spmf2:
  1607   assumes chain: "Complete_Partial_Order.chain ord Y"
  1608   and g: "\<And>y. monotone ord (ord_spmf op =) (g y)"
  1609   shows "bind_spmf x (\<lambda>y. lub_spmf (g y ` Y)) = lub_spmf ((\<lambda>p. bind_spmf x (\<lambda>y. g y p)) ` Y)"
  1610   (is "?lhs = ?rhs")
  1611 proof(cases "Y = {}")
  1612   case Y: False
  1613   show ?thesis
  1614   proof(rule spmf_eqI)
  1615     fix i
  1616     have chain': "\<And>y. Complete_Partial_Order.chain (ord_spmf op =) (g y ` Y)"
  1617       using chain g[THEN monotoneD] by(rule chain_imageI)
  1618     have chain'': "Complete_Partial_Order.chain op \<le> ((\<lambda>p y. ennreal (spmf x y * spmf (g y p) i)) ` Y)"
  1619       using chain by(rule chain_imageI)(auto simp add: le_fun_def dest: ord_spmf_eq_leD monotoneD[OF g] intro!: mult_left_mono)
  1620     have chain''': "Complete_Partial_Order.chain (ord_spmf op =) ((\<lambda>p. bind_spmf x (\<lambda>y. g y p)) ` Y)"
  1621       using chain by(rule chain_imageI)(rule monotone_bind_spmf2[OF g, THEN monotoneD])
  1622 
  1623     have "ennreal (spmf ?lhs i) = \<integral>\<^sup>+ y. (SUP p:Y. ennreal (spmf x y * spmf (g y p) i)) \<partial>count_space (set_spmf x)"
  1624       by(simp add: ennreal_spmf_bind ennreal_spmf_lub_spmf[OF chain'] Y nn_integral_measure_spmf' SUP_mult_left_ennreal ennreal_mult)
  1625     also have "\<dots> = (SUP p:Y. \<integral>\<^sup>+ y. ennreal (spmf x y * spmf (g y p) i) \<partial>count_space (set_spmf x))"
  1626       unfolding nn_integral_measure_spmf' using Y chain''
  1627       by(rule nn_integral_monotone_convergence_SUP_countable) simp
  1628     also have "\<dots> = (SUP p:Y. ennreal (spmf (bind_spmf x (\<lambda>y. g y p)) i))"
  1629       by(simp add: ennreal_spmf_bind nn_integral_measure_spmf' ennreal_mult)
  1630     also have "\<dots> = ennreal (spmf ?rhs i)" using chain'''
  1631       by(auto simp add: ennreal_spmf_lub_spmf Y)
  1632     finally show "spmf ?lhs i = spmf ?rhs i" by simp
  1633   qed
  1634 qed simp
  1635 
  1636 lemma mcont_bind_spmf [cont_intro]:
  1637   assumes f: "mcont luba orda lub_spmf (ord_spmf op =) f"
  1638   and g: "\<And>y. mcont luba orda lub_spmf (ord_spmf op =) (g y)"
  1639   shows "mcont luba orda lub_spmf (ord_spmf op =) (\<lambda>x. bind_spmf (f x) (\<lambda>y. g y x))"
  1640 proof(rule spmf.mcont2mcont'[OF _ _ f])
  1641   fix z
  1642   show "mcont lub_spmf (ord_spmf op =) lub_spmf (ord_spmf op =) (\<lambda>x. bind_spmf x (\<lambda>y. g y z))"
  1643     by(rule mcont_bind_spmf1)
  1644 next
  1645   fix x
  1646   let ?f = "\<lambda>z. bind_spmf x (\<lambda>y. g y z)"
  1647   have "monotone orda (ord_spmf op =) ?f" using mcont_mono[OF g] by(rule monotone_bind_spmf2)
  1648   moreover have "cont luba orda lub_spmf (ord_spmf op =) ?f"
  1649   proof(rule contI)
  1650     fix Y
  1651     assume chain: "Complete_Partial_Order.chain orda Y" and Y: "Y \<noteq> {}"
  1652     have "bind_spmf x (\<lambda>y. g y (luba Y)) = bind_spmf x (\<lambda>y. lub_spmf (g y ` Y))"
  1653       by(rule bind_spmf_cong)(simp_all add: mcont_contD[OF g chain Y])
  1654     also have "\<dots> = lub_spmf ((\<lambda>p. x \<bind> (\<lambda>y. g y p)) ` Y)" using chain
  1655       by(rule bind_lub_spmf2)(rule mcont_mono[OF g])
  1656     finally show "bind_spmf x (\<lambda>y. g y (luba Y)) = \<dots>" .
  1657   qed
  1658   ultimately show "mcont luba orda lub_spmf (ord_spmf op =) ?f" by(rule mcontI)
  1659 qed
  1660 
  1661 lemma bind_pmf_mono [partial_function_mono]:
  1662   "(\<And>y. mono_spmf (\<lambda>f. C y f)) \<Longrightarrow> mono_spmf (\<lambda>f. bind_pmf p (\<lambda>x. C x f))"
  1663 using bind_spmf_mono[of "\<lambda>_. spmf_of_pmf p" C] by simp
  1664 
  1665 lemma map_spmf_mono [partial_function_mono]: "mono_spmf B \<Longrightarrow> mono_spmf (\<lambda>g. map_spmf f (B g))"
  1666 unfolding map_spmf_conv_bind_spmf by(rule bind_spmf_mono) simp_all
  1667 
  1668 lemma mcont_map_spmf [cont_intro]:
  1669   "mcont luba orda lub_spmf (ord_spmf op =) g
  1670   \<Longrightarrow> mcont luba orda lub_spmf (ord_spmf op =) (\<lambda>x. map_spmf f (g x))"
  1671 unfolding map_spmf_conv_bind_spmf by(rule mcont_bind_spmf) simp_all
  1672 
  1673 lemma monotone_set_spmf: "monotone (ord_spmf op =) op \<subseteq> set_spmf"
  1674 by(rule monotoneI)(rule ord_spmf_eqD_set_spmf)
  1675 
  1676 lemma cont_set_spmf: "cont lub_spmf (ord_spmf op =) Union op \<subseteq> set_spmf"
  1677 by(rule contI)(subst set_lub_spmf; simp)
  1678 
  1679 lemma mcont2mcont_set_spmf[THEN mcont2mcont, cont_intro]:
  1680   shows mcont_set_spmf: "mcont lub_spmf (ord_spmf op =) Union op \<subseteq> set_spmf"
  1681 by(rule mcontI monotone_set_spmf cont_set_spmf)+
  1682 
  1683 lemma monotone_spmf: "monotone (ord_spmf op =) op \<le> (\<lambda>p. spmf p x)"
  1684 by(rule monotoneI)(simp add: ord_spmf_eq_leD)
  1685 
  1686 lemma cont_spmf: "cont lub_spmf (ord_spmf op =) Sup op \<le> (\<lambda>p. spmf p x)"
  1687 by(rule contI)(simp add: spmf_lub_spmf)
  1688 
  1689 lemma mcont_spmf: "mcont lub_spmf (ord_spmf op =) Sup op \<le> (\<lambda>p. spmf p x)"
  1690 by(rule mcontI monotone_spmf cont_spmf)+
  1691 
  1692 lemma cont_ennreal_spmf: "cont lub_spmf (ord_spmf op =) Sup op \<le> (\<lambda>p. ennreal (spmf p x))"
  1693 by(rule contI)(simp add: ennreal_spmf_lub_spmf)
  1694 
  1695 lemma mcont2mcont_ennreal_spmf [THEN mcont2mcont, cont_intro]:
  1696   shows mcont_ennreal_spmf: "mcont lub_spmf (ord_spmf op =) Sup op \<le> (\<lambda>p. ennreal (spmf p x))"
  1697 by(rule mcontI mono2mono_ennreal monotone_spmf cont_ennreal_spmf)+
  1698 
  1699 lemma nn_integral_map_spmf [simp]: "nn_integral (measure_spmf (map_spmf f p)) g = nn_integral (measure_spmf p) (g \<circ> f)"
  1700 by(auto 4 3 simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space intro: nn_integral_cong split: split_indicator)
  1701 
  1702 subsubsection \<open>Admissibility of @{term rel_spmf}\<close>
  1703 
  1704 lemma rel_spmf_measureD:
  1705   assumes "rel_spmf R p q"
  1706   shows "measure (measure_spmf p) A \<le> measure (measure_spmf q) {y. \<exists>x\<in>A. R x y}" (is "?lhs \<le> ?rhs")
  1707 proof -
  1708   have "?lhs = measure (measure_pmf p) (Some ` A)" by(simp add: measure_measure_spmf_conv_measure_pmf)
  1709   also have "\<dots> \<le> measure (measure_pmf q) {y. \<exists>x\<in>Some ` A. rel_option R x y}"
  1710     using assms by(rule rel_pmf_measureD)
  1711   also have "\<dots> = ?rhs" unfolding measure_measure_spmf_conv_measure_pmf
  1712     by(rule arg_cong2[where f=measure])(auto simp add: option_rel_Some1)
  1713   finally show ?thesis .
  1714 qed
  1715 
  1716 locale rel_spmf_characterisation =
  1717   assumes rel_pmf_measureI:
  1718     "\<And>(R :: 'a option \<Rightarrow> 'b option \<Rightarrow> bool) p q.
  1719     (\<And>A. measure (measure_pmf p) A \<le> measure (measure_pmf q) {y. \<exists>x\<in>A. R x y})
  1720     \<Longrightarrow> rel_pmf R p q"
  1721   \<comment> \<open>This assumption is shown to hold in general in the AFP entry \<open>MFMC_Countable\<close>.\<close>
  1722 begin
  1723 
  1724 context fixes R :: "'a \<Rightarrow> 'b \<Rightarrow> bool" begin
  1725 
  1726 lemma rel_spmf_measureI:
  1727   assumes eq1: "\<And>A. measure (measure_spmf p) A \<le> measure (measure_spmf q) {y. \<exists>x\<in>A. R x y}"
  1728   assumes eq2: "weight_spmf q \<le> weight_spmf p"
  1729   shows "rel_spmf R p q"
  1730 proof(rule rel_pmf_measureI)
  1731   fix A :: "'a option set"
  1732   define A' where "A' = the ` (A \<inter> range Some)"
  1733   define A'' where "A'' = A \<inter> {None}"
  1734   have A: "A = Some ` A' \<union> A''" "Some ` A' \<inter> A'' = {}"
  1735     unfolding A'_def A''_def by(auto 4 3 intro: rev_image_eqI)
  1736   have "measure (measure_pmf p) A = measure (measure_pmf p) (Some ` A') + measure (measure_pmf p) A''"
  1737     by(simp add: A measure_pmf.finite_measure_Union)
  1738   also have "measure (measure_pmf p) (Some ` A') = measure (measure_spmf p) A'"
  1739     by(simp add: measure_measure_spmf_conv_measure_pmf)
  1740   also have "\<dots> \<le> measure (measure_spmf q) {y. \<exists>x\<in>A'. R x y}" by(rule eq1)
  1741   also (ord_eq_le_trans[OF _ add_right_mono])
  1742   have "\<dots> = measure (measure_pmf q) {y. \<exists>x\<in>A'. rel_option R (Some x) y}"
  1743     unfolding measure_measure_spmf_conv_measure_pmf
  1744     by(rule arg_cong2[where f=measure])(auto simp add: A'_def option_rel_Some1)
  1745   also
  1746   { have "weight_spmf p \<le> measure (measure_spmf q) {y. \<exists>x. R x y}"
  1747       using eq1[of UNIV] unfolding weight_spmf_def by simp
  1748     also have "\<dots> \<le> weight_spmf q" unfolding weight_spmf_def
  1749       by(rule measure_spmf.finite_measure_mono) simp_all
  1750     finally have "weight_spmf p = weight_spmf q" using eq2 by simp }
  1751   then have "measure (measure_pmf p) A'' = measure (measure_pmf q) (if None \<in> A then {None} else {})"
  1752     unfolding A''_def by(simp add: pmf_None_eq_weight_spmf measure_pmf_single)
  1753   also have "measure (measure_pmf q) {y. \<exists>x\<in>A'. rel_option R (Some x) y} + \<dots> = measure (measure_pmf q) {y. \<exists>x\<in>A. rel_option R x y}"
  1754     by(subst measure_pmf.finite_measure_Union[symmetric])
  1755       (auto 4 3 intro!: arg_cong2[where f=measure] simp add: option_rel_Some1 option_rel_Some2 A'_def intro: rev_bexI elim: option.rel_cases)
  1756   finally show "measure (measure_pmf p) A \<le> \<dots>" .
  1757 qed
  1758 
  1759 lemma admissible_rel_spmf:
  1760   "ccpo.admissible (prod_lub lub_spmf lub_spmf) (rel_prod (ord_spmf op =) (ord_spmf op =)) (case_prod (rel_spmf R))"
  1761   (is "ccpo.admissible ?lub ?ord ?P")
  1762 proof(rule ccpo.admissibleI)
  1763   fix Y
  1764   assume chain: "Complete_Partial_Order.chain ?ord Y"
  1765     and Y: "Y \<noteq> {}"
  1766     and R: "\<forall>(p, q) \<in> Y. rel_spmf R p q"
  1767   from R have R: "\<And>p q. (p, q) \<in> Y \<Longrightarrow> rel_spmf R p q" by auto
  1768   have chain1: "Complete_Partial_Order.chain (ord_spmf op =) (fst ` Y)"
  1769     and chain2: "Complete_Partial_Order.chain (ord_spmf op =) (snd ` Y)"
  1770     using chain by(rule chain_imageI; clarsimp)+
  1771   from Y have Y1: "fst ` Y \<noteq> {}" and Y2: "snd ` Y \<noteq> {}" by auto
  1772 
  1773   have "rel_spmf R (lub_spmf (fst ` Y)) (lub_spmf (snd ` Y))"
  1774   proof(rule rel_spmf_measureI)
  1775     show "weight_spmf (lub_spmf (snd ` Y)) \<le> weight_spmf (lub_spmf (fst ` Y))"
  1776       by(auto simp add: weight_lub_spmf chain1 chain2 Y rel_spmf_weightD[OF R, symmetric] intro!: cSUP_least intro: cSUP_upper2[OF bdd_aboveI2[OF weight_spmf_le_1]])
  1777 
  1778     fix A
  1779     have "measure (measure_spmf (lub_spmf (fst ` Y))) A = (SUP y:fst ` Y. measure (measure_spmf y) A)"
  1780       using chain1 Y1 by(rule measure_lub_spmf)
  1781     also have "\<dots> \<le> (SUP y:snd ` Y. measure (measure_spmf y) {y. \<exists>x\<in>A. R x y})" using Y1
  1782       by(rule cSUP_least)(auto intro!: cSUP_upper2[OF bdd_aboveI2[OF measure_spmf.subprob_measure_le_1]] rel_spmf_measureD R)
  1783     also have "\<dots> = measure (measure_spmf (lub_spmf (snd ` Y))) {y. \<exists>x\<in>A. R x y}"
  1784       using chain2 Y2 by(rule measure_lub_spmf[symmetric])
  1785     finally show "measure (measure_spmf (lub_spmf (fst ` Y))) A \<le> \<dots>" .
  1786   qed
  1787   then show "?P (?lub Y)" by(simp add: prod_lub_def)
  1788 qed
  1789 
  1790 lemma admissible_rel_spmf_mcont [cont_intro]:
  1791   "\<lbrakk> mcont lub ord lub_spmf (ord_spmf op =) f; mcont lub ord lub_spmf (ord_spmf op =) g \<rbrakk>
  1792   \<Longrightarrow> ccpo.admissible lub ord (\<lambda>x. rel_spmf R (f x) (g x))"
  1793 by(rule admissible_subst[OF admissible_rel_spmf, where f="\<lambda>x. (f x, g x)", simplified])(rule mcont_Pair)
  1794 
  1795 context includes lifting_syntax
  1796 begin
  1797 
  1798 lemma fixp_spmf_parametric':
  1799   assumes f: "\<And>x. monotone (ord_spmf op =) (ord_spmf op =) F"
  1800   and g: "\<And>x. monotone (ord_spmf op =) (ord_spmf op =) G"
  1801   and param: "(rel_spmf R ===> rel_spmf R) F G"
  1802   shows "(rel_spmf R) (ccpo.fixp lub_spmf (ord_spmf op =) F) (ccpo.fixp lub_spmf (ord_spmf op =) G)"
  1803 by(rule parallel_fixp_induct[OF ccpo_spmf ccpo_spmf _ f g])(auto intro: param[THEN rel_funD])
  1804 
  1805 lemma fixp_spmf_parametric:
  1806   assumes f: "\<And>x. mono_spmf (\<lambda>f. F f x)"
  1807   and g: "\<And>x. mono_spmf (\<lambda>f. G f x)"
  1808   and param: "((A ===> rel_spmf R) ===> A ===> rel_spmf R) F G"
  1809   shows "(A ===> rel_spmf R) (spmf.fixp_fun F) (spmf.fixp_fun G)"
  1810 using f g
  1811 proof(rule parallel_fixp_induct_1_1[OF partial_function_definitions_spmf partial_function_definitions_spmf _ _ reflexive reflexive, where P="(A ===> rel_spmf R)"])
  1812   show "ccpo.admissible (prod_lub (fun_lub lub_spmf) (fun_lub lub_spmf)) (rel_prod (fun_ord (ord_spmf op =)) (fun_ord (ord_spmf op =))) (\<lambda>x. (A ===> rel_spmf R) (fst x) (snd x))"
  1813     unfolding rel_fun_def
  1814     apply(rule admissible_all admissible_imp admissible_rel_spmf_mcont)+
  1815     apply(rule spmf.mcont2mcont[OF mcont_call])
  1816      apply(rule mcont_fst)
  1817     apply(rule spmf.mcont2mcont[OF mcont_call])
  1818      apply(rule mcont_snd)
  1819     done
  1820   show "(A ===> rel_spmf R) (\<lambda>_. lub_spmf {}) (\<lambda>_. lub_spmf {})" by auto
  1821   show "(A ===> rel_spmf R) (F f) (G g)" if "(A ===> rel_spmf R) f g" for f g
  1822     using that by(rule rel_funD[OF param])
  1823 qed
  1824 
  1825 end
  1826 
  1827 end
  1828 
  1829 end
  1830 
  1831 subsection \<open>Restrictions on spmfs\<close>
  1832 
  1833 definition restrict_spmf :: "'a spmf \<Rightarrow> 'a set \<Rightarrow> 'a spmf" (infixl "\<upharpoonleft>" 110)
  1834 where "p \<upharpoonleft> A = map_pmf (\<lambda>x. x \<bind> (\<lambda>y. if y \<in> A then Some y else None)) p"
  1835 
  1836 lemma set_restrict_spmf [simp]: "set_spmf (p \<upharpoonleft> A) = set_spmf p \<inter> A"
  1837 by(fastforce simp add: restrict_spmf_def set_spmf_def split: bind_splits if_split_asm)
  1838 
  1839 lemma restrict_map_spmf: "map_spmf f p \<upharpoonleft> A = map_spmf f (p \<upharpoonleft> (f -` A))"
  1840 by(simp add: restrict_spmf_def pmf.map_comp o_def map_option_bind bind_map_option if_distrib cong del: if_weak_cong)
  1841 
  1842 lemma restrict_restrict_spmf [simp]: "p \<upharpoonleft> A \<upharpoonleft> B = p \<upharpoonleft> (A \<inter> B)"
  1843 by(auto simp add: restrict_spmf_def pmf.map_comp o_def intro!: pmf.map_cong bind_option_cong)
  1844 
  1845 lemma restrict_spmf_empty [simp]: "p \<upharpoonleft> {} = return_pmf None"
  1846 by(simp add: restrict_spmf_def)
  1847 
  1848 lemma restrict_spmf_UNIV [simp]: "p \<upharpoonleft> UNIV = p"
  1849 by(simp add: restrict_spmf_def)
  1850 
  1851 lemma spmf_restrict_spmf_outside [simp]: "x \<notin> A \<Longrightarrow> spmf (p \<upharpoonleft> A) x = 0"
  1852 by(simp add: spmf_eq_0_set_spmf)
  1853 
  1854 lemma emeasure_restrict_spmf [simp]:
  1855   "emeasure (measure_spmf (p \<upharpoonleft> A)) X = emeasure (measure_spmf p) (X \<inter> A)"
  1856 by(auto simp add: restrict_spmf_def measure_spmf_def emeasure_distr measurable_restrict_space1 emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure] split: bind_splits if_split_asm)
  1857 
  1858 lemma measure_restrict_spmf [simp]:
  1859   "measure (measure_spmf (p \<upharpoonleft> A)) X = measure (measure_spmf p) (X \<inter> A)"
  1860 using emeasure_restrict_spmf[of p A X]
  1861 by(simp only: measure_spmf.emeasure_eq_measure ennreal_inj measure_nonneg)
  1862 
  1863 lemma spmf_restrict_spmf: "spmf (p \<upharpoonleft> A) x = (if x \<in> A then spmf p x else 0)"
  1864 by(simp add: spmf_conv_measure_spmf)
  1865 
  1866 lemma spmf_restrict_spmf_inside [simp]: "x \<in> A \<Longrightarrow> spmf (p \<upharpoonleft> A) x = spmf p x"
  1867 by(simp add: spmf_restrict_spmf)
  1868 
  1869 lemma pmf_restrict_spmf_None: "pmf (p \<upharpoonleft> A) None = pmf p None + measure (measure_spmf p) (- A)"
  1870 proof -
  1871   have [simp]: "None \<notin> Some ` (- A)" by auto
  1872   have "(\<lambda>x. x \<bind> (\<lambda>y. if y \<in> A then Some y else None)) -` {None} = {None} \<union> (Some ` (- A))"
  1873     by(auto split: bind_splits if_split_asm)
  1874   then show ?thesis unfolding ereal.inject[symmetric]
  1875     by(simp add: restrict_spmf_def ennreal_pmf_map emeasure_pmf_single del: ereal.inject)
  1876       (simp add: pmf.rep_eq measure_pmf.finite_measure_Union[symmetric] measure_measure_spmf_conv_measure_pmf measure_pmf.emeasure_eq_measure)
  1877 qed
  1878 
  1879 lemma restrict_spmf_trivial: "(\<And>x. x \<in> set_spmf p \<Longrightarrow> x \<in> A) \<Longrightarrow> p \<upharpoonleft> A = p"
  1880 by(rule spmf_eqI)(auto simp add: spmf_restrict_spmf spmf_eq_0_set_spmf)
  1881 
  1882 lemma restrict_spmf_trivial': "set_spmf p \<subseteq> A \<Longrightarrow> p \<upharpoonleft> A = p"
  1883 by(rule restrict_spmf_trivial) blast
  1884 
  1885 lemma restrict_return_spmf: "return_spmf x \<upharpoonleft> A = (if x \<in> A then return_spmf x else return_pmf None)"
  1886 by(simp add: restrict_spmf_def)
  1887 
  1888 lemma restrict_return_spmf_inside [simp]: "x \<in> A \<Longrightarrow> return_spmf x \<upharpoonleft> A = return_spmf x"
  1889 by(simp add: restrict_return_spmf)
  1890 
  1891 lemma restrict_return_spmf_outside [simp]: "x \<notin> A \<Longrightarrow> return_spmf x \<upharpoonleft> A = return_pmf None"
  1892 by(simp add: restrict_return_spmf)
  1893 
  1894 lemma restrict_spmf_return_pmf_None [simp]: "return_pmf None \<upharpoonleft> A = return_pmf None"
  1895 by(simp add: restrict_spmf_def)
  1896 
  1897 lemma restrict_bind_pmf: "bind_pmf p g \<upharpoonleft> A = p \<bind> (\<lambda>x. g x \<upharpoonleft> A)"
  1898 by(simp add: restrict_spmf_def map_bind_pmf o_def)
  1899 
  1900 lemma restrict_bind_spmf: "bind_spmf p g \<upharpoonleft> A = p \<bind> (\<lambda>x. g x \<upharpoonleft> A)"
  1901 by(auto simp add: bind_spmf_def restrict_bind_pmf cong del: option.case_cong_weak cong: option.case_cong intro!: bind_pmf_cong split: option.split)
  1902 
  1903 lemma bind_restrict_pmf: "bind_pmf (p \<upharpoonleft> A) g = p \<bind> (\<lambda>x. if x \<in> Some ` A then g x else g None)"
  1904 by(auto simp add: restrict_spmf_def bind_map_pmf fun_eq_iff split: bind_split intro: arg_cong2[where f=bind_pmf])
  1905 
  1906 lemma bind_restrict_spmf: "bind_spmf (p \<upharpoonleft> A) g = p \<bind> (\<lambda>x. if x \<in> A then g x else return_pmf None)"
  1907 by(auto simp add: bind_spmf_def bind_restrict_pmf fun_eq_iff intro: arg_cong2[where f=bind_pmf] split: option.split)
  1908 
  1909 lemma spmf_map_restrict: "spmf (map_spmf fst (p \<upharpoonleft> (snd -` {y}))) x = spmf p (x, y)"
  1910 by(subst spmf_map)(auto intro: arg_cong2[where f=measure] simp add: spmf_conv_measure_spmf)
  1911 
  1912 lemma measure_eqI_restrict_spmf:
  1913   assumes "rel_spmf R (restrict_spmf p A) (restrict_spmf q B)"
  1914   shows "measure (measure_spmf p) A = measure (measure_spmf q) B"
  1915 proof -
  1916   from assms have "weight_spmf (restrict_spmf p A) = weight_spmf (restrict_spmf q B)" by(rule rel_spmf_weightD)
  1917   thus ?thesis by(simp add: weight_spmf_def)
  1918 qed
  1919 
  1920 subsection \<open>Subprobability distributions of sets\<close>
  1921 
  1922 definition spmf_of_set :: "'a set \<Rightarrow> 'a spmf"
  1923 where
  1924   "spmf_of_set A = (if finite A \<and> A \<noteq> {} then spmf_of_pmf (pmf_of_set A) else return_pmf None)"
  1925 
  1926 lemma spmf_of_set: "spmf (spmf_of_set A) x = indicator A x / card A"
  1927 by(auto simp add: spmf_of_set_def)
  1928 
  1929 lemma pmf_spmf_of_set_None [simp]: "pmf (spmf_of_set A) None = indicator {A. infinite A \<or> A = {}} A"
  1930 by(simp add: spmf_of_set_def)
  1931 
  1932 lemma set_spmf_of_set: "set_spmf (spmf_of_set A) = (if finite A then A else {})"
  1933 by(simp add: spmf_of_set_def)
  1934 
  1935 lemma set_spmf_of_set_finite [simp]: "finite A \<Longrightarrow> set_spmf (spmf_of_set A) = A"
  1936 by(simp add: set_spmf_of_set)
  1937 
  1938 lemma spmf_of_set_singleton: "spmf_of_set {x} = return_spmf x"
  1939 by(simp add: spmf_of_set_def pmf_of_set_singleton)
  1940 
  1941 lemma map_spmf_of_set_inj_on [simp]:
  1942   "inj_on f A \<Longrightarrow> map_spmf f (spmf_of_set A) = spmf_of_set (f ` A)"
  1943 by(auto simp add: spmf_of_set_def map_pmf_of_set_inj dest: finite_imageD)
  1944 
  1945 lemma spmf_of_pmf_pmf_of_set [simp]:
  1946   "\<lbrakk> finite A; A \<noteq> {} \<rbrakk> \<Longrightarrow> spmf_of_pmf (pmf_of_set A) = spmf_of_set A"
  1947 by(simp add: spmf_of_set_def)
  1948 
  1949 lemma weight_spmf_of_set:
  1950   "weight_spmf (spmf_of_set A) = (if finite A \<and> A \<noteq> {} then 1 else 0)"
  1951 by(auto simp only: spmf_of_set_def weight_spmf_of_pmf weight_return_pmf_None split: if_split)
  1952 
  1953 lemma weight_spmf_of_set_finite [simp]: "\<lbrakk> finite A; A \<noteq> {} \<rbrakk> \<Longrightarrow> weight_spmf (spmf_of_set A) = 1"
  1954 by(simp add: weight_spmf_of_set)
  1955 
  1956 lemma weight_spmf_of_set_infinite [simp]: "infinite A \<Longrightarrow> weight_spmf (spmf_of_set A) = 0"
  1957 by(simp add: weight_spmf_of_set)
  1958 
  1959 lemma measure_spmf_spmf_of_set:
  1960   "measure_spmf (spmf_of_set A) = (if finite A \<and> A \<noteq> {} then measure_pmf (pmf_of_set A) else null_measure (count_space UNIV))"
  1961 by(simp add: spmf_of_set_def del: spmf_of_pmf_pmf_of_set)
  1962 
  1963 lemma emeasure_spmf_of_set:
  1964   "emeasure (measure_spmf (spmf_of_set S)) A = card (S \<inter> A) / card S"
  1965 by(auto simp add: measure_spmf_spmf_of_set emeasure_pmf_of_set)
  1966 
  1967 lemma measure_spmf_of_set:
  1968   "measure (measure_spmf (spmf_of_set S)) A = card (S \<inter> A) / card S"
  1969 by(auto simp add: measure_spmf_spmf_of_set measure_pmf_of_set)
  1970 
  1971 lemma nn_integral_spmf_of_set: "nn_integral (measure_spmf (spmf_of_set A)) f = setsum f A / card A"
  1972 by(cases "finite A")(auto simp add: spmf_of_set_def nn_integral_pmf_of_set card_gt_0_iff simp del: spmf_of_pmf_pmf_of_set)
  1973 
  1974 lemma integral_spmf_of_set: "integral\<^sup>L (measure_spmf (spmf_of_set A)) f = setsum f A / card A"
  1975 by(clarsimp simp add: spmf_of_set_def integral_pmf_of_set card_gt_0_iff simp del: spmf_of_pmf_pmf_of_set)
  1976 
  1977 notepad begin \<comment> \<open>@{const pmf_of_set} is not fully parametric.\<close>
  1978   define R :: "nat \<Rightarrow> nat \<Rightarrow> bool" where "R x y \<longleftrightarrow> (x \<noteq> 0 \<longrightarrow> y = 0)" for x y
  1979   define A :: "nat set" where "A = {0, 1}"
  1980   define B :: "nat set" where "B = {0, 1, 2}"
  1981   have "rel_set R A B" unfolding R_def[abs_def] A_def B_def rel_set_def by auto
  1982   have "\<not> rel_pmf R (pmf_of_set A) (pmf_of_set B)"
  1983   proof
  1984     assume "rel_pmf R (pmf_of_set A) (pmf_of_set B)"
  1985     then obtain pq where pq: "\<And>x y. (x, y) \<in> set_pmf pq \<Longrightarrow> R x y"
  1986       and 1: "map_pmf fst pq = pmf_of_set A"
  1987       and 2: "map_pmf snd pq = pmf_of_set B"
  1988       by cases auto
  1989     have "pmf (pmf_of_set B) 1 = 1 / 3" by(simp add: B_def)
  1990     have "pmf (pmf_of_set B) 2 = 1 / 3" by(simp add: B_def)
  1991 
  1992     have "2 / 3 = pmf (pmf_of_set B) 1 + pmf (pmf_of_set B) 2" by(simp add: B_def)
  1993     also have "\<dots> = measure (measure_pmf (pmf_of_set B)) ({1} \<union> {2})"
  1994       by(subst measure_pmf.finite_measure_Union)(simp_all add: measure_pmf_single)
  1995     also have "\<dots> = emeasure (measure_pmf pq) (snd -` {2, 1})"
  1996       unfolding 2[symmetric] measure_pmf.emeasure_eq_measure[symmetric] by(simp)
  1997     also have "\<dots> = emeasure (measure_pmf pq) {(0, 2), (0, 1)}"
  1998       by(rule emeasure_eq_AE)(auto simp add: AE_measure_pmf_iff R_def dest!: pq)
  1999     also have "\<dots> \<le> emeasure (measure_pmf pq) (fst -` {0})"
  2000       by(rule emeasure_mono) auto
  2001     also have "\<dots> = emeasure (measure_pmf (pmf_of_set A)) {0}"
  2002       unfolding 1[symmetric] by simp
  2003     also have "\<dots> = pmf (pmf_of_set A) 0"
  2004       by(simp add: measure_pmf_single measure_pmf.emeasure_eq_measure)
  2005     also have "pmf (pmf_of_set A) 0 = 1 / 2" by(simp add: A_def)
  2006     finally show False by(subst (asm) ennreal_le_iff; simp)
  2007   qed
  2008 end
  2009 
  2010 lemma rel_pmf_of_set_bij:
  2011   assumes f: "bij_betw f A B"
  2012   and A: "A \<noteq> {}" "finite A"
  2013   and B: "B \<noteq> {}" "finite B"
  2014   and R: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)"
  2015   shows "rel_pmf R (pmf_of_set A) (pmf_of_set B)"
  2016 proof(rule pmf.rel_mono_strong)
  2017   define AB where "AB = (\<lambda>x. (x, f x)) ` A"
  2018   define R' where "R' x y \<longleftrightarrow> (x, y) \<in> AB" for x y
  2019   have "(x, y) \<in> AB" if "(x, y) \<in> set_pmf (pmf_of_set AB)" for x y
  2020     using that by(auto simp add: AB_def A)
  2021   moreover have "map_pmf fst (pmf_of_set AB) = pmf_of_set A"
  2022     by(simp add: AB_def map_pmf_of_set_inj[symmetric] inj_on_def A pmf.map_comp o_def)
  2023   moreover
  2024   from f have [simp]: "inj_on f A" by(rule bij_betw_imp_inj_on)
  2025   from f have [simp]: "f ` A = B" by(rule bij_betw_imp_surj_on)
  2026   have "map_pmf snd (pmf_of_set AB) = pmf_of_set B"
  2027     by(simp add: AB_def map_pmf_of_set_inj[symmetric] inj_on_def A pmf.map_comp o_def)
  2028       (simp add: map_pmf_of_set_inj A)
  2029   ultimately show "rel_pmf (\<lambda>x y. (x, y) \<in> AB) (pmf_of_set A) (pmf_of_set B)" ..
  2030 qed(auto intro: R)
  2031 
  2032 lemma rel_spmf_of_set_bij:
  2033   assumes f: "bij_betw f A B"
  2034   and R: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)"
  2035   shows "rel_spmf R (spmf_of_set A) (spmf_of_set B)"
  2036 proof -
  2037   have "finite A \<longleftrightarrow> finite B" using f by(rule bij_betw_finite)
  2038   moreover have "A = {} \<longleftrightarrow> B = {}" using f by(auto dest: bij_betw_empty2 bij_betw_empty1)
  2039   ultimately show ?thesis using assms
  2040     by(auto simp add: spmf_of_set_def simp del: spmf_of_pmf_pmf_of_set intro: rel_pmf_of_set_bij)
  2041 qed
  2042 
  2043 context includes lifting_syntax
  2044 begin
  2045 
  2046 lemma rel_spmf_of_set:
  2047   assumes "bi_unique R"
  2048   shows "(rel_set R ===> rel_spmf R) spmf_of_set spmf_of_set"
  2049 proof
  2050   fix A B
  2051   assume R: "rel_set R A B"
  2052   with assms obtain f where "bij_betw f A B" and f: "\<And>x. x \<in> A \<Longrightarrow> R x (f x)"
  2053     by(auto dest: bi_unique_rel_set_bij_betw)
  2054   then show "rel_spmf R (spmf_of_set A) (spmf_of_set B)" by(rule rel_spmf_of_set_bij)
  2055 qed
  2056 
  2057 end
  2058 
  2059 lemma map_mem_spmf_of_set:
  2060   assumes "finite B" "B \<noteq> {}"
  2061   shows "map_spmf (\<lambda>x. x \<in> A) (spmf_of_set B) = spmf_of_pmf (bernoulli_pmf (card (A \<inter> B) / card B))"
  2062   (is "?lhs = ?rhs")
  2063 proof(rule spmf_eqI)
  2064   fix i
  2065   have "ennreal (spmf ?lhs i) = card (B \<inter> (\<lambda>x. x \<in> A) -` {i}) / (card B)"
  2066     by(subst ennreal_spmf_map)(simp add: measure_spmf_spmf_of_set assms emeasure_pmf_of_set)
  2067   also have "\<dots> = (if i then card (B \<inter> A) / card B else card (B - A) / card B)"
  2068     by(auto intro: arg_cong[where f=card])
  2069   also have "\<dots> = (if i then card (B \<inter> A) / card B else (card B - card (B \<inter> A)) / card B)"
  2070     by(auto simp add: card_Diff_subset_Int assms)
  2071   also have "\<dots> = ennreal (spmf ?rhs i)"
  2072     by(simp add: assms card_gt_0_iff field_simps card_mono Int_commute of_nat_diff)
  2073   finally show "spmf ?lhs i = spmf ?rhs i" by simp
  2074 qed
  2075 
  2076 abbreviation coin_spmf :: "bool spmf"
  2077 where "coin_spmf \<equiv> spmf_of_set UNIV"
  2078 
  2079 lemma map_eq_const_coin_spmf: "map_spmf (op = c) coin_spmf = coin_spmf"
  2080 proof -
  2081   have "inj (op \<longleftrightarrow> c)" "range (op \<longleftrightarrow> c) = UNIV" by(auto intro: inj_onI)
  2082   then show ?thesis by simp
  2083 qed
  2084 
  2085 lemma bind_coin_spmf_eq_const: "coin_spmf \<bind> (\<lambda>x :: bool. return_spmf (b = x)) = coin_spmf"
  2086 using map_eq_const_coin_spmf unfolding map_spmf_conv_bind_spmf by simp
  2087 
  2088 lemma bind_coin_spmf_eq_const': "coin_spmf \<bind> (\<lambda>x :: bool. return_spmf (x = b)) = coin_spmf"
  2089 by(rewrite in "_ = \<hole>" bind_coin_spmf_eq_const[symmetric, of b])(auto intro: bind_spmf_cong)
  2090 
  2091 subsection \<open>Losslessness\<close>
  2092 
  2093 definition lossless_spmf :: "'a spmf \<Rightarrow> bool"
  2094 where "lossless_spmf p \<longleftrightarrow> weight_spmf p = 1"
  2095 
  2096 lemma lossless_iff_pmf_None: "lossless_spmf p \<longleftrightarrow> pmf p None = 0"
  2097 by(simp add: lossless_spmf_def pmf_None_eq_weight_spmf)
  2098 
  2099 lemma lossless_return_spmf [iff]: "lossless_spmf (return_spmf x)"
  2100 by(simp add: lossless_iff_pmf_None)
  2101 
  2102 lemma lossless_return_pmf_None [iff]: "\<not> lossless_spmf (return_pmf None)"
  2103 by(simp add: lossless_iff_pmf_None)
  2104 
  2105 lemma lossless_map_spmf [simp]: "lossless_spmf (map_spmf f p) \<longleftrightarrow> lossless_spmf p"
  2106 by(auto simp add: lossless_iff_pmf_None pmf_eq_0_set_pmf)
  2107 
  2108 lemma lossless_bind_spmf [simp]:
  2109   "lossless_spmf (p \<bind> f) \<longleftrightarrow> lossless_spmf p \<and> (\<forall>x\<in>set_spmf p. lossless_spmf (f x))"
  2110 by(simp add: lossless_iff_pmf_None pmf_bind_spmf_None add_nonneg_eq_0_iff integral_nonneg_AE integral_nonneg_eq_0_iff_AE measure_spmf.integrable_const_bound[where B=1] pmf_le_1)
  2111 
  2112 lemma lossless_weight_spmfD: "lossless_spmf p \<Longrightarrow> weight_spmf p = 1"
  2113 by(simp add: lossless_spmf_def)
  2114 
  2115 lemma lossless_iff_set_pmf_None:
  2116   "lossless_spmf p \<longleftrightarrow> None \<notin> set_pmf p"
  2117 by (simp add: lossless_iff_pmf_None pmf_eq_0_set_pmf)
  2118 
  2119 lemma lossless_spmf_of_set [simp]: "lossless_spmf (spmf_of_set A) \<longleftrightarrow> finite A \<and> A \<noteq> {}"
  2120 by(auto simp add: lossless_spmf_def weight_spmf_of_set)
  2121 
  2122 lemma lossless_spmf_spmf_of_spmf [simp]: "lossless_spmf (spmf_of_pmf p)"
  2123 by(simp add: lossless_spmf_def)
  2124 
  2125 lemma lossless_spmf_bind_pmf [simp]:
  2126   "lossless_spmf (bind_pmf p f) \<longleftrightarrow> (\<forall>x\<in>set_pmf p. lossless_spmf (f x))"
  2127 by(simp add: lossless_iff_pmf_None pmf_bind integral_nonneg_AE integral_nonneg_eq_0_iff_AE measure_pmf.integrable_const_bound[where B=1] AE_measure_pmf_iff pmf_le_1)
  2128 
  2129 lemma lossless_spmf_conv_spmf_of_pmf: "lossless_spmf p \<longleftrightarrow> (\<exists>p'. p = spmf_of_pmf p')"
  2130 proof
  2131   assume "lossless_spmf p"
  2132   hence *: "\<And>y. y \<in> set_pmf p \<Longrightarrow> \<exists>x. y = Some x"
  2133     by(case_tac y)(simp_all add: lossless_iff_set_pmf_None)
  2134 
  2135   let ?p = "map_pmf the p"
  2136   have "p = spmf_of_pmf ?p"
  2137   proof(rule spmf_eqI)
  2138     fix i
  2139     have "ennreal (pmf (map_pmf the p) i) = \<integral>\<^sup>+ x. indicator (the -` {i}) x \<partial>p" by(simp add: ennreal_pmf_map)
  2140     also have "\<dots> = \<integral>\<^sup>+ x. indicator {i} x \<partial>measure_spmf p" unfolding measure_spmf_def
  2141       by(subst nn_integral_distr)(auto simp add: nn_integral_restrict_space AE_measure_pmf_iff simp del: nn_integral_indicator intro!: nn_integral_cong_AE split: split_indicator dest!: * )
  2142     also have "\<dots> = spmf p i" by(simp add: emeasure_spmf_single)
  2143     finally show "spmf p i = spmf (spmf_of_pmf ?p) i" by simp
  2144   qed
  2145   thus "\<exists>p'. p = spmf_of_pmf p'" ..
  2146 qed auto
  2147 
  2148 lemma spmf_False_conv_True: "lossless_spmf p \<Longrightarrow> spmf p False = 1 - spmf p True"
  2149 by(clarsimp simp add: lossless_spmf_conv_spmf_of_pmf pmf_False_conv_True)
  2150 
  2151 lemma spmf_True_conv_False: "lossless_spmf p \<Longrightarrow> spmf p True = 1 - spmf p False"
  2152 by(simp add: spmf_False_conv_True)
  2153 
  2154 lemma bind_eq_return_spmf:
  2155   "bind_spmf p f = return_spmf x \<longleftrightarrow> (\<forall>y\<in>set_spmf p. f y = return_spmf x) \<and> lossless_spmf p"
  2156 by(auto simp add: bind_spmf_def bind_eq_return_pmf in_set_spmf lossless_iff_pmf_None pmf_eq_0_set_pmf iff del: not_None_eq split: option.split)
  2157 
  2158 lemma rel_spmf_return_spmf2:
  2159   "rel_spmf R p (return_spmf x) \<longleftrightarrow> lossless_spmf p \<and> (\<forall>a\<in>set_spmf p. R a x)"
  2160 by(auto simp add: lossless_iff_set_pmf_None rel_pmf_return_pmf2 option_rel_Some2 in_set_spmf, metis in_set_spmf not_None_eq)
  2161 
  2162 lemma rel_spmf_return_spmf1:
  2163   "rel_spmf R (return_spmf x) p \<longleftrightarrow> lossless_spmf p \<and> (\<forall>a\<in>set_spmf p. R x a)"
  2164 using rel_spmf_return_spmf2[of "R\<inverse>\<inverse>"] by(simp add: spmf_rel_conversep)
  2165 
  2166 lemma rel_spmf_bindI1:
  2167   assumes f: "\<And>x. x \<in> set_spmf p \<Longrightarrow> rel_spmf R (f x) q"
  2168   and p: "lossless_spmf p"
  2169   shows "rel_spmf R (bind_spmf p f) q"
  2170 proof -
  2171   fix x :: 'a
  2172   have "rel_spmf R (bind_spmf p f) (bind_spmf (return_spmf x) (\<lambda>_. q))"
  2173     by(rule rel_spmf_bindI[where R="\<lambda>x _. x \<in> set_spmf p"])(simp_all add: rel_spmf_return_spmf2 p f)
  2174   then show ?thesis by simp
  2175 qed
  2176 
  2177 lemma rel_spmf_bindI2:
  2178   "\<lbrakk> \<And>x. x \<in> set_spmf q \<Longrightarrow> rel_spmf R p (f x); lossless_spmf q \<rbrakk>
  2179   \<Longrightarrow> rel_spmf R p (bind_spmf q f)"
  2180 using rel_spmf_bindI1[of q "conversep R" f p] by(simp add: spmf_rel_conversep)
  2181 
  2182 subsection \<open>Scaling\<close>
  2183 
  2184 definition scale_spmf :: "real \<Rightarrow> 'a spmf \<Rightarrow> 'a spmf"
  2185 where
  2186   "scale_spmf r p = embed_spmf (\<lambda>x. min (inverse (weight_spmf p)) (max 0 r) * spmf p x)"
  2187 
  2188 lemma scale_spmf_le_1:
  2189   "(\<integral>\<^sup>+ x. min (inverse (weight_spmf p)) (max 0 r) * spmf p x \<partial>count_space UNIV) \<le> 1" (is "?lhs \<le> _")
  2190 proof -
  2191   have "?lhs = min (inverse (weight_spmf p)) (max 0 r) * \<integral>\<^sup>+ x. spmf p x \<partial>count_space UNIV"
  2192     by(subst nn_integral_cmult[symmetric])(simp_all add: weight_spmf_nonneg max_def min_def ennreal_mult)
  2193   also have "\<dots> \<le> 1" unfolding weight_spmf_eq_nn_integral_spmf[symmetric]
  2194     by(simp add: min_def max_def weight_spmf_nonneg order.strict_iff_order field_simps ennreal_mult[symmetric])
  2195   finally show ?thesis .
  2196 qed
  2197 
  2198 lemma spmf_scale_spmf: "spmf (scale_spmf r p) x = max 0 (min (inverse (weight_spmf p)) r) * spmf p x" (is "?lhs = ?rhs")
  2199 unfolding scale_spmf_def
  2200 apply(subst spmf_embed_spmf[OF scale_spmf_le_1])
  2201 apply(simp add: max_def min_def weight_spmf_le_0 field_simps weight_spmf_nonneg not_le order.strict_iff_order)
  2202 apply(metis antisym_conv order_trans weight_spmf_nonneg zero_le_mult_iff zero_le_one)
  2203 done
  2204 
  2205 lemma real_inverse_le_1_iff: fixes x :: real
  2206   shows "\<lbrakk> 0 \<le> x; x \<le> 1 \<rbrakk> \<Longrightarrow> 1 / x \<le> 1 \<longleftrightarrow> x = 1 \<or> x = 0"
  2207 by auto
  2208 
  2209 lemma spmf_scale_spmf': "r \<le> 1 \<Longrightarrow> spmf (scale_spmf r p) x = max 0 r * spmf p x"
  2210 using real_inverse_le_1_iff[OF weight_spmf_nonneg weight_spmf_le_1, of p]
  2211 by(auto simp add: spmf_scale_spmf max_def min_def field_simps)(metis pmf_le_0_iff spmf_le_weight)
  2212 
  2213 lemma scale_spmf_neg: "r \<le> 0 \<Longrightarrow> scale_spmf r p = return_pmf None"
  2214 by(rule spmf_eqI)(simp add: spmf_scale_spmf' max_def)
  2215 
  2216 lemma scale_spmf_return_None [simp]: "scale_spmf r (return_pmf None) = return_pmf None"
  2217 by(rule spmf_eqI)(simp add: spmf_scale_spmf)
  2218 
  2219 lemma scale_spmf_conv_bind_bernoulli:
  2220   assumes "r \<le> 1"
  2221   shows "scale_spmf r p = bind_pmf (bernoulli_pmf r) (\<lambda>b. if b then p else return_pmf None)" (is "?lhs = ?rhs")
  2222 proof(rule spmf_eqI)
  2223   fix x
  2224   have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using assms
  2225     unfolding spmf_scale_spmf ennreal_pmf_bind nn_integral_measure_pmf UNIV_bool bernoulli_pmf.rep_eq
  2226     apply(auto simp add: nn_integral_count_space_finite max_def min_def field_simps real_inverse_le_1_iff[OF weight_spmf_nonneg weight_spmf_le_1] weight_spmf_lt_0 not_le ennreal_mult[symmetric])
  2227     apply (metis pmf_le_0_iff spmf_le_weight)
  2228     apply (metis pmf_le_0_iff spmf_le_weight)
  2229     apply (meson le_divide_eq_1_pos measure_spmf.subprob_measure_le_1 not_less order_trans weight_spmf_le_0)
  2230     by (meson divide_le_0_1_iff less_imp_le order_trans weight_spmf_le_0)
  2231   thus "spmf ?lhs x = spmf ?rhs x" by simp
  2232 qed
  2233 
  2234 lemma nn_integral_spmf: "(\<integral>\<^sup>+ x. spmf p x \<partial>count_space A) = emeasure (measure_spmf p) A"
  2235 apply(simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space nn_integral_pmf[symmetric])
  2236 apply(rule nn_integral_bij_count_space[where g=Some])
  2237 apply(auto simp add: bij_betw_def)
  2238 done
  2239 
  2240 lemma measure_spmf_scale_spmf: "measure_spmf (scale_spmf r p) = scale_measure (min (inverse (weight_spmf p)) r) (measure_spmf p)"
  2241 apply(rule measure_eqI)
  2242  apply simp
  2243 apply(simp add: nn_integral_spmf[symmetric] spmf_scale_spmf)
  2244 apply(subst nn_integral_cmult[symmetric])
  2245 apply(auto simp add: max_def min_def ennreal_mult[symmetric] not_le ennreal_lt_0)
  2246 done
  2247 
  2248 lemma measure_spmf_scale_spmf':
  2249   "r \<le> 1 \<Longrightarrow> measure_spmf (scale_spmf r p) = scale_measure r (measure_spmf p)"
  2250 unfolding measure_spmf_scale_spmf
  2251 apply(cases "weight_spmf p > 0")
  2252  apply(simp add: min.absorb2 field_simps weight_spmf_le_1 mult_le_one)
  2253 apply(clarsimp simp add: weight_spmf_le_0 min_def scale_spmf_neg weight_spmf_eq_0 not_less)
  2254 done
  2255 
  2256 lemma scale_spmf_1 [simp]: "scale_spmf 1 p = p"
  2257 apply(rule spmf_eqI)
  2258 apply(simp add: spmf_scale_spmf max_def min_def order.strict_iff_order field_simps weight_spmf_nonneg)
  2259 apply(metis antisym_conv divide_le_eq_1 less_imp_le pmf_nonneg spmf_le_weight weight_spmf_nonneg weight_spmf_le_1)
  2260 done
  2261 
  2262 lemma scale_spmf_0 [simp]: "scale_spmf 0 p = return_pmf None"
  2263 by(rule spmf_eqI)(simp add: spmf_scale_spmf min_def max_def weight_spmf_le_0)
  2264 
  2265 lemma bind_scale_spmf:
  2266   assumes r: "r \<le> 1"
  2267   shows "bind_spmf (scale_spmf r p) f = bind_spmf p (\<lambda>x. scale_spmf r (f x))"
  2268   (is "?lhs = ?rhs")
  2269 proof(rule spmf_eqI)
  2270   fix x
  2271   have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using r
  2272     by(simp add: ennreal_spmf_bind measure_spmf_scale_spmf' nn_integral_scale_measure spmf_scale_spmf')
  2273       (simp add: ennreal_mult ennreal_lt_0 nn_integral_cmult max_def min_def)
  2274   thus "spmf ?lhs x = spmf ?rhs x" by simp
  2275 qed
  2276 
  2277 lemma scale_bind_spmf:
  2278   assumes "r \<le> 1"
  2279   shows "scale_spmf r (bind_spmf p f) = bind_spmf p (\<lambda>x. scale_spmf r (f x))"
  2280   (is "?lhs = ?rhs")
  2281 proof(rule spmf_eqI)
  2282   fix x
  2283   have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using assms
  2284     unfolding spmf_scale_spmf'[OF assms]
  2285     by(simp add: ennreal_mult ennreal_spmf_bind spmf_scale_spmf' nn_integral_cmult max_def min_def)
  2286   thus "spmf ?lhs x = spmf ?rhs x" by simp
  2287 qed
  2288 
  2289 lemma bind_spmf_const: "bind_spmf p (\<lambda>x. q) = scale_spmf (weight_spmf p) q" (is "?lhs = ?rhs")
  2290 proof(rule spmf_eqI)
  2291   fix x
  2292   have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)"
  2293     using measure_spmf.subprob_measure_le_1[of p "space (measure_spmf p)"]
  2294     by(subst ennreal_spmf_bind)(simp add: spmf_scale_spmf' weight_spmf_le_1 ennreal_mult mult.commute max_def min_def measure_spmf.emeasure_eq_measure)
  2295   thus "spmf ?lhs x = spmf ?rhs x" by simp
  2296 qed
  2297 
  2298 lemma map_scale_spmf: "map_spmf f (scale_spmf r p) = scale_spmf r (map_spmf f p)" (is "?lhs = ?rhs")
  2299 proof(rule spmf_eqI)
  2300   fix i
  2301   show "spmf ?lhs i = spmf ?rhs i" unfolding spmf_scale_spmf
  2302     by(subst (1 2) spmf_map)(auto simp add: measure_spmf_scale_spmf max_def min_def ennreal_lt_0)
  2303 qed
  2304 
  2305 lemma set_scale_spmf: "set_spmf (scale_spmf r p) = (if r > 0 then set_spmf p else {})"
  2306 apply(auto simp add: in_set_spmf_iff_spmf spmf_scale_spmf)
  2307 apply(simp add: max_def min_def not_le weight_spmf_lt_0 weight_spmf_eq_0 split: if_split_asm)
  2308 done
  2309 
  2310 lemma set_scale_spmf' [simp]: "0 < r \<Longrightarrow> set_spmf (scale_spmf r p) = set_spmf p"
  2311 by(simp add: set_scale_spmf)
  2312 
  2313 lemma rel_spmf_scaleI:
  2314   assumes "r > 0 \<Longrightarrow> rel_spmf A p q"
  2315   shows "rel_spmf A (scale_spmf r p) (scale_spmf r q)"
  2316 proof(cases "r > 0")
  2317   case True
  2318   from assms[OF this] show ?thesis
  2319     by(rule rel_spmfE)(auto simp add: map_scale_spmf[symmetric] spmf_rel_map True intro: rel_spmf_reflI)
  2320 qed(simp add: not_less scale_spmf_neg)
  2321 
  2322 lemma weight_scale_spmf: "weight_spmf (scale_spmf r p) = min 1 (max 0 r * weight_spmf p)"
  2323 proof -
  2324   have "ennreal (weight_spmf (scale_spmf r p)) = min 1 (max 0 r * ennreal (weight_spmf p))"
  2325     unfolding weight_spmf_eq_nn_integral_spmf
  2326     apply(simp add: spmf_scale_spmf ennreal_mult zero_ereal_def[symmetric] nn_integral_cmult)
  2327     apply(auto simp add: weight_spmf_eq_nn_integral_spmf[symmetric] field_simps min_def max_def not_le weight_spmf_lt_0 ennreal_mult[symmetric])
  2328     subgoal by(subst (asm) ennreal_mult[symmetric], meson divide_less_0_1_iff le_less_trans not_le weight_spmf_lt_0, simp+, meson not_le pos_divide_le_eq weight_spmf_le_0)
  2329     subgoal by(cases "r \<ge> 0")(simp_all add: ennreal_mult[symmetric] weight_spmf_nonneg ennreal_lt_0, meson le_less_trans not_le pos_divide_le_eq zero_less_divide_1_iff)
  2330     done
  2331   thus ?thesis by(auto simp add: min_def max_def ennreal_mult[symmetric] split: if_split_asm)
  2332 qed
  2333 
  2334 lemma weight_scale_spmf' [simp]:
  2335   "\<lbrakk> 0 \<le> r; r \<le> 1 \<rbrakk> \<Longrightarrow> weight_spmf (scale_spmf r p) = r * weight_spmf p"
  2336 by(simp add: weight_scale_spmf max_def min_def)(metis antisym_conv mult_left_le order_trans weight_spmf_le_1)
  2337 
  2338 lemma pmf_scale_spmf_None:
  2339   "pmf (scale_spmf k p) None = 1 - min 1 (max 0 k * (1 - pmf p None))"
  2340 unfolding pmf_None_eq_weight_spmf by(simp add: weight_scale_spmf)
  2341 
  2342 lemma scale_scale_spmf:
  2343   "scale_spmf r (scale_spmf r' p) = scale_spmf (r * max 0 (min (inverse (weight_spmf p)) r')) p"
  2344   (is "?lhs = ?rhs")
  2345 proof(rule spmf_eqI)
  2346   fix i
  2347   have "max 0 (min (1 / weight_spmf p) r') *
  2348     max 0 (min (1 / min 1 (weight_spmf p * max 0 r')) r) =
  2349     max 0 (min (1 / weight_spmf p) (r * max 0 (min (1 / weight_spmf p) r')))"
  2350   proof(cases "weight_spmf p > 0")
  2351     case False
  2352     thus ?thesis by(simp add: not_less weight_spmf_le_0)
  2353   next
  2354     case True
  2355     thus ?thesis by(simp add: field_simps max_def min.absorb_iff2[symmetric])(auto simp add: min_def field_simps zero_le_mult_iff)
  2356   qed
  2357   then show "spmf ?lhs i = spmf ?rhs i"
  2358     by(simp add: spmf_scale_spmf field_simps weight_scale_spmf)
  2359 qed
  2360 
  2361 lemma scale_scale_spmf' [simp]:
  2362   "\<lbrakk> 0 \<le> r; r \<le> 1; 0 \<le> r'; r' \<le> 1 \<rbrakk>
  2363   \<Longrightarrow> scale_spmf r (scale_spmf r' p) = scale_spmf (r * r') p"
  2364 apply(cases "weight_spmf p > 0")
  2365 apply(auto simp add: scale_scale_spmf min_def max_def field_simps not_le weight_spmf_lt_0 weight_spmf_eq_0 not_less weight_spmf_le_0)
  2366 apply(subgoal_tac "1 = r'")
  2367  apply (metis (no_types) divide_1 eq_iff measure_spmf.subprob_measure_le_1 mult.commute mult_cancel_right1)
  2368 apply(meson eq_iff le_divide_eq_1_pos measure_spmf.subprob_measure_le_1 mult_imp_div_pos_le order.trans)
  2369 done
  2370 
  2371 lemma scale_spmf_eq_same: "scale_spmf r p = p \<longleftrightarrow> weight_spmf p = 0 \<or> r = 1 \<or> r \<ge> 1 \<and> weight_spmf p = 1"
  2372   (is "?lhs \<longleftrightarrow> ?rhs")
  2373 proof
  2374   assume ?lhs
  2375   hence "weight_spmf (scale_spmf r p) = weight_spmf p" by simp
  2376   hence *: "min 1 (max 0 r * weight_spmf p) = weight_spmf p" by(simp add: weight_scale_spmf)
  2377   hence **: "weight_spmf p = 0 \<or> r \<ge> 1" by(auto simp add: min_def max_def split: if_split_asm)
  2378   show ?rhs
  2379   proof(cases "weight_spmf p = 0")
  2380     case False
  2381     with ** have "r \<ge> 1" by simp
  2382     with * False have "r = 1 \<or> weight_spmf p = 1" by(simp add: max_def min_def not_le split: if_split_asm)
  2383     with \<open>r \<ge> 1\<close> show ?thesis by simp
  2384   qed simp
  2385 qed(auto intro!: spmf_eqI simp add: spmf_scale_spmf, metis pmf_le_0_iff spmf_le_weight)
  2386 
  2387 lemma map_const_spmf_of_set:
  2388   "\<lbrakk> finite A; A \<noteq> {} \<rbrakk> \<Longrightarrow> map_spmf (\<lambda>_. c) (spmf_of_set A) = return_spmf c"
  2389 by(simp add: map_spmf_conv_bind_spmf bind_spmf_const)
  2390 
  2391 subsection \<open>Conditional spmfs\<close>
  2392 
  2393 lemma set_pmf_Int_Some: "set_pmf p \<inter> Some ` A = {} \<longleftrightarrow> set_spmf p \<inter> A = {}"
  2394 by(auto simp add: in_set_spmf)
  2395 
  2396 lemma measure_spmf_zero_iff: "measure (measure_spmf p) A = 0 \<longleftrightarrow> set_spmf p \<inter> A = {}"
  2397 unfolding measure_measure_spmf_conv_measure_pmf by(simp add: measure_pmf_zero_iff set_pmf_Int_Some)
  2398 
  2399 definition cond_spmf :: "'a spmf \<Rightarrow> 'a set \<Rightarrow> 'a spmf"
  2400 where "cond_spmf p A = (if set_spmf p \<inter> A = {} then return_pmf None else cond_pmf p (Some ` A))"
  2401 
  2402 lemma set_cond_spmf [simp]: "set_spmf (cond_spmf p A) = set_spmf p \<inter> A"
  2403 by(auto 4 4 simp add: cond_spmf_def in_set_spmf iff: set_cond_pmf[THEN set_eq_iff[THEN iffD1], THEN spec, rotated])
  2404 
  2405 lemma cond_map_spmf [simp]: "cond_spmf (map_spmf f p) A = map_spmf f (cond_spmf p (f -` A))"
  2406 proof -
  2407   have "map_option f -` Some ` A = Some ` f -` A" by auto
  2408   moreover have "set_pmf p \<inter> map_option f -` Some ` A \<noteq> {}" if "Some x \<in> set_pmf p" "f x \<in> A" for x
  2409     using that by auto
  2410   ultimately show ?thesis by(auto simp add: cond_spmf_def in_set_spmf cond_map_pmf)
  2411 qed
  2412 
  2413 lemma spmf_cond_spmf [simp]:
  2414   "spmf (cond_spmf p A) x = (if x \<in> A then spmf p x / measure (measure_spmf p) A else 0)"
  2415 by(auto simp add: cond_spmf_def pmf_cond set_pmf_Int_Some[symmetric] measure_measure_spmf_conv_measure_pmf measure_pmf_zero_iff)
  2416 
  2417 lemma bind_eq_return_pmf_None:
  2418   "bind_spmf p f = return_pmf None \<longleftrightarrow> (\<forall>x\<in>set_spmf p. f x = return_pmf None)"
  2419 by(auto simp add: bind_spmf_def bind_eq_return_pmf in_set_spmf split: option.splits)
  2420 
  2421 lemma return_pmf_None_eq_bind:
  2422   "return_pmf None = bind_spmf p f \<longleftrightarrow> (\<forall>x\<in>set_spmf p. f x = return_pmf None)"
  2423 using bind_eq_return_pmf_None[of p f] by auto
  2424 
  2425 (* Conditional probabilities do not seem to interact nicely with bind. *)
  2426 
  2427 subsection \<open>Product spmf\<close>
  2428 
  2429 definition pair_spmf :: "'a spmf \<Rightarrow> 'b spmf \<Rightarrow> ('a \<times> 'b) spmf"
  2430 where "pair_spmf p q = bind_pmf (pair_pmf p q) (\<lambda>xy. case xy of (Some x, Some y) \<Rightarrow> return_spmf (x, y) | _ \<Rightarrow> return_pmf None)"
  2431 
  2432 lemma map_fst_pair_spmf [simp]: "map_spmf fst (pair_spmf p q) = scale_spmf (weight_spmf q) p"
  2433 unfolding bind_spmf_const[symmetric]
  2434 apply(simp add: pair_spmf_def map_bind_pmf pair_pmf_def bind_assoc_pmf option.case_distrib)
  2435 apply(subst bind_commute_pmf)
  2436 apply(auto intro!: bind_pmf_cong[OF refl] simp add: bind_return_pmf bind_spmf_def bind_return_pmf' case_option_collapse option.case_distrib[where h="map_spmf _"] option.case_distrib[symmetric] case_option_id split: option.split cong del: option.case_cong)
  2437 done
  2438 
  2439 lemma map_snd_pair_spmf [simp]: "map_spmf snd (pair_spmf p q) = scale_spmf (weight_spmf p) q"
  2440 unfolding bind_spmf_const[symmetric]
  2441 apply(simp add: pair_spmf_def map_bind_pmf pair_pmf_def bind_assoc_pmf option.case_distrib cong del: option.case_cong)
  2442 apply(auto intro!: bind_pmf_cong[OF refl] simp add: bind_return_pmf bind_spmf_def bind_return_pmf' case_option_collapse option.case_distrib[where h="map_spmf _"] option.case_distrib[symmetric] case_option_id split: option.split cong del: option.case_cong)
  2443 done
  2444 
  2445 lemma set_pair_spmf [simp]: "set_spmf (pair_spmf p q) = set_spmf p \<times> set_spmf q"
  2446 by(auto 4 3 simp add: pair_spmf_def set_spmf_bind_pmf bind_UNION in_set_spmf intro: rev_bexI split: option.splits)
  2447 
  2448 lemma spmf_pair [simp]: "spmf (pair_spmf p q) (x, y) = spmf p x * spmf q y" (is "?lhs = ?rhs")
  2449 proof -
  2450   have "ennreal ?lhs = \<integral>\<^sup>+ a. \<integral>\<^sup>+ b. indicator {(x, y)} (a, b) \<partial>measure_spmf q \<partial>measure_spmf p"
  2451     unfolding measure_spmf_def pair_spmf_def ennreal_pmf_bind nn_integral_pair_pmf'
  2452     by(auto simp add: zero_ereal_def[symmetric] nn_integral_distr nn_integral_restrict_space nn_integral_multc[symmetric] intro!: nn_integral_cong split: option.split split_indicator)
  2453   also have "\<dots> = \<integral>\<^sup>+ a. (\<integral>\<^sup>+ b. indicator {y} b \<partial>measure_spmf q) * indicator {x} a \<partial>measure_spmf p"
  2454     by(subst nn_integral_multc[symmetric])(auto intro!: nn_integral_cong split: split_indicator)
  2455   also have "\<dots> = ennreal ?rhs" by(simp add: emeasure_spmf_single max_def ennreal_mult mult.commute)
  2456   finally show ?thesis by simp
  2457 qed
  2458 
  2459 lemma pair_map_spmf2: "pair_spmf p (map_spmf f q) = map_spmf (apsnd f) (pair_spmf p q)"
  2460 by(auto simp add: pair_spmf_def pair_map_pmf2 bind_map_pmf map_bind_pmf intro: bind_pmf_cong split: option.split)
  2461 
  2462 lemma pair_map_spmf1: "pair_spmf (map_spmf f p) q = map_spmf (apfst f) (pair_spmf p q)"
  2463 by(auto simp add: pair_spmf_def pair_map_pmf1 bind_map_pmf map_bind_pmf intro: bind_pmf_cong split: option.split)
  2464 
  2465 lemma pair_map_spmf: "pair_spmf (map_spmf f p) (map_spmf g q) = map_spmf (map_prod f g) (pair_spmf p q)"
  2466 unfolding pair_map_spmf2 pair_map_spmf1 spmf.map_comp by(simp add: apfst_def apsnd_def o_def prod.map_comp)
  2467 
  2468 lemma pair_spmf_alt_def: "pair_spmf p q = bind_spmf p (\<lambda>x. bind_spmf q (\<lambda>y. return_spmf (x, y)))"
  2469 by(auto simp add: pair_spmf_def pair_pmf_def bind_spmf_def bind_assoc_pmf bind_return_pmf split: option.split intro: bind_pmf_cong)
  2470 
  2471 lemma weight_pair_spmf [simp]: "weight_spmf (pair_spmf p q) = weight_spmf p * weight_spmf q"
  2472 unfolding pair_spmf_alt_def by(simp add: weight_bind_spmf o_def)
  2473 
  2474 lemma pair_scale_spmf1: (* FIXME: generalise to arbitrary r *)
  2475   "r \<le> 1 \<Longrightarrow> pair_spmf (scale_spmf r p) q = scale_spmf r (pair_spmf p q)"
  2476 by(simp add: pair_spmf_alt_def scale_bind_spmf bind_scale_spmf)
  2477 
  2478 lemma pair_scale_spmf2: (* FIXME: generalise to arbitrary r *)
  2479   "r \<le> 1 \<Longrightarrow> pair_spmf p (scale_spmf r q) = scale_spmf r (pair_spmf p q)"
  2480 by(simp add: pair_spmf_alt_def scale_bind_spmf bind_scale_spmf)
  2481 
  2482 lemma pair_spmf_return_None1 [simp]: "pair_spmf (return_pmf None) p = return_pmf None"
  2483 by(rule spmf_eqI)(clarsimp)
  2484 
  2485 lemma pair_spmf_return_None2 [simp]: "pair_spmf p (return_pmf None) = return_pmf None"
  2486 by(rule spmf_eqI)(clarsimp)
  2487 
  2488 lemma pair_spmf_return_spmf1: "pair_spmf (return_spmf x) q = map_spmf (Pair x) q"
  2489 by(rule spmf_eqI)(auto split: split_indicator simp add: spmf_map_inj' inj_on_def intro: spmf_map_outside)
  2490 
  2491 lemma pair_spmf_return_spmf2: "pair_spmf p (return_spmf y) = map_spmf (\<lambda>x. (x, y)) p"
  2492 by(rule spmf_eqI)(auto split: split_indicator simp add: inj_on_def intro!: spmf_map_outside spmf_map_inj'[symmetric])
  2493 
  2494 lemma pair_spmf_return_spmf [simp]: "pair_spmf (return_spmf x) (return_spmf y) = return_spmf (x, y)"
  2495 by(simp add: pair_spmf_return_spmf1)
  2496 
  2497 lemma rel_pair_spmf_prod:
  2498   "rel_spmf (rel_prod A B) (pair_spmf p q) (pair_spmf p' q') \<longleftrightarrow>
  2499    rel_spmf A (scale_spmf (weight_spmf q) p) (scale_spmf (weight_spmf q') p') \<and>
  2500    rel_spmf B (scale_spmf (weight_spmf p) q) (scale_spmf (weight_spmf p') q')"
  2501   (is "?lhs \<longleftrightarrow> ?rhs" is "_ \<longleftrightarrow> ?A \<and> ?B" is "_ \<longleftrightarrow> rel_spmf _ ?p ?p' \<and> rel_spmf _ ?q ?q'")
  2502 proof(intro iffI conjI)
  2503   assume ?rhs
  2504   then obtain pq pq' where p: "map_spmf fst pq = ?p" and p': "map_spmf snd pq = ?p'"
  2505     and q: "map_spmf fst pq' = ?q" and q': "map_spmf snd pq' = ?q'"
  2506     and *: "\<And>x x'. (x, x') \<in> set_spmf pq \<Longrightarrow> A x x'"
  2507     and **: "\<And>y y'. (y, y') \<in> set_spmf pq' \<Longrightarrow> B y y'" by(auto elim!: rel_spmfE)
  2508   let ?f = "\<lambda>((x, x'), (y, y')). ((x, y), (x', y'))"
  2509   let ?r = "1 / (weight_spmf p * weight_spmf q)"
  2510   let ?pq = "scale_spmf ?r (map_spmf ?f (pair_spmf pq pq'))"
  2511 
  2512   { fix p :: "'x spmf" and q :: "'y spmf"
  2513     assume "weight_spmf q \<noteq> 0"
  2514       and "weight_spmf p \<noteq> 0"
  2515       and "1 / (weight_spmf p * weight_spmf q) \<le> weight_spmf p * weight_spmf q"
  2516     hence "1 \<le> (weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q)"
  2517       by(simp add: pos_divide_le_eq order.strict_iff_order weight_spmf_nonneg)
  2518     moreover have "(weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q) \<le> (1 * 1) * (1 * 1)"
  2519       by(intro mult_mono)(simp_all add: weight_spmf_nonneg weight_spmf_le_1)
  2520     ultimately have "(weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q) = 1" by simp
  2521     hence *: "weight_spmf p * weight_spmf q = 1"
  2522       by(metis antisym_conv less_le mult_less_cancel_left1 weight_pair_spmf weight_spmf_le_1 weight_spmf_nonneg)
  2523     hence "weight_spmf p = 1" by(metis antisym_conv mult_left_le weight_spmf_le_1 weight_spmf_nonneg)
  2524     moreover with * have "weight_spmf q = 1" by simp
  2525     moreover note calculation }
  2526   note full = this
  2527 
  2528   show ?lhs
  2529   proof
  2530     have [simp]: "fst \<circ> ?f = map_prod fst fst" by(simp add: fun_eq_iff)
  2531     have "map_spmf fst ?pq = scale_spmf ?r (pair_spmf ?p ?q)"
  2532       by(simp add: pair_map_spmf[symmetric] p q map_scale_spmf spmf.map_comp)
  2533     also have "\<dots> = pair_spmf p q" using full[of p q]
  2534       by(simp add: pair_scale_spmf1 pair_scale_spmf2 weight_spmf_le_1 weight_spmf_nonneg)
  2535         (auto simp add: scale_scale_spmf max_def min_def field_simps weight_spmf_nonneg weight_spmf_eq_0)
  2536     finally show "map_spmf fst ?pq = \<dots>" .
  2537 
  2538     have [simp]: "snd \<circ> ?f = map_prod snd snd" by(simp add: fun_eq_iff)
  2539     from \<open>?rhs\<close> have eq: "weight_spmf p * weight_spmf q = weight_spmf p' * weight_spmf q'"
  2540       by(auto dest!: rel_spmf_weightD simp add: weight_spmf_le_1 weight_spmf_nonneg)
  2541 
  2542     have "map_spmf snd ?pq = scale_spmf ?r (pair_spmf ?p' ?q')"
  2543       by(simp add: pair_map_spmf[symmetric] p' q' map_scale_spmf spmf.map_comp)
  2544     also have "\<dots> = pair_spmf p' q'" using full[of p' q'] eq
  2545       by(simp add: pair_scale_spmf1 pair_scale_spmf2 weight_spmf_le_1 weight_spmf_nonneg)
  2546         (auto simp add: scale_scale_spmf max_def min_def field_simps weight_spmf_nonneg weight_spmf_eq_0)
  2547     finally show "map_spmf snd ?pq = \<dots>" .
  2548   qed(auto simp add: set_scale_spmf split: if_split_asm dest: * ** )
  2549 next
  2550   assume ?lhs
  2551   then obtain pq where pq: "map_spmf fst pq = pair_spmf p q"
  2552     and pq': "map_spmf snd pq = pair_spmf p' q'"
  2553     and *: "\<And>x y x' y'. ((x, y), (x', y')) \<in> set_spmf pq \<Longrightarrow> A x x' \<and> B y y'"
  2554     by(auto elim: rel_spmfE)
  2555 
  2556   show ?A
  2557   proof
  2558     let ?f = "(\<lambda>((x, y), (x', y')). (x, x'))"
  2559     let ?pq = "map_spmf ?f pq"
  2560     have [simp]: "fst \<circ> ?f = fst \<circ> fst" by(simp add: split_def o_def)
  2561     show "map_spmf fst ?pq = scale_spmf (weight_spmf q) p" using pq
  2562       by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
  2563 
  2564     have [simp]: "snd \<circ> ?f = fst \<circ> snd" by(simp add: split_def o_def)
  2565     show "map_spmf snd ?pq = scale_spmf (weight_spmf q') p'" using pq'
  2566       by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
  2567   qed(auto dest: * )
  2568 
  2569   show ?B
  2570   proof
  2571     let ?f = "(\<lambda>((x, y), (x', y')). (y, y'))"
  2572     let ?pq = "map_spmf ?f pq"
  2573     have [simp]: "fst \<circ> ?f = snd \<circ> fst" by(simp add: split_def o_def)
  2574     show "map_spmf fst ?pq = scale_spmf (weight_spmf p) q" using pq
  2575       by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
  2576 
  2577     have [simp]: "snd \<circ> ?f = snd \<circ> snd" by(simp add: split_def o_def)
  2578     show "map_spmf snd ?pq = scale_spmf (weight_spmf p') q'" using pq'
  2579       by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric])
  2580   qed(auto dest: * )
  2581 qed
  2582 
  2583 lemma pair_pair_spmf:
  2584   "pair_spmf (pair_spmf p q) r = map_spmf (\<lambda>(x, (y, z)). ((x, y), z)) (pair_spmf p (pair_spmf q r))"
  2585 by(simp add: pair_spmf_alt_def map_spmf_conv_bind_spmf)
  2586 
  2587 lemma pair_commute_spmf:
  2588   "pair_spmf p q = map_spmf (\<lambda>(y, x). (x, y)) (pair_spmf q p)"
  2589 unfolding pair_spmf_alt_def by(subst bind_commute_spmf)(simp add: map_spmf_conv_bind_spmf)
  2590 
  2591 subsection \<open>Assertions\<close>
  2592 
  2593 definition assert_spmf :: "bool \<Rightarrow> unit spmf"
  2594 where "assert_spmf b = (if b then return_spmf () else return_pmf None)"
  2595 
  2596 lemma assert_spmf_simps [simp]:
  2597   "assert_spmf True = return_spmf ()"
  2598   "assert_spmf False = return_pmf None"
  2599 by(simp_all add: assert_spmf_def)
  2600 
  2601 lemma in_set_assert_spmf [simp]: "x \<in> set_spmf (assert_spmf p) \<longleftrightarrow> p"
  2602 by(cases p) simp_all
  2603 
  2604 lemma set_spmf_assert_spmf_eq_empty [simp]: "set_spmf (assert_spmf b) = {} \<longleftrightarrow> \<not> b"
  2605 by(cases b) simp_all
  2606 
  2607 lemma lossless_assert_spmf [iff]: "lossless_spmf (assert_spmf b) \<longleftrightarrow> b"
  2608 by(cases b) simp_all
  2609 
  2610 subsection \<open>Try\<close>
  2611 
  2612 definition try_spmf :: "'a spmf \<Rightarrow> 'a spmf \<Rightarrow> 'a spmf" ("TRY _ ELSE _" [0,60] 59)
  2613 where "try_spmf p q = bind_pmf p (\<lambda>x. case x of None \<Rightarrow> q | Some y \<Rightarrow> return_spmf y)"
  2614 
  2615 lemma try_spmf_lossless [simp]:
  2616   assumes "lossless_spmf p"
  2617   shows "TRY p ELSE q = p"
  2618 proof -
  2619   have "TRY p ELSE q = bind_pmf p return_pmf" unfolding try_spmf_def using assms
  2620     by(auto simp add: lossless_iff_set_pmf_None split: option.split intro: bind_pmf_cong)
  2621   thus ?thesis by(simp add: bind_return_pmf')
  2622 qed
  2623 
  2624 lemma try_spmf_return_spmf1: "TRY return_spmf x ELSE q = return_spmf x"
  2625 by(simp add: try_spmf_def bind_return_pmf)
  2626 
  2627 lemma try_spmf_return_None [simp]: "TRY return_pmf None ELSE q = q"
  2628 by(simp add: try_spmf_def bind_return_pmf)
  2629 
  2630 lemma try_spmf_return_pmf_None2 [simp]: "TRY p ELSE return_pmf None = p"
  2631 by(simp add: try_spmf_def option.case_distrib[symmetric] bind_return_pmf' case_option_id)
  2632 
  2633 lemma map_try_spmf: "map_spmf f (try_spmf p q) = try_spmf (map_spmf f p) (map_spmf f q)"
  2634 by(simp add: try_spmf_def map_bind_pmf bind_map_pmf option.case_distrib[where h="map_spmf f"] o_def cong del: option.case_cong_weak)
  2635 
  2636 lemma try_spmf_bind_pmf: "TRY (bind_pmf p f) ELSE q = bind_pmf p (\<lambda>x. TRY (f x) ELSE q)"
  2637 by(simp add: try_spmf_def bind_assoc_pmf)
  2638 
  2639 lemma try_spmf_bind_spmf_lossless:
  2640   "lossless_spmf p \<Longrightarrow> TRY (bind_spmf p f) ELSE q = bind_spmf p (\<lambda>x. TRY (f x) ELSE q)"
  2641 by(auto simp add: try_spmf_def bind_spmf_def bind_assoc_pmf bind_return_pmf lossless_iff_set_pmf_None intro!: bind_pmf_cong split: option.split)
  2642 
  2643 lemma try_spmf_bind_out:
  2644   "lossless_spmf p \<Longrightarrow> bind_spmf p (\<lambda>x. TRY (f x) ELSE q) = TRY (bind_spmf p f) ELSE q"
  2645 by(simp add: try_spmf_bind_spmf_lossless)
  2646 
  2647 lemma lossless_try_spmf [simp]:
  2648   "lossless_spmf (TRY p ELSE q) \<longleftrightarrow> lossless_spmf p \<or> lossless_spmf q"
  2649 by(auto simp add: try_spmf_def in_set_spmf lossless_iff_set_pmf_None split: option.splits)
  2650 
  2651 context includes lifting_syntax
  2652 begin
  2653 
  2654 lemma try_spmf_parametric [transfer_rule]:
  2655   "(rel_spmf A ===> rel_spmf A ===> rel_spmf A) try_spmf try_spmf"
  2656 unfolding try_spmf_def[abs_def] by transfer_prover
  2657 
  2658 end
  2659 
  2660 lemma try_spmf_cong:
  2661   "\<lbrakk> p = p'; \<not> lossless_spmf p' \<Longrightarrow> q = q' \<rbrakk> \<Longrightarrow> TRY p ELSE q = TRY p' ELSE q'"
  2662 unfolding try_spmf_def
  2663 by(rule bind_pmf_cong)(auto split: option.split simp add: lossless_iff_set_pmf_None)
  2664 
  2665 lemma rel_spmf_try_spmf:
  2666   "\<lbrakk> rel_spmf R p p'; \<not> lossless_spmf p' \<Longrightarrow> rel_spmf R q q' \<rbrakk>
  2667   \<Longrightarrow> rel_spmf R (TRY p ELSE q) (TRY p' ELSE q')"
  2668 unfolding try_spmf_def
  2669 apply(rule rel_pmf_bindI[where R="\<lambda>x y. rel_option R x y \<and> x \<in> set_pmf p \<and> y \<in> set_pmf p'"])
  2670  apply(erule pmf.rel_mono_strong; simp)
  2671 apply(auto split: option.split simp add: lossless_iff_set_pmf_None)
  2672 done
  2673 
  2674 lemma spmf_try_spmf:
  2675   "spmf (TRY p ELSE q) x = spmf p x + pmf p None * spmf q x"
  2676 proof -
  2677   have "ennreal (spmf (TRY p ELSE q) x) = \<integral>\<^sup>+ y. ennreal (spmf q x) * indicator {None} y + indicator {Some x} y \<partial>measure_pmf p"
  2678     unfolding try_spmf_def ennreal_pmf_bind by(rule nn_integral_cong)(simp split: option.split split_indicator)
  2679   also have "\<dots> = (\<integral>\<^sup>+ y. ennreal (spmf q x) * indicator {None} y \<partial>measure_pmf p) + \<integral>\<^sup>+ y. indicator {Some x} y \<partial>measure_pmf p"
  2680     by(simp add: nn_integral_add)
  2681   also have "\<dots> = ennreal (spmf q x) * pmf p None + spmf p x" by(simp add: emeasure_pmf_single)
  2682   finally show ?thesis by(simp add: ennreal_mult[symmetric] ennreal_plus[symmetric] del: ennreal_plus)
  2683 qed
  2684 
  2685 lemma try_scale_spmf_same [simp]: "lossless_spmf p \<Longrightarrow> TRY scale_spmf k p ELSE p = p"
  2686 by(rule spmf_eqI)(auto simp add: spmf_try_spmf spmf_scale_spmf pmf_scale_spmf_None lossless_iff_pmf_None weight_spmf_conv_pmf_None min_def max_def field_simps)
  2687 
  2688 lemma pmf_try_spmf_None [simp]: "pmf (TRY p ELSE q) None = pmf p None * pmf q None" (is "?lhs = ?rhs")
  2689 proof -
  2690   have "?lhs = \<integral> x. pmf q None * indicator {None} x \<partial>measure_pmf p"
  2691     unfolding try_spmf_def pmf_bind by(rule integral_cong)(simp_all split: option.split)
  2692   also have "\<dots> = ?rhs" by(simp add: measure_pmf_single)
  2693   finally show ?thesis .
  2694 qed
  2695 
  2696 lemma try_bind_spmf_lossless2:
  2697   "lossless_spmf q \<Longrightarrow> TRY (bind_spmf p f) ELSE q = TRY (p \<bind> (\<lambda>x. TRY (f x) ELSE q)) ELSE q"
  2698 by(rule spmf_eqI)(simp add: spmf_try_spmf pmf_bind_spmf_None spmf_bind field_simps measure_spmf.integrable_const_bound[where B=1] pmf_le_1 lossless_iff_pmf_None)
  2699 
  2700 lemma try_bind_spmf_lossless2':
  2701   fixes f :: "'a \<Rightarrow> 'b spmf" shows
  2702   "\<lbrakk> NO_MATCH (\<lambda>x :: 'a. try_spmf (g x :: 'b spmf) (h x)) f; lossless_spmf q \<rbrakk>
  2703   \<Longrightarrow> TRY (bind_spmf p f) ELSE q = TRY (p \<bind> (\<lambda>x :: 'a. TRY (f x) ELSE q)) ELSE q"
  2704 by(rule try_bind_spmf_lossless2)
  2705 
  2706 lemma try_bind_assert_spmf:
  2707   "TRY (assert_spmf b \<bind> f) ELSE q = (if b then TRY (f ()) ELSE q else q)"
  2708 by simp
  2709 
  2710 subsection \<open>Miscellaneous\<close>
  2711 
  2712 lemma assumes "rel_spmf (\<lambda>x y. bad1 x = bad2 y \<and> (\<not> bad2 y \<longrightarrow> A x \<longleftrightarrow> B y)) p q" (is "rel_spmf ?A _ _")
  2713   shows fundamental_lemma_bad: "measure (measure_spmf p) {x. bad1 x} = measure (measure_spmf q) {y. bad2 y}" (is "?bad")
  2714   and fundamental_lemma: "\<bar>measure (measure_spmf p) {x. A x} - measure (measure_spmf q) {y. B y}\<bar> \<le>
  2715     measure (measure_spmf p) {x. bad1 x}" (is ?fundamental)
  2716 proof -
  2717   have good: "rel_fun ?A op = (\<lambda>x. A x \<and> \<not> bad1 x) (\<lambda>y. B y \<and> \<not> bad2 y)" by(auto simp add: rel_fun_def)
  2718   from assms have 1: "measure (measure_spmf p) {x. A x \<and> \<not> bad1 x} = measure (measure_spmf q) {y. B y \<and> \<not> bad2 y}"
  2719     by(rule measure_spmf_parametric[THEN rel_funD, THEN rel_funD])(rule Collect_parametric[THEN rel_funD, OF good])
  2720 
  2721   have bad: "rel_fun ?A op = bad1 bad2" by(simp add: rel_fun_def)
  2722   show 2: ?bad using assms
  2723     by(rule measure_spmf_parametric[THEN rel_funD, THEN rel_funD])(rule Collect_parametric[THEN rel_funD, OF bad])
  2724 
  2725   let ?\<mu>p = "measure (measure_spmf p)" and ?\<mu>q = "measure (measure_spmf q)"
  2726   have "{x. A x \<and> bad1 x} \<union> {x. A x \<and> \<not> bad1 x} = {x. A x}"
  2727     and "{y. B y \<and> bad2 y} \<union> {y. B y \<and> \<not> bad2 y} = {y. B y}" by auto
  2728   then have "\<bar>?\<mu>p {x. A x} - ?\<mu>q {x. B x}\<bar> = \<bar>?\<mu>p ({x. A x \<and> bad1 x} \<union> {x. A x \<and> \<not> bad1 x}) - ?\<mu>q ({y. B y \<and> bad2 y} \<union> {y. B y \<and> \<not> bad2 y})\<bar>"
  2729     by simp
  2730   also have "\<dots> = \<bar>?\<mu>p {x. A x \<and> bad1 x} + ?\<mu>p {x. A x \<and> \<not> bad1 x} - ?\<mu>q {y. B y \<and> bad2 y} - ?\<mu>q {y. B y \<and> \<not> bad2 y}\<bar>"
  2731     by(subst (1 2) measure_Union)(auto)
  2732   also have "\<dots> = \<bar>?\<mu>p {x. A x \<and> bad1 x} - ?\<mu>q {y. B y \<and> bad2 y}\<bar>" using 1 by simp
  2733   also have "\<dots> \<le> max (?\<mu>p {x. A x \<and> bad1 x}) (?\<mu>q {y. B y \<and> bad2 y})"
  2734     by(rule abs_leI)(auto simp add: max_def not_le, simp_all only: add_increasing measure_nonneg mult_2)
  2735   also have "\<dots> \<le> max (?\<mu>p {x. bad1 x}) (?\<mu>q {y. bad2 y})"
  2736     by(rule max.mono; rule measure_spmf.finite_measure_mono; auto)
  2737   also note 2[symmetric]
  2738   finally show ?fundamental by simp
  2739 qed
  2740 
  2741 end