author | paulson <lp15@cam.ac.uk> |
Wed, 25 Mar 2020 12:37:57 +0000 | |
changeset 71591 | 8e4d542f041b |
parent 69064 | 5840724b1d71 |
child 72581 | de581f98a3a1 |
permissions | -rw-r--r-- |
63194 | 1 |
(* Title: HOL/Probability/PMF_Impl.thy |
2 |
Author: Manuel Eberl, TU München |
|
3 |
||
4 |
An implementation of PMFs using Mappings, which are implemented with association lists |
|
5 |
by default. Also includes Quickcheck setup for PMFs. |
|
6 |
*) |
|
7 |
||
63195 | 8 |
section \<open>Code generation for PMFs\<close> |
9 |
||
63194 | 10 |
theory PMF_Impl |
66453
cc19f7ca2ed6
session-qualified theory imports: isabelle imports -U -i -d '~~/src/Benchmarks' -a;
wenzelm
parents:
64267
diff
changeset
|
11 |
imports Probability_Mass_Function "HOL-Library.AList_Mapping" |
63194 | 12 |
begin |
13 |
||
63195 | 14 |
subsection \<open>General code generation setup\<close> |
15 |
||
63194 | 16 |
definition pmf_of_mapping :: "('a, real) mapping \<Rightarrow> 'a pmf" where |
17 |
"pmf_of_mapping m = embed_pmf (Mapping.lookup_default 0 m)" |
|
18 |
||
19 |
lemma nn_integral_lookup_default: |
|
20 |
fixes m :: "('a, real) mapping" |
|
21 |
assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ x. x \<ge> 0)" |
|
22 |
shows "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) = |
|
23 |
ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)" |
|
24 |
proof - |
|
25 |
have "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) = |
|
26 |
(\<Sum>x\<in>Mapping.keys m. ennreal (Mapping.lookup_default 0 m x))" using assms |
|
27 |
by (subst nn_integral_count_space'[of "Mapping.keys m"]) |
|
28 |
(auto simp: Mapping.lookup_default_def keys_is_none_rep Option.is_none_def) |
|
29 |
also from assms have "\<dots> = ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)" |
|
64267 | 30 |
by (intro sum_ennreal) |
63194 | 31 |
(auto simp: Mapping.lookup_default_def All_mapping_def split: option.splits) |
32 |
finally show ?thesis . |
|
33 |
qed |
|
34 |
||
35 |
lemma pmf_of_mapping: |
|
36 |
assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ p. p \<ge> 0)" |
|
37 |
assumes "(\<Sum>x\<in>Mapping.keys m. Mapping.lookup_default 0 m x) = 1" |
|
38 |
shows "pmf (pmf_of_mapping m) x = Mapping.lookup_default 0 m x" |
|
39 |
unfolding pmf_of_mapping_def |
|
40 |
proof (intro pmf_embed_pmf) |
|
41 |
from assms show "(\<integral>\<^sup>+x. ennreal (Mapping.lookup_default 0 m x) \<partial>count_space UNIV) = 1" |
|
42 |
by (subst nn_integral_lookup_default) (simp_all) |
|
43 |
qed (insert assms, simp add: All_mapping_def Mapping.lookup_default_def split: option.splits) |
|
44 |
||
45 |
lemma pmf_of_set_pmf_of_mapping: |
|
46 |
assumes "A \<noteq> {}" "set xs = A" "distinct xs" |
|
47 |
shows "pmf_of_set A = pmf_of_mapping (Mapping.tabulate xs (\<lambda>_. 1 / real (length xs)))" |
|
48 |
(is "?lhs = ?rhs") |
|
49 |
by (rule pmf_eqI, subst pmf_of_mapping) |
|
50 |
(insert assms, auto intro!: All_mapping_tabulate |
|
51 |
simp: Mapping.lookup_default_def lookup_tabulate distinct_card) |
|
52 |
||
53 |
lift_definition mapping_of_pmf :: "'a pmf \<Rightarrow> ('a, real) mapping" is |
|
54 |
"\<lambda>p x. if pmf p x = 0 then None else Some (pmf p x)" . |
|
55 |
||
56 |
lemma lookup_default_mapping_of_pmf: |
|
57 |
"Mapping.lookup_default 0 (mapping_of_pmf p) x = pmf p x" |
|
58 |
by (simp add: mapping_of_pmf.abs_eq lookup_default_def Mapping.lookup.abs_eq) |
|
59 |
||
60 |
context |
|
61 |
begin |
|
62 |
||
63 |
interpretation pmf_as_function . |
|
64 |
||
65 |
lemma nn_integral_pmf_eq_1: "(\<integral>\<^sup>+ x. ennreal (pmf p x) \<partial>count_space UNIV) = 1" |
|
66 |
by transfer simp_all |
|
67 |
end |
|
68 |
||
69 |
lemma pmf_of_mapping_mapping_of_pmf [code abstype]: |
|
70 |
"pmf_of_mapping (mapping_of_pmf p) = p" |
|
71 |
unfolding pmf_of_mapping_def |
|
72 |
by (rule pmf_eqI, subst pmf_embed_pmf) |
|
73 |
(insert nn_integral_pmf_eq_1[of p], |
|
74 |
auto simp: lookup_default_mapping_of_pmf split: option.splits) |
|
75 |
||
76 |
lemma mapping_of_pmfI: |
|
77 |
assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup m x = Some (pmf p x)" |
|
78 |
assumes "Mapping.keys m = set_pmf p" |
|
79 |
shows "mapping_of_pmf p = m" |
|
80 |
using assms by transfer (rule ext, auto simp: set_pmf_eq) |
|
81 |
||
82 |
lemma mapping_of_pmfI': |
|
83 |
assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup_default 0 m x = pmf p x" |
|
84 |
assumes "Mapping.keys m = set_pmf p" |
|
85 |
shows "mapping_of_pmf p = m" |
|
86 |
using assms unfolding Mapping.lookup_default_def |
|
87 |
by transfer (rule ext, force simp: set_pmf_eq) |
|
88 |
||
89 |
lemma return_pmf_code [code abstract]: |
|
90 |
"mapping_of_pmf (return_pmf x) = Mapping.update x 1 Mapping.empty" |
|
91 |
by (intro mapping_of_pmfI) (auto simp: lookup_update') |
|
92 |
||
93 |
lemma pmf_of_set_code_aux: |
|
94 |
assumes "A \<noteq> {}" "set xs = A" "distinct xs" |
|
95 |
shows "mapping_of_pmf (pmf_of_set A) = Mapping.tabulate xs (\<lambda>_. 1 / real (length xs))" |
|
96 |
using assms |
|
97 |
by (intro mapping_of_pmfI, subst pmf_of_set) |
|
98 |
(auto simp: lookup_tabulate distinct_card) |
|
99 |
||
100 |
definition pmf_of_set_impl where |
|
101 |
"pmf_of_set_impl A = mapping_of_pmf (pmf_of_set A)" |
|
63195 | 102 |
|
103 |
(* This equation can be used to easily implement pmf_of_set for other set implementations *) |
|
104 |
lemma pmf_of_set_impl_code_alt: |
|
105 |
assumes "A \<noteq> {}" "finite A" |
|
106 |
shows "pmf_of_set_impl A = |
|
107 |
(let p = 1 / real (card A) |
|
108 |
in Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A)" |
|
109 |
proof - |
|
110 |
define p where "p = 1 / real (card A)" |
|
111 |
let ?m = "Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A" |
|
112 |
interpret comp_fun_idem "\<lambda>x. Mapping.update x p" |
|
113 |
by standard (transfer, force simp: fun_eq_iff)+ |
|
114 |
have keys: "Mapping.keys ?m = A" |
|
115 |
using assms(2) by (induction A rule: finite_induct) simp_all |
|
116 |
have lookup: "Mapping.lookup ?m x = Some p" if "x \<in> A" for x |
|
117 |
using assms(2) that by (induction A rule: finite_induct) (auto simp: lookup_update') |
|
118 |
from keys lookup assms show ?thesis unfolding pmf_of_set_impl_def |
|
119 |
by (intro mapping_of_pmfI) (simp_all add: Let_def p_def) |
|
120 |
qed |
|
121 |
||
63194 | 122 |
lemma pmf_of_set_impl_code [code]: |
123 |
"pmf_of_set_impl (set xs) = |
|
124 |
(if xs = [] then |
|
125 |
Code.abort (STR ''pmf_of_set of empty set'') (\<lambda>_. mapping_of_pmf (pmf_of_set (set xs))) |
|
126 |
else let xs' = remdups xs; p = 1 / real (length xs') in |
|
127 |
Mapping.tabulate xs' (\<lambda>_. p))" |
|
128 |
unfolding pmf_of_set_impl_def |
|
129 |
using pmf_of_set_code_aux[of "set xs" "remdups xs"] by (simp add: Let_def) |
|
130 |
||
131 |
lemma pmf_of_set_code [code abstract]: |
|
132 |
"mapping_of_pmf (pmf_of_set A) = pmf_of_set_impl A" |
|
133 |
by (simp add: pmf_of_set_impl_def) |
|
134 |
||
135 |
||
136 |
lemma pmf_of_multiset_pmf_of_mapping: |
|
137 |
assumes "A \<noteq> {#}" "set xs = set_mset A" "distinct xs" |
|
138 |
shows "mapping_of_pmf (pmf_of_multiset A) = Mapping.tabulate xs (\<lambda>x. count A x / real (size A))" |
|
139 |
using assms by (intro mapping_of_pmfI) (auto simp: lookup_tabulate) |
|
140 |
||
141 |
definition pmf_of_multiset_impl where |
|
63195 | 142 |
"pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)" |
143 |
||
144 |
lemma pmf_of_multiset_impl_code_alt: |
|
145 |
assumes "A \<noteq> {#}" |
|
146 |
shows "pmf_of_multiset_impl A = |
|
147 |
(let p = 1 / real (size A) |
|
67399 | 148 |
in fold_mset (\<lambda>x. Mapping.map_default x 0 ((+) p)) Mapping.empty A)" |
63195 | 149 |
proof - |
150 |
define p where "p = 1 / real (size A)" |
|
67399 | 151 |
interpret comp_fun_commute "\<lambda>x. Mapping.map_default x 0 ((+) p)" |
63195 | 152 |
unfolding Mapping.map_default_def [abs_def] |
153 |
by (standard, intro mapping_eqI ext) |
|
154 |
(simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def) |
|
67399 | 155 |
let ?m = "fold_mset (\<lambda>x. Mapping.map_default x 0 ((+) p)) Mapping.empty A" |
63195 | 156 |
have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all |
157 |
have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x |
|
158 |
by (induction A) |
|
159 |
(simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs) |
|
160 |
from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def |
|
161 |
by (intro mapping_of_pmfI') (simp_all add: Let_def p_def) |
|
162 |
qed |
|
63194 | 163 |
|
164 |
lemma pmf_of_multiset_impl_code [code]: |
|
165 |
"pmf_of_multiset_impl (mset xs) = |
|
166 |
(if xs = [] then |
|
167 |
Code.abort (STR ''pmf_of_multiset of empty multiset'') |
|
168 |
(\<lambda>_. mapping_of_pmf (pmf_of_multiset (mset xs))) |
|
169 |
else let xs' = remdups xs; p = 1 / real (length xs) in |
|
170 |
Mapping.tabulate xs' (\<lambda>x. real (count (mset xs) x) * p))" |
|
171 |
using pmf_of_multiset_pmf_of_mapping[of "mset xs" "remdups xs"] |
|
63195 | 172 |
by (simp add: pmf_of_multiset_impl_def) |
63194 | 173 |
|
174 |
lemma pmf_of_multiset_code [code abstract]: |
|
175 |
"mapping_of_pmf (pmf_of_multiset A) = pmf_of_multiset_impl A" |
|
176 |
by (simp add: pmf_of_multiset_impl_def) |
|
177 |
||
63195 | 178 |
|
63194 | 179 |
lemma bernoulli_pmf_code [code abstract]: |
180 |
"mapping_of_pmf (bernoulli_pmf p) = |
|
181 |
(if p \<le> 0 then Mapping.update False 1 Mapping.empty |
|
182 |
else if p \<ge> 1 then Mapping.update True 1 Mapping.empty |
|
183 |
else Mapping.update False (1 - p) (Mapping.update True p Mapping.empty))" |
|
184 |
by (intro mapping_of_pmfI) (auto simp: bernoulli_pmf.rep_eq lookup_update' set_pmf_eq) |
|
185 |
||
186 |
||
187 |
lemma pmf_code [code]: "pmf p x = Mapping.lookup_default 0 (mapping_of_pmf p) x" |
|
188 |
unfolding mapping_of_pmf_def Mapping.lookup_default_def |
|
189 |
by (auto split: option.splits simp: id_def Mapping.lookup.abs_eq) |
|
190 |
||
191 |
lemma set_pmf_code [code]: "set_pmf p = Mapping.keys (mapping_of_pmf p)" |
|
192 |
by transfer (auto simp: dom_def set_pmf_eq) |
|
193 |
||
194 |
lemma keys_mapping_of_pmf [simp]: "Mapping.keys (mapping_of_pmf p) = set_pmf p" |
|
195 |
by transfer (auto simp: dom_def set_pmf_eq) |
|
196 |
||
197 |
||
198 |
||
199 |
definition fold_combine_plus where |
|
67399 | 200 |
"fold_combine_plus = comm_monoid_set.F (Mapping.combine ((+) :: real \<Rightarrow> _)) Mapping.empty" |
63194 | 201 |
|
202 |
context |
|
203 |
begin |
|
204 |
||
67399 | 205 |
interpretation fold_combine_plus: combine_mapping_abel_semigroup "(+) :: real \<Rightarrow> _" |
63194 | 206 |
by unfold_locales (simp_all add: add_ac) |
207 |
||
208 |
qualified lemma lookup_default_fold_combine_plus: |
|
209 |
fixes A :: "'b set" and f :: "'b \<Rightarrow> ('a, real) mapping" |
|
210 |
assumes "finite A" |
|
211 |
shows "Mapping.lookup_default 0 (fold_combine_plus f A) x = |
|
212 |
(\<Sum>y\<in>A. Mapping.lookup_default 0 (f y) x)" |
|
213 |
unfolding fold_combine_plus_def using assms |
|
214 |
by (induction A rule: finite_induct) |
|
215 |
(simp_all add: lookup_default_empty lookup_default_neutral_combine) |
|
216 |
||
217 |
qualified lemma keys_fold_combine_plus: |
|
218 |
"finite A \<Longrightarrow> Mapping.keys (fold_combine_plus f A) = (\<Union>x\<in>A. Mapping.keys (f x))" |
|
219 |
by (simp add: fold_combine_plus_def fold_combine_plus.keys_fold_combine) |
|
220 |
||
221 |
qualified lemma fold_combine_plus_code [code]: |
|
67399 | 222 |
"fold_combine_plus g (set xs) = foldr (\<lambda>x. Mapping.combine (+) (g x)) (remdups xs) Mapping.empty" |
63194 | 223 |
by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code) |
224 |
||
225 |
private lemma lookup_default_0_map_values: |
|
63195 | 226 |
assumes "f x 0 = 0" |
227 |
shows "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)" |
|
63194 | 228 |
unfolding Mapping.lookup_default_def |
63195 | 229 |
using assms by transfer (auto split: option.splits) |
63194 | 230 |
|
231 |
qualified lemma mapping_of_bind_pmf: |
|
232 |
assumes "finite (set_pmf p)" |
|
233 |
shows "mapping_of_pmf (bind_pmf p f) = |
|
69064
5840724b1d71
Prefix form of infix with * on either side no longer needs special treatment
nipkow
parents:
67399
diff
changeset
|
234 |
fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. (*) (pmf p x)) |
63194 | 235 |
(mapping_of_pmf (f x))) (set_pmf p)" |
236 |
using assms |
|
237 |
by (intro mapping_of_pmfI') |
|
238 |
(auto simp: keys_fold_combine_plus lookup_default_fold_combine_plus |
|
239 |
pmf_bind integral_measure_pmf lookup_default_0_map_values |
|
240 |
lookup_default_mapping_of_pmf mult_ac) |
|
241 |
||
63195 | 242 |
lift_definition bind_pmf_aux :: "'a pmf \<Rightarrow> ('a \<Rightarrow> 'b pmf) \<Rightarrow> 'a set \<Rightarrow> ('b, real) mapping" is |
243 |
"\<lambda>(p :: 'a pmf) (f :: 'a \<Rightarrow> 'b pmf) (A::'a set) (x::'b). |
|
244 |
if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then |
|
245 |
Some (measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)) |
|
246 |
else None" . |
|
247 |
||
248 |
lemma keys_bind_pmf_aux [simp]: |
|
249 |
"Mapping.keys (bind_pmf_aux p f A) = (\<Union>x\<in>A. set_pmf (f x))" |
|
250 |
by transfer (auto split: if_splits) |
|
251 |
||
252 |
lemma lookup_default_bind_pmf_aux: |
|
253 |
"Mapping.lookup_default 0 (bind_pmf_aux p f A) x = |
|
254 |
(if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then |
|
255 |
measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x) else 0)" |
|
256 |
unfolding lookup_default_def by transfer' simp_all |
|
257 |
||
258 |
lemma lookup_default_bind_pmf_aux' [simp]: |
|
259 |
"Mapping.lookup_default 0 (bind_pmf_aux p f (set_pmf p)) x = pmf (bind_pmf p f) x" |
|
260 |
unfolding lookup_default_def |
|
261 |
by transfer (auto simp: pmf_bind AE_measure_pmf_iff set_pmf_eq |
|
262 |
intro!: integral_cong_AE integral_eq_zero_AE) |
|
263 |
||
264 |
lemma bind_pmf_aux_correct: |
|
265 |
"mapping_of_pmf (bind_pmf p f) = bind_pmf_aux p f (set_pmf p)" |
|
266 |
by (intro mapping_of_pmfI') simp_all |
|
267 |
||
268 |
lemma bind_pmf_aux_code_aux: |
|
269 |
assumes "finite A" |
|
270 |
shows "bind_pmf_aux p f A = |
|
69064
5840724b1d71
Prefix form of infix with * on either side no longer needs special treatment
nipkow
parents:
67399
diff
changeset
|
271 |
fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. (*) (pmf p x)) |
63195 | 272 |
(mapping_of_pmf (f x))) A" (is "?lhs = ?rhs") |
273 |
proof (intro mapping_eqI'[where d = 0]) |
|
274 |
fix x assume "x \<in> Mapping.keys ?lhs" |
|
275 |
then obtain y where y: "y \<in> A" "x \<in> set_pmf (f y)" by auto |
|
276 |
hence "Mapping.lookup_default 0 ?lhs x = |
|
277 |
measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)" |
|
278 |
by (auto simp: lookup_default_bind_pmf_aux) |
|
279 |
also from assms have "\<dots> = (\<Sum>y\<in>A. pmf p y * pmf (f y) x)" |
|
280 |
by (subst integral_measure_pmf [of A]) |
|
281 |
(auto simp: set_pmf_eq indicator_def mult_ac split: if_splits) |
|
282 |
also from assms have "\<dots> = Mapping.lookup_default 0 ?rhs x" |
|
283 |
by (simp add: lookup_default_fold_combine_plus lookup_default_0_map_values |
|
284 |
lookup_default_mapping_of_pmf) |
|
285 |
finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" . |
|
286 |
qed (insert assms, simp_all add: keys_fold_combine_plus) |
|
287 |
||
288 |
lemma bind_pmf_aux_code [code]: |
|
289 |
"bind_pmf_aux p f (set xs) = |
|
69064
5840724b1d71
Prefix form of infix with * on either side no longer needs special treatment
nipkow
parents:
67399
diff
changeset
|
290 |
fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. (*) (pmf p x)) |
63195 | 291 |
(mapping_of_pmf (f x))) (set xs)" |
292 |
by (rule bind_pmf_aux_code_aux) simp_all |
|
293 |
||
294 |
lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct |
|
63194 | 295 |
|
296 |
end |
|
297 |
||
63195 | 298 |
hide_const (open) fold_combine_plus |
63194 | 299 |
|
300 |
||
301 |
lift_definition cond_pmf_impl :: "'a pmf \<Rightarrow> 'a set \<Rightarrow> ('a, real) mapping option" is |
|
302 |
"\<lambda>p A. if A \<inter> set_pmf p = {} then None else |
|
303 |
Some (\<lambda>x. if x \<in> A \<inter> set_pmf p then Some (pmf p x / measure_pmf.prob p A) else None)" . |
|
304 |
||
63195 | 305 |
lemma cond_pmf_impl_code_alt: |
306 |
assumes "finite A" |
|
307 |
shows "cond_pmf_impl p A = ( |
|
308 |
let C = A \<inter> set_pmf p; |
|
309 |
prob = (\<Sum>x\<in>C. pmf p x) |
|
310 |
in if prob = 0 then |
|
311 |
None |
|
312 |
else |
|
313 |
Some (Mapping.map_values (\<lambda>_ y. y / prob) |
|
314 |
(Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))" |
|
63194 | 315 |
proof - |
63195 | 316 |
define C where "C = A \<inter> set_pmf p" |
317 |
define prob where "prob = (\<Sum>x\<in>C. pmf p x)" |
|
318 |
also note C_def |
|
319 |
also from assms have "(\<Sum>x\<in>A \<inter> set_pmf p. pmf p x) = (\<Sum>x\<in>A. pmf p x)" |
|
64267 | 320 |
by (intro sum.mono_neutral_left) (auto simp: set_pmf_eq) |
63195 | 321 |
finally have prob1: "prob = (\<Sum>x\<in>A. pmf p x)" . |
322 |
hence prob2: "prob = measure_pmf.prob p A" |
|
323 |
using assms by (subst measure_measure_pmf_finite) simp_all |
|
324 |
have prob3: "prob = 0 \<longleftrightarrow> A \<inter> set_pmf p = {}" |
|
64267 | 325 |
by (subst prob1, subst sum_nonneg_eq_0_iff) (auto simp: set_pmf_eq assms) |
63195 | 326 |
from assms have prob4: "prob = measure_pmf.prob p C" |
327 |
unfolding prob_def by (intro measure_measure_pmf_finite [symmetric]) (simp_all add: C_def) |
|
63194 | 328 |
|
329 |
show ?thesis |
|
330 |
proof (cases "prob = 0") |
|
331 |
case True |
|
63195 | 332 |
hence "A \<inter> set_pmf p = {}" by (subst (asm) prob3) |
333 |
with True show ?thesis by (simp add: Let_def prob_def C_def cond_pmf_impl.abs_eq) |
|
63194 | 334 |
next |
335 |
case False |
|
63195 | 336 |
hence A: "C \<noteq> {}" unfolding C_def by (subst (asm) prob3) auto |
337 |
with prob3 have prob_nz: "prob \<noteq> 0" by (auto simp: C_def) |
|
63194 | 338 |
fix x |
63195 | 339 |
have "cond_pmf_impl p A = |
340 |
Some (mapping.Mapping (\<lambda>x. if x \<in> C then |
|
341 |
Some (pmf p x / measure_pmf.prob p C) else None))" |
|
63194 | 342 |
(is "_ = Some ?m") |
63195 | 343 |
using A prob2 prob4 unfolding C_def by transfer (auto simp: fun_eq_iff) |
344 |
also have "?m = Mapping.map_values (\<lambda>_ y. y / prob) |
|
345 |
(Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))" |
|
346 |
using prob_nz prob4 assms unfolding C_def |
|
347 |
by transfer (auto simp: fun_eq_iff set_pmf_eq) |
|
348 |
finally show ?thesis using False by (simp add: Let_def prob_def C_def) |
|
63194 | 349 |
qed |
350 |
qed |
|
351 |
||
63195 | 352 |
lemma cond_pmf_impl_code [code]: |
353 |
"cond_pmf_impl p (set xs) = ( |
|
354 |
let C = set xs \<inter> set_pmf p; |
|
355 |
prob = (\<Sum>x\<in>C. pmf p x) |
|
356 |
in if prob = 0 then |
|
357 |
None |
|
358 |
else |
|
359 |
Some (Mapping.map_values (\<lambda>_ y. y / prob) |
|
360 |
(Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))" |
|
361 |
by (rule cond_pmf_impl_code_alt) simp_all |
|
362 |
||
63194 | 363 |
lemma cond_pmf_code [code abstract]: |
364 |
"mapping_of_pmf (cond_pmf p A) = |
|
365 |
(case cond_pmf_impl p A of |
|
366 |
None \<Rightarrow> Code.abort (STR ''cond_pmf with set of probability 0'') |
|
367 |
(\<lambda>_. mapping_of_pmf (cond_pmf p A)) |
|
368 |
| Some m \<Rightarrow> m)" |
|
369 |
proof (cases "cond_pmf_impl p A") |
|
370 |
case (Some m) |
|
371 |
hence A: "set_pmf p \<inter> A \<noteq> {}" by transfer (auto split: if_splits) |
|
372 |
from Some have B: "Mapping.keys m = set_pmf (cond_pmf p A)" |
|
373 |
by (subst set_cond_pmf[OF A], transfer) (auto split: if_splits) |
|
374 |
with Some A have "mapping_of_pmf (cond_pmf p A) = m" |
|
375 |
by (intro mapping_of_pmfI[OF _ B], transfer) (auto split: if_splits simp: pmf_cond) |
|
376 |
with Some show ?thesis by simp |
|
377 |
qed simp_all |
|
378 |
||
379 |
||
380 |
lemma binomial_pmf_code [code abstract]: |
|
381 |
"mapping_of_pmf (binomial_pmf n p) = ( |
|
382 |
if p < 0 \<or> p > 1 then |
|
63195 | 383 |
Code.abort (STR ''binomial_pmf with invalid probability'') |
384 |
(\<lambda>_. mapping_of_pmf (binomial_pmf n p)) |
|
63194 | 385 |
else if p = 0 then Mapping.update 0 1 Mapping.empty |
386 |
else if p = 1 then Mapping.update n 1 Mapping.empty |
|
387 |
else Mapping.tabulate [0..<Suc n] (\<lambda>k. real (n choose k) * p ^ k * (1 - p) ^ (n - k)))" |
|
388 |
by (cases "p < 0 \<or> p > 1") |
|
389 |
(simp, intro mapping_of_pmfI, |
|
390 |
auto simp: lookup_update' lookup_empty set_pmf_binomial_eq lookup_tabulate split: if_splits) |
|
391 |
||
63195 | 392 |
|
63194 | 393 |
lemma pred_pmf_code [code]: |
394 |
"pred_pmf P p = (\<forall>x\<in>set_pmf p. P x)" |
|
395 |
by (auto simp: pred_pmf_def) |
|
396 |
||
397 |
||
398 |
lemma mapping_of_pmf_pmf_of_list: |
|
63882
018998c00003
renamed listsum -> sum_list, listprod ~> prod_list
nipkow
parents:
63793
diff
changeset
|
399 |
assumes "\<And>x. x \<in> snd ` set xs \<Longrightarrow> x > 0" "sum_list (map snd xs) = 1" |
63194 | 400 |
shows "mapping_of_pmf (pmf_of_list xs) = |
401 |
Mapping.tabulate (remdups (map fst xs)) |
|
63882
018998c00003
renamed listsum -> sum_list, listprod ~> prod_list
nipkow
parents:
63793
diff
changeset
|
402 |
(\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs)))" |
63194 | 403 |
proof - |
404 |
from assms have wf: "pmf_of_list_wf xs" by (intro pmf_of_list_wfI) force |
|
63539 | 405 |
with assms have "set_pmf (pmf_of_list xs) = fst ` set xs" |
63194 | 406 |
by (intro set_pmf_of_list_eq) auto |
63539 | 407 |
with wf show ?thesis |
63194 | 408 |
by (intro mapping_of_pmfI) (auto simp: lookup_tabulate pmf_pmf_of_list) |
409 |
qed |
|
410 |
||
411 |
lemma mapping_of_pmf_pmf_of_list': |
|
412 |
assumes "pmf_of_list_wf xs" |
|
413 |
defines "xs' \<equiv> filter (\<lambda>z. snd z \<noteq> 0) xs" |
|
414 |
shows "mapping_of_pmf (pmf_of_list xs) = |
|
415 |
Mapping.tabulate (remdups (map fst xs')) |
|
63882
018998c00003
renamed listsum -> sum_list, listprod ~> prod_list
nipkow
parents:
63793
diff
changeset
|
416 |
(\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs')))" (is "_ = ?rhs") |
63194 | 417 |
proof - |
418 |
have wf: "pmf_of_list_wf xs'" unfolding xs'_def by (rule pmf_of_list_remove_zeros) fact |
|
419 |
have pos: "\<forall>x\<in>snd`set xs'. x > 0" using assms(1) unfolding xs'_def |
|
420 |
by (force simp: pmf_of_list_wf_def) |
|
421 |
from assms have "pmf_of_list xs = pmf_of_list xs'" |
|
422 |
unfolding xs'_def by (subst pmf_of_list_remove_zeros) simp_all |
|
423 |
also from wf pos have "mapping_of_pmf \<dots> = ?rhs" |
|
424 |
by (intro mapping_of_pmf_pmf_of_list) (auto simp: pmf_of_list_wf_def) |
|
425 |
finally show ?thesis . |
|
426 |
qed |
|
427 |
||
428 |
lemma pmf_of_list_wf_code [code]: |
|
63882
018998c00003
renamed listsum -> sum_list, listprod ~> prod_list
nipkow
parents:
63793
diff
changeset
|
429 |
"pmf_of_list_wf xs \<longleftrightarrow> list_all (\<lambda>z. snd z \<ge> 0) xs \<and> sum_list (map snd xs) = 1" |
63194 | 430 |
by (auto simp add: pmf_of_list_wf_def list_all_def) |
431 |
||
432 |
lemma pmf_of_list_code [code abstract]: |
|
433 |
"mapping_of_pmf (pmf_of_list xs) = ( |
|
434 |
if pmf_of_list_wf xs then |
|
435 |
let xs' = filter (\<lambda>z. snd z \<noteq> 0) xs |
|
436 |
in Mapping.tabulate (remdups (map fst xs')) |
|
63882
018998c00003
renamed listsum -> sum_list, listprod ~> prod_list
nipkow
parents:
63793
diff
changeset
|
437 |
(\<lambda>x. sum_list (map snd (filter (\<lambda>z. fst z = x) xs'))) |
63194 | 438 |
else |
439 |
Code.abort (STR ''Invalid list for pmf_of_list'') (\<lambda>_. mapping_of_pmf (pmf_of_list xs)))" |
|
440 |
using mapping_of_pmf_pmf_of_list'[of xs] by (simp add: Let_def) |
|
441 |
||
442 |
lemma mapping_of_pmf_eq_iff [simp]: |
|
443 |
"mapping_of_pmf p = mapping_of_pmf q \<longleftrightarrow> p = (q :: 'a pmf)" |
|
444 |
proof (transfer, intro iffI pmf_eqI) |
|
445 |
fix p q :: "'a pmf" and x :: 'a |
|
446 |
assume "(\<lambda>x. if pmf p x = 0 then None else Some (pmf p x)) = |
|
447 |
(\<lambda>x. if pmf q x = 0 then None else Some (pmf q x))" |
|
448 |
hence "(if pmf p x = 0 then None else Some (pmf p x)) = |
|
449 |
(if pmf q x = 0 then None else Some (pmf q x))" for x |
|
450 |
by (simp add: fun_eq_iff) |
|
451 |
from this[of x] show "pmf p x = pmf q x" by (auto split: if_splits) |
|
452 |
qed (simp_all cong: if_cong) |
|
453 |
||
63195 | 454 |
|
455 |
subsection \<open>Code abbreviations for integrals and probabilities\<close> |
|
456 |
||
457 |
text \<open> |
|
458 |
Integrals and probabilities are defined for general measures, so we cannot give any |
|
459 |
code equations directly. We can, however, specialise these constants them to PMFs, |
|
460 |
give code equations for these specialised constants, and tell the code generator |
|
461 |
to unfold the original constants to the specialised ones whenever possible. |
|
462 |
\<close> |
|
463 |
||
464 |
definition pmf_integral where |
|
465 |
"pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \<Rightarrow> real)" |
|
466 |
||
467 |
definition pmf_set_integral where |
|
468 |
"pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\<lambda>x. indicator A x * f x :: real)" |
|
469 |
||
470 |
definition pmf_prob where |
|
471 |
"pmf_prob p A = measure_pmf.prob p A" |
|
472 |
||
473 |
lemma pmf_prob_compl: "pmf_prob p (-A) = 1 - pmf_prob p A" |
|
474 |
using measure_pmf.prob_compl[of A p] by (simp add: pmf_prob_def Compl_eq_Diff_UNIV) |
|
475 |
||
476 |
lemma pmf_integral_pmf_set_integral [code]: |
|
477 |
"pmf_integral p f = pmf_set_integral p f (set_pmf p)" |
|
478 |
unfolding pmf_integral_def pmf_set_integral_def |
|
479 |
by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff) |
|
480 |
||
481 |
lemma pmf_prob_pmf_set_integral: |
|
482 |
"pmf_prob p A = pmf_set_integral p (\<lambda>_. 1) A" |
|
483 |
by (simp add: pmf_prob_def pmf_set_integral_def) |
|
484 |
||
485 |
lemma pmf_set_integral_code_alt_finite: |
|
486 |
"finite A \<Longrightarrow> pmf_set_integral p f A = (\<Sum>x\<in>A. pmf p x * f x)" |
|
487 |
unfolding pmf_set_integral_def |
|
488 |
by (subst integral_measure_pmf[of A]) (auto simp: indicator_def mult_ac split: if_splits) |
|
489 |
||
490 |
lemma pmf_set_integral_code [code]: |
|
491 |
"pmf_set_integral p f (set xs) = (\<Sum>x\<in>set xs. pmf p x * f x)" |
|
492 |
by (rule pmf_set_integral_code_alt_finite) simp_all |
|
493 |
||
494 |
||
495 |
lemma pmf_prob_code_alt_finite: |
|
496 |
"finite A \<Longrightarrow> pmf_prob p A = (\<Sum>x\<in>A. pmf p x)" |
|
497 |
by (simp add: pmf_prob_pmf_set_integral pmf_set_integral_code_alt_finite) |
|
498 |
||
499 |
lemma pmf_prob_code [code]: |
|
500 |
"pmf_prob p (set xs) = (\<Sum>x\<in>set xs. pmf p x)" |
|
501 |
"pmf_prob p (List.coset xs) = 1 - (\<Sum>x\<in>set xs. pmf p x)" |
|
502 |
by (simp_all add: pmf_prob_code_alt_finite pmf_prob_compl) |
|
503 |
||
504 |
||
505 |
lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p" |
|
506 |
by (intro ext) (simp add: pmf_prob_def) |
|
507 |
||
508 |
(* FIXME: Why does this not work without parameters? *) |
|
509 |
lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p" |
|
510 |
by (intro ext) (simp add: pmf_integral_def) |
|
511 |
||
512 |
||
513 |
||
63194 | 514 |
definition "pmf_of_alist xs = embed_pmf (\<lambda>x. case map_of xs x of Some p \<Rightarrow> p | None \<Rightarrow> 0)" |
515 |
||
516 |
lemma pmf_of_mapping_Mapping [code_post]: |
|
517 |
"pmf_of_mapping (Mapping xs) = pmf_of_alist xs" |
|
518 |
unfolding pmf_of_mapping_def Mapping.lookup_default_def [abs_def] pmf_of_alist_def |
|
519 |
by transfer simp_all |
|
520 |
||
521 |
||
522 |
instantiation pmf :: (equal) equal |
|
523 |
begin |
|
524 |
||
525 |
definition "equal_pmf p q = (mapping_of_pmf p = mapping_of_pmf (q :: 'a pmf))" |
|
526 |
||
527 |
instance by standard (simp add: equal_pmf_def) |
|
528 |
end |
|
529 |
||
63793
e68a0b651eb5
add_mset constructor in multisets
fleury <Mathias.Fleury@mpi-inf.mpg.de>
parents:
63539
diff
changeset
|
530 |
definition single :: "'a \<Rightarrow> 'a multiset" where |
e68a0b651eb5
add_mset constructor in multisets
fleury <Mathias.Fleury@mpi-inf.mpg.de>
parents:
63539
diff
changeset
|
531 |
"single s = {#s#}" |
63194 | 532 |
|
533 |
definition (in term_syntax) |
|
534 |
pmfify :: "('a::typerep multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)) \<Rightarrow> |
|
535 |
'a \<times> (unit \<Rightarrow> Code_Evaluation.term) \<Rightarrow> |
|
536 |
'a pmf \<times> (unit \<Rightarrow> Code_Evaluation.term)" where |
|
537 |
[code_unfold]: "pmfify A x = |
|
538 |
Code_Evaluation.valtermify pmf_of_multiset {\<cdot>} |
|
67399 | 539 |
(Code_Evaluation.valtermify (+) {\<cdot>} A {\<cdot>} |
63194 | 540 |
(Code_Evaluation.valtermify single {\<cdot>} x))" |
541 |
||
542 |
||
543 |
notation fcomp (infixl "\<circ>>" 60) |
|
544 |
notation scomp (infixl "\<circ>\<rightarrow>" 60) |
|
545 |
||
546 |
instantiation pmf :: (random) random |
|
547 |
begin |
|
548 |
||
549 |
definition |
|
550 |
"Quickcheck_Random.random i = |
|
551 |
Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>A. |
|
552 |
Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>x. Pair (pmfify A x)))" |
|
553 |
||
554 |
instance .. |
|
555 |
||
556 |
end |
|
557 |
||
558 |
no_notation fcomp (infixl "\<circ>>" 60) |
|
559 |
no_notation scomp (infixl "\<circ>\<rightarrow>" 60) |
|
560 |
||
561 |
instantiation pmf :: (full_exhaustive) full_exhaustive |
|
562 |
begin |
|
563 |
||
564 |
definition full_exhaustive_pmf :: "('a pmf \<times> (unit \<Rightarrow> term) \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option" |
|
565 |
where |
|
566 |
"full_exhaustive_pmf f i = |
|
567 |
Quickcheck_Exhaustive.full_exhaustive (\<lambda>A. |
|
568 |
Quickcheck_Exhaustive.full_exhaustive (\<lambda>x. f (pmfify A x)) i) i" |
|
569 |
||
570 |
instance .. |
|
571 |
||
572 |
end |
|
573 |
||
64267 | 574 |
end |