| author | wenzelm | 
| Fri, 26 May 2017 11:09:16 +0200 | |
| changeset 65930 | 9a28fc03c3fe | 
| parent 64267 | b9a1486e79be | 
| child 66453 | cc19f7ca2ed6 | 
| permissions | -rw-r--r-- | 
| 63194 | 1 | (* Title: HOL/Probability/PMF_Impl.thy | 
| 2 | Author: Manuel Eberl, TU München | |
| 3 | ||
| 4 | An implementation of PMFs using Mappings, which are implemented with association lists | |
| 5 | by default. Also includes Quickcheck setup for PMFs. | |
| 6 | *) | |
| 7 | ||
| 63195 | 8 | section \<open>Code generation for PMFs\<close> | 
| 9 | ||
| 63194 | 10 | theory PMF_Impl | 
| 11 | imports Probability_Mass_Function "~~/src/HOL/Library/AList_Mapping" | |
| 12 | begin | |
| 13 | ||
| 63195 | 14 | subsection \<open>General code generation setup\<close> | 
| 15 | ||
| 63194 | 16 | definition pmf_of_mapping :: "('a, real) mapping \<Rightarrow> 'a pmf" where
 | 
| 17 | "pmf_of_mapping m = embed_pmf (Mapping.lookup_default 0 m)" | |
| 18 | ||
| 19 | lemma nn_integral_lookup_default: | |
| 20 |   fixes m :: "('a, real) mapping"
 | |
| 21 | assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ x. x \<ge> 0)" | |
| 22 | shows "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) = | |
| 23 | ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)" | |
| 24 | proof - | |
| 25 | have "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) = | |
| 26 | (\<Sum>x\<in>Mapping.keys m. ennreal (Mapping.lookup_default 0 m x))" using assms | |
| 27 | by (subst nn_integral_count_space'[of "Mapping.keys m"]) | |
| 28 | (auto simp: Mapping.lookup_default_def keys_is_none_rep Option.is_none_def) | |
| 29 | also from assms have "\<dots> = ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)" | |
| 64267 | 30 | by (intro sum_ennreal) | 
| 63194 | 31 | (auto simp: Mapping.lookup_default_def All_mapping_def split: option.splits) | 
| 32 | finally show ?thesis . | |
| 33 | qed | |
| 34 | ||
| 35 | lemma pmf_of_mapping: | |
| 36 | assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ p. p \<ge> 0)" | |
| 37 | assumes "(\<Sum>x\<in>Mapping.keys m. Mapping.lookup_default 0 m x) = 1" | |
| 38 | shows "pmf (pmf_of_mapping m) x = Mapping.lookup_default 0 m x" | |
| 39 | unfolding pmf_of_mapping_def | |
| 40 | proof (intro pmf_embed_pmf) | |
| 41 | from assms show "(\<integral>\<^sup>+x. ennreal (Mapping.lookup_default 0 m x) \<partial>count_space UNIV) = 1" | |
| 42 | by (subst nn_integral_lookup_default) (simp_all) | |
| 43 | qed (insert assms, simp add: All_mapping_def Mapping.lookup_default_def split: option.splits) | |
| 44 | ||
| 45 | lemma pmf_of_set_pmf_of_mapping: | |
| 46 |   assumes "A \<noteq> {}" "set xs = A" "distinct xs"
 | |
| 47 | shows "pmf_of_set A = pmf_of_mapping (Mapping.tabulate xs (\<lambda>_. 1 / real (length xs)))" | |
| 48 | (is "?lhs = ?rhs") | |
| 49 | by (rule pmf_eqI, subst pmf_of_mapping) | |
| 50 | (insert assms, auto intro!: All_mapping_tabulate | |
| 51 | simp: Mapping.lookup_default_def lookup_tabulate distinct_card) | |
| 52 | ||
| 53 | lift_definition mapping_of_pmf :: "'a pmf \<Rightarrow> ('a, real) mapping" is
 | |
| 54 | "\<lambda>p x. if pmf p x = 0 then None else Some (pmf p x)" . | |
| 55 | ||
| 56 | lemma lookup_default_mapping_of_pmf: | |
| 57 | "Mapping.lookup_default 0 (mapping_of_pmf p) x = pmf p x" | |
| 58 | by (simp add: mapping_of_pmf.abs_eq lookup_default_def Mapping.lookup.abs_eq) | |
| 59 | ||
| 60 | context | |
| 61 | begin | |
| 62 | ||
| 63 | interpretation pmf_as_function . | |
| 64 | ||
| 65 | lemma nn_integral_pmf_eq_1: "(\<integral>\<^sup>+ x. ennreal (pmf p x) \<partial>count_space UNIV) = 1" | |
| 66 | by transfer simp_all | |
| 67 | end | |
| 68 | ||
| 69 | lemma pmf_of_mapping_mapping_of_pmf [code abstype]: | |
| 70 | "pmf_of_mapping (mapping_of_pmf p) = p" | |
| 71 | unfolding pmf_of_mapping_def | |
| 72 | by (rule pmf_eqI, subst pmf_embed_pmf) | |
| 73 | (insert nn_integral_pmf_eq_1[of p], | |
| 74 | auto simp: lookup_default_mapping_of_pmf split: option.splits) | |
| 75 | ||
| 76 | lemma mapping_of_pmfI: | |
| 77 | assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup m x = Some (pmf p x)" | |
| 78 | assumes "Mapping.keys m = set_pmf p" | |
| 79 | shows "mapping_of_pmf p = m" | |
| 80 | using assms by transfer (rule ext, auto simp: set_pmf_eq) | |
| 81 | ||
| 82 | lemma mapping_of_pmfI': | |
| 83 | assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup_default 0 m x = pmf p x" | |
| 84 | assumes "Mapping.keys m = set_pmf p" | |
| 85 | shows "mapping_of_pmf p = m" | |
| 86 | using assms unfolding Mapping.lookup_default_def | |
| 87 | by transfer (rule ext, force simp: set_pmf_eq) | |
| 88 | ||
| 89 | lemma return_pmf_code [code abstract]: | |
| 90 | "mapping_of_pmf (return_pmf x) = Mapping.update x 1 Mapping.empty" | |
| 91 | by (intro mapping_of_pmfI) (auto simp: lookup_update') | |
| 92 | ||
| 93 | lemma pmf_of_set_code_aux: | |
| 94 |   assumes "A \<noteq> {}" "set xs = A" "distinct xs"
 | |
| 95 | shows "mapping_of_pmf (pmf_of_set A) = Mapping.tabulate xs (\<lambda>_. 1 / real (length xs))" | |
| 96 | using assms | |
| 97 | by (intro mapping_of_pmfI, subst pmf_of_set) | |
| 98 | (auto simp: lookup_tabulate distinct_card) | |
| 99 | ||
| 100 | definition pmf_of_set_impl where | |
| 101 | "pmf_of_set_impl A = mapping_of_pmf (pmf_of_set A)" | |
| 63195 | 102 | |
| 103 | (* This equation can be used to easily implement pmf_of_set for other set implementations *) | |
| 104 | lemma pmf_of_set_impl_code_alt: | |
| 105 |   assumes "A \<noteq> {}" "finite A"
 | |
| 106 | shows "pmf_of_set_impl A = | |
| 107 | (let p = 1 / real (card A) | |
| 108 | in Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A)" | |
| 109 | proof - | |
| 110 | define p where "p = 1 / real (card A)" | |
| 111 | let ?m = "Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A" | |
| 112 | interpret comp_fun_idem "\<lambda>x. Mapping.update x p" | |
| 113 | by standard (transfer, force simp: fun_eq_iff)+ | |
| 114 | have keys: "Mapping.keys ?m = A" | |
| 115 | using assms(2) by (induction A rule: finite_induct) simp_all | |
| 116 | have lookup: "Mapping.lookup ?m x = Some p" if "x \<in> A" for x | |
| 117 | using assms(2) that by (induction A rule: finite_induct) (auto simp: lookup_update') | |
| 118 | from keys lookup assms show ?thesis unfolding pmf_of_set_impl_def | |
| 119 | by (intro mapping_of_pmfI) (simp_all add: Let_def p_def) | |
| 120 | qed | |
| 121 | ||
| 63194 | 122 | lemma pmf_of_set_impl_code [code]: | 
| 123 | "pmf_of_set_impl (set xs) = | |
| 124 | (if xs = [] then | |
| 125 | Code.abort (STR ''pmf_of_set of empty set'') (\<lambda>_. mapping_of_pmf (pmf_of_set (set xs))) | |
| 126 | else let xs' = remdups xs; p = 1 / real (length xs') in | |
| 127 | Mapping.tabulate xs' (\<lambda>_. p))" | |
| 128 | unfolding pmf_of_set_impl_def | |
| 129 | using pmf_of_set_code_aux[of "set xs" "remdups xs"] by (simp add: Let_def) | |
| 130 | ||
| 131 | lemma pmf_of_set_code [code abstract]: | |
| 132 | "mapping_of_pmf (pmf_of_set A) = pmf_of_set_impl A" | |
| 133 | by (simp add: pmf_of_set_impl_def) | |
| 134 | ||
| 135 | ||
| 136 | lemma pmf_of_multiset_pmf_of_mapping: | |
| 137 |   assumes "A \<noteq> {#}" "set xs = set_mset A" "distinct xs"
 | |
| 138 | shows "mapping_of_pmf (pmf_of_multiset A) = Mapping.tabulate xs (\<lambda>x. count A x / real (size A))" | |
| 139 | using assms by (intro mapping_of_pmfI) (auto simp: lookup_tabulate) | |
| 140 | ||
| 141 | definition pmf_of_multiset_impl where | |
| 63195 | 142 | "pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)" | 
| 143 | ||
| 144 | lemma pmf_of_multiset_impl_code_alt: | |
| 145 |   assumes "A \<noteq> {#}"
 | |
| 146 | shows "pmf_of_multiset_impl A = | |
| 147 | (let p = 1 / real (size A) | |
| 148 | in fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A)" | |
| 149 | proof - | |
| 150 | define p where "p = 1 / real (size A)" | |
| 151 | interpret comp_fun_commute "\<lambda>x. Mapping.map_default x 0 (op + p)" | |
| 152 | unfolding Mapping.map_default_def [abs_def] | |
| 153 | by (standard, intro mapping_eqI ext) | |
| 154 | (simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def) | |
| 155 | let ?m = "fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A" | |
| 156 | have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all | |
| 157 | have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x | |
| 158 | by (induction A) | |
| 159 | (simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs) | |
| 160 | from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def | |
| 161 | by (intro mapping_of_pmfI') (simp_all add: Let_def p_def) | |
| 162 | qed | |
| 63194 | 163 | |
| 164 | lemma pmf_of_multiset_impl_code [code]: | |
| 165 | "pmf_of_multiset_impl (mset xs) = | |
| 166 | (if xs = [] then | |
| 167 | Code.abort (STR ''pmf_of_multiset of empty multiset'') | |
| 168 | (\<lambda>_. mapping_of_pmf (pmf_of_multiset (mset xs))) | |
| 169 | else let xs' = remdups xs; p = 1 / real (length xs) in | |
| 170 | Mapping.tabulate xs' (\<lambda>x. real (count (mset xs) x) * p))" | |
| 171 | using pmf_of_multiset_pmf_of_mapping[of "mset xs" "remdups xs"] | |
| 63195 | 172 | by (simp add: pmf_of_multiset_impl_def) | 
| 63194 | 173 | |
| 174 | lemma pmf_of_multiset_code [code abstract]: | |
| 175 | "mapping_of_pmf (pmf_of_multiset A) = pmf_of_multiset_impl A" | |
| 176 | by (simp add: pmf_of_multiset_impl_def) | |
| 177 | ||
| 63195 | 178 | |
| 63194 | 179 | lemma bernoulli_pmf_code [code abstract]: | 
| 180 | "mapping_of_pmf (bernoulli_pmf p) = | |
| 181 | (if p \<le> 0 then Mapping.update False 1 Mapping.empty | |
| 182 | else if p \<ge> 1 then Mapping.update True 1 Mapping.empty | |
| 183 | else Mapping.update False (1 - p) (Mapping.update True p Mapping.empty))" | |
| 184 | by (intro mapping_of_pmfI) (auto simp: bernoulli_pmf.rep_eq lookup_update' set_pmf_eq) | |
| 185 | ||
| 186 | ||
| 187 | lemma pmf_code [code]: "pmf p x = Mapping.lookup_default 0 (mapping_of_pmf p) x" | |
| 188 | unfolding mapping_of_pmf_def Mapping.lookup_default_def | |
| 189 | by (auto split: option.splits simp: id_def Mapping.lookup.abs_eq) | |
| 190 | ||
| 191 | lemma set_pmf_code [code]: "set_pmf p = Mapping.keys (mapping_of_pmf p)" | |
| 192 | by transfer (auto simp: dom_def set_pmf_eq) | |
| 193 | ||
| 194 | lemma keys_mapping_of_pmf [simp]: "Mapping.keys (mapping_of_pmf p) = set_pmf p" | |
| 195 | by transfer (auto simp: dom_def set_pmf_eq) | |
| 196 | ||
| 197 | ||
| 198 | ||
| 199 | definition fold_combine_plus where | |
| 200 | "fold_combine_plus = comm_monoid_set.F (Mapping.combine (op + :: real \<Rightarrow> _)) Mapping.empty" | |
| 201 | ||
| 202 | context | |
| 203 | begin | |
| 204 | ||
| 205 | interpretation fold_combine_plus: combine_mapping_abel_semigroup "op + :: real \<Rightarrow> _" | |
| 206 | by unfold_locales (simp_all add: add_ac) | |
| 207 | ||
| 208 | qualified lemma lookup_default_fold_combine_plus: | |
| 209 |   fixes A :: "'b set" and f :: "'b \<Rightarrow> ('a, real) mapping"
 | |
| 210 | assumes "finite A" | |
| 211 | shows "Mapping.lookup_default 0 (fold_combine_plus f A) x = | |
| 212 | (\<Sum>y\<in>A. Mapping.lookup_default 0 (f y) x)" | |
| 213 | unfolding fold_combine_plus_def using assms | |
| 214 | by (induction A rule: finite_induct) | |
| 215 | (simp_all add: lookup_default_empty lookup_default_neutral_combine) | |
| 216 | ||
| 217 | qualified lemma keys_fold_combine_plus: | |
| 218 | "finite A \<Longrightarrow> Mapping.keys (fold_combine_plus f A) = (\<Union>x\<in>A. Mapping.keys (f x))" | |
| 219 | by (simp add: fold_combine_plus_def fold_combine_plus.keys_fold_combine) | |
| 220 | ||
| 221 | qualified lemma fold_combine_plus_code [code]: | |
| 222 | "fold_combine_plus g (set xs) = foldr (\<lambda>x. Mapping.combine op+ (g x)) (remdups xs) Mapping.empty" | |
| 223 | by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code) | |
| 224 | ||
| 225 | private lemma lookup_default_0_map_values: | |
| 63195 | 226 | assumes "f x 0 = 0" | 
| 227 | shows "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)" | |
| 63194 | 228 | unfolding Mapping.lookup_default_def | 
| 63195 | 229 | using assms by transfer (auto split: option.splits) | 
| 63194 | 230 | |
| 231 | qualified lemma mapping_of_bind_pmf: | |
| 232 | assumes "finite (set_pmf p)" | |
| 233 | shows "mapping_of_pmf (bind_pmf p f) = | |
| 63195 | 234 | fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x)) | 
| 63194 | 235 | (mapping_of_pmf (f x))) (set_pmf p)" | 
| 236 | using assms | |
| 237 | by (intro mapping_of_pmfI') | |
| 238 | (auto simp: keys_fold_combine_plus lookup_default_fold_combine_plus | |
| 239 | pmf_bind integral_measure_pmf lookup_default_0_map_values | |
| 240 | lookup_default_mapping_of_pmf mult_ac) | |
| 241 | ||
| 63195 | 242 | lift_definition bind_pmf_aux :: "'a pmf \<Rightarrow> ('a \<Rightarrow> 'b pmf) \<Rightarrow> 'a set \<Rightarrow> ('b, real) mapping" is
 | 
| 243 | "\<lambda>(p :: 'a pmf) (f :: 'a \<Rightarrow> 'b pmf) (A::'a set) (x::'b). | |
| 244 | if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then | |
| 245 | Some (measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)) | |
| 246 | else None" . | |
| 247 | ||
| 248 | lemma keys_bind_pmf_aux [simp]: | |
| 249 | "Mapping.keys (bind_pmf_aux p f A) = (\<Union>x\<in>A. set_pmf (f x))" | |
| 250 | by transfer (auto split: if_splits) | |
| 251 | ||
| 252 | lemma lookup_default_bind_pmf_aux: | |
| 253 | "Mapping.lookup_default 0 (bind_pmf_aux p f A) x = | |
| 254 | (if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then | |
| 255 | measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x) else 0)" | |
| 256 | unfolding lookup_default_def by transfer' simp_all | |
| 257 | ||
| 258 | lemma lookup_default_bind_pmf_aux' [simp]: | |
| 259 | "Mapping.lookup_default 0 (bind_pmf_aux p f (set_pmf p)) x = pmf (bind_pmf p f) x" | |
| 260 | unfolding lookup_default_def | |
| 261 | by transfer (auto simp: pmf_bind AE_measure_pmf_iff set_pmf_eq | |
| 262 | intro!: integral_cong_AE integral_eq_zero_AE) | |
| 263 | ||
| 264 | lemma bind_pmf_aux_correct: | |
| 265 | "mapping_of_pmf (bind_pmf p f) = bind_pmf_aux p f (set_pmf p)" | |
| 266 | by (intro mapping_of_pmfI') simp_all | |
| 267 | ||
| 268 | lemma bind_pmf_aux_code_aux: | |
| 269 | assumes "finite A" | |
| 270 | shows "bind_pmf_aux p f A = | |
| 271 | fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x)) | |
| 272 | (mapping_of_pmf (f x))) A" (is "?lhs = ?rhs") | |
| 273 | proof (intro mapping_eqI'[where d = 0]) | |
| 274 | fix x assume "x \<in> Mapping.keys ?lhs" | |
| 275 | then obtain y where y: "y \<in> A" "x \<in> set_pmf (f y)" by auto | |
| 276 | hence "Mapping.lookup_default 0 ?lhs x = | |
| 277 | measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)" | |
| 278 | by (auto simp: lookup_default_bind_pmf_aux) | |
| 279 | also from assms have "\<dots> = (\<Sum>y\<in>A. pmf p y * pmf (f y) x)" | |
| 280 | by (subst integral_measure_pmf [of A]) | |
| 281 | (auto simp: set_pmf_eq indicator_def mult_ac split: if_splits) | |
| 282 | also from assms have "\<dots> = Mapping.lookup_default 0 ?rhs x" | |
| 283 | by (simp add: lookup_default_fold_combine_plus lookup_default_0_map_values | |
| 284 | lookup_default_mapping_of_pmf) | |
| 285 | finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" . | |
| 286 | qed (insert assms, simp_all add: keys_fold_combine_plus) | |
| 287 | ||
| 288 | lemma bind_pmf_aux_code [code]: | |
| 289 | "bind_pmf_aux p f (set xs) = | |
| 290 | fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x)) | |
| 291 | (mapping_of_pmf (f x))) (set xs)" | |
| 292 | by (rule bind_pmf_aux_code_aux) simp_all | |
| 293 | ||
| 294 | lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct | |
| 63194 | 295 | |
| 296 | end | |
| 297 | ||
| 63195 | 298 | hide_const (open) fold_combine_plus | 
| 63194 | 299 | |
| 300 | ||
| 301 | lift_definition cond_pmf_impl :: "'a pmf \<Rightarrow> 'a set \<Rightarrow> ('a, real) mapping option" is
 | |
| 302 |   "\<lambda>p A. if A \<inter> set_pmf p = {} then None else 
 | |
| 303 | Some (\<lambda>x. if x \<in> A \<inter> set_pmf p then Some (pmf p x / measure_pmf.prob p A) else None)" . | |
| 304 | ||
| 63195 | 305 | lemma cond_pmf_impl_code_alt: | 
| 306 | assumes "finite A" | |
| 307 | shows "cond_pmf_impl p A = ( | |
| 308 | let C = A \<inter> set_pmf p; | |
| 309 | prob = (\<Sum>x\<in>C. pmf p x) | |
| 310 | in if prob = 0 then | |
| 311 | None | |
| 312 | else | |
| 313 | Some (Mapping.map_values (\<lambda>_ y. y / prob) | |
| 314 | (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))" | |
| 63194 | 315 | proof - | 
| 63195 | 316 | define C where "C = A \<inter> set_pmf p" | 
| 317 | define prob where "prob = (\<Sum>x\<in>C. pmf p x)" | |
| 318 | also note C_def | |
| 319 | also from assms have "(\<Sum>x\<in>A \<inter> set_pmf p. pmf p x) = (\<Sum>x\<in>A. pmf p x)" | |
| 64267 | 320 | by (intro sum.mono_neutral_left) (auto simp: set_pmf_eq) | 
| 63195 | 321 | finally have prob1: "prob = (\<Sum>x\<in>A. pmf p x)" . | 
| 322 | hence prob2: "prob = measure_pmf.prob p A" | |
| 323 | using assms by (subst measure_measure_pmf_finite) simp_all | |
| 324 |   have prob3: "prob = 0 \<longleftrightarrow> A \<inter> set_pmf p = {}"
 | |
| 64267 | 325 | by (subst prob1, subst sum_nonneg_eq_0_iff) (auto simp: set_pmf_eq assms) | 
| 63195 | 326 | from assms have prob4: "prob = measure_pmf.prob p C" | 
| 327 | unfolding prob_def by (intro measure_measure_pmf_finite [symmetric]) (simp_all add: C_def) | |
| 63194 | 328 | |
| 329 | show ?thesis | |
| 330 | proof (cases "prob = 0") | |
| 331 | case True | |
| 63195 | 332 |     hence "A \<inter> set_pmf p = {}" by (subst (asm) prob3)
 | 
| 333 | with True show ?thesis by (simp add: Let_def prob_def C_def cond_pmf_impl.abs_eq) | |
| 63194 | 334 | next | 
| 335 | case False | |
| 63195 | 336 |     hence A: "C \<noteq> {}" unfolding C_def by (subst (asm) prob3) auto
 | 
| 337 | with prob3 have prob_nz: "prob \<noteq> 0" by (auto simp: C_def) | |
| 63194 | 338 | fix x | 
| 63195 | 339 | have "cond_pmf_impl p A = | 
| 340 | Some (mapping.Mapping (\<lambda>x. if x \<in> C then | |
| 341 | Some (pmf p x / measure_pmf.prob p C) else None))" | |
| 63194 | 342 | (is "_ = Some ?m") | 
| 63195 | 343 | using A prob2 prob4 unfolding C_def by transfer (auto simp: fun_eq_iff) | 
| 344 | also have "?m = Mapping.map_values (\<lambda>_ y. y / prob) | |
| 345 | (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))" | |
| 346 | using prob_nz prob4 assms unfolding C_def | |
| 347 | by transfer (auto simp: fun_eq_iff set_pmf_eq) | |
| 348 | finally show ?thesis using False by (simp add: Let_def prob_def C_def) | |
| 63194 | 349 | qed | 
| 350 | qed | |
| 351 | ||
| 63195 | 352 | lemma cond_pmf_impl_code [code]: | 
| 353 | "cond_pmf_impl p (set xs) = ( | |
| 354 | let C = set xs \<inter> set_pmf p; | |
| 355 | prob = (\<Sum>x\<in>C. pmf p x) | |
| 356 | in if prob = 0 then | |
| 357 | None | |
| 358 | else | |
| 359 | Some (Mapping.map_values (\<lambda>_ y. y / prob) | |
| 360 | (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))" | |
| 361 | by (rule cond_pmf_impl_code_alt) simp_all | |
| 362 | ||
| 63194 | 363 | lemma cond_pmf_code [code abstract]: | 
| 364 | "mapping_of_pmf (cond_pmf p A) = | |
| 365 | (case cond_pmf_impl p A of | |
| 366 | None \<Rightarrow> Code.abort (STR ''cond_pmf with set of probability 0'') | |
| 367 | (\<lambda>_. mapping_of_pmf (cond_pmf p A)) | |
| 368 | | Some m \<Rightarrow> m)" | |
| 369 | proof (cases "cond_pmf_impl p A") | |
| 370 | case (Some m) | |
| 371 |   hence A: "set_pmf p \<inter> A \<noteq> {}" by transfer (auto split: if_splits)
 | |
| 372 | from Some have B: "Mapping.keys m = set_pmf (cond_pmf p A)" | |
| 373 | by (subst set_cond_pmf[OF A], transfer) (auto split: if_splits) | |
| 374 | with Some A have "mapping_of_pmf (cond_pmf p A) = m" | |
| 375 | by (intro mapping_of_pmfI[OF _ B], transfer) (auto split: if_splits simp: pmf_cond) | |
| 376 | with Some show ?thesis by simp | |
| 377 | qed simp_all | |
| 378 | ||
| 379 | ||
| 380 | lemma binomial_pmf_code [code abstract]: | |
| 381 | "mapping_of_pmf (binomial_pmf n p) = ( | |
| 382 | if p < 0 \<or> p > 1 then | |
| 63195 | 383 | Code.abort (STR ''binomial_pmf with invalid probability'') | 
| 384 | (\<lambda>_. mapping_of_pmf (binomial_pmf n p)) | |
| 63194 | 385 | else if p = 0 then Mapping.update 0 1 Mapping.empty | 
| 386 | else if p = 1 then Mapping.update n 1 Mapping.empty | |
| 387 | else Mapping.tabulate [0..<Suc n] (\<lambda>k. real (n choose k) * p ^ k * (1 - p) ^ (n - k)))" | |
| 388 | by (cases "p < 0 \<or> p > 1") | |
| 389 | (simp, intro mapping_of_pmfI, | |
| 390 | auto simp: lookup_update' lookup_empty set_pmf_binomial_eq lookup_tabulate split: if_splits) | |
| 391 | ||
| 63195 | 392 | |
| 63194 | 393 | lemma pred_pmf_code [code]: | 
| 394 | "pred_pmf P p = (\<forall>x\<in>set_pmf p. P x)" | |
| 395 | by (auto simp: pred_pmf_def) | |
| 396 | ||
| 397 | ||
| 398 | lemma mapping_of_pmf_pmf_of_list: | |
| 63882 
018998c00003
renamed listsum -> sum_list, listprod ~> prod_list
 nipkow parents: 
63793diff
changeset | 399 | assumes "\<And>x. x \<in> snd ` set xs \<Longrightarrow> x > 0" "sum_list (map snd xs) = 1" | 
| 63194 | 400 | shows "mapping_of_pmf (pmf_of_list xs) = | 
| 401 | Mapping.tabulate (remdups (map fst xs)) | |
| 63882 
018998c00003
renamed listsum -> sum_list, listprod ~> prod_list
 nipkow parents: 
63793diff
changeset | 402 | (\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs)))" | 
| 63194 | 403 | proof - | 
| 404 | from assms have wf: "pmf_of_list_wf xs" by (intro pmf_of_list_wfI) force | |
| 63539 | 405 | with assms have "set_pmf (pmf_of_list xs) = fst ` set xs" | 
| 63194 | 406 | by (intro set_pmf_of_list_eq) auto | 
| 63539 | 407 | with wf show ?thesis | 
| 63194 | 408 | by (intro mapping_of_pmfI) (auto simp: lookup_tabulate pmf_pmf_of_list) | 
| 409 | qed | |
| 410 | ||
| 411 | lemma mapping_of_pmf_pmf_of_list': | |
| 412 | assumes "pmf_of_list_wf xs" | |
| 413 | defines "xs' \<equiv> filter (\<lambda>z. snd z \<noteq> 0) xs" | |
| 414 | shows "mapping_of_pmf (pmf_of_list xs) = | |
| 415 | Mapping.tabulate (remdups (map fst xs')) | |
| 63882 
018998c00003
renamed listsum -> sum_list, listprod ~> prod_list
 nipkow parents: 
63793diff
changeset | 416 | (\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs')))" (is "_ = ?rhs") | 
| 63194 | 417 | proof - | 
| 418 | have wf: "pmf_of_list_wf xs'" unfolding xs'_def by (rule pmf_of_list_remove_zeros) fact | |
| 419 | have pos: "\<forall>x\<in>snd`set xs'. x > 0" using assms(1) unfolding xs'_def | |
| 420 | by (force simp: pmf_of_list_wf_def) | |
| 421 | from assms have "pmf_of_list xs = pmf_of_list xs'" | |
| 422 | unfolding xs'_def by (subst pmf_of_list_remove_zeros) simp_all | |
| 423 | also from wf pos have "mapping_of_pmf \<dots> = ?rhs" | |
| 424 | by (intro mapping_of_pmf_pmf_of_list) (auto simp: pmf_of_list_wf_def) | |
| 425 | finally show ?thesis . | |
| 426 | qed | |
| 427 | ||
| 428 | lemma pmf_of_list_wf_code [code]: | |
| 63882 
018998c00003
renamed listsum -> sum_list, listprod ~> prod_list
 nipkow parents: 
63793diff
changeset | 429 | "pmf_of_list_wf xs \<longleftrightarrow> list_all (\<lambda>z. snd z \<ge> 0) xs \<and> sum_list (map snd xs) = 1" | 
| 63194 | 430 | by (auto simp add: pmf_of_list_wf_def list_all_def) | 
| 431 | ||
| 432 | lemma pmf_of_list_code [code abstract]: | |
| 433 | "mapping_of_pmf (pmf_of_list xs) = ( | |
| 434 | if pmf_of_list_wf xs then | |
| 435 | let xs' = filter (\<lambda>z. snd z \<noteq> 0) xs | |
| 436 | in Mapping.tabulate (remdups (map fst xs')) | |
| 63882 
018998c00003
renamed listsum -> sum_list, listprod ~> prod_list
 nipkow parents: 
63793diff
changeset | 437 | (\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs'))) | 
| 63194 | 438 | else | 
| 439 | Code.abort (STR ''Invalid list for pmf_of_list'') (\<lambda>_. mapping_of_pmf (pmf_of_list xs)))" | |
| 440 | using mapping_of_pmf_pmf_of_list'[of xs] by (simp add: Let_def) | |
| 441 | ||
| 442 | lemma mapping_of_pmf_eq_iff [simp]: | |
| 443 | "mapping_of_pmf p = mapping_of_pmf q \<longleftrightarrow> p = (q :: 'a pmf)" | |
| 444 | proof (transfer, intro iffI pmf_eqI) | |
| 445 | fix p q :: "'a pmf" and x :: 'a | |
| 446 | assume "(\<lambda>x. if pmf p x = 0 then None else Some (pmf p x)) = | |
| 447 | (\<lambda>x. if pmf q x = 0 then None else Some (pmf q x))" | |
| 448 | hence "(if pmf p x = 0 then None else Some (pmf p x)) = | |
| 449 | (if pmf q x = 0 then None else Some (pmf q x))" for x | |
| 450 | by (simp add: fun_eq_iff) | |
| 451 | from this[of x] show "pmf p x = pmf q x" by (auto split: if_splits) | |
| 452 | qed (simp_all cong: if_cong) | |
| 453 | ||
| 63195 | 454 | |
| 455 | subsection \<open>Code abbreviations for integrals and probabilities\<close> | |
| 456 | ||
| 457 | text \<open> | |
| 458 | Integrals and probabilities are defined for general measures, so we cannot give any | |
| 459 | code equations directly. We can, however, specialise these constants them to PMFs, | |
| 460 | give code equations for these specialised constants, and tell the code generator | |
| 461 | to unfold the original constants to the specialised ones whenever possible. | |
| 462 | \<close> | |
| 463 | ||
| 464 | definition pmf_integral where | |
| 465 | "pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \<Rightarrow> real)" | |
| 466 | ||
| 467 | definition pmf_set_integral where | |
| 468 | "pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\<lambda>x. indicator A x * f x :: real)" | |
| 469 | ||
| 470 | definition pmf_prob where | |
| 471 | "pmf_prob p A = measure_pmf.prob p A" | |
| 472 | ||
| 473 | lemma pmf_prob_compl: "pmf_prob p (-A) = 1 - pmf_prob p A" | |
| 474 | using measure_pmf.prob_compl[of A p] by (simp add: pmf_prob_def Compl_eq_Diff_UNIV) | |
| 475 | ||
| 476 | lemma pmf_integral_pmf_set_integral [code]: | |
| 477 | "pmf_integral p f = pmf_set_integral p f (set_pmf p)" | |
| 478 | unfolding pmf_integral_def pmf_set_integral_def | |
| 479 | by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff) | |
| 480 | ||
| 481 | lemma pmf_prob_pmf_set_integral: | |
| 482 | "pmf_prob p A = pmf_set_integral p (\<lambda>_. 1) A" | |
| 483 | by (simp add: pmf_prob_def pmf_set_integral_def) | |
| 484 | ||
| 485 | lemma pmf_set_integral_code_alt_finite: | |
| 486 | "finite A \<Longrightarrow> pmf_set_integral p f A = (\<Sum>x\<in>A. pmf p x * f x)" | |
| 487 | unfolding pmf_set_integral_def | |
| 488 | by (subst integral_measure_pmf[of A]) (auto simp: indicator_def mult_ac split: if_splits) | |
| 489 | ||
| 490 | lemma pmf_set_integral_code [code]: | |
| 491 | "pmf_set_integral p f (set xs) = (\<Sum>x\<in>set xs. pmf p x * f x)" | |
| 492 | by (rule pmf_set_integral_code_alt_finite) simp_all | |
| 493 | ||
| 494 | ||
| 495 | lemma pmf_prob_code_alt_finite: | |
| 496 | "finite A \<Longrightarrow> pmf_prob p A = (\<Sum>x\<in>A. pmf p x)" | |
| 497 | by (simp add: pmf_prob_pmf_set_integral pmf_set_integral_code_alt_finite) | |
| 498 | ||
| 499 | lemma pmf_prob_code [code]: | |
| 500 | "pmf_prob p (set xs) = (\<Sum>x\<in>set xs. pmf p x)" | |
| 501 | "pmf_prob p (List.coset xs) = 1 - (\<Sum>x\<in>set xs. pmf p x)" | |
| 502 | by (simp_all add: pmf_prob_code_alt_finite pmf_prob_compl) | |
| 503 | ||
| 504 | ||
| 505 | lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p" | |
| 506 | by (intro ext) (simp add: pmf_prob_def) | |
| 507 | ||
| 508 | (* FIXME: Why does this not work without parameters? *) | |
| 509 | lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p" | |
| 510 | by (intro ext) (simp add: pmf_integral_def) | |
| 511 | ||
| 512 | ||
| 513 | ||
| 63194 | 514 | definition "pmf_of_alist xs = embed_pmf (\<lambda>x. case map_of xs x of Some p \<Rightarrow> p | None \<Rightarrow> 0)" | 
| 515 | ||
| 516 | lemma pmf_of_mapping_Mapping [code_post]: | |
| 517 | "pmf_of_mapping (Mapping xs) = pmf_of_alist xs" | |
| 518 | unfolding pmf_of_mapping_def Mapping.lookup_default_def [abs_def] pmf_of_alist_def | |
| 519 | by transfer simp_all | |
| 520 | ||
| 521 | ||
| 522 | instantiation pmf :: (equal) equal | |
| 523 | begin | |
| 524 | ||
| 525 | definition "equal_pmf p q = (mapping_of_pmf p = mapping_of_pmf (q :: 'a pmf))" | |
| 526 | ||
| 527 | instance by standard (simp add: equal_pmf_def) | |
| 528 | end | |
| 529 | ||
| 63793 
e68a0b651eb5
add_mset constructor in multisets
 fleury <Mathias.Fleury@mpi-inf.mpg.de> parents: 
63539diff
changeset | 530 | definition single :: "'a \<Rightarrow> 'a multiset" where | 
| 
e68a0b651eb5
add_mset constructor in multisets
 fleury <Mathias.Fleury@mpi-inf.mpg.de> parents: 
63539diff
changeset | 531 | "single s = {#s#}"
 | 
| 63194 | 532 | |
| 533 | definition (in term_syntax) | |
| 534 |   pmfify :: "('a::typerep multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)) \<Rightarrow>
 | |
| 535 | 'a \<times> (unit \<Rightarrow> Code_Evaluation.term) \<Rightarrow> | |
| 536 | 'a pmf \<times> (unit \<Rightarrow> Code_Evaluation.term)" where | |
| 537 | [code_unfold]: "pmfify A x = | |
| 538 |     Code_Evaluation.valtermify pmf_of_multiset {\<cdot>} 
 | |
| 539 |       (Code_Evaluation.valtermify (op +) {\<cdot>} A {\<cdot>} 
 | |
| 540 |        (Code_Evaluation.valtermify single {\<cdot>} x))"
 | |
| 541 | ||
| 542 | ||
| 543 | notation fcomp (infixl "\<circ>>" 60) | |
| 544 | notation scomp (infixl "\<circ>\<rightarrow>" 60) | |
| 545 | ||
| 546 | instantiation pmf :: (random) random | |
| 547 | begin | |
| 548 | ||
| 549 | definition | |
| 550 | "Quickcheck_Random.random i = | |
| 551 | Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>A. | |
| 552 | Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>x. Pair (pmfify A x)))" | |
| 553 | ||
| 554 | instance .. | |
| 555 | ||
| 556 | end | |
| 557 | ||
| 558 | no_notation fcomp (infixl "\<circ>>" 60) | |
| 559 | no_notation scomp (infixl "\<circ>\<rightarrow>" 60) | |
| 560 | ||
| 561 | instantiation pmf :: (full_exhaustive) full_exhaustive | |
| 562 | begin | |
| 563 | ||
| 564 | definition full_exhaustive_pmf :: "('a pmf \<times> (unit \<Rightarrow> term) \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option"
 | |
| 565 | where | |
| 566 | "full_exhaustive_pmf f i = | |
| 567 | Quickcheck_Exhaustive.full_exhaustive (\<lambda>A. | |
| 568 | Quickcheck_Exhaustive.full_exhaustive (\<lambda>x. f (pmfify A x)) i) i" | |
| 569 | ||
| 570 | instance .. | |
| 571 | ||
| 572 | end | |
| 573 | ||
| 64267 | 574 | end |