143 |
143 |
144 lemma pmf_of_multiset_impl_code_alt: |
144 lemma pmf_of_multiset_impl_code_alt: |
145 assumes "A \<noteq> {#}" |
145 assumes "A \<noteq> {#}" |
146 shows "pmf_of_multiset_impl A = |
146 shows "pmf_of_multiset_impl A = |
147 (let p = 1 / real (size A) |
147 (let p = 1 / real (size A) |
148 in fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A)" |
148 in fold_mset (\<lambda>x. Mapping.map_default x 0 ((+) p)) Mapping.empty A)" |
149 proof - |
149 proof - |
150 define p where "p = 1 / real (size A)" |
150 define p where "p = 1 / real (size A)" |
151 interpret comp_fun_commute "\<lambda>x. Mapping.map_default x 0 (op + p)" |
151 interpret comp_fun_commute "\<lambda>x. Mapping.map_default x 0 ((+) p)" |
152 unfolding Mapping.map_default_def [abs_def] |
152 unfolding Mapping.map_default_def [abs_def] |
153 by (standard, intro mapping_eqI ext) |
153 by (standard, intro mapping_eqI ext) |
154 (simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def) |
154 (simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def) |
155 let ?m = "fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A" |
155 let ?m = "fold_mset (\<lambda>x. Mapping.map_default x 0 ((+) p)) Mapping.empty A" |
156 have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all |
156 have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all |
157 have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x |
157 have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x |
158 by (induction A) |
158 by (induction A) |
159 (simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs) |
159 (simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs) |
160 from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def |
160 from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def |
195 by transfer (auto simp: dom_def set_pmf_eq) |
195 by transfer (auto simp: dom_def set_pmf_eq) |
196 |
196 |
197 |
197 |
198 |
198 |
199 definition fold_combine_plus where |
199 definition fold_combine_plus where |
200 "fold_combine_plus = comm_monoid_set.F (Mapping.combine (op + :: real \<Rightarrow> _)) Mapping.empty" |
200 "fold_combine_plus = comm_monoid_set.F (Mapping.combine ((+) :: real \<Rightarrow> _)) Mapping.empty" |
201 |
201 |
202 context |
202 context |
203 begin |
203 begin |
204 |
204 |
205 interpretation fold_combine_plus: combine_mapping_abel_semigroup "op + :: real \<Rightarrow> _" |
205 interpretation fold_combine_plus: combine_mapping_abel_semigroup "(+) :: real \<Rightarrow> _" |
206 by unfold_locales (simp_all add: add_ac) |
206 by unfold_locales (simp_all add: add_ac) |
207 |
207 |
208 qualified lemma lookup_default_fold_combine_plus: |
208 qualified lemma lookup_default_fold_combine_plus: |
209 fixes A :: "'b set" and f :: "'b \<Rightarrow> ('a, real) mapping" |
209 fixes A :: "'b set" and f :: "'b \<Rightarrow> ('a, real) mapping" |
210 assumes "finite A" |
210 assumes "finite A" |
217 qualified lemma keys_fold_combine_plus: |
217 qualified lemma keys_fold_combine_plus: |
218 "finite A \<Longrightarrow> Mapping.keys (fold_combine_plus f A) = (\<Union>x\<in>A. Mapping.keys (f x))" |
218 "finite A \<Longrightarrow> Mapping.keys (fold_combine_plus f A) = (\<Union>x\<in>A. Mapping.keys (f x))" |
219 by (simp add: fold_combine_plus_def fold_combine_plus.keys_fold_combine) |
219 by (simp add: fold_combine_plus_def fold_combine_plus.keys_fold_combine) |
220 |
220 |
221 qualified lemma fold_combine_plus_code [code]: |
221 qualified lemma fold_combine_plus_code [code]: |
222 "fold_combine_plus g (set xs) = foldr (\<lambda>x. Mapping.combine op+ (g x)) (remdups xs) Mapping.empty" |
222 "fold_combine_plus g (set xs) = foldr (\<lambda>x. Mapping.combine (+) (g x)) (remdups xs) Mapping.empty" |
223 by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code) |
223 by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code) |
224 |
224 |
225 private lemma lookup_default_0_map_values: |
225 private lemma lookup_default_0_map_values: |
226 assumes "f x 0 = 0" |
226 assumes "f x 0 = 0" |
227 shows "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)" |
227 shows "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)" |
229 using assms by transfer (auto split: option.splits) |
229 using assms by transfer (auto split: option.splits) |
230 |
230 |
231 qualified lemma mapping_of_bind_pmf: |
231 qualified lemma mapping_of_bind_pmf: |
232 assumes "finite (set_pmf p)" |
232 assumes "finite (set_pmf p)" |
233 shows "mapping_of_pmf (bind_pmf p f) = |
233 shows "mapping_of_pmf (bind_pmf p f) = |
234 fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x)) |
234 fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. ( * ) (pmf p x)) |
235 (mapping_of_pmf (f x))) (set_pmf p)" |
235 (mapping_of_pmf (f x))) (set_pmf p)" |
236 using assms |
236 using assms |
237 by (intro mapping_of_pmfI') |
237 by (intro mapping_of_pmfI') |
238 (auto simp: keys_fold_combine_plus lookup_default_fold_combine_plus |
238 (auto simp: keys_fold_combine_plus lookup_default_fold_combine_plus |
239 pmf_bind integral_measure_pmf lookup_default_0_map_values |
239 pmf_bind integral_measure_pmf lookup_default_0_map_values |
266 by (intro mapping_of_pmfI') simp_all |
266 by (intro mapping_of_pmfI') simp_all |
267 |
267 |
268 lemma bind_pmf_aux_code_aux: |
268 lemma bind_pmf_aux_code_aux: |
269 assumes "finite A" |
269 assumes "finite A" |
270 shows "bind_pmf_aux p f A = |
270 shows "bind_pmf_aux p f A = |
271 fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x)) |
271 fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. ( * ) (pmf p x)) |
272 (mapping_of_pmf (f x))) A" (is "?lhs = ?rhs") |
272 (mapping_of_pmf (f x))) A" (is "?lhs = ?rhs") |
273 proof (intro mapping_eqI'[where d = 0]) |
273 proof (intro mapping_eqI'[where d = 0]) |
274 fix x assume "x \<in> Mapping.keys ?lhs" |
274 fix x assume "x \<in> Mapping.keys ?lhs" |
275 then obtain y where y: "y \<in> A" "x \<in> set_pmf (f y)" by auto |
275 then obtain y where y: "y \<in> A" "x \<in> set_pmf (f y)" by auto |
276 hence "Mapping.lookup_default 0 ?lhs x = |
276 hence "Mapping.lookup_default 0 ?lhs x = |
285 finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" . |
285 finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" . |
286 qed (insert assms, simp_all add: keys_fold_combine_plus) |
286 qed (insert assms, simp_all add: keys_fold_combine_plus) |
287 |
287 |
288 lemma bind_pmf_aux_code [code]: |
288 lemma bind_pmf_aux_code [code]: |
289 "bind_pmf_aux p f (set xs) = |
289 "bind_pmf_aux p f (set xs) = |
290 fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x)) |
290 fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. ( * ) (pmf p x)) |
291 (mapping_of_pmf (f x))) (set xs)" |
291 (mapping_of_pmf (f x))) (set xs)" |
292 by (rule bind_pmf_aux_code_aux) simp_all |
292 by (rule bind_pmf_aux_code_aux) simp_all |
293 |
293 |
294 lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct |
294 lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct |
295 |
295 |
534 pmfify :: "('a::typerep multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)) \<Rightarrow> |
534 pmfify :: "('a::typerep multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)) \<Rightarrow> |
535 'a \<times> (unit \<Rightarrow> Code_Evaluation.term) \<Rightarrow> |
535 'a \<times> (unit \<Rightarrow> Code_Evaluation.term) \<Rightarrow> |
536 'a pmf \<times> (unit \<Rightarrow> Code_Evaluation.term)" where |
536 'a pmf \<times> (unit \<Rightarrow> Code_Evaluation.term)" where |
537 [code_unfold]: "pmfify A x = |
537 [code_unfold]: "pmfify A x = |
538 Code_Evaluation.valtermify pmf_of_multiset {\<cdot>} |
538 Code_Evaluation.valtermify pmf_of_multiset {\<cdot>} |
539 (Code_Evaluation.valtermify (op +) {\<cdot>} A {\<cdot>} |
539 (Code_Evaluation.valtermify (+) {\<cdot>} A {\<cdot>} |
540 (Code_Evaluation.valtermify single {\<cdot>} x))" |
540 (Code_Evaluation.valtermify single {\<cdot>} x))" |
541 |
541 |
542 |
542 |
543 notation fcomp (infixl "\<circ>>" 60) |
543 notation fcomp (infixl "\<circ>>" 60) |
544 notation scomp (infixl "\<circ>\<rightarrow>" 60) |
544 notation scomp (infixl "\<circ>\<rightarrow>" 60) |