Added code generation for PMFs
authoreberlm
Tue May 31 13:02:44 2016 +0200 (2016-05-31)
changeset 631940b7bdb75f451
parent 63190 3e79279c10ca
child 63195 f3f08c0d4aaf
Added code generation for PMFs
src/HOL/Library/AList_Mapping.thy
src/HOL/Library/Mapping.thy
src/HOL/Library/RBT.thy
src/HOL/Library/RBT_Mapping.thy
src/HOL/Library/RBT_Set.thy
src/HOL/Option.thy
src/HOL/Probability/PMF_Impl.thy
src/HOL/Probability/Probability.thy
src/HOL/Probability/Probability_Mass_Function.thy
src/HOL/Probability/Random_Permutations.thy
     1.1 --- a/src/HOL/Library/AList_Mapping.thy	Tue May 31 12:24:43 2016 +0200
     1.2 +++ b/src/HOL/Library/AList_Mapping.thy	Tue May 31 13:02:44 2016 +0200
     1.3 @@ -63,9 +63,43 @@
     1.4      by (auto intro!: map_of_eqI) (auto dest!: map_of_eq_dom intro: aux)
     1.5  qed
     1.6  
     1.7 +lemma map_values_Mapping [code]:
     1.8 +  fixes f :: "'a \<Rightarrow> 'b" and xs :: "('c \<times> 'a) list"
     1.9 +  shows "Mapping.map_values f (Mapping xs) = Mapping (map (\<lambda>(x,y). (x, f y)) xs)"
    1.10 +proof (transfer, rule ext, goal_cases)
    1.11 +  case (1 f xs x)
    1.12 +  thus ?case by (induction xs) auto
    1.13 +qed
    1.14 +
    1.15 +lemma combine_code [code]: 
    1.16 +  "Mapping.combine f (Mapping xs) (Mapping ys) =
    1.17 +     Mapping.tabulate (remdups (map fst xs @ map fst ys)) 
    1.18 +       (\<lambda>x. the (combine_options f (map_of xs x) (map_of ys x)))"
    1.19 +proof (transfer, rule ext, rule sym, goal_cases)
    1.20 +  case (1 f xs ys x)
    1.21 +  show ?case
    1.22 +  by (cases "map_of xs x"; cases "map_of ys x"; simp)
    1.23 +     (force simp: map_of_eq_None_iff combine_options_altdef option.the_def o_def image_iff
    1.24 +            dest: map_of_SomeD split: option.splits)+
    1.25 +qed
    1.26 +
    1.27 +(* TODO: Move? *)
    1.28 +lemma map_of_filter_distinct:
    1.29 +  assumes "distinct (map fst xs)"
    1.30 +  shows   "map_of (filter P xs) x = 
    1.31 +             (case map_of xs x of None \<Rightarrow> None | Some y \<Rightarrow> if P (x,y) then Some y else None)"
    1.32 +  using assms
    1.33 +  by (auto simp: map_of_eq_None_iff filter_map distinct_map_filter dest: map_of_SomeD
    1.34 +           simp del: map_of_eq_Some_iff intro!: map_of_is_SomeI split: option.splits)
    1.35 +(* END TODO *)
    1.36 +  
    1.37 +lemma filter_Mapping [code]:
    1.38 +  "Mapping.filter P (Mapping xs) = Mapping (filter (\<lambda>(k,v). P k v) (AList.clearjunk xs))"
    1.39 + by (transfer, rule ext)
    1.40 +    (subst map_of_filter_distinct, simp_all add: map_of_clearjunk split: option.split)
    1.41 +
    1.42  lemma [code nbe]:
    1.43    "HOL.equal (x :: ('a, 'b) mapping) x \<longleftrightarrow> True"
    1.44    by (fact equal_refl)
    1.45  
    1.46  end
    1.47 -
     2.1 --- a/src/HOL/Library/Mapping.thy	Tue May 31 12:24:43 2016 +0200
     2.2 +++ b/src/HOL/Library/Mapping.thy	Tue May 31 13:02:44 2016 +0200
     2.3 @@ -88,6 +88,18 @@
     2.4    "((A ===> B) ===> (C ===> D) ===> (B ===> rel_option C) ===> A ===> rel_option D) 
     2.5       (\<lambda>f g m. (map_option g \<circ> m \<circ> f)) (\<lambda>f g m. (map_option g \<circ> m \<circ> f))"
     2.6    by transfer_prover
     2.7 +  
     2.8 +lemma combine_with_key_parametric: 
     2.9 +  shows "((A ===> B ===> B ===> B) ===> (A ===> rel_option B) ===> (A ===> rel_option B) ===>
    2.10 +           (A ===> rel_option B)) (\<lambda>f m1 m2 x. combine_options (f x) (m1 x) (m2 x))
    2.11 +           (\<lambda>f m1 m2 x. combine_options (f x) (m1 x) (m2 x))"
    2.12 +  unfolding combine_options_def by transfer_prover
    2.13 +  
    2.14 +lemma combine_parametric: 
    2.15 +  shows "((B ===> B ===> B) ===> (A ===> rel_option B) ===> (A ===> rel_option B) ===>
    2.16 +           (A ===> rel_option B)) (\<lambda>f m1 m2 x. combine_options f (m1 x) (m2 x))
    2.17 +           (\<lambda>f m1 m2 x. combine_options f (m1 x) (m2 x))"
    2.18 +  unfolding combine_options_def by transfer_prover
    2.19  
    2.20  end
    2.21  
    2.22 @@ -106,6 +118,8 @@
    2.23  lift_definition lookup :: "('a, 'b) mapping \<Rightarrow> 'a \<Rightarrow> 'b option"
    2.24    is "\<lambda>m k. m k" parametric lookup_parametric .
    2.25  
    2.26 +definition "lookup_default d m k = (case Mapping.lookup m k of None \<Rightarrow> d | Some v \<Rightarrow> v)"
    2.27 +
    2.28  declare [[code drop: Mapping.lookup]]
    2.29  setup \<open>Code.add_default_eqn @{thm Mapping.lookup.abs_eq}\<close> \<comment> \<open>FIXME lifting\<close>
    2.30  
    2.31 @@ -115,6 +129,9 @@
    2.32  lift_definition delete :: "'a \<Rightarrow> ('a, 'b) mapping \<Rightarrow> ('a, 'b) mapping"
    2.33    is "\<lambda>k m. m(k := None)" parametric delete_parametric .
    2.34  
    2.35 +lift_definition filter :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> ('a, 'b) mapping \<Rightarrow> ('a, 'b) mapping"
    2.36 +  is "\<lambda>P m k. case m k of None \<Rightarrow> None | Some v \<Rightarrow> if P k v then Some v else None" . 
    2.37 +
    2.38  lift_definition keys :: "('a, 'b) mapping \<Rightarrow> 'a set"
    2.39    is dom parametric dom_parametric .
    2.40  
    2.41 @@ -126,6 +143,20 @@
    2.42  
    2.43  lift_definition map :: "('c \<Rightarrow> 'a) \<Rightarrow> ('b \<Rightarrow> 'd) \<Rightarrow> ('a, 'b) mapping \<Rightarrow> ('c, 'd) mapping"
    2.44    is "\<lambda>f g m. (map_option g \<circ> m \<circ> f)" parametric map_parametric .
    2.45 +  
    2.46 +lift_definition map_values :: "('c \<Rightarrow> 'a \<Rightarrow> 'b) \<Rightarrow> ('c, 'a) mapping \<Rightarrow> ('c, 'b) mapping"
    2.47 +  is "\<lambda>f m x. map_option (f x) (m x)" . 
    2.48 +
    2.49 +lift_definition combine_with_key :: 
    2.50 +  "('a \<Rightarrow> 'b \<Rightarrow> 'b \<Rightarrow> 'b) \<Rightarrow> ('a,'b) mapping \<Rightarrow> ('a,'b) mapping \<Rightarrow> ('a,'b) mapping"
    2.51 +  is "\<lambda>f m1 m2 x. combine_options (f x) (m1 x) (m2 x)" parametric combine_with_key_parametric .
    2.52 +
    2.53 +lift_definition combine :: 
    2.54 +  "('b \<Rightarrow> 'b \<Rightarrow> 'b) \<Rightarrow> ('a,'b) mapping \<Rightarrow> ('a,'b) mapping \<Rightarrow> ('a,'b) mapping"
    2.55 +  is "\<lambda>f m1 m2 x. combine_options f (m1 x) (m2 x)" parametric combine_parametric .
    2.56 +
    2.57 +definition All_mapping where
    2.58 +  "All_mapping m P \<longleftrightarrow> (\<forall>x. case Mapping.lookup m x of None \<Rightarrow> True | Some y \<Rightarrow> P x y)"
    2.59  
    2.60  declare [[code drop: map]]
    2.61  
    2.62 @@ -217,10 +248,80 @@
    2.63    "k \<noteq> k' \<Longrightarrow> lookup (update k v m) k' = lookup m k'" 
    2.64    by transfer simp
    2.65  
    2.66 +lemma lookup_update': 
    2.67 +  "Mapping.lookup (update k v m) k' = (if k = k' then Some v else lookup m k')"
    2.68 +  by (auto simp: lookup_update lookup_update_neq)
    2.69 +
    2.70  lemma lookup_empty:
    2.71    "lookup empty k = None" 
    2.72    by transfer simp
    2.73  
    2.74 +lemma lookup_filter:
    2.75 +  "lookup (filter P m) k = 
    2.76 +     (case lookup m k of None \<Rightarrow> None | Some v \<Rightarrow> if P k v then Some v else None)"
    2.77 +  by transfer simp_all
    2.78 +
    2.79 +lemma lookup_map_values:
    2.80 +  "lookup (map_values f m) k = map_option (f k) (lookup m k)"
    2.81 +  by transfer simp_all
    2.82 +
    2.83 +lemma lookup_default_empty: "lookup_default d empty k = d"
    2.84 +  by (simp add: lookup_default_def lookup_empty)
    2.85 +
    2.86 +lemma lookup_default_update:
    2.87 +  "lookup_default d (update k v m) k = v" 
    2.88 +  by (simp add: lookup_default_def lookup_update)
    2.89 +
    2.90 +lemma lookup_default_update_neq:
    2.91 +  "k \<noteq> k' \<Longrightarrow> lookup_default d (update k v m) k' = lookup_default d m k'" 
    2.92 +  by (simp add: lookup_default_def lookup_update_neq)
    2.93 +
    2.94 +lemma lookup_default_update': 
    2.95 +  "lookup_default d (update k v m) k' = (if k = k' then v else lookup_default d m k')"
    2.96 +  by (auto simp: lookup_default_update lookup_default_update_neq)
    2.97 +
    2.98 +lemma lookup_default_filter:
    2.99 +  "lookup_default d (filter P m) k =  
   2.100 +     (if P k (lookup_default d m k) then lookup_default d m k else d)"
   2.101 +  by (simp add: lookup_default_def lookup_filter split: option.splits)
   2.102 +
   2.103 +lemma lookup_default_map_values:
   2.104 +  "lookup_default (f k d) (map_values f m) k = f k (lookup_default d m k)"
   2.105 +  by (simp add: lookup_default_def lookup_map_values split: option.splits)  
   2.106 +
   2.107 +lemma lookup_combine_with_key:
   2.108 +  "Mapping.lookup (combine_with_key f m1 m2) x = 
   2.109 +     combine_options (f x) (Mapping.lookup m1 x) (Mapping.lookup m2 x)"
   2.110 +  by transfer (auto split: option.splits)
   2.111 +  
   2.112 +lemma combine_altdef: "combine f m1 m2 = combine_with_key (\<lambda>_. f) m1 m2"
   2.113 +  by transfer' (rule refl)
   2.114 +
   2.115 +lemma lookup_combine:
   2.116 +  "Mapping.lookup (combine f m1 m2) x = 
   2.117 +     combine_options f (Mapping.lookup m1 x) (Mapping.lookup m2 x)"
   2.118 +  by transfer (auto split: option.splits)
   2.119 +  
   2.120 +lemma lookup_default_neutral_combine_with_key: 
   2.121 +  assumes "\<And>x. f k d x = x" "\<And>x. f k x d = x"
   2.122 +  shows   "Mapping.lookup_default d (combine_with_key f m1 m2) k = 
   2.123 +             f k (Mapping.lookup_default d m1 k) (Mapping.lookup_default d m2 k)"
   2.124 +  by (auto simp: lookup_default_def lookup_combine_with_key assms split: option.splits)
   2.125 +  
   2.126 +lemma lookup_default_neutral_combine: 
   2.127 +  assumes "\<And>x. f d x = x" "\<And>x. f x d = x"
   2.128 +  shows   "Mapping.lookup_default d (combine f m1 m2) x = 
   2.129 +             f (Mapping.lookup_default d m1 x) (Mapping.lookup_default d m2 x)"
   2.130 +  by (auto simp: lookup_default_def lookup_combine assms split: option.splits)
   2.131 +
   2.132 +lemma lookup_tabulate: 
   2.133 +  assumes "distinct xs"
   2.134 +  shows   "Mapping.lookup (Mapping.tabulate xs f) x = (if x \<in> set xs then Some (f x) else None)"
   2.135 +  using assms by transfer (auto simp: map_of_eq_None_iff o_def dest!: map_of_SomeD)
   2.136 +
   2.137 +lemma lookup_of_alist: "Mapping.lookup (Mapping.of_alist xs) k = map_of xs k"
   2.138 +  by transfer simp_all
   2.139 +
   2.140  lemma keys_is_none_rep [code_unfold]:
   2.141    "k \<in> keys m \<longleftrightarrow> \<not> (Option.is_none (lookup m k))"
   2.142    by transfer (auto simp add: Option.is_none_def)
   2.143 @@ -247,6 +348,13 @@
   2.144    "k \<notin> keys m \<Longrightarrow> replace k v m = m"
   2.145    "k \<in> keys m \<Longrightarrow> replace k v m = update k v m"
   2.146    by (transfer, auto simp add: replace_def fun_upd_twist)+
   2.147 +  
   2.148 +lemma map_values_update: "map_values f (update k v m) = update k (f k v) (map_values f m)"
   2.149 +  by transfer (simp_all add: fun_eq_iff)
   2.150 +  
   2.151 +lemma size_mono:
   2.152 +  "finite (keys m') \<Longrightarrow> keys m \<subseteq> keys m' \<Longrightarrow> size m \<le> size m'"
   2.153 +  unfolding size_def by (auto intro: card_mono)
   2.154  
   2.155  lemma size_empty [simp]:
   2.156    "size empty = 0"
   2.157 @@ -265,6 +373,13 @@
   2.158    "size (tabulate ks f) = length (remdups ks)"
   2.159    unfolding size_def by transfer (auto simp add: map_of_map_restrict  card_set comp_def)
   2.160  
   2.161 +lemma keys_filter: "keys (filter P m) \<subseteq> keys m"
   2.162 +  by transfer (auto split: option.splits)
   2.163 +
   2.164 +lemma size_filter: "finite (keys m) \<Longrightarrow> size (filter P m) \<le> size m"
   2.165 +  by (intro size_mono keys_filter)
   2.166 +
   2.167 +
   2.168  lemma bulkload_tabulate:
   2.169    "bulkload xs = tabulate [0..<length xs] (nth xs)"
   2.170    by transfer (auto simp add: map_of_map_restrict)
   2.171 @@ -293,6 +408,10 @@
   2.172    "is_empty (map_entry k f m) \<longleftrightarrow> is_empty m"
   2.173    unfolding is_empty_def by transfer (auto split: option.split)
   2.174  
   2.175 +lemma is_empty_map_values [simp]:
   2.176 +  "is_empty (map_values f m) \<longleftrightarrow> is_empty m"
   2.177 +  unfolding is_empty_def by transfer (auto simp: fun_eq_iff)
   2.178 +
   2.179  lemma is_empty_map_default [simp]:
   2.180    "\<not> is_empty (map_default k v f m)"
   2.181    by (simp add: map_default_def)
   2.182 @@ -329,10 +448,24 @@
   2.183    "keys (map_default k v f m) = insert k (keys m)"
   2.184    by (simp add: map_default_def)
   2.185  
   2.186 +lemma keys_map_values [simp]:
   2.187 +  "keys (map_values f m) = keys m"
   2.188 +  by transfer (simp_all add: dom_def)
   2.189 +
   2.190 +lemma keys_combine_with_key [simp]: 
   2.191 +  "Mapping.keys (combine_with_key f m1 m2) = Mapping.keys m1 \<union> Mapping.keys m2"
   2.192 +  by transfer (auto simp: dom_def combine_options_def split: option.splits)  
   2.193 +
   2.194 +lemma keys_combine [simp]: "Mapping.keys (combine f m1 m2) = Mapping.keys m1 \<union> Mapping.keys m2"
   2.195 +  by (simp add: combine_altdef)
   2.196 +
   2.197  lemma keys_tabulate [simp]:
   2.198    "keys (tabulate ks f) = set ks"
   2.199    by transfer (simp add: map_of_map_restrict o_def)
   2.200  
   2.201 +lemma keys_of_alist [simp]: "keys (of_alist xs) = set (List.map fst xs)"
   2.202 +  by transfer (simp_all add: dom_map_of_conv_image_fst)
   2.203 +
   2.204  lemma keys_bulkload [simp]:
   2.205    "keys (bulkload xs) = {0..<length xs}"
   2.206    by (simp add: bulkload_tabulate)
   2.207 @@ -407,11 +540,91 @@
   2.208      by simp
   2.209  qed
   2.210  
   2.211 +lemma All_mapping_mono:
   2.212 +  "(\<And>k v. k \<in> keys m \<Longrightarrow> P k v \<Longrightarrow> Q k v) \<Longrightarrow> All_mapping m P \<Longrightarrow> All_mapping m Q"
   2.213 +  unfolding All_mapping_def by transfer (auto simp: All_mapping_def dom_def split: option.splits)
   2.214  
   2.215 -subsection \<open>Code generator setup\<close>
   2.216 +lemma All_mapping_empty [simp]: "All_mapping Mapping.empty P"
   2.217 +  by (auto simp: All_mapping_def lookup_empty)
   2.218 +  
   2.219 +lemma All_mapping_update_iff: 
   2.220 +  "All_mapping (Mapping.update k v m) P \<longleftrightarrow> P k v \<and> All_mapping m (\<lambda>k' v'. k = k' \<or> P k' v')"
   2.221 +  unfolding All_mapping_def 
   2.222 +proof safe
   2.223 +  assume "\<forall>x. case Mapping.lookup (Mapping.update k v m) x of None \<Rightarrow> True | Some y \<Rightarrow> P x y"
   2.224 +  hence A: "case Mapping.lookup (Mapping.update k v m) x of None \<Rightarrow> True | Some y \<Rightarrow> P x y" for x
   2.225 +    by blast
   2.226 +  from A[of k] show "P k v" by (simp add: lookup_update)
   2.227 +  show "case Mapping.lookup m x of None \<Rightarrow> True | Some v' \<Rightarrow> k = x \<or> P x v'" for x
   2.228 +    using A[of x] by (auto simp add: lookup_update' split: if_splits option.splits)
   2.229 +next
   2.230 +  assume "P k v"
   2.231 +  assume "\<forall>x. case Mapping.lookup m x of None \<Rightarrow> True | Some v' \<Rightarrow> k = x \<or> P x v'"
   2.232 +  hence A: "case Mapping.lookup m x of None \<Rightarrow> True | Some v' \<Rightarrow> k = x \<or> P x v'" for x by blast
   2.233 +  show "case Mapping.lookup (Mapping.update k v m) x of None \<Rightarrow> True | Some xa \<Rightarrow> P x xa" for x
   2.234 +    using \<open>P k v\<close> A[of x] by (auto simp: lookup_update' split: option.splits)
   2.235 +qed
   2.236 +
   2.237 +lemma All_mapping_update:
   2.238 +  "P k v \<Longrightarrow> All_mapping m (\<lambda>k' v'. k = k' \<or> P k' v') \<Longrightarrow> All_mapping (Mapping.update k v m) P"
   2.239 +  by (simp add: All_mapping_update_iff)
   2.240 +
   2.241 +lemma All_mapping_filter_iff:
   2.242 +  "All_mapping (filter P m) Q \<longleftrightarrow> All_mapping m (\<lambda>k v. P k v \<longrightarrow> Q k v)"
   2.243 +  by (auto simp: All_mapping_def lookup_filter split: option.splits)
   2.244 +
   2.245 +lemma All_mapping_filter:
   2.246 +  "All_mapping m Q \<Longrightarrow> All_mapping (filter P m) Q"
   2.247 +  by (auto simp: All_mapping_filter_iff intro: All_mapping_mono)
   2.248  
   2.249 -hide_const (open) empty is_empty rep lookup update delete ordered_keys keys size
   2.250 -  replace default map_entry map_default tabulate bulkload map of_alist
   2.251 +lemma All_mapping_map_values:
   2.252 +  "All_mapping (map_values f m) P \<longleftrightarrow> All_mapping m (\<lambda>k v. P k (f k v))"
   2.253 +  by (auto simp: All_mapping_def lookup_map_values split: option.splits)
   2.254 +
   2.255 +lemma All_mapping_tabulate: 
   2.256 +  "(\<forall>x\<in>set xs. P x (f x)) \<Longrightarrow> All_mapping (Mapping.tabulate xs f) P"
   2.257 +  unfolding All_mapping_def 
   2.258 +  by (intro allI,  transfer) (auto split: option.split dest!: map_of_SomeD)
   2.259 +
   2.260 +lemma All_mapping_alist:
   2.261 +  "(\<And>k v. (k, v) \<in> set xs \<Longrightarrow> P k v) \<Longrightarrow> All_mapping (Mapping.of_alist xs) P"
   2.262 +  by (auto simp: All_mapping_def lookup_of_alist dest!: map_of_SomeD split: option.splits)
   2.263 +
   2.264 +
   2.265 +lemma combine_empty [simp]:
   2.266 +  "combine f Mapping.empty y = y" "combine f y Mapping.empty = y"
   2.267 +  by (transfer, force)+
   2.268 +
   2.269 +lemma (in abel_semigroup) comm_monoid_set_combine: "comm_monoid_set (combine f) Mapping.empty"
   2.270 +  by standard (transfer fixing: f, simp add: combine_options_ac[of f] ac_simps)+
   2.271 +
   2.272 +locale combine_mapping_abel_semigroup = abel_semigroup
   2.273 +begin
   2.274 +
   2.275 +sublocale combine: comm_monoid_set "combine f" Mapping.empty
   2.276 +  by (rule comm_monoid_set_combine)
   2.277 +
   2.278 +lemma fold_combine_code:
   2.279 +  "combine.F g (set xs) = foldr (\<lambda>x. combine f (g x)) (remdups xs) Mapping.empty"
   2.280 +proof -
   2.281 +  have "combine.F g (set xs) = foldr (\<lambda>x. combine f (g x)) xs Mapping.empty"
   2.282 +    if "distinct xs" for xs
   2.283 +    using that by (induction xs) simp_all
   2.284 +  from this[of "remdups xs"] show ?thesis by simp
   2.285 +qed
   2.286 +  
   2.287 +lemma keys_fold_combine:
   2.288 +  assumes "finite A"
   2.289 +  shows   "Mapping.keys (combine.F g A) = (\<Union>x\<in>A. Mapping.keys (g x))"
   2.290 +  using assms by (induction A rule: finite_induct) simp_all
   2.291  
   2.292  end
   2.293  
   2.294 +  
   2.295 +subsection \<open>Code generator setup\<close>
   2.296 +
   2.297 +hide_const (open) empty is_empty rep lookup lookup_default filter update delete ordered_keys
   2.298 +  keys size replace default map_entry map_default tabulate bulkload map map_values combine of_alist
   2.299 +
   2.300 +end
   2.301 +
     3.1 --- a/src/HOL/Library/RBT.thy	Tue May 31 12:24:43 2016 +0200
     3.2 +++ b/src/HOL/Library/RBT.thy	Tue May 31 13:02:44 2016 +0200
     3.3 @@ -67,12 +67,22 @@
     3.4  
     3.5  lift_definition foldi :: "('c \<Rightarrow> bool) \<Rightarrow> ('a \<Rightarrow> 'b \<Rightarrow> 'c \<Rightarrow> 'c) \<Rightarrow> ('a :: linorder, 'b) rbt \<Rightarrow> 'c \<Rightarrow> 'c"
     3.6    is RBT_Impl.foldi .
     3.7 +  
     3.8 +lift_definition combine_with_key :: "('a \<Rightarrow> 'b \<Rightarrow> 'b \<Rightarrow> 'b) \<Rightarrow> ('a::linorder, 'b) rbt \<Rightarrow> ('a, 'b) rbt \<Rightarrow> ('a, 'b) rbt"
     3.9 +  is RBT_Impl.rbt_union_with_key by (rule is_rbt_rbt_unionwk)
    3.10 +
    3.11 +lift_definition combine :: "('b \<Rightarrow> 'b \<Rightarrow> 'b) \<Rightarrow> ('a::linorder, 'b) rbt \<Rightarrow> ('a, 'b) rbt \<Rightarrow> ('a, 'b) rbt"
    3.12 +  is RBT_Impl.rbt_union_with by (rule rbt_unionw_is_rbt)
    3.13  
    3.14  subsection \<open>Derived operations\<close>
    3.15  
    3.16  definition is_empty :: "('a::linorder, 'b) rbt \<Rightarrow> bool" where
    3.17    [code]: "is_empty t = (case impl_of t of RBT_Impl.Empty \<Rightarrow> True | _ \<Rightarrow> False)"
    3.18  
    3.19 +(* TODO: Is deleting more efficient than re-building the tree? 
    3.20 +   (Probably more difficult to prove though *)
    3.21 +definition filter :: "('a \<Rightarrow> 'b \<Rightarrow> bool) \<Rightarrow> ('a::linorder, 'b) rbt \<Rightarrow> ('a, 'b) rbt" where
    3.22 +  [code]: "filter P t = fold (\<lambda>k v t. if P k v then insert k v t else t) t empty" 
    3.23  
    3.24  subsection \<open>Abstract lookup properties\<close>
    3.25  
    3.26 @@ -128,6 +138,17 @@
    3.27    "lookup (map f t) k = map_option (f k) (lookup t k)"
    3.28    by transfer (rule rbt_lookup_map)
    3.29  
    3.30 +lemma lookup_combine_with_key [simp]:
    3.31 +  "lookup (combine_with_key f t1 t2) k = combine_options (f k) (lookup t1 k) (lookup t2 k)"
    3.32 +  by transfer (simp_all add: combine_options_def rbt_lookup_rbt_unionwk)
    3.33 +
    3.34 +lemma combine_altdef: "combine f t1 t2 = combine_with_key (\<lambda>_. f) t1 t2"
    3.35 +  by transfer (simp add: rbt_union_with_def)
    3.36 +
    3.37 +lemma lookup_combine [simp]:
    3.38 +  "lookup (combine f t1 t2) k = combine_options f (lookup t1 k) (lookup t2 k)"
    3.39 +  by (simp add: combine_altdef)
    3.40 +
    3.41  lemma fold_fold:
    3.42    "fold f t = List.fold (case_prod f) (entries t)"
    3.43    by transfer (rule RBT_Impl.fold_def)
    3.44 @@ -182,6 +203,26 @@
    3.45    "keys t = List.map fst (entries t)"
    3.46    by transfer (simp add: RBT_Impl.keys_def)
    3.47  
    3.48 +context
    3.49 +begin
    3.50 +
    3.51 +private lemma lookup_filter_aux:
    3.52 +  assumes "distinct (List.map fst xs)"
    3.53 +  shows   "lookup (List.fold (\<lambda>(k, v) t. if P k v then insert k v t else t) xs t) k =
    3.54 +             (case map_of xs k of 
    3.55 +                None \<Rightarrow> lookup t k
    3.56 +              | Some v \<Rightarrow> if P k v then Some v else lookup t k)"
    3.57 +  using assms by (induction xs arbitrary: t) (force split: option.splits)+
    3.58 +
    3.59 +lemma lookup_filter: 
    3.60 +  "lookup (filter P t) k = 
    3.61 +     (case lookup t k of None \<Rightarrow> None | Some v \<Rightarrow> if P k v then Some v else None)"
    3.62 +  unfolding filter_def using lookup_filter_aux[of "entries t" P empty k]
    3.63 +  by (simp add: fold_fold distinct_entries split: option.splits)
    3.64 +  
    3.65 +end
    3.66 +
    3.67 +
    3.68  subsection \<open>Quickcheck generators\<close>
    3.69  
    3.70  quickcheck_generator rbt predicate: is_rbt constructors: empty, insert
     4.1 --- a/src/HOL/Library/RBT_Mapping.thy	Tue May 31 12:24:43 2016 +0200
     4.2 +++ b/src/HOL/Library/RBT_Mapping.thy	Tue May 31 13:02:44 2016 +0200
     4.3 @@ -77,6 +77,24 @@
     4.4  
     4.5  end
     4.6  
     4.7 +lemma map_values_Mapping [code]: 
     4.8 +  "Mapping.map_values f (Mapping t) = Mapping (RBT.map f t)"
     4.9 +  by (transfer fixing: t) (auto simp: fun_eq_iff)
    4.10 +
    4.11 +lemma filter_Mapping [code]: 
    4.12 +  "Mapping.filter P (Mapping t) = Mapping (RBT.filter P t)"
    4.13 +  by (transfer' fixing: P t) (simp add: RBT.lookup_filter fun_eq_iff)
    4.14 +
    4.15 +lemma combine_with_key_Mapping [code]:
    4.16 +  "Mapping.combine_with_key f (Mapping t1) (Mapping t2) =
    4.17 +     Mapping (RBT.combine_with_key f t1 t2)"
    4.18 +  by (transfer fixing: f t1 t2) (simp_all add: fun_eq_iff)
    4.19 +
    4.20 +lemma combine_Mapping [code]:
    4.21 +  "Mapping.combine f (Mapping t1) (Mapping t2) =
    4.22 +     Mapping (RBT.combine f t1 t2)"
    4.23 +  by (transfer fixing: f t1 t2) (simp_all add: fun_eq_iff)
    4.24 +
    4.25  lemma equal_Mapping [code]:
    4.26    "HOL.equal (Mapping t1) (Mapping t2) \<longleftrightarrow> RBT.entries t1 = RBT.entries t2"
    4.27    by (transfer fixing: t1 t2) (simp add: entries_lookup)
     5.1 --- a/src/HOL/Library/RBT_Set.thy	Tue May 31 12:24:43 2016 +0200
     5.2 +++ b/src/HOL/Library/RBT_Set.thy	Tue May 31 13:02:44 2016 +0200
     5.3 @@ -838,10 +838,10 @@
     5.4  
     5.5  lemma Bleast_code [code]:
     5.6    "Bleast (Set t) P =
     5.7 -    (case filter P (RBT.keys t) of
     5.8 +    (case List.filter P (RBT.keys t) of
     5.9        x # xs \<Rightarrow> x
    5.10      | [] \<Rightarrow> abort_Bleast (Set t) P)"
    5.11 -proof (cases "filter P (RBT.keys t)")
    5.12 +proof (cases "List.filter P (RBT.keys t)")
    5.13    case Nil
    5.14    thus ?thesis by (simp add: Bleast_def abort_Bleast_def)
    5.15  next
     6.1 --- a/src/HOL/Option.thy	Tue May 31 12:24:43 2016 +0200
     6.2 +++ b/src/HOL/Option.thy	Tue May 31 13:02:44 2016 +0200
     6.3 @@ -136,6 +136,43 @@
     6.4      | _ \<Rightarrow> False)"
     6.5    by (auto split: prod.split option.split)
     6.6  
     6.7 +
     6.8 +definition combine_options :: "('a \<Rightarrow> 'a \<Rightarrow> 'a) \<Rightarrow> 'a option \<Rightarrow> 'a option \<Rightarrow> 'a option"
     6.9 +  where "combine_options f x y = 
    6.10 +           (case x of None \<Rightarrow> y | Some x \<Rightarrow> (case y of None \<Rightarrow> Some x | Some y \<Rightarrow> Some (f x y)))"
    6.11 +
    6.12 +lemma combine_options_simps [simp]:
    6.13 +  "combine_options f None y = y"
    6.14 +  "combine_options f x None = x"
    6.15 +  "combine_options f (Some a) (Some b) = Some (f a b)"
    6.16 +  by (simp_all add: combine_options_def split: option.splits)
    6.17 +  
    6.18 +lemma combine_options_cases [case_names None1 None2 Some]:
    6.19 +  "(x = None \<Longrightarrow> P x y) \<Longrightarrow> (y = None \<Longrightarrow> P x y) \<Longrightarrow> 
    6.20 +     (\<And>a b. x = Some a \<Longrightarrow> y = Some b \<Longrightarrow> P x y) \<Longrightarrow> P x y"
    6.21 +  by (cases x; cases y) simp_all
    6.22 +
    6.23 +lemma combine_options_commute: 
    6.24 +  "(\<And>x y. f x y = f y x) \<Longrightarrow> combine_options f x y = combine_options f y x"
    6.25 +  using combine_options_cases[of x ]
    6.26 +  by (induction x y rule: combine_options_cases) simp_all
    6.27 +
    6.28 +lemma combine_options_assoc:
    6.29 +  "(\<And>x y z. f (f x y) z = f x (f y z)) \<Longrightarrow> 
    6.30 +     combine_options f (combine_options f x y) z =
    6.31 +     combine_options f x (combine_options f y z)"
    6.32 +  by (auto simp: combine_options_def split: option.splits)
    6.33 +
    6.34 +lemma combine_options_left_commute:
    6.35 +  "(\<And>x y. f x y = f y x) \<Longrightarrow> (\<And>x y z. f (f x y) z = f x (f y z)) \<Longrightarrow> 
    6.36 +     combine_options f y (combine_options f x z) =
    6.37 +     combine_options f x (combine_options f y z)"
    6.38 +  by (auto simp: combine_options_def split: option.splits)
    6.39 +
    6.40 +lemmas combine_options_ac = 
    6.41 +  combine_options_commute combine_options_assoc combine_options_left_commute
    6.42 +
    6.43 +
    6.44  context
    6.45  begin
    6.46  
     7.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     7.2 +++ b/src/HOL/Probability/PMF_Impl.thy	Tue May 31 13:02:44 2016 +0200
     7.3 @@ -0,0 +1,478 @@
     7.4 +(*  Title:      HOL/Probability/PMF_Impl.thy
     7.5 +    Author:     Manuel Eberl, TU M√ľnchen
     7.6 +    
     7.7 +    An implementation of PMFs using Mappings, which are implemented with association lists
     7.8 +    by default. Also includes Quickcheck setup for PMFs.
     7.9 +*)
    7.10 +
    7.11 +theory PMF_Impl
    7.12 +imports Probability_Mass_Function "~~/src/HOL/Library/AList_Mapping"
    7.13 +begin
    7.14 +
    7.15 +definition pmf_of_mapping :: "('a, real) mapping \<Rightarrow> 'a pmf" where
    7.16 +  "pmf_of_mapping m = embed_pmf (Mapping.lookup_default 0 m)" 
    7.17 +
    7.18 +lemma nn_integral_lookup_default:
    7.19 +  fixes m :: "('a, real) mapping"
    7.20 +  assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ x. x \<ge> 0)"
    7.21 +  shows   "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) = 
    7.22 +             ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)"
    7.23 +proof -
    7.24 +  have "nn_integral (count_space UNIV) (\<lambda>k. ennreal (Mapping.lookup_default 0 m k)) = 
    7.25 +          (\<Sum>x\<in>Mapping.keys m. ennreal (Mapping.lookup_default 0 m x))" using assms
    7.26 +    by (subst nn_integral_count_space'[of "Mapping.keys m"])
    7.27 +       (auto simp: Mapping.lookup_default_def keys_is_none_rep Option.is_none_def)
    7.28 +  also from assms have "\<dots> = ennreal (\<Sum>k\<in>Mapping.keys m. Mapping.lookup_default 0 m k)"
    7.29 +    by (intro setsum_ennreal) 
    7.30 +       (auto simp: Mapping.lookup_default_def All_mapping_def split: option.splits)
    7.31 +  finally show ?thesis .
    7.32 +qed
    7.33 +
    7.34 +lemma pmf_of_mapping: 
    7.35 +  assumes "finite (Mapping.keys m)" "All_mapping m (\<lambda>_ p. p \<ge> 0)"
    7.36 +  assumes "(\<Sum>x\<in>Mapping.keys m. Mapping.lookup_default 0 m x) = 1"
    7.37 +  shows   "pmf (pmf_of_mapping m) x = Mapping.lookup_default 0 m x"
    7.38 +  unfolding pmf_of_mapping_def
    7.39 +proof (intro pmf_embed_pmf)
    7.40 +  from assms show "(\<integral>\<^sup>+x. ennreal (Mapping.lookup_default 0 m x) \<partial>count_space UNIV) = 1"
    7.41 +    by (subst nn_integral_lookup_default) (simp_all)
    7.42 +qed (insert assms, simp add: All_mapping_def Mapping.lookup_default_def split: option.splits)
    7.43 +
    7.44 +lemma pmf_of_set_pmf_of_mapping:
    7.45 +  assumes "A \<noteq> {}" "set xs = A" "distinct xs"
    7.46 +  shows   "pmf_of_set A = pmf_of_mapping (Mapping.tabulate xs (\<lambda>_. 1 / real (length xs)))" 
    7.47 +           (is "?lhs = ?rhs")
    7.48 +  by (rule pmf_eqI, subst pmf_of_mapping)
    7.49 +     (insert assms, auto intro!: All_mapping_tabulate 
    7.50 +                         simp: Mapping.lookup_default_def lookup_tabulate distinct_card)
    7.51 +
    7.52 +lift_definition mapping_of_pmf :: "'a pmf \<Rightarrow> ('a, real) mapping" is
    7.53 +  "\<lambda>p x. if pmf p x = 0 then None else Some (pmf p x)" .
    7.54 +  
    7.55 +lemma lookup_default_mapping_of_pmf: 
    7.56 +  "Mapping.lookup_default 0 (mapping_of_pmf p) x = pmf p x"
    7.57 +  by (simp add: mapping_of_pmf.abs_eq lookup_default_def Mapping.lookup.abs_eq)
    7.58 +
    7.59 +context
    7.60 +begin
    7.61 +
    7.62 +interpretation pmf_as_function .
    7.63 +
    7.64 +lemma nn_integral_pmf_eq_1: "(\<integral>\<^sup>+ x. ennreal (pmf p x) \<partial>count_space UNIV) = 1"
    7.65 +  by transfer simp_all
    7.66 +end
    7.67 +  
    7.68 +lemma pmf_of_mapping_mapping_of_pmf [code abstype]: 
    7.69 +    "pmf_of_mapping (mapping_of_pmf p) = p"
    7.70 +  unfolding pmf_of_mapping_def
    7.71 +  by (rule pmf_eqI, subst pmf_embed_pmf)
    7.72 +     (insert nn_integral_pmf_eq_1[of p], 
    7.73 +      auto simp: lookup_default_mapping_of_pmf split: option.splits)
    7.74 +
    7.75 +lemma mapping_of_pmfI:
    7.76 +  assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup m x = Some (pmf p x)" 
    7.77 +  assumes "Mapping.keys m = set_pmf p"
    7.78 +  shows   "mapping_of_pmf p = m"
    7.79 +  using assms by transfer (rule ext, auto simp: set_pmf_eq)
    7.80 +  
    7.81 +lemma mapping_of_pmfI':
    7.82 +  assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup_default 0 m x = pmf p x" 
    7.83 +  assumes "Mapping.keys m = set_pmf p"
    7.84 +  shows   "mapping_of_pmf p = m"
    7.85 +  using assms unfolding Mapping.lookup_default_def 
    7.86 +  by transfer (rule ext, force simp: set_pmf_eq)
    7.87 +
    7.88 +lemma return_pmf_code [code abstract]:
    7.89 +  "mapping_of_pmf (return_pmf x) = Mapping.update x 1 Mapping.empty"
    7.90 +  by (intro mapping_of_pmfI) (auto simp: lookup_update')
    7.91 +
    7.92 +lemma pmf_of_set_code_aux:
    7.93 +  assumes "A \<noteq> {}" "set xs = A" "distinct xs"
    7.94 +  shows   "mapping_of_pmf (pmf_of_set A) = Mapping.tabulate xs (\<lambda>_. 1 / real (length xs))"
    7.95 +  using assms
    7.96 +  by (intro mapping_of_pmfI, subst pmf_of_set)
    7.97 +     (auto simp: lookup_tabulate distinct_card)
    7.98 +
    7.99 +definition pmf_of_set_impl where
   7.100 +  "pmf_of_set_impl A = mapping_of_pmf (pmf_of_set A)"
   7.101 +  
   7.102 +lemma pmf_of_set_impl_code [code]:
   7.103 +  "pmf_of_set_impl (set xs) = 
   7.104 +    (if xs = [] then
   7.105 +         Code.abort (STR ''pmf_of_set of empty set'') (\<lambda>_. mapping_of_pmf (pmf_of_set (set xs)))
   7.106 +      else let xs' = remdups xs; p = 1 / real (length xs') in
   7.107 +         Mapping.tabulate xs' (\<lambda>_. p))"
   7.108 +  unfolding pmf_of_set_impl_def
   7.109 +  using pmf_of_set_code_aux[of "set xs" "remdups xs"] by (simp add: Let_def)
   7.110 +
   7.111 +lemma pmf_of_set_code [code abstract]:
   7.112 +  "mapping_of_pmf (pmf_of_set A) = pmf_of_set_impl A"
   7.113 +  by (simp add: pmf_of_set_impl_def)
   7.114 +
   7.115 +
   7.116 +lemma pmf_of_multiset_pmf_of_mapping:
   7.117 +  assumes "A \<noteq> {#}" "set xs = set_mset A" "distinct xs"
   7.118 +  shows   "mapping_of_pmf (pmf_of_multiset A) = Mapping.tabulate xs (\<lambda>x. count A x / real (size A))" 
   7.119 +  using assms by (intro mapping_of_pmfI) (auto simp: lookup_tabulate)
   7.120 +
   7.121 +definition pmf_of_multiset_impl where
   7.122 +  "pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)"
   7.123 +
   7.124 +lemma pmf_of_multiset_impl_code [code]:
   7.125 +  "pmf_of_multiset_impl (mset xs) =
   7.126 +     (if xs = [] then 
   7.127 +        Code.abort (STR ''pmf_of_multiset of empty multiset'') 
   7.128 +          (\<lambda>_. mapping_of_pmf (pmf_of_multiset (mset xs)))
   7.129 +      else let xs' = remdups xs; p = 1 / real (length xs) in
   7.130 +         Mapping.tabulate xs' (\<lambda>x. real (count (mset xs) x) * p))"
   7.131 +  using pmf_of_multiset_pmf_of_mapping[of "mset xs" "remdups xs"]
   7.132 +  by (simp add: pmf_of_multiset_impl_def)
   7.133 +
   7.134 +lemma pmf_of_multiset_code [code abstract]:
   7.135 +  "mapping_of_pmf (pmf_of_multiset A) = pmf_of_multiset_impl A"
   7.136 +  by (simp add: pmf_of_multiset_impl_def)
   7.137 +
   7.138 +lemma bernoulli_pmf_code [code abstract]:
   7.139 +  "mapping_of_pmf (bernoulli_pmf p) = 
   7.140 +     (if p \<le> 0 then Mapping.update False 1 Mapping.empty 
   7.141 +      else if p \<ge> 1 then Mapping.update True 1 Mapping.empty
   7.142 +      else Mapping.update False (1 - p) (Mapping.update True p Mapping.empty))"
   7.143 +  by (intro mapping_of_pmfI) (auto simp: bernoulli_pmf.rep_eq lookup_update' set_pmf_eq)
   7.144 +
   7.145 +
   7.146 +
   7.147 +
   7.148 +lemma pmf_code [code]: "pmf p x = Mapping.lookup_default 0 (mapping_of_pmf p) x"
   7.149 +  unfolding mapping_of_pmf_def Mapping.lookup_default_def
   7.150 +  by (auto split: option.splits simp: id_def Mapping.lookup.abs_eq)
   7.151 +
   7.152 +lemma set_pmf_code [code]: "set_pmf p = Mapping.keys (mapping_of_pmf p)"
   7.153 +  by transfer (auto simp: dom_def set_pmf_eq)
   7.154 +
   7.155 +lemma keys_mapping_of_pmf [simp]: "Mapping.keys (mapping_of_pmf p) = set_pmf p"
   7.156 +  by transfer (auto simp: dom_def set_pmf_eq)
   7.157 +  
   7.158 +
   7.159 +
   7.160 +(* This is necessary since we want something the guarantees finiteness, but simply using 
   7.161 +   "finite" restricts the code equations to types where finiteness of the universe can 
   7.162 +   be decided. This simply fails when finiteness is not clear *)
   7.163 +definition is_list_set where "is_list_set A = finite A"
   7.164 +
   7.165 +lemma is_list_set_code [code]: "is_list_set (set xs) = True"
   7.166 +  by (simp add: is_list_set_def)
   7.167 +
   7.168 +definition fold_combine_plus where
   7.169 +  "fold_combine_plus = comm_monoid_set.F (Mapping.combine (op + :: real \<Rightarrow> _)) Mapping.empty"
   7.170 +
   7.171 +context
   7.172 +begin
   7.173 +
   7.174 +interpretation fold_combine_plus: combine_mapping_abel_semigroup "op + :: real \<Rightarrow> _"
   7.175 +  by unfold_locales (simp_all add: add_ac)
   7.176 +  
   7.177 +qualified lemma lookup_default_fold_combine_plus: 
   7.178 +  fixes A :: "'b set" and f :: "'b \<Rightarrow> ('a, real) mapping"
   7.179 +  assumes "finite A"
   7.180 +  shows   "Mapping.lookup_default 0 (fold_combine_plus f A) x = 
   7.181 +             (\<Sum>y\<in>A. Mapping.lookup_default 0 (f y) x)"
   7.182 +  unfolding fold_combine_plus_def using assms 
   7.183 +    by (induction A rule: finite_induct) 
   7.184 +       (simp_all add: lookup_default_empty lookup_default_neutral_combine)
   7.185 +
   7.186 +qualified lemma keys_fold_combine_plus: 
   7.187 +  "finite A \<Longrightarrow> Mapping.keys (fold_combine_plus f A) = (\<Union>x\<in>A. Mapping.keys (f x))"
   7.188 +  by (simp add: fold_combine_plus_def fold_combine_plus.keys_fold_combine)
   7.189 +
   7.190 +qualified lemma fold_combine_plus_code [code]:
   7.191 +  "fold_combine_plus g (set xs) = foldr (\<lambda>x. Mapping.combine op+ (g x)) (remdups xs) Mapping.empty"
   7.192 +  by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code)
   7.193 +
   7.194 +private lemma lookup_default_0_map_values:
   7.195 +  assumes "f 0 = 0"
   7.196 +  shows   "Mapping.lookup_default 0 (Mapping.map_values f m) x = f (Mapping.lookup_default 0 m x)"
   7.197 +  unfolding Mapping.lookup_default_def
   7.198 +  using assms by transfer (auto split: option.splits)  
   7.199 +
   7.200 +qualified lemma mapping_of_bind_pmf:
   7.201 +  assumes "finite (set_pmf p)"
   7.202 +  shows   "mapping_of_pmf (bind_pmf p f) = 
   7.203 +             fold_combine_plus (\<lambda>x. Mapping.map_values (op * (pmf p x)) 
   7.204 +               (mapping_of_pmf (f x))) (set_pmf p)"
   7.205 +  using assms
   7.206 +  by (intro mapping_of_pmfI')
   7.207 +     (auto simp: keys_fold_combine_plus lookup_default_fold_combine_plus 
   7.208 +                 pmf_bind integral_measure_pmf lookup_default_0_map_values 
   7.209 +                 lookup_default_mapping_of_pmf mult_ac)
   7.210 +
   7.211 +lemma bind_pmf_code [code abstract]:
   7.212 +  "mapping_of_pmf (bind_pmf p f) = 
   7.213 +     (let A = set_pmf p in if is_list_set A then
   7.214 +       fold_combine_plus (\<lambda>x. Mapping.map_values (op * (pmf p x)) (mapping_of_pmf (f x))) A
   7.215 +     else
   7.216 +       Code.abort (STR ''bind_pmf with infinite support.'') (\<lambda>_. mapping_of_pmf (bind_pmf p f)))"
   7.217 +  using mapping_of_bind_pmf[of p f] by (auto simp: Let_def is_list_set_def)
   7.218 +
   7.219 +end
   7.220 +
   7.221 +hide_const (open) is_list_set fold_combine_plus
   7.222 +
   7.223 +
   7.224 +lift_definition cond_pmf_impl :: "'a pmf \<Rightarrow> 'a set \<Rightarrow> ('a, real) mapping option" is
   7.225 +  "\<lambda>p A. if A \<inter> set_pmf p = {} then None else 
   7.226 +     Some (\<lambda>x. if x \<in> A \<inter> set_pmf p then Some (pmf p x / measure_pmf.prob p A) else None)" .
   7.227 +
   7.228 +lemma cond_pmf_impl_code [code]:
   7.229 +  "cond_pmf_impl p (set xs) = (
   7.230 +     let B = set_pmf p;
   7.231 +         xs' = remdups (filter (\<lambda>x. x \<in> B) xs);
   7.232 +         prob = listsum (map (pmf p) xs')
   7.233 +     in  if prob = 0 then 
   7.234 +           None
   7.235 +         else
   7.236 +           Some (Mapping.map_values (\<lambda>y. y / prob) 
   7.237 +             (Mapping.filter (\<lambda>k _. k \<in> set xs') (mapping_of_pmf p))))"     
   7.238 +proof -
   7.239 +  define xs' where "xs' = remdups (filter (\<lambda>x. x \<in> set_pmf p) xs)"
   7.240 +  have xs': "set xs' = set xs \<inter> set_pmf p" "distinct xs'" by (auto simp: xs'_def)
   7.241 +  define prob where "prob = listsum (map (pmf p) xs')"
   7.242 +  have "prob = (\<Sum>x\<in>set xs'. pmf p x)"
   7.243 +    unfolding prob_def by (rule listsum_distinct_conv_setsum_set) (simp_all add: xs'_def)
   7.244 +  also note xs'(1)
   7.245 +  also have "(\<Sum>x\<in>set xs \<inter> set_pmf p. pmf p x) = (\<Sum>x\<in>set xs. pmf p x)"
   7.246 +    by (intro setsum.mono_neutral_left) (auto simp: set_pmf_eq)
   7.247 +  finally have prob1: "prob = (\<Sum>x\<in>set xs. pmf p x)" .
   7.248 +  hence prob2: "prob = measure_pmf.prob p (set xs)"
   7.249 +    by (subst measure_measure_pmf_finite) simp_all
   7.250 +  have prob3: "prob = 0 \<longleftrightarrow> set xs \<inter> set_pmf p = {}"
   7.251 +    by (subst prob1, subst setsum_nonneg_eq_0_iff) (auto simp: set_pmf_eq)
   7.252 +  
   7.253 +  show ?thesis
   7.254 +  proof (cases "prob = 0")
   7.255 +    case True
   7.256 +    hence "set xs \<inter> set_pmf p = {}" by (subst (asm) prob3)
   7.257 +    with True show ?thesis by (simp add: Let_def prob_def xs'_def cond_pmf_impl.abs_eq)
   7.258 +  next
   7.259 +    case False
   7.260 +    hence A: "set xs' \<noteq> {}" unfolding xs' by (subst (asm) prob3) auto
   7.261 +    with xs' prob3 have prob_nz: "prob \<noteq> 0" by auto
   7.262 +    fix x
   7.263 +    have "cond_pmf_impl p (set xs) = 
   7.264 +            Some (mapping.Mapping (\<lambda>x. if x \<in> set xs' then 
   7.265 +              Some (pmf p x / measure_pmf.prob p (set xs)) else None))" 
   7.266 +         (is "_ = Some ?m")
   7.267 +      using A unfolding xs'_def by transfer auto
   7.268 +    also have "?m = Mapping.map_values (\<lambda>y. y / prob) 
   7.269 +                 (Mapping.filter (\<lambda>k _. k \<in> set xs') (mapping_of_pmf p))"
   7.270 +      unfolding prob2 [symmetric] xs' using xs' prob_nz 
   7.271 +      by transfer (rule ext, simp add: set_pmf_eq)
   7.272 +    finally show ?thesis using False by (simp add: Let_def prob_def xs'_def)
   7.273 +  qed
   7.274 +qed
   7.275 +
   7.276 +lemma cond_pmf_code [code abstract]:
   7.277 +  "mapping_of_pmf (cond_pmf p A) = 
   7.278 +     (case cond_pmf_impl p A of
   7.279 +        None \<Rightarrow> Code.abort (STR ''cond_pmf with set of probability 0'')
   7.280 +                  (\<lambda>_. mapping_of_pmf (cond_pmf p A))
   7.281 +      | Some m \<Rightarrow> m)"
   7.282 +proof (cases "cond_pmf_impl p A")
   7.283 +  case (Some m)
   7.284 +  hence A: "set_pmf p \<inter> A \<noteq> {}" by transfer (auto split: if_splits)
   7.285 +  from Some have B: "Mapping.keys m = set_pmf (cond_pmf p A)"
   7.286 +    by (subst set_cond_pmf[OF A], transfer) (auto split: if_splits)
   7.287 +  with Some A have "mapping_of_pmf (cond_pmf p A) = m"
   7.288 +    by (intro mapping_of_pmfI[OF _ B], transfer) (auto split: if_splits simp: pmf_cond)
   7.289 +  with Some show ?thesis by simp
   7.290 +qed simp_all
   7.291 +
   7.292 +
   7.293 +lemma binomial_pmf_code [code abstract]:
   7.294 +  "mapping_of_pmf (binomial_pmf n p) = (
   7.295 +     if p < 0 \<or> p > 1 then 
   7.296 +       Code.abort (STR ''binomial_pmf with invalid probability'') (\<lambda>_. mapping_of_pmf (binomial_pmf n p))
   7.297 +     else if p = 0 then Mapping.update 0 1 Mapping.empty
   7.298 +     else if p = 1 then Mapping.update n 1 Mapping.empty
   7.299 +     else Mapping.tabulate [0..<Suc n] (\<lambda>k. real (n choose k) * p ^ k * (1 - p) ^ (n - k)))"
   7.300 +  by (cases "p < 0 \<or> p > 1")
   7.301 +     (simp, intro mapping_of_pmfI, 
   7.302 +      auto simp: lookup_update' lookup_empty set_pmf_binomial_eq lookup_tabulate split: if_splits)
   7.303 +
   7.304 +lemma pred_pmf_code [code]:
   7.305 +  "pred_pmf P p = (\<forall>x\<in>set_pmf p. P x)"
   7.306 +  by (auto simp: pred_pmf_def)
   7.307 +
   7.308 +
   7.309 +definition pmf_integral where
   7.310 +  "pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \<Rightarrow> real)"
   7.311 +
   7.312 +definition pmf_set_integral where
   7.313 +  "pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\<lambda>x. indicator A x * f x :: real)"
   7.314 +
   7.315 +definition pmf_prob where
   7.316 +  "pmf_prob p A = measure_pmf.prob p A"
   7.317 +
   7.318 +lemma pmf_integral_pmf_set_integral [code]:
   7.319 +  "pmf_integral p f = pmf_set_integral p f (set_pmf p)"
   7.320 +  unfolding pmf_integral_def pmf_set_integral_def
   7.321 +  by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff)
   7.322 +
   7.323 +lemma pmf_set_integral_code [code]:
   7.324 +  "pmf_set_integral p f (set xs) = listsum (map (\<lambda>x. pmf p x * f x) (remdups xs))"
   7.325 +proof -
   7.326 +  have "listsum (map (\<lambda>x. pmf p x * f x) (remdups xs)) = (\<Sum>x\<in>set xs. pmf p x * f x)"
   7.327 +    by (subst listsum_distinct_conv_setsum_set) simp_all
   7.328 +  also have "\<dots> = pmf_set_integral p f (set xs)" unfolding pmf_set_integral_def
   7.329 +   by (subst integral_measure_pmf[of "set xs"])
   7.330 +      (auto simp: indicator_def mult_ac split: if_splits)
   7.331 +  finally show ?thesis ..
   7.332 +qed
   7.333 +
   7.334 +lemma pmf_prob_code [code]:
   7.335 +  "pmf_prob p (set xs) = listsum (map (pmf p) (remdups xs))"
   7.336 +proof -
   7.337 +  have "pmf_prob p (set xs) = pmf_set_integral p (\<lambda>_. 1) (set xs)"
   7.338 +    unfolding pmf_prob_def pmf_set_integral_def by simp
   7.339 +  also have "\<dots> = listsum (map (pmf p) (remdups xs))"
   7.340 +    unfolding pmf_set_integral_code by simp
   7.341 +  finally show ?thesis .
   7.342 +qed
   7.343 +
   7.344 +lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p"
   7.345 +  by (intro ext) (simp add: pmf_prob_def)
   7.346 +
   7.347 +(* Why does this not work without parameters? *)
   7.348 +lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p"
   7.349 +  by (intro ext) (simp add: pmf_integral_def)
   7.350 +
   7.351 +lemma mapping_of_pmf_pmf_of_list:
   7.352 +  assumes "\<And>x. x \<in> snd ` set xs \<Longrightarrow> x > 0" "listsum (map snd xs) = 1"
   7.353 +  shows   "mapping_of_pmf (pmf_of_list xs) = 
   7.354 +             Mapping.tabulate (remdups (map fst xs)) 
   7.355 +               (\<lambda>x. listsum (map snd (filter (\<lambda>z. fst z = x) xs)))"
   7.356 +proof -
   7.357 +  from assms have wf: "pmf_of_list_wf xs" by (intro pmf_of_list_wfI) force
   7.358 +  moreover from this assms have "set_pmf (pmf_of_list xs) = fst ` set xs"
   7.359 +    by (intro set_pmf_of_list_eq) auto
   7.360 +  ultimately show ?thesis
   7.361 +    by (intro mapping_of_pmfI) (auto simp: lookup_tabulate pmf_pmf_of_list)
   7.362 +qed
   7.363 +
   7.364 +lemma mapping_of_pmf_pmf_of_list':
   7.365 +  assumes "pmf_of_list_wf xs"
   7.366 +  defines "xs' \<equiv> filter (\<lambda>z. snd z \<noteq> 0) xs"
   7.367 +  shows   "mapping_of_pmf (pmf_of_list xs) = 
   7.368 +             Mapping.tabulate (remdups (map fst xs')) 
   7.369 +               (\<lambda>x. listsum (map snd (filter (\<lambda>z. fst z = x) xs')))" (is "_ = ?rhs") 
   7.370 +proof -
   7.371 +  have wf: "pmf_of_list_wf xs'" unfolding xs'_def by (rule pmf_of_list_remove_zeros) fact
   7.372 +  have pos: "\<forall>x\<in>snd`set xs'. x > 0" using assms(1) unfolding xs'_def
   7.373 +    by (force simp: pmf_of_list_wf_def)
   7.374 +  from assms have "pmf_of_list xs = pmf_of_list xs'" 
   7.375 +    unfolding xs'_def by (subst pmf_of_list_remove_zeros) simp_all
   7.376 +  also from wf pos have "mapping_of_pmf \<dots> = ?rhs"
   7.377 +    by (intro mapping_of_pmf_pmf_of_list) (auto simp: pmf_of_list_wf_def)
   7.378 +  finally show ?thesis .
   7.379 +qed
   7.380 +
   7.381 +lemma pmf_of_list_wf_code [code]:
   7.382 +  "pmf_of_list_wf xs \<longleftrightarrow> list_all (\<lambda>z. snd z \<ge> 0) xs \<and> listsum (map snd xs) = 1"
   7.383 +  by (auto simp add: pmf_of_list_wf_def list_all_def)
   7.384 +
   7.385 +lemma pmf_of_list_code [code abstract]:
   7.386 +  "mapping_of_pmf (pmf_of_list xs) = (
   7.387 +     if pmf_of_list_wf xs then
   7.388 +       let xs' = filter (\<lambda>z. snd z \<noteq> 0) xs
   7.389 +       in  Mapping.tabulate (remdups (map fst xs')) 
   7.390 +             (\<lambda>x. listsum (map snd (filter (\<lambda>z. fst z = x) xs')))
   7.391 +     else
   7.392 +       Code.abort (STR ''Invalid list for pmf_of_list'') (\<lambda>_. mapping_of_pmf (pmf_of_list xs)))"
   7.393 +  using mapping_of_pmf_pmf_of_list'[of xs] by (simp add: Let_def)
   7.394 +
   7.395 +  
   7.396 +lemma mapping_of_pmf_eq_iff [simp]:
   7.397 +  "mapping_of_pmf p = mapping_of_pmf q \<longleftrightarrow> p = (q :: 'a pmf)"
   7.398 +proof (transfer, intro iffI pmf_eqI)
   7.399 +  fix p q :: "'a pmf" and x :: 'a
   7.400 +  assume "(\<lambda>x. if pmf p x = 0 then None else Some (pmf p x)) =
   7.401 +            (\<lambda>x. if pmf q x = 0 then None else Some (pmf q x))"
   7.402 +  hence "(if pmf p x = 0 then None else Some (pmf p x)) =
   7.403 +           (if pmf q x = 0 then None else Some (pmf q x))" for x
   7.404 +    by (simp add: fun_eq_iff)
   7.405 +  from this[of x] show "pmf p x = pmf q x" by (auto split: if_splits)
   7.406 +qed (simp_all cong: if_cong)
   7.407 +
   7.408 +definition "pmf_of_alist xs = embed_pmf (\<lambda>x. case map_of xs x of Some p \<Rightarrow> p | None \<Rightarrow> 0)"
   7.409 +
   7.410 +lemma pmf_of_mapping_Mapping [code_post]:
   7.411 +    "pmf_of_mapping (Mapping xs) = pmf_of_alist xs"
   7.412 +  unfolding pmf_of_mapping_def Mapping.lookup_default_def [abs_def] pmf_of_alist_def
   7.413 +  by transfer simp_all
   7.414 +
   7.415 +
   7.416 +instantiation pmf :: (equal) equal
   7.417 +begin
   7.418 +
   7.419 +definition "equal_pmf p q = (mapping_of_pmf p = mapping_of_pmf (q :: 'a pmf))"
   7.420 +
   7.421 +instance by standard (simp add: equal_pmf_def)
   7.422 +end
   7.423 +
   7.424 +
   7.425 +definition (in term_syntax)
   7.426 +  pmfify :: "('a::typerep multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)) \<Rightarrow>
   7.427 +             'a \<times> (unit \<Rightarrow> Code_Evaluation.term) \<Rightarrow>
   7.428 +             'a pmf \<times> (unit \<Rightarrow> Code_Evaluation.term)" where
   7.429 +  [code_unfold]: "pmfify A x =  
   7.430 +    Code_Evaluation.valtermify pmf_of_multiset {\<cdot>} 
   7.431 +      (Code_Evaluation.valtermify (op +) {\<cdot>} A {\<cdot>} 
   7.432 +       (Code_Evaluation.valtermify single {\<cdot>} x))"
   7.433 +
   7.434 +
   7.435 +notation fcomp (infixl "\<circ>>" 60)
   7.436 +notation scomp (infixl "\<circ>\<rightarrow>" 60)
   7.437 +
   7.438 +instantiation pmf :: (random) random
   7.439 +begin
   7.440 +
   7.441 +definition
   7.442 +  "Quickcheck_Random.random i = 
   7.443 +     Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>A. 
   7.444 +       Quickcheck_Random.random i \<circ>\<rightarrow> (\<lambda>x. Pair (pmfify A x)))"
   7.445 +
   7.446 +instance ..
   7.447 +
   7.448 +end
   7.449 +
   7.450 +no_notation fcomp (infixl "\<circ>>" 60)
   7.451 +no_notation scomp (infixl "\<circ>\<rightarrow>" 60)
   7.452 +
   7.453 +(*
   7.454 +instantiation pmf :: (exhaustive) exhaustive
   7.455 +begin
   7.456 +
   7.457 +definition exhaustive_pmf :: "('a pmf \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option"
   7.458 +where
   7.459 +  "exhaustive_pmf f i =
   7.460 +     Quickcheck_Exhaustive.exhaustive (\<lambda>A. 
   7.461 +       Quickcheck_Exhaustive.exhaustive (\<lambda>x. f (pmf_of_multiset (A + {#x#}))) i) i"
   7.462 +
   7.463 +instance ..
   7.464 +
   7.465 +end
   7.466 +*)
   7.467 +
   7.468 +instantiation pmf :: (full_exhaustive) full_exhaustive
   7.469 +begin
   7.470 +
   7.471 +definition full_exhaustive_pmf :: "('a pmf \<times> (unit \<Rightarrow> term) \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option"
   7.472 +where
   7.473 +  "full_exhaustive_pmf f i =
   7.474 +     Quickcheck_Exhaustive.full_exhaustive (\<lambda>A. 
   7.475 +       Quickcheck_Exhaustive.full_exhaustive (\<lambda>x. f (pmfify A x)) i) i"
   7.476 +
   7.477 +instance ..
   7.478 +
   7.479 +end
   7.480 +
   7.481 +end
   7.482 \ No newline at end of file
     8.1 --- a/src/HOL/Probability/Probability.thy	Tue May 31 12:24:43 2016 +0200
     8.2 +++ b/src/HOL/Probability/Probability.thy	Tue May 31 13:02:44 2016 +0200
     8.3 @@ -8,6 +8,7 @@
     8.4    Complete_Measure
     8.5    Projective_Limit
     8.6    Probability_Mass_Function
     8.7 +  PMF_Impl
     8.8    Stream_Space
     8.9    Random_Permutations
    8.10    Embed_Measure
     9.1 --- a/src/HOL/Probability/Probability_Mass_Function.thy	Tue May 31 12:24:43 2016 +0200
     9.2 +++ b/src/HOL/Probability/Probability_Mass_Function.thy	Tue May 31 13:02:44 2016 +0200
     9.3 @@ -1787,6 +1787,58 @@
     9.4  end
     9.5  
     9.6  
     9.7 +primrec replicate_pmf :: "nat \<Rightarrow> 'a pmf \<Rightarrow> 'a list pmf" where
     9.8 +  "replicate_pmf 0 _ = return_pmf []"
     9.9 +| "replicate_pmf (Suc n) p = do {x \<leftarrow> p; xs \<leftarrow> replicate_pmf n p; return_pmf (x#xs)}"
    9.10 +
    9.11 +lemma replicate_pmf_1: "replicate_pmf 1 p = map_pmf (\<lambda>x. [x]) p"
    9.12 +  by (simp add: map_pmf_def bind_return_pmf)
    9.13 +  
    9.14 +lemma set_replicate_pmf: 
    9.15 +  "set_pmf (replicate_pmf n p) = {xs\<in>lists (set_pmf p). length xs = n}"
    9.16 +  by (induction n) (auto simp: length_Suc_conv)
    9.17 +
    9.18 +lemma replicate_pmf_distrib:
    9.19 +  "replicate_pmf (m + n) p = 
    9.20 +     do {xs \<leftarrow> replicate_pmf m p; ys \<leftarrow> replicate_pmf n p; return_pmf (xs @ ys)}"
    9.21 +  by (induction m) (simp_all add: bind_return_pmf bind_return_pmf' bind_assoc_pmf)
    9.22 +
    9.23 +lemma power_diff': 
    9.24 +  assumes "b \<le> a"
    9.25 +  shows   "x ^ (a - b) = (if x = 0 \<and> a = b then 1 else x ^ a / (x::'a::field) ^ b)"
    9.26 +proof (cases "x = 0")
    9.27 +  case True
    9.28 +  with assms show ?thesis by (cases "a - b") simp_all
    9.29 +qed (insert assms, simp_all add: power_diff)
    9.30 +
    9.31 +  
    9.32 +lemma binomial_pmf_Suc:
    9.33 +  assumes "p \<in> {0..1}"
    9.34 +  shows   "binomial_pmf (Suc n) p = 
    9.35 +             do {b \<leftarrow> bernoulli_pmf p; 
    9.36 +                 k \<leftarrow> binomial_pmf n p; 
    9.37 +                 return_pmf ((if b then 1 else 0) + k)}" (is "_ = ?rhs")
    9.38 +proof (intro pmf_eqI)
    9.39 +  fix k
    9.40 +  have A: "indicator {Suc a} (Suc b) = indicator {a} b" for a b
    9.41 +    by (simp add: indicator_def)
    9.42 +  show "pmf (binomial_pmf (Suc n) p) k = pmf ?rhs k"
    9.43 +    by (cases k; cases "k > n")
    9.44 +       (insert assms, auto simp: pmf_bind measure_pmf_single A divide_simps algebra_simps
    9.45 +          not_less less_eq_Suc_le [symmetric] power_diff')
    9.46 +qed
    9.47 +
    9.48 +lemma binomial_pmf_0: "p \<in> {0..1} \<Longrightarrow> binomial_pmf 0 p = return_pmf 0"
    9.49 +  by (rule pmf_eqI) (simp_all add: indicator_def)
    9.50 +
    9.51 +lemma binomial_pmf_altdef:
    9.52 +  assumes "p \<in> {0..1}"
    9.53 +  shows   "binomial_pmf n p = map_pmf (length \<circ> filter id) (replicate_pmf n (bernoulli_pmf p))"
    9.54 +  by (induction n) 
    9.55 +     (insert assms, auto simp: binomial_pmf_Suc map_pmf_def bind_return_pmf bind_assoc_pmf 
    9.56 +        bind_return_pmf' binomial_pmf_0 intro!: bind_pmf_cong)
    9.57 +
    9.58 +
    9.59  subsection \<open>PMFs from assiciation lists\<close>
    9.60  
    9.61  definition pmf_of_list ::" ('a \<times> real) list \<Rightarrow> 'a pmf" where 
    9.62 @@ -1921,4 +1973,52 @@
    9.63    using assms unfolding pmf_of_list_wf_def Sigma_Algebra.measure_def
    9.64    by (subst emeasure_pmf_of_list [OF assms], subst enn2real_ennreal) (auto intro!: listsum_nonneg)
    9.65  
    9.66 +(* TODO Move? *)
    9.67 +lemma listsum_nonneg_eq_zero_iff:
    9.68 +  fixes xs :: "'a :: linordered_ab_group_add list"
    9.69 +  shows "(\<And>x. x \<in> set xs \<Longrightarrow> x \<ge> 0) \<Longrightarrow> listsum xs = 0 \<longleftrightarrow> set xs \<subseteq> {0}"
    9.70 +proof (induction xs)
    9.71 +  case (Cons x xs)
    9.72 +  from Cons.prems have "listsum (x#xs) = 0 \<longleftrightarrow> x = 0 \<and> listsum xs = 0"
    9.73 +    unfolding listsum_simps by (subst add_nonneg_eq_0_iff) (auto intro: listsum_nonneg)
    9.74 +  with Cons.IH Cons.prems show ?case by simp
    9.75 +qed simp_all
    9.76 +
    9.77 +lemma listsum_filter_nonzero:
    9.78 +  "listsum (filter (\<lambda>x. x \<noteq> 0) xs) = listsum xs"
    9.79 +  by (induction xs) simp_all
    9.80 +(* END MOVE *)
    9.81 +  
    9.82 +lemma set_pmf_of_list_eq:
    9.83 +  assumes "pmf_of_list_wf xs" "\<And>x. x \<in> snd ` set xs \<Longrightarrow> x > 0"
    9.84 +  shows   "set_pmf (pmf_of_list xs) = fst ` set xs"
    9.85 +proof
    9.86 +  {
    9.87 +    fix x assume A: "x \<in> fst ` set xs" and B: "x \<notin> set_pmf (pmf_of_list xs)"
    9.88 +    then obtain y where y: "(x, y) \<in> set xs" by auto
    9.89 +    from B have "listsum (map snd [z\<leftarrow>xs. fst z = x]) = 0"
    9.90 +      by (simp add: pmf_pmf_of_list[OF assms(1)] set_pmf_eq)
    9.91 +    moreover from y have "y \<in> snd ` {xa \<in> set xs. fst xa = x}" by force
    9.92 +    ultimately have "y = 0" using assms(1) 
    9.93 +      by (subst (asm) listsum_nonneg_eq_zero_iff) (auto simp: pmf_of_list_wf_def)
    9.94 +    with assms(2) y have False by force
    9.95 +  }
    9.96 +  thus "fst ` set xs \<subseteq> set_pmf (pmf_of_list xs)" by blast
    9.97 +qed (insert set_pmf_of_list[OF assms(1)], simp_all)
    9.98 +  
    9.99 +lemma pmf_of_list_remove_zeros:
   9.100 +  assumes "pmf_of_list_wf xs"
   9.101 +  defines "xs' \<equiv> filter (\<lambda>z. snd z \<noteq> 0) xs"
   9.102 +  shows   "pmf_of_list_wf xs'" "pmf_of_list xs' = pmf_of_list xs"
   9.103 +proof -
   9.104 +  have "map snd [z\<leftarrow>xs . snd z \<noteq> 0] = filter (\<lambda>x. x \<noteq> 0) (map snd xs)"
   9.105 +    by (induction xs) simp_all
   9.106 +  with assms(1) show wf: "pmf_of_list_wf xs'"
   9.107 +    by (auto simp: pmf_of_list_wf_def xs'_def listsum_filter_nonzero)
   9.108 +  have "listsum (map snd [z\<leftarrow>xs' . fst z = i]) = listsum (map snd [z\<leftarrow>xs . fst z = i])" for i
   9.109 +    unfolding xs'_def by (induction xs) simp_all
   9.110 +  with assms(1) wf show "pmf_of_list xs' = pmf_of_list xs"
   9.111 +    by (intro pmf_eqI) (simp_all add: pmf_pmf_of_list)
   9.112 +qed
   9.113 +
   9.114  end
    10.1 --- a/src/HOL/Probability/Random_Permutations.thy	Tue May 31 12:24:43 2016 +0200
    10.2 +++ b/src/HOL/Probability/Random_Permutations.thy	Tue May 31 13:02:44 2016 +0200
    10.3 @@ -102,7 +102,11 @@
    10.4               map_pmf (\<lambda>xs. fold f xs x) (pmf_of_set (permutations_of_set A))"
    10.5    by (subst fold_random_permutation_foldl [OF assms], intro map_pmf_cong)
    10.6       (simp_all add: foldl_conv_fold)
    10.7 -
    10.8 +     
    10.9 +lemma fold_random_permutation_code [code]: 
   10.10 +  "fold_random_permutation f x (set xs) =
   10.11 +     map_pmf (foldl (\<lambda>x y. f y x) x) (pmf_of_set (permutations_of_set (set xs)))"
   10.12 +  by (simp add: fold_random_permutation_foldl)
   10.13  
   10.14  text \<open>
   10.15    We now introduce a slightly generalised version of the above fold 
   10.16 @@ -134,7 +138,7 @@
   10.17    We now show that the recursive definition is equivalent to 
   10.18    a random fold followed by a monadic bind.
   10.19  \<close>
   10.20 -lemma fold_bind_random_permutation_altdef:
   10.21 +lemma fold_bind_random_permutation_altdef [code]:
   10.22    "fold_bind_random_permutation f g x A = fold_random_permutation f x A \<bind> g"
   10.23  proof (induction f x A rule: fold_random_permutation.induct [case_names empty infinite remove])
   10.24    case (remove A f x)