# HG changeset patch # User Andreas Lochbihler # Date 1465369532 -7200 # Node ID ea13f44da888da4a477c43542b8d3321297fb2e9 # Parent 9986559617ee2b12c920e5cfd04f5aa5cfc47d79# Parent af43e35211c860152d1d62b7f644bdee0cf9d563 merged diff -r 9986559617ee -r ea13f44da888 src/HOL/Library/Complete_Partial_Order2.thy --- a/src/HOL/Library/Complete_Partial_Order2.thy Mon Jun 06 22:22:05 2016 +0200 +++ b/src/HOL/Library/Complete_Partial_Order2.thy Wed Jun 08 09:05:32 2016 +0200 @@ -1701,10 +1701,48 @@ \ mcont lub ord lubb ordb (\x. snd (t x))" by(auto intro!: mcontI monotoneI contI dest: mcont_monoD mcont_contD simp add: rel_prod_sel split_beta prod_lub_def image_image) +lemma monotone_Pair: + "\ monotone ord orda f; monotone ord ordb g \ + \ monotone ord (rel_prod orda ordb) (\x. (f x, g x))" +by(simp add: monotone_def) + +lemma cont_Pair: + "\ cont lub ord luba orda f; cont lub ord lubb ordb g \ + \ cont lub ord (prod_lub luba lubb) (rel_prod orda ordb) (\x. (f x, g x))" +by(rule contI)(auto simp add: prod_lub_def image_image dest!: contD) + +lemma mcont_Pair: + "\ mcont lub ord luba orda f; mcont lub ord lubb ordb g \ + \ mcont lub ord (prod_lub luba lubb) (rel_prod orda ordb) (\x. (f x, g x))" +by(rule mcontI)(simp_all add: monotone_Pair mcont_mono cont_Pair) + context partial_function_definitions begin text \Specialised versions of @{thm [source] mcont_call} for admissibility proofs for parallel fixpoint inductions\ lemmas mcont_call_fst [cont_intro] = mcont_call[THEN mcont2mcont, OF mcont_fst] lemmas mcont_call_snd [cont_intro] = mcont_call[THEN mcont2mcont, OF mcont_snd] end +lemma map_option_mono [partial_function_mono]: + "mono_option B \ mono_option (\f. map_option g (B f))" +unfolding map_conv_bind_option by(rule bind_mono) simp_all + +lemma compact_flat_lub [cont_intro]: "compact (flat_lub x) (flat_ord x) y" +using flat_interpretation[THEN ccpo] +proof(rule ccpo.compactI[OF _ ccpo.admissibleI]) + fix A + assume chain: "Complete_Partial_Order.chain (flat_ord x) A" + and A: "A \ {}" + and *: "\z\A. \ flat_ord x y z" + from A obtain z where "z \ A" by blast + with * have z: "\ flat_ord x y z" .. + hence y: "x \ y" "y \ z" by(auto simp add: flat_ord_def) + { assume "\ A \ {x}" + then obtain z' where "z' \ A" "z' \ x" by auto + then have "(THE z. z \ A - {x}) = z'" + by(intro the_equality)(auto dest: chainD[OF chain] simp add: flat_ord_def) + moreover have "z' \ y" using \z' \ A\ * by(auto simp add: flat_ord_def) + ultimately have "y \ (THE z. z \ A - {x})" by simp } + with z show "\ flat_ord x y (flat_lub x A)" by(simp add: flat_ord_def flat_lub_def) +qed + end diff -r 9986559617ee -r ea13f44da888 src/HOL/Probability/Probability.thy --- a/src/HOL/Probability/Probability.thy Mon Jun 06 22:22:05 2016 +0200 +++ b/src/HOL/Probability/Probability.thy Wed Jun 08 09:05:32 2016 +0200 @@ -8,6 +8,7 @@ Complete_Measure Projective_Limit Probability_Mass_Function + SPMF PMF_Impl Stream_Space Random_Permutations diff -r 9986559617ee -r ea13f44da888 src/HOL/Probability/SPMF.thy --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Probability/SPMF.thy Wed Jun 08 09:05:32 2016 +0200 @@ -0,0 +1,2726 @@ +(* Author: Andreas Lochbihler, ETH Zurich *) + +section \Discrete subprobability distribution\ + +theory SPMF imports + Probability_Mass_Function + Embed_Measure + "~~/src/HOL/Library/Complete_Partial_Order2" + "~~/src/HOL/Library/Rewrite" +begin + +subsection \Auxiliary material\ + +lemma cSUP_singleton [simp]: "(SUP x:{x}. f x :: _ :: conditionally_complete_lattice) = f x" +by (metis cSup_singleton image_empty image_insert) + +subsubsection \More about extended reals\ + +lemma [simp]: + shows ennreal_max_0: "ennreal (max 0 x) = ennreal x" + and ennreal_max_0': "ennreal (max x 0) = ennreal x" +by(simp_all add: max_def ennreal_eq_0_iff) + +lemma ennreal_enn2real_if: "ennreal (enn2real r) = (if r = \ then 0 else r)" +by(auto intro!: ennreal_enn2real simp add: less_top) + +lemma e2ennreal_0 [simp]: "e2ennreal 0 = 0" +by(simp add: zero_ennreal_def) + +lemma enn2real_bot [simp]: "enn2real \ = 0" +by(simp add: bot_ennreal_def) + +lemma continuous_at_ennreal[continuous_intros]: "continuous F f \ continuous F (\x. ennreal (f x))" + unfolding continuous_def by auto + +lemma ennreal_Sup: + assumes *: "(SUP a:A. ennreal a) \ \" + and "A \ {}" + shows "ennreal (Sup A) = (SUP a:A. ennreal a)" +proof (rule continuous_at_Sup_mono) + obtain r where r: "ennreal r = (SUP a:A. ennreal a)" "r \ 0" + using * by(cases "(SUP a:A. ennreal a)") simp_all + then show "bdd_above A" + by(auto intro!: SUP_upper bdd_aboveI[of _ r] simp add: ennreal_le_iff[symmetric]) +qed (auto simp: mono_def continuous_at_imp_continuous_at_within continuous_at_ennreal ennreal_leI assms) + +lemma ennreal_SUP: + "\ (SUP a:A. ennreal (f a)) \ \; A \ {} \ \ ennreal (SUP a:A. f a) = (SUP a:A. ennreal (f a))" +using ennreal_Sup[of "f ` A"] by auto + +lemma ennreal_lt_0: "x < 0 \ ennreal x = 0" +by(simp add: ennreal_eq_0_iff) + +subsubsection \More about @{typ "'a option"}\ + +lemma None_in_map_option_image [simp]: "None \ map_option f ` A \ None \ A" +by auto + +lemma Some_in_map_option_image [simp]: "Some x \ map_option f ` A \ (\y. x = f y \ Some y \ A)" +by(auto intro: rev_image_eqI dest: sym) + +lemma case_option_collapse: "case_option x (\_. x) = (\_. x)" +by(simp add: fun_eq_iff split: option.split) + +lemma case_option_id: "case_option None Some = id" +by(rule ext)(simp split: option.split) + +inductive ord_option :: "('a \ 'b \ bool) \ 'a option \ 'b option \ bool" + for ord :: "'a \ 'b \ bool" +where + None: "ord_option ord None x" +| Some: "ord x y \ ord_option ord (Some x) (Some y)" + +inductive_simps ord_option_simps [simp]: + "ord_option ord None x" + "ord_option ord x None" + "ord_option ord (Some x) (Some y)" + "ord_option ord (Some x) None" + +inductive_simps ord_option_eq_simps [simp]: + "ord_option op = None y" + "ord_option op = (Some x) y" + +lemma ord_option_reflI: "(\y. y \ set_option x \ ord y y) \ ord_option ord x x" +by(cases x) simp_all + +lemma reflp_ord_option: "reflp ord \ reflp (ord_option ord)" +by(simp add: reflp_def ord_option_reflI) + +lemma ord_option_trans: + "\ ord_option ord x y; ord_option ord y z; + \a b c. \ a \ set_option x; b \ set_option y; c \ set_option z; ord a b; ord b c \ \ ord a c \ + \ ord_option ord x z" +by(auto elim!: ord_option.cases) + +lemma transp_ord_option: "transp ord \ transp (ord_option ord)" +unfolding transp_def by(blast intro: ord_option_trans) + +lemma antisymP_ord_option: "antisymP ord \ antisymP (ord_option ord)" +by(auto intro!: antisymI elim!: ord_option.cases dest: antisymD) + +lemma ord_option_chainD: + "Complete_Partial_Order.chain (ord_option ord) Y + \ Complete_Partial_Order.chain ord {x. Some x \ Y}" +by(rule chainI)(auto dest: chainD) + +definition lub_option :: "('a set \ 'b) \ 'a option set \ 'b option" +where "lub_option lub Y = (if Y \ {None} then None else Some (lub {x. Some x \ Y}))" + +lemma map_lub_option: "map_option f (lub_option lub Y) = lub_option (f \ lub) Y" +by(simp add: lub_option_def) + +lemma lub_option_upper: + assumes "Complete_Partial_Order.chain (ord_option ord) Y" "x \ Y" + and lub_upper: "\Y x. \ Complete_Partial_Order.chain ord Y; x \ Y \ \ ord x (lub Y)" + shows "ord_option ord x (lub_option lub Y)" +using assms(1-2) +by(cases x)(auto simp add: lub_option_def intro: lub_upper[OF ord_option_chainD]) + +lemma lub_option_least: + assumes Y: "Complete_Partial_Order.chain (ord_option ord) Y" + and upper: "\x. x \ Y \ ord_option ord x y" + assumes lub_least: "\Y y. \ Complete_Partial_Order.chain ord Y; \x. x \ Y \ ord x y \ \ ord (lub Y) y" + shows "ord_option ord (lub_option lub Y) y" +using Y +by(cases y)(auto 4 3 simp add: lub_option_def intro: lub_least[OF ord_option_chainD] dest: upper) + +lemma lub_map_option: "lub_option lub (map_option f ` Y) = lub_option (lub \ op ` f) Y" +apply(auto simp add: lub_option_def) +apply(erule notE) +apply(rule arg_cong[where f=lub]) +apply(auto intro: rev_image_eqI dest: sym) +done + +lemma ord_option_mono: "\ ord_option A x y; \x y. A x y \ B x y \ \ ord_option B x y" +by(auto elim: ord_option.cases) + +lemma ord_option_mono' [mono]: + "(\x y. A x y \ B x y) \ ord_option A x y \ ord_option B x y" +by(blast intro: ord_option_mono) + +lemma ord_option_compp: "ord_option (A OO B) = ord_option A OO ord_option B" +by(auto simp add: fun_eq_iff elim!: ord_option.cases intro: ord_option.intros) + +lemma ord_option_inf: "inf (ord_option A) (ord_option B) = ord_option (inf A B)" (is "?lhs = ?rhs") +proof(rule antisym) + show "?lhs \ ?rhs" by(auto elim!: ord_option.cases) +qed(auto elim: ord_option_mono) + +lemma ord_option_map2: "ord_option ord x (map_option f y) = ord_option (\x y. ord x (f y)) x y" +by(auto elim: ord_option.cases) + +lemma ord_option_map1: "ord_option ord (map_option f x) y = ord_option (\x y. ord (f x) y) x y" +by(auto elim: ord_option.cases) + +lemma option_ord_Some1_iff: "option_ord (Some x) y \ y = Some x" +by(auto simp add: flat_ord_def) + +subsubsection \A relator for sets that treats sets like predicates\ + +context begin interpretation lifting_syntax . +definition rel_pred :: "('a \ 'b \ bool) \ 'a set \ 'b set \ bool" +where "rel_pred R A B = (R ===> op =) (\x. x \ A) (\y. y \ B)" + +lemma rel_predI: "(R ===> op =) (\x. x \ A) (\y. y \ B) \ rel_pred R A B" +by(simp add: rel_pred_def) + +lemma rel_predD: "\ rel_pred R A B; R x y \ \ x \ A \ y \ B" +by(simp add: rel_pred_def rel_fun_def) + +lemma Collect_parametric: "((A ===> op =) ===> rel_pred A) Collect Collect" + -- \Declare this rule as @{attribute transfer_rule} only locally + because it blows up the search space for @{method transfer} + (in combination with @{thm [source] Collect_transfer})\ +by(simp add: rel_funI rel_predI) +end + +subsubsection \Monotonicity rules\ + +lemma monotone_gfp_eadd1: "monotone op \ op \ (\x. x + y :: enat)" +by(auto intro!: monotoneI) + +lemma monotone_gfp_eadd2: "monotone op \ op \ (\y. x + y :: enat)" +by(auto intro!: monotoneI) + +lemma mono2mono_gfp_eadd[THEN gfp.mono2mono2, cont_intro, simp]: + shows monotone_eadd: "monotone (rel_prod op \ op \) op \ (\(x, y). x + y :: enat)" +by(simp add: monotone_gfp_eadd1 monotone_gfp_eadd2) + +lemma eadd_gfp_partial_function_mono [partial_function_mono]: + "\ monotone (fun_ord op \) op \ f; monotone (fun_ord op \) op \ g \ + \ monotone (fun_ord op \) op \ (\x. f x + g x :: enat)" +by(rule mono2mono_gfp_eadd) + +lemma mono2mono_ereal[THEN lfp.mono2mono]: + shows monotone_ereal: "monotone op \ op \ ereal" +by(rule monotoneI) simp + +lemma mono2mono_ennreal[THEN lfp.mono2mono]: + shows monotone_ennreal: "monotone op \ op \ ennreal" +by(rule monotoneI)(simp add: ennreal_leI) + +subsubsection \Bijections\ + +lemma bi_unique_rel_set_bij_betw: + assumes unique: "bi_unique R" + and rel: "rel_set R A B" + shows "\f. bij_betw f A B \ (\x\A. R x (f x))" +proof - + from assms obtain f where f: "\x. x \ A \ R x (f x)" and B: "\x. x \ A \ f x \ B" + apply(atomize_elim) + apply(fold all_conj_distrib) + apply(subst choice_iff[symmetric]) + apply(auto dest: rel_setD1) + done + have "inj_on f A" by(rule inj_onI)(auto dest!: f dest: bi_uniqueDl[OF unique]) + moreover have "f ` A = B" using rel + by(auto 4 3 intro: B dest: rel_setD2 f bi_uniqueDr[OF unique]) + ultimately have "bij_betw f A B" unfolding bij_betw_def .. + thus ?thesis using f by blast +qed + +lemma bij_betw_rel_setD: "bij_betw f A B \ rel_set (\x y. y = f x) A B" +by(rule rel_setI)(auto dest: bij_betwE bij_betw_imp_surj_on[symmetric]) + +subsection {* Subprobability mass function *} + +type_synonym 'a spmf = "'a option pmf" +translations (type) "'a spmf" \ (type) "'a option pmf" + +definition measure_spmf :: "'a spmf \ 'a measure" +where "measure_spmf p = distr (restrict_space (measure_pmf p) (range Some)) (count_space UNIV) the" + +abbreviation spmf :: "'a spmf \ 'a \ real" +where "spmf p x \ pmf p (Some x)" + +lemma space_measure_spmf: "space (measure_spmf p) = UNIV" +by(simp add: measure_spmf_def) + +lemma sets_measure_spmf [simp, measurable_cong]: "sets (measure_spmf p) = sets (count_space UNIV)" +by(simp add: measure_spmf_def) + +lemma measure_spmf_not_bot [simp]: "measure_spmf p \ \" +proof + assume "measure_spmf p = \" + hence "space (measure_spmf p) = space \" by simp + thus False by(simp add: space_measure_spmf) +qed + +lemma measurable_the_measure_pmf_Some [measurable, simp]: + "the \ measurable (restrict_space (measure_pmf p) (range Some)) (count_space UNIV)" +by(auto simp add: measurable_def sets_restrict_space space_restrict_space integral_restrict_space) + +lemma measurable_spmf_measure1[simp]: "measurable (measure_spmf M) N = UNIV \ space N" +by(auto simp: measurable_def space_measure_spmf) + +lemma measurable_spmf_measure2[simp]: "measurable N (measure_spmf M) = measurable N (count_space UNIV)" +by(intro measurable_cong_sets) simp_all + +lemma subprob_space_measure_spmf [simp, intro!]: "subprob_space (measure_spmf p)" +proof + show "emeasure (measure_spmf p) (space (measure_spmf p)) \ 1" + by(simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space measure_pmf.measure_le_1) +qed(simp add: space_measure_spmf) + +interpretation measure_spmf: subprob_space "measure_spmf p" for p +by(rule subprob_space_measure_spmf) + +lemma finite_measure_spmf [simp]: "finite_measure (measure_spmf p)" +by unfold_locales + +lemma spmf_conv_measure_spmf: "spmf p x = measure (measure_spmf p) {x}" +by(auto simp add: measure_spmf_def measure_distr measure_restrict_space pmf.rep_eq space_restrict_space intro: arg_cong2[where f=measure]) + +lemma emeasure_measure_spmf_conv_measure_pmf: + "emeasure (measure_spmf p) A = emeasure (measure_pmf p) (Some ` A)" +by(auto simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure]) + +lemma measure_measure_spmf_conv_measure_pmf: + "measure (measure_spmf p) A = measure (measure_pmf p) (Some ` A)" +using emeasure_measure_spmf_conv_measure_pmf[of p A] +by(simp add: measure_spmf.emeasure_eq_measure measure_pmf.emeasure_eq_measure) + +lemma emeasure_spmf_map_pmf_Some [simp]: + "emeasure (measure_spmf (map_pmf Some p)) A = emeasure (measure_pmf p) A" +by(auto simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure]) + +lemma measure_spmf_map_pmf_Some [simp]: + "measure (measure_spmf (map_pmf Some p)) A = measure (measure_pmf p) A" +using emeasure_spmf_map_pmf_Some[of p A] by(simp add: measure_spmf.emeasure_eq_measure measure_pmf.emeasure_eq_measure) + +lemma nn_integral_measure_spmf: "(\\<^sup>+ x. f x \measure_spmf p) = \\<^sup>+ x. ennreal (spmf p x) * f x \count_space UNIV" + (is "?lhs = ?rhs") +proof - + have "?lhs = \\<^sup>+ x. pmf p x * f (the x) \count_space (range Some)" + by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space nn_integral_measure_pmf nn_integral_count_space_indicator ac_simps times_ereal.simps(1)[symmetric] del: times_ereal.simps(1)) + also have "\ = \\<^sup>+ x. ennreal (spmf p (the x)) * f (the x) \count_space (range Some)" + by(rule nn_integral_cong) auto + also have "\ = \\<^sup>+ x. spmf p (the (Some x)) * f (the (Some x)) \count_space UNIV" + by(rule nn_integral_bij_count_space[symmetric])(simp add: bij_betw_def) + also have "\ = ?rhs" by simp + finally show ?thesis . +qed + +lemma integral_measure_spmf: + assumes "integrable (measure_spmf p) f" + shows "(\ x. f x \measure_spmf p) = \ x. spmf p x * f x \count_space UNIV" +proof - + have "integrable (count_space UNIV) (\x. spmf p x * f x)" + using assms by(simp add: integrable_iff_bounded nn_integral_measure_spmf abs_mult ennreal_mult'') + then show ?thesis using assms + by(simp add: real_lebesgue_integral_def nn_integral_measure_spmf ennreal_mult'[symmetric]) +qed + +lemma emeasure_spmf_single: "emeasure (measure_spmf p) {x} = spmf p x" +by(simp add: measure_spmf.emeasure_eq_measure spmf_conv_measure_spmf) + +lemma measurable_measure_spmf[measurable]: + "(\x. measure_spmf (M x)) \ measurable (count_space UNIV) (subprob_algebra (count_space UNIV))" +by (auto simp: space_subprob_algebra) + +lemma nn_integral_measure_spmf_conv_measure_pmf: + assumes [measurable]: "f \ borel_measurable (count_space UNIV)" + shows "nn_integral (measure_spmf p) f = nn_integral (restrict_space (measure_pmf p) (range Some)) (f \ the)" +by(simp add: measure_spmf_def nn_integral_distr o_def) + +lemma measure_spmf_in_space_subprob_algebra [simp]: + "measure_spmf p \ space (subprob_algebra (count_space UNIV))" +by(simp add: space_subprob_algebra) + +lemma nn_integral_spmf_neq_top: "(\\<^sup>+ x. spmf p x \count_space UNIV) \ \" +using nn_integral_measure_spmf[where f="\_. 1", of p, symmetric] by simp + +lemma SUP_spmf_neq_top': "(SUP p:Y. ennreal (spmf p x)) \ \" +proof(rule neq_top_trans) + show "(SUP p:Y. ennreal (spmf p x)) \ 1" by(rule SUP_least)(simp add: pmf_le_1) +qed simp + +lemma SUP_spmf_neq_top: "(SUP i. ennreal (spmf (Y i) x)) \ \" +proof(rule neq_top_trans) + show "(SUP i. ennreal (spmf (Y i) x)) \ 1" by(rule SUP_least)(simp add: pmf_le_1) +qed simp + +lemma SUP_emeasure_spmf_neq_top: "(SUP p:Y. emeasure (measure_spmf p) A) \ \" +proof(rule neq_top_trans) + show "(SUP p:Y. emeasure (measure_spmf p) A) \ 1" + by(rule SUP_least)(simp add: measure_spmf.subprob_emeasure_le_1) +qed simp + +subsection {* Support *} + +definition set_spmf :: "'a spmf \ 'a set" +where "set_spmf p = set_pmf p \ set_option" + +lemma set_spmf_rep_eq: "set_spmf p = {x. measure (measure_spmf p) {x} \ 0}" +proof - + have "\x :: 'a. the -` {x} \ range Some = {Some x}" by auto + then show ?thesis + by(auto simp add: set_spmf_def set_pmf.rep_eq measure_spmf_def measure_distr measure_restrict_space space_restrict_space intro: rev_image_eqI) +qed + +lemma in_set_spmf: "x \ set_spmf p \ Some x \ set_pmf p" +by(simp add: set_spmf_def) + +lemma AE_measure_spmf_iff [simp]: "(AE x in measure_spmf p. P x) \ (\x\set_spmf p. P x)" +by(auto 4 3 simp add: measure_spmf_def AE_distr_iff AE_restrict_space_iff AE_measure_pmf_iff set_spmf_def cong del: AE_cong) + +lemma spmf_eq_0_set_spmf: "spmf p x = 0 \ x \ set_spmf p" +by(auto simp add: pmf_eq_0_set_pmf set_spmf_def intro: rev_image_eqI) + +lemma in_set_spmf_iff_spmf: "x \ set_spmf p \ spmf p x \ 0" +by(auto simp add: set_spmf_def set_pmf_iff intro: rev_image_eqI) + +lemma set_spmf_return_pmf_None [simp]: "set_spmf (return_pmf None) = {}" +by(auto simp add: set_spmf_def) + +lemma countable_set_spmf [simp]: "countable (set_spmf p)" +by(simp add: set_spmf_def bind_UNION) + +lemma spmf_eqI: + assumes "\i. spmf p i = spmf q i" + shows "p = q" +proof(rule pmf_eqI) + fix i + show "pmf p i = pmf q i" + proof(cases i) + case (Some i') + thus ?thesis by(simp add: assms) + next + case None + have "ennreal (pmf p i) = measure (measure_pmf p) {i}" by(simp add: pmf_def) + also have "{i} = space (measure_pmf p) - range Some" + by(auto simp add: None intro: ccontr) + also have "measure (measure_pmf p) \ = ennreal 1 - measure (measure_pmf p) (range Some)" + by(simp add: measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf) + also have "range Some = (\x\set_spmf p. {Some x}) \ Some ` (- set_spmf p)" + by auto + also have "measure (measure_pmf p) \ = measure (measure_pmf p) (\x\set_spmf p. {Some x})" + by(rule measure_pmf.measure_zero_union)(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff) + also have "ennreal \ = \\<^sup>+ x. measure (measure_pmf p) {Some x} \count_space (set_spmf p)" + unfolding measure_pmf.emeasure_eq_measure[symmetric] + by(simp_all add: emeasure_UN_countable disjoint_family_on_def) + also have "\ = \\<^sup>+ x. spmf p x \count_space (set_spmf p)" by(simp add: pmf_def) + also have "\ = \\<^sup>+ x. spmf q x \count_space (set_spmf p)" by(simp add: assms) + also have "set_spmf p = set_spmf q" by(auto simp add: in_set_spmf_iff_spmf assms) + also have "(\\<^sup>+ x. spmf q x \count_space (set_spmf q)) = \\<^sup>+ x. measure (measure_pmf q) {Some x} \count_space (set_spmf q)" + by(simp add: pmf_def) + also have "\ = measure (measure_pmf q) (\x\set_spmf q. {Some x})" + unfolding measure_pmf.emeasure_eq_measure[symmetric] + by(simp_all add: emeasure_UN_countable disjoint_family_on_def) + also have "\ = measure (measure_pmf q) ((\x\set_spmf q. {Some x}) \ Some ` (- set_spmf q))" + by(rule ennreal_cong measure_pmf.measure_zero_union[symmetric])+(auto simp add: measure_pmf.prob_eq_0 AE_measure_pmf_iff in_set_spmf_iff_spmf set_pmf_iff) + also have "((\x\set_spmf q. {Some x}) \ Some ` (- set_spmf q)) = range Some" by auto + also have "ennreal 1 - measure (measure_pmf q) \ = measure (measure_pmf q) (space (measure_pmf q) - range Some)" + by(simp add: one_ereal_def measure_pmf.prob_compl ennreal_minus[symmetric] del: space_measure_pmf) + also have "space (measure_pmf q) - range Some = {i}" + by(auto simp add: None intro: ccontr) + also have "measure (measure_pmf q) \ = pmf q i" by(simp add: pmf_def) + finally show ?thesis by simp + qed +qed + +lemma integral_measure_spmf_restrict: + fixes f :: "'a \ 'b :: {banach, second_countable_topology}" shows + "(\ x. f x \measure_spmf M) = (\ x. f x \restrict_space (measure_spmf M) (set_spmf M))" +by(auto intro!: integral_cong_AE simp add: integral_restrict_space) + +lemma nn_integral_measure_spmf': + "(\\<^sup>+ x. f x \measure_spmf p) = \\<^sup>+ x. ennreal (spmf p x) * f x \count_space (set_spmf p)" +by(auto simp add: nn_integral_measure_spmf nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator) + +subsection {* Functorial structure *} + +abbreviation map_spmf :: "('a \ 'b) \ 'a spmf \ 'b spmf" +where "map_spmf f \ map_pmf (map_option f)" + +context begin +local_setup {* Local_Theory.map_background_naming (Name_Space.mandatory_path "spmf") *} + +lemma map_comp: "map_spmf f (map_spmf g p) = map_spmf (f \ g) p" +by(simp add: pmf.map_comp o_def option.map_comp) + +lemma map_id0: "map_spmf id = id" +by(simp add: pmf.map_id option.map_id0) + +lemma map_id [simp]: "map_spmf id p = p" +by(simp add: map_id0) + +lemma map_ident [simp]: "map_spmf (\x. x) p = p" +by(simp add: id_def[symmetric]) + +end + +lemma set_map_spmf [simp]: "set_spmf (map_spmf f p) = f ` set_spmf p" +by(simp add: set_spmf_def image_bind bind_image o_def Option.option.set_map) + +lemma map_spmf_cong: + "\ p = q; \x. x \ set_spmf q \ f x = g x \ + \ map_spmf f p = map_spmf g q" +by(auto intro: pmf.map_cong option.map_cong simp add: in_set_spmf) + +lemma map_spmf_cong_simp: + "\ p = q; \x. x \ set_spmf q =simp=> f x = g x \ + \ map_spmf f p = map_spmf g q" +unfolding simp_implies_def by(rule map_spmf_cong) + +lemma map_spmf_idI: "(\x. x \ set_spmf p \ f x = x) \ map_spmf f p = p" +by(rule map_pmf_idI map_option_idI)+(simp add: in_set_spmf) + +lemma emeasure_map_spmf: + "emeasure (measure_spmf (map_spmf f p)) A = emeasure (measure_spmf p) (f -` A)" +by(auto simp add: measure_spmf_def emeasure_distr measurable_restrict_space1 space_restrict_space emeasure_restrict_space intro: arg_cong2[where f=emeasure]) + +lemma measure_map_spmf: "measure (measure_spmf (map_spmf f p)) A = measure (measure_spmf p) (f -` A)" +using emeasure_map_spmf[of f p A] by(simp add: measure_spmf.emeasure_eq_measure) + +lemma measure_map_spmf_conv_distr: + "measure_spmf (map_spmf f p) = distr (measure_spmf p) (count_space UNIV) f" +by(rule measure_eqI)(simp_all add: emeasure_map_spmf emeasure_distr) + +lemma spmf_map_pmf_Some [simp]: "spmf (map_pmf Some p) i = pmf p i" +by(simp add: pmf_map_inj') + +lemma spmf_map_inj: "\ inj_on f (set_spmf M); x \ set_spmf M \ \ spmf (map_spmf f M) (f x) = spmf M x" +by(subst option.map(2)[symmetric, where f=f])(rule pmf_map_inj, auto simp add: in_set_spmf inj_on_def elim!: option.inj_map_strong[rotated]) + +lemma spmf_map_inj': "inj f \ spmf (map_spmf f M) (f x) = spmf M x" +by(subst option.map(2)[symmetric, where f=f])(rule pmf_map_inj'[OF option.inj_map]) + +lemma spmf_map_outside: "x \ f ` set_spmf M \ spmf (map_spmf f M) x = 0" +unfolding spmf_eq_0_set_spmf by simp + +lemma ennreal_spmf_map: "ennreal (spmf (map_spmf f p) x) = emeasure (measure_spmf p) (f -` {x})" +by(auto simp add: ennreal_pmf_map measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure]) + +lemma spmf_map: "spmf (map_spmf f p) x = measure (measure_spmf p) (f -` {x})" +using ennreal_spmf_map[of f p x] by(simp add: measure_spmf.emeasure_eq_measure) + +lemma ennreal_spmf_map_conv_nn_integral: + "ennreal (spmf (map_spmf f p) x) = integral\<^sup>N (measure_spmf p) (indicator (f -` {x}))" +by(auto simp add: ennreal_pmf_map measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space intro: arg_cong2[where f=emeasure]) + +subsection {* Monad operations *} + +subsubsection {* Return *} + +abbreviation return_spmf :: "'a \ 'a spmf" +where "return_spmf x \ return_pmf (Some x)" + +lemma pmf_return_spmf: "pmf (return_spmf x) y = indicator {y} (Some x)" +by(fact pmf_return) + +lemma measure_spmf_return_spmf: "measure_spmf (return_spmf x) = Giry_Monad.return (count_space UNIV) x" +by(rule measure_eqI)(simp_all add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space indicator_def) + +lemma measure_spmf_return_pmf_None [simp]: "measure_spmf (return_pmf None) = null_measure (count_space UNIV)" +by(rule measure_eqI)(auto simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space indicator_eq_0_iff) + +lemma set_return_spmf [simp]: "set_spmf (return_spmf x) = {x}" +by(auto simp add: set_spmf_def) + +subsubsection {* Bind *} + +definition bind_spmf :: "'a spmf \ ('a \ 'b spmf) \ 'b spmf" +where "bind_spmf x f = bind_pmf x (\a. case a of None \ return_pmf None | Some a' \ f a')" + +adhoc_overloading Monad_Syntax.bind bind_spmf + +lemma return_None_bind_spmf [simp]: "return_pmf None \ (f :: 'a \ _) = return_pmf None" +by(simp add: bind_spmf_def bind_return_pmf) + +lemma return_bind_spmf [simp]: "return_spmf x \ f = f x" +by(simp add: bind_spmf_def bind_return_pmf) + +lemma bind_return_spmf [simp]: "x \ return_spmf = x" +proof - + have "\a :: 'a option. (case a of None \ return_pmf None | Some a' \ return_spmf a') = return_pmf a" + by(simp split: option.split) + then show ?thesis + by(simp add: bind_spmf_def bind_return_pmf') +qed + +lemma bind_spmf_assoc [simp]: + fixes x :: "'a spmf" and f :: "'a \ 'b spmf" and g :: "'b \ 'c spmf" + shows "(x \ f) \ g = x \ (\y. f y \ g)" +by(auto simp add: bind_spmf_def bind_assoc_pmf fun_eq_iff bind_return_pmf split: option.split intro: arg_cong[where f="bind_pmf x"]) + +lemma pmf_bind_spmf_None: "pmf (p \ f) None = pmf p None + \ x. pmf (f x) None \measure_spmf p" + (is "?lhs = ?rhs") +proof - + let ?f = "\x. pmf (case x of None \ return_pmf None | Some x \ f x) None" + have "?lhs = \ x. ?f x \measure_pmf p" + by(simp add: bind_spmf_def pmf_bind) + also have "\ = \ x. ?f None * indicator {None} x + ?f x * indicator (range Some) x \measure_pmf p" + by(rule integral_cong)(auto simp add: indicator_def) + also have "\ = (\ x. ?f None * indicator {None} x \measure_pmf p) + (\ x. ?f x * indicator (range Some) x \measure_pmf p)" + by(rule integral_add)(auto 4 3 intro: integrable_real_mult_indicator measure_pmf.integrable_const_bound[where B=1] simp add: AE_measure_pmf_iff pmf_le_1) + also have "\ = pmf p None + \ x. indicator (range Some) x * pmf (f (the x)) None \measure_pmf p" + by(auto simp add: measure_measure_pmf_finite indicator_eq_0_iff intro!: integral_cong) + also have "\ = ?rhs" unfolding measure_spmf_def + by(subst integral_distr)(auto simp add: integral_restrict_space) + finally show ?thesis . +qed + +lemma spmf_bind: "spmf (p \ f) y = \ x. spmf (f x) y \measure_spmf p" +unfolding measure_spmf_def +by(subst integral_distr)(auto simp add: bind_spmf_def pmf_bind integral_restrict_space indicator_eq_0_iff intro!: integral_cong split: option.split) + +lemma ennreal_spmf_bind: "ennreal (spmf (p \ f) x) = \\<^sup>+ y. spmf (f y) x \measure_spmf p" +by(auto simp add: bind_spmf_def ennreal_pmf_bind nn_integral_measure_spmf_conv_measure_pmf nn_integral_restrict_space intro: nn_integral_cong split: split_indicator option.split) + +lemma measure_spmf_bind_pmf: "measure_spmf (p \ f) = measure_pmf p \ measure_spmf \ f" + (is "?lhs = ?rhs") +proof(rule measure_eqI) + show "sets ?lhs = sets ?rhs" + by(simp add: sets_bind[where N="count_space UNIV"] space_measure_spmf) +next + fix A :: "'a set" + have "emeasure ?lhs A = \\<^sup>+ x. emeasure (measure_spmf (f x)) A \measure_pmf p" + by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def) + also have "\ = emeasure ?rhs A" + by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra) + finally show "emeasure ?lhs A = emeasure ?rhs A" . +qed + +lemma measure_spmf_bind: "measure_spmf (p \ f) = measure_spmf p \ measure_spmf \ f" + (is "?lhs = ?rhs") +proof(rule measure_eqI) + show "sets ?lhs = sets ?rhs" + by(simp add: sets_bind[where N="count_space UNIV"] space_measure_spmf) +next + fix A :: "'a set" + let ?A = "the -` A \ range Some" + have "emeasure ?lhs A = \\<^sup>+ x. emeasure (measure_pmf (case x of None \ return_pmf None | Some x \ f x)) ?A \measure_pmf p" + by(simp add: measure_spmf_def emeasure_distr space_restrict_space emeasure_restrict_space bind_spmf_def) + also have "\ = \\<^sup>+ x. emeasure (measure_pmf (f (the x))) ?A * indicator (range Some) x \measure_pmf p" + by(rule nn_integral_cong)(auto split: option.split simp add: indicator_def) + also have "\ = \\<^sup>+ x. emeasure (measure_spmf (f x)) A \measure_spmf p" + by(simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space emeasure_distr space_restrict_space emeasure_restrict_space) + also have "\ = emeasure ?rhs A" + by(simp add: emeasure_bind[where N="count_space UNIV"] space_measure_spmf space_subprob_algebra) + finally show "emeasure ?lhs A = emeasure ?rhs A" . +qed + +lemma map_spmf_bind_spmf: "map_spmf f (bind_spmf p g) = bind_spmf p (map_spmf f \ g)" +by(auto simp add: bind_spmf_def map_bind_pmf fun_eq_iff split: option.split intro: arg_cong2[where f=bind_pmf]) + +lemma bind_map_spmf: "map_spmf f p \ g = p \ g \ f" +by(simp add: bind_spmf_def bind_map_pmf o_def cong del: option.case_cong_weak) + +lemma spmf_bind_leI: + assumes "\y. y \ set_spmf p \ spmf (f y) x \ r" + and "0 \ r" + shows "spmf (bind_spmf p f) x \ r" +proof - + have "ennreal (spmf (bind_spmf p f) x) = \\<^sup>+ y. spmf (f y) x \measure_spmf p" by(rule ennreal_spmf_bind) + also have "\ \ \\<^sup>+ y. r \measure_spmf p" by(rule nn_integral_mono_AE)(simp add: assms) + also have "\ \ r" using assms measure_spmf.emeasure_space_le_1 + by(auto simp add: measure_spmf.emeasure_eq_measure intro!: mult_left_le) + finally show ?thesis using assms(2) by(simp) +qed + +lemma map_spmf_conv_bind_spmf: "map_spmf f p = (p \ (\x. return_spmf (f x)))" +by(simp add: map_pmf_def bind_spmf_def)(rule bind_pmf_cong, simp_all split: option.split) + +lemma bind_spmf_cong: + "\ p = q; \x. x \ set_spmf q \ f x = g x \ + \ bind_spmf p f = bind_spmf q g" +by(auto simp add: bind_spmf_def in_set_spmf intro: bind_pmf_cong option.case_cong) + +lemma bind_spmf_cong_simp: + "\ p = q; \x. x \ set_spmf q =simp=> f x = g x \ + \ bind_spmf p f = bind_spmf q g" +by(simp add: simp_implies_def cong: bind_spmf_cong) + +lemma set_bind_spmf: "set_spmf (M \ f) = set_spmf M \ (set_spmf \ f)" +by(auto simp add: set_spmf_def bind_spmf_def bind_UNION split: option.splits) + +lemma bind_spmf_const_return_None [simp]: "bind_spmf p (\_. return_pmf None) = return_pmf None" +by(simp add: bind_spmf_def case_option_collapse) + +lemma bind_commute_spmf: + "bind_spmf p (\x. bind_spmf q (f x)) = bind_spmf q (\y. bind_spmf p (\x. f x y))" + (is "?lhs = ?rhs") +proof - + let ?f = "\x y. case x of None \ return_pmf None | Some a \ (case y of None \ return_pmf None | Some b \ f a b)" + have "?lhs = p \ (\x. q \ ?f x)" + unfolding bind_spmf_def by(rule bind_pmf_cong[OF refl])(simp split: option.split) + also have "\ = q \ (\y. p \ (\x. ?f x y))" by(rule bind_commute_pmf) + also have "\ = ?rhs" unfolding bind_spmf_def + by(rule bind_pmf_cong[OF refl])(auto split: option.split, metis bind_spmf_const_return_None bind_spmf_def) + finally show ?thesis . +qed + +subsection {* Relator *} + +abbreviation rel_spmf :: "('a \ 'b \ bool) \ 'a spmf \ 'b spmf \ bool" +where "rel_spmf R \ rel_pmf (rel_option R)" + +lemma rel_pmf_mono: + "\rel_pmf A f g; \x y. A x y \ B x y \ \ rel_pmf B f g" +using pmf.rel_mono[of A B] by(simp add: le_fun_def) + +lemma rel_spmf_mono: + "\rel_spmf A f g; \x y. A x y \ B x y \ \ rel_spmf B f g" +apply(erule rel_pmf_mono) +using option.rel_mono[of A B] by(simp add: le_fun_def) + +lemma rel_spmf_mono_strong: + "\ rel_spmf A f g; \x y. \ A x y; x \ set_spmf f; y \ set_spmf g \ \ B x y \ \ rel_spmf B f g" +apply(erule pmf.rel_mono_strong) +apply(erule option.rel_mono_strong) +apply(auto simp add: in_set_spmf) +done + +lemma rel_spmf_reflI: "(\x. x \ set_spmf p \ P x x) \ rel_spmf P p p" +by(rule rel_pmf_reflI)(auto simp add: set_spmf_def intro: rel_option_reflI) + +lemma rel_spmfI [intro?]: + "\ \x y. (x, y) \ set_spmf pq \ P x y; map_spmf fst pq = p; map_spmf snd pq = q \ + \ rel_spmf P p q" +by(rule rel_pmf.intros[where pq="map_pmf (\x. case x of None \ (None, None) | Some (a, b) \ (Some a, Some b)) pq"]) + (auto simp add: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong) + +lemma rel_spmfE [elim?, consumes 1, case_names rel_spmf]: + assumes "rel_spmf P p q" + obtains pq where + "\x y. (x, y) \ set_spmf pq \ P x y" + "p = map_spmf fst pq" + "q = map_spmf snd pq" +using assms +proof(cases rule: rel_pmf.cases[consumes 1, case_names rel_pmf]) + case (rel_pmf pq) + let ?pq = "map_pmf (\(a, b). case (a, b) of (Some x, Some y) \ Some (x, y) | _ \ None) pq" + have "\x y. (x, y) \ set_spmf ?pq \ P x y" + by(auto simp add: in_set_spmf split: option.split_asm dest: rel_pmf(1)) + moreover + have "\x. (x, None) \ set_pmf pq \ x = None" by(auto dest!: rel_pmf(1)) + then have "p = map_spmf fst ?pq" using rel_pmf(2) + by(auto simp add: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split) + moreover + have "\y. (None, y) \ set_pmf pq \ y = None" by(auto dest!: rel_pmf(1)) + then have "q = map_spmf snd ?pq" using rel_pmf(3) + by(auto simp add: pmf.map_comp split_beta intro!: pmf.map_cong split: option.split) + ultimately show thesis .. +qed + +lemma rel_spmf_simps: + "rel_spmf R p q \ (\pq. (\(x, y)\set_spmf pq. R x y) \ map_spmf fst pq = p \ map_spmf snd pq = q)" +by(auto intro: rel_spmfI elim!: rel_spmfE) + +lemma spmf_rel_map: + shows spmf_rel_map1: "\R f x. rel_spmf R (map_spmf f x) = rel_spmf (\x. R (f x)) x" + and spmf_rel_map2: "\R x g y. rel_spmf R x (map_spmf g y) = rel_spmf (\x y. R x (g y)) x y" +by(simp_all add: fun_eq_iff pmf.rel_map option.rel_map[abs_def]) + +lemma spmf_rel_conversep: "rel_spmf R\\ = (rel_spmf R)\\" +by(simp add: option.rel_conversep pmf.rel_conversep) + +lemma spmf_rel_eq: "rel_spmf op = = op =" +by(simp add: pmf.rel_eq option.rel_eq) + +context begin interpretation lifting_syntax . + +lemma bind_spmf_parametric [transfer_rule]: + "(rel_spmf A ===> (A ===> rel_spmf B) ===> rel_spmf B) bind_spmf bind_spmf" +unfolding bind_spmf_def[abs_def] by transfer_prover + +lemma return_spmf_parametric: "(A ===> rel_spmf A) return_spmf return_spmf" +by transfer_prover + +lemma map_spmf_parametric: "((A ===> B) ===> rel_spmf A ===> rel_spmf B) map_spmf map_spmf" +by transfer_prover + +lemma rel_spmf_parametric: + "((A ===> B ===> op =) ===> rel_spmf A ===> rel_spmf B ===> op =) rel_spmf rel_spmf" +by transfer_prover + +lemma set_spmf_parametric [transfer_rule]: + "(rel_spmf A ===> rel_set A) set_spmf set_spmf" +unfolding set_spmf_def[abs_def] by transfer_prover + +lemma return_spmf_None_parametric: + "(rel_spmf A) (return_pmf None) (return_pmf None)" +by simp + +end + +lemma rel_spmf_bindI: + "\ rel_spmf R p q; \x y. R x y \ rel_spmf P (f x) (g y) \ + \ rel_spmf P (p \ f) (q \ g)" +by(fact bind_spmf_parametric[THEN rel_funD, THEN rel_funD, OF _ rel_funI]) + +lemma rel_spmf_bind_reflI: + "(\x. x \ set_spmf p \ rel_spmf P (f x) (g x)) \ rel_spmf P (p \ f) (p \ g)" +by(rule rel_spmf_bindI[where R="\x y. x = y \ x \ set_spmf p"])(auto intro: rel_spmf_reflI) + +lemma rel_pmf_return_pmfI: "P x y \ rel_pmf P (return_pmf x) (return_pmf y)" +by(rule rel_pmf.intros[where pq="return_pmf (x, y)"])(simp_all) + +context begin interpretation lifting_syntax . +text \We do not yet have a relator for @{typ "'a measure"}, so we combine @{const measure} and @{const measure_pmf}\ +lemma measure_pmf_parametric: + "(rel_pmf A ===> rel_pred A ===> op =) (\p. measure (measure_pmf p)) (\q. measure (measure_pmf q))" +proof(rule rel_funI)+ + fix p q X Y + assume "rel_pmf A p q" and "rel_pred A X Y" + from this(1) obtain pq where A: "\x y. (x, y) \ set_pmf pq \ A x y" + and p: "p = map_pmf fst pq" and q: "q = map_pmf snd pq" by cases auto + show "measure p X = measure q Y" unfolding p q measure_map_pmf + by(rule measure_pmf.finite_measure_eq_AE)(auto simp add: AE_measure_pmf_iff dest!: A rel_predD[OF \rel_pred _ _ _\]) +qed + +lemma measure_spmf_parametric: + "(rel_spmf A ===> rel_pred A ===> op =) (\p. measure (measure_spmf p)) (\q. measure (measure_spmf q))" +unfolding measure_measure_spmf_conv_measure_pmf[abs_def] +apply(rule rel_funI)+ +apply(erule measure_pmf_parametric[THEN rel_funD, THEN rel_funD]) +apply(auto simp add: rel_pred_def rel_fun_def elim: option.rel_cases) +done +end + +subsection {* From @{typ "'a pmf"} to @{typ "'a spmf"} *} + +definition spmf_of_pmf :: "'a pmf \ 'a spmf" +where "spmf_of_pmf = map_pmf Some" + +lemma set_spmf_spmf_of_pmf [simp]: "set_spmf (spmf_of_pmf p) = set_pmf p" +by(auto simp add: spmf_of_pmf_def set_spmf_def bind_image o_def) + +lemma spmf_spmf_of_pmf [simp]: "spmf (spmf_of_pmf p) x = pmf p x" +by(simp add: spmf_of_pmf_def) + +lemma pmf_spmf_of_pmf_None [simp]: "pmf (spmf_of_pmf p) None = 0" +using ennreal_pmf_map[of Some p None] by(simp add: spmf_of_pmf_def) + +lemma emeasure_spmf_of_pmf [simp]: "emeasure (measure_spmf (spmf_of_pmf p)) A = emeasure (measure_pmf p) A" +by(simp add: emeasure_measure_spmf_conv_measure_pmf spmf_of_pmf_def inj_vimage_image_eq) + +lemma measure_spmf_spmf_of_pmf [simp]: "measure_spmf (spmf_of_pmf p) = measure_pmf p" +by(rule measure_eqI) simp_all + +lemma map_spmf_of_pmf [simp]: "map_spmf f (spmf_of_pmf p) = spmf_of_pmf (map_pmf f p)" +by(simp add: spmf_of_pmf_def pmf.map_comp o_def) + +lemma rel_spmf_spmf_of_pmf [simp]: "rel_spmf R (spmf_of_pmf p) (spmf_of_pmf q) = rel_pmf R p q" +by(simp add: spmf_of_pmf_def pmf.rel_map) + +lemma spmf_of_pmf_return_pmf [simp]: "spmf_of_pmf (return_pmf x) = return_spmf x" +by(simp add: spmf_of_pmf_def) + +lemma bind_spmf_of_pmf [simp]: "bind_spmf (spmf_of_pmf p) f = bind_pmf p f" +by(simp add: spmf_of_pmf_def bind_spmf_def bind_map_pmf) + +lemma set_spmf_bind_pmf: "set_spmf (bind_pmf p f) = Set.bind (set_pmf p) (set_spmf \ f)" +unfolding bind_spmf_of_pmf[symmetric] by(subst set_bind_spmf) simp + +lemma spmf_of_pmf_bind: "spmf_of_pmf (bind_pmf p f) = bind_pmf p (\x. spmf_of_pmf (f x))" +by(simp add: spmf_of_pmf_def map_bind_pmf) + +lemma bind_pmf_return_spmf: "p \ (\x. return_spmf (f x)) = spmf_of_pmf (map_pmf f p)" +by(simp add: map_pmf_def spmf_of_pmf_bind) + +subsection {* Weight of a subprobability *} + +abbreviation weight_spmf :: "'a spmf \ real" +where "weight_spmf p \ measure (measure_spmf p) (space (measure_spmf p))" + +lemma weight_spmf_def: "weight_spmf p = measure (measure_spmf p) UNIV" +by(simp add: space_measure_spmf) + +lemma weight_spmf_le_1: "weight_spmf p \ 1" +by(simp add: measure_spmf.subprob_measure_le_1) + +lemma weight_return_spmf [simp]: "weight_spmf (return_spmf x) = 1" +by(simp add: measure_spmf_return_spmf measure_return) + +lemma weight_return_pmf_None [simp]: "weight_spmf (return_pmf None) = 0" +by(simp) + +lemma weight_map_spmf [simp]: "weight_spmf (map_spmf f p) = weight_spmf p" +by(simp add: weight_spmf_def measure_map_spmf) + +lemma weight_spmf_of_pmf [simp]: "weight_spmf (spmf_of_pmf p) = 1" +using measure_pmf.prob_space[of p] by(simp add: spmf_of_pmf_def weight_spmf_def) + +lemma weight_spmf_nonneg: "weight_spmf p \ 0" +by(fact measure_nonneg) + +lemma (in finite_measure) integrable_weight_spmf [simp]: + "(\x. weight_spmf (f x)) \ borel_measurable M \ integrable M (\x. weight_spmf (f x))" +by(rule integrable_const_bound[where B=1])(simp_all add: weight_spmf_nonneg weight_spmf_le_1) + +lemma weight_spmf_eq_nn_integral_spmf: "weight_spmf p = \\<^sup>+ x. spmf p x \count_space UNIV" +by(simp add: measure_measure_spmf_conv_measure_pmf space_measure_spmf measure_pmf.emeasure_eq_measure[symmetric] nn_integral_pmf[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1) + +lemma weight_spmf_eq_nn_integral_support: + "weight_spmf p = \\<^sup>+ x. spmf p x \count_space (set_spmf p)" +unfolding weight_spmf_eq_nn_integral_spmf +by(auto simp add: nn_integral_count_space_indicator in_set_spmf_iff_spmf intro!: nn_integral_cong split: split_indicator) + +lemma pmf_None_eq_weight_spmf: "pmf p None = 1 - weight_spmf p" +proof - + have "weight_spmf p = \\<^sup>+ x. spmf p x \count_space UNIV" by(rule weight_spmf_eq_nn_integral_spmf) + also have "\ = \\<^sup>+ x. ennreal (pmf p x) * indicator (range Some) x \count_space UNIV" + by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1) + also have "\ + pmf p None = \\<^sup>+ x. ennreal (pmf p x) * indicator (range Some) x + ennreal (pmf p None) * indicator {None} x \count_space UNIV" + by(subst nn_integral_add)(simp_all add: max_def) + also have "\ = \\<^sup>+ x. pmf p x \count_space UNIV" + by(rule nn_integral_cong)(auto split: split_indicator) + also have "\ = 1" by (simp add: nn_integral_pmf) + finally show ?thesis by(simp add: ennreal_plus[symmetric] del: ennreal_plus) +qed + +lemma weight_spmf_conv_pmf_None: "weight_spmf p = 1 - pmf p None" +by(simp add: pmf_None_eq_weight_spmf) + +lemma weight_spmf_le_0: "weight_spmf p \ 0 \ weight_spmf p = 0" +by(rule measure_le_0_iff) + +lemma weight_spmf_lt_0: "\ weight_spmf p < 0" +by(simp add: not_less weight_spmf_nonneg) + +lemma spmf_le_weight: "spmf p x \ weight_spmf p" +proof - + have "ennreal (spmf p x) \ weight_spmf p" + unfolding weight_spmf_eq_nn_integral_spmf by(rule nn_integral_ge_point) simp + then show ?thesis by simp +qed + +lemma weight_spmf_eq_0: "weight_spmf p = 0 \ p = return_pmf None" +by(auto intro!: pmf_eqI simp add: pmf_None_eq_weight_spmf split: split_indicator)(metis not_Some_eq pmf_le_0_iff spmf_le_weight) + +lemma weight_bind_spmf: "weight_spmf (x \ f) = lebesgue_integral (measure_spmf x) (weight_spmf \ f)" +unfolding weight_spmf_def +by(simp add: measure_spmf_bind o_def measure_spmf.measure_bind[where N="count_space UNIV"]) + +lemma rel_spmf_weightD: "rel_spmf A p q \ weight_spmf p = weight_spmf q" +by(erule rel_spmfE) simp + +lemma rel_spmf_bij_betw: + assumes f: "bij_betw f (set_spmf p) (set_spmf q)" + and eq: "\x. x \ set_spmf p \ spmf p x = spmf q (f x)" + shows "rel_spmf (\x y. f x = y) p q" +proof - + let ?f = "map_option f" + + have weq: "ennreal (weight_spmf p) = ennreal (weight_spmf q)" + unfolding weight_spmf_eq_nn_integral_support + by(subst nn_integral_bij_count_space[OF f, symmetric])(rule nn_integral_cong_AE, simp add: eq AE_count_space) + then have "None \ set_pmf p \ None \ set_pmf q" + by(simp add: pmf_None_eq_weight_spmf set_pmf_iff) + with f have "bij_betw (map_option f) (set_pmf p) (set_pmf q)" + apply(auto simp add: bij_betw_def in_set_spmf inj_on_def intro: option.expand) + apply(rename_tac [!] x) + apply(case_tac [!] x) + apply(auto iff: in_set_spmf) + done + then have "rel_pmf (\x y. ?f x = y) p q" + by(rule rel_pmf_bij_betw)(case_tac x, simp_all add: weq[simplified] eq in_set_spmf pmf_None_eq_weight_spmf) + thus ?thesis by(rule pmf.rel_mono_strong)(auto intro!: rel_optionI simp add: Option.is_none_def) +qed + +subsection {* From density to spmfs *} + +context fixes f :: "'a \ real" begin + +definition embed_spmf :: "'a spmf" +where "embed_spmf = embed_pmf (\x. case x of None \ 1 - enn2real (\\<^sup>+ x. ennreal (f x) \count_space UNIV) | Some x' \ max 0 (f x'))" + +context + assumes prob: "(\\<^sup>+ x. ennreal (f x) \count_space UNIV) \ 1" +begin + +lemma nn_integral_embed_spmf_eq_1: + "(\\<^sup>+ x. ennreal (case x of None \ 1 - enn2real (\\<^sup>+ x. ennreal (f x) \count_space UNIV) | Some x' \ max 0 (f x')) \count_space UNIV) = 1" + (is "?lhs = _" is "(\\<^sup>+ x. ?f x \?M) = _") +proof - + have "?lhs = \\<^sup>+ x. ?f x * indicator {None} x + ?f x * indicator (range Some) x \?M" + by(rule nn_integral_cong)(auto split: split_indicator) + also have "\ = (1 - enn2real (\\<^sup>+ x. ennreal (f x) \count_space UNIV)) + \\<^sup>+ x. ?f x * indicator (range Some) x \?M" + (is "_ = ?None + ?Some") + by(subst nn_integral_add)(simp_all add: AE_count_space max_def le_diff_eq real_le_ereal_iff one_ereal_def[symmetric] prob split: option.split) + also have "?Some = \\<^sup>+ x. ?f x \count_space (range Some)" + by(simp add: nn_integral_count_space_indicator) + also have "count_space (range Some) = embed_measure (count_space UNIV) Some" + by(simp add: embed_measure_count_space) + also have "(\\<^sup>+ x. ?f x \\) = \\<^sup>+ x. ennreal (f x) \count_space UNIV" + by(subst nn_integral_embed_measure)(simp_all add: measurable_embed_measure1) + also have "?None + \ = 1" using prob + by(auto simp add: ennreal_minus[symmetric] ennreal_1[symmetric] ennreal_enn2real_if top_unique simp del: ennreal_1)(simp add: diff_add_self_ennreal) + finally show ?thesis . +qed + +lemma pmf_embed_spmf_None: "pmf embed_spmf None = 1 - enn2real (\\<^sup>+ x. ennreal (f x) \count_space UNIV)" +unfolding embed_spmf_def +apply(subst pmf_embed_pmf) + subgoal using prob by(simp add: field_simps enn2real_leI split: option.split) + apply(rule nn_integral_embed_spmf_eq_1) +apply simp +done + +lemma spmf_embed_spmf [simp]: "spmf embed_spmf x = max 0 (f x)" +unfolding embed_spmf_def +apply(subst pmf_embed_pmf) + subgoal using prob by(simp add: field_simps enn2real_leI split: option.split) + apply(rule nn_integral_embed_spmf_eq_1) +apply simp +done + +end + +end + +lemma embed_spmf_K_0[simp]: "embed_spmf (\_. 0) = return_pmf None" (is "?lhs = ?rhs") +by(rule spmf_eqI)(simp add: zero_ereal_def[symmetric]) + +subsection {* Ordering on spmfs *} + +text {* + @{const rel_pmf} does not preserve a ccpo structure. Counterexample by Saheb-Djahromi: + Take prefix order over @{text "bool llist"} and + the set @{text "range (\n :: nat. uniform (llist_n n))"} where @{text "llist_n"} is the set + of all @{text llist}s of length @{text n} and @{text uniform} returns a uniform distribution over + the given set. The set forms a chain in @{text "ord_pmf lprefix"}, but it has not an upper bound. + Any upper bound may contain only infinite lists in its support because otherwise it is not greater + than the @{text "n+1"}-st element in the chain where @{text n} is the length of the finite list. + Moreover its support must contain all infinite lists, because otherwise there is a finite list + all of whose finite extensions are not in the support - a contradiction to the upper bound property. + Hence, the support is uncountable, but pmf's only have countable support. + + However, if all chains in the ccpo are finite, then it should preserve the ccpo structure. +*} + +abbreviation ord_spmf :: "('a \ 'a \ bool) \ 'a spmf \ 'a spmf \ bool" +where "ord_spmf ord \ rel_pmf (ord_option ord)" + +locale ord_spmf_syntax begin +notation ord_spmf (infix "\\" 60) +end + +lemma ord_spmf_map_spmf1: "ord_spmf R (map_spmf f p) = ord_spmf (\x. R (f x)) p" +by(simp add: pmf.rel_map[abs_def] ord_option_map1[abs_def]) + +lemma ord_spmf_map_spmf2: "ord_spmf R p (map_spmf f q) = ord_spmf (\x y. R x (f y)) p q" +by(simp add: pmf.rel_map ord_option_map2) + +lemma ord_spmf_map_spmf12: "ord_spmf R (map_spmf f p) (map_spmf f q) = ord_spmf (\x y. R (f x) (f y)) p q" +by(simp add: pmf.rel_map ord_option_map1[abs_def] ord_option_map2) + +lemmas ord_spmf_map_spmf = ord_spmf_map_spmf1 ord_spmf_map_spmf2 ord_spmf_map_spmf12 + +context fixes ord :: "'a \ 'a \ bool" (structure) begin +interpretation ord_spmf_syntax . + +lemma ord_spmfI: + "\ \x y. (x, y) \ set_spmf pq \ ord x y; map_spmf fst pq = p; map_spmf snd pq = q \ + \ p \ q" +by(rule rel_pmf.intros[where pq="map_pmf (\x. case x of None \ (None, None) | Some (a, b) \ (Some a, Some b)) pq"]) + (auto simp add: pmf.map_comp o_def in_set_spmf split: option.splits intro: pmf.map_cong) + +lemma ord_spmf_None [simp]: "return_pmf None \ x" +by(rule rel_pmf.intros[where pq="map_pmf (Pair None) x"])(auto simp add: pmf.map_comp o_def) + +lemma ord_spmf_reflI: "(\x. x \ set_spmf p \ ord x x) \ p \ p" +by(rule rel_pmf_reflI ord_option_reflI)+(auto simp add: in_set_spmf) + +lemma rel_spmf_inf: + assumes "p \ q" + and "q \ p" + and refl: "reflp ord" + and trans: "transp ord" + shows "rel_spmf (inf ord ord\\) p q" +proof - + from \p \ q\ \q \ p\ + have "rel_pmf (inf (ord_option ord) (ord_option ord)\\) p q" + by(rule rel_pmf_inf)(blast intro: reflp_ord_option transp_ord_option refl trans)+ + also have "inf (ord_option ord) (ord_option ord)\\ = rel_option (inf ord ord\\)" + by(auto simp add: fun_eq_iff elim: ord_option.cases option.rel_cases) + finally show ?thesis . +qed + +end + +lemma ord_spmf_return_spmf2: "ord_spmf R p (return_spmf y) \ (\x\set_spmf p. R x y)" +by(auto simp add: rel_pmf_return_pmf2 in_set_spmf ord_option.simps intro: ccontr) + +lemma ord_spmf_mono: "\ ord_spmf A p q; \x y. A x y \ B x y \ \ ord_spmf B p q" +by(erule rel_pmf_mono)(erule ord_option_mono) + +lemma ord_spmf_compp: "ord_spmf (A OO B) = ord_spmf A OO ord_spmf B" +by(simp add: ord_option_compp pmf.rel_compp) + +lemma ord_spmf_bindI: + assumes pq: "ord_spmf R p q" + and fg: "\x y. R x y \ ord_spmf P (f x) (g y)" + shows "ord_spmf P (p \ f) (q \ g)" +unfolding bind_spmf_def using pq +by(rule rel_pmf_bindI)(auto split: option.split intro: fg) + +lemma ord_spmf_bind_reflI: + "(\x. x \ set_spmf p \ ord_spmf R (f x) (g x)) + \ ord_spmf R (p \ f) (p \ g)" +by(rule ord_spmf_bindI[where R="\x y. x = y \ x \ set_spmf p"])(auto intro: ord_spmf_reflI) + +lemma ord_pmf_increaseI: + assumes le: "\x. spmf p x \ spmf q x" + and refl: "\x. x \ set_spmf p \ R x x" + shows "ord_spmf R p q" +proof(rule rel_pmf.intros) + define pq where "pq = embed_pmf + (\(x, y). case x of Some x' \ (case y of Some y' \ if x' = y' then spmf p x' else 0 | None \ 0) + | None \ (case y of None \ pmf q None | Some y' \ spmf q y' - spmf p y'))" + (is "_ = embed_pmf ?f") + have nonneg: "\xy. ?f xy \ 0" + by(clarsimp simp add: le field_simps split: option.split) + have integral: "(\\<^sup>+ xy. ?f xy \count_space UNIV) = 1" (is "nn_integral ?M _ = _") + proof - + have "(\\<^sup>+ xy. ?f xy \count_space UNIV) = + \\<^sup>+ xy. ennreal (?f xy) * indicator {(None, None)} xy + + ennreal (?f xy) * indicator (range (\x. (None, Some x))) xy + + ennreal (?f xy) * indicator (range (\x. (Some x, Some x))) xy \?M" + by(rule nn_integral_cong)(auto split: split_indicator option.splits if_split_asm) + also have "\ = (\\<^sup>+ xy. ?f xy * indicator {(None, None)} xy \?M) + + (\\<^sup>+ xy. ennreal (?f xy) * indicator (range (\x. (None, Some x))) xy \?M) + + (\\<^sup>+ xy. ennreal (?f xy) * indicator (range (\x. (Some x, Some x))) xy \?M)" + (is "_ = ?None + ?Some2 + ?Some") + by(subst nn_integral_add)(simp_all add: nn_integral_add AE_count_space le_diff_eq le split: option.split) + also have "?None = pmf q None" by simp + also have "?Some2 = \\<^sup>+ x. ennreal (spmf q x) - spmf p x \count_space UNIV" + by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1 ennreal_minus) + also have "\ = (\\<^sup>+ x. spmf q x \count_space UNIV) - (\\<^sup>+ x. spmf p x \count_space UNIV)" + (is "_ = ?Some2' - ?Some2''") + by(subst nn_integral_diff)(simp_all add: le nn_integral_spmf_neq_top) + also have "?Some = \\<^sup>+ x. spmf p x \count_space UNIV" + by(simp add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] inj_on_def nn_integral_embed_measure measurable_embed_measure1) + also have "pmf q None + (?Some2' - ?Some2'') + \ = pmf q None + ?Some2'" + by(auto simp add: diff_add_self_ennreal le intro!: nn_integral_mono) + also have "\ = \\<^sup>+ x. ennreal (pmf q x) * indicator {None} x + ennreal (pmf q x) * indicator (range Some) x \count_space UNIV" + by(subst nn_integral_add)(simp_all add: nn_integral_count_space_indicator[symmetric] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1) + also have "\ = \\<^sup>+ x. pmf q x \count_space UNIV" + by(rule nn_integral_cong)(auto split: split_indicator) + also have "\ = 1" by(simp add: nn_integral_pmf) + finally show ?thesis . + qed + note f = nonneg integral + + { fix x y + assume "(x, y) \ set_pmf pq" + hence "?f (x, y) \ 0" unfolding pq_def by(simp add: set_embed_pmf[OF f]) + then show "ord_option R x y" + by(simp add: spmf_eq_0_set_spmf refl split: option.split_asm if_split_asm) } + + have weight_le: "weight_spmf p \ weight_spmf q" + by(subst ennreal_le_iff[symmetric])(auto simp add: weight_spmf_eq_nn_integral_spmf intro!: nn_integral_mono le) + + show "map_pmf fst pq = p" + proof(rule pmf_eqI) + fix i + have "ennreal (pmf (map_pmf fst pq) i) = (\\<^sup>+ y. pmf pq (i, y) \count_space UNIV)" + unfolding pq_def ennreal_pmf_map + apply(simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density nn_integral_count_space_indicator[symmetric]) + apply(subst pmf_embed_pmf[OF f]) + apply(rule nn_integral_bij_count_space[symmetric]) + apply(auto simp add: bij_betw_def inj_on_def) + done + also have "\ = pmf p i" + proof(cases i) + case (Some x) + have "(\\<^sup>+ y. pmf pq (Some x, y) \count_space UNIV) = \\<^sup>+ y. pmf p (Some x) * indicator {Some x} y \count_space UNIV" + by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split) + then show ?thesis using Some by simp + next + case None + have "(\\<^sup>+ y. pmf pq (None, y) \count_space UNIV) = + (\\<^sup>+ y. ennreal (pmf pq (None, Some (the y))) * indicator (range Some) y + + ennreal (pmf pq (None, None)) * indicator {None} y \count_space UNIV)" + by(rule nn_integral_cong)(auto split: split_indicator) + also have "\ = (\\<^sup>+ y. ennreal (pmf pq (None, Some (the y))) \count_space (range Some)) + pmf pq (None, None)" + by(subst nn_integral_add)(simp_all add: nn_integral_count_space_indicator) + also have "\ = (\\<^sup>+ y. ennreal (spmf q y) - ennreal (spmf p y) \count_space UNIV) + pmf q None" + by(simp add: pq_def pmf_embed_pmf[OF f] embed_measure_count_space[symmetric] nn_integral_embed_measure measurable_embed_measure1 ennreal_minus) + also have "(\\<^sup>+ y. ennreal (spmf q y) - ennreal (spmf p y) \count_space UNIV) = + (\\<^sup>+ y. spmf q y \count_space UNIV) - (\\<^sup>+ y. spmf p y \count_space UNIV)" + by(subst nn_integral_diff)(simp_all add: AE_count_space le nn_integral_spmf_neq_top split: split_indicator) + also have "\ = pmf p None - pmf q None" + by(simp add: pmf_None_eq_weight_spmf weight_spmf_eq_nn_integral_spmf[symmetric] ennreal_minus) + also have "\ = ennreal (pmf p None) - ennreal (pmf q None)" by(simp add: ennreal_minus) + finally show ?thesis using None weight_le + by(auto simp add: diff_add_self_ennreal pmf_None_eq_weight_spmf intro: ennreal_leI) + qed + finally show "pmf (map_pmf fst pq) i = pmf p i" by simp + qed + + show "map_pmf snd pq = q" + proof(rule pmf_eqI) + fix i + have "ennreal (pmf (map_pmf snd pq) i) = (\\<^sup>+ x. pmf pq (x, i) \count_space UNIV)" + unfolding pq_def ennreal_pmf_map + apply(simp add: embed_pmf.rep_eq[OF f] o_def emeasure_density nn_integral_count_space_indicator[symmetric]) + apply(subst pmf_embed_pmf[OF f]) + apply(rule nn_integral_bij_count_space[symmetric]) + apply(auto simp add: bij_betw_def inj_on_def) + done + also have "\ = ennreal (pmf q i)" + proof(cases i) + case None + have "(\\<^sup>+ x. pmf pq (x, None) \count_space UNIV) = \\<^sup>+ x. pmf q None * indicator {None :: 'a option} x \count_space UNIV" + by(rule nn_integral_cong)(simp add: pq_def pmf_embed_pmf[OF f] split: option.split) + then show ?thesis using None by simp + next + case (Some y) + have "(\\<^sup>+ x. pmf pq (x, Some y) \count_space UNIV) = + (\\<^sup>+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x + + ennreal (pmf pq (None, Some y)) * indicator {None} x \count_space UNIV)" + by(rule nn_integral_cong)(auto split: split_indicator) + also have "\ = (\\<^sup>+ x. ennreal (pmf pq (x, Some y)) * indicator (range Some) x \count_space UNIV) + pmf pq (None, Some y)" + by(subst nn_integral_add)(simp_all) + also have "\ = (\\<^sup>+ x. ennreal (spmf p y) * indicator {Some y} x \count_space UNIV) + (spmf q y - spmf p y)" + by(auto simp add: pq_def pmf_embed_pmf[OF f] one_ereal_def[symmetric] simp del: nn_integral_indicator_singleton intro!: arg_cong2[where f="op +"] nn_integral_cong split: option.split) + also have "\ = spmf q y" by(simp add: ennreal_minus[symmetric] le) + finally show ?thesis using Some by simp + qed + finally show "pmf (map_pmf snd pq) i = pmf q i" by simp + qed +qed + +lemma ord_spmf_eq_leD: + assumes "ord_spmf op = p q" + shows "spmf p x \ spmf q x" +proof(cases "x \ set_spmf p") + case False + thus ?thesis by(simp add: in_set_spmf_iff_spmf) +next + case True + from assms obtain pq + where pq: "\x y. (x, y) \ set_pmf pq \ ord_option op = x y" + and p: "p = map_pmf fst pq" + and q: "q = map_pmf snd pq" by cases auto + have "ennreal (spmf p x) = integral\<^sup>N pq (indicator (fst -` {Some x}))" + using p by(simp add: ennreal_pmf_map) + also have "\ = integral\<^sup>N pq (indicator {(Some x, Some x)})" + by(rule nn_integral_cong_AE)(auto simp add: AE_measure_pmf_iff split: split_indicator dest: pq) + also have "\ \ integral\<^sup>N pq (indicator (snd -` {Some x}))" + by(rule nn_integral_mono) simp + also have "\ = ennreal (spmf q x)" using q by(simp add: ennreal_pmf_map) + finally show ?thesis by simp +qed + +lemma ord_spmf_eqD_set_spmf: "ord_spmf op = p q \ set_spmf p \ set_spmf q" +by(rule subsetI)(drule_tac x=x in ord_spmf_eq_leD, auto simp add: in_set_spmf_iff_spmf) + +lemma ord_spmf_eqD_emeasure: + "ord_spmf op = p q \ emeasure (measure_spmf p) A \ emeasure (measure_spmf q) A" +by(auto intro!: nn_integral_mono split: split_indicator dest: ord_spmf_eq_leD simp add: nn_integral_measure_spmf nn_integral_indicator[symmetric]) + +lemma ord_spmf_eqD_measure_spmf: "ord_spmf op = p q \ measure_spmf p \ measure_spmf q" +by(rule less_eq_measure.intros)(simp_all add: ord_spmf_eqD_emeasure) + + +subsection {* CCPO structure for the flat ccpo @{term "ord_option op ="} *} + +context fixes Y :: "'a spmf set" begin + +definition lub_spmf :: "'a spmf" +where "lub_spmf = embed_spmf (\x. enn2real (SUP p : Y. ennreal (spmf p x)))" + -- \We go through @{typ ennreal} to have a sensible definition even if @{term Y} is empty.\ + +lemma lub_spmf_empty [simp]: "SPMF.lub_spmf {} = return_pmf None" +by(simp add: SPMF.lub_spmf_def bot_ereal_def) + +context assumes chain: "Complete_Partial_Order.chain (ord_spmf op =) Y" begin + +lemma chain_ord_spmf_eqD: "Complete_Partial_Order.chain (op \) ((\p x. ennreal (spmf p x)) ` Y)" + (is "Complete_Partial_Order.chain _ (?f ` _)") +proof(rule chainI) + fix f g + assume "f \ ?f ` Y" "g \ ?f ` Y" + then obtain p q where f: "f = ?f p" "p \ Y" and g: "g = ?f q" "q \ Y" by blast + from chain \p \ Y\ \q \ Y\ have "ord_spmf op = p q \ ord_spmf op = q p" by(rule chainD) + thus "f \ g \ g \ f" + proof + assume "ord_spmf op = p q" + hence "\x. spmf p x \ spmf q x" by(rule ord_spmf_eq_leD) + hence "f \ g" unfolding f g by(auto intro: le_funI) + thus ?thesis .. + next + assume "ord_spmf op = q p" + hence "\x. spmf q x \ spmf p x" by(rule ord_spmf_eq_leD) + hence "g \ f" unfolding f g by(auto intro: le_funI) + thus ?thesis .. + qed +qed + +lemma ord_spmf_eq_pmf_None_eq: + assumes le: "ord_spmf op = p q" + and None: "pmf p None = pmf q None" + shows "p = q" +proof(rule spmf_eqI) + fix i + from le have le': "\x. spmf p x \ spmf q x" by(rule ord_spmf_eq_leD) + have "(\\<^sup>+ x. ennreal (spmf q x) - spmf p x \count_space UNIV) = + (\\<^sup>+ x. spmf q x \count_space UNIV) - (\\<^sup>+ x. spmf p x \count_space UNIV)" + by(subst nn_integral_diff)(simp_all add: AE_count_space le' nn_integral_spmf_neq_top) + also have "\ = (1 - pmf q None) - (1 - pmf p None)" unfolding pmf_None_eq_weight_spmf + by(simp add: weight_spmf_eq_nn_integral_spmf[symmetric] ennreal_minus) + also have "\ = 0" using None by simp + finally have "\x. spmf q x \ spmf p x" + by(simp add: nn_integral_0_iff_AE AE_count_space ennreal_minus ennreal_eq_0_iff) + with le' show "spmf p i = spmf q i" by(rule antisym) +qed + +lemma ord_spmf_eqD_pmf_None: + assumes "ord_spmf op = x y" + shows "pmf x None \ pmf y None" +using assms +apply cases +apply(clarsimp simp only: ennreal_le_iff[symmetric, OF pmf_nonneg] ennreal_pmf_map) +apply(fastforce simp add: AE_measure_pmf_iff intro!: nn_integral_mono_AE) +done + +text \ + Chains on @{typ "'a spmf"} maintain countable support. + Thanks to Johannes Hölzl for the proof idea. +\ +lemma spmf_chain_countable: "countable (\p\Y. set_spmf p)" +proof(cases "Y = {}") + case Y: False + show ?thesis + proof(cases "\x\Y. \y\Y. ord_spmf op = y x") + case True + then obtain x where x: "x \ Y" and upper: "\y. y \ Y \ ord_spmf op = y x" by blast + hence "(\x\Y. set_spmf x) \ set_spmf x" by(auto dest: ord_spmf_eqD_set_spmf) + thus ?thesis by(rule countable_subset) simp + next + case False + define N :: "'a option pmf \ real" where "N p = pmf p None" for p + + have N_less_imp_le_spmf: "\ x \ Y; y \ Y; N y < N x \ \ ord_spmf op = x y" for x y + using chainD[OF chain, of x y] ord_spmf_eqD_pmf_None[of x y] ord_spmf_eqD_pmf_None[of y x] + by (auto simp: N_def) + have N_eq_imp_eq: "\ x \ Y; y \ Y; N y = N x \ \ x = y" for x y + using chainD[OF chain, of x y] by(auto simp add: N_def dest: ord_spmf_eq_pmf_None_eq) + + have NC: "N ` Y \ {}" "bdd_below (N ` Y)" + using \Y \ {}\ by(auto intro!: bdd_belowI[of _ 0] simp: N_def) + have NC_less: "Inf (N ` Y) < N x" if "x \ Y" for x unfolding cInf_less_iff[OF NC] + proof(rule ccontr) + assume **: "\ (\y\N ` Y. y < N x)" + { fix y + assume "y \ Y" + with ** consider "N x < N y" | "N x = N y" by(auto simp add: not_less le_less) + hence "ord_spmf op = y x" using \y \ Y\ \x \ Y\ + by cases(auto dest: N_less_imp_le_spmf N_eq_imp_eq intro: ord_spmf_reflI) } + with False \x \ Y\ show False by blast + qed + + from NC have "Inf (N ` Y) \ closure (N ` Y)" by (intro closure_contains_Inf) + then obtain X' where "\n. X' n \ N ` Y" and X': "X' \ Inf (N ` Y)" + unfolding closure_sequential by auto + then obtain X where X: "\n. X n \ Y" and "X' = (\n. N (X n))" unfolding image_iff Bex_def by metis + + with X' have seq: "(\n. N (X n)) \ Inf (N ` Y)" by simp + have "(\x \ Y. set_spmf x) \ (\n. set_spmf (X n))" + proof(rule UN_least) + fix x + assume "x \ Y" + from order_tendstoD(2)[OF seq NC_less[OF \x \ Y\]] + obtain i where "N (X i) < N x" by (auto simp: eventually_sequentially) + thus "set_spmf x \ (\n. set_spmf (X n))" using X \x \ Y\ + by(blast dest: N_less_imp_le_spmf ord_spmf_eqD_set_spmf) + qed + thus ?thesis by(rule countable_subset) simp + qed +qed simp + +lemma lub_spmf_subprob: "(\\<^sup>+ x. (SUP p : Y. ennreal (spmf p x)) \count_space UNIV) \ 1" +proof(cases "Y = {}") + case True + thus ?thesis by(simp add: bot_ennreal) +next + case False + let ?B = "\p\Y. set_spmf p" + have countable: "countable ?B" by(rule spmf_chain_countable) + + have "(\\<^sup>+ x. (SUP p:Y. ennreal (spmf p x)) \count_space UNIV) = + (\\<^sup>+ x. (SUP p:Y. ennreal (spmf p x) * indicator ?B x) \count_space UNIV)" + by(intro nn_integral_cong SUP_cong)(auto split: split_indicator simp add: spmf_eq_0_set_spmf) + also have "\ = (\\<^sup>+ x. (SUP p:Y. ennreal (spmf p x)) \count_space ?B)" + unfolding ennreal_indicator[symmetric] using False + by(subst SUP_mult_right_ennreal[symmetric])(simp add: ennreal_indicator nn_integral_count_space_indicator) + also have "\ = (SUP p:Y. \\<^sup>+ x. spmf p x \count_space ?B)" using False _ countable + by(rule nn_integral_monotone_convergence_SUP_countable)(rule chain_ord_spmf_eqD) + also have "\ \ 1" + proof(rule SUP_least) + fix p + assume "p \ Y" + have "(\\<^sup>+ x. spmf p x \count_space ?B) = \\<^sup>+ x. ennreal (spmf p x) * indicator ?B x \count_space UNIV" + by(simp add: nn_integral_count_space_indicator) + also have "\ = \\<^sup>+ x. spmf p x \count_space UNIV" + by(rule nn_integral_cong)(auto split: split_indicator simp add: spmf_eq_0_set_spmf \p \ Y\) + also have "\ \ 1" + by(simp add: weight_spmf_eq_nn_integral_spmf[symmetric] weight_spmf_le_1) + finally show "(\\<^sup>+ x. spmf p x \count_space ?B) \ 1" . + qed + finally show ?thesis . +qed + +lemma spmf_lub_spmf: + assumes "Y \ {}" + shows "spmf lub_spmf x = (SUP p : Y. spmf p x)" +proof - + from assms obtain p where "p \ Y" by auto + have "spmf lub_spmf x = max 0 (enn2real (SUP p:Y. ennreal (spmf p x)))" unfolding lub_spmf_def + by(rule spmf_embed_spmf)(simp del: SUP_eq_top_iff Sup_eq_top_iff add: ennreal_enn2real_if SUP_spmf_neq_top' lub_spmf_subprob) + also have "\ = enn2real (SUP p:Y. ennreal (spmf p x))" + by(rule max_absorb2)(simp) + also have "\ = enn2real (ennreal (SUP p : Y. spmf p x))" using assms + by(subst ennreal_SUP[symmetric])(simp_all add: SUP_spmf_neq_top' del: SUP_eq_top_iff Sup_eq_top_iff) + also have "0 \ (\p\Y. spmf p x)" using assms + by(auto intro!: cSUP_upper2 bdd_aboveI[where M=1] simp add: pmf_le_1) + then have "enn2real (ennreal (SUP p : Y. spmf p x)) = (SUP p : Y. spmf p x)" + by(rule enn2real_ennreal) + finally show ?thesis . +qed + +lemma ennreal_spmf_lub_spmf: "Y \ {} \ ennreal (spmf lub_spmf x) = (SUP p:Y. ennreal (spmf p x))" +unfolding spmf_lub_spmf by(subst ennreal_SUP)(simp_all add: SUP_spmf_neq_top' del: SUP_eq_top_iff Sup_eq_top_iff) + +lemma lub_spmf_upper: + assumes p: "p \ Y" + shows "ord_spmf op = p lub_spmf" +proof(rule ord_pmf_increaseI) + fix x + from p have [simp]: "Y \ {}" by auto + from p have "ennreal (spmf p x) \ (SUP p:Y. ennreal (spmf p x))" by(rule SUP_upper) + also have "\ = ennreal (spmf lub_spmf x)" using p + by(subst spmf_lub_spmf)(auto simp add: ennreal_SUP SUP_spmf_neq_top' simp del: SUP_eq_top_iff Sup_eq_top_iff) + finally show "spmf p x \ spmf lub_spmf x" by simp +qed simp + +lemma lub_spmf_least: + assumes z: "\x. x \ Y \ ord_spmf op = x z" + shows "ord_spmf op = lub_spmf z" +proof(cases "Y = {}") + case nonempty: False + show ?thesis + proof(rule ord_pmf_increaseI) + fix x + from nonempty obtain p where p: "p \ Y" by auto + have "ennreal (spmf lub_spmf x) = (SUP p:Y. ennreal (spmf p x))" + by(subst spmf_lub_spmf)(auto simp add: ennreal_SUP SUP_spmf_neq_top' nonempty simp del: SUP_eq_top_iff Sup_eq_top_iff) + also have "\ \ ennreal (spmf z x)" by(rule SUP_least)(simp add: ord_spmf_eq_leD z) + finally show "spmf lub_spmf x \ spmf z x" by simp + qed simp +qed simp + +lemma set_lub_spmf: "set_spmf lub_spmf = (\p\Y. set_spmf p)" (is "?lhs = ?rhs") +proof(cases "Y = {}") + case [simp]: False + show ?thesis + proof(rule set_eqI) + fix x + have "x \ ?lhs \ ennreal (spmf lub_spmf x) > 0" + by(simp_all add: in_set_spmf_iff_spmf less_le) + also have "\ \ (\p\Y. ennreal (spmf p x) > 0)" + by(simp add: ennreal_spmf_lub_spmf less_SUP_iff) + also have "\ \ x \ ?rhs" + by(auto simp add: in_set_spmf_iff_spmf less_le) + finally show "x \ ?lhs \ x \ ?rhs" . + qed +qed simp + +lemma emeasure_lub_spmf: + assumes Y: "Y \ {}" + shows "emeasure (measure_spmf lub_spmf) A = (SUP y:Y. emeasure (measure_spmf y) A)" + (is "?lhs = ?rhs") +proof - + let ?M = "count_space (set_spmf lub_spmf)" + have "?lhs = \\<^sup>+ x. ennreal (spmf lub_spmf x) * indicator A x \?M" + by(auto simp add: nn_integral_indicator[symmetric] nn_integral_measure_spmf') + also have "\ = \\<^sup>+ x. (SUP y:Y. ennreal (spmf y x) * indicator A x) \?M" + unfolding ennreal_indicator[symmetric] + by(simp add: spmf_lub_spmf assms ennreal_SUP[OF SUP_spmf_neq_top'] SUP_mult_right_ennreal) + also from assms have "\ = (SUP y:Y. \\<^sup>+ x. ennreal (spmf y x) * indicator A x \?M)" + proof(rule nn_integral_monotone_convergence_SUP_countable) + have "(\i x. ennreal (spmf i x) * indicator A x) ` Y = (\f x. f x * indicator A x) ` (\p x. ennreal (spmf p x)) ` Y" + by(simp add: image_image) + also have "Complete_Partial_Order.chain op \ \" using chain_ord_spmf_eqD + by(rule chain_imageI)(auto simp add: le_fun_def split: split_indicator) + finally show "Complete_Partial_Order.chain op \ ((\i x. ennreal (spmf i x) * indicator A x) ` Y)" . + qed simp + also have "\ = (SUP y:Y. \\<^sup>+ x. ennreal (spmf y x) * indicator A x \count_space UNIV)" + by(auto simp add: nn_integral_count_space_indicator set_lub_spmf spmf_eq_0_set_spmf split: split_indicator intro!: SUP_cong nn_integral_cong) + also have "\ = ?rhs" + by(auto simp add: nn_integral_indicator[symmetric] nn_integral_measure_spmf) + finally show ?thesis . +qed + +lemma measure_lub_spmf: + assumes Y: "Y \ {}" + shows "measure (measure_spmf lub_spmf) A = (SUP y:Y. measure (measure_spmf y) A)" (is "?lhs = ?rhs") +proof - + have "ennreal ?lhs = ennreal ?rhs" + using emeasure_lub_spmf[OF assms] SUP_emeasure_spmf_neq_top[of A Y] Y + unfolding measure_spmf.emeasure_eq_measure by(subst ennreal_SUP) + moreover have "0 \ ?rhs" using Y + by(auto intro!: cSUP_upper2 bdd_aboveI[where M=1] measure_spmf.subprob_measure_le_1) + ultimately show ?thesis by(simp) +qed + +lemma weight_lub_spmf: + assumes Y: "Y \ {}" + shows "weight_spmf lub_spmf = (SUP y:Y. weight_spmf y)" +unfolding weight_spmf_def by(rule measure_lub_spmf) fact + +lemma measure_spmf_lub_spmf: + assumes Y: "Y \ {}" + shows "measure_spmf lub_spmf = (SUP p:Y. measure_spmf p)" (is "?lhs = ?rhs") +proof(rule measure_eqI) + from assms obtain p where p: "p \ Y" by auto + from chain have chain': "Complete_Partial_Order.chain op \ (measure_spmf ` Y)" + by(rule chain_imageI)(rule ord_spmf_eqD_measure_spmf) + show "sets ?lhs = sets ?rhs" and "emeasure ?lhs A = emeasure ?rhs A" for A + using chain' Y p by(simp_all add: sets_SUP emeasure_SUP emeasure_lub_spmf) +qed + +end + +end + +lemma partial_function_definitions_spmf: "partial_function_definitions (ord_spmf op =) lub_spmf" + (is "partial_function_definitions ?R _") +proof + fix x show "?R x x" by(simp add: ord_spmf_reflI) +next + fix x y z + assume "?R x y" "?R y z" + with transp_ord_option[OF transp_equality] show "?R x z" by(rule transp_rel_pmf[THEN transpD]) +next + fix x y + assume "?R x y" "?R y x" + thus "x = y" + by(rule rel_pmf_antisym)(simp_all add: reflp_ord_option transp_ord_option antisymP_ord_option) +next + fix Y x + assume "Complete_Partial_Order.chain ?R Y" "x \ Y" + then show "?R x (lub_spmf Y)" + by(rule lub_spmf_upper) +next + fix Y z + assume "Complete_Partial_Order.chain ?R Y" "\x. x \ Y \ ?R x z" + then show "?R (lub_spmf Y) z" + by(cases "Y = {}")(simp_all add: lub_spmf_least) +qed + +lemma ccpo_spmf: "class.ccpo lub_spmf (ord_spmf op =) (mk_less (ord_spmf op =))" +by(rule ccpo partial_function_definitions_spmf)+ + +interpretation spmf: partial_function_definitions "ord_spmf op =" "lub_spmf" + rewrites "lub_spmf {} \ return_pmf None" +by(rule partial_function_definitions_spmf) simp + +declaration {* Partial_Function.init "spmf" @{term spmf.fixp_fun} + @{term spmf.mono_body} @{thm spmf.fixp_rule_uc} @{thm spmf.fixp_induct_uc} + NONE *} + +declare spmf.leq_refl[simp] +declare admissible_leI[OF ccpo_spmf, cont_intro] + +abbreviation "mono_spmf \ monotone (fun_ord (ord_spmf op =)) (ord_spmf op =)" + +lemma lub_spmf_const [simp]: "lub_spmf {p} = p" +by(rule spmf_eqI)(simp add: spmf_lub_spmf[OF ccpo.chain_singleton[OF ccpo_spmf]]) + +lemma bind_spmf_mono': + assumes fg: "ord_spmf op = f g" + and hk: "\x :: 'a. ord_spmf op = (h x) (k x)" + shows "ord_spmf op = (f \ h) (g \ k)" +unfolding bind_spmf_def using assms(1) +by(rule rel_pmf_bindI)(auto split: option.split simp add: hk) + +lemma bind_spmf_mono [partial_function_mono]: + assumes mf: "mono_spmf B" and mg: "\y. mono_spmf (\f. C y f)" + shows "mono_spmf (\f. bind_spmf (B f) (\y. C y f))" +proof (rule monotoneI) + fix f g :: "'a \ 'b spmf" + assume fg: "fun_ord (ord_spmf op =) f g" + with mf have "ord_spmf op = (B f) (B g)" by (rule monotoneD[of _ _ _ f g]) + moreover from mg have "\y'. ord_spmf op = (C y' f) (C y' g)" + by (rule monotoneD) (rule fg) + ultimately show "ord_spmf op = (bind_spmf (B f) (\y. C y f)) (bind_spmf (B g) (\y'. C y' g))" + by(rule bind_spmf_mono') +qed + +lemma monotone_bind_spmf1: "monotone (ord_spmf op =) (ord_spmf op =) (\y. bind_spmf y g)" +by(rule monotoneI)(simp add: bind_spmf_mono' ord_spmf_reflI) + +lemma monotone_bind_spmf2: + assumes g: "\x. monotone ord (ord_spmf op =) (\y. g y x)" + shows "monotone ord (ord_spmf op =) (\y. bind_spmf p (g y))" +by(rule monotoneI)(auto intro: bind_spmf_mono' monotoneD[OF g] ord_spmf_reflI) + +lemma bind_lub_spmf: + assumes chain: "Complete_Partial_Order.chain (ord_spmf op =) Y" + shows "bind_spmf (lub_spmf Y) f = lub_spmf ((\p. bind_spmf p f) ` Y)" (is "?lhs = ?rhs") +proof(cases "Y = {}") + case Y: False + show ?thesis + proof(rule spmf_eqI) + fix i + have chain': "Complete_Partial_Order.chain op \ ((\p x. ennreal (spmf p x * spmf (f x) i)) ` Y)" + using chain by(rule chain_imageI)(auto simp add: le_fun_def dest: ord_spmf_eq_leD intro: mult_right_mono) + have chain'': "Complete_Partial_Order.chain (ord_spmf op =) ((\p. p \ f) ` Y)" + using chain by(rule chain_imageI)(auto intro!: monotoneI bind_spmf_mono' ord_spmf_reflI) + let ?M = "count_space (set_spmf (lub_spmf Y))" + have "ennreal (spmf ?lhs i) = \\<^sup>+ x. ennreal (spmf (lub_spmf Y) x) * ennreal (spmf (f x) i) \?M" + by(auto simp add: ennreal_spmf_lub_spmf ennreal_spmf_bind nn_integral_measure_spmf') + also have "\ = \\<^sup>+ x. (SUP p:Y. ennreal (spmf p x * spmf (f x) i)) \?M" + by(subst ennreal_spmf_lub_spmf[OF chain Y])(subst SUP_mult_right_ennreal, simp_all add: ennreal_mult Y) + also have "\ = (SUP p:Y. \\<^sup>+ x. ennreal (spmf p x * spmf (f x) i) \?M)" + using Y chain' by(rule nn_integral_monotone_convergence_SUP_countable) simp + also have "\ = (SUP p:Y. ennreal (spmf (bind_spmf p f) i))" + by(auto simp add: ennreal_spmf_bind nn_integral_measure_spmf nn_integral_count_space_indicator set_lub_spmf[OF chain] in_set_spmf_iff_spmf ennreal_mult intro!: SUP_cong nn_integral_cong split: split_indicator) + also have "\ = ennreal (spmf ?rhs i)" using chain'' by(simp add: ennreal_spmf_lub_spmf Y) + finally show "spmf ?lhs i = spmf ?rhs i" by simp + qed +qed simp + +lemma map_lub_spmf: + "Complete_Partial_Order.chain (ord_spmf op =) Y + \ map_spmf f (lub_spmf Y) = lub_spmf (map_spmf f ` Y)" +unfolding map_spmf_conv_bind_spmf[abs_def] by(simp add: bind_lub_spmf o_def) + +lemma mcont_bind_spmf1: "mcont lub_spmf (ord_spmf op =) lub_spmf (ord_spmf op =) (\y. bind_spmf y f)" +using monotone_bind_spmf1 by(rule mcontI)(rule contI, simp add: bind_lub_spmf) + +lemma bind_lub_spmf2: + assumes chain: "Complete_Partial_Order.chain ord Y" + and g: "\y. monotone ord (ord_spmf op =) (g y)" + shows "bind_spmf x (\y. lub_spmf (g y ` Y)) = lub_spmf ((\p. bind_spmf x (\y. g y p)) ` Y)" + (is "?lhs = ?rhs") +proof(cases "Y = {}") + case Y: False + show ?thesis + proof(rule spmf_eqI) + fix i + have chain': "\y. Complete_Partial_Order.chain (ord_spmf op =) (g y ` Y)" + using chain g[THEN monotoneD] by(rule chain_imageI) + have chain'': "Complete_Partial_Order.chain op \ ((\p y. ennreal (spmf x y * spmf (g y p) i)) ` Y)" + using chain by(rule chain_imageI)(auto simp add: le_fun_def dest: ord_spmf_eq_leD monotoneD[OF g] intro!: mult_left_mono) + have chain''': "Complete_Partial_Order.chain (ord_spmf op =) ((\p. bind_spmf x (\y. g y p)) ` Y)" + using chain by(rule chain_imageI)(rule monotone_bind_spmf2[OF g, THEN monotoneD]) + + have "ennreal (spmf ?lhs i) = \\<^sup>+ y. (SUP p:Y. ennreal (spmf x y * spmf (g y p) i)) \count_space (set_spmf x)" + by(simp add: ennreal_spmf_bind ennreal_spmf_lub_spmf[OF chain'] Y nn_integral_measure_spmf' SUP_mult_left_ennreal ennreal_mult) + also have "\ = (SUP p:Y. \\<^sup>+ y. ennreal (spmf x y * spmf (g y p) i) \count_space (set_spmf x))" + unfolding nn_integral_measure_spmf' using Y chain'' + by(rule nn_integral_monotone_convergence_SUP_countable) simp + also have "\ = (SUP p:Y. ennreal (spmf (bind_spmf x (\y. g y p)) i))" + by(simp add: ennreal_spmf_bind nn_integral_measure_spmf' ennreal_mult) + also have "\ = ennreal (spmf ?rhs i)" using chain''' + by(auto simp add: ennreal_spmf_lub_spmf Y) + finally show "spmf ?lhs i = spmf ?rhs i" by simp + qed +qed simp + +lemma mcont_bind_spmf [cont_intro]: + assumes f: "mcont luba orda lub_spmf (ord_spmf op =) f" + and g: "\y. mcont luba orda lub_spmf (ord_spmf op =) (g y)" + shows "mcont luba orda lub_spmf (ord_spmf op =) (\x. bind_spmf (f x) (\y. g y x))" +proof(rule spmf.mcont2mcont'[OF _ _ f]) + fix z + show "mcont lub_spmf (ord_spmf op =) lub_spmf (ord_spmf op =) (\x. bind_spmf x (\y. g y z))" + by(rule mcont_bind_spmf1) +next + fix x + let ?f = "\z. bind_spmf x (\y. g y z)" + have "monotone orda (ord_spmf op =) ?f" using mcont_mono[OF g] by(rule monotone_bind_spmf2) + moreover have "cont luba orda lub_spmf (ord_spmf op =) ?f" + proof(rule contI) + fix Y + assume chain: "Complete_Partial_Order.chain orda Y" and Y: "Y \ {}" + have "bind_spmf x (\y. g y (luba Y)) = bind_spmf x (\y. lub_spmf (g y ` Y))" + by(rule bind_spmf_cong)(simp_all add: mcont_contD[OF g chain Y]) + also have "\ = lub_spmf ((\p. x \ (\y. g y p)) ` Y)" using chain + by(rule bind_lub_spmf2)(rule mcont_mono[OF g]) + finally show "bind_spmf x (\y. g y (luba Y)) = \" . + qed + ultimately show "mcont luba orda lub_spmf (ord_spmf op =) ?f" by(rule mcontI) +qed + +lemma bind_pmf_mono [partial_function_mono]: + "(\y. mono_spmf (\f. C y f)) \ mono_spmf (\f. bind_pmf p (\x. C x f))" +using bind_spmf_mono[of "\_. spmf_of_pmf p" C] by simp + +lemma map_spmf_mono [partial_function_mono]: "mono_spmf B \ mono_spmf (\g. map_spmf f (B g))" +unfolding map_spmf_conv_bind_spmf by(rule bind_spmf_mono) simp_all + +lemma mcont_map_spmf [cont_intro]: + "mcont luba orda lub_spmf (ord_spmf op =) g + \ mcont luba orda lub_spmf (ord_spmf op =) (\x. map_spmf f (g x))" +unfolding map_spmf_conv_bind_spmf by(rule mcont_bind_spmf) simp_all + +lemma monotone_set_spmf: "monotone (ord_spmf op =) op \ set_spmf" +by(rule monotoneI)(rule ord_spmf_eqD_set_spmf) + +lemma cont_set_spmf: "cont lub_spmf (ord_spmf op =) Union op \ set_spmf" +by(rule contI)(subst set_lub_spmf; simp) + +lemma mcont2mcont_set_spmf[THEN mcont2mcont, cont_intro]: + shows mcont_set_spmf: "mcont lub_spmf (ord_spmf op =) Union op \ set_spmf" +by(rule mcontI monotone_set_spmf cont_set_spmf)+ + +lemma monotone_spmf: "monotone (ord_spmf op =) op \ (\p. spmf p x)" +by(rule monotoneI)(simp add: ord_spmf_eq_leD) + +lemma cont_spmf: "cont lub_spmf (ord_spmf op =) Sup op \ (\p. spmf p x)" +by(rule contI)(simp add: spmf_lub_spmf) + +lemma mcont_spmf: "mcont lub_spmf (ord_spmf op =) Sup op \ (\p. spmf p x)" +by(rule mcontI monotone_spmf cont_spmf)+ + +lemma cont_ennreal_spmf: "cont lub_spmf (ord_spmf op =) Sup op \ (\p. ennreal (spmf p x))" +by(rule contI)(simp add: ennreal_spmf_lub_spmf) + +lemma mcont2mcont_ennreal_spmf [THEN mcont2mcont, cont_intro]: + shows mcont_ennreal_spmf: "mcont lub_spmf (ord_spmf op =) Sup op \ (\p. ennreal (spmf p x))" +by(rule mcontI mono2mono_ennreal monotone_spmf cont_ennreal_spmf)+ + +lemma nn_integral_map_spmf [simp]: "nn_integral (measure_spmf (map_spmf f p)) g = nn_integral (measure_spmf p) (g \ f)" +by(auto 4 3 simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space intro: nn_integral_cong split: split_indicator) + +subsubsection \Admissibility of @{term rel_spmf}\ + +lemma rel_spmf_measureD: + assumes "rel_spmf R p q" + shows "measure (measure_spmf p) A \ measure (measure_spmf q) {y. \x\A. R x y}" (is "?lhs \ ?rhs") +proof - + have "?lhs = measure (measure_pmf p) (Some ` A)" by(simp add: measure_measure_spmf_conv_measure_pmf) + also have "\ \ measure (measure_pmf q) {y. \x\Some ` A. rel_option R x y}" + using assms by(rule rel_pmf_measureD) + also have "\ = ?rhs" unfolding measure_measure_spmf_conv_measure_pmf + by(rule arg_cong2[where f=measure])(auto simp add: option_rel_Some1) + finally show ?thesis . +qed + +locale rel_spmf_characterisation = + assumes rel_pmf_measureI: + "\(R :: 'a option \ 'b option \ bool) p q. + (\A. measure (measure_pmf p) A \ measure (measure_pmf q) {y. \x\A. R x y}) + \ rel_pmf R p q" + -- \This assumption is shown to hold in general in the AFP entry @{text "MFMC_Countable"}.\ +begin + +context fixes R :: "'a \ 'b \ bool" begin + +lemma rel_spmf_measureI: + assumes eq1: "\A. measure (measure_spmf p) A \ measure (measure_spmf q) {y. \x\A. R x y}" + assumes eq2: "weight_spmf q \ weight_spmf p" + shows "rel_spmf R p q" +proof(rule rel_pmf_measureI) + fix A :: "'a option set" + define A' where "A' = the ` (A \ range Some)" + define A'' where "A'' = A \ {None}" + have A: "A = Some ` A' \ A''" "Some ` A' \ A'' = {}" + unfolding A'_def A''_def by(auto 4 3 intro: rev_image_eqI) + have "measure (measure_pmf p) A = measure (measure_pmf p) (Some ` A') + measure (measure_pmf p) A''" + by(simp add: A measure_pmf.finite_measure_Union) + also have "measure (measure_pmf p) (Some ` A') = measure (measure_spmf p) A'" + by(simp add: measure_measure_spmf_conv_measure_pmf) + also have "\ \ measure (measure_spmf q) {y. \x\A'. R x y}" by(rule eq1) + also (ord_eq_le_trans[OF _ add_right_mono]) + have "\ = measure (measure_pmf q) {y. \x\A'. rel_option R (Some x) y}" + unfolding measure_measure_spmf_conv_measure_pmf + by(rule arg_cong2[where f=measure])(auto simp add: A'_def option_rel_Some1) + also + { have "weight_spmf p \ measure (measure_spmf q) {y. \x. R x y}" + using eq1[of UNIV] unfolding weight_spmf_def by simp + also have "\ \ weight_spmf q" unfolding weight_spmf_def + by(rule measure_spmf.finite_measure_mono) simp_all + finally have "weight_spmf p = weight_spmf q" using eq2 by simp } + then have "measure (measure_pmf p) A'' = measure (measure_pmf q) (if None \ A then {None} else {})" + unfolding A''_def by(simp add: pmf_None_eq_weight_spmf measure_pmf_single) + also have "measure (measure_pmf q) {y. \x\A'. rel_option R (Some x) y} + \ = measure (measure_pmf q) {y. \x\A. rel_option R x y}" + by(subst measure_pmf.finite_measure_Union[symmetric]) + (auto 4 3 intro!: arg_cong2[where f=measure] simp add: option_rel_Some1 option_rel_Some2 A'_def intro: rev_bexI elim: option.rel_cases) + finally show "measure (measure_pmf p) A \ \" . +qed + +lemma admissible_rel_spmf: + "ccpo.admissible (prod_lub lub_spmf lub_spmf) (rel_prod (ord_spmf op =) (ord_spmf op =)) (case_prod (rel_spmf R))" + (is "ccpo.admissible ?lub ?ord ?P") +proof(rule ccpo.admissibleI) + fix Y + assume chain: "Complete_Partial_Order.chain ?ord Y" + and Y: "Y \ {}" + and R: "\(p, q) \ Y. rel_spmf R p q" + from R have R: "\p q. (p, q) \ Y \ rel_spmf R p q" by auto + have chain1: "Complete_Partial_Order.chain (ord_spmf op =) (fst ` Y)" + and chain2: "Complete_Partial_Order.chain (ord_spmf op =) (snd ` Y)" + using chain by(rule chain_imageI; clarsimp)+ + from Y have Y1: "fst ` Y \ {}" and Y2: "snd ` Y \ {}" by auto + + have "rel_spmf R (lub_spmf (fst ` Y)) (lub_spmf (snd ` Y))" + proof(rule rel_spmf_measureI) + show "weight_spmf (lub_spmf (snd ` Y)) \ weight_spmf (lub_spmf (fst ` Y))" + by(auto simp add: weight_lub_spmf chain1 chain2 Y rel_spmf_weightD[OF R, symmetric] intro!: cSUP_least intro: cSUP_upper2[OF bdd_aboveI2[OF weight_spmf_le_1]]) + + fix A + have "measure (measure_spmf (lub_spmf (fst ` Y))) A = (SUP y:fst ` Y. measure (measure_spmf y) A)" + using chain1 Y1 by(rule measure_lub_spmf) + also have "\ \ (SUP y:snd ` Y. measure (measure_spmf y) {y. \x\A. R x y})" using Y1 + by(rule cSUP_least)(auto intro!: cSUP_upper2[OF bdd_aboveI2[OF measure_spmf.subprob_measure_le_1]] rel_spmf_measureD R) + also have "\ = measure (measure_spmf (lub_spmf (snd ` Y))) {y. \x\A. R x y}" + using chain2 Y2 by(rule measure_lub_spmf[symmetric]) + finally show "measure (measure_spmf (lub_spmf (fst ` Y))) A \ \" . + qed + then show "?P (?lub Y)" by(simp add: prod_lub_def) +qed + +lemma admissible_rel_spmf_mcont [cont_intro]: + "\ mcont lub ord lub_spmf (ord_spmf op =) f; mcont lub ord lub_spmf (ord_spmf op =) g \ + \ ccpo.admissible lub ord (\x. rel_spmf R (f x) (g x))" +by(rule admissible_subst[OF admissible_rel_spmf, where f="\x. (f x, g x)", simplified])(rule mcont_Pair) + +context begin interpretation lifting_syntax . + +lemma fixp_spmf_parametric': + assumes f: "\x. monotone (ord_spmf op =) (ord_spmf op =) F" + and g: "\x. monotone (ord_spmf op =) (ord_spmf op =) G" + and param: "(rel_spmf R ===> rel_spmf R) F G" + shows "(rel_spmf R) (ccpo.fixp lub_spmf (ord_spmf op =) F) (ccpo.fixp lub_spmf (ord_spmf op =) G)" +by(rule parallel_fixp_induct[OF ccpo_spmf ccpo_spmf _ f g])(auto intro: param[THEN rel_funD]) + +lemma fixp_spmf_parametric: + assumes f: "\x. mono_spmf (\f. F f x)" + and g: "\x. mono_spmf (\f. G f x)" + and param: "((A ===> rel_spmf R) ===> A ===> rel_spmf R) F G" + shows "(A ===> rel_spmf R) (spmf.fixp_fun F) (spmf.fixp_fun G)" +using f g +proof(rule parallel_fixp_induct_1_1[OF partial_function_definitions_spmf partial_function_definitions_spmf _ _ reflexive reflexive, where P="(A ===> rel_spmf R)"]) + show "ccpo.admissible (prod_lub (fun_lub lub_spmf) (fun_lub lub_spmf)) (rel_prod (fun_ord (ord_spmf op =)) (fun_ord (ord_spmf op =))) (\x. (A ===> rel_spmf R) (fst x) (snd x))" + unfolding rel_fun_def + apply(rule admissible_all admissible_imp admissible_rel_spmf_mcont)+ + apply(rule spmf.mcont2mcont[OF mcont_call]) + apply(rule mcont_fst) + apply(rule spmf.mcont2mcont[OF mcont_call]) + apply(rule mcont_snd) + done + show "(A ===> rel_spmf R) (\_. lub_spmf {}) (\_. lub_spmf {})" by auto + show "(A ===> rel_spmf R) (F f) (G g)" if "(A ===> rel_spmf R) f g" for f g + using that by(rule rel_funD[OF param]) +qed + +end + +end + +end + +subsection {* Restrictions on spmfs *} + +definition restrict_spmf :: "'a spmf \ 'a set \ 'a spmf" (infixl "\" 110) +where "p \ A = map_pmf (\x. x \ (\y. if y \ A then Some y else None)) p" + +lemma set_restrict_spmf [simp]: "set_spmf (p \ A) = set_spmf p \ A" +by(fastforce simp add: restrict_spmf_def set_spmf_def split: bind_splits if_split_asm) + +lemma restrict_map_spmf: "map_spmf f p \ A = map_spmf f (p \ (f -` A))" +by(simp add: restrict_spmf_def pmf.map_comp o_def map_option_bind bind_map_option if_distrib cong del: if_weak_cong) + +lemma restrict_restrict_spmf [simp]: "p \ A \ B = p \ (A \ B)" +by(auto simp add: restrict_spmf_def pmf.map_comp o_def intro!: pmf.map_cong bind_option_cong) + +lemma restrict_spmf_empty [simp]: "p \ {} = return_pmf None" +by(simp add: restrict_spmf_def) + +lemma restrict_spmf_UNIV [simp]: "p \ UNIV = p" +by(simp add: restrict_spmf_def) + +lemma spmf_restrict_spmf_outside [simp]: "x \ A \ spmf (p \ A) x = 0" +by(simp add: spmf_eq_0_set_spmf) + +lemma emeasure_restrict_spmf [simp]: + "emeasure (measure_spmf (p \ A)) X = emeasure (measure_spmf p) (X \ A)" +by(auto simp add: restrict_spmf_def measure_spmf_def emeasure_distr measurable_restrict_space1 emeasure_restrict_space space_restrict_space intro: arg_cong2[where f=emeasure] split: bind_splits if_split_asm) + +lemma measure_restrict_spmf [simp]: + "measure (measure_spmf (p \ A)) X = measure (measure_spmf p) (X \ A)" +using emeasure_restrict_spmf[of p A X] +by(simp only: measure_spmf.emeasure_eq_measure ennreal_inj measure_nonneg) + +lemma spmf_restrict_spmf: "spmf (p \ A) x = (if x \ A then spmf p x else 0)" +by(simp add: spmf_conv_measure_spmf) + +lemma spmf_restrict_spmf_inside [simp]: "x \ A \ spmf (p \ A) x = spmf p x" +by(simp add: spmf_restrict_spmf) + +lemma pmf_restrict_spmf_None: "pmf (p \ A) None = pmf p None + measure (measure_spmf p) (- A)" +proof - + have [simp]: "None \ Some ` (- A)" by auto + have "(\x. x \ (\y. if y \ A then Some y else None)) -` {None} = {None} \ (Some ` (- A))" + by(auto split: bind_splits if_split_asm) + then show ?thesis unfolding ereal.inject[symmetric] + by(simp add: restrict_spmf_def ennreal_pmf_map emeasure_pmf_single del: ereal.inject) + (simp add: pmf.rep_eq measure_pmf.finite_measure_Union[symmetric] measure_measure_spmf_conv_measure_pmf measure_pmf.emeasure_eq_measure) +qed + +lemma restrict_spmf_trivial: "(\x. x \ set_spmf p \ x \ A) \ p \ A = p" +by(rule spmf_eqI)(auto simp add: spmf_restrict_spmf spmf_eq_0_set_spmf) + +lemma restrict_spmf_trivial': "set_spmf p \ A \ p \ A = p" +by(rule restrict_spmf_trivial) blast + +lemma restrict_return_spmf: "return_spmf x \ A = (if x \ A then return_spmf x else return_pmf None)" +by(simp add: restrict_spmf_def) + +lemma restrict_return_spmf_inside [simp]: "x \ A \ return_spmf x \ A = return_spmf x" +by(simp add: restrict_return_spmf) + +lemma restrict_return_spmf_outside [simp]: "x \ A \ return_spmf x \ A = return_pmf None" +by(simp add: restrict_return_spmf) + +lemma restrict_spmf_return_pmf_None [simp]: "return_pmf None \ A = return_pmf None" +by(simp add: restrict_spmf_def) + +lemma restrict_bind_pmf: "bind_pmf p g \ A = p \ (\x. g x \ A)" +by(simp add: restrict_spmf_def map_bind_pmf o_def) + +lemma restrict_bind_spmf: "bind_spmf p g \ A = p \ (\x. g x \ A)" +by(auto simp add: bind_spmf_def restrict_bind_pmf cong del: option.case_cong_weak cong: option.case_cong intro!: bind_pmf_cong split: option.split) + +lemma bind_restrict_pmf: "bind_pmf (p \ A) g = p \ (\x. if x \ Some ` A then g x else g None)" +by(auto simp add: restrict_spmf_def bind_map_pmf fun_eq_iff split: bind_split intro: arg_cong2[where f=bind_pmf]) + +lemma bind_restrict_spmf: "bind_spmf (p \ A) g = p \ (\x. if x \ A then g x else return_pmf None)" +by(auto simp add: bind_spmf_def bind_restrict_pmf fun_eq_iff intro: arg_cong2[where f=bind_pmf] split: option.split) + +lemma spmf_map_restrict: "spmf (map_spmf fst (p \ (snd -` {y}))) x = spmf p (x, y)" +by(subst spmf_map)(auto intro: arg_cong2[where f=measure] simp add: spmf_conv_measure_spmf) + +lemma measure_eqI_restrict_spmf: + assumes "rel_spmf R (restrict_spmf p A) (restrict_spmf q B)" + shows "measure (measure_spmf p) A = measure (measure_spmf q) B" +proof - + from assms have "weight_spmf (restrict_spmf p A) = weight_spmf (restrict_spmf q B)" by(rule rel_spmf_weightD) + thus ?thesis by(simp add: weight_spmf_def) +qed + +subsection {* Subprobability distributions of sets *} + +definition spmf_of_set :: "'a set \ 'a spmf" +where + "spmf_of_set A = (if finite A \ A \ {} then spmf_of_pmf (pmf_of_set A) else return_pmf None)" + +lemma spmf_of_set: "spmf (spmf_of_set A) x = indicator A x / card A" +by(auto simp add: spmf_of_set_def) + +lemma pmf_spmf_of_set_None [simp]: "pmf (spmf_of_set A) None = indicator {A. infinite A \ A = {}} A" +by(simp add: spmf_of_set_def) + +lemma set_spmf_of_set: "set_spmf (spmf_of_set A) = (if finite A then A else {})" +by(simp add: spmf_of_set_def) + +lemma set_spmf_of_set_finite [simp]: "finite A \ set_spmf (spmf_of_set A) = A" +by(simp add: set_spmf_of_set) + +lemma spmf_of_set_singleton: "spmf_of_set {x} = return_spmf x" +by(simp add: spmf_of_set_def pmf_of_set_singleton) + +lemma map_spmf_of_set_inj_on [simp]: + "inj_on f A \ map_spmf f (spmf_of_set A) = spmf_of_set (f ` A)" +by(auto simp add: spmf_of_set_def map_pmf_of_set_inj dest: finite_imageD) + +lemma spmf_of_pmf_pmf_of_set [simp]: + "\ finite A; A \ {} \ \ spmf_of_pmf (pmf_of_set A) = spmf_of_set A" +by(simp add: spmf_of_set_def) + +lemma weight_spmf_of_set: + "weight_spmf (spmf_of_set A) = (if finite A \ A \ {} then 1 else 0)" +by(auto simp only: spmf_of_set_def weight_spmf_of_pmf weight_return_pmf_None split: if_split) + +lemma weight_spmf_of_set_finite [simp]: "\ finite A; A \ {} \ \ weight_spmf (spmf_of_set A) = 1" +by(simp add: weight_spmf_of_set) + +lemma weight_spmf_of_set_infinite [simp]: "infinite A \ weight_spmf (spmf_of_set A) = 0" +by(simp add: weight_spmf_of_set) + +lemma measure_spmf_spmf_of_set: + "measure_spmf (spmf_of_set A) = (if finite A \ A \ {} then measure_pmf (pmf_of_set A) else null_measure (count_space UNIV))" +by(simp add: spmf_of_set_def del: spmf_of_pmf_pmf_of_set) + +lemma emeasure_spmf_of_set: + "emeasure (measure_spmf (spmf_of_set S)) A = card (S \ A) / card S" +by(auto simp add: measure_spmf_spmf_of_set emeasure_pmf_of_set) + +lemma measure_spmf_of_set: + "measure (measure_spmf (spmf_of_set S)) A = card (S \ A) / card S" +by(auto simp add: measure_spmf_spmf_of_set measure_pmf_of_set) + +lemma nn_integral_spmf_of_set: "nn_integral (measure_spmf (spmf_of_set A)) f = setsum f A / card A" +by(cases "finite A")(auto simp add: spmf_of_set_def nn_integral_pmf_of_set card_gt_0_iff simp del: spmf_of_pmf_pmf_of_set) + +lemma integral_spmf_of_set: "integral\<^sup>L (measure_spmf (spmf_of_set A)) f = setsum f A / card A" +by(clarsimp simp add: spmf_of_set_def integral_pmf_of_set card_gt_0_iff simp del: spmf_of_pmf_pmf_of_set) + +notepad begin -- \@{const pmf_of_set} is not fully parametric.\ + define R :: "nat \ nat \ bool" where "R x y \ (x \ 0 \ y = 0)" for x y + define A :: "nat set" where "A = {0, 1}" + define B :: "nat set" where "B = {0, 1, 2}" + have "rel_set R A B" unfolding R_def[abs_def] A_def B_def rel_set_def by auto + have "\ rel_pmf R (pmf_of_set A) (pmf_of_set B)" + proof + assume "rel_pmf R (pmf_of_set A) (pmf_of_set B)" + then obtain pq where pq: "\x y. (x, y) \ set_pmf pq \ R x y" + and 1: "map_pmf fst pq = pmf_of_set A" + and 2: "map_pmf snd pq = pmf_of_set B" + by cases auto + have "pmf (pmf_of_set B) 1 = 1 / 3" by(simp add: B_def) + have "pmf (pmf_of_set B) 2 = 1 / 3" by(simp add: B_def) + + have "2 / 3 = pmf (pmf_of_set B) 1 + pmf (pmf_of_set B) 2" by(simp add: B_def) + also have "\ = measure (measure_pmf (pmf_of_set B)) ({1} \ {2})" + by(subst measure_pmf.finite_measure_Union)(simp_all add: measure_pmf_single) + also have "\ = emeasure (measure_pmf pq) (snd -` {2, 1})" + unfolding 2[symmetric] measure_pmf.emeasure_eq_measure[symmetric] by(simp) + also have "\ = emeasure (measure_pmf pq) {(0, 2), (0, 1)}" + by(rule emeasure_eq_AE)(auto simp add: AE_measure_pmf_iff R_def dest!: pq) + also have "\ \ emeasure (measure_pmf pq) (fst -` {0})" + by(rule emeasure_mono) auto + also have "\ = emeasure (measure_pmf (pmf_of_set A)) {0}" + unfolding 1[symmetric] by simp + also have "\ = pmf (pmf_of_set A) 0" + by(simp add: measure_pmf_single measure_pmf.emeasure_eq_measure) + also have "pmf (pmf_of_set A) 0 = 1 / 2" by(simp add: A_def) + finally show False by(subst (asm) ennreal_le_iff; simp) + qed +end + +lemma rel_pmf_of_set_bij: + assumes f: "bij_betw f A B" + and A: "A \ {}" "finite A" + and B: "B \ {}" "finite B" + and R: "\x. x \ A \ R x (f x)" + shows "rel_pmf R (pmf_of_set A) (pmf_of_set B)" +proof(rule pmf.rel_mono_strong) + define AB where "AB = (\x. (x, f x)) ` A" + define R' where "R' x y \ (x, y) \ AB" for x y + have "(x, y) \ AB" if "(x, y) \ set_pmf (pmf_of_set AB)" for x y + using that by(auto simp add: AB_def A) + moreover have "map_pmf fst (pmf_of_set AB) = pmf_of_set A" + by(simp add: AB_def map_pmf_of_set_inj[symmetric] inj_on_def A pmf.map_comp o_def) + moreover + from f have [simp]: "inj_on f A" by(rule bij_betw_imp_inj_on) + from f have [simp]: "f ` A = B" by(rule bij_betw_imp_surj_on) + have "map_pmf snd (pmf_of_set AB) = pmf_of_set B" + by(simp add: AB_def map_pmf_of_set_inj[symmetric] inj_on_def A pmf.map_comp o_def) + (simp add: map_pmf_of_set_inj A) + ultimately show "rel_pmf (\x y. (x, y) \ AB) (pmf_of_set A) (pmf_of_set B)" .. +qed(auto intro: R) + +lemma rel_spmf_of_set_bij: + assumes f: "bij_betw f A B" + and R: "\x. x \ A \ R x (f x)" + shows "rel_spmf R (spmf_of_set A) (spmf_of_set B)" +proof - + have "finite A \ finite B" using f by(rule bij_betw_finite) + moreover have "A = {} \ B = {}" using f by(auto dest: bij_betw_empty2 bij_betw_empty1) + ultimately show ?thesis using assms + by(auto simp add: spmf_of_set_def simp del: spmf_of_pmf_pmf_of_set intro: rel_pmf_of_set_bij) +qed + +context begin interpretation lifting_syntax . +lemma rel_spmf_of_set: + assumes "bi_unique R" + shows "(rel_set R ===> rel_spmf R) spmf_of_set spmf_of_set" +proof + fix A B + assume R: "rel_set R A B" + with assms obtain f where "bij_betw f A B" and f: "\x. x \ A \ R x (f x)" + by(auto dest: bi_unique_rel_set_bij_betw) + then show "rel_spmf R (spmf_of_set A) (spmf_of_set B)" by(rule rel_spmf_of_set_bij) +qed +end + +lemma map_mem_spmf_of_set: + assumes "finite B" "B \ {}" + shows "map_spmf (\x. x \ A) (spmf_of_set B) = spmf_of_pmf (bernoulli_pmf (card (A \ B) / card B))" + (is "?lhs = ?rhs") +proof(rule spmf_eqI) + fix i + have "ennreal (spmf ?lhs i) = card (B \ (\x. x \ A) -` {i}) / (card B)" + by(subst ennreal_spmf_map)(simp add: measure_spmf_spmf_of_set assms emeasure_pmf_of_set) + also have "\ = (if i then card (B \ A) / card B else card (B - A) / card B)" + by(auto intro: arg_cong[where f=card]) + also have "\ = (if i then card (B \ A) / card B else (card B - card (B \ A)) / card B)" + by(auto simp add: card_Diff_subset_Int assms) + also have "\ = ennreal (spmf ?rhs i)" + by(simp add: assms card_gt_0_iff field_simps card_mono Int_commute of_nat_diff) + finally show "spmf ?lhs i = spmf ?rhs i" by simp +qed + +abbreviation coin_spmf :: "bool spmf" +where "coin_spmf \ spmf_of_set UNIV" + +lemma map_eq_const_coin_spmf: "map_spmf (op = c) coin_spmf = coin_spmf" +proof - + have "inj (op \ c)" "range (op \ c) = UNIV" by(auto intro: inj_onI) + then show ?thesis by simp +qed + +lemma bind_coin_spmf_eq_const: "coin_spmf \ (\x :: bool. return_spmf (b = x)) = coin_spmf" +using map_eq_const_coin_spmf unfolding map_spmf_conv_bind_spmf by simp + +lemma bind_coin_spmf_eq_const': "coin_spmf \ (\x :: bool. return_spmf (x = b)) = coin_spmf" +by(rewrite in "_ = \" bind_coin_spmf_eq_const[symmetric, of b])(auto intro: bind_spmf_cong) + +subsection {* Losslessness *} + +definition lossless_spmf :: "'a spmf \ bool" +where "lossless_spmf p \ weight_spmf p = 1" + +lemma lossless_iff_pmf_None: "lossless_spmf p \ pmf p None = 0" +by(simp add: lossless_spmf_def pmf_None_eq_weight_spmf) + +lemma lossless_return_spmf [iff]: "lossless_spmf (return_spmf x)" +by(simp add: lossless_iff_pmf_None) + +lemma lossless_return_pmf_None [iff]: "\ lossless_spmf (return_pmf None)" +by(simp add: lossless_iff_pmf_None) + +lemma lossless_map_spmf [simp]: "lossless_spmf (map_spmf f p) \ lossless_spmf p" +by(auto simp add: lossless_iff_pmf_None pmf_eq_0_set_pmf) + +lemma lossless_bind_spmf [simp]: + "lossless_spmf (p \ f) \ lossless_spmf p \ (\x\set_spmf p. lossless_spmf (f x))" +by(simp add: lossless_iff_pmf_None pmf_bind_spmf_None add_nonneg_eq_0_iff integral_nonneg_AE integral_nonneg_eq_0_iff_AE measure_spmf.integrable_const_bound[where B=1] pmf_le_1) + +lemma lossless_weight_spmfD: "lossless_spmf p \ weight_spmf p = 1" +by(simp add: lossless_spmf_def) + +lemma lossless_iff_set_pmf_None: + "lossless_spmf p \ None \ set_pmf p" +by (simp add: lossless_iff_pmf_None pmf_eq_0_set_pmf) + +lemma lossless_spmf_of_set [simp]: "lossless_spmf (spmf_of_set A) \ finite A \ A \ {}" +by(auto simp add: lossless_spmf_def weight_spmf_of_set) + +lemma lossless_spmf_spmf_of_spmf [simp]: "lossless_spmf (spmf_of_pmf p)" +by(simp add: lossless_spmf_def) + +lemma lossless_spmf_bind_pmf [simp]: + "lossless_spmf (bind_pmf p f) \ (\x\set_pmf p. lossless_spmf (f x))" +by(simp add: lossless_iff_pmf_None pmf_bind integral_nonneg_AE integral_nonneg_eq_0_iff_AE measure_pmf.integrable_const_bound[where B=1] AE_measure_pmf_iff pmf_le_1) + +lemma lossless_spmf_conv_spmf_of_pmf: "lossless_spmf p \ (\p'. p = spmf_of_pmf p')" +proof + assume "lossless_spmf p" + hence *: "\y. y \ set_pmf p \ \x. y = Some x" + by(case_tac y)(simp_all add: lossless_iff_set_pmf_None) + + let ?p = "map_pmf the p" + have "p = spmf_of_pmf ?p" + proof(rule spmf_eqI) + fix i + have "ennreal (pmf (map_pmf the p) i) = \\<^sup>+ x. indicator (the -` {i}) x \p" by(simp add: ennreal_pmf_map) + also have "\ = \\<^sup>+ x. indicator {i} x \measure_spmf p" unfolding measure_spmf_def + by(subst nn_integral_distr)(auto simp add: nn_integral_restrict_space AE_measure_pmf_iff simp del: nn_integral_indicator intro!: nn_integral_cong_AE split: split_indicator dest!: * ) + also have "\ = spmf p i" by(simp add: emeasure_spmf_single) + finally show "spmf p i = spmf (spmf_of_pmf ?p) i" by simp + qed + thus "\p'. p = spmf_of_pmf p'" .. +qed auto + +lemma spmf_False_conv_True: "lossless_spmf p \ spmf p False = 1 - spmf p True" +by(clarsimp simp add: lossless_spmf_conv_spmf_of_pmf pmf_False_conv_True) + +lemma spmf_True_conv_False: "lossless_spmf p \ spmf p True = 1 - spmf p False" +by(simp add: spmf_False_conv_True) + +lemma bind_eq_return_spmf: + "bind_spmf p f = return_spmf x \ (\y\set_spmf p. f y = return_spmf x) \ lossless_spmf p" +by(auto simp add: bind_spmf_def bind_eq_return_pmf in_set_spmf lossless_iff_pmf_None pmf_eq_0_set_pmf iff del: not_None_eq split: option.split) + +lemma rel_spmf_return_spmf2: + "rel_spmf R p (return_spmf x) \ lossless_spmf p \ (\a\set_spmf p. R a x)" +by(auto simp add: lossless_iff_set_pmf_None rel_pmf_return_pmf2 option_rel_Some2 in_set_spmf, metis in_set_spmf not_None_eq) + +lemma rel_spmf_return_spmf1: + "rel_spmf R (return_spmf x) p \ lossless_spmf p \ (\a\set_spmf p. R x a)" +using rel_spmf_return_spmf2[of "R\\"] by(simp add: spmf_rel_conversep) + +lemma rel_spmf_bindI1: + assumes f: "\x. x \ set_spmf p \ rel_spmf R (f x) q" + and p: "lossless_spmf p" + shows "rel_spmf R (bind_spmf p f) q" +proof - + fix x :: 'a + have "rel_spmf R (bind_spmf p f) (bind_spmf (return_spmf x) (\_. q))" + by(rule rel_spmf_bindI[where R="\x _. x \ set_spmf p"])(simp_all add: rel_spmf_return_spmf2 p f) + then show ?thesis by simp +qed + +lemma rel_spmf_bindI2: + "\ \x. x \ set_spmf q \ rel_spmf R p (f x); lossless_spmf q \ + \ rel_spmf R p (bind_spmf q f)" +using rel_spmf_bindI1[of q "conversep R" f p] by(simp add: spmf_rel_conversep) + +subsection {* Scaling *} + +definition scale_spmf :: "real \ 'a spmf \ 'a spmf" +where + "scale_spmf r p = embed_spmf (\x. min (inverse (weight_spmf p)) (max 0 r) * spmf p x)" + +lemma scale_spmf_le_1: + "(\\<^sup>+ x. min (inverse (weight_spmf p)) (max 0 r) * spmf p x \count_space UNIV) \ 1" (is "?lhs \ _") +proof - + have "?lhs = min (inverse (weight_spmf p)) (max 0 r) * \\<^sup>+ x. spmf p x \count_space UNIV" + by(subst nn_integral_cmult[symmetric])(simp_all add: weight_spmf_nonneg max_def min_def ennreal_mult) + also have "\ \ 1" unfolding weight_spmf_eq_nn_integral_spmf[symmetric] + by(simp add: min_def max_def weight_spmf_nonneg order.strict_iff_order field_simps ennreal_mult[symmetric]) + finally show ?thesis . +qed + +lemma spmf_scale_spmf: "spmf (scale_spmf r p) x = max 0 (min (inverse (weight_spmf p)) r) * spmf p x" (is "?lhs = ?rhs") +unfolding scale_spmf_def +apply(subst spmf_embed_spmf[OF scale_spmf_le_1]) +apply(simp add: max_def min_def weight_spmf_le_0 field_simps weight_spmf_nonneg not_le order.strict_iff_order) +apply(metis antisym_conv order_trans weight_spmf_nonneg zero_le_mult_iff zero_le_one) +done + +lemma real_inverse_le_1_iff: fixes x :: real + shows "\ 0 \ x; x \ 1 \ \ 1 / x \ 1 \ x = 1 \ x = 0" +by auto + +lemma spmf_scale_spmf': "r \ 1 \ spmf (scale_spmf r p) x = max 0 r * spmf p x" +using real_inverse_le_1_iff[OF weight_spmf_nonneg weight_spmf_le_1, of p] +by(auto simp add: spmf_scale_spmf max_def min_def field_simps)(metis pmf_le_0_iff spmf_le_weight) + +lemma scale_spmf_neg: "r \ 0 \ scale_spmf r p = return_pmf None" +by(rule spmf_eqI)(simp add: spmf_scale_spmf' max_def) + +lemma scale_spmf_return_None [simp]: "scale_spmf r (return_pmf None) = return_pmf None" +by(rule spmf_eqI)(simp add: spmf_scale_spmf) + +lemma scale_spmf_conv_bind_bernoulli: + assumes "r \ 1" + shows "scale_spmf r p = bind_pmf (bernoulli_pmf r) (\b. if b then p else return_pmf None)" (is "?lhs = ?rhs") +proof(rule spmf_eqI) + fix x + have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using assms + unfolding spmf_scale_spmf ennreal_pmf_bind nn_integral_measure_pmf UNIV_bool bernoulli_pmf.rep_eq + apply(auto simp add: nn_integral_count_space_finite max_def min_def field_simps real_inverse_le_1_iff[OF weight_spmf_nonneg weight_spmf_le_1] weight_spmf_lt_0 not_le ennreal_mult[symmetric]) + apply (metis pmf_le_0_iff spmf_le_weight) + apply (metis pmf_le_0_iff spmf_le_weight) + apply (meson le_divide_eq_1_pos measure_spmf.subprob_measure_le_1 not_less order_trans weight_spmf_le_0) + by (meson divide_le_0_1_iff less_imp_le order_trans weight_spmf_le_0) + thus "spmf ?lhs x = spmf ?rhs x" by simp +qed + +lemma nn_integral_spmf: "(\\<^sup>+ x. spmf p x \count_space A) = emeasure (measure_spmf p) A" +apply(simp add: measure_spmf_def emeasure_distr emeasure_restrict_space space_restrict_space nn_integral_pmf[symmetric]) +apply(rule nn_integral_bij_count_space[where g=Some]) +apply(auto simp add: bij_betw_def) +done + +lemma measure_spmf_scale_spmf: "measure_spmf (scale_spmf r p) = scale_measure (min (inverse (weight_spmf p)) r) (measure_spmf p)" +apply(rule measure_eqI) + apply simp +apply(simp add: nn_integral_spmf[symmetric] spmf_scale_spmf) +apply(subst nn_integral_cmult[symmetric]) +apply(auto simp add: max_def min_def ennreal_mult[symmetric] not_le ennreal_lt_0) +done + +lemma measure_spmf_scale_spmf': + "r \ 1 \ measure_spmf (scale_spmf r p) = scale_measure r (measure_spmf p)" +unfolding measure_spmf_scale_spmf +apply(cases "weight_spmf p > 0") + apply(simp add: min.absorb2 field_simps weight_spmf_le_1 mult_le_one) +apply(clarsimp simp add: weight_spmf_le_0 min_def scale_spmf_neg weight_spmf_eq_0 not_less) +done + +lemma scale_spmf_1 [simp]: "scale_spmf 1 p = p" +apply(rule spmf_eqI) +apply(simp add: spmf_scale_spmf max_def min_def order.strict_iff_order field_simps weight_spmf_nonneg) +apply(metis antisym_conv divide_le_eq_1 less_imp_le pmf_nonneg spmf_le_weight weight_spmf_nonneg weight_spmf_le_1) +done + +lemma scale_spmf_0 [simp]: "scale_spmf 0 p = return_pmf None" +by(rule spmf_eqI)(simp add: spmf_scale_spmf min_def max_def weight_spmf_le_0) + +lemma bind_scale_spmf: + assumes r: "r \ 1" + shows "bind_spmf (scale_spmf r p) f = bind_spmf p (\x. scale_spmf r (f x))" + (is "?lhs = ?rhs") +proof(rule spmf_eqI) + fix x + have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using r + by(simp add: ennreal_spmf_bind measure_spmf_scale_spmf' nn_integral_scale_measure spmf_scale_spmf') + (simp add: ennreal_mult ennreal_lt_0 nn_integral_cmult max_def min_def) + thus "spmf ?lhs x = spmf ?rhs x" by simp +qed + +lemma scale_bind_spmf: + assumes "r \ 1" + shows "scale_spmf r (bind_spmf p f) = bind_spmf p (\x. scale_spmf r (f x))" + (is "?lhs = ?rhs") +proof(rule spmf_eqI) + fix x + have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" using assms + unfolding spmf_scale_spmf'[OF assms] + by(simp add: ennreal_mult ennreal_spmf_bind spmf_scale_spmf' nn_integral_cmult max_def min_def) + thus "spmf ?lhs x = spmf ?rhs x" by simp +qed + +lemma bind_spmf_const: "bind_spmf p (\x. q) = scale_spmf (weight_spmf p) q" (is "?lhs = ?rhs") +proof(rule spmf_eqI) + fix x + have "ennreal (spmf ?lhs x) = ennreal (spmf ?rhs x)" + using measure_spmf.subprob_measure_le_1[of p "space (measure_spmf p)"] + by(subst ennreal_spmf_bind)(simp add: spmf_scale_spmf' weight_spmf_le_1 ennreal_mult mult.commute max_def min_def measure_spmf.emeasure_eq_measure) + thus "spmf ?lhs x = spmf ?rhs x" by simp +qed + +lemma map_scale_spmf: "map_spmf f (scale_spmf r p) = scale_spmf r (map_spmf f p)" (is "?lhs = ?rhs") +proof(rule spmf_eqI) + fix i + show "spmf ?lhs i = spmf ?rhs i" unfolding spmf_scale_spmf + by(subst (1 2) spmf_map)(auto simp add: measure_spmf_scale_spmf max_def min_def ennreal_lt_0) +qed + +lemma set_scale_spmf: "set_spmf (scale_spmf r p) = (if r > 0 then set_spmf p else {})" +apply(auto simp add: in_set_spmf_iff_spmf spmf_scale_spmf) +apply(simp add: max_def min_def not_le weight_spmf_lt_0 weight_spmf_eq_0 split: if_split_asm) +done + +lemma set_scale_spmf' [simp]: "0 < r \ set_spmf (scale_spmf r p) = set_spmf p" +by(simp add: set_scale_spmf) + +lemma rel_spmf_scaleI: + assumes "r > 0 \ rel_spmf A p q" + shows "rel_spmf A (scale_spmf r p) (scale_spmf r q)" +proof(cases "r > 0") + case True + from assms[OF this] show ?thesis + by(rule rel_spmfE)(auto simp add: map_scale_spmf[symmetric] spmf_rel_map True intro: rel_spmf_reflI) +qed(simp add: not_less scale_spmf_neg) + +lemma weight_scale_spmf: "weight_spmf (scale_spmf r p) = min 1 (max 0 r * weight_spmf p)" +proof - + have "ennreal (weight_spmf (scale_spmf r p)) = min 1 (max 0 r * ennreal (weight_spmf p))" + unfolding weight_spmf_eq_nn_integral_spmf + apply(simp add: spmf_scale_spmf ennreal_mult zero_ereal_def[symmetric] nn_integral_cmult) + apply(auto simp add: weight_spmf_eq_nn_integral_spmf[symmetric] field_simps min_def max_def not_le weight_spmf_lt_0 ennreal_mult[symmetric]) + subgoal by(subst (asm) ennreal_mult[symmetric], meson divide_less_0_1_iff le_less_trans not_le weight_spmf_lt_0, simp+, meson not_le pos_divide_le_eq weight_spmf_le_0) + subgoal by(cases "r \ 0")(simp_all add: ennreal_mult[symmetric] weight_spmf_nonneg ennreal_lt_0, meson le_less_trans not_le pos_divide_le_eq zero_less_divide_1_iff) + done + thus ?thesis by(auto simp add: min_def max_def ennreal_mult[symmetric] split: if_split_asm) +qed + +lemma weight_scale_spmf' [simp]: + "\ 0 \ r; r \ 1 \ \ weight_spmf (scale_spmf r p) = r * weight_spmf p" +by(simp add: weight_scale_spmf max_def min_def)(metis antisym_conv mult_left_le order_trans weight_spmf_le_1) + +lemma pmf_scale_spmf_None: + "pmf (scale_spmf k p) None = 1 - min 1 (max 0 k * (1 - pmf p None))" +unfolding pmf_None_eq_weight_spmf by(simp add: weight_scale_spmf) + +lemma scale_scale_spmf: + "scale_spmf r (scale_spmf r' p) = scale_spmf (r * max 0 (min (inverse (weight_spmf p)) r')) p" + (is "?lhs = ?rhs") +proof(rule spmf_eqI) + fix i + have "max 0 (min (1 / weight_spmf p) r') * + max 0 (min (1 / min 1 (weight_spmf p * max 0 r')) r) = + max 0 (min (1 / weight_spmf p) (r * max 0 (min (1 / weight_spmf p) r')))" + proof(cases "weight_spmf p > 0") + case False + thus ?thesis by(simp add: not_less weight_spmf_le_0) + next + case True + thus ?thesis by(simp add: field_simps max_def min.absorb_iff2[symmetric])(auto simp add: min_def field_simps zero_le_mult_iff) + qed + then show "spmf ?lhs i = spmf ?rhs i" + by(simp add: spmf_scale_spmf field_simps weight_scale_spmf) +qed + +lemma scale_scale_spmf' [simp]: + "\ 0 \ r; r \ 1; 0 \ r'; r' \ 1 \ + \ scale_spmf r (scale_spmf r' p) = scale_spmf (r * r') p" +apply(cases "weight_spmf p > 0") +apply(auto simp add: scale_scale_spmf min_def max_def field_simps not_le weight_spmf_lt_0 weight_spmf_eq_0 not_less weight_spmf_le_0) +apply(subgoal_tac "1 = r'") + apply (metis (no_types) divide_1 eq_iff measure_spmf.subprob_measure_le_1 mult.commute mult_cancel_right1) +apply(meson eq_iff le_divide_eq_1_pos measure_spmf.subprob_measure_le_1 mult_imp_div_pos_le order.trans) +done + +lemma scale_spmf_eq_same: "scale_spmf r p = p \ weight_spmf p = 0 \ r = 1 \ r \ 1 \ weight_spmf p = 1" + (is "?lhs \ ?rhs") +proof + assume ?lhs + hence "weight_spmf (scale_spmf r p) = weight_spmf p" by simp + hence *: "min 1 (max 0 r * weight_spmf p) = weight_spmf p" by(simp add: weight_scale_spmf) + hence **: "weight_spmf p = 0 \ r \ 1" by(auto simp add: min_def max_def split: if_split_asm) + show ?rhs + proof(cases "weight_spmf p = 0") + case False + with ** have "r \ 1" by simp + with * False have "r = 1 \ weight_spmf p = 1" by(simp add: max_def min_def not_le split: if_split_asm) + with \r \ 1\ show ?thesis by simp + qed simp +qed(auto intro!: spmf_eqI simp add: spmf_scale_spmf, metis pmf_le_0_iff spmf_le_weight) + +lemma map_const_spmf_of_set: + "\ finite A; A \ {} \ \ map_spmf (\_. c) (spmf_of_set A) = return_spmf c" +by(simp add: map_spmf_conv_bind_spmf bind_spmf_const) + +subsection {* Conditional spmfs *} + +lemma set_pmf_Int_Some: "set_pmf p \ Some ` A = {} \ set_spmf p \ A = {}" +by(auto simp add: in_set_spmf) + +lemma measure_spmf_zero_iff: "measure (measure_spmf p) A = 0 \ set_spmf p \ A = {}" +unfolding measure_measure_spmf_conv_measure_pmf by(simp add: measure_pmf_zero_iff set_pmf_Int_Some) + +definition cond_spmf :: "'a spmf \ 'a set \ 'a spmf" +where "cond_spmf p A = (if set_spmf p \ A = {} then return_pmf None else cond_pmf p (Some ` A))" + +lemma set_cond_spmf [simp]: "set_spmf (cond_spmf p A) = set_spmf p \ A" +by(auto 4 4 simp add: cond_spmf_def in_set_spmf iff: set_cond_pmf[THEN set_eq_iff[THEN iffD1], THEN spec, rotated]) + +lemma cond_map_spmf [simp]: "cond_spmf (map_spmf f p) A = map_spmf f (cond_spmf p (f -` A))" +proof - + have "map_option f -` Some ` A = Some ` f -` A" by auto + moreover have "set_pmf p \ map_option f -` Some ` A \ {}" if "Some x \ set_pmf p" "f x \ A" for x + using that by auto + ultimately show ?thesis by(auto simp add: cond_spmf_def in_set_spmf cond_map_pmf) +qed + +lemma spmf_cond_spmf [simp]: + "spmf (cond_spmf p A) x = (if x \ A then spmf p x / measure (measure_spmf p) A else 0)" +by(auto simp add: cond_spmf_def pmf_cond set_pmf_Int_Some[symmetric] measure_measure_spmf_conv_measure_pmf measure_pmf_zero_iff) + +lemma bind_eq_return_pmf_None: + "bind_spmf p f = return_pmf None \ (\x\set_spmf p. f x = return_pmf None)" +by(auto simp add: bind_spmf_def bind_eq_return_pmf in_set_spmf split: option.splits) + +lemma return_pmf_None_eq_bind: + "return_pmf None = bind_spmf p f \ (\x\set_spmf p. f x = return_pmf None)" +using bind_eq_return_pmf_None[of p f] by auto + +(* Conditional probabilities do not seem to interact nicely with bind. *) + +subsection {* Product spmf *} + +definition pair_spmf :: "'a spmf \ 'b spmf \ ('a \ 'b) spmf" +where "pair_spmf p q = bind_pmf (pair_pmf p q) (\xy. case xy of (Some x, Some y) \ return_spmf (x, y) | _ \ return_pmf None)" + +lemma map_fst_pair_spmf [simp]: "map_spmf fst (pair_spmf p q) = scale_spmf (weight_spmf q) p" +unfolding bind_spmf_const[symmetric] +apply(simp add: pair_spmf_def map_bind_pmf pair_pmf_def bind_assoc_pmf option.case_distrib) +apply(subst bind_commute_pmf) +apply(auto intro!: bind_pmf_cong[OF refl] simp add: bind_return_pmf bind_spmf_def bind_return_pmf' case_option_collapse option.case_distrib[where h="map_spmf _"] option.case_distrib[symmetric] case_option_id split: option.split cong del: option.case_cong) +done + +lemma map_snd_pair_spmf [simp]: "map_spmf snd (pair_spmf p q) = scale_spmf (weight_spmf p) q" +unfolding bind_spmf_const[symmetric] +apply(simp add: pair_spmf_def map_bind_pmf pair_pmf_def bind_assoc_pmf option.case_distrib cong del: option.case_cong) +apply(auto intro!: bind_pmf_cong[OF refl] simp add: bind_return_pmf bind_spmf_def bind_return_pmf' case_option_collapse option.case_distrib[where h="map_spmf _"] option.case_distrib[symmetric] case_option_id split: option.split cong del: option.case_cong) +done + +lemma set_pair_spmf [simp]: "set_spmf (pair_spmf p q) = set_spmf p \ set_spmf q" +by(auto 4 3 simp add: pair_spmf_def set_spmf_bind_pmf bind_UNION in_set_spmf intro: rev_bexI split: option.splits) + +lemma spmf_pair [simp]: "spmf (pair_spmf p q) (x, y) = spmf p x * spmf q y" (is "?lhs = ?rhs") +proof - + have "ennreal ?lhs = \\<^sup>+ a. \\<^sup>+ b. indicator {(x, y)} (a, b) \measure_spmf q \measure_spmf p" + unfolding measure_spmf_def pair_spmf_def ennreal_pmf_bind nn_integral_pair_pmf' + by(auto simp add: zero_ereal_def[symmetric] nn_integral_distr nn_integral_restrict_space nn_integral_multc[symmetric] intro!: nn_integral_cong split: option.split split_indicator) + also have "\ = \\<^sup>+ a. (\\<^sup>+ b. indicator {y} b \measure_spmf q) * indicator {x} a \measure_spmf p" + by(subst nn_integral_multc[symmetric])(auto intro!: nn_integral_cong split: split_indicator) + also have "\ = ennreal ?rhs" by(simp add: emeasure_spmf_single max_def ennreal_mult mult.commute) + finally show ?thesis by simp +qed + +lemma pair_map_spmf2: "pair_spmf p (map_spmf f q) = map_spmf (apsnd f) (pair_spmf p q)" +by(auto simp add: pair_spmf_def pair_map_pmf2 bind_map_pmf map_bind_pmf intro: bind_pmf_cong split: option.split) + +lemma pair_map_spmf1: "pair_spmf (map_spmf f p) q = map_spmf (apfst f) (pair_spmf p q)" +by(auto simp add: pair_spmf_def pair_map_pmf1 bind_map_pmf map_bind_pmf intro: bind_pmf_cong split: option.split) + +lemma pair_map_spmf: "pair_spmf (map_spmf f p) (map_spmf g q) = map_spmf (map_prod f g) (pair_spmf p q)" +unfolding pair_map_spmf2 pair_map_spmf1 spmf.map_comp by(simp add: apfst_def apsnd_def o_def prod.map_comp) + +lemma pair_spmf_alt_def: "pair_spmf p q = bind_spmf p (\x. bind_spmf q (\y. return_spmf (x, y)))" +by(auto simp add: pair_spmf_def pair_pmf_def bind_spmf_def bind_assoc_pmf bind_return_pmf split: option.split intro: bind_pmf_cong) + +lemma weight_pair_spmf [simp]: "weight_spmf (pair_spmf p q) = weight_spmf p * weight_spmf q" +unfolding pair_spmf_alt_def by(simp add: weight_bind_spmf o_def) + +lemma pair_scale_spmf1: (* FIXME: generalise to arbitrary r *) + "r \ 1 \ pair_spmf (scale_spmf r p) q = scale_spmf r (pair_spmf p q)" +by(simp add: pair_spmf_alt_def scale_bind_spmf bind_scale_spmf) + +lemma pair_scale_spmf2: (* FIXME: generalise to arbitrary r *) + "r \ 1 \ pair_spmf p (scale_spmf r q) = scale_spmf r (pair_spmf p q)" +by(simp add: pair_spmf_alt_def scale_bind_spmf bind_scale_spmf) + +lemma pair_spmf_return_None1 [simp]: "pair_spmf (return_pmf None) p = return_pmf None" +by(rule spmf_eqI)(clarsimp) + +lemma pair_spmf_return_None2 [simp]: "pair_spmf p (return_pmf None) = return_pmf None" +by(rule spmf_eqI)(clarsimp) + +lemma pair_spmf_return_spmf1: "pair_spmf (return_spmf x) q = map_spmf (Pair x) q" +by(rule spmf_eqI)(auto split: split_indicator simp add: spmf_map_inj' inj_on_def intro: spmf_map_outside) + +lemma pair_spmf_return_spmf2: "pair_spmf p (return_spmf y) = map_spmf (\x. (x, y)) p" +by(rule spmf_eqI)(auto split: split_indicator simp add: inj_on_def intro!: spmf_map_outside spmf_map_inj'[symmetric]) + +lemma pair_spmf_return_spmf [simp]: "pair_spmf (return_spmf x) (return_spmf y) = return_spmf (x, y)" +by(simp add: pair_spmf_return_spmf1) + +lemma rel_pair_spmf_prod: + "rel_spmf (rel_prod A B) (pair_spmf p q) (pair_spmf p' q') \ + rel_spmf A (scale_spmf (weight_spmf q) p) (scale_spmf (weight_spmf q') p') \ + rel_spmf B (scale_spmf (weight_spmf p) q) (scale_spmf (weight_spmf p') q')" + (is "?lhs \ ?rhs" is "_ \ ?A \ ?B" is "_ \ rel_spmf _ ?p ?p' \ rel_spmf _ ?q ?q'") +proof(intro iffI conjI) + assume ?rhs + then obtain pq pq' where p: "map_spmf fst pq = ?p" and p': "map_spmf snd pq = ?p'" + and q: "map_spmf fst pq' = ?q" and q': "map_spmf snd pq' = ?q'" + and *: "\x x'. (x, x') \ set_spmf pq \ A x x'" + and **: "\y y'. (y, y') \ set_spmf pq' \ B y y'" by(auto elim!: rel_spmfE) + let ?f = "\((x, x'), (y, y')). ((x, y), (x', y'))" + let ?r = "1 / (weight_spmf p * weight_spmf q)" + let ?pq = "scale_spmf ?r (map_spmf ?f (pair_spmf pq pq'))" + + { fix p :: "'x spmf" and q :: "'y spmf" + assume "weight_spmf q \ 0" + and "weight_spmf p \ 0" + and "1 / (weight_spmf p * weight_spmf q) \ weight_spmf p * weight_spmf q" + hence "1 \ (weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q)" + by(simp add: pos_divide_le_eq order.strict_iff_order weight_spmf_nonneg) + moreover have "(weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q) \ (1 * 1) * (1 * 1)" + by(intro mult_mono)(simp_all add: weight_spmf_nonneg weight_spmf_le_1) + ultimately have "(weight_spmf p * weight_spmf q) * (weight_spmf p * weight_spmf q) = 1" by simp + hence *: "weight_spmf p * weight_spmf q = 1" + by(metis antisym_conv less_le mult_less_cancel_left1 weight_pair_spmf weight_spmf_le_1 weight_spmf_nonneg) + hence "weight_spmf p = 1" by(metis antisym_conv mult_left_le weight_spmf_le_1 weight_spmf_nonneg) + moreover with * have "weight_spmf q = 1" by simp + moreover note calculation } + note full = this + + show ?lhs + proof + have [simp]: "fst \ ?f = map_prod fst fst" by(simp add: fun_eq_iff) + have "map_spmf fst ?pq = scale_spmf ?r (pair_spmf ?p ?q)" + by(simp add: pair_map_spmf[symmetric] p q map_scale_spmf spmf.map_comp) + also have "\ = pair_spmf p q" using full[of p q] + by(simp add: pair_scale_spmf1 pair_scale_spmf2 weight_spmf_le_1 weight_spmf_nonneg) + (auto simp add: scale_scale_spmf max_def min_def field_simps weight_spmf_nonneg weight_spmf_eq_0) + finally show "map_spmf fst ?pq = \" . + + have [simp]: "snd \ ?f = map_prod snd snd" by(simp add: fun_eq_iff) + from \?rhs\ have eq: "weight_spmf p * weight_spmf q = weight_spmf p' * weight_spmf q'" + by(auto dest!: rel_spmf_weightD simp add: weight_spmf_le_1 weight_spmf_nonneg) + + have "map_spmf snd ?pq = scale_spmf ?r (pair_spmf ?p' ?q')" + by(simp add: pair_map_spmf[symmetric] p' q' map_scale_spmf spmf.map_comp) + also have "\ = pair_spmf p' q'" using full[of p' q'] eq + by(simp add: pair_scale_spmf1 pair_scale_spmf2 weight_spmf_le_1 weight_spmf_nonneg) + (auto simp add: scale_scale_spmf max_def min_def field_simps weight_spmf_nonneg weight_spmf_eq_0) + finally show "map_spmf snd ?pq = \" . + qed(auto simp add: set_scale_spmf split: if_split_asm dest: * ** ) +next + assume ?lhs + then obtain pq where pq: "map_spmf fst pq = pair_spmf p q" + and pq': "map_spmf snd pq = pair_spmf p' q'" + and *: "\x y x' y'. ((x, y), (x', y')) \ set_spmf pq \ A x x' \ B y y'" + by(auto elim: rel_spmfE) + + show ?A + proof + let ?f = "(\((x, y), (x', y')). (x, x'))" + let ?pq = "map_spmf ?f pq" + have [simp]: "fst \ ?f = fst \ fst" by(simp add: split_def o_def) + show "map_spmf fst ?pq = scale_spmf (weight_spmf q) p" using pq + by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric]) + + have [simp]: "snd \ ?f = fst \ snd" by(simp add: split_def o_def) + show "map_spmf snd ?pq = scale_spmf (weight_spmf q') p'" using pq' + by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric]) + qed(auto dest: * ) + + show ?B + proof + let ?f = "(\((x, y), (x', y')). (y, y'))" + let ?pq = "map_spmf ?f pq" + have [simp]: "fst \ ?f = snd \ fst" by(simp add: split_def o_def) + show "map_spmf fst ?pq = scale_spmf (weight_spmf p) q" using pq + by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric]) + + have [simp]: "snd \ ?f = snd \ snd" by(simp add: split_def o_def) + show "map_spmf snd ?pq = scale_spmf (weight_spmf p') q'" using pq' + by(simp add: spmf.map_comp)(simp add: spmf.map_comp[symmetric]) + qed(auto dest: * ) +qed + +lemma pair_pair_spmf: + "pair_spmf (pair_spmf p q) r = map_spmf (\(x, (y, z)). ((x, y), z)) (pair_spmf p (pair_spmf q r))" +by(simp add: pair_spmf_alt_def map_spmf_conv_bind_spmf) + +lemma pair_commute_spmf: + "pair_spmf p q = map_spmf (\(y, x). (x, y)) (pair_spmf q p)" +unfolding pair_spmf_alt_def by(subst bind_commute_spmf)(simp add: map_spmf_conv_bind_spmf) + +subsection {* Assertions *} + +definition assert_spmf :: "bool \ unit spmf" +where "assert_spmf b = (if b then return_spmf () else return_pmf None)" + +lemma assert_spmf_simps [simp]: + "assert_spmf True = return_spmf ()" + "assert_spmf False = return_pmf None" +by(simp_all add: assert_spmf_def) + +lemma in_set_assert_spmf [simp]: "x \ set_spmf (assert_spmf p) \ p" +by(cases p) simp_all + +lemma set_spmf_assert_spmf_eq_empty [simp]: "set_spmf (assert_spmf b) = {} \ \ b" +by(cases b) simp_all + +lemma lossless_assert_spmf [iff]: "lossless_spmf (assert_spmf b) \ b" +by(cases b) simp_all + +subsection {* Try *} + +definition try_spmf :: "'a spmf \ 'a spmf \ 'a spmf" ("TRY _ ELSE _" [0,60] 59) +where "try_spmf p q = bind_pmf p (\x. case x of None \ q | Some y \ return_spmf y)" + +lemma try_spmf_lossless [simp]: + assumes "lossless_spmf p" + shows "TRY p ELSE q = p" +proof - + have "TRY p ELSE q = bind_pmf p return_pmf" unfolding try_spmf_def using assms + by(auto simp add: lossless_iff_set_pmf_None split: option.split intro: bind_pmf_cong) + thus ?thesis by(simp add: bind_return_pmf') +qed + +lemma try_spmf_return_spmf1: "TRY return_spmf x ELSE q = return_spmf x" +by(simp add: try_spmf_def bind_return_pmf) + +lemma try_spmf_return_None [simp]: "TRY return_pmf None ELSE q = q" +by(simp add: try_spmf_def bind_return_pmf) + +lemma try_spmf_return_pmf_None2 [simp]: "TRY p ELSE return_pmf None = p" +by(simp add: try_spmf_def option.case_distrib[symmetric] bind_return_pmf' case_option_id) + +lemma map_try_spmf: "map_spmf f (try_spmf p q) = try_spmf (map_spmf f p) (map_spmf f q)" +by(simp add: try_spmf_def map_bind_pmf bind_map_pmf option.case_distrib[where h="map_spmf f"] o_def cong del: option.case_cong_weak) + +lemma try_spmf_bind_pmf: "TRY (bind_pmf p f) ELSE q = bind_pmf p (\x. TRY (f x) ELSE q)" +by(simp add: try_spmf_def bind_assoc_pmf) + +lemma try_spmf_bind_spmf_lossless: + "lossless_spmf p \ TRY (bind_spmf p f) ELSE q = bind_spmf p (\x. TRY (f x) ELSE q)" +by(auto simp add: try_spmf_def bind_spmf_def bind_assoc_pmf bind_return_pmf lossless_iff_set_pmf_None intro!: bind_pmf_cong split: option.split) + +lemma try_spmf_bind_out: + "lossless_spmf p \ bind_spmf p (\x. TRY (f x) ELSE q) = TRY (bind_spmf p f) ELSE q" +by(simp add: try_spmf_bind_spmf_lossless) + +lemma lossless_try_spmf [simp]: + "lossless_spmf (TRY p ELSE q) \ lossless_spmf p \ lossless_spmf q" +by(auto simp add: try_spmf_def in_set_spmf lossless_iff_set_pmf_None split: option.splits) + +context begin interpretation lifting_syntax . +lemma try_spmf_parametric [transfer_rule]: + "(rel_spmf A ===> rel_spmf A ===> rel_spmf A) try_spmf try_spmf" +unfolding try_spmf_def[abs_def] by transfer_prover +end + +lemma try_spmf_cong: + "\ p = p'; \ lossless_spmf p' \ q = q' \ \ TRY p ELSE q = TRY p' ELSE q'" +unfolding try_spmf_def +by(rule bind_pmf_cong)(auto split: option.split simp add: lossless_iff_set_pmf_None) + +lemma rel_spmf_try_spmf: + "\ rel_spmf R p p'; \ lossless_spmf p' \ rel_spmf R q q' \ + \ rel_spmf R (TRY p ELSE q) (TRY p' ELSE q')" +unfolding try_spmf_def +apply(rule rel_pmf_bindI[where R="\x y. rel_option R x y \ x \ set_pmf p \ y \ set_pmf p'"]) + apply(erule pmf.rel_mono_strong; simp) +apply(auto split: option.split simp add: lossless_iff_set_pmf_None) +done + +lemma spmf_try_spmf: + "spmf (TRY p ELSE q) x = spmf p x + pmf p None * spmf q x" +proof - + have "ennreal (spmf (TRY p ELSE q) x) = \\<^sup>+ y. ennreal (spmf q x) * indicator {None} y + indicator {Some x} y \measure_pmf p" + unfolding try_spmf_def ennreal_pmf_bind by(rule nn_integral_cong)(simp split: option.split split_indicator) + also have "\ = (\\<^sup>+ y. ennreal (spmf q x) * indicator {None} y \measure_pmf p) + \\<^sup>+ y. indicator {Some x} y \measure_pmf p" + by(simp add: nn_integral_add) + also have "\ = ennreal (spmf q x) * pmf p None + spmf p x" by(simp add: emeasure_pmf_single) + finally show ?thesis by(simp add: ennreal_mult[symmetric] ennreal_plus[symmetric] del: ennreal_plus) +qed + +lemma try_scale_spmf_same [simp]: "lossless_spmf p \ TRY scale_spmf k p ELSE p = p" +by(rule spmf_eqI)(auto simp add: spmf_try_spmf spmf_scale_spmf pmf_scale_spmf_None lossless_iff_pmf_None weight_spmf_conv_pmf_None min_def max_def field_simps) + +lemma pmf_try_spmf_None [simp]: "pmf (TRY p ELSE q) None = pmf p None * pmf q None" (is "?lhs = ?rhs") +proof - + have "?lhs = \ x. pmf q None * indicator {None} x \measure_pmf p" + unfolding try_spmf_def pmf_bind by(rule integral_cong)(simp_all split: option.split) + also have "\ = ?rhs" by(simp add: measure_pmf_single) + finally show ?thesis . +qed + +lemma try_bind_spmf_lossless2: + "lossless_spmf q \ TRY (bind_spmf p f) ELSE q = TRY (p \ (\x. TRY (f x) ELSE q)) ELSE q" +by(rule spmf_eqI)(simp add: spmf_try_spmf pmf_bind_spmf_None spmf_bind field_simps measure_spmf.integrable_const_bound[where B=1] pmf_le_1 lossless_iff_pmf_None) + +lemma try_bind_spmf_lossless2': + fixes f :: "'a \ 'b spmf" shows + "\ NO_MATCH (\x :: 'a. try_spmf (g x :: 'b spmf) (h x)) f; lossless_spmf q \ + \ TRY (bind_spmf p f) ELSE q = TRY (p \ (\x :: 'a. TRY (f x) ELSE q)) ELSE q" +by(rule try_bind_spmf_lossless2) + +lemma try_bind_assert_spmf: + "TRY (assert_spmf b \ f) ELSE q = (if b then TRY (f ()) ELSE q else q)" +by simp + +subsection \Miscellaneous\ + +lemma assumes "rel_spmf (\x y. bad1 x = bad2 y \ (\ bad2 y \ A x \ B y)) p q" (is "rel_spmf ?A _ _") + shows fundamental_lemma_bad: "measure (measure_spmf p) {x. bad1 x} = measure (measure_spmf q) {y. bad2 y}" (is "?bad") + and fundamental_lemma: "\measure (measure_spmf p) {x. A x} - measure (measure_spmf q) {y. B y}\ \ + measure (measure_spmf p) {x. bad1 x}" (is ?fundamental) +proof - + have good: "rel_fun ?A op = (\x. A x \ \ bad1 x) (\y. B y \ \ bad2 y)" by(auto simp add: rel_fun_def) + from assms have 1: "measure (measure_spmf p) {x. A x \ \ bad1 x} = measure (measure_spmf q) {y. B y \ \ bad2 y}" + by(rule measure_spmf_parametric[THEN rel_funD, THEN rel_funD])(rule Collect_parametric[THEN rel_funD, OF good]) + + have bad: "rel_fun ?A op = bad1 bad2" by(simp add: rel_fun_def) + show 2: ?bad using assms + by(rule measure_spmf_parametric[THEN rel_funD, THEN rel_funD])(rule Collect_parametric[THEN rel_funD, OF bad]) + + let ?\p = "measure (measure_spmf p)" and ?\q = "measure (measure_spmf q)" + have "{x. A x \ bad1 x} \ {x. A x \ \ bad1 x} = {x. A x}" + and "{y. B y \ bad2 y} \ {y. B y \ \ bad2 y} = {y. B y}" by auto + then have "\?\p {x. A x} - ?\q {x. B x}\ = \?\p ({x. A x \ bad1 x} \ {x. A x \ \ bad1 x}) - ?\q ({y. B y \ bad2 y} \ {y. B y \ \ bad2 y})\" + by simp + also have "\ = \?\p {x. A x \ bad1 x} + ?\p {x. A x \ \ bad1 x} - ?\q {y. B y \ bad2 y} - ?\q {y. B y \ \ bad2 y}\" + by(subst (1 2) measure_Union)(auto) + also have "\ = \?\p {x. A x \ bad1 x} - ?\q {y. B y \ bad2 y}\" using 1 by simp + also have "\ \ max (?\p {x. A x \ bad1 x}) (?\q {y. B y \ bad2 y})" + by(rule abs_leI)(auto simp add: max_def not_le, simp_all only: add_increasing measure_nonneg mult_2) + also have "\ \ max (?\p {x. bad1 x}) (?\q {y. bad2 y})" + by(rule max.mono; rule measure_spmf.finite_measure_mono; auto) + also note 2[symmetric] + finally show ?fundamental by simp +qed + +end diff -r 9986559617ee -r ea13f44da888 src/HOL/Probability/document/root.tex --- a/src/HOL/Probability/document/root.tex Mon Jun 06 22:22:05 2016 +0200 +++ b/src/HOL/Probability/document/root.tex Wed Jun 08 09:05:32 2016 +0200 @@ -2,6 +2,7 @@ \usepackage{graphicx,isabelle,isabellesym,latexsym,textcomp} \usepackage{amsmath} \usepackage{amssymb} +\usepackage{wasysym} \usepackage[only,bigsqcap]{stmaryrd} \usepackage[utf8]{inputenc} \usepackage{pdfsetup}