src/HOL/Probability/PMF_Impl.thy
changeset 67399 eab6ce8368fa
parent 66453 cc19f7ca2ed6
child 69064 5840724b1d71
equal deleted inserted replaced
67398:5eb932e604a2 67399:eab6ce8368fa
   143 
   143 
   144 lemma pmf_of_multiset_impl_code_alt:
   144 lemma pmf_of_multiset_impl_code_alt:
   145   assumes "A \<noteq> {#}"
   145   assumes "A \<noteq> {#}"
   146   shows   "pmf_of_multiset_impl A =
   146   shows   "pmf_of_multiset_impl A =
   147              (let p = 1 / real (size A)
   147              (let p = 1 / real (size A)
   148               in  fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A)"
   148               in  fold_mset (\<lambda>x. Mapping.map_default x 0 ((+) p)) Mapping.empty A)"
   149 proof -
   149 proof -
   150   define p where "p = 1 / real (size A)"
   150   define p where "p = 1 / real (size A)"
   151   interpret comp_fun_commute "\<lambda>x. Mapping.map_default x 0 (op + p)"
   151   interpret comp_fun_commute "\<lambda>x. Mapping.map_default x 0 ((+) p)"
   152     unfolding Mapping.map_default_def [abs_def]
   152     unfolding Mapping.map_default_def [abs_def]
   153     by (standard, intro mapping_eqI ext) 
   153     by (standard, intro mapping_eqI ext) 
   154        (simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def)
   154        (simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def)
   155   let ?m = "fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A"
   155   let ?m = "fold_mset (\<lambda>x. Mapping.map_default x 0 ((+) p)) Mapping.empty A"
   156   have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all
   156   have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all
   157   have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x
   157   have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x
   158     by (induction A)
   158     by (induction A)
   159        (simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs)
   159        (simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs)
   160   from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def
   160   from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def
   195   by transfer (auto simp: dom_def set_pmf_eq)
   195   by transfer (auto simp: dom_def set_pmf_eq)
   196   
   196   
   197 
   197 
   198 
   198 
   199 definition fold_combine_plus where
   199 definition fold_combine_plus where
   200   "fold_combine_plus = comm_monoid_set.F (Mapping.combine (op + :: real \<Rightarrow> _)) Mapping.empty"
   200   "fold_combine_plus = comm_monoid_set.F (Mapping.combine ((+) :: real \<Rightarrow> _)) Mapping.empty"
   201 
   201 
   202 context
   202 context
   203 begin
   203 begin
   204 
   204 
   205 interpretation fold_combine_plus: combine_mapping_abel_semigroup "op + :: real \<Rightarrow> _"
   205 interpretation fold_combine_plus: combine_mapping_abel_semigroup "(+) :: real \<Rightarrow> _"
   206   by unfold_locales (simp_all add: add_ac)
   206   by unfold_locales (simp_all add: add_ac)
   207   
   207   
   208 qualified lemma lookup_default_fold_combine_plus: 
   208 qualified lemma lookup_default_fold_combine_plus: 
   209   fixes A :: "'b set" and f :: "'b \<Rightarrow> ('a, real) mapping"
   209   fixes A :: "'b set" and f :: "'b \<Rightarrow> ('a, real) mapping"
   210   assumes "finite A"
   210   assumes "finite A"
   217 qualified lemma keys_fold_combine_plus: 
   217 qualified lemma keys_fold_combine_plus: 
   218   "finite A \<Longrightarrow> Mapping.keys (fold_combine_plus f A) = (\<Union>x\<in>A. Mapping.keys (f x))"
   218   "finite A \<Longrightarrow> Mapping.keys (fold_combine_plus f A) = (\<Union>x\<in>A. Mapping.keys (f x))"
   219   by (simp add: fold_combine_plus_def fold_combine_plus.keys_fold_combine)
   219   by (simp add: fold_combine_plus_def fold_combine_plus.keys_fold_combine)
   220 
   220 
   221 qualified lemma fold_combine_plus_code [code]:
   221 qualified lemma fold_combine_plus_code [code]:
   222   "fold_combine_plus g (set xs) = foldr (\<lambda>x. Mapping.combine op+ (g x)) (remdups xs) Mapping.empty"
   222   "fold_combine_plus g (set xs) = foldr (\<lambda>x. Mapping.combine (+) (g x)) (remdups xs) Mapping.empty"
   223   by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code)
   223   by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code)
   224 
   224 
   225 private lemma lookup_default_0_map_values:
   225 private lemma lookup_default_0_map_values:
   226   assumes "f x 0 = 0"
   226   assumes "f x 0 = 0"
   227   shows   "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)"
   227   shows   "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)"
   229   using assms by transfer (auto split: option.splits)
   229   using assms by transfer (auto split: option.splits)
   230 
   230 
   231 qualified lemma mapping_of_bind_pmf:
   231 qualified lemma mapping_of_bind_pmf:
   232   assumes "finite (set_pmf p)"
   232   assumes "finite (set_pmf p)"
   233   shows   "mapping_of_pmf (bind_pmf p f) = 
   233   shows   "mapping_of_pmf (bind_pmf p f) = 
   234              fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x)) 
   234              fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. ( * ) (pmf p x)) 
   235                (mapping_of_pmf (f x))) (set_pmf p)"
   235                (mapping_of_pmf (f x))) (set_pmf p)"
   236   using assms
   236   using assms
   237   by (intro mapping_of_pmfI')
   237   by (intro mapping_of_pmfI')
   238      (auto simp: keys_fold_combine_plus lookup_default_fold_combine_plus 
   238      (auto simp: keys_fold_combine_plus lookup_default_fold_combine_plus 
   239                  pmf_bind integral_measure_pmf lookup_default_0_map_values 
   239                  pmf_bind integral_measure_pmf lookup_default_0_map_values 
   266   by (intro mapping_of_pmfI') simp_all
   266   by (intro mapping_of_pmfI') simp_all
   267 
   267 
   268 lemma bind_pmf_aux_code_aux:
   268 lemma bind_pmf_aux_code_aux:
   269   assumes "finite A"
   269   assumes "finite A"
   270   shows   "bind_pmf_aux p f A = 
   270   shows   "bind_pmf_aux p f A = 
   271              fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x))
   271              fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. ( * ) (pmf p x))
   272                (mapping_of_pmf (f x))) A" (is "?lhs = ?rhs")
   272                (mapping_of_pmf (f x))) A" (is "?lhs = ?rhs")
   273 proof (intro mapping_eqI'[where d = 0])
   273 proof (intro mapping_eqI'[where d = 0])
   274   fix x assume "x \<in> Mapping.keys ?lhs"
   274   fix x assume "x \<in> Mapping.keys ?lhs"
   275   then obtain y where y: "y \<in> A" "x \<in> set_pmf (f y)" by auto
   275   then obtain y where y: "y \<in> A" "x \<in> set_pmf (f y)" by auto
   276   hence "Mapping.lookup_default 0 ?lhs x = 
   276   hence "Mapping.lookup_default 0 ?lhs x = 
   285   finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" .
   285   finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" .
   286 qed (insert assms, simp_all add: keys_fold_combine_plus)
   286 qed (insert assms, simp_all add: keys_fold_combine_plus)
   287 
   287 
   288 lemma bind_pmf_aux_code [code]:
   288 lemma bind_pmf_aux_code [code]:
   289   "bind_pmf_aux p f (set xs) = 
   289   "bind_pmf_aux p f (set xs) = 
   290      fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x))
   290      fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. ( * ) (pmf p x))
   291                (mapping_of_pmf (f x))) (set xs)"
   291                (mapping_of_pmf (f x))) (set xs)"
   292   by (rule bind_pmf_aux_code_aux) simp_all
   292   by (rule bind_pmf_aux_code_aux) simp_all
   293 
   293 
   294 lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct
   294 lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct
   295 
   295 
   534   pmfify :: "('a::typerep multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)) \<Rightarrow>
   534   pmfify :: "('a::typerep multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)) \<Rightarrow>
   535              'a \<times> (unit \<Rightarrow> Code_Evaluation.term) \<Rightarrow>
   535              'a \<times> (unit \<Rightarrow> Code_Evaluation.term) \<Rightarrow>
   536              'a pmf \<times> (unit \<Rightarrow> Code_Evaluation.term)" where
   536              'a pmf \<times> (unit \<Rightarrow> Code_Evaluation.term)" where
   537   [code_unfold]: "pmfify A x =  
   537   [code_unfold]: "pmfify A x =  
   538     Code_Evaluation.valtermify pmf_of_multiset {\<cdot>} 
   538     Code_Evaluation.valtermify pmf_of_multiset {\<cdot>} 
   539       (Code_Evaluation.valtermify (op +) {\<cdot>} A {\<cdot>} 
   539       (Code_Evaluation.valtermify (+) {\<cdot>} A {\<cdot>} 
   540        (Code_Evaluation.valtermify single {\<cdot>} x))"
   540        (Code_Evaluation.valtermify single {\<cdot>} x))"
   541 
   541 
   542 
   542 
   543 notation fcomp (infixl "\<circ>>" 60)
   543 notation fcomp (infixl "\<circ>>" 60)
   544 notation scomp (infixl "\<circ>\<rightarrow>" 60)
   544 notation scomp (infixl "\<circ>\<rightarrow>" 60)