# Theory SPMF

theory SPMF
imports Probability_Mass_Function Complete_Partial_Order2 Rewrite
```(* Author: Andreas Lochbihler, ETH Zurich *)

section ‹Discrete subprobability distribution›

theory SPMF imports
Probability_Mass_Function
"HOL-Library.Complete_Partial_Order2"
"HOL-Library.Rewrite"
begin

subsection ‹Auxiliary material›

lemma cSUP_singleton [simp]: "(SUP x:{x}. f x :: _ :: conditionally_complete_lattice) = f x"
by (metis cSup_singleton image_empty image_insert)

lemma [simp]:
shows ennreal_max_0: "ennreal (max 0 x) = ennreal x"
and ennreal_max_0': "ennreal (max x 0) = ennreal x"

lemma ennreal_enn2real_if: "ennreal (enn2real r) = (if r = ⊤ then 0 else r)"
by(auto intro!: ennreal_enn2real simp add: less_top)

lemma e2ennreal_0 [simp]: "e2ennreal 0 = 0"

lemma enn2real_bot [simp]: "enn2real ⊥ = 0"

lemma continuous_at_ennreal[continuous_intros]: "continuous F f ⟹ continuous F (λx. ennreal (f x))"
unfolding continuous_def by auto

lemma ennreal_Sup:
assumes *: "(SUP a:A. ennreal a) ≠ ⊤"
and "A ≠ {}"
shows "ennreal (Sup A) = (SUP a:A. ennreal a)"
proof (rule continuous_at_Sup_mono)
obtain r where r: "ennreal r = (SUP a:A. ennreal a)" "r ≥ 0"
using * by(cases "(SUP a:A. ennreal a)") simp_all
then show "bdd_above A"
by(auto intro!: SUP_upper bdd_aboveI[of _ r] simp add: ennreal_le_iff[symmetric])
qed (auto simp: mono_def continuous_at_imp_continuous_at_within continuous_at_ennreal ennreal_leI assms)

lemma ennreal_SUP:
"⟦ (SUP a:A. ennreal (f a)) ≠ ⊤; A ≠ {} ⟧ ⟹ ennreal (SUP a:A. f a) = (SUP a:A. ennreal (f a))"
using ennreal_Sup[of "f ` A"] by auto

lemma ennreal_lt_0: "x < 0 ⟹ ennreal x = 0"

subsubsection ‹More about @{typ "'a option"}›

lemma None_in_map_option_image [simp]: "None ∈ map_option f ` A ⟷ None ∈ A"
by auto

lemma Some_in_map_option_image [simp]: "Some x ∈ map_option f ` A ⟷ (∃y. x = f y ∧ Some y ∈ A)"
by(auto intro: rev_image_eqI dest: sym)

lemma case_option_collapse: "case_option x (λ_. x) = (λ_. x)"

lemma case_option_id: "case_option None Some = id"
by(rule ext)(simp split: option.split)

inductive ord_option :: "('a ⇒ 'b ⇒ bool) ⇒ 'a option ⇒ 'b option ⇒ bool"
for ord :: "'a ⇒ 'b ⇒ bool"
where
None: "ord_option ord None x"
| Some: "ord x y ⟹ ord_option ord (Some x) (Some y)"

inductive_simps ord_option_simps [simp]:
"ord_option ord None x"
"ord_option ord x None"
"ord_option ord (Some x) (Some y)"
"ord_option ord (Some x) None"

inductive_simps ord_option_eq_simps [simp]:
"ord_option op = None y"
"ord_option op = (Some x) y"

lemma ord_option_reflI: "(⋀y. y ∈ set_option x ⟹ ord y y) ⟹ ord_option ord x x"
by(cases x) simp_all

lemma reflp_ord_option: "reflp ord ⟹ reflp (ord_option ord)"

lemma ord_option_trans:
"⟦ ord_option ord x y; ord_option ord y z;
⋀a b c. ⟦ a ∈ set_option x; b ∈ set_option y; c ∈ set_option z; ord a b; ord b c ⟧ ⟹ ord a c ⟧
⟹ ord_option ord x z"
by(auto elim!: ord_option.cases)

lemma transp_ord_option: "transp ord ⟹ transp (ord_option ord)"
unfolding transp_def by(blast intro: ord_option_trans)

lemma antisymp_ord_option: "antisymp ord ⟹ antisymp (ord_option ord)"
by(auto intro!: antisympI elim!: ord_option.cases dest: antisympD)

lemma ord_option_chainD:
"Complete_Partial_Order.chain (ord_option ord) Y
⟹ Complete_Partial_Order.chain ord {x. Some x ∈ Y}"
by(rule chainI)(auto dest: chainD)

definition lub_option :: "('a set ⇒ 'b) ⇒ 'a option set ⇒ 'b option"
where "lub_option lub Y = (if Y ⊆ {None} then None else Some (lub {x. Some x ∈ Y}))"

lemma map_lub_option: "map_option f (lub_option lub Y) = lub_option (f ∘ lub) Y"

lemma lub_option_upper:
assumes "Complete_Partial_Order.chain (ord_option ord) Y" "x ∈ Y"
and lub_upper: "⋀Y x. ⟦ Complete_Partial_Order.chain ord Y; x ∈ Y ⟧ ⟹ ord x (lub Y)"
shows "ord_option ord x (lub_option lub Y)"
using assms(1-2)
by(cases x)(auto simp add: lub_option_def intro: lub_upper[OF ord_option_chainD])

lemma lub_option_least:
assumes Y: "Complete_Partial_Order.chain (ord_option ord) Y"
and upper: "⋀x. x ∈ Y ⟹ ord_option ord x y"
assumes lub_least: "⋀Y y. ⟦ Complete_Partial_Order.chain ord Y; ⋀x. x ∈ Y ⟹ ord x y ⟧ ⟹ ord (lub Y) y"
shows "ord_option ord (lub_option lub Y) y"
using Y
by(cases y)(auto 4 3 simp add: lub_option_def intro: lub_least[OF ord_option_chainD] dest: upper)

lemma lub_map_option: "lub_option lub (map_option f ` Y) = lub_option (lub ∘ op ` f) Y"
apply(erule notE)
apply(rule arg_cong[where f=lub])
apply(auto intro: rev_image_eqI dest: sym)
done

lemma ord_option_mono: "⟦ ord_option A x y; ⋀x y. A x y ⟹ B x y ⟧ ⟹ ord_option B x y"
by(auto elim: ord_option.cases)

lemma ord_option_mono' [mono]:
"(⋀x y. A x y ⟶ B x y) ⟹ ord_option A x y ⟶ ord_option B x y"
by(blast intro: ord_option_mono)

lemma ord_option_compp: "ord_option (A OO B) = ord_option A OO ord_option B"
by(auto simp add: fun_eq_iff elim!: ord_option.cases intro: ord_option.intros)

lemma ord_option_inf: "inf (ord_option A) (ord_option B) = ord_option (inf A B)" (is "?lhs = ?rhs")
proof(rule antisym)
show "?lhs ≤ ?rhs" by(auto elim!: ord_option.cases)
qed(auto elim: ord_option_mono)

lemma ord_option_map2: "ord_option ord x (map_option f y) = ord_option (λx y. ord x (f y)) x y"
by(auto elim: ord_option.cases)

lemma ord_option_map1: "ord_option ord (map_option f x) y = ord_option (λx y. ord (f x) y) x y"
by(auto elim: ord_option.cases)

lemma option_ord_Some1_iff: "option_ord (Some x) y ⟷ y = Some x"

subsubsection ‹A relator for sets that treats sets like predicates›

context includes lifting_syntax
begin

definition rel_pred :: "('a ⇒ 'b ⇒ bool) ⇒ 'a set ⇒ 'b set ⇒ bool"
where "rel_pred R A B = (R ===> op =) (λx. x ∈ A) (λy. y ∈ B)"

lemma rel_predI: "(R ===> op =) (λx. x ∈ A) (λy. y ∈ B) ⟹ rel_pred R A B"

lemma rel_predD: "⟦ rel_pred R A B; R x y ⟧ ⟹ x ∈ A ⟷ y ∈ B"

lemma Collect_parametric: "((A ===> op =) ===> rel_pred A) Collect Collect"
― ‹Declare this rule as @{attribute transfer_rule} only locally
because it blows up the search space for @{method transfer}
(in combination with @{thm [source] Collect_transfer})›

end

subsubsection ‹Monotonicity rules›

lemma monotone_gfp_eadd1: "monotone op ≥ op ≥ (λx. x + y :: enat)"
by(auto intro!: monotoneI)

lemma monotone_gfp_eadd2: "monotone op ≥ op ≥ (λy. x + y :: enat)"
by(auto intro!: monotoneI)

shows monotone_eadd: "monotone (rel_prod op ≥ op ≥) op ≥ (λ(x, y). x + y :: enat)"

"⟦ monotone (fun_ord op ≥) op ≥ f; monotone (fun_ord op ≥) op ≥ g ⟧
⟹ monotone (fun_ord op ≥) op ≥ (λx. f x + g x :: enat)"

lemma mono2mono_ereal[THEN lfp.mono2mono]:
shows monotone_ereal: "monotone op ≤ op ≤ ereal"
by(rule monotoneI) simp

lemma mono2mono_ennreal[THEN lfp.mono2mono]:
shows monotone_ennreal: "monotone op ≤ op ≤ ennreal"

subsubsection ‹Bijections›

lemma bi_unique_rel_set_bij_betw:
assumes unique: "bi_unique R"
and rel: "rel_set R A B"
shows "∃f. bij_betw f A B ∧ (∀x∈A. R x (f x))"
proof -
from assms obtain f where f: "⋀x. x ∈ A ⟹ R x (f x)" and B: "⋀x. x ∈ A ⟹ f x ∈ B"
apply(atomize_elim)
apply(fold all_conj_distrib)
apply(subst choice_iff[symmetric])
apply(auto dest: rel_setD1)
done
have "inj_on f A" by(rule inj_onI)(auto dest!: f dest: bi_uniqueDl[OF unique])
moreover have "f ` A = B" using rel
by(auto 4 3 intro: B dest: rel_setD2 f bi_uniqueDr[OF unique])
ultimately have "bij_betw f A B" unfolding bij_betw_def ..
thus ?thesis using f by blast
qed

lemma bij_betw_rel_setD: "bij_betw f A B ⟹ rel_set (λx y. y = f x) A B"
by(rule rel_setI)(auto dest: bij_betwE bij_betw_imp_surj_on[symmetric])

subsection ‹Subprobability mass function›

type_synonym 'a spmf = "'a option pmf"
translations (type) "'a spmf" ↽ (type) "'a option pmf"

definition measure_spmf :: "'a spmf ⇒ 'a measure"
where "measure_spmf p = distr (restrict_space (measure_pmf p) (range Some)) (count_space UNIV) the"

abbreviation spmf :: "'a spmf ⇒ 'a ⇒ real"
where "spmf p x ≡ pmf p (Some x)"

lemma space_measure_spmf: "space (measure_spmf p) = UNIV"

lemma sets_measure_spmf [simp, measurable_cong]: "sets (measure_spmf p) = sets (count_space UNIV)"

lemma measure_spmf_not_bot [simp]: "measure_spmf p ≠ ⊥"
proof
assume "measure_spmf p = ⊥"
hence "space (measure_spmf p) = space ⊥" by simp
qed

lemma measurable_the_measure_pmf_Some [measurable, simp]:
"the ∈ measurable (restrict_space (measure_pmf p) (range Some)) (count_space UNIV)"
by(auto simp add: measurable_def sets_restrict_space space_restrict_space integral_restrict_space)

lemma measurable_spmf_measure1[simp]: "measurable (measure_spmf M) N = UNIV → space N"
by(auto simp: measurable_def space_measure_spmf)

lemma measurable_spmf_measure2[simp]: "measurable N (measure_spmf M) = measurable N (count_space UNIV)"
by(intro measurable_cong_sets) simp_all

lemma subprob_space_measure_spmf [simp, intro!]: "subprob_space (measure_spmf p)"
proof
show "emeasure (measure_spmf p) (space (measure_spmf p)) ≤ 1"
by(simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space measure_pmf.measure_le_1)

interpretation measure_spmf: subprob_space "measure_spmf p" for p
by(rule subprob_space_measure_spmf)

lemma finite_measure_spmf [simp]: "finite_measure (measure_spmf p)"
by unfold_locales

lemma spmf_conv_measure_spmf: "spmf p x = measure (measure_spmf p) {x}"
by(auto simp add: measure_spmf_def measure_distr measure_restrict_space pmf.rep_eq space_restrict_space intro: arg_cong2[where f=measure])

lemma emeasure_measure_spmf_conv_measure_pmf:
"emeasure (measure_spmf p) A = emeasure (measure_pmf p) (Some ` A)"
by(auto simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])

lemma measure_measure_spmf_conv_measure_pmf:
"measure (measure_spmf p) A = measure (measure_pmf p) (Some ` A)"
using emeasure_measure_spmf_conv_measure_pmf[of p A]

lemma emeasure_spmf_map_pmf_Some [simp]:
"emeasure (measure_spmf (map_pmf Some p)) A = emeasure (measure_pmf p) A"
by(auto simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])

lemma measure_spmf_map_pmf_Some [simp]:
"measure (measure_spmf (map_pmf Some p)) A = measure (measure_pmf p) A"
using emeasure_spmf_map_pmf_Some[of p A] by(simp add: measure_spmf.emeasure_eq_measure measure_pmf.emeasure_eq_measure)

lemma nn_integral_measure_spmf: "(∫⇧+ x. f x ∂measure_spmf p) = ∫⇧+ x. ennreal (spmf p x) * f x ∂count_space UNIV"
(is "?lhs = ?rhs")
proof -
have "?lhs = ∫⇧+ x. pmf p x * f (the x) ∂count_space (range Some)"
by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space nn_integral_measure_pmf nn_integral_count_space_indicator ac_simps times_ereal.simps(1)[symmetric] del: times_ereal.simps(1))
also have "… = ∫⇧+ x. ennreal (spmf p (the x)) * f (the x) ∂count_space (range Some)"
by(rule nn_integral_cong) auto
also have "… = ∫⇧+ x. spmf p (the (Some x)) * f (the (Some x)) ∂count_space UNIV"
also have "… = ?rhs" by simp
finally show ?thesis .
qed

lemma integral_measure_spmf:
assumes "integrable (measure_spmf p) f"
shows "(∫ x. f x ∂measure_spmf p) = ∫ x. spmf p x * f x ∂count_space UNIV"
proof -
have "integrable (count_space UNIV) (λx. spmf p x * f x)"
using assms by(simp add: integrable_iff_bounded nn_integral_measure_spmf abs_mult ennreal_mult'')
then show ?thesis using assms
qed

lemma emeasure_spmf_single: "emeasure (measure_spmf p) {x} = spmf p x"

lemma measurable_measure_spmf[measurable]:
"(λx. measure_spmf (M x)) ∈ measurable (count_space UNIV) (subprob_algebra (count_space UNIV))"
by (auto simp: space_subprob_algebra)

lemma nn_integral_measure_spmf_conv_measure_pmf:
assumes [measurable]: "f ∈ borel_measurable (count_space UNIV)"
shows "nn_integral (measure_spmf p) f = nn_integral (restrict_space (measure_pmf p) (range Some)) (f ∘ the)"

lemma measure_spmf_in_space_subprob_algebra [simp]:
"measure_spmf p ∈ space (subprob_algebra (count_space UNIV))"

lemma nn_integral_spmf_neq_top: "(∫⇧+ x. spmf p x ∂count_space UNIV) ≠ ⊤"
using nn_integral_measure_spmf[where f="λ_. 1", of p, symmetric] by simp

lemma SUP_spmf_neq_top': "(SUP p:Y. ennreal (spmf p x)) ≠ ⊤"
proof(rule neq_top_trans)
show "(SUP p:Y. ennreal (spmf p x)) ≤ 1" by(rule SUP_least)(simp add: pmf_le_1)
qed simp

lemma SUP_spmf_neq_top: "(SUP i. ennreal (spmf (Y i) x)) ≠ ⊤"
proof(rule neq_top_trans)
show "(SUP i. ennreal (spmf (Y i) x)) ≤ 1" by(rule SUP_least)(simp add: pmf_le_1)
qed simp

lemma SUP_emeasure_spmf_neq_top: "(SUP p:Y. emeasure (measure_spmf p) A) ≠ ⊤"
proof(rule neq_top_trans)
show "(SUP p:Y. emeasure (measure_spmf p) A) ≤ 1"
qed simp

subsection ‹Support›

definition set_spmf :: "'a spmf ⇒ 'a set"
where "set_spmf p = set_pmf p ⤜ set_option"

lemma set_spmf_rep_eq: "set_spmf p = {x. measure (measure_spmf p) {x} ≠ 0}"
proof -
have "⋀x :: 'a. the -` {x} ∩ range Some = {Some x}" by auto
then show ?thesis
by(auto simp add: set_spmf_def set_pmf.rep_eq measure_spmf_def measure_distr measure_restrict_space space_restrict_space intro: rev_image_eqI)
qed

lemma in_set_spmf: "x ∈ set_spmf p ⟷ Some x ∈ set_pmf p"

lemma AE_measure_spmf_iff [simp]: "(AE x in measure_spmf p. P x) ⟷ (∀x∈set_spmf p. P x)"
by(auto 4 3 simp add: measure_spmf_def AE_distr_iff AE_restrict_space_iff AE_measure_pmf_iff set_spmf_def cong del: AE_cong)

lemma spmf_eq_0_set_spmf: "spmf p x = 0 ⟷ x ∉ set_spmf p"
by(auto simp add: pmf_eq_0_set_pmf set_spmf_def intro: rev_image_eqI)

lemma in_set_spmf_iff_spmf: "x ∈ set_spmf p ⟷ spmf p x ≠ 0"
by(auto simp add: set_spmf_def set_pmf_iff intro: rev_image_eqI)

lemma set_spmf_return_pmf_None [simp]: "set_spmf (return_pmf None) = {}"

lemma countable_set_spmf [simp]: "countable (set_spmf p)"

lemma spmf_eqI:
assumes "⋀i. spmf p i = spmf q i"
shows "p = q"
proof(rule pmf_eqI)
fix i
show "pmf p i = pmf q i"
proof(cases i)
case (Some i')
next
case None
have "ennreal (pmf p i) = measure (measure_pmf p) {i}" by(simp add: pmf_def)
also have "{i} = space (measure_pmf p) - range Some"
by(auto simp add: None intro: ccontr)
also have "measure (measure_pmf p) … = ennreal 1 - measure (measure_pmf p) (range Some)"
by(simp add: measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf)
also have "range Some = (⋃x∈set_spmf p. {Some x}) ∪ Some ` (- set_spmf p)"
by auto
also have "measure (measure_pmf p) … = measure (measure_pmf p) (⋃x∈set_spmf p. {Some x})"
by(rule measure_pmf.measure_zero_union)(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff)
also have "ennreal … = ∫⇧+ x. measure (measure_pmf p) {Some x} ∂count_space (set_spmf p)"
unfolding measure_pmf.emeasure_eq_measure[symmetric]
also have "… = ∫⇧+ x. spmf p x ∂count_space (set_spmf p)" by(simp add: pmf_def)
also have "… = ∫⇧+ x. spmf q x ∂count_space (set_spmf p)" by(simp add: assms)
also have "set_spmf p = set_spmf q" by(auto simp add: in_set_spmf_iff_spmf assms)
also have "(∫⇧+ x. spmf q x ∂count_space (set_spmf q)) = ∫⇧+ x. measure (measure_pmf q) {Some x} ∂count_space (set_spmf q)"
also have "… = measure (measure_pmf q) (⋃x∈set_spmf q. {Some x})"
unfolding measure_pmf.emeasure_eq_measure[symmetric]
also have "… = measure (measure_pmf q) ((⋃x∈set_spmf q. {Some x}) ∪ Some ` (- set_spmf q))"
by(rule ennreal_cong measure_pmf.measure_zero_union[symmetric])+(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff)
also have "((⋃x∈set_spmf q. {Some x}) ∪ Some ` (- set_spmf q)) = range Some" by auto
also have "ennreal 1 - measure (measure_pmf q) … = measure (measure_pmf q) (space (measure_pmf q) - range Some)"
by(simp add: one_ereal_def measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf)
also have "space (measure_pmf q) - range Some = {i}"
by(auto simp add: None intro: ccontr)
also have "measure (measure_pmf q) … = pmf q i" by(simp add: pmf_def)
finally show ?thesis by simp
qed
qed

lemma integral_measure_spmf_restrict:
fixes f ::  "'a ⇒ 'b :: {banach, second_countable_topology}" shows
"(∫ x. f x ∂measure_spmf M) = (∫ x. f x ∂restrict_space (measure_spmf M) (set_spmf M))"
by(auto intro!: integral_cong_AE simp add: integral_restrict_space)

lemma nn_integral_measure_spmf':
"(∫⇧+ x. f x ∂measure_spmf p) = ∫⇧+ x. ennreal (spmf p x) * f x ∂count_space (set_spmf p)"
by(auto simp add: nn_integral_measure_spmf nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator)

subsection ‹Functorial structure›

abbreviation map_spmf :: "('a ⇒ 'b) ⇒ 'a spmf ⇒ 'b spmf"
where "map_spmf f ≡ map_pmf (map_option f)"

context begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "spmf")›

lemma map_comp: "map_spmf f (map_spmf g p) = map_spmf (f ∘ g) p"

lemma map_id0: "map_spmf id = id"

lemma map_id [simp]: "map_spmf id p = p"

lemma map_ident [simp]: "map_spmf (λx. x) p = p"

end

lemma set_map_spmf [simp]: "set_spmf (map_spmf f p) = f ` set_spmf p"
by(simp add: set_spmf_def image_bind bind_image o_def Option.option.set_map)

lemma map_spmf_cong:
"⟦ p = q; ⋀x. x ∈ set_spmf q ⟹ f x = g x ⟧
⟹ map_spmf f p = map_spmf g q"
by(auto intro: pmf.map_cong option.map_cong simp add: in_set_spmf)

lemma map_spmf_cong_simp:
"⟦ p = q; ⋀x. x ∈ set_spmf q =simp=> f x = g x ⟧
⟹ map_spmf f p = map_spmf g q"
unfolding simp_implies_def by(rule map_spmf_cong)

lemma map_spmf_idI: "(⋀x. x ∈ set_spmf p ⟹ f x = x) ⟹ map_spmf f p = p"

lemma emeasure_map_spmf:
"emeasure (measure_spmf (map_spmf f p)) A = emeasure (measure_spmf p) (f -` A)"
by(auto simp add: measure_spmf_def emeasure_distr measurable_restrict_space1 space_restrict_space emeasure_restrict_space intro: arg_cong2[where f=emeasure])

lemma measure_map_spmf: "measure (measure_spmf (map_spmf f p)) A = measure (measure_spmf p) (f -` A)"
using emeasure_map_spmf[of f p A] by(simp add: measure_spmf.emeasure_eq_measure)

lemma measure_map_spmf_conv_distr:
"measure_spmf (map_spmf f p) = distr (measure_spmf p) (count_space UNIV) f"

lemma spmf_map_pmf_Some [simp]: "spmf (map_pmf Some p) i = pmf p i"

lemma spmf_map_inj: "⟦ inj_on f (set_spmf M); x ∈ set_spmf M ⟧ ⟹ spmf (map_spmf f M) (f x) = spmf M x"
by(subst option.map(2)[symmetric, where f=f])(rule pmf_map_inj, auto simp add: in_set_spmf inj_on_def elim!: option.inj_map_strong[rotated])

lemma spmf_map_inj': "inj f ⟹ spmf (map_spmf f M) (f x) = spmf M x"
by(subst option.map(2)[symmetric, where f=f])(rule pmf_map_inj'[OF option.inj_map])

lemma spmf_map_outside: "x ∉ f ` set_spmf M ⟹ spmf (map_spmf f M) x = 0"
unfolding spmf_eq_0_set_spmf by simp

lemma ennreal_spmf_map: "ennreal (spmf (map_spmf f p) x) = emeasure (measure_spmf p) (f -` {x})"
by(auto simp add: ennreal_pmf_map measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure])

lemma spmf_map: "spmf (map_spmf f p) x = measure (measure_spmf p) (f -` {x})"
using ennreal_spmf_map[of f p x] by(simp add: measure_spmf.emeasure_eq_measure)

lemma ennreal_spmf_map_conv_nn_integral:
"ennreal (spmf (map_spmf f p) x) = integral⇧N (measure_spmf p) (indicator (f -` {x}))"
by(auto simp add: ennreal_pmf_map measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space intro: arg_cong2[where f=emeasure])

subsubsection ‹Return›

abbreviation return_spmf :: "'a ⇒ 'a spmf"
where "return_spmf x ≡ return_pmf (Some x)"

lemma pmf_return_spmf: "pmf (return_spmf x) y = indicator {y} (Some x)"
by(fact pmf_return)

lemma measure_spmf_return_spmf: "measure_spmf (return_spmf x) = Giry_Monad.return (count_space UNIV) x"
by(rule measure_eqI)(simp_all add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space indicator_def)

lemma measure_spmf_return_pmf_None [simp]: "measure_spmf (return_pmf None) = null_measure (count_space UNIV)"
by(rule measure_eqI)(auto simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space indicator_eq_0_iff)

lemma set_return_spmf [simp]: "set_spmf (return_spmf x) = {x}"

subsubsection ‹Bind›

definition bind_spmf :: "'a spmf ⇒ ('a ⇒ 'b spmf) ⇒ 'b spmf"
where "bind_spmf x f = bind_pmf x (λa. case a of None ⇒ return_pmf None | Some a' ⇒ f a')"

lemma return_None_bind_spmf [simp]: "return_pmf None ⤜ (f :: 'a ⇒ _) = return_pmf None"

lemma return_bind_spmf [simp]: "return_spmf x ⤜ f = f x"

lemma bind_return_spmf [simp]: "x ⤜ return_spmf = x"
proof -
have "⋀a :: 'a option. (case a of None ⇒ return_pmf None | Some a' ⇒ return_spmf a') = return_pmf a"
by(simp split: option.split)
then show ?thesis
qed

lemma bind_spmf_assoc [simp]:
fixes x :: "'a spmf" and f :: "'a ⇒ 'b spmf" and g :: "'b ⇒ 'c spmf"
shows "(x ⤜ f) ⤜ g = x ⤜ (λy. f y ⤜ g)"
by(auto simp add: bind_spmf_def bind_assoc_pmf fun_eq_iff bind_return_pmf split: option.split intro: arg_cong[where f="bind_pmf x"])

lemma pmf_bind_spmf_None: "pmf (p ⤜ f) None = pmf p None + ∫ x. pmf (f x) None ∂measure_spmf p"
(is "?lhs = ?rhs")
proof -
let ?f = "λx. pmf (case x of None ⇒ return_pmf None | Some x ⇒ f x) None"
have "?lhs = ∫ x. ?f x ∂measure_pmf p"
also have "… = ∫ x. ?f None * indicator {None} x + ?f x * indicator (range Some) x ∂measure_pmf p"
also have "… = (∫ x. ?f None * indicator {None} x ∂measure_pmf p) + (∫ x. ?f x * indicator (range Some) x ∂measure_pmf p)"
by(rule Bochner_Integration.integral_add)(auto 4 3 intro: integrable_real_mult_indicator measure_pmf.integrable_const_bound[where B=1] simp add: AE_measure_pmf_iff pmf_le_1)
also have "… = pmf p None + ∫ x. indicator (range Some) x * pmf (f (the x)) None ∂measure_pmf p"
by(auto simp add: measure_measure_pmf_finite indicator_eq_0_iff intro!: Bochner_Integration.integral_cong)
also have "… = ?rhs" unfolding measure_spmf_def
finally show ?thesis .
qed

lemma spmf_bind: "spmf (p ⤜ f) y = ∫ x. spmf (f x) y ∂measure_spmf p"
unfolding measure_spmf_def
by(subst integral_distr)(auto simp add: bind_spmf_def pmf_bind integral_restrict_space indicator_eq_0_iff intro!: Bochner_Integration.integral_cong split: option.split)

lemma ennreal_spmf_bind: "ennreal (spmf (p ⤜ f) x) = ∫⇧+ y. spmf (f y) x ∂measure_spmf p"
by(auto simp add: bind_spmf_def ennreal_pmf_bind nn_integral_measure_spmf_conv_measure_pmf nn_integral_restrict_space intro: nn_integral_cong split: split_indicator option.split)

lemma measure_spmf_bind_pmf: "measure_spmf (p ⤜ f) = measure_pmf p ⤜ measure_spmf ∘ f"
(is "?lhs = ?rhs")
proof(rule measure_eqI)
show "sets ?lhs = sets ?rhs"
by(simp add: sets_bind[where N="count_space UNIV"] space_measure_spmf)
next
fix A :: "'a set"
have "emeasure ?lhs A = ∫⇧+ x. emeasure (measure_spmf (f x)) A ∂measure_pmf p"
by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def)
also have "… = emeasure ?rhs A"
by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra)
finally show "emeasure ?lhs A = emeasure ?rhs A" .
qed

lemma measure_spmf_bind: "measure_spmf (p ⤜ f) = measure_spmf p ⤜ measure_spmf ∘ f"
(is "?lhs = ?rhs")
proof(rule measure_eqI)
show "sets ?lhs = sets ?rhs"
by(simp add: sets_bind[where N="count_space UNIV"] space_measure_spmf)
next
fix A :: "'a set"
let ?A = "the -` A ∩ range Some"
have "emeasure ?lhs A = ∫⇧+ x. emeasure (measure_pmf (case x of None ⇒ return_pmf None | Some x ⇒ f x)) ?A ∂measure_pmf p"
by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def)
also have "… =  ∫⇧+ x. emeasure (measure_pmf (f (the x))) ?A * indicator (range Some) x ∂measure_pmf p"
by(rule nn_integral_cong)(auto split: option.split simp add: indicator_def)
also have "… = ∫⇧+ x. emeasure (measure_spmf (f x)) A ∂measure_spmf p"
by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space emeasure_distr space_restrict_space emeasure_restrict_space)
also have "… = emeasure ?rhs A"
by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra)
finally show "emeasure ?lhs A = emeasure ?rhs A" .
qed

lemma map_spmf_bind_spmf: "map_spmf f (bind_spmf p g) = bind_spmf p (map_spmf f ∘ g)"
by(auto simp add: bind_spmf_def map_bind_pmf fun_eq_iff split: option.split intro: arg_cong2[where f=bind_pmf])

lemma bind_map_spmf: "map_spmf f p ⤜ g = p ⤜ g ∘ f"
by(simp add: bind_spmf_def bind_map_pmf o_def cong del: option.case_cong_weak)

lemma spmf_bind_leI:
assumes "⋀y. y ∈ set_spmf p ⟹ spmf (f y) x ≤ r"
and "0 ≤ r"
shows "spmf (bind_spmf p f) x ≤ r"
proof -
have "ennreal (spmf (bind_spmf p f) x) = ∫⇧+ y. spmf (f y) x ∂measure_spmf p" by(rule ennreal_spmf_bind)
also have "… ≤ ∫⇧+ y. r ∂measure_spmf p" by(rule nn_integral_mono_AE)(simp add: assms)
also have "… ≤ r" using assms measure_spmf.emeasure_space_le_1
by(auto simp add: measure_spmf.emeasure_eq_measure intro!: mult_left_le)
finally show ?thesis using assms(2) by(simp)
qed

lemma map_spmf_conv_bind_spmf: "map_spmf f p = (p ⤜ (λx. return_spmf (f x)))"
by(simp add: map_pmf_def bind_spmf_def)(rule bind_pmf_cong, simp_all split: option.split)

lemma bind_spmf_cong:
"⟦ p = q; ⋀x. x ∈ set_spmf q ⟹ f x = g x ⟧
⟹ bind_spmf p f = bind_spmf q g"
by(auto simp add: bind_spmf_def in_set_spmf intro: bind_pmf_cong option.case_cong)

lemma bind_spmf_cong_simp:
"⟦ p = q; ⋀x. x ∈ set_spmf q =simp=> f x = g x ⟧
⟹ bind_spmf p f = bind_spmf q g"

lemma set_bind_spmf: "set_spmf (M ⤜ f) = set_spmf M ⤜ (set_spmf ∘ f)"
by(auto simp add: set_spmf_def bind_spmf_def bind_UNION split: option.splits)

lemma bind_spmf_const_return_None [simp]: "bind_spmf p (λ_. return_pmf None) = return_pmf None"

lemma bind_commute_spmf:
"bind_spmf p (λx. bind_spmf q (f x)) = bind_spmf q (λy. bind_spmf p (λx. f x y))"
(is "?lhs = ?rhs")
proof -
let ?f = "λx y. case x of None ⇒ return_pmf None | Some a ⇒ (case y of None ⇒ return_pmf None | Some b ⇒ f a b)"
have "?lhs = p ⤜ (λx. q ⤜ ?f x)"
unfolding bind_spmf_def by(rule bind_pmf_cong[OF refl])(simp split: option.split)
also have "… = q ⤜ (λy. p ⤜ (λx. ?f x y))" by(rule bind_commute_pmf)
also have "… = ?rhs" unfolding bind_spmf_def
by(rule bind_pmf_cong[OF refl])(auto split: option.split, metis bind_spmf_const_return_None bind_spmf_def)
finally show ?thesis .
qed

subsection ‹Relator›

abbreviation rel_spmf :: "('a ⇒ 'b ⇒ bool) ⇒ 'a spmf ⇒ 'b spmf ⇒ bool"
where "rel_spmf R ≡ rel_pmf (rel_option R)"

lemma rel_pmf_mono:
"⟦rel_pmf A f g; ⋀x y. A x y ⟹ B x y ⟧ ⟹ rel_pmf B f g"
using pmf.rel_mono[of A B] by(simp add: le_fun_def)

lemma rel_spmf_mono:
"⟦rel_spmf A f g; ⋀x y. A x y ⟹ B x y ⟧ ⟹ rel_spmf B f g"
apply(erule rel_pmf_mono)
using option.rel_mono[of A B] by(simp add: le_fun_def)

lemma rel_spmf_mono_strong:
"⟦ rel_spmf A f g; ⋀x y. ⟦ A x y; x ∈ set_spmf f; y ∈ set_spmf g ⟧ ⟹ B x y ⟧ ⟹ rel_spmf B f g"
apply(erule pmf.rel_mono_strong)
apply(erule option.rel_mono_strong)
done

lemma rel_spmf_reflI: "(⋀x. x ∈ set_spmf p ⟹ P x x) ⟹ rel_spmf P p p"
by(rule rel_pmf_reflI)(auto simp add: set_spmf_def intro: rel_option_reflI)

lemma rel_spmfI [intro?]:
"⟦ ⋀x y. (x, y) ∈ set_spmf pq ⟹ P x y; map_spmf fst pq = p; map_spmf snd pq = q ⟧
⟹ rel_spmf P p q"
by(rule rel_pmf.intros[where pq="map_pmf (λx. case x of None ⇒ (None, None) | Some (a, b) ⇒ (Some a, Some b)) pq"])
(auto simp add: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong)

lemma rel_spmfE [elim?, consumes 1, case_names rel_spmf]:
assumes "rel_spmf P p q"
obtains pq where
"⋀x y. (x, y) ∈ set_spmf pq ⟹ P x y"
"p = map_spmf fst pq"
"q = map_spmf snd pq"
using assms
proof(cases rule: rel_pmf.cases[consumes 1, case_names rel_pmf])
case (rel_pmf pq)
let ?pq = "map_pmf (λ(a, b). case (a, b) of (Some x, Some y) ⇒ Some (x, y) | _ ⇒ None) pq"
have "⋀x y. (x, y) ∈ set_spmf ?pq ⟹ P x y"
by(auto simp add: in_set_spmf split: option.split_asm dest: rel_pmf(1))
moreover
have "⋀x. (x, None) ∈ set_pmf pq ⟹ x = None" by(auto dest!: rel_pmf(1))
then have "p = map_spmf fst ?pq" using rel_pmf(2)
by(auto simp add: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split)
moreover
have "⋀y. (None, y) ∈ set_pmf pq ⟹ y = None" by(auto dest!: rel_pmf(1))
then have "q = map_spmf snd ?pq" using rel_pmf(3)
by(auto simp add: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split)
ultimately show thesis ..
qed

lemma rel_spmf_simps:
"rel_spmf R p q ⟷ (∃pq. (∀(x, y)∈set_spmf pq. R x y) ∧ map_spmf fst pq = p ∧ map_spmf snd pq = q)"
by(auto intro: rel_spmfI elim!: rel_spmfE)

lemma spmf_rel_map:
shows spmf_rel_map1: "⋀R f x. rel_spmf R (map_spmf f x) = rel_spmf (λx. R (f x)) x"
and spmf_rel_map2: "⋀R x g y. rel_spmf R x (map_spmf g y) = rel_spmf (λx y. R x (g y)) x y"

lemma spmf_rel_conversep: "rel_spmf R¯¯ = (rel_spmf R)¯¯"

lemma spmf_rel_eq: "rel_spmf op = = op ="

context includes lifting_syntax
begin

lemma bind_spmf_parametric [transfer_rule]:
"(rel_spmf A ===> (A ===> rel_spmf B) ===> rel_spmf B) bind_spmf bind_spmf"
unfolding bind_spmf_def[abs_def] by transfer_prover

lemma return_spmf_parametric: "(A ===> rel_spmf A) return_spmf return_spmf"
by transfer_prover

lemma map_spmf_parametric: "((A ===> B) ===> rel_spmf A ===> rel_spmf B) map_spmf map_spmf"
by transfer_prover

lemma rel_spmf_parametric:
"((A ===> B ===> op =) ===> rel_spmf A ===> rel_spmf B ===> op =) rel_spmf rel_spmf"
by transfer_prover

lemma set_spmf_parametric [transfer_rule]:
"(rel_spmf A ===> rel_set A) set_spmf set_spmf"
unfolding set_spmf_def[abs_def] by transfer_prover

lemma return_spmf_None_parametric:
"(rel_spmf A) (return_pmf None) (return_pmf None)"
by simp

end

lemma rel_spmf_bindI:
"⟦ rel_spmf R p q; ⋀x y. R x y ⟹ rel_spmf P (f x) (g y) ⟧
⟹ rel_spmf P (p ⤜ f) (q ⤜ g)"
by(fact bind_spmf_parametric[THEN rel_funD, THEN rel_funD, OF _ rel_funI])

lemma rel_spmf_bind_reflI:
"(⋀x. x ∈ set_spmf p ⟹ rel_spmf P (f x) (g x)) ⟹ rel_spmf P (p ⤜ f) (p ⤜ g)"
by(rule rel_spmf_bindI[where R="λx y. x = y ∧ x ∈ set_spmf p"])(auto intro: rel_spmf_reflI)

lemma rel_pmf_return_pmfI: "P x y ⟹ rel_pmf P (return_pmf x) (return_pmf y)"
by(rule rel_pmf.intros[where pq="return_pmf (x, y)"])(simp_all)

context includes lifting_syntax
begin

text ‹We do not yet have a relator for @{typ "'a measure"}, so we combine @{const measure} and @{const measure_pmf}›
lemma measure_pmf_parametric:
"(rel_pmf A ===> rel_pred A ===> op =) (λp. measure (measure_pmf p)) (λq. measure (measure_pmf q))"
proof(rule rel_funI)+
fix p q X Y
assume "rel_pmf A p q" and "rel_pred A X Y"
from this(1) obtain pq where A: "⋀x y. (x, y) ∈ set_pmf pq ⟹ A x y"
and p: "p = map_pmf fst pq" and q: "q = map_pmf snd pq" by cases auto
show "measure p X = measure q Y" unfolding p q measure_map_pmf
by(rule measure_pmf.finite_measure_eq_AE)(auto simp add: AE_measure_pmf_iff dest!: A rel_predD[OF ‹rel_pred _ _ _›])
qed

lemma measure_spmf_parametric:
"(rel_spmf A ===> rel_pred A ===> op =) (λp. measure (measure_spmf p)) (λq. measure (measure_spmf q))"
unfolding measure_measure_spmf_conv_measure_pmf[abs_def]
apply(rule rel_funI)+
apply(erule measure_pmf_parametric[THEN rel_funD, THEN rel_funD])
apply(auto simp add: rel_pred_def rel_fun_def elim: option.rel_cases)
done

end

subsection ‹From @{typ "'a pmf"} to @{typ "'a spmf"}›

definition spmf_of_pmf :: "'a pmf ⇒ 'a spmf"
where "spmf_of_pmf = map_pmf Some"

lemma set_spmf_spmf_of_pmf [simp]: "set_spmf (spmf_of_pmf p) = set_pmf p"
by(auto simp add: spmf_of_pmf_def set_spmf_def bind_image o_def)

lemma spmf_spmf_of_pmf [simp]: "spmf (spmf_of_pmf p) x = pmf p x"

lemma pmf_spmf_of_pmf_None [simp]: "pmf (spmf_of_pmf p) None = 0"
using ennreal_pmf_map[of Some p None] by(simp add: spmf_of_pmf_def)

lemma emeasure_spmf_of_pmf [simp]: "emeasure (measure_spmf (spmf_of_pmf p)) A = emeasure (measure_pmf p) A"

lemma measure_spmf_spmf_of_pmf [simp]: "measure_spmf (spmf_of_pmf p) = measure_pmf p"
by(rule measure_eqI) simp_all

lemma map_spmf_of_pmf [simp]: "map_spmf f (spmf_of_pmf p) = spmf_of_pmf (map_pmf f p)"

lemma rel_spmf_spmf_of_pmf [simp]: "rel_spmf R (spmf_of_pmf p) (spmf_of_pmf q) = rel_pmf R p q"

lemma spmf_of_pmf_return_pmf [simp]: "spmf_of_pmf (return_pmf x) = return_spmf x"

lemma bind_spmf_of_pmf [simp]: "bind_spmf (spmf_of_pmf p) f = bind_pmf p f"

lemma set_spmf_bind_pmf: "set_spmf (bind_pmf p f) = Set.bind (set_pmf p) (set_spmf ∘ f)"
unfolding bind_spmf_of_pmf[symmetric] by(subst set_bind_spmf) simp

lemma spmf_of_pmf_bind: "spmf_of_pmf (bind_pmf p f) = bind_pmf p (λx. spmf_of_pmf (f x))"

lemma bind_pmf_return_spmf: "p ⤜ (λx. return_spmf (f x)) = spmf_of_pmf (map_pmf f p)"

subsection ‹Weight of a subprobability›

abbreviation weight_spmf :: "'a spmf ⇒ real"
where "weight_spmf p ≡ measure (measure_spmf p) (space (measure_spmf p))"

lemma weight_spmf_def: "weight_spmf p = measure (measure_spmf p) UNIV"

lemma weight_spmf_le_1: "weight_spmf p ≤ 1"

lemma weight_return_spmf [simp]: "weight_spmf (return_spmf x) = 1"

lemma weight_return_pmf_None [simp]: "weight_spmf (return_pmf None) = 0"
by(simp)

lemma weight_map_spmf [simp]: "weight_spmf (map_spmf f p) = weight_spmf p"

lemma weight_spmf_of_pmf [simp]: "weight_spmf (spmf_of_pmf p) = 1"
using measure_pmf.prob_space[of p] by(simp add: spmf_of_pmf_def weight_spmf_def)

lemma weight_spmf_nonneg: "weight_spmf p ≥ 0"
by(fact measure_nonneg)

lemma (in finite_measure) integrable_weight_spmf [simp]:
"(λx. weight_spmf (f x)) ∈ borel_measurable M ⟹ integrable M (λx. weight_spmf (f x))"
by(rule integrable_const_bound[where B=1])(simp_all add: weight_spmf_nonneg weight_spmf_le_1)

lemma weight_spmf_eq_nn_integral_spmf: "weight_spmf p = ∫⇧+ x. spmf p x ∂count_space UNIV"
by(simp add: measure_measure_spmf_conv_measure_pmf space_measure_spmf measure_pmf.emeasure_eq_measure[symmetric] nn_integral_pmf[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1)

lemma weight_spmf_eq_nn_integral_support:
"weight_spmf p = ∫⇧+ x. spmf p x ∂count_space (set_spmf p)"
unfolding weight_spmf_eq_nn_integral_spmf
by(auto simp add: nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator)

lemma pmf_None_eq_weight_spmf: "pmf p None = 1 - weight_spmf p"
proof -
have "weight_spmf p = ∫⇧+ x. spmf p x ∂count_space UNIV" by(rule weight_spmf_eq_nn_integral_spmf)
also have "… = ∫⇧+ x. ennreal (pmf p x) * indicator (range Some) x ∂count_space UNIV"
by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1)
also have "… + pmf p None = ∫⇧+ x. ennreal (pmf p x) * indicator (range Some) x + ennreal (pmf p None) * indicator {None} x ∂count_space UNIV"
also have "… = ∫⇧+ x. pmf p x ∂count_space UNIV"
by(rule nn_integral_cong)(auto split: split_indicator)
also have "… = 1" by (simp add: nn_integral_pmf)
finally show ?thesis by(simp add: ennreal_plus[symmetric] del: ennreal_plus)
qed

lemma weight_spmf_conv_pmf_None: "weight_spmf p = 1 - pmf p None"

lemma weight_spmf_le_0: "weight_spmf p ≤ 0 ⟷ weight_spmf p = 0"
by(rule measure_le_0_iff)

lemma weight_spmf_lt_0: "¬ weight_spmf p < 0"

lemma spmf_le_weight: "spmf p x ≤ weight_spmf p"
proof -
have "ennreal (spmf p x) ≤ weight_spmf p"
unfolding weight_spmf_eq_nn_integral_spmf by(rule nn_integral_ge_point) simp
then show ?thesis by simp
qed

lemma weight_spmf_eq_0: "weight_spmf p = 0 ⟷ p = return_pmf None"
by(auto intro!: pmf_eqI simp add: pmf_None_eq_weight_spmf split: split_indicator)(metis not_Some_eq pmf_le_0_iff spmf_le_weight)

lemma weight_bind_spmf: "weight_spmf (x ⤜ f) = lebesgue_integral (measure_spmf x) (weight_spmf ∘ f)"
unfolding weight_spmf_def
by(simp add: measure_spmf_bind o_def measure_spmf.measure_bind[where N="count_space UNIV"])

lemma rel_spmf_weightD: "rel_spmf A p q ⟹ weight_spmf p = weight_spmf q"
by(erule rel_spmfE) simp

lemma rel_spmf_bij_betw:
assumes f: "bij_betw f (set_spmf p) (set_spmf q)"
and eq: "⋀x. x ∈ set_spmf p ⟹ spmf p x = spmf q (f x)"
shows "rel_spmf (λx y. f x = y) p q"
proof -
let ?f = "map_option f"

have weq: "ennreal (weight_spmf p) = ennreal (weight_spmf q)"
unfolding weight_spmf_eq_nn_integral_support
by(subst nn_integral_bij_count_space[OF f, symmetric])(rule nn_integral_cong_AE, simp add: eq AE_count_space)
then have "None ∈ set_pmf p ⟷ None ∈ set_pmf q"
with f have "bij_betw (map_option f) (set_pmf p) (set_pmf q)"
apply(auto simp add: bij_betw_def in_set_spmf inj_on_def intro: option.expand)
apply(rename_tac [!] x)
apply(case_tac [!] x)
apply(auto iff: in_set_spmf)
done
then have "rel_pmf (λx y. ?f x = y) p q"
by(rule rel_pmf_bij_betw)(case_tac x, simp_all add: weq[simplified] eq in_set_spmf pmf_None_eq_weight_spmf)
thus ?thesis by(rule pmf.rel_mono_strong)(auto intro!: rel_optionI simp add: Option.is_none_def)
qed

subsection ‹From density to spmfs›

context fixes f :: "'a ⇒ real" begin

definition embed_spmf :: "'a spmf"
where "embed_spmf = embed_pmf (λx. case x of None ⇒ 1 - enn2real (∫⇧+ x. ennreal (f x) ∂count_space UNIV) | Some x' ⇒ max 0 (f x'))"

context
assumes prob: "(∫⇧+ x. ennreal (f x) ∂count_space UNIV) ≤ 1"
begin

lemma nn_integral_embed_spmf_eq_1:
"(∫⇧+ x. ennreal (case x of None ⇒ 1 - enn2real (∫⇧+ x. ennreal (f x) ∂count_space UNIV) | Some x' ⇒ max 0 (f x')) ∂count_space UNIV) = 1"
(is "?lhs = _" is "(∫⇧+ x. ?f x ∂?M) = _")
proof -
have "?lhs = ∫⇧+ x. ?f x * indicator {None} x + ?f x * indicator (range Some) x ∂?M"
by(rule nn_integral_cong)(auto split: split_indicator)
also have "… = (1 - enn2real (∫⇧+ x. ennreal (f x) ∂count_space UNIV)) + ∫⇧+ x. ?f x * indicator (range Some) x ∂?M"
(is "_ = ?None + ?Some")
also have "?Some = ∫⇧+ x. ?f x ∂count_space (range Some)"
also have "count_space (range Some) = embed_measure (count_space UNIV) Some"
also have "(∫⇧+ x. ?f x ∂…) = ∫⇧+ x. ennreal (f x) ∂count_space UNIV"
also have "?None + … = 1" using prob
finally show ?thesis .
qed

lemma pmf_embed_spmf_None: "pmf embed_spmf None = 1 - enn2real (∫⇧+ x. ennreal (f x) ∂count_space UNIV)"
unfolding embed_spmf_def
apply(subst pmf_embed_pmf)
subgoal using prob by(simp add: field_simps enn2real_leI split: option.split)
apply(rule nn_integral_embed_spmf_eq_1)
apply simp
done

lemma spmf_embed_spmf [simp]: "spmf embed_spmf x = max 0 (f x)"
unfolding embed_spmf_def
apply(subst pmf_embed_pmf)
subgoal using prob by(simp add: field_simps enn2real_leI split: option.split)
apply(rule nn_integral_embed_spmf_eq_1)
apply simp
done

end

end

lemma embed_spmf_K_0[simp]: "embed_spmf (λ_. 0) = return_pmf None" (is "?lhs = ?rhs")

subsection ‹Ordering on spmfs›

text ‹
@{const rel_pmf} does not preserve a ccpo structure. Counterexample by Saheb-Djahromi:
Take prefix order over ‹bool llist› and
the set ‹range (λn :: nat. uniform (llist_n n))› where ‹llist_n› is the set
of all ‹llist›s of length ‹n› and ‹uniform› returns a uniform distribution over
the given set. The set forms a chain in ‹ord_pmf lprefix›, but it has not an upper bound.
Any upper bound may contain only infinite lists in its support because otherwise it is not greater
than the ‹n+1›-st element in the chain where ‹n› is the length of the finite list.
Moreover its support must contain all infinite lists, because otherwise there is a finite list
all of whose finite extensions are not in the support - a contradiction to the upper bound property.
Hence, the support is uncountable, but pmf's only have countable support.

However, if all chains in the ccpo are finite, then it should preserve the ccpo structure.
›

abbreviation ord_spmf :: "('a ⇒ 'a ⇒ bool) ⇒ 'a spmf ⇒ 'a spmf ⇒ bool"
where "ord_spmf ord ≡ rel_pmf (ord_option ord)"

locale ord_spmf_syntax begin
notation ord_spmf (infix "⊑ı" 60)
end

lemma ord_spmf_map_spmf1: "ord_spmf R (map_spmf f p) = ord_spmf (λx. R (f x)) p"

lemma ord_spmf_map_spmf2: "ord_spmf R p (map_spmf f q) = ord_spmf (λx y. R x (f y)) p q"

lemma ord_spmf_map_spmf12: "ord_spmf R (map_spmf f p) (map_spmf f q) = ord_spmf (λx y. R (f x) (f y)) p q"

lemmas ord_spmf_map_spmf = ord_spmf_map_spmf1 ord_spmf_map_spmf2 ord_spmf_map_spmf12

context fixes ord :: "'a ⇒ 'a ⇒ bool" (structure) begin
interpretation ord_spmf_syntax .

lemma ord_spmfI:
"⟦ ⋀x y. (x, y) ∈ set_spmf pq ⟹ ord x y; map_spmf fst pq = p; map_spmf snd pq = q ⟧
⟹ p ⊑ q"
by(rule rel_pmf.intros[where pq="map_pmf (λx. case x of None ⇒ (None, None) | Some (a, b) ⇒ (Some a, Some b)) pq"])
(auto simp add: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong)

lemma ord_spmf_None [simp]: "return_pmf None ⊑ x"
by(rule rel_pmf.intros[where pq="map_pmf (Pair None) x"])(auto simp add: pmf.map_comp o_def)

lemma ord_spmf_reflI: "(⋀x. x ∈ set_spmf p ⟹ ord x x) ⟹ p ⊑ p"
by(rule rel_pmf_reflI ord_option_reflI)+(auto simp add: in_set_spmf)

lemma rel_spmf_inf:
assumes "p ⊑ q"
and "q ⊑ p"
and refl: "reflp ord"
and trans: "transp ord"
shows "rel_spmf (inf ord ord¯¯) p q"
proof -
from ‹p ⊑ q› ‹q ⊑ p›
have "rel_pmf (inf (ord_option ord) (ord_option ord)¯¯) p q"
by(rule rel_pmf_inf)(blast intro: reflp_ord_option transp_ord_option refl trans)+
also have "inf (ord_option ord) (ord_option ord)¯¯ = rel_option (inf ord ord¯¯)"
by(auto simp add: fun_eq_iff elim: ord_option.cases option.rel_cases)
finally show ?thesis .
qed

end

lemma ord_spmf_return_spmf2: "ord_spmf R p (return_spmf y) ⟷ (∀x∈set_spmf p. R x y)"
by(auto simp add: rel_pmf_return_pmf2 in_set_spmf ord_option.simps intro: ccontr)

lemma ord_spmf_mono: "⟦ ord_spmf A p q; ⋀x y. A x y ⟹ B x y ⟧ ⟹ ord_spmf B p q"
by(erule rel_pmf_mono)(erule ord_option_mono)

lemma ord_spmf_compp: "ord_spmf (A OO B) = ord_spmf A OO ord_spmf B"

lemma ord_spmf_bindI:
assumes pq: "ord_spmf R p q"
and fg: "⋀x y. R x y ⟹ ord_spmf P (f x) (g y)"
shows "ord_spmf P (p ⤜ f) (q ⤜ g)"
unfolding bind_spmf_def using pq
by(rule rel_pmf_bindI)(auto split: option.split intro: fg)

lemma ord_spmf_bind_reflI:
"(⋀x. x ∈ set_spmf p ⟹ ord_spmf R (f x) (g x))
⟹ ord_spmf R (p ⤜ f) (p ⤜ g)"
by(rule ord_spmf_bindI[where R="λx y. x = y ∧ x ∈ set_spmf p"])(auto intro: ord_spmf_reflI)

lemma ord_pmf_increaseI:
assumes le: "⋀x. spmf p x ≤ spmf q x"
and refl: "⋀x. x ∈ set_spmf p ⟹ R x x"
shows "ord_spmf R p q"
proof(rule rel_pmf.intros)
define pq where "pq = embed_pmf
(λ(x, y). case x of Some x' ⇒ (case y of Some y' ⇒ if x' = y' then spmf p x' else 0 | None ⇒ 0)
| None ⇒ (case y of None ⇒ pmf q None | Some y' ⇒ spmf q y' - spmf p y'))"
(is "_ = embed_pmf ?f")
have nonneg: "⋀xy. ?f xy ≥ 0"
by(clarsimp simp add: le field_simps split: option.split)
have integral: "(∫⇧+ xy. ?f xy ∂count_space UNIV) = 1" (is "nn_integral ?M _ = _")
proof -
have "(∫⇧+ xy. ?f xy ∂count_space UNIV) =
∫⇧+ xy. ennreal (?f xy) * indicator {(None, None)} xy +
ennreal (?f xy) * indicator (range (λx. (None, Some x))) xy +
ennreal (?f xy) * indicator (range (λx. (Some x, Some x))) xy ∂?M"
by(rule nn_integral_cong)(auto split: split_indicator option.splits if_split_asm)
also have "… = (∫⇧+ xy. ?f xy * indicator {(None, None)} xy ∂?M) +
(∫⇧+ xy. ennreal (?f xy) * indicator (range (λx. (None, Some x))) xy ∂?M) +
(∫⇧+ xy. ennreal (?f xy) * indicator (range (λx. (Some x, Some x))) xy ∂?M)"
(is "_ = ?None + ?Some2 + ?Some")
also have "?None = pmf q None" by simp
also have "?Some2 = ∫⇧+ x. ennreal (spmf q x) - spmf p x ∂count_space UNIV"
by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1 ennreal_minus)
also have "… = (∫⇧+ x. spmf q x ∂count_space UNIV) - (∫⇧+ x. spmf p x ∂count_space UNIV)"
(is "_ = ?Some2' - ?Some2''")
also have "?Some = ∫⇧+ x. spmf p x ∂count_space UNIV"
by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1)
also have "pmf q None + (?Some2' - ?Some2'') + … = pmf q None + ?Some2'"
also have "… = ∫⇧+ x. ennreal (pmf q x) * indicator {None} x + ennreal (pmf q x) * indicator (range Some) x ∂count_space UNIV"
also have "… = ∫⇧+ x. pmf q x ∂count_space UNIV"
by(rule nn_integral_cong)(auto split: split_indicator)
also have "… = 1" by(simp add: nn_integral_pmf)
finally show ?thesis .
qed
note f = nonneg integral

{ fix x y
assume "(x, y) ∈ set_pmf pq"
hence "?f (x, y) ≠ 0" unfolding pq_def by(simp add: set_embed_pmf[OF f])
then show "ord_option R x y"
by(simp add: spmf_eq_0_set_spmf refl split: option.split_asm if_split_asm) }

have weight_le: "weight_spmf p ≤ weight_spmf q"
by(subst ennreal_le_iff[symmetric])(auto simp add: weight_spmf_eq_nn_integral_spmf intro!: nn_integral_mono le)

show "map_pmf fst pq = p"
proof(rule pmf_eqI)
fix i
have "ennreal (pmf (map_pmf fst pq) i) = (∫⇧+ y. pmf pq (i, y) ∂count_space UNIV)"
unfolding pq_def ennreal_pmf_map
apply(simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density nn_integral_count_space_indicator[symmetric])
apply(subst pmf_embed_pmf[OF f])
apply(rule nn_integral_bij_count_space[symmetric])
done
also have "… = pmf p i"
proof(cases i)
case (Some x)
have "(∫⇧+ y. pmf pq (Some x, y) ∂count_space UNIV) = ∫⇧+ y. pmf p (Some x) * indicator {Some x} y ∂count_space UNIV"
by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split)
then show ?thesis using Some by simp
next
case None
have "(∫⇧+ y. pmf pq (None, y) ∂count_space UNIV) =
(∫⇧+ y. ennreal (pmf pq (None, Some (the y))) * indicator (range Some) y +
ennreal (pmf pq (None, None)) * indicator {None} y ∂count_space UNIV)"
by(rule nn_integral_cong)(auto split: split_indicator)
also have "… = (∫⇧+ y. ennreal (pmf pq (None, Some (the y))) ∂count_space (range Some)) + pmf pq (None, None)"
also have "… = (∫⇧+ y. ennreal (spmf q y) - ennreal (spmf p y) ∂count_space UNIV) + pmf q None"
by(simp add: pq_def pmf_embed_pmf[OF f] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1 ennreal_minus)
also have "(∫⇧+ y. ennreal (spmf q y) - ennreal (spmf p y) ∂count_space UNIV) =
(∫⇧+ y. spmf q y ∂count_space UNIV) - (∫⇧+ y. spmf p y ∂count_space UNIV)"
by(subst nn_integral_diff)(simp_all add: AE_count_space le nn_integral_spmf_neq_top split: split_indicator)
also have "… = pmf p None - pmf q None"
also have "… = ennreal (pmf p None) - ennreal (pmf q None)" by(simp add: ennreal_minus)
finally show ?thesis using None weight_le
qed
finally show "pmf (map_pmf fst pq) i = pmf p i" by simp
qed

show "map_pmf snd pq = q"
proof(rule pmf_eqI)
fix i
have "ennreal (pmf (map_pmf snd pq) i) = (∫⇧+ x. pmf pq (x, i) ∂count_space UNIV)"
unfolding pq_def ennreal_pmf_map
apply(simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density nn_integral_count_space_indicator[symmetric])
apply(subst pmf_embed_pmf[OF f])
apply(rule nn_integral_bij_count_space[symmetric])
done
also have "… = ennreal (pmf q i)"
proof(cases i)
case None
have "(∫⇧+ x. pmf pq (x, None) ∂count_space UNIV) = ∫⇧+ x. pmf q None * indicator {None :: 'a option} x ∂count_space UNIV"
by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split)
then show ?thesis using None by simp
next
case (Some y)
have "(∫⇧+ x. pmf pq (x, Some y) ∂count_space UNIV) =
(∫⇧+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x +
ennreal (pmf pq (None, Some y)) * indicator {None} x ∂count_space UNIV)"
by(rule nn_integral_cong)(auto split: split_indicator)
also have "… = (∫⇧+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x ∂count_space UNIV) + pmf pq (None, Some y)"
also have "… = (∫⇧+ x. ennreal (spmf p y) * indicator {Some y} x ∂count_space UNIV) + (spmf q y - spmf p y)"
by(auto simp add: pq_def pmf_embed_pmf[OF f] one_ereal_def[symmetric] simp del: nn_integral_indicator_singleton intro!: arg_cong2[where f="op +"] nn_integral_cong split: option.split)
also have "… = spmf q y" by(simp add: ennreal_minus[symmetric] le)
finally show ?thesis using Some by simp
qed
finally show "pmf (map_pmf snd pq) i = pmf q i" by simp
qed
qed

lemma ord_spmf_eq_leD:
assumes "ord_spmf op = p q"
shows "spmf p x ≤ spmf q x"
proof(cases "x ∈ set_spmf p")
case False
next
case True
from assms obtain pq
where pq: "⋀x y. (x, y) ∈ set_pmf pq ⟹ ord_option op = x y"
and p: "p = map_pmf fst pq"
and q: "q = map_pmf snd pq" by cases auto
have "ennreal (spmf p x) = integral⇧N pq (indicator (fst -` {Some x}))"
also have "… = integral⇧N pq (indicator {(Some x, Some x)})"
by(rule nn_integral_cong_AE)(auto simp add: AE_measure_pmf_iff split: split_indicator dest: pq)
also have "… ≤ integral⇧N pq (indicator (snd -` {Some x}))"
by(rule nn_integral_mono) simp
also have "… = ennreal (spmf q x)" using q by(simp add: ennreal_pmf_map)
finally show ?thesis by simp
qed

lemma ord_spmf_eqD_set_spmf: "ord_spmf op = p q ⟹ set_spmf p ⊆ set_spmf q"
by(rule subsetI)(drule_tac x=x in ord_spmf_eq_leD, auto simp add: in_set_spmf_iff_spmf)

lemma ord_spmf_eqD_emeasure:
"ord_spmf op = p q ⟹ emeasure (measure_spmf p) A ≤ emeasure (measure_spmf q) A"
by(auto intro!: nn_integral_mono split: split_indicator dest: ord_spmf_eq_leD simp add: nn_integral_measure_spmf nn_integral_indicator[symmetric])

lemma ord_spmf_eqD_measure_spmf: "ord_spmf op = p q ⟹ measure_spmf p ≤ measure_spmf q"
by (subst le_measure) (auto simp: ord_spmf_eqD_emeasure)

subsection ‹CCPO structure for the flat ccpo @{term "ord_option op ="}›

context fixes Y :: "'a spmf set" begin

definition lub_spmf :: "'a spmf"
where "lub_spmf = embed_spmf (λx. enn2real (SUP p : Y. ennreal (spmf p x)))"
― ‹We go through @{typ ennreal} to have a sensible definition even if @{term Y} is empty.›

lemma lub_spmf_empty [simp]: "SPMF.lub_spmf {} = return_pmf None"

context assumes chain: "Complete_Partial_Order.chain (ord_spmf op =) Y" begin

lemma chain_ord_spmf_eqD: "Complete_Partial_Order.chain (op ≤) ((λp x. ennreal (spmf p x)) ` Y)"
(is "Complete_Partial_Order.chain _ (?f ` _)")
proof(rule chainI)
fix f g
assume "f ∈ ?f ` Y" "g ∈ ?f ` Y"
then obtain p q where f: "f = ?f p" "p ∈ Y" and g: "g = ?f q" "q ∈ Y" by blast
from chain ‹p ∈ Y› ‹q ∈ Y› have "ord_spmf op = p q ∨ ord_spmf op = q p" by(rule chainD)
thus "f ≤ g ∨ g ≤ f"
proof
assume "ord_spmf op = p q"
hence "⋀x. spmf p x ≤ spmf q x" by(rule ord_spmf_eq_leD)
hence "f ≤ g" unfolding f g by(auto intro: le_funI)
thus ?thesis ..
next
assume "ord_spmf op = q p"
hence "⋀x. spmf q x ≤ spmf p x" by(rule ord_spmf_eq_leD)
hence "g ≤ f" unfolding f g by(auto intro: le_funI)
thus ?thesis ..
qed
qed

lemma ord_spmf_eq_pmf_None_eq:
assumes le: "ord_spmf op = p q"
and None: "pmf p None = pmf q None"
shows "p = q"
proof(rule spmf_eqI)
fix i
from le have le': "⋀x. spmf p x ≤ spmf q x" by(rule ord_spmf_eq_leD)
have "(∫⇧+ x. ennreal (spmf q x) - spmf p x ∂count_space UNIV) =
(∫⇧+ x. spmf q x ∂count_space UNIV) - (∫⇧+ x. spmf p x ∂count_space UNIV)"
by(subst nn_integral_diff)(simp_all add: AE_count_space le' nn_integral_spmf_neq_top)
also have "… = (1 - pmf q None) - (1 - pmf p None)" unfolding pmf_None_eq_weight_spmf
also have "… = 0" using None by simp
finally have "⋀x. spmf q x ≤ spmf p x"
by(simp add: nn_integral_0_iff_AE AE_count_space ennreal_minus ennreal_eq_0_iff)
with le' show "spmf p i = spmf q i" by(rule antisym)
qed

lemma ord_spmf_eqD_pmf_None:
assumes "ord_spmf op = x y"
shows "pmf x None ≥ pmf y None"
using assms
apply cases
apply(clarsimp simp only: ennreal_le_iff[symmetric, OF pmf_nonneg] ennreal_pmf_map)
apply(fastforce simp add: AE_measure_pmf_iff intro!: nn_integral_mono_AE)
done

text ‹
Chains on @{typ "'a spmf"} maintain countable support.
Thanks to Johannes HÃ¶lzl for the proof idea.
›
lemma spmf_chain_countable: "countable (⋃p∈Y. set_spmf p)"
proof(cases "Y = {}")
case Y: False
show ?thesis
proof(cases "∃x∈Y. ∀y∈Y. ord_spmf op = y x")
case True
then obtain x where x: "x ∈ Y" and upper: "⋀y. y ∈ Y ⟹ ord_spmf op = y x" by blast
hence "(⋃x∈Y. set_spmf x) ⊆ set_spmf x" by(auto dest: ord_spmf_eqD_set_spmf)
thus ?thesis by(rule countable_subset) simp
next
case False
define N :: "'a option pmf ⇒ real" where "N p = pmf p None" for p

have N_less_imp_le_spmf: "⟦ x ∈ Y; y ∈ Y; N y < N x ⟧ ⟹ ord_spmf op = x y" for x y
using chainD[OF chain, of x y] ord_spmf_eqD_pmf_None[of x y] ord_spmf_eqD_pmf_None[of y x]
by (auto simp: N_def)
have N_eq_imp_eq: "⟦ x ∈ Y; y ∈ Y; N y = N x ⟧ ⟹ x = y" for x y
using chainD[OF chain, of x y] by(auto simp add: N_def dest: ord_spmf_eq_pmf_None_eq)

have NC: "N ` Y ≠ {}" "bdd_below (N ` Y)"
using ‹Y ≠ {}› by(auto intro!: bdd_belowI[of _ 0] simp: N_def)
have NC_less: "Inf (N ` Y) < N x" if "x ∈ Y" for x unfolding cInf_less_iff[OF NC]
proof(rule ccontr)
assume **: "¬ (∃y∈N ` Y. y < N x)"
{ fix y
assume "y ∈ Y"
with ** consider "N x < N y" | "N x = N y" by(auto simp add: not_less le_less)
hence "ord_spmf op = y x" using ‹y ∈ Y› ‹x ∈ Y›
by cases(auto dest: N_less_imp_le_spmf N_eq_imp_eq intro: ord_spmf_reflI) }
with False ‹x ∈ Y› show False by blast
qed

from NC have "Inf (N ` Y) ∈ closure (N ` Y)" by (intro closure_contains_Inf)
then obtain X' where "⋀n. X' n ∈ N ` Y" and X': "X' ⇢ Inf (N ` Y)"
unfolding closure_sequential by auto
then obtain X where X: "⋀n. X n ∈ Y" and "X' = (λn. N (X n))" unfolding image_iff Bex_def by metis

with X' have seq: "(λn. N (X n)) ⇢ Inf (N ` Y)" by simp
have "(⋃x ∈ Y. set_spmf x) ⊆ (⋃n. set_spmf (X n))"
proof(rule UN_least)
fix x
assume "x ∈ Y"
from order_tendstoD(2)[OF seq NC_less[OF ‹x ∈ Y›]]
obtain i where "N (X i) < N x" by (auto simp: eventually_sequentially)
thus "set_spmf x ⊆ (⋃n. set_spmf (X n))" using X ‹x ∈ Y›
by(blast dest: N_less_imp_le_spmf ord_spmf_eqD_set_spmf)
qed
thus ?thesis by(rule countable_subset) simp
qed
qed simp

lemma lub_spmf_subprob: "(∫⇧+ x. (SUP p : Y. ennreal (spmf p x)) ∂count_space UNIV) ≤ 1"
proof(cases "Y = {}")
case True
next
case False
let ?B = "⋃p∈Y. set_spmf p"
have countable: "countable ?B" by(rule spmf_chain_countable)

have "(∫⇧+ x. (SUP p:Y. ennreal (spmf p x)) ∂count_space UNIV) =
(∫⇧+ x. (SUP p:Y. ennreal (spmf p x) * indicator ?B x) ∂count_space UNIV)"
by(intro nn_integral_cong SUP_cong)(auto split: split_indicator simp add: spmf_eq_0_set_spmf)
also have "… = (∫⇧+ x. (SUP p:Y. ennreal (spmf p x)) ∂count_space ?B)"
unfolding ennreal_indicator[symmetric] using False
also have "… = (SUP p:Y. ∫⇧+ x. spmf p x ∂count_space ?B)" using False _ countable
by(rule nn_integral_monotone_convergence_SUP_countable)(rule chain_ord_spmf_eqD)
also have "… ≤ 1"
proof(rule SUP_least)
fix p
assume "p ∈ Y"
have "(∫⇧+ x. spmf p x ∂count_space ?B) = ∫⇧+ x. ennreal (spmf p x) * indicator ?B x ∂count_space UNIV"
also have "… = ∫⇧+ x. spmf p x ∂count_space UNIV"
by(rule nn_integral_cong)(auto split: split_indicator simp add: spmf_eq_0_set_spmf ‹p ∈ Y›)
also have "… ≤ 1"
finally show "(∫⇧+ x. spmf p x ∂count_space ?B) ≤ 1" .
qed
finally show ?thesis .
qed

lemma spmf_lub_spmf:
assumes "Y ≠ {}"
shows "spmf lub_spmf x = (SUP p : Y. spmf p x)"
proof -
from assms obtain p where "p ∈ Y" by auto
have "spmf lub_spmf x = max 0 (enn2real (SUP p:Y. ennreal (spmf p x)))" unfolding lub_spmf_def
by(rule spmf_embed_spmf)(simp del: SUP_eq_top_iff Sup_eq_top_iff add: ennreal_enn2real_if SUP_spmf_neq_top' lub_spmf_subprob)
also have "… = enn2real (SUP p:Y. ennreal (spmf p x))"
by(rule max_absorb2)(simp)
also have "… = enn2real (ennreal (SUP p : Y. spmf p x))" using assms
by(subst ennreal_SUP[symmetric])(simp_all add: SUP_spmf_neq_top' del: SUP_eq_top_iff Sup_eq_top_iff)
also have "0 ≤ (⨆p∈Y. spmf p x)" using assms
by(auto intro!: cSUP_upper2 bdd_aboveI[where M=1] simp add: pmf_le_1)
then have "enn2real (ennreal (SUP p : Y. spmf p x)) = (SUP p : Y. spmf p x)"
by(rule enn2real_ennreal)
finally show ?thesis .
qed

lemma ennreal_spmf_lub_spmf: "Y ≠ {} ⟹ ennreal (spmf lub_spmf x) = (SUP p:Y. ennreal (spmf p x))"
unfolding spmf_lub_spmf by(subst ennreal_SUP)(simp_all add: SUP_spmf_neq_top' del: SUP_eq_top_iff Sup_eq_top_iff)

lemma lub_spmf_upper:
assumes p: "p ∈ Y"
shows "ord_spmf op = p lub_spmf"
proof(rule ord_pmf_increaseI)
fix x
from p have [simp]: "Y ≠ {}" by auto
from p have "ennreal (spmf p x) ≤ (SUP p:Y. ennreal (spmf p x))" by(rule SUP_upper)
also have "… = ennreal (spmf lub_spmf x)" using p
by(subst spmf_lub_spmf)(auto simp add: ennreal_SUP SUP_spmf_neq_top' simp del: SUP_eq_top_iff Sup_eq_top_iff)
finally show "spmf p x ≤ spmf lub_spmf x" by simp
qed simp

lemma lub_spmf_least:
assumes z: "⋀x. x ∈ Y ⟹ ord_spmf op = x z"
shows "ord_spmf op = lub_spmf z"
proof(cases "Y = {}")
case nonempty: False
show ?thesis
proof(rule ord_pmf_increaseI)
fix x
from nonempty obtain p where p: "p ∈ Y" by auto
have "ennreal (spmf lub_spmf x) = (SUP p:Y. ennreal (spmf p x))"
by(subst spmf_lub_spmf)(auto simp add: ennreal_SUP SUP_spmf_neq_top' nonempty simp del: SUP_eq_top_iff Sup_eq_top_iff)
also have "… ≤ ennreal (spmf z x)" by(rule SUP_least)(simp add: ord_spmf_eq_leD z)
finally show "spmf lub_spmf x ≤ spmf z x" by simp
qed simp
qed simp

lemma set_lub_spmf: "set_spmf lub_spmf = (⋃p∈Y. set_spmf p)" (is "?lhs = ?rhs")
proof(cases "Y = {}")
case [simp]: False
show ?thesis
proof(rule set_eqI)
fix x
have "x ∈ ?lhs ⟷ ennreal (spmf lub_spmf x) > 0"
also have "… ⟷ (∃p∈Y. ennreal (spmf p x) > 0)"
also have "… ⟷ x ∈ ?rhs"
finally show "x ∈ ?lhs ⟷ x ∈ ?rhs" .
qed
qed simp

lemma emeasure_lub_spmf:
assumes Y: "Y ≠ {}"
shows "emeasure (measure_spmf lub_spmf) A = (SUP y:Y. emeasure (measure_spmf y) A)"
(is "?lhs = ?rhs")
proof -
let ?M = "count_space (set_spmf lub_spmf)"
have "?lhs = ∫⇧+ x. ennreal (spmf lub_spmf x) * indicator A x ∂?M"
also have "… = ∫⇧+ x. (SUP y:Y. ennreal (spmf y x) * indicator A x) ∂?M"
unfolding ennreal_indicator[symmetric]
by(simp add: spmf_lub_spmf assms ennreal_SUP[OF SUP_spmf_neq_top'] SUP_mult_right_ennreal)
also from assms have "… = (SUP y:Y. ∫⇧+ x. ennreal (spmf y x) * indicator A x ∂?M)"
proof(rule nn_integral_monotone_convergence_SUP_countable)
have "(λi x. ennreal (spmf i x) * indicator A x) ` Y = (λf x. f x * indicator A x) ` (λp x. ennreal (spmf p x)) ` Y"
also have "Complete_Partial_Order.chain op ≤ …" using chain_ord_spmf_eqD
by(rule chain_imageI)(auto simp add: le_fun_def split: split_indicator)
finally show "Complete_Partial_Order.chain op ≤ ((λi x. ennreal (spmf i x) * indicator A x) ` Y)" .
qed simp
also have "… = (SUP y:Y. ∫⇧+ x. ennreal (spmf y x) * indicator A x ∂count_space UNIV)"
by(auto simp add: nn_integral_count_space_indicator set_lub_spmf spmf_eq_0_set_spmf split: split_indicator intro!: SUP_cong nn_integral_cong)
also have "… = ?rhs"
finally show ?thesis .
qed

lemma measure_lub_spmf:
assumes Y: "Y ≠ {}"
shows "measure (measure_spmf lub_spmf) A = (SUP y:Y. measure (measure_spmf y) A)" (is "?lhs = ?rhs")
proof -
have "ennreal ?lhs = ennreal ?rhs"
using emeasure_lub_spmf[OF assms] SUP_emeasure_spmf_neq_top[of A Y] Y
unfolding measure_spmf.emeasure_eq_measure by(subst ennreal_SUP)
moreover have "0 ≤ ?rhs" using Y
by(auto intro!: cSUP_upper2 bdd_aboveI[where M=1] measure_spmf.subprob_measure_le_1)
ultimately show ?thesis by(simp)
qed

lemma weight_lub_spmf:
assumes Y: "Y ≠ {}"
shows "weight_spmf lub_spmf = (SUP y:Y. weight_spmf y)"
unfolding weight_spmf_def by(rule measure_lub_spmf) fact

lemma measure_spmf_lub_spmf:
assumes Y: "Y ≠ {}"
shows "measure_spmf lub_spmf = (SUP p:Y. measure_spmf p)" (is "?lhs = ?rhs")
proof(rule measure_eqI)
from assms obtain p where p: "p ∈ Y" by auto
from chain have chain': "Complete_Partial_Order.chain op ≤ (measure_spmf ` Y)"
by(rule chain_imageI)(rule ord_spmf_eqD_measure_spmf)
show "sets ?lhs = sets ?rhs"
using Y by (subst sets_SUP) auto
show "emeasure ?lhs A = emeasure ?rhs A" for A
using chain' Y p by (subst emeasure_SUP_chain) (auto simp:  emeasure_lub_spmf)
qed

end

end

lemma partial_function_definitions_spmf: "partial_function_definitions (ord_spmf op =) lub_spmf"
(is "partial_function_definitions ?R _")
proof
fix x show "?R x x" by(simp add: ord_spmf_reflI)
next
fix x y z
assume "?R x y" "?R y z"
with transp_ord_option[OF transp_equality] show "?R x z" by(rule transp_rel_pmf[THEN transpD])
next
fix x y
assume "?R x y" "?R y x"
thus "x = y"
by(rule rel_pmf_antisym)(simp_all add: reflp_ord_option transp_ord_option antisymp_ord_option)
next
fix Y x
assume "Complete_Partial_Order.chain ?R Y" "x ∈ Y"
then show "?R x (lub_spmf Y)"
by(rule lub_spmf_upper)
next
fix Y z
assume "Complete_Partial_Order.chain ?R Y" "⋀x. x ∈ Y ⟹ ?R x z"
then show "?R (lub_spmf Y) z"
by(cases "Y = {}")(simp_all add: lub_spmf_least)
qed

lemma ccpo_spmf: "class.ccpo lub_spmf (ord_spmf op =) (mk_less (ord_spmf op =))"
by(rule ccpo partial_function_definitions_spmf)+

interpretation spmf: partial_function_definitions "ord_spmf op =" "lub_spmf"
rewrites "lub_spmf {} ≡ return_pmf None"
by(rule partial_function_definitions_spmf) simp

declaration ‹Partial_Function.init "spmf" @{term spmf.fixp_fun}
@{term spmf.mono_body} @{thm spmf.fixp_rule_uc} @{thm spmf.fixp_induct_uc}
NONE›

declare spmf.leq_refl[simp]

abbreviation "mono_spmf ≡ monotone (fun_ord (ord_spmf op =)) (ord_spmf op =)"

lemma lub_spmf_const [simp]: "lub_spmf {p} = p"
by(rule spmf_eqI)(simp add: spmf_lub_spmf[OF ccpo.chain_singleton[OF ccpo_spmf]])

lemma bind_spmf_mono':
assumes fg: "ord_spmf op = f g"
and hk: "⋀x :: 'a. ord_spmf op = (h x) (k x)"
shows "ord_spmf op = (f ⤜ h) (g ⤜ k)"
unfolding bind_spmf_def using assms(1)
by(rule rel_pmf_bindI)(auto split: option.split simp add: hk)

lemma bind_spmf_mono [partial_function_mono]:
assumes mf: "mono_spmf B" and mg: "⋀y. mono_spmf (λf. C y f)"
shows "mono_spmf (λf. bind_spmf (B f) (λy. C y f))"
proof (rule monotoneI)
fix f g :: "'a ⇒ 'b spmf"
assume fg: "fun_ord (ord_spmf op =) f g"
with mf have "ord_spmf op = (B f) (B g)" by (rule monotoneD[of _ _ _ f g])
moreover from mg have "⋀y'. ord_spmf op = (C y' f) (C y' g)"
by (rule monotoneD) (rule fg)
ultimately show "ord_spmf op = (bind_spmf (B f) (λy. C y f)) (bind_spmf (B g) (λy'. C y' g))"
by(rule bind_spmf_mono')
qed

lemma monotone_bind_spmf1: "monotone (ord_spmf op =) (ord_spmf op =) (λy. bind_spmf y g)"

lemma monotone_bind_spmf2:
assumes g: "⋀x. monotone ord (ord_spmf op =) (λy. g y x)"
shows "monotone ord (ord_spmf op =) (λy. bind_spmf p (g y))"
by(rule monotoneI)(auto intro: bind_spmf_mono' monotoneD[OF g] ord_spmf_reflI)

lemma bind_lub_spmf:
assumes chain: "Complete_Partial_Order.chain (ord_spmf op =) Y"
shows "bind_spmf (lub_spmf Y) f = lub_spmf ((λp. bind_spmf p f) ` Y)" (is "?lhs = ?rhs")
proof(cases "Y = {}")
case Y: False
show ?thesis
proof(rule spmf_eqI)
fix i
have chain': "Complete_Partial_Order.chain op ≤ ((λp x. ennreal (spmf p x * spmf (f x) i)) ` Y)"
using chain by(rule chain_imageI)(auto simp add: le_fun_def dest: ord_spmf_eq_leD intro: mult_right_mono)
have chain'': "Complete_Partial_Order.chain (ord_spmf op =) ((λp. p ⤜ f) ` Y)"
using chain by(rule chain_imageI)(auto intro!: monotoneI bind_spmf_mono' ord_spmf_reflI)
let ?M = "count_space (set_spmf (lub_spmf Y))"
have "ennreal (spmf ?lhs i) = ∫⇧+ x. ennreal (spmf (lub_spmf Y) x) * ennreal (spmf (f x) i) ∂?M"
by(auto simp add: ennreal_spmf_lub_spmf ennreal_spmf_bind nn_integral_measure_spmf')
also have "… = ∫⇧+ x. (SUP p:Y. ennreal (spmf p x * spmf (f x) i)) ∂?M"
by(subst ennreal_spmf_lub_spmf[OF chain Y])(subst SUP_mult_right_ennreal, simp_all add: ennreal_mult Y)
also have "… = (SUP p:Y. ∫⇧+ x. ennreal (spmf p x * spmf (f x) i) ∂?M)"
using Y chain' by(rule nn_integral_monotone_convergence_SUP_countable) simp
also have "… = (SUP p:Y. ennreal (spmf (bind_spmf p f) i))"
by(auto simp add: ennreal_spmf_bind nn_integral_measure_spmf nn_integral_count_space_indicator set_lub_spmf[OF chain] in_set_spmf_iff_spmf ennreal_mult intro!: SUP_cong nn_integral_cong split: split_indicator)
also have "… = ennreal (spmf ?rhs i)" using chain'' by(simp add: ennreal_spmf_lub_spmf Y)
finally show "spmf ?lhs i = spmf ?rhs i" by simp
qed
qed simp

lemma map_lub_spmf:
"Complete_Partial_Order.chain (ord_spmf op =) Y
⟹ map_spmf f (lub_spmf Y) = lub_spmf (map_spmf f ` Y)"
unfolding map_spmf_conv_bind_spmf[abs_def] by(simp add: bind_lub_spmf o_def)

lemma mcont_bind_spmf1: "mcont lub_spmf (ord_spmf op =) lub_spmf (ord_spmf op =) (λy. bind_spmf y f)"
using monotone_bind_spmf1 by(rule mcontI)(rule contI, simp add: bind_lub_spmf)

lemma bind_lub_spmf2:
assumes chain: "Complete_Partial_Order.chain ord Y"
and g: "⋀y. monotone ord (ord_spmf op =) (g y)"
shows "bind_spmf x (λy. lub_spmf (g y ` Y)) = lub_spmf ((λp. bind_spmf x (λy. g y p)) ` Y)"
(is "?lhs = ?rhs")
proof(cases "Y = {}")
case Y: False
show ?thesis
proof(rule spmf_eqI)
fix i
have chain': "⋀y. Complete_Partial_Order.chain (ord_spmf op =) (g y ` Y)"
using chain g[THEN monotoneD] by(rule chain_imageI)
have chain'': "Complete_Partial_Order.chain op ≤ ((λp y. ennreal (spmf x y * spmf (g y p) i)) ` Y)"
using chain by(rule chain_imageI)(auto simp add: le_fun_def dest: ord_spmf_eq_leD monotoneD[OF g] intro!: mult_left_mono)
have chain''': "Complete_Partial_Order.chain (ord_spmf op =) ((λp. bind_spmf x (λy. g y p)) ` Y)"
using chain by(rule chain_imageI)(rule monotone_bind_spmf2[OF g, THEN monotoneD])

have "ennreal (spmf ?lhs i) = ∫⇧+ y. (SUP p:Y. ennreal (spmf x y * spmf (g y p) i)) ∂count_space (set_spmf x)"
by(simp add: ennreal_spmf_bind ennreal_spmf_lub_spmf[OF chain'] Y nn_integral_measure_spmf' SUP_mult_left_ennreal ennreal_mult)
also have "… = (SUP p:Y. ∫⇧+ y. ennreal (spmf x y * spmf (g y p) i) ∂count_space (set_spmf x))"
unfolding nn_integral_measure_spmf' using Y chain''
by(rule nn_integral_monotone_convergence_SUP_countable) simp
also have "… = (SUP p:Y. ennreal (spmf (bind_spmf x (λy. g y p)) i))"
also have "… = ennreal (spmf ?rhs i)" using chain'''
finally show "spmf ?lhs i = spmf ?rhs i" by simp
qed
qed simp

lemma mcont_bind_spmf [cont_intro]:
assumes f: "mcont luba orda lub_spmf (ord_spmf op =) f"
and g: "⋀y. mcont luba orda lub_spmf (ord_spmf op =) (g y)"
shows "mcont luba orda lub_spmf (ord_spmf op =) (λx. bind_spmf (f x) (λy. g y x))"
proof(rule spmf.mcont2mcont'[OF _ _ f])
fix z
show "mcont lub_spmf (ord_spmf op =) lub_spmf (ord_spmf op =) (λx. bind_spmf x (λy. g y z))"
by(rule mcont_bind_spmf1)
next
fix x
let ?f = "λz. bind_spmf x (λy. g y z)"
have "monotone orda (ord_spmf op =) ?f" using mcont_mono[OF g] by(rule monotone_bind_spmf2)
moreover have "cont luba orda lub_spmf (ord_spmf op =) ?f"
proof(rule contI)
fix Y
assume chain: "Complete_Partial_Order.chain orda Y" and Y: "Y ≠ {}"
have "bind_spmf x (λy. g y (luba Y)) = bind_spmf x (λy. lub_spmf (g y ` Y))"
by(rule bind_spmf_cong)(simp_all add: mcont_contD[OF g chain Y])
also have "… = lub_spmf ((λp. x ⤜ (λy. g y p)) ` Y)" using chain
by(rule bind_lub_spmf2)(rule mcont_mono[OF g])
finally show "bind_spmf x (λy. g y (luba Y)) = …" .
qed
ultimately show "mcont luba orda lub_spmf (ord_spmf op =) ?f" by(rule mcontI)
qed

lemma bind_pmf_mono [partial_function_mono]:
"(⋀y. mono_spmf (λf. C y f)) ⟹ mono_spmf (λf. bind_pmf p (λx. C x f))"
using bind_spmf_mono[of "λ_. spmf_of_pmf p" C] by simp

lemma map_spmf_mono [partial_function_mono]: "mono_spmf B ⟹ mono_spmf (λg. map_spmf f (B g))"
unfolding map_spmf_conv_bind_spmf by(rule bind_spmf_mono) simp_all

lemma mcont_map_spmf [cont_intro]:
"mcont luba orda lub_spmf (ord_spmf op =) g
⟹ mcont luba orda lub_spmf (ord_spmf op =) (λx. map_spmf f (g x))"
unfolding map_spmf_conv_bind_spmf by(rule mcont_bind_spmf) simp_all

lemma monotone_set_spmf: "monotone (ord_spmf op =) op ⊆ set_spmf"
by(rule monotoneI)(rule ord_spmf_eqD_set_spmf)

lemma cont_set_spmf: "cont lub_spmf (ord_spmf op =) Union op ⊆ set_spmf"
by(rule contI)(subst set_lub_spmf; simp)

lemma mcont2mcont_set_spmf[THEN mcont2mcont, cont_intro]:
shows mcont_set_spmf: "mcont lub_spmf (ord_spmf op =) Union op ⊆ set_spmf"
by(rule mcontI monotone_set_spmf cont_set_spmf)+

lemma monotone_spmf: "monotone (ord_spmf op =) op ≤ (λp. spmf p x)"

lemma cont_spmf: "cont lub_spmf (ord_spmf op =) Sup op ≤ (λp. spmf p x)"

lemma mcont_spmf: "mcont lub_spmf (ord_spmf op =) Sup op ≤ (λp. spmf p x)"
by(rule mcontI monotone_spmf cont_spmf)+

lemma cont_ennreal_spmf: "cont lub_spmf (ord_spmf op =) Sup op ≤ (λp. ennreal (spmf p x))"

lemma mcont2mcont_ennreal_spmf [THEN mcont2mcont, cont_intro]:
shows mcont_ennreal_spmf: "mcont lub_spmf (ord_spmf op =) Sup op ≤ (λp. ennreal (spmf p x))"
by(rule mcontI mono2mono_ennreal monotone_spmf cont_ennreal_spmf)+

lemma nn_integral_map_spmf [simp]: "nn_integral (measure_spmf (map_spmf f p)) g = nn_integral (measure_spmf p) (g ∘ f)"
by(auto 4 3 simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space intro: nn_integral_cong split: split_indicator)

lemma rel_spmf_measureD:
assumes "rel_spmf R p q"
shows "measure (measure_spmf p) A ≤ measure (measure_spmf q) {y. ∃x∈A. R x y}" (is "?lhs ≤ ?rhs")
proof -
have "?lhs = measure (measure_pmf p) (Some ` A)" by(simp add: measure_measure_spmf_conv_measure_pmf)
also have "… ≤ measure (measure_pmf q) {y. ∃x∈Some ` A. rel_option R x y}"
using assms by(rule rel_pmf_measureD)
also have "… = ?rhs" unfolding measure_measure_spmf_conv_measure_pmf
by(rule arg_cong2[where f=measure])(auto simp add: option_rel_Some1)
finally show ?thesis .
qed

locale rel_spmf_characterisation =
assumes rel_pmf_measureI:
"⋀(R :: 'a option ⇒ 'b option ⇒ bool) p q.
(⋀A. measure (measure_pmf p) A ≤ measure (measure_pmf q) {y. ∃x∈A. R x y})
⟹ rel_pmf R p q"
― ‹This assumption is shown to hold in general in the AFP entry ‹MFMC_Countable›.›
begin

context fixes R :: "'a ⇒ 'b ⇒ bool" begin

lemma rel_spmf_measureI:
assumes eq1: "⋀A. measure (measure_spmf p) A ≤ measure (measure_spmf q) {y. ∃x∈A. R x y}"
assumes eq2: "weight_spmf q ≤ weight_spmf p"
shows "rel_spmf R p q"
proof(rule rel_pmf_measureI)
fix A :: "'a option set"
define A' where "A' = the ` (A ∩ range Some)"
define A'' where "A'' = A ∩ {None}"
have A: "A = Some ` A' ∪ A''" "Some ` A' ∩ A'' = {}"
unfolding A'_def A''_def by(auto 4 3 intro: rev_image_eqI)
have "measure (measure_pmf p) A = measure (measure_pmf p) (Some ` A') + measure (measure_pmf p) A''"
also have "measure (measure_pmf p) (Some ` A') = measure (measure_spmf p) A'"
also have "… ≤ measure (measure_spmf q) {y. ∃x∈A'. R x y}" by(rule eq1)
have "… = measure (measure_pmf q) {y. ∃x∈A'. rel_option R (Some x) y}"
unfolding measure_measure_spmf_conv_measure_pmf
by(rule arg_cong2[where f=measure])(auto simp add: A'_def option_rel_Some1)
also
{ have "weight_spmf p ≤ measure (measure_spmf q) {y. ∃x. R x y}"
using eq1[of UNIV] unfolding weight_spmf_def by simp
also have "… ≤ weight_spmf q" unfolding weight_spmf_def
by(rule measure_spmf.finite_measure_mono) simp_all
finally have "weight_spmf p = weight_spmf q" using eq2 by simp }
then have "measure (measure_pmf p) A'' = measure (measure_pmf q) (if None ∈ A then {None} else {})"
unfolding A''_def by(simp add: pmf_None_eq_weight_spmf measure_pmf_single)
also have "measure (measure_pmf q) {y. ∃x∈A'. rel_option R (Some x) y} + … = measure (measure_pmf q) {y. ∃x∈A. rel_option R x y}"
by(subst measure_pmf.finite_measure_Union[symmetric])
(auto 4 3 intro!: arg_cong2[where f=measure] simp add: option_rel_Some1 option_rel_Some2 A'_def intro: rev_bexI elim: option.rel_cases)
finally show "measure (measure_pmf p) A ≤ …" .
qed

"ccpo.admissible (prod_lub lub_spmf lub_spmf) (rel_prod (ord_spmf op =) (ord_spmf op =)) (case_prod (rel_spmf R))"
fix Y
assume chain: "Complete_Partial_Order.chain ?ord Y"
and Y: "Y ≠ {}"
and R: "∀(p, q) ∈ Y. rel_spmf R p q"
from R have R: "⋀p q. (p, q) ∈ Y ⟹ rel_spmf R p q" by auto
have chain1: "Complete_Partial_Order.chain (ord_spmf op =) (fst ` Y)"
and chain2: "Complete_Partial_Order.chain (ord_spmf op =) (snd ` Y)"
using chain by(rule chain_imageI; clarsimp)+
from Y have Y1: "fst ` Y ≠ {}" and Y2: "snd ` Y ≠ {}" by auto

have "rel_spmf R (lub_spmf (fst ` Y)) (lub_spmf (snd ` Y))"
proof(rule rel_spmf_measureI)
show "weight_spmf (lub_spmf (snd ` Y)) ≤ weight_spmf (lub_spmf (fst ` Y))"
by(auto simp add: weight_lub_spmf chain1 chain2 Y rel_spmf_weightD[OF R, symmetric] intro!: cSUP_least intro: cSUP_upper2[OF bdd_aboveI2[OF weight_spmf_le_1]])

fix A
have "measure (measure_spmf (lub_spmf (fst ` Y))) A = (SUP y:fst ` Y. measure (measure_spmf y) A)"
using chain1 Y1 by(rule measure_lub_spmf)
also have "… ≤ (SUP y:snd ` Y. measure (measure_spmf y) {y. ∃x∈A. R x y})" using Y1
by(rule cSUP_least)(auto intro!: cSUP_upper2[OF bdd_aboveI2[OF measure_spmf.subprob_measure_le_1]] rel_spmf_measureD R)
also have "… = measure (measure_spmf (lub_spmf (snd ` Y))) {y. ∃x∈A. R x y}"
using chain2 Y2 by(rule measure_lub_spmf[symmetric])
finally show "measure (measure_spmf (lub_spmf (fst ` Y))) A ≤ …" .
qed
then show "?P (?lub Y)" by(simp add: prod_lub_def)
qed

"⟦ mcont lub ord lub_spmf (ord_spmf op =) f; mcont lub ord lub_spmf (ord_spmf op =) g ⟧
⟹ ccpo.admissible lub ord (λx. rel_spmf R (f x) (g x))"

context includes lifting_syntax
begin

lemma fixp_spmf_parametric':
assumes f: "⋀x. monotone (ord_spmf op =) (ord_spmf op =) F"
and g: "⋀x. monotone (ord_spmf op =) (ord_spmf op =) G"
and param: "(rel_spmf R ===> rel_spmf R) F G"
shows "(rel_spmf R) (ccpo.fixp lub_spmf (ord_spmf op =) F) (ccpo.fixp lub_spmf (ord_spmf op =) G)"
by(rule parallel_fixp_induct[OF ccpo_spmf ccpo_spmf _ f g])(auto intro: param[THEN rel_funD])

lemma fixp_spmf_parametric:
assumes f: "⋀x. mono_spmf (λf. F f x)"
and g: "⋀x. mono_spmf (λf. G f x)"
and param: "((A ===> rel_spmf R) ===> A ===> rel_spmf R) F G"
shows "(A ===> rel_spmf R) (spmf.fixp_fun F) (spmf.fixp_fun G)"
using f g
proof(rule parallel_fixp_induct_1_1[OF partial_function_definitions_spmf partial_function_definitions_spmf _ _ reflexive reflexive, where P="(A ===> rel_spmf R)"])
show "ccpo.admissible (prod_lub (fun_lub lub_spmf) (fun_lub lub_spmf)) (rel_prod (fun_ord (ord_spmf op =)) (fun_ord (ord_spmf op =))) (λx. (A ===> rel_spmf R) (fst x) (snd x))"
unfolding rel_fun_def
apply(rule spmf.mcont2mcont[OF mcont_call])
apply(rule mcont_fst)
apply(rule spmf.mcont2mcont[OF mcont_call])
apply(rule mcont_snd)
done
show "(A ===> rel_spmf R) (λ_. lub_spmf {}) (λ_. lub_spmf {})" by auto
show "(A ===> rel_spmf R) (F f) (G g)" if "(A ===> rel_spmf R) f g" for f g
using that by(rule rel_funD[OF param])
qed

end

end

end

subsection ‹Restrictions on spmfs›

definition restrict_spmf :: "'a spmf ⇒ 'a set ⇒ 'a spmf" (infixl "↿" 110)
where "p ↿ A = map_pmf (λx. x ⤜ (λy. if y ∈ A then Some y else None)) p"

lemma set_restrict_spmf [simp]: "set_spmf (p ↿ A) = set_spmf p ∩ A"
by(fastforce simp add: restrict_spmf_def set_spmf_def split: bind_splits if_split_asm)

lemma restrict_map_spmf: "map_spmf f p ↿ A = map_spmf f (p ↿ (f -` A))"
by(simp add: restrict_spmf_def pmf.map_comp o_def map_option_bind bind_map_option if_distrib cong del: if_weak_cong)

lemma restrict_restrict_spmf [simp]: "p ↿ A ↿ B = p ↿ (A ∩ B)"
by(auto simp add: restrict_spmf_def pmf.map_comp o_def intro!: pmf.map_cong bind_option_cong)

lemma restrict_spmf_empty [simp]: "p ↿ {} = return_pmf None"

lemma restrict_spmf_UNIV [simp]: "p ↿ UNIV = p"

lemma spmf_restrict_spmf_outside [simp]: "x ∉ A ⟹ spmf (p ↿ A) x = 0"

lemma emeasure_restrict_spmf [simp]:
"emeasure (measure_spmf (p ↿ A)) X = emeasure (measure_spmf p) (X ∩ A)"
by(auto simp add: restrict_spmf_def measure_spmf_def emeasure_distr measurable_restrict_space1 emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure] split: bind_splits if_split_asm)

lemma measure_restrict_spmf [simp]:
"measure (measure_spmf (p ↿ A)) X = measure (measure_spmf p) (X ∩ A)"
using emeasure_restrict_spmf[of p A X]
by(simp only: measure_spmf.emeasure_eq_measure ennreal_inj measure_nonneg)

lemma spmf_restrict_spmf: "spmf (p ↿ A) x = (if x ∈ A then spmf p x else 0)"

lemma spmf_restrict_spmf_inside [simp]: "x ∈ A ⟹ spmf (p ↿ A) x = spmf p x"

lemma pmf_restrict_spmf_None: "pmf (p ↿ A) None = pmf p None + measure (measure_spmf p) (- A)"
proof -
have [simp]: "None ∉ Some ` (- A)" by auto
have "(λx. x ⤜ (λy. if y ∈ A then Some y else None)) -` {None} = {None} ∪ (Some ` (- A))"
by(auto split: bind_splits if_split_asm)
then show ?thesis unfolding ereal.inject[symmetric]
by(simp add: restrict_spmf_def ennreal_pmf_map emeasure_pmf_single del: ereal.inject)
(simp add: pmf.rep_eq measure_pmf.finite_measure_Union[symmetric] measure_measure_spmf_conv_measure_pmf measure_pmf.emeasure_eq_measure)
qed

lemma restrict_spmf_trivial: "(⋀x. x ∈ set_spmf p ⟹ x ∈ A) ⟹ p ↿ A = p"
by(rule spmf_eqI)(auto simp add: spmf_restrict_spmf spmf_eq_0_set_spmf)

lemma restrict_spmf_trivial': "set_spmf p ⊆ A ⟹ p ↿ A = p"
by(rule restrict_spmf_trivial) blast

lemma restrict_return_spmf: "return_spmf x ↿ A = (if x ∈ A then return_spmf x else return_pmf None)"

lemma restrict_return_spmf_inside [simp]: "x ∈ A ⟹ return_spmf x ↿ A = return_spmf x"

lemma restrict_return_spmf_outside [simp]: "x ∉ A ⟹ return_spmf x ↿ A = return_pmf None"

lemma restrict_spmf_return_pmf_None [simp]: "return_pmf None ↿ A = return_pmf None"

lemma restrict_bind_pmf: "bind_pmf p g ↿ A = p ⤜ (λx. g x ↿ A)"

lemma restrict_bind_spmf: "bind_spmf p g ↿ A = p ⤜ (λx. g x ↿ A)"
by(auto simp add: bind_spmf_def restrict_bind_pmf cong del: option.case_cong_weak cong: option.case_cong intro!: bind_pmf_cong split: option.split)

lemma bind_restrict_pmf: "bind_pmf (p ↿ A) g = p ⤜ (λx. if x ∈ Some ` A then g x else g None)"
by(auto simp add: restrict_spmf_def bind_map_pmf fun_eq_iff split: bind_split intro: arg_cong2[where f=bind_pmf])

lemma bind_restrict_spmf: "bind_spmf (p ↿ A) g = p ⤜ (λx. if x ∈ A then g x else return_pmf None)"
by(auto simp add: bind_spmf_def bind_restrict_pmf fun_eq_iff intro: arg_cong2[where f=bind_pmf] split: option.split)

lemma spmf_map_restrict: "spmf (map_spmf fst (p ↿ (snd -` {y}))) x = spmf p (x, y)"
by(subst spmf_map)(auto intro: arg_cong2[where f=measure] simp add: spmf_conv_measure_spmf)

lemma measure_eqI_restrict_spmf:
assumes "rel_spmf R (restrict_spmf p A) (restrict_spmf q B)"
shows "measure (measure_spmf p) A = measure (measure_spmf q) B"
proof -
from assms have "weight_spmf (restrict_spmf p A) = weight_spmf (restrict_spmf q B)" by(rule rel_spmf_weightD)
qed

subsection ‹Subprobability distributions of sets›

definition spmf_of_set :: "'a set ⇒ 'a spmf"
where
"spmf_of_set A = (if finite A ∧ A ≠ {} then spmf_of_pmf (pmf_of_set A) else return_pmf None)"

lemma spmf_of_set: "spmf (spmf_of_set A) x = indicator A x / card A"

lemma pmf_spmf_of_set_None [simp]: "pmf (spmf_of_set A) None = indicator {A. infinite A ∨ A = {}} A"

lemma set_spmf_of_set: "set_spmf (spmf_of_set A) = (if finite A then A else {})"

lemma set_spmf_of_set_finite [simp]: "finite A ⟹ set_spmf (spmf_of_set A) = A"

lemma spmf_of_set_singleton: "spmf_of_set {x} = return_spmf x"

lemma map_spmf_of_set_inj_on [simp]:
"inj_on f A ⟹ map_spmf f (spmf_of_set A) = spmf_of_set (f ` A)"
by(auto simp add: spmf_of_set_def map_pmf_of_set_inj dest: finite_imageD)

lemma spmf_of_pmf_pmf_of_set [simp]:
"⟦ finite A; A ≠ {} ⟧ ⟹ spmf_of_pmf (pmf_of_set A) = spmf_of_set A"

lemma weight_spmf_of_set:
"weight_spmf (spmf_of_set A) = (if finite A ∧ A ≠ {} then 1 else 0)"
by(auto simp only: spmf_of_set_def weight_spmf_of_pmf weight_return_pmf_None split: if_split)

lemma weight_spmf_of_set_finite [simp]: "⟦ finite A; A ≠ {} ⟧ ⟹ weight_spmf (spmf_of_set A) = 1"

lemma weight_spmf_of_set_infinite [simp]: "infinite A ⟹ weight_spmf (spmf_of_set A) = 0"

lemma measure_spmf_spmf_of_set:
"measure_spmf (spmf_of_set A) = (if finite A ∧ A ≠ {} then measure_pmf (pmf_of_set A) else null_measure (count_space UNIV))"

lemma emeasure_spmf_of_set:
"emeasure (measure_spmf (spmf_of_set S)) A = card (S ∩ A) / card S"

lemma measure_spmf_of_set:
"measure (measure_spmf (spmf_of_set S)) A = card (S ∩ A) / card S"

lemma nn_integral_spmf_of_set: "nn_integral (measure_spmf (spmf_of_set A)) f = sum f A / card A"
by(cases "finite A")(auto simp add: spmf_of_set_def nn_integral_pmf_of_set card_gt_0_iff simp del: spmf_of_pmf_pmf_of_set)

lemma integral_spmf_of_set: "integral⇧L (measure_spmf (spmf_of_set A)) f = sum f A / card A"
by(clarsimp simp add: spmf_of_set_def integral_pmf_of_set card_gt_0_iff simp del: spmf_of_pmf_pmf_of_set)

notepad begin ― ‹@{const pmf_of_set} is not fully parametric.›
define R :: "nat ⇒ nat ⇒ bool" where "R x y ⟷ (x ≠ 0 ⟶ y = 0)" for x y
define A :: "nat set" where "A = {0, 1}"
define B :: "nat set" where "B = {0, 1, 2}"
have "rel_set R A B" unfolding R_def[abs_def] A_def B_def rel_set_def by auto
have "¬ rel_pmf R (pmf_of_set A) (pmf_of_set B)"
proof
assume "rel_pmf R (pmf_of_set A) (pmf_of_set B)"
then obtain pq where pq: "⋀x y. (x, y) ∈ set_pmf pq ⟹ R x y"
and 1: "map_pmf fst pq = pmf_of_set A"
and 2: "map_pmf snd pq = pmf_of_set B"
by cases auto
have "pmf (pmf_of_set B) 1 = 1 / 3" by(simp add: B_def)
have "pmf (pmf_of_set B) 2 = 1 / 3" by(simp add: B_def)

have "2 / 3 = pmf (pmf_of_set B) 1 + pmf (pmf_of_set B) 2" by(simp add: B_def)
also have "… = measure (measure_pmf (pmf_of_set B)) ({1} ∪ {2})"
also have "… = emeasure (measure_pmf pq) (snd -` {2, 1})"
unfolding 2[symmetric] measure_pmf.emeasure_eq_measure[symmetric] by(simp)
also have "… = emeasure (measure_pmf pq) {(0, 2), (0, 1)}"
by(rule emeasure_eq_AE)(auto simp add: AE_measure_pmf_iff R_def dest!: pq)
also have "… ≤ emeasure (measure_pmf pq) (fst -` {0})"
by(rule emeasure_mono) auto
also have "… = emeasure (measure_pmf (pmf_of_set A)) {0}"
unfolding 1[symmetric] by simp
also have "… = pmf (pmf_of_set A) 0"
also have "pmf (pmf_of_set A) 0 = 1 / 2" by(simp add: A_def)
finally show False by(subst (asm) ennreal_le_iff; simp)
qed
end

lemma rel_pmf_of_set_bij:
assumes f: "bij_betw f A B"
and A: "A ≠ {}" "finite A"
and B: "B ≠ {}" "finite B"
and R: "⋀x. x ∈ A ⟹ R x (f x)"
shows "rel_pmf R (pmf_of_set A) (pmf_of_set B)"
proof(rule pmf.rel_mono_strong)
define AB where "AB = (λx. (x, f x)) ` A"
define R' where "R' x y ⟷ (x, y) ∈ AB" for x y
have "(x, y) ∈ AB" if "(x, y) ∈ set_pmf (pmf_of_set AB)" for x y
using that by(auto simp add: AB_def A)
moreover have "map_pmf fst (pmf_of_set AB) = pmf_of_set A"
by(simp add: AB_def map_pmf_of_set_inj[symmetric] inj_on_def A pmf.map_comp o_def)
moreover
from f have [simp]: "inj_on f A" by(rule bij_betw_imp_inj_on)
from f have [simp]: "f ` A = B" by(rule bij_betw_imp_surj_on)
have "map_pmf snd (pmf_of_set AB) = pmf_of_set B"
by(simp add: AB_def map_pmf_of_set_inj[symmetric] inj_on_def A pmf.map_comp o_def)
ultimately show "rel_pmf (λx y. (x, y) ∈ AB) (pmf_of_set A) (pmf_of_set B)" ..
qed(auto intro: R)

lemma rel_spmf_of_set_bij:
assumes f: "bij_betw f A B"
and R: "⋀x. x ∈ A ⟹ R x (f x)"
shows "rel_spmf R (spmf_of_set A) (spmf_of_set B)"
proof -
have "finite A ⟷ finite B" using f by(rule bij_betw_finite)
moreover have "A = {} ⟷ B = {}" using f by(auto dest: bij_betw_empty2 bij_betw_empty1)
ultimately show ?thesis using assms
by(auto simp add: spmf_of_set_def simp del: spmf_of_pmf_pmf_of_set intro: rel_pmf_of_set_bij)
qed

context includes lifting_syntax
begin

lemma rel_spmf_of_set:
assumes "bi_unique R"
shows "(rel_set R ===> rel_spmf R) spmf_of_set spmf_of_set"
proof
fix A B
assume R: "rel_set R A B"
with assms obtain f where "bij_betw f A B" and f: "⋀x. x ∈ A ⟹ R x (f x)"
by(auto dest: bi_unique_rel_set_bij_betw)
then show "rel_spmf R (spmf_of_set A) (spmf_of_set B)" by(rule rel_spmf_of_set_bij)
qed

end

lemma map_mem_spmf_of_set:
assumes "finite B" "B ≠ {}"
shows "map_spmf (λx. x ∈ A) (spmf_of_set B) = spmf_of_pmf (bernoulli_pmf (card (A ∩ B) / card B))"
(is "?lhs = ?rhs")
proof(rule spmf_eqI)
fix i
have "ennreal (spmf ?lhs i) = card (B ∩ (λx. x ∈ A) -` {i}) / (card B)"
by(subst ennreal_spmf_map)(simp add: measure_spmf_spmf_of_set assms emeasure_pmf_of_set)
also have "… = (if i then card (B ∩ A) / card B else card (B - A) / card B)"
by(auto intro: arg_cong[where f=card])
also have "… = (if i then card (B ∩ A) / card B else (card B - card (B ∩ A)) / card B)"
also have "… = ennreal (spmf ?rhs i)"
by(simp add: assms card_gt_0_iff field_simps card_mono Int_commute of_nat_diff)
finally show "spmf ?lhs i = spmf ?rhs i" by simp
qed

abbreviation coin_spmf :: "bool spmf"
where "coin_spmf ≡ spmf_of_set UNIV"

lemma map_eq_const_coin_spmf: "map_spmf (op = c) coin_spmf = coin_spmf"
proof -
have "inj (op ⟷ c)" "range (op ⟷ c) = UNIV" by(auto intro: inj_onI)
then show ?thesis by simp
qed

lemma bind_coin_spmf_eq_const: "coin_spmf ⤜ (λx :: bool. return_spmf (b = x)) = coin_spmf"
using map_eq_const_coin_spmf unfolding map_spmf_conv_bind_spmf by simp

lemma bind_coin_spmf_eq_const': "coin_spmf ⤜ (λx :: bool. return_spmf (x = b)) = coin_spmf"
by(rewrite in "_ = ⌑" bind_coin_spmf_eq_const[symmetric, of b])(auto intro: bind_spmf_cong)

subsection ‹Losslessness›

definition lossless_spmf :: "'a spmf ⇒ bool"
where "lossless_spmf p ⟷ weight_spmf p = 1"

lemma lossless_iff_pmf_None: "lossless_spmf p ⟷ pmf p None = 0"

lemma lossless_return_spmf [iff]: "lossless_spmf (return_spmf x)"

lemma lossless_return_pmf_None [iff]: "¬ lossless_spmf (return_pmf None)"

lemma lossless_map_spmf [simp]: "lossless_spmf (map_spmf f p) ⟷ lossless_spmf p"

lemma lossless_bind_spmf [simp]:
"lossless_spmf (p ⤜ f) ⟷ lossless_spmf p ∧ (∀x∈set_spmf p. lossless_spmf (f x))"

lemma lossless_weight_spmfD: "lossless_spmf p ⟹ weight_spmf p = 1"

lemma lossless_iff_set_pmf_None:
"lossless_spmf p ⟷ None ∉ set_pmf p"

lemma lossless_spmf_of_set [simp]: "lossless_spmf (spmf_of_set A) ⟷ finite A ∧ A ≠ {}"

lemma lossless_spmf_spmf_of_spmf [simp]: "lossless_spmf (spmf_of_pmf p)"

lemma lossless_spmf_bind_pmf [simp]:
"lossless_spmf (bind_pmf p f) ⟷ (∀x∈set_pmf p. lossless_spmf (f x))"
by(simp add: lossless_iff_pmf_None pmf_bind integral_nonneg_AE integral_nonneg_eq_0_iff_AE measure_pmf.integrable_const_bound[where B=1] AE_measure_pmf_iff pmf_le_1)

lemma lossless_spmf_conv_spmf_of_pmf: "lossless_spmf p ⟷ (∃p'. p = spmf_of_pmf p')"
proof
assume "lossless_spmf p"
hence *: "⋀y. y ∈ set_pmf p ⟹ ∃x. y = Some x"

let ?p = "map_pmf the p"
have "p = spmf_of_pmf ?p"
proof(rule spmf_eqI)
fix i
have "ennreal (pmf (map_pmf the p) i) = ∫⇧+ x. indicator (the -` {i}) x ∂p" by(simp add: ennreal_pmf_map)
also have "… = ∫⇧+ x. indicator {i} x ∂measure_spmf p" unfolding measure_spmf_def
by(subst nn_integral_distr)(auto simp add: nn_integral_restrict_space AE_measure_pmf_iff simp del: nn_integral_indicator intro!: nn_integral_cong_AE split: split_indicator dest!: * )
also have "… = spmf p i" by(simp add: emeasure_spmf_single)
finally show "spmf p i = spmf (spmf_of_pmf ?p) i" by simp
qed
thus "∃p'. p = spmf_of_pmf p'" ..
qed auto

lemma spmf_False_conv_True: "lossless_spmf p ⟹ spmf p False = 1 - spmf p True"

lemma spmf_True_conv_False: "lossless_spmf p ⟹ spmf p True = 1 - spmf p False"

lemma bind_eq_return_spmf:
"bind_spmf p f = return_spmf x ⟷ (∀y∈set_spmf p. f y = return_spmf x) ∧ lossless_spmf p"
by(auto simp add: bind_spmf_def bind_eq_return_pmf in_set_spmf lossless_iff_pmf_None pmf_eq_0_set_pmf iff del: not_None_eq split: option.split)

lemma rel_spmf_return_spmf2:
"rel_spmf R p (return_spmf x) ⟷ lossless_spmf p ∧ (∀a∈set_spmf p. R a x)"
by(auto simp add: lossless_iff_set_pmf_None rel_pmf_return_pmf2 option_rel_Some2 in_set_spmf, metis in_set_spmf not_None_eq)

lemma rel_spmf_return_spmf1:
"rel_spmf R (return_spmf x) p ⟷ lossless_spmf p ∧ (∀a∈set_spmf p. R x a)"
using rel_spmf_return_spmf2[of "R¯¯"] by(simp add: spmf_rel_conversep)

lemma rel_spmf_bindI1:
assumes f: "⋀x. x ∈ set_spmf p ⟹ rel_spmf R (f x) q"
and p: "lossless_spmf p"
shows "rel_spmf R (bind_spmf p f) q"
proof -
fix x :: 'a
have "rel_spmf R (bind_spmf p f) (bind_spmf (return_spmf x) (λ_. q))"
by(rule rel_spmf_bindI[where R="λx _. x ∈ set_spmf p"])(simp_all add: rel_spmf_return_spmf2 p f)
then show ?thesis by simp
qed

lemma rel_spmf_bindI2:
"⟦ ⋀x. x ∈ set_spmf q ⟹ rel_spmf R p (f x); lossless_spmf q ⟧
⟹ rel_spmf R p (bind_spmf q f)"
using rel_spmf_bindI1[of q "conversep R" f p] by(simp add: spmf_rel_conversep)

subsection ‹Scaling›

definition scale_spmf :: "real ⇒ 'a spmf ⇒ 'a spmf"
where
"scale_spmf r p = embed_spmf (λx. min (inverse (weight_spmf p)) (max 0 r) * spmf p x)"

lemma scale_spmf_le_1:
"(∫⇧+ x. min (inverse (weight_spmf p)) (max 0 r) * spmf p x ∂count_space UNIV) ≤ 1" (is "?lhs ≤ _")
proof -
have "?lhs = min (inverse (weight_spmf p)) (max 0 r) * ∫⇧+ x. spmf p x ∂count_space UNIV"
by(subst nn_integral_cmult[symmetric])(simp_all add: weight_spmf_nonneg max_def min_def ennreal_mult)
also have "… ≤ 1" unfolding weight_spmf_eq_nn_integral_spmf[symmetric]
by(simp add: min_def max_def weight_spmf_nonneg order.strict_iff_order field_simps ennreal_mult[symmetric])
finally show ?thesis .
qed

lemma spmf_scale_spmf: "spmf (scale_spmf r p) x = max 0 (min (inverse (weight_spmf p)) r) * spmf p x" (is "?lhs = ?rhs")
unfolding scale_spmf_def
apply(subst spmf_embed_spmf[OF scale_spmf_le_1])
apply(simp add: max_def min_def weight_spmf_le_0 field_simps weight_spmf_nonneg not_le order.strict_iff_order)
apply(metis antisym_conv order_trans weight_spmf_nonneg zero_le_mult_iff zero_le_one)
done

lemma real_inverse_le_1_iff: fixes x :: real
shows "⟦ 0 ≤ x; x ≤ 1 ⟧ ⟹ 1 / x ≤ 1 ⟷ x = 1 ∨ x = 0"
by auto

lemma spmf_scale_spmf': "r ≤ 1 ⟹ spmf (scale_spmf r p) x = max 0 r * spmf p x"
using real_inverse_le_1_iff[OF weight_spmf_nonneg weight_spmf_le_1, of p]
by(auto simp add: spmf_scale_spmf max_def min_def field_simps)(metis pmf_le_0_iff spmf_le_weight)

lemma scale_spmf_neg: "r ≤ 0 ⟹ scale_spmf r p = return_pmf None"

lemma scale_spmf_return_None [simp]: "scale_spmf r (return_pmf None) = return_pmf None"

lemma scale_spmf_conv_bind_bernoulli:
assumes "r ≤ 1"
shows "scale_spmf r p = bind_pmf (bernoulli_pmf r) (λb. if b then p else return_pmf None)" (is "?lhs = ?rhs")
proof(rule spmf_eqI)
fix x
have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using assms
unfolding spmf_scale_spmf ennreal_pmf_bind nn_integral_measure_pmf UNIV_bool bernoulli_pmf.rep_eq
apply(auto simp add: nn_integral_count_space_finite max_def min_def field_simps real_inverse_le_1_iff[OF weight_spmf_nonneg weight_spmf_le_1] weight_spmf_lt_0 not_le ennreal_mult[symmetric])
apply (metis pmf_le_0_iff spmf_le_weight)
apply (metis pmf_le_0_iff spmf_le_weight)
apply (meson le_divide_eq_1_pos measure_spmf.subprob_measure_le_1 not_less order_trans weight_spmf_le_0)
by (meson divide_le_0_1_iff less_imp_le order_trans weight_spmf_le_0)
thus "spmf ?lhs x = spmf ?rhs x" by simp
qed

lemma nn_integral_spmf: "(∫⇧+ x. spmf p x ∂count_space A) = emeasure (measure_spmf p) A"
apply(simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space nn_integral_pmf[symmetric])
apply(rule nn_integral_bij_count_space[where g=Some])
done

lemma measure_spmf_scale_spmf: "measure_spmf (scale_spmf r p) = scale_measure (min (inverse (weight_spmf p)) r) (measure_spmf p)"
apply(rule measure_eqI)
apply simp
apply(subst nn_integral_cmult[symmetric])
apply(auto simp add: max_def min_def ennreal_mult[symmetric] not_le ennreal_lt_0)
done

lemma measure_spmf_scale_spmf':
"r ≤ 1 ⟹ measure_spmf (scale_spmf r p) = scale_measure r (measure_spmf p)"
unfolding measure_spmf_scale_spmf
apply(cases "weight_spmf p > 0")
apply(simp add: min.absorb2 field_simps weight_spmf_le_1 mult_le_one)
apply(clarsimp simp add: weight_spmf_le_0 min_def scale_spmf_neg weight_spmf_eq_0 not_less)
done

lemma scale_spmf_1 [simp]: "scale_spmf 1 p = p"
apply(rule spmf_eqI)
apply(simp add: spmf_scale_spmf max_def min_def order.strict_iff_order field_simps weight_spmf_nonneg)
apply(metis antisym_conv divide_le_eq_1 less_imp_le pmf_nonneg spmf_le_weight weight_spmf_nonneg weight_spmf_le_1)
done

lemma scale_spmf_0 [simp]: "scale_spmf 0 p = return_pmf None"
by(rule spmf_eqI)(simp add: spmf_scale_spmf min_def max_def weight_spmf_le_0)

lemma bind_scale_spmf:
assumes r: "r ≤ 1"
shows "bind_spmf (scale_spmf r p) f = bind_spmf p (λx. scale_spmf r (f x))"
(is "?lhs = ?rhs")
proof(rule spmf_eqI)
fix x
have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using r
by(simp add: ennreal_spmf_bind measure_spmf_scale_spmf' nn_integral_scale_measure spmf_scale_spmf')
(simp add: ennreal_mult ennreal_lt_0 nn_integral_cmult max_def min_def)
thus "spmf ?lhs x = spmf ?rhs x" by simp
qed

lemma scale_bind_spmf:
assumes "r ≤ 1"
shows "scale_spmf r (bind_spmf p f) = bind_spmf p (λx. scale_spmf r (f x))"
(is "?lhs = ?rhs")
proof(rule spmf_eqI)
fix x
have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using assms
unfolding spmf_scale_spmf'[OF assms]
by(simp add: ennreal_mult ennreal_spmf_bind spmf_scale_spmf' nn_integral_cmult max_def min_def)
thus "spmf ?lhs x = spmf ?rhs x" by simp
qed

lemma bind_spmf_const: "bind_spmf p (λx. q) = scale_spmf (weight_spmf p) q" (is "?lhs = ?rhs")
proof(rule spmf_eqI)
fix x
have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)"
using measure_spmf.subprob_measure_le_1[of p "space (measure_spmf p)"]
by(subst ennreal_spmf_bind)(simp add: spmf_scale_spmf' weight_spmf_le_1 ennreal_mult mult.commute max_def min_def measure_spmf.emeasure_eq_measure)
thus "spmf ?lhs x = spmf ?rhs x" by simp
qed

lemma map_scale_spmf: "map_spmf f (scale_spmf r p) = scale_spmf r (map_spmf f p)" (is "?lhs = ?rhs")
proof(rule spmf_eqI)
fix i
show "spmf ?lhs i = spmf ?rhs i" unfolding spmf_scale_spmf
by(subst (1 2) spmf_map)(auto simp add: measure_spmf_scale_spmf max_def min_def ennreal_lt_0)
qed

lemma set_scale_spmf: "set_spmf (scale_spmf r p) = (if r > 0 then set_spmf p else {})"
apply(simp add: max_def min_def not_le weight_spmf_lt_0 weight_spmf_eq_0 split: if_split_asm)
done

lemma set_scale_spmf' [simp]: "0 < r ⟹ set_spmf (scale_spmf r p) = set_spmf p"

lemma rel_spmf_scaleI:
assumes "r > 0 ⟹ rel_spmf A p q"
shows "rel_spmf A (scale_spmf r p) (scale_spmf r q)"
proof(cases "r > 0")
case True
from assms[OF this] show ?thesis
by(rule rel_spmfE)(auto simp add: map_scale_spmf[symmetric] spmf_rel_map True intro: rel_spmf_reflI)

lemma weight_scale_spmf: "weight_spmf (scale_spmf r p) = min 1 (max 0 r * weight_spmf p)"
proof -
have "ennreal (weight_spmf (scale_spmf r p)) = min 1 (max 0 r * ennreal (weight_spmf p))"
unfolding weight_spmf_eq_nn_integral_spmf
apply(simp add: spmf_scale_spmf ennreal_mult zero_ereal_def[symmetric] nn_integral_cmult)
apply(auto simp add: weight_spmf_eq_nn_integral_spmf[symmetric] field_simps min_def max_def not_le weight_spmf_lt_0 ennreal_mult[symmetric])
subgoal by(subst (asm) ennreal_mult[symmetric], meson divide_less_0_1_iff le_less_trans not_le weight_spmf_lt_0, simp+, meson not_le pos_divide_le_eq weight_spmf_le_0)
subgoal by(cases "r ≥ 0")(simp_all add: ennreal_mult[symmetric] weight_spmf_nonneg ennreal_lt_0, meson le_less_trans not_le pos_divide_le_eq zero_less_divide_1_iff)
done
thus ?thesis by(auto simp add: min_def max_def ennreal_mult[symmetric] split: if_split_asm)
qed

lemma weight_scale_spmf' [simp]:
"⟦ 0 ≤ r; r ≤ 1 ⟧ ⟹ weight_spmf (scale_spmf r p) = r * weight_spmf p"
by(simp add: weight_scale_spmf max_def min_def)(metis antisym_conv mult_left_le order_trans weight_spmf_le_1)

lemma pmf_scale_spmf_None:
"pmf (scale_spmf k p) None = 1 - min 1 (max 0 k * (1 - pmf p None))"

lemma scale_scale_spmf:
"scale_spmf r (scale_spmf r' p) = scale_spmf (r * max 0 (min (inverse (weight_spmf p)) r')) p"
(is "?lhs = ?rhs")
proof(rule spmf_eqI)
fix i
have "max 0 (min (1 / weight_spmf p) r') *
max 0 (min (1 / min 1 (weight_spmf p * max 0 r')) r) =
max 0 (min (1 / weight_spmf p) (r * max 0 (min (1 / weight_spmf p) r')))"
proof(cases "weight_spmf p > 0")
case False
thus ?thesis by(simp add: not_less weight_spmf_le_0)
next
case True
thus ?thesis by(simp add: field_simps max_def min.absorb_iff2[symmetric])(auto simp add: min_def field_simps zero_le_mult_iff)
qed
then show "spmf ?lhs i = spmf ?rhs i"
qed

lemma scale_scale_spmf' [simp]:
"⟦ 0 ≤ r; r ≤ 1; 0 ≤ r'; r' ≤ 1 ⟧
⟹ scale_spmf r (scale_spmf r' p) = scale_spmf (r * r') p"
apply(cases "weight_spmf p > 0")
apply(auto simp add: scale_scale_spmf min_def max_def field_simps not_le weight_spmf_lt_0 weight_spmf_eq_0 not_less weight_spmf_le_0)
apply(subgoal_tac "1 = r'")
apply (metis (no_types) div_by_1 eq_iff measure_spmf.subprob_measure_le_1 mult.commute mult_cancel_right1)
apply(meson eq_iff le_divide_eq_1_pos measure_spmf.subprob_measure_le_1 mult_imp_div_pos_le order.trans)
done

lemma scale_spmf_eq_same: "scale_spmf r p = p ⟷ weight_spmf p = 0 ∨ r = 1 ∨ r ≥ 1 ∧ weight_spmf p = 1"
(is "?lhs ⟷ ?rhs")
proof
assume ?lhs
hence "weight_spmf (scale_spmf r p) = weight_spmf p" by simp
hence *: "min 1 (max 0 r * weight_spmf p) = weight_spmf p" by(simp add: weight_scale_spmf)
hence **: "weight_spmf p = 0 ∨ r ≥ 1" by(auto simp add: min_def max_def split: if_split_asm)
show ?rhs
proof(cases "weight_spmf p = 0")
case False
with ** have "r ≥ 1" by simp
with * False have "r = 1 ∨ weight_spmf p = 1" by(simp add: max_def min_def not_le split: if_split_asm)
with ‹r ≥ 1› show ?thesis by simp
qed simp
qed(auto intro!: spmf_eqI simp add: spmf_scale_spmf, metis pmf_le_0_iff spmf_le_weight)

lemma map_const_spmf_of_set:
"⟦ finite A; A ≠ {} ⟧ ⟹ map_spmf (λ_. c) (spmf_of_set A) = return_spmf c"

subsection ‹Conditional spmfs›

lemma set_pmf_Int_Some: "set_pmf p ∩ Some ` A = {} ⟷ set_spmf p ∩ A = {}"

lemma measure_spmf_zero_iff: "measure (measure_spmf p) A = 0 ⟷ set_spmf p ∩ A = {}"
unfolding measure_measure_spmf_conv_measure_pmf by(simp add: measure_pmf_zero_iff set_pmf_Int_Some)

definition cond_spmf :: "'a spmf ⇒ 'a set ⇒ 'a spmf"
where "cond_spmf p A = (if set_spmf p ∩ A = {} then return_pmf None else cond_pmf p (Some ` A))"

lemma set_cond_spmf [simp]: "set_spmf (cond_spmf p A) = set_spmf p ∩ A"
by(auto 4 4 simp add: cond_spmf_def in_set_spmf iff: set_cond_pmf[THEN set_eq_iff[THEN iffD1], THEN spec, rotated])

lemma cond_map_spmf [simp]: "cond_spmf (map_spmf f p) A = map_spmf f (cond_spmf p (f -` A))"
proof -
have "map_option f -` Some ` A = Some ` f -` A" by auto
moreover have "set_pmf p ∩ map_option f -` Some ` A ≠ {}" if "Some x ∈ set_pmf p" "f x ∈ A" for x
using that by auto
ultimately show ?thesis by(auto simp add: cond_spmf_def in_set_spmf cond_map_pmf)
qed

lemma spmf_cond_spmf [simp]:
"spmf (cond_spmf p A) x = (if x ∈ A then spmf p x / measure (measure_spmf p) A else 0)"
by(auto simp add: cond_spmf_def pmf_cond set_pmf_Int_Some[symmetric] measure_measure_spmf_conv_measure_pmf measure_pmf_zero_iff)

lemma bind_eq_return_pmf_None:
"bind_spmf p f = return_pmf None ⟷ (∀x∈set_spmf p. f x = return_pmf None)"
by(auto simp add: bind_spmf_def bind_eq_return_pmf in_set_spmf split: option.splits)

lemma return_pmf_None_eq_bind:
"return_pmf None = bind_spmf p f ⟷ (∀x∈set_spmf p. f x = return_pmf None)"
using bind_eq_return_pmf_None[of p f] by auto

(* Conditional probabilities do not seem to interact nicely with bind. *)

subsection ‹Product spmf›

definition pair_spmf :: "'a spmf ⇒ 'b spmf ⇒ ('a × 'b) spmf"
where "pair_spmf p q = bind_pmf (pair_pmf p q) (λxy. case xy of (Some x, Some y) ⇒ return_spmf (x, y) | _ ⇒ return_pmf None)"

lemma map_fst_pair_spmf [simp]: "map_spmf fst (pair_spmf p q) = scale_spmf (weight_spmf q) p"
unfolding bind_spmf_const[symmetric]
apply(simp add: pair_spmf_def map_bind_pmf pair_pmf_def bind_assoc_pmf option.case_distrib)
apply(subst bind_commute_pmf)
apply(auto intro!: bind_pmf_cong[OF refl] simp add: bind_return_pmf bind_spmf_def bind_return_pmf' case_option_collapse option.case_distrib[where h="map_spmf _"] option.case_distrib[symmetric] case_option_id split: option.split cong del: option.case_cong_weak)
done

lemma map_snd_pair_spmf [simp]: "map_spmf snd (pair_spmf p q) = scale_spmf (weight_spmf p) q"
unfolding bind_spmf_const[symmetric]
apply(simp add: pair_spmf_def map_bind_pmf pair_pmf_def bind_assoc_pmf option.case_distrib
cong del: option.case_cong_weak)
apply(auto intro!: bind_pmf_cong[OF refl] simp add: bind_return_pmf bind_spmf_def bind_return_pmf' case_option_collapse option.case_distrib[where h="map_spmf _"] option.case_distrib[symmetric] case_option_id split: option.split cong del: option.case_cong_weak)
done

lemma set_pair_spmf [simp]: "set_spmf (pair_spmf p q) = set_spmf p × set_spmf q"
by(auto 4 3 simp add: pair_spmf_def set_spmf_bind_pmf bind_UNION in_set_spmf intro: rev_bexI split: option.splits)

lemma spmf_pair [simp]: "spmf (pair_spmf p q) (x, y) = spmf p x * spmf q y" (is "?lhs = ?rhs")
proof -
have "ennreal ?lhs = ∫⇧+ a. ∫⇧+ b. indicator {(x, y)} (a, b) ∂measure_spmf q ∂measure_spmf p"
unfolding measure_spmf_def pair_spmf_def ennreal_pmf_bind nn_integral_pair_pmf'
by(auto simp add: zero_ereal_def[symmetric] nn_integral_distr nn_integral_restrict_space nn_integral_multc[symmetric] intro!: nn_integral_cong split: option.split split_indicator)
also have "… = ∫⇧+ a. (∫⇧+ b. indicator {y} b ∂measure_spmf q) * indicator {x} a ∂measure_spmf p"
by(subst nn_integral_multc[symmetric])(auto intro!: nn_integral_cong split: split_indicator)
also have "… = ennreal ?rhs" by(simp add: emeasure_spmf_single max_def ennreal_mult mult.commute)
finally show ?thesis by simp
qed

lemma pair_map_spmf2: "pair_spmf p (map_spmf f q) = map_spmf (apsnd f) (pair_spmf p q)"
by(auto simp add: pair_spmf_def pair_map_pmf2 bind_map_pmf map_bind_pmf intro: bind_pmf_cong split: option.split)

lemma pair_map_spmf1: "pair_spmf (map_spmf f p) q = map_spmf (apfst f) (pair_spmf p q)"
by(auto simp add: pair_spmf_def pair_map_pmf1 bind_map_pmf map_bind_pmf intro: bind_pmf_cong split: option.split)

lemma pair_map_spmf: "pair_spmf (map_spmf f p) (map_spmf g q) = map_spmf (map_prod f g) (pair_spmf p q)"
unfolding pair_map_spmf2 pair_map_spmf1 spmf.map_comp by(simp add: apfst_def apsnd_def o_def prod.map_comp)

lemma pair_spmf_alt_def: "pair_spmf p q = bind_spmf p (λx. bind_spmf q (λy. return_spmf (x, y)))"
by(auto simp add: pair_spmf_def pair_pmf_def bind_spmf_def bind_assoc_pmf bind_return_pmf split: option.split intro: bind_pmf_cong)

lemma weight_pair_spmf [simp]: "weight_spmf (pair_spmf p q) = weight_spmf p * weight_spmf q"
unfolding pair_spmf_alt_def by(simp add: weight_bind_spmf o_def)

lemma pair_scale_spmf1: (* FIXME: generalise to arbitrary r *)
"r ≤ 1 ⟹ pair_spmf (scale_spmf r p) q = scale_spmf r (pair_spmf p q)"

lemma pair_scale_spmf2: (* FIXME: generalise to arbitrary r *)
"r ≤ 1 ⟹ pair_spmf p (scale_spmf r q) = scale_spmf r (pair_spmf p q)"

lemma pair_spmf_return_None1 [simp]: "pair_spmf (return_pmf None) p = return_pmf None"
by(rule spmf_eqI)(clarsimp)

lemma pair_spmf_return_None2 [simp]: "pair_spmf p (return_pmf None) = return_pmf None"
by(rule spmf_eqI)(clarsimp)

lemma pair_spmf_return_spmf1: "pair_spmf (return_spmf x) q = map_spmf (Pair x) q"
by(rule spmf_eqI)(auto split: split_indicator simp add: spmf_map_inj' inj_on_def intro: spmf_map_outside)

lemma pair_spmf_return_spmf2: "pair_spmf p (return_spmf y) = map_spmf (λx. (x, y)) p"
by(rule spmf_eqI)(auto split: split_indicator simp add: inj_on_def intro!: spmf_map_outside spmf_map_inj'[symmetric])

lemma pair_spmf_return_spmf [simp]: "pair_spmf (return_spmf x) (return_spmf y) = return_spmf (x, y)"

lemma rel_pair_spmf_prod:
"rel_spmf (rel_prod A B) (pair_spmf p q) (pair_spmf p' q') ⟷
rel_spmf A (scale_spmf (weight_spmf q) p) (scale_spmf (weight_spmf q') p') ∧
rel_spmf B (scale_spmf (weight_spmf p) q) (scale_spmf (weight_spmf p') q')"
(is "?lhs ⟷ ?rhs" is "_ ⟷ ?A ∧ ?B" is "_ ⟷ rel_spmf _ ?p ?p' ∧ rel_spmf _ ?q ?q'")
proof(intro iffI conjI)
assume ?rhs
then obtain pq pq' where p: "map_spmf fst pq = ?p" and p': "map_spmf snd pq = ?p'"
and q: "map_spmf fst pq' = ?q" and q': "map_spmf snd pq' = ?q'"
and *: "⋀x x'. (x, x') ∈ set_spmf pq ⟹ A x x'"
and **: "⋀y y'. (y, y') ∈ set_spmf pq' ⟹ B y y'" by(auto elim!: rel_spmfE)
let ?f = "λ((x, x'), (y, y')). ((x, y), (x', y'))"
let ?r = "1 / (weight_spmf p * weight_spmf q)"
let ?pq = "scale_spmf ?r (map_spmf ?f (pair_spmf pq pq'))"

{ fix p :: "'x spmf" and q :: "'y spmf"
assume "weight_spmf q ≠ 0"
and "weight_spmf p ≠ 0"
and "1 / (weight_spmf p * weight_spmf q) ≤ weight_spmf p * weight_spmf q"
hence "1 ≤ (weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q)"
moreover have "(weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q) ≤ (1 * 1) * (1 * 1)"
ultimately have "(weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q) = 1" by simp
hence *: "weight_spmf p * weight_spmf q = 1"
by(metis antisym_conv less_le mult_less_cancel_left1 weight_pair_spmf weight_spmf_le_1 weight_spmf_nonneg)
hence **: "weight_spmf p = 1" by(metis antisym_conv mult_left_le weight_spmf_le_1 weight_spmf_nonneg)
moreover from * ** have "weight_spmf q = 1" by simp
moreover note calculation }
note full = this

show ?lhs
proof
have [simp]: "fst ∘ ?f = map_prod fst fst" by(simp add: fun_eq_iff)
have "map_spmf fst ?pq = scale_spmf ?r (pair_spmf ?p ?q)"
by(simp add: pair_map_spmf[symmetric] p q map_scale_spmf spmf.map_comp)
also have "… = pair_spmf p q" using full[of p q]
by(simp add: pair_scale_spmf1 pair_scale_spmf2 weight_spmf_le_1 weight_spmf_nonneg)
(auto simp add: scale_scale_spmf max_def min_def field_simps weight_spmf_nonneg weight_spmf_eq_0)
finally show "map_spmf fst ?pq = …" .

have [simp]: "snd ∘ ?f = map_prod snd snd" by(simp add: fun_eq_iff)
from ‹?rhs› have eq: "weight_spmf p * weight_spmf q = weight_spmf p' * weight_spmf q'"
by(auto dest!: rel_spmf_weightD simp add: weight_spmf_le_1 weight_spmf_nonneg)

have "map_spmf snd ?pq = scale_spmf ?r (pair_spmf ?p' ?q')"
by(simp add: pair_map_spmf[symmetric] p' q' map_scale_spmf spmf.map_comp)
also have "… = pair_spmf p' q'" using full[of p' q'] eq
by(simp add: pair_scale_spmf1 pair_scale_spmf2 weight_spmf_le_1 weight_spmf_nonneg)
(auto simp add: scale_scale_spmf max_def min_def field_simps weight_spmf_nonneg weight_spmf_eq_0)
finally show "map_spmf snd ?pq = …" .
qed(auto simp add: set_scale_spmf split: if_split_asm dest: * ** )
next
assume ?lhs
then obtain pq where pq: "map_spmf fst pq = pair_spmf p q"
and pq': "map_spmf snd pq = pair_spmf p' q'"
and *: "⋀x y x' y'. ((x, y), (x', y')) ∈ set_spmf pq ⟹ A x x' ∧ B y y'"
by(auto elim: rel_spmfE)

show ?A
proof
let ?f = "(λ((x, y), (x', y')). (x, x'))"
let ?pq = "map_spmf ?f pq"
have [simp]: "fst ∘ ?f = fst ∘ fst" by(simp add: split_def o_def)
show "map_spmf fst ?pq = scale_spmf (weight_spmf q) p" using pq

have [simp]: "snd ∘ ?f = fst ∘ snd" by(simp add: split_def o_def)
show "map_spmf snd ?pq = scale_spmf (weight_spmf q') p'" using pq'
qed(auto dest: * )

show ?B
proof
let ?f = "(λ((x, y), (x', y')). (y, y'))"
let ?pq = "map_spmf ?f pq"
have [simp]: "fst ∘ ?f = snd ∘ fst" by(simp add: split_def o_def)
show "map_spmf fst ?pq = scale_spmf (weight_spmf p) q" using pq

have [simp]: "snd ∘ ?f = snd ∘ snd" by(simp add: split_def o_def)
show "map_spmf snd ?pq = scale_spmf (weight_spmf p') q'" using pq'
qed(auto dest: * )
qed

lemma pair_pair_spmf:
"pair_spmf (pair_spmf p q) r = map_spmf (λ(x, (y, z)). ((x, y), z)) (pair_spmf p (pair_spmf q r))"

lemma pair_commute_spmf:
"pair_spmf p q = map_spmf (λ(y, x). (x, y)) (pair_spmf q p)"
unfolding pair_spmf_alt_def by(subst bind_commute_spmf)(simp add: map_spmf_conv_bind_spmf)

subsection ‹Assertions›

definition assert_spmf :: "bool ⇒ unit spmf"
where "assert_spmf b = (if b then return_spmf () else return_pmf None)"

lemma assert_spmf_simps [simp]:
"assert_spmf True = return_spmf ()"
"assert_spmf False = return_pmf None"

lemma in_set_assert_spmf [simp]: "x ∈ set_spmf (assert_spmf p) ⟷ p"
by(cases p) simp_all

lemma set_spmf_assert_spmf_eq_empty [simp]: "set_spmf (assert_spmf b) = {} ⟷ ¬ b"
by(cases b) simp_all

lemma lossless_assert_spmf [iff]: "lossless_spmf (assert_spmf b) ⟷ b"
by(cases b) simp_all

subsection ‹Try›

definition try_spmf :: "'a spmf ⇒ 'a spmf ⇒ 'a spmf" ("TRY _ ELSE _" [0,60] 59)
where "try_spmf p q = bind_pmf p (λx. case x of None ⇒ q | Some y ⇒ return_spmf y)"

lemma try_spmf_lossless [simp]:
assumes "lossless_spmf p"
shows "TRY p ELSE q = p"
proof -
have "TRY p ELSE q = bind_pmf p return_pmf" unfolding try_spmf_def using assms
by(auto simp add: lossless_iff_set_pmf_None split: option.split intro: bind_pmf_cong)
qed

lemma try_spmf_return_spmf1: "TRY return_spmf x ELSE q = return_spmf x"

lemma try_spmf_return_None [simp]: "TRY return_pmf None ELSE q = q"

lemma try_spmf_return_pmf_None2 [simp]: "TRY p ELSE return_pmf None = p"
by(simp add: try_spmf_def option.case_distrib[symmetric] bind_return_pmf' case_option_id)

lemma map_try_spmf: "map_spmf f (try_spmf p q) = try_spmf (map_spmf f p) (map_spmf f q)"
by(simp add: try_spmf_def map_bind_pmf bind_map_pmf option.case_distrib[where h="map_spmf f"] o_def cong del: option.case_cong_weak)

lemma try_spmf_bind_pmf: "TRY (bind_pmf p f) ELSE q = bind_pmf p (λx. TRY (f x) ELSE q)"

lemma try_spmf_bind_spmf_lossless:
"lossless_spmf p ⟹ TRY (bind_spmf p f) ELSE q = bind_spmf p (λx. TRY (f x) ELSE q)"
by(auto simp add: try_spmf_def bind_spmf_def bind_assoc_pmf bind_return_pmf lossless_iff_set_pmf_None intro!: bind_pmf_cong split: option.split)

lemma try_spmf_bind_out:
"lossless_spmf p ⟹ bind_spmf p (λx. TRY (f x) ELSE q) = TRY (bind_spmf p f) ELSE q"

lemma lossless_try_spmf [simp]:
"lossless_spmf (TRY p ELSE q) ⟷ lossless_spmf p ∨ lossless_spmf q"
by(auto simp add: try_spmf_def in_set_spmf lossless_iff_set_pmf_None split: option.splits)

context includes lifting_syntax
begin

lemma try_spmf_parametric [transfer_rule]:
"(rel_spmf A ===> rel_spmf A ===> rel_spmf A) try_spmf try_spmf"
unfolding try_spmf_def[abs_def] by transfer_prover

end

lemma try_spmf_cong:
"⟦ p = p'; ¬ lossless_spmf p' ⟹ q = q' ⟧ ⟹ TRY p ELSE q = TRY p' ELSE q'"
unfolding try_spmf_def
by(rule bind_pmf_cong)(auto split: option.split simp add: lossless_iff_set_pmf_None)

lemma rel_spmf_try_spmf:
"⟦ rel_spmf R p p'; ¬ lossless_spmf p' ⟹ rel_spmf R q q' ⟧
⟹ rel_spmf R (TRY p ELSE q) (TRY p' ELSE q')"
unfolding try_spmf_def
apply(rule rel_pmf_bindI[where R="λx y. rel_option R x y ∧ x ∈ set_pmf p ∧ y ∈ set_pmf p'"])
apply(erule pmf.rel_mono_strong; simp)
apply(auto split: option.split simp add: lossless_iff_set_pmf_None)
done

lemma spmf_try_spmf:
"spmf (TRY p ELSE q) x = spmf p x + pmf p None * spmf q x"
proof -
have "ennreal (spmf (TRY p ELSE q) x) = ∫⇧+ y. ennreal (spmf q x) * indicator {None} y + indicator {Some x} y ∂measure_pmf p"
unfolding try_spmf_def ennreal_pmf_bind by(rule nn_integral_cong)(simp split: option.split split_indicator)
also have "… = (∫⇧+ y. ennreal (spmf q x) * indicator {None} y ∂measure_pmf p) + ∫⇧+ y. indicator {Some x} y ∂measure_pmf p"
also have "… = ennreal (spmf q x) * pmf p None + spmf p x" by(simp add: emeasure_pmf_single)
finally show ?thesis by(simp add: ennreal_mult[symmetric] ennreal_plus[symmetric] del: ennreal_plus)
qed

lemma try_scale_spmf_same [simp]: "lossless_spmf p ⟹ TRY scale_spmf k p ELSE p = p"
by(rule spmf_eqI)(auto simp add: spmf_try_spmf spmf_scale_spmf pmf_scale_spmf_None lossless_iff_pmf_None weight_spmf_conv_pmf_None min_def max_def field_simps)

lemma pmf_try_spmf_None [simp]: "pmf (TRY p ELSE q) None = pmf p None * pmf q None" (is "?lhs = ?rhs")
proof -
have "?lhs = ∫ x. pmf q None * indicator {None} x ∂measure_pmf p"
unfolding try_spmf_def pmf_bind by(rule Bochner_Integration.integral_cong)(simp_all split: option.split)
also have "… = ?rhs" by(simp add: measure_pmf_single)
finally show ?thesis .
qed

lemma try_bind_spmf_lossless2:
"lossless_spmf q ⟹ TRY (bind_spmf p f) ELSE q = TRY (p ⤜ (λx. TRY (f x) ELSE q)) ELSE q"
by(rule spmf_eqI)(simp add: spmf_try_spmf pmf_bind_spmf_None spmf_bind field_simps measure_spmf.integrable_const_bound[where B=1] pmf_le_1 lossless_iff_pmf_None)

lemma try_bind_spmf_lossless2':
fixes f :: "'a ⇒ 'b spmf" shows
"⟦ NO_MATCH (λx :: 'a. try_spmf (g x :: 'b spmf) (h x)) f; lossless_spmf q ⟧
⟹ TRY (bind_spmf p f) ELSE q = TRY (p ⤜ (λx :: 'a. TRY (f x) ELSE q)) ELSE q"
by(rule try_bind_spmf_lossless2)

lemma try_bind_assert_spmf:
"TRY (assert_spmf b ⤜ f) ELSE q = (if b then TRY (f ()) ELSE q else q)"
by simp

subsection ‹Miscellaneous›

lemma assumes "rel_spmf (λx y. bad1 x = bad2 y ∧ (¬ bad2 y ⟶ A x ⟷ B y)) p q" (is "rel_spmf ?A _ _")
and fundamental_lemma: "¦measure (measure_spmf p) {x. A x} - measure (measure_spmf q) {y. B y}¦ ≤
measure (measure_spmf p) {x. bad1 x}" (is ?fundamental)
proof -
have good: "rel_fun ?A op = (λx. A x ∧ ¬ bad1 x) (λy. B y ∧ ¬ bad2 y)" by(auto simp add: rel_fun_def)
from assms have 1: "measure (measure_spmf p) {x. A x ∧ ¬ bad1 x} = measure (measure_spmf q) {y. B y ∧ ¬ bad2 y}"
by(rule measure_spmf_parametric[THEN rel_funD, THEN rel_funD])(rule Collect_parametric[THEN rel_funD, OF good])

by(rule measure_spmf_parametric[THEN rel_funD, THEN rel_funD])(rule Collect_parametric[THEN rel_funD, OF bad])

let ?μp = "measure (measure_spmf p)" and ?μq = "measure (measure_spmf q)"
have "{x. A x ∧ bad1 x} ∪ {x. A x ∧ ¬ bad1 x} = {x. A x}"
and "{y. B y ∧ bad2 y} ∪ {y. B y ∧ ¬ bad2 y} = {y. B y}" by auto
then have "¦?μp {x. A x} - ?μq {x. B x}¦ = ¦?μp ({x. A x ∧ bad1 x} ∪ {x. A x ∧ ¬ bad1 x}) - ?μq ({y. B y ∧ bad2 y} ∪ {y. B y ∧ ¬ bad2 y})¦"
by simp
also have "… = ¦?μp {x. A x ∧ bad1 x} + ?μp {x. A x ∧ ¬ bad1 x} - ?μq {y. B y ∧ bad2 y} - ?μq {y. B y ∧ ¬ bad2 y}¦"
by(subst (1 2) measure_Union)(auto)
also have "… = ¦?μp {x. A x ∧ bad1 x} - ?μq {y. B y ∧ bad2 y}¦" using 1 by simp
also have "… ≤ max (?μp {x. A x ∧ bad1 x}) (?μq {y. B y ∧ bad2 y})"