63194
|
1 |
(* Title: HOL/Probability/PMF_Impl.thy
|
|
2 |
Author: Manuel Eberl, TU München
|
|
3 |
|
|
4 |
An implementation of PMFs using Mappings, which are implemented with association lists
|
|
5 |
by default. Also includes Quickcheck setup for PMFs.
|
|
6 |
*)
|
|
7 |
|
63195
|
8 |
section \<open>Code generation for PMFs\<close>
|
|
9 |
|
63194
|
10 |
theory PMF_Impl
|
|
11 |
imports Probability_Mass_Function "~~/src/HOL/Library/AList_Mapping"
|
|
12 |
begin
|
|
13 |
|
63195
|
14 |
subsection \<open>General code generation setup\<close>
|
|
15 |
|
63194
|
16 |
definition pmf_of_mapping :: "('a, real) mapping \<Rightarrow> 'a pmf" where
|
|
17 |
"pmf_of_mapping m = embed_pmf (Mapping.lookup_default 0 m)"
|
|
18 |
|
|
19 |
lemma nn_integral_lookup_default:
|
|
20 |
fixes m :: "('a, real) mapping"
|
|
21 |
assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ x. x \<ge> 0)"
|
|
22 |
shows "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) =
|
|
23 |
ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)"
|
|
24 |
proof -
|
|
25 |
have "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) =
|
|
26 |
(\<Sum>x\<in>Mapping.keys m. ennreal (Mapping.lookup_default 0 m x))" using assms
|
|
27 |
by (subst nn_integral_count_space'[of "Mapping.keys m"])
|
|
28 |
(auto simp: Mapping.lookup_default_def keys_is_none_rep Option.is_none_def)
|
|
29 |
also from assms have "\<dots> = ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)"
|
|
30 |
by (intro setsum_ennreal)
|
|
31 |
(auto simp: Mapping.lookup_default_def All_mapping_def split: option.splits)
|
|
32 |
finally show ?thesis .
|
|
33 |
qed
|
|
34 |
|
|
35 |
lemma pmf_of_mapping:
|
|
36 |
assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ p. p \<ge> 0)"
|
|
37 |
assumes "(\<Sum>x\<in>Mapping.keys m. Mapping.lookup_default 0 m x) = 1"
|
|
38 |
shows "pmf (pmf_of_mapping m) x = Mapping.lookup_default 0 m x"
|
|
39 |
unfolding pmf_of_mapping_def
|
|
40 |
proof (intro pmf_embed_pmf)
|
|
41 |
from assms show "(\<integral>\<^sup>+x. ennreal (Mapping.lookup_default 0 m x) \<partial>count_space UNIV) = 1"
|
|
42 |
by (subst nn_integral_lookup_default) (simp_all)
|
|
43 |
qed (insert assms, simp add: All_mapping_def Mapping.lookup_default_def split: option.splits)
|
|
44 |
|
|
45 |
lemma pmf_of_set_pmf_of_mapping:
|
|
46 |
assumes "A \<noteq> {}" "set xs = A" "distinct xs"
|
|
47 |
shows "pmf_of_set A = pmf_of_mapping (Mapping.tabulate xs (\<lambda>_. 1 / real (length xs)))"
|
|
48 |
(is "?lhs = ?rhs")
|
|
49 |
by (rule pmf_eqI, subst pmf_of_mapping)
|
|
50 |
(insert assms, auto intro!: All_mapping_tabulate
|
|
51 |
simp: Mapping.lookup_default_def lookup_tabulate distinct_card)
|
|
52 |
|
|
53 |
lift_definition mapping_of_pmf :: "'a pmf \<Rightarrow> ('a, real) mapping" is
|
|
54 |
"\<lambda>p x. if pmf p x = 0 then None else Some (pmf p x)" .
|
|
55 |
|
|
56 |
lemma lookup_default_mapping_of_pmf:
|
|
57 |
"Mapping.lookup_default 0 (mapping_of_pmf p) x = pmf p x"
|
|
58 |
by (simp add: mapping_of_pmf.abs_eq lookup_default_def Mapping.lookup.abs_eq)
|
|
59 |
|
|
60 |
context
|
|
61 |
begin
|
|
62 |
|
|
63 |
interpretation pmf_as_function .
|
|
64 |
|
|
65 |
lemma nn_integral_pmf_eq_1: "(\<integral>\<^sup>+ x. ennreal (pmf p x) \<partial>count_space UNIV) = 1"
|
|
66 |
by transfer simp_all
|
|
67 |
end
|
|
68 |
|
|
69 |
lemma pmf_of_mapping_mapping_of_pmf [code abstype]:
|
|
70 |
"pmf_of_mapping (mapping_of_pmf p) = p"
|
|
71 |
unfolding pmf_of_mapping_def
|
|
72 |
by (rule pmf_eqI, subst pmf_embed_pmf)
|
|
73 |
(insert nn_integral_pmf_eq_1[of p],
|
|
74 |
auto simp: lookup_default_mapping_of_pmf split: option.splits)
|
|
75 |
|
|
76 |
lemma mapping_of_pmfI:
|
|
77 |
assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup m x = Some (pmf p x)"
|
|
78 |
assumes "Mapping.keys m = set_pmf p"
|
|
79 |
shows "mapping_of_pmf p = m"
|
|
80 |
using assms by transfer (rule ext, auto simp: set_pmf_eq)
|
|
81 |
|
|
82 |
lemma mapping_of_pmfI':
|
|
83 |
assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup_default 0 m x = pmf p x"
|
|
84 |
assumes "Mapping.keys m = set_pmf p"
|
|
85 |
shows "mapping_of_pmf p = m"
|
|
86 |
using assms unfolding Mapping.lookup_default_def
|
|
87 |
by transfer (rule ext, force simp: set_pmf_eq)
|
|
88 |
|
|
89 |
lemma return_pmf_code [code abstract]:
|
|
90 |
"mapping_of_pmf (return_pmf x) = Mapping.update x 1 Mapping.empty"
|
|
91 |
by (intro mapping_of_pmfI) (auto simp: lookup_update')
|
|
92 |
|
|
93 |
lemma pmf_of_set_code_aux:
|
|
94 |
assumes "A \<noteq> {}" "set xs = A" "distinct xs"
|
|
95 |
shows "mapping_of_pmf (pmf_of_set A) = Mapping.tabulate xs (\<lambda>_. 1 / real (length xs))"
|
|
96 |
using assms
|
|
97 |
by (intro mapping_of_pmfI, subst pmf_of_set)
|
|
98 |
(auto simp: lookup_tabulate distinct_card)
|
|
99 |
|
|
100 |
definition pmf_of_set_impl where
|
|
101 |
"pmf_of_set_impl A = mapping_of_pmf (pmf_of_set A)"
|
63195
|
102 |
|
|
103 |
(* This equation can be used to easily implement pmf_of_set for other set implementations *)
|
|
104 |
lemma pmf_of_set_impl_code_alt:
|
|
105 |
assumes "A \<noteq> {}" "finite A"
|
|
106 |
shows "pmf_of_set_impl A =
|
|
107 |
(let p = 1 / real (card A)
|
|
108 |
in Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A)"
|
|
109 |
proof -
|
|
110 |
define p where "p = 1 / real (card A)"
|
|
111 |
let ?m = "Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A"
|
|
112 |
interpret comp_fun_idem "\<lambda>x. Mapping.update x p"
|
|
113 |
by standard (transfer, force simp: fun_eq_iff)+
|
|
114 |
have keys: "Mapping.keys ?m = A"
|
|
115 |
using assms(2) by (induction A rule: finite_induct) simp_all
|
|
116 |
have lookup: "Mapping.lookup ?m x = Some p" if "x \<in> A" for x
|
|
117 |
using assms(2) that by (induction A rule: finite_induct) (auto simp: lookup_update')
|
|
118 |
from keys lookup assms show ?thesis unfolding pmf_of_set_impl_def
|
|
119 |
by (intro mapping_of_pmfI) (simp_all add: Let_def p_def)
|
|
120 |
qed
|
|
121 |
|
63194
|
122 |
lemma pmf_of_set_impl_code [code]:
|
|
123 |
"pmf_of_set_impl (set xs) =
|
|
124 |
(if xs = [] then
|
|
125 |
Code.abort (STR ''pmf_of_set of empty set'') (\<lambda>_. mapping_of_pmf (pmf_of_set (set xs)))
|
|
126 |
else let xs' = remdups xs; p = 1 / real (length xs') in
|
|
127 |
Mapping.tabulate xs' (\<lambda>_. p))"
|
|
128 |
unfolding pmf_of_set_impl_def
|
|
129 |
using pmf_of_set_code_aux[of "set xs" "remdups xs"] by (simp add: Let_def)
|
|
130 |
|
|
131 |
lemma pmf_of_set_code [code abstract]:
|
|
132 |
"mapping_of_pmf (pmf_of_set A) = pmf_of_set_impl A"
|
|
133 |
by (simp add: pmf_of_set_impl_def)
|
|
134 |
|
|
135 |
|
|
136 |
lemma pmf_of_multiset_pmf_of_mapping:
|
|
137 |
assumes "A \<noteq> {#}" "set xs = set_mset A" "distinct xs"
|
|
138 |
shows "mapping_of_pmf (pmf_of_multiset A) = Mapping.tabulate xs (\<lambda>x. count A x / real (size A))"
|
|
139 |
using assms by (intro mapping_of_pmfI) (auto simp: lookup_tabulate)
|
|
140 |
|
|
141 |
definition pmf_of_multiset_impl where
|
63195
|
142 |
"pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)"
|
|
143 |
|
|
144 |
lemma pmf_of_multiset_impl_code_alt:
|
|
145 |
assumes "A \<noteq> {#}"
|
|
146 |
shows "pmf_of_multiset_impl A =
|
|
147 |
(let p = 1 / real (size A)
|
|
148 |
in fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A)"
|
|
149 |
proof -
|
|
150 |
define p where "p = 1 / real (size A)"
|
|
151 |
interpret comp_fun_commute "\<lambda>x. Mapping.map_default x 0 (op + p)"
|
|
152 |
unfolding Mapping.map_default_def [abs_def]
|
|
153 |
by (standard, intro mapping_eqI ext)
|
|
154 |
(simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def)
|
|
155 |
let ?m = "fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A"
|
|
156 |
have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all
|
|
157 |
have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x
|
|
158 |
by (induction A)
|
|
159 |
(simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs)
|
|
160 |
from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def
|
|
161 |
by (intro mapping_of_pmfI') (simp_all add: Let_def p_def)
|
|
162 |
qed
|
63194
|
163 |
|
|
164 |
lemma pmf_of_multiset_impl_code [code]:
|
|
165 |
"pmf_of_multiset_impl (mset xs) =
|
|
166 |
(if xs = [] then
|
|
167 |
Code.abort (STR ''pmf_of_multiset of empty multiset'')
|
|
168 |
(\<lambda>_. mapping_of_pmf (pmf_of_multiset (mset xs)))
|
|
169 |
else let xs' = remdups xs; p = 1 / real (length xs) in
|
|
170 |
Mapping.tabulate xs' (\<lambda>x. real (count (mset xs) x) * p))"
|
|
171 |
using pmf_of_multiset_pmf_of_mapping[of "mset xs" "remdups xs"]
|
63195
|
172 |
by (simp add: pmf_of_multiset_impl_def)
|
63194
|
173 |
|
|
174 |
lemma pmf_of_multiset_code [code abstract]:
|
|
175 |
"mapping_of_pmf (pmf_of_multiset A) = pmf_of_multiset_impl A"
|
|
176 |
by (simp add: pmf_of_multiset_impl_def)
|
|
177 |
|
63195
|
178 |
|
63194
|
179 |
lemma bernoulli_pmf_code [code abstract]:
|
|
180 |
"mapping_of_pmf (bernoulli_pmf p) =
|
|
181 |
(if p \<le> 0 then Mapping.update False 1 Mapping.empty
|
|
182 |
else if p \<ge> 1 then Mapping.update True 1 Mapping.empty
|
|
183 |
else Mapping.update False (1 - p) (Mapping.update True p Mapping.empty))"
|
|
184 |
by (intro mapping_of_pmfI) (auto simp: bernoulli_pmf.rep_eq lookup_update' set_pmf_eq)
|
|
185 |
|
|
186 |
|
|
187 |
lemma pmf_code [code]: "pmf p x = Mapping.lookup_default 0 (mapping_of_pmf p) x"
|
|
188 |
unfolding mapping_of_pmf_def Mapping.lookup_default_def
|
|
189 |
by (auto split: option.splits simp: id_def Mapping.lookup.abs_eq)
|
|
190 |
|
|
191 |
lemma set_pmf_code [code]: "set_pmf p = Mapping.keys (mapping_of_pmf p)"
|
|
192 |
by transfer (auto simp: dom_def set_pmf_eq)
|
|
193 |
|
|
194 |
lemma keys_mapping_of_pmf [simp]: "Mapping.keys (mapping_of_pmf p) = set_pmf p"
|
|
195 |
by transfer (auto simp: dom_def set_pmf_eq)
|
|
196 |
|
|
197 |
|
|
198 |
|
|
199 |
definition fold_combine_plus where
|
|
200 |
"fold_combine_plus = comm_monoid_set.F (Mapping.combine (op + :: real \<Rightarrow> _)) Mapping.empty"
|
|
201 |
|
|
202 |
context
|
|
203 |
begin
|
|
204 |
|
|
205 |
interpretation fold_combine_plus: combine_mapping_abel_semigroup "op + :: real \<Rightarrow> _"
|
|
206 |
by unfold_locales (simp_all add: add_ac)
|
|
207 |
|
|
208 |
qualified lemma lookup_default_fold_combine_plus:
|
|
209 |
fixes A :: "'b set" and f :: "'b \<Rightarrow> ('a, real) mapping"
|
|
210 |
assumes "finite A"
|
|
211 |
shows "Mapping.lookup_default 0 (fold_combine_plus f A) x =
|
|
212 |
(\<Sum>y\<in>A. Mapping.lookup_default 0 (f y) x)"
|
|
213 |
unfolding fold_combine_plus_def using assms
|
|
214 |
by (induction A rule: finite_induct)
|
|
215 |
(simp_all add: lookup_default_empty lookup_default_neutral_combine)
|
|
216 |
|
|
217 |
qualified lemma keys_fold_combine_plus:
|
|
218 |
"finite A \<Longrightarrow> Mapping.keys (fold_combine_plus f A) = (\<Union>x\<in>A. Mapping.keys (f x))"
|
|
219 |
by (simp add: fold_combine_plus_def fold_combine_plus.keys_fold_combine)
|
|
220 |
|
|
221 |
qualified lemma fold_combine_plus_code [code]:
|
|
222 |
"fold_combine_plus g (set xs) = foldr (\<lambda>x. Mapping.combine op+ (g x)) (remdups xs) Mapping.empty"
|
|
223 |
by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code)
|
|
224 |
|
|
225 |
private lemma lookup_default_0_map_values:
|
63195
|
226 |
assumes "f x 0 = 0"
|
|
227 |
shows "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)"
|
63194
|
228 |
unfolding Mapping.lookup_default_def
|
63195
|
229 |
using assms by transfer (auto split: option.splits)
|
63194
|
230 |
|
|
231 |
qualified lemma mapping_of_bind_pmf:
|
|
232 |
assumes "finite (set_pmf p)"
|
|
233 |
shows "mapping_of_pmf (bind_pmf p f) =
|
63195
|
234 |
fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x))
|
63194
|
235 |
(mapping_of_pmf (f x))) (set_pmf p)"
|
|
236 |
using assms
|
|
237 |
by (intro mapping_of_pmfI')
|
|
238 |
(auto simp: keys_fold_combine_plus lookup_default_fold_combine_plus
|
|
239 |
pmf_bind integral_measure_pmf lookup_default_0_map_values
|
|
240 |
lookup_default_mapping_of_pmf mult_ac)
|
|
241 |
|
63195
|
242 |
lift_definition bind_pmf_aux :: "'a pmf \<Rightarrow> ('a \<Rightarrow> 'b pmf) \<Rightarrow> 'a set \<Rightarrow> ('b, real) mapping" is
|
|
243 |
"\<lambda>(p :: 'a pmf) (f :: 'a \<Rightarrow> 'b pmf) (A::'a set) (x::'b).
|
|
244 |
if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then
|
|
245 |
Some (measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x))
|
|
246 |
else None" .
|
|
247 |
|
|
248 |
lemma keys_bind_pmf_aux [simp]:
|
|
249 |
"Mapping.keys (bind_pmf_aux p f A) = (\<Union>x\<in>A. set_pmf (f x))"
|
|
250 |
by transfer (auto split: if_splits)
|
|
251 |
|
|
252 |
lemma lookup_default_bind_pmf_aux:
|
|
253 |
"Mapping.lookup_default 0 (bind_pmf_aux p f A) x =
|
|
254 |
(if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then
|
|
255 |
measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x) else 0)"
|
|
256 |
unfolding lookup_default_def by transfer' simp_all
|
|
257 |
|
|
258 |
lemma lookup_default_bind_pmf_aux' [simp]:
|
|
259 |
"Mapping.lookup_default 0 (bind_pmf_aux p f (set_pmf p)) x = pmf (bind_pmf p f) x"
|
|
260 |
unfolding lookup_default_def
|
|
261 |
by transfer (auto simp: pmf_bind AE_measure_pmf_iff set_pmf_eq
|
|
262 |
intro!: integral_cong_AE integral_eq_zero_AE)
|
|
263 |
|
|
264 |
lemma bind_pmf_aux_correct:
|
|
265 |
"mapping_of_pmf (bind_pmf p f) = bind_pmf_aux p f (set_pmf p)"
|
|
266 |
by (intro mapping_of_pmfI') simp_all
|
|
267 |
|
|
268 |
lemma bind_pmf_aux_code_aux:
|
|
269 |
assumes "finite A"
|
|
270 |
shows "bind_pmf_aux p f A =
|
|
271 |
fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x))
|
|
272 |
(mapping_of_pmf (f x))) A" (is "?lhs = ?rhs")
|
|
273 |
proof (intro mapping_eqI'[where d = 0])
|
|
274 |
fix x assume "x \<in> Mapping.keys ?lhs"
|
|
275 |
then obtain y where y: "y \<in> A" "x \<in> set_pmf (f y)" by auto
|
|
276 |
hence "Mapping.lookup_default 0 ?lhs x =
|
|
277 |
measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)"
|
|
278 |
by (auto simp: lookup_default_bind_pmf_aux)
|
|
279 |
also from assms have "\<dots> = (\<Sum>y\<in>A. pmf p y * pmf (f y) x)"
|
|
280 |
by (subst integral_measure_pmf [of A])
|
|
281 |
(auto simp: set_pmf_eq indicator_def mult_ac split: if_splits)
|
|
282 |
also from assms have "\<dots> = Mapping.lookup_default 0 ?rhs x"
|
|
283 |
by (simp add: lookup_default_fold_combine_plus lookup_default_0_map_values
|
|
284 |
lookup_default_mapping_of_pmf)
|
|
285 |
finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" .
|
|
286 |
qed (insert assms, simp_all add: keys_fold_combine_plus)
|
|
287 |
|
|
288 |
lemma bind_pmf_aux_code [code]:
|
|
289 |
"bind_pmf_aux p f (set xs) =
|
|
290 |
fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x))
|
|
291 |
(mapping_of_pmf (f x))) (set xs)"
|
|
292 |
by (rule bind_pmf_aux_code_aux) simp_all
|
|
293 |
|
|
294 |
lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct
|
63194
|
295 |
|
|
296 |
end
|
|
297 |
|
63195
|
298 |
hide_const (open) fold_combine_plus
|
63194
|
299 |
|
|
300 |
|
|
301 |
lift_definition cond_pmf_impl :: "'a pmf \<Rightarrow> 'a set \<Rightarrow> ('a, real) mapping option" is
|
|
302 |
"\<lambda>p A. if A \<inter> set_pmf p = {} then None else
|
|
303 |
Some (\<lambda>x. if x \<in> A \<inter> set_pmf p then Some (pmf p x / measure_pmf.prob p A) else None)" .
|
|
304 |
|
63195
|
305 |
lemma cond_pmf_impl_code_alt:
|
|
306 |
assumes "finite A"
|
|
307 |
shows "cond_pmf_impl p A = (
|
|
308 |
let C = A \<inter> set_pmf p;
|
|
309 |
prob = (\<Sum>x\<in>C. pmf p x)
|
|
310 |
in if prob = 0 then
|
|
311 |
None
|
|
312 |
else
|
|
313 |
Some (Mapping.map_values (\<lambda>_ y. y / prob)
|
|
314 |
(Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))"
|
63194
|
315 |
proof -
|
63195
|
316 |
define C where "C = A \<inter> set_pmf p"
|
|
317 |
define prob where "prob = (\<Sum>x\<in>C. pmf p x)"
|
|
318 |
also note C_def
|
|
319 |
also from assms have "(\<Sum>x\<in>A \<inter> set_pmf p. pmf p x) = (\<Sum>x\<in>A. pmf p x)"
|
63194
|
320 |
by (intro setsum.mono_neutral_left) (auto simp: set_pmf_eq)
|
63195
|
321 |
finally have prob1: "prob = (\<Sum>x\<in>A. pmf p x)" .
|
|
322 |
hence prob2: "prob = measure_pmf.prob p A"
|
|
323 |
using assms by (subst measure_measure_pmf_finite) simp_all
|
|
324 |
have prob3: "prob = 0 \<longleftrightarrow> A \<inter> set_pmf p = {}"
|
|
325 |
by (subst prob1, subst setsum_nonneg_eq_0_iff) (auto simp: set_pmf_eq assms)
|
|
326 |
from assms have prob4: "prob = measure_pmf.prob p C"
|
|
327 |
unfolding prob_def by (intro measure_measure_pmf_finite [symmetric]) (simp_all add: C_def)
|
63194
|
328 |
|
|
329 |
show ?thesis
|
|
330 |
proof (cases "prob = 0")
|
|
331 |
case True
|
63195
|
332 |
hence "A \<inter> set_pmf p = {}" by (subst (asm) prob3)
|
|
333 |
with True show ?thesis by (simp add: Let_def prob_def C_def cond_pmf_impl.abs_eq)
|
63194
|
334 |
next
|
|
335 |
case False
|
63195
|
336 |
hence A: "C \<noteq> {}" unfolding C_def by (subst (asm) prob3) auto
|
|
337 |
with prob3 have prob_nz: "prob \<noteq> 0" by (auto simp: C_def)
|
63194
|
338 |
fix x
|
63195
|
339 |
have "cond_pmf_impl p A =
|
|
340 |
Some (mapping.Mapping (\<lambda>x. if x \<in> C then
|
|
341 |
Some (pmf p x / measure_pmf.prob p C) else None))"
|
63194
|
342 |
(is "_ = Some ?m")
|
63195
|
343 |
using A prob2 prob4 unfolding C_def by transfer (auto simp: fun_eq_iff)
|
|
344 |
also have "?m = Mapping.map_values (\<lambda>_ y. y / prob)
|
|
345 |
(Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))"
|
|
346 |
using prob_nz prob4 assms unfolding C_def
|
|
347 |
by transfer (auto simp: fun_eq_iff set_pmf_eq)
|
|
348 |
finally show ?thesis using False by (simp add: Let_def prob_def C_def)
|
63194
|
349 |
qed
|
|
350 |
qed
|
|
351 |
|
63195
|
352 |
lemma cond_pmf_impl_code [code]:
|
|
353 |
"cond_pmf_impl p (set xs) = (
|
|
354 |
let C = set xs \<inter> set_pmf p;
|
|
355 |
prob = (\<Sum>x\<in>C. pmf p x)
|
|
356 |
in if prob = 0 then
|
|
357 |
None
|
|
358 |
else
|
|
359 |
Some (Mapping.map_values (\<lambda>_ y. y / prob)
|
|
360 |
(Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))"
|
|
361 |
by (rule cond_pmf_impl_code_alt) simp_all
|
|
362 |
|
63194
|
363 |
lemma cond_pmf_code [code abstract]:
|
|
364 |
"mapping_of_pmf (cond_pmf p A) =
|
|
365 |
(case cond_pmf_impl p A of
|
|
366 |
None \<Rightarrow> Code.abort (STR ''cond_pmf with set of probability 0'')
|
|
367 |
(\<lambda>_. mapping_of_pmf (cond_pmf p A))
|
|
368 |
| Some m \<Rightarrow> m)"
|
|
369 |
proof (cases "cond_pmf_impl p A")
|
|
370 |
case (Some m)
|
|
371 |
hence A: "set_pmf p \<inter> A \<noteq> {}" by transfer (auto split: if_splits)
|
|
372 |
from Some have B: "Mapping.keys m = set_pmf (cond_pmf p A)"
|
|
373 |
by (subst set_cond_pmf[OF A], transfer) (auto split: if_splits)
|
|
374 |
with Some A have "mapping_of_pmf (cond_pmf p A) = m"
|
|
375 |
by (intro mapping_of_pmfI[OF _ B], transfer) (auto split: if_splits simp: pmf_cond)
|
|
376 |
with Some show ?thesis by simp
|
|
377 |
qed simp_all
|
|
378 |
|
|
379 |
|
|
380 |
lemma binomial_pmf_code [code abstract]:
|
|
381 |
"mapping_of_pmf (binomial_pmf n p) = (
|
|
382 |
if p < 0 \<or> p > 1 then
|
63195
|
383 |
Code.abort (STR ''binomial_pmf with invalid probability'')
|
|
384 |
(\<lambda>_. mapping_of_pmf (binomial_pmf n p))
|
63194
|
385 |
else if p = 0 then Mapping.update 0 1 Mapping.empty
|
|
386 |
else if p = 1 then Mapping.update n 1 Mapping.empty
|
|
387 |
else Mapping.tabulate [0..<Suc n] (\<lambda>k. real (n choose k) * p ^ k * (1 - p) ^ (n - k)))"
|
|
388 |
by (cases "p < 0 \<or> p > 1")
|
|
389 |
(simp, intro mapping_of_pmfI,
|
|
390 |
auto simp: lookup_update' lookup_empty set_pmf_binomial_eq lookup_tabulate split: if_splits)
|
|
391 |
|
63195
|
392 |
|
63194
|
393 |
lemma pred_pmf_code [code]:
|
|
394 |
"pred_pmf P p = (\<forall>x\<in>set_pmf p. P x)"
|
|
395 |
by (auto simp: pred_pmf_def)
|
|
396 |
|
|
397 |
|
|
398 |
lemma mapping_of_pmf_pmf_of_list:
|
|
399 |
assumes "\<And>x. x \<in> snd ` set xs \<Longrightarrow> x > 0" "listsum (map snd xs) = 1"
|
|
400 |
shows "mapping_of_pmf (pmf_of_list xs) =
|
|
401 |
Mapping.tabulate (remdups (map fst xs))
|
|
402 |
(\<lambda>x. listsum (map snd (filter (\<lambda>z. fst z = x) xs)))"
|
|
403 |
proof -
|
|
404 |
from assms have wf: "pmf_of_list_wf xs" by (intro pmf_of_list_wfI) force
|
|
405 |
moreover from this assms have "set_pmf (pmf_of_list xs) = fst ` set xs"
|
|
406 |
by (intro set_pmf_of_list_eq) auto
|
|
407 |
ultimately show ?thesis
|
|
408 |
by (intro mapping_of_pmfI) (auto simp: lookup_tabulate pmf_pmf_of_list)
|
|
409 |
qed
|
|
410 |
|
|
411 |
lemma mapping_of_pmf_pmf_of_list':
|
|
412 |
assumes "pmf_of_list_wf xs"
|
|
413 |
defines "xs' \<equiv> filter (\<lambda>z. snd z \<noteq> 0) xs"
|
|
414 |
shows "mapping_of_pmf (pmf_of_list xs) =
|
|
415 |
Mapping.tabulate (remdups (map fst xs'))
|
|
416 |
(\<lambda>x. listsum (map snd (filter (\<lambda>z. fst z = x) xs')))" (is "_ = ?rhs")
|
|
417 |
proof -
|
|
418 |
have wf: "pmf_of_list_wf xs'" unfolding xs'_def by (rule pmf_of_list_remove_zeros) fact
|
|
419 |
have pos: "\<forall>x\<in>snd`set xs'. x > 0" using assms(1) unfolding xs'_def
|
|
420 |
by (force simp: pmf_of_list_wf_def)
|
|
421 |
from assms have "pmf_of_list xs = pmf_of_list xs'"
|
|
422 |
unfolding xs'_def by (subst pmf_of_list_remove_zeros) simp_all
|
|
423 |
also from wf pos have "mapping_of_pmf \<dots> = ?rhs"
|
|
424 |
by (intro mapping_of_pmf_pmf_of_list) (auto simp: pmf_of_list_wf_def)
|
|
425 |
finally show ?thesis .
|
|
426 |
qed
|
|
427 |
|
|
428 |
lemma pmf_of_list_wf_code [code]:
|
|
429 |
"pmf_of_list_wf xs \<longleftrightarrow> list_all (\<lambda>z. snd z \<ge> 0) xs \<and> listsum (map snd xs) = 1"
|
|
430 |
by (auto simp add: pmf_of_list_wf_def list_all_def)
|
|
431 |
|
|
432 |
lemma pmf_of_list_code [code abstract]:
|
|
433 |
"mapping_of_pmf (pmf_of_list xs) = (
|
|
434 |
if pmf_of_list_wf xs then
|
|
435 |
let xs' = filter (\<lambda>z. snd z \<noteq> 0) xs
|
|
436 |
in Mapping.tabulate (remdups (map fst xs'))
|
|
437 |
(\<lambda>x. listsum (map snd (filter (\<lambda>z. fst z = x) xs')))
|
|
438 |
else
|
|
439 |
Code.abort (STR ''Invalid list for pmf_of_list'') (\<lambda>_. mapping_of_pmf (pmf_of_list xs)))"
|
|
440 |
using mapping_of_pmf_pmf_of_list'[of xs] by (simp add: Let_def)
|
|
441 |
|
|
442 |
lemma mapping_of_pmf_eq_iff [simp]:
|
|
443 |
"mapping_of_pmf p = mapping_of_pmf q \<longleftrightarrow> p = (q :: 'a pmf)"
|
|
444 |
proof (transfer, intro iffI pmf_eqI)
|
|
445 |
fix p q :: "'a pmf" and x :: 'a
|
|
446 |
assume "(\<lambda>x. if pmf p x = 0 then None else Some (pmf p x)) =
|
|
447 |
(\<lambda>x. if pmf q x = 0 then None else Some (pmf q x))"
|
|
448 |
hence "(if pmf p x = 0 then None else Some (pmf p x)) =
|
|
449 |
(if pmf q x = 0 then None else Some (pmf q x))" for x
|
|
450 |
by (simp add: fun_eq_iff)
|
|
451 |
from this[of x] show "pmf p x = pmf q x" by (auto split: if_splits)
|
|
452 |
qed (simp_all cong: if_cong)
|
|
453 |
|
63195
|
454 |
|
|
455 |
subsection \<open>Code abbreviations for integrals and probabilities\<close>
|
|
456 |
|
|
457 |
text \<open>
|
|
458 |
Integrals and probabilities are defined for general measures, so we cannot give any
|
|
459 |
code equations directly. We can, however, specialise these constants them to PMFs,
|
|
460 |
give code equations for these specialised constants, and tell the code generator
|
|
461 |
to unfold the original constants to the specialised ones whenever possible.
|
|
462 |
\<close>
|
|
463 |
|
|
464 |
definition pmf_integral where
|
|
465 |
"pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \<Rightarrow> real)"
|
|
466 |
|
|
467 |
definition pmf_set_integral where
|
|
468 |
"pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\<lambda>x. indicator A x * f x :: real)"
|
|
469 |
|
|
470 |
definition pmf_prob where
|
|
471 |
"pmf_prob p A = measure_pmf.prob p A"
|
|
472 |
|
|
473 |
lemma pmf_prob_compl: "pmf_prob p (-A) = 1 - pmf_prob p A"
|
|
474 |
using measure_pmf.prob_compl[of A p] by (simp add: pmf_prob_def Compl_eq_Diff_UNIV)
|
|
475 |
|
|
476 |
lemma pmf_integral_pmf_set_integral [code]:
|
|
477 |
"pmf_integral p f = pmf_set_integral p f (set_pmf p)"
|
|
478 |
unfolding pmf_integral_def pmf_set_integral_def
|
|
479 |
by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff)
|
|
480 |
|
|
481 |
lemma pmf_prob_pmf_set_integral:
|
|
482 |
"pmf_prob p A = pmf_set_integral p (\<lambda>_. 1) A"
|
|
483 |
by (simp add: pmf_prob_def pmf_set_integral_def)
|
|
484 |
|
|
485 |
lemma pmf_set_integral_code_alt_finite:
|
|
486 |
"finite A \<Longrightarrow> pmf_set_integral p f A = (\<Sum>x\<in>A. pmf p x * f x)"
|
|
487 |
unfolding pmf_set_integral_def
|
|
488 |
by (subst integral_measure_pmf[of A]) (auto simp: indicator_def mult_ac split: if_splits)
|
|
489 |
|
|
490 |
lemma pmf_set_integral_code [code]:
|
|
491 |
"pmf_set_integral p f (set xs) = (\<Sum>x\<in>set xs. pmf p x * f x)"
|
|
492 |
by (rule pmf_set_integral_code_alt_finite) simp_all
|
|
493 |
|
|
494 |
|
|
495 |
lemma pmf_prob_code_alt_finite:
|
|
496 |
"finite A \<Longrightarrow> pmf_prob p A = (\<Sum>x\<in>A. pmf p x)"
|
|
497 |
by (simp add: pmf_prob_pmf_set_integral pmf_set_integral_code_alt_finite)
|
|
498 |
|
|
499 |
lemma pmf_prob_code [code]:
|
|
500 |
"pmf_prob p (set xs) = (\<Sum>x\<in>set xs. pmf p x)"
|
|
501 |
"pmf_prob p (List.coset xs) = 1 - (\<Sum>x\<in>set xs. pmf p x)"
|
|
502 |
by (simp_all add: pmf_prob_code_alt_finite pmf_prob_compl)
|
|
503 |
|
|
504 |
|
|
505 |
lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p"
|
|
506 |
by (intro ext) (simp add: pmf_prob_def)
|
|
507 |
|
|
508 |
(* FIXME: Why does this not work without parameters? *)
|
|
509 |
lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p"
|
|
510 |
by (intro ext) (simp add: pmf_integral_def)
|
|
511 |
|
|
512 |
|
|
513 |
|
63194
|
514 |
definition "pmf_of_alist xs = embed_pmf (\<lambda>x. case map_of xs x of Some p \<Rightarrow> p | None \<Rightarrow> 0)"
|
|
515 |
|
|
516 |
lemma pmf_of_mapping_Mapping [code_post]:
|
|
517 |
"pmf_of_mapping (Mapping xs) = pmf_of_alist xs"
|
|
518 |
unfolding pmf_of_mapping_def Mapping.lookup_default_def [abs_def] pmf_of_alist_def
|
|
519 |
by transfer simp_all
|
|
520 |
|
|
521 |
|
|
522 |
instantiation pmf :: (equal) equal
|
|
523 |
begin
|
|
524 |
|
|
525 |
definition "equal_pmf p q = (mapping_of_pmf p = mapping_of_pmf (q :: 'a pmf))"
|
|
526 |
|
|
527 |
instance by standard (simp add: equal_pmf_def)
|
|
528 |
end
|
|
529 |
|
|
530 |
|
|
531 |
definition (in term_syntax)
|
|
532 |
pmfify :: "('a::typerep multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)) \<Rightarrow>
|
|
533 |
'a \<times> (unit \<Rightarrow> Code_Evaluation.term) \<Rightarrow>
|
|
534 |
'a pmf \<times> (unit \<Rightarrow> Code_Evaluation.term)" where
|
|
535 |
[code_unfold]: "pmfify A x =
|
|
536 |
Code_Evaluation.valtermify pmf_of_multiset {\<cdot>}
|
|
537 |
(Code_Evaluation.valtermify (op +) {\<cdot>} A {\<cdot>}
|
|
538 |
(Code_Evaluation.valtermify single {\<cdot>} x))"
|
|
539 |
|
|
540 |
|
|
541 |
notation fcomp (infixl "\<circ>>" 60)
|
|
542 |
notation scomp (infixl "\<circ>\<rightarrow>" 60)
|
|
543 |
|
|
544 |
instantiation pmf :: (random) random
|
|
545 |
begin
|
|
546 |
|
|
547 |
definition
|
|
548 |
"Quickcheck_Random.random i =
|
|
549 |
Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>A.
|
|
550 |
Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>x. Pair (pmfify A x)))"
|
|
551 |
|
|
552 |
instance ..
|
|
553 |
|
|
554 |
end
|
|
555 |
|
|
556 |
no_notation fcomp (infixl "\<circ>>" 60)
|
|
557 |
no_notation scomp (infixl "\<circ>\<rightarrow>" 60)
|
|
558 |
|
|
559 |
instantiation pmf :: (full_exhaustive) full_exhaustive
|
|
560 |
begin
|
|
561 |
|
|
562 |
definition full_exhaustive_pmf :: "('a pmf \<times> (unit \<Rightarrow> term) \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option"
|
|
563 |
where
|
|
564 |
"full_exhaustive_pmf f i =
|
|
565 |
Quickcheck_Exhaustive.full_exhaustive (\<lambda>A.
|
|
566 |
Quickcheck_Exhaustive.full_exhaustive (\<lambda>x. f (pmfify A x)) i) i"
|
|
567 |
|
|
568 |
instance ..
|
|
569 |
|
|
570 |
end
|
|
571 |
|
|
572 |
end |