Tuned code equations for mappings and PMFs
authoreberlm
Wed, 01 Jun 2016 13:48:34 +0200
changeset 63195 f3f08c0d4aaf
parent 63194 0b7bdb75f451
child 63196 82552b478356
Tuned code equations for mappings and PMFs
src/HOL/Library/AList_Mapping.thy
src/HOL/Library/DAList_Multiset.thy
src/HOL/Library/Mapping.thy
src/HOL/Library/Multiset.thy
src/HOL/Probability/PMF_Impl.thy
--- a/src/HOL/Library/AList_Mapping.thy	Tue May 31 13:02:44 2016 +0200
+++ b/src/HOL/Library/AList_Mapping.thy	Wed Jun 01 13:48:34 2016 +0200
@@ -64,13 +64,25 @@
 qed
 
 lemma map_values_Mapping [code]:
-  fixes f :: "'a \<Rightarrow> 'b" and xs :: "('c \<times> 'a) list"
-  shows "Mapping.map_values f (Mapping xs) = Mapping (map (\<lambda>(x,y). (x, f y)) xs)"
+  fixes f :: "'c \<Rightarrow> 'a \<Rightarrow> 'b" and xs :: "('c \<times> 'a) list"
+  shows "Mapping.map_values f (Mapping xs) = Mapping (map (\<lambda>(x,y). (x, f x y)) xs)"
 proof (transfer, rule ext, goal_cases)
   case (1 f xs x)
   thus ?case by (induction xs) auto
 qed
 
+lemma combine_with_key_code [code]: 
+  "Mapping.combine_with_key f (Mapping xs) (Mapping ys) =
+     Mapping.tabulate (remdups (map fst xs @ map fst ys)) 
+       (\<lambda>x. the (combine_options (f x) (map_of xs x) (map_of ys x)))"
+proof (transfer, rule ext, rule sym, goal_cases)
+  case (1 f xs ys x)
+  show ?case
+  by (cases "map_of xs x"; cases "map_of ys x"; simp)
+     (force simp: map_of_eq_None_iff combine_options_def option.the_def o_def image_iff
+            dest: map_of_SomeD split: option.splits)+
+qed
+
 lemma combine_code [code]: 
   "Mapping.combine f (Mapping xs) (Mapping ys) =
      Mapping.tabulate (remdups (map fst xs @ map fst ys)) 
@@ -79,7 +91,7 @@
   case (1 f xs ys x)
   show ?case
   by (cases "map_of xs x"; cases "map_of ys x"; simp)
-     (force simp: map_of_eq_None_iff combine_options_altdef option.the_def o_def image_iff
+     (force simp: map_of_eq_None_iff combine_options_def option.the_def o_def image_iff
             dest: map_of_SomeD split: option.splits)+
 qed
 
--- a/src/HOL/Library/DAList_Multiset.thy	Tue May 31 13:02:44 2016 +0200
+++ b/src/HOL/Library/DAList_Multiset.thy	Wed Jun 01 13:48:34 2016 +0200
@@ -12,6 +12,8 @@
 
 lemma [code, code del]: "{#} = {#}" ..
 
+lemma [code, code del]: "Multiset.is_empty = Multiset.is_empty" ..
+
 lemma [code, code del]: "single = single" ..
 
 lemma [code, code del]: "plus = (plus :: 'a multiset \<Rightarrow> _)" ..
@@ -187,6 +189,27 @@
 lemma Mempty_Bag [code]: "{#} = Bag (DAList.empty)"
   by (simp add: multiset_eq_iff alist.Alist_inverse DAList.empty_def)
 
+lift_definition is_empty_Bag_impl :: "('a, nat) alist \<Rightarrow> bool" is
+  "\<lambda>xs. list_all (\<lambda>x. snd x = 0) xs" .
+
+lemma is_empty_Bag [code]: "Multiset.is_empty (Bag xs) \<longleftrightarrow> is_empty_Bag_impl xs"
+proof -
+  have "Multiset.is_empty (Bag xs) \<longleftrightarrow> (\<forall>x. count (Bag xs) x = 0)"
+    unfolding Multiset.is_empty_def multiset_eq_iff by simp
+  also have "\<dots> \<longleftrightarrow> (\<forall>x\<in>fst ` set (alist.impl_of xs). count (Bag xs) x = 0)"
+  proof (intro iffI allI ballI)
+    fix x assume A: "\<forall>x\<in>fst ` set (alist.impl_of xs). count (Bag xs) x = 0"
+    thus "count (Bag xs) x = 0"
+    proof (cases "x \<in> fst ` set (alist.impl_of xs)")
+      case False
+      thus ?thesis by (force simp: count_of_def split: option.splits)
+    qed (insert A, auto)
+  qed simp_all
+  also have "\<dots> \<longleftrightarrow> list_all (\<lambda>x. snd x = 0) (alist.impl_of xs)" 
+    by (auto simp: count_of_def list_all_def)
+  finally show ?thesis by (simp add: is_empty_Bag_impl.rep_eq)
+qed
+
 lemma single_Bag [code]: "{#x#} = Bag (DAList.update x 1 DAList.empty)"
   by (simp add: multiset_eq_iff alist.Alist_inverse update.rep_eq empty.rep_eq)
 
--- a/src/HOL/Library/Mapping.thy	Tue May 31 13:02:44 2016 +0200
+++ b/src/HOL/Library/Mapping.thy	Wed Jun 01 13:48:34 2016 +0200
@@ -240,6 +240,32 @@
 
 subsection \<open>Properties\<close>
 
+lemma mapping_eqI:
+  "(\<And>x. lookup m x = lookup m' x) \<Longrightarrow> m = m'"
+  by transfer (simp add: fun_eq_iff)
+
+lemma mapping_eqI': 
+  assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup_default d m x = Mapping.lookup_default d m' x" 
+      and "Mapping.keys m = Mapping.keys m'"
+  shows   "m = m'"
+proof (intro mapping_eqI)
+  fix x
+  show "Mapping.lookup m x = Mapping.lookup m' x"
+  proof (cases "Mapping.lookup m x")
+    case None
+    hence "x \<notin> Mapping.keys m" by transfer (simp add: dom_def)
+    hence "x \<notin> Mapping.keys m'" by (simp add: assms)
+    hence "Mapping.lookup m' x = None" by transfer (simp add: dom_def)
+    with None show ?thesis by simp
+  next
+    case (Some y)
+    hence A: "x \<in> Mapping.keys m" by transfer (simp add: dom_def)
+    hence "x \<in> Mapping.keys m'" by (simp add: assms)
+    hence "\<exists>y'. Mapping.lookup m' x = Some y'" by transfer (simp add: dom_def)
+    with Some assms(1)[OF A] show ?thesis by (auto simp add: lookup_default_def)
+  qed
+qed
+
 lemma lookup_update:
   "lookup (update k v m) k = Some v" 
   by transfer simp
@@ -314,6 +340,51 @@
              f (Mapping.lookup_default d m1 x) (Mapping.lookup_default d m2 x)"
   by (auto simp: lookup_default_def lookup_combine assms split: option.splits)
 
+lemma lookup_map_entry:
+  "lookup (map_entry x f m) x = map_option f (lookup m x)"
+  by transfer (auto split: option.splits)
+
+lemma lookup_map_entry_neq:
+  "x \<noteq> y \<Longrightarrow> lookup (map_entry x f m) y = lookup m y"
+  by transfer (auto split: option.splits)
+
+lemma lookup_map_entry':
+  "lookup (map_entry x f m) y = 
+     (if x = y then map_option f (lookup m y) else lookup m y)"
+  by transfer (auto split: option.splits)
+  
+lemma lookup_default:
+  "lookup (default x d m) x = Some (lookup_default d m x)"
+    unfolding lookup_default_def default_def
+    by transfer (auto split: option.splits)
+
+lemma lookup_default_neq:
+  "x \<noteq> y \<Longrightarrow> lookup (default x d m) y = lookup m y"
+    unfolding lookup_default_def default_def
+    by transfer (auto split: option.splits)
+
+lemma lookup_default':
+  "lookup (default x d m) y = 
+     (if x = y then Some (lookup_default d m x) else lookup m y)"
+  unfolding lookup_default_def default_def
+  by transfer (auto split: option.splits)
+  
+lemma lookup_map_default:
+  "lookup (map_default x d f m) x = Some (f (lookup_default d m x))"
+    unfolding lookup_default_def default_def
+    by (simp add: map_default_def lookup_map_entry lookup_default lookup_default_def)
+
+lemma lookup_map_default_neq:
+  "x \<noteq> y \<Longrightarrow> lookup (map_default x d f m) y = lookup m y"
+    unfolding lookup_default_def default_def
+    by (simp add: map_default_def lookup_map_entry_neq lookup_default_neq) 
+
+lemma lookup_map_default':
+  "lookup (map_default x d f m) y = 
+     (if x = y then Some (f (lookup_default d m x)) else lookup m y)"
+    unfolding lookup_default_def default_def
+    by (simp add: map_default_def lookup_map_entry' lookup_default' lookup_default_def)  
+
 lemma lookup_tabulate: 
   assumes "distinct xs"
   shows   "Mapping.lookup (Mapping.tabulate xs f) x = (if x \<in> set xs then Some (f x) else None)"
--- a/src/HOL/Library/Multiset.thy	Tue May 31 13:02:44 2016 +0200
+++ b/src/HOL/Library/Multiset.thy	Wed Jun 01 13:48:34 2016 +0200
@@ -90,6 +90,15 @@
 
 end
 
+context
+begin
+
+qualified definition is_empty :: "'a multiset \<Rightarrow> bool" where
+  [code_abbrev]: "is_empty A \<longleftrightarrow> A = {#}"
+
+end
+
+
 lift_definition single :: "'a \<Rightarrow> 'a multiset" is "\<lambda>a b. if b = a then 1 else 0"
 by (rule only1_in_multiset)
 
@@ -2583,6 +2592,9 @@
 lemma [code]: "{#} = mset []"
   by simp
 
+lemma [code]: "Multiset.is_empty (mset xs) \<longleftrightarrow> List.null xs"
+  by (simp add: Multiset.is_empty_def List.null_def)
+
 lemma [code]: "{#x#} = mset [x]"
   by simp
 
--- a/src/HOL/Probability/PMF_Impl.thy	Tue May 31 13:02:44 2016 +0200
+++ b/src/HOL/Probability/PMF_Impl.thy	Wed Jun 01 13:48:34 2016 +0200
@@ -5,10 +5,14 @@
     by default. Also includes Quickcheck setup for PMFs.
 *)
 
+section \<open>Code generation for PMFs\<close>
+
 theory PMF_Impl
 imports Probability_Mass_Function "~~/src/HOL/Library/AList_Mapping"
 begin
 
+subsection \<open>General code generation setup\<close>
+
 definition pmf_of_mapping :: "('a, real) mapping \<Rightarrow> 'a pmf" where
   "pmf_of_mapping m = embed_pmf (Mapping.lookup_default 0 m)" 
 
@@ -95,7 +99,26 @@
 
 definition pmf_of_set_impl where
   "pmf_of_set_impl A = mapping_of_pmf (pmf_of_set A)"
-  
+
+(* This equation can be used to easily implement pmf_of_set for other set implementations *)
+lemma pmf_of_set_impl_code_alt:
+  assumes "A \<noteq> {}" "finite A"
+  shows   "pmf_of_set_impl A = 
+             (let p = 1 / real (card A) 
+              in  Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A)"
+proof -
+  define p where "p = 1 / real (card A)"
+  let ?m = "Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A"
+  interpret comp_fun_idem "\<lambda>x. Mapping.update x p"
+    by standard (transfer, force simp: fun_eq_iff)+
+  have keys: "Mapping.keys ?m = A"
+    using assms(2) by (induction A rule: finite_induct) simp_all
+  have lookup: "Mapping.lookup ?m x = Some p" if "x \<in> A" for x
+    using assms(2) that by (induction A rule: finite_induct) (auto simp: lookup_update')
+  from keys lookup assms show ?thesis unfolding pmf_of_set_impl_def
+    by (intro mapping_of_pmfI) (simp_all add: Let_def p_def)
+qed
+
 lemma pmf_of_set_impl_code [code]:
   "pmf_of_set_impl (set xs) = 
     (if xs = [] then
@@ -116,7 +139,27 @@
   using assms by (intro mapping_of_pmfI) (auto simp: lookup_tabulate)
 
 definition pmf_of_multiset_impl where
-  "pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)"
+  "pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)"  
+
+lemma pmf_of_multiset_impl_code_alt:
+  assumes "A \<noteq> {#}"
+  shows   "pmf_of_multiset_impl A =
+             (let p = 1 / real (size A)
+              in  fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A)"
+proof -
+  define p where "p = 1 / real (size A)"
+  interpret comp_fun_commute "\<lambda>x. Mapping.map_default x 0 (op + p)"
+    unfolding Mapping.map_default_def [abs_def]
+    by (standard, intro mapping_eqI ext) 
+       (simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def)
+  let ?m = "fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A"
+  have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all
+  have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x
+    by (induction A)
+       (simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs)
+  from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def
+    by (intro mapping_of_pmfI') (simp_all add: Let_def p_def)
+qed
 
 lemma pmf_of_multiset_impl_code [code]:
   "pmf_of_multiset_impl (mset xs) =
@@ -126,12 +169,13 @@
       else let xs' = remdups xs; p = 1 / real (length xs) in
          Mapping.tabulate xs' (\<lambda>x. real (count (mset xs) x) * p))"
   using pmf_of_multiset_pmf_of_mapping[of "mset xs" "remdups xs"]
-  by (simp add: pmf_of_multiset_impl_def)
+  by (simp add: pmf_of_multiset_impl_def)        
 
 lemma pmf_of_multiset_code [code abstract]:
   "mapping_of_pmf (pmf_of_multiset A) = pmf_of_multiset_impl A"
   by (simp add: pmf_of_multiset_impl_def)
 
+  
 lemma bernoulli_pmf_code [code abstract]:
   "mapping_of_pmf (bernoulli_pmf p) = 
      (if p \<le> 0 then Mapping.update False 1 Mapping.empty 
@@ -140,8 +184,6 @@
   by (intro mapping_of_pmfI) (auto simp: bernoulli_pmf.rep_eq lookup_update' set_pmf_eq)
 
 
-
-
 lemma pmf_code [code]: "pmf p x = Mapping.lookup_default 0 (mapping_of_pmf p) x"
   unfolding mapping_of_pmf_def Mapping.lookup_default_def
   by (auto split: option.splits simp: id_def Mapping.lookup.abs_eq)
@@ -154,14 +196,6 @@
   
 
 
-(* This is necessary since we want something the guarantees finiteness, but simply using 
-   "finite" restricts the code equations to types where finiteness of the universe can 
-   be decided. This simply fails when finiteness is not clear *)
-definition is_list_set where "is_list_set A = finite A"
-
-lemma is_list_set_code [code]: "is_list_set (set xs) = True"
-  by (simp add: is_list_set_def)
-
 definition fold_combine_plus where
   "fold_combine_plus = comm_monoid_set.F (Mapping.combine (op + :: real \<Rightarrow> _)) Mapping.empty"
 
@@ -189,15 +223,15 @@
   by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code)
 
 private lemma lookup_default_0_map_values:
-  assumes "f 0 = 0"
-  shows   "Mapping.lookup_default 0 (Mapping.map_values f m) x = f (Mapping.lookup_default 0 m x)"
+  assumes "f x 0 = 0"
+  shows   "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)"
   unfolding Mapping.lookup_default_def
-  using assms by transfer (auto split: option.splits)  
+  using assms by transfer (auto split: option.splits)
 
 qualified lemma mapping_of_bind_pmf:
   assumes "finite (set_pmf p)"
   shows   "mapping_of_pmf (bind_pmf p f) = 
-             fold_combine_plus (\<lambda>x. Mapping.map_values (op * (pmf p x)) 
+             fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x)) 
                (mapping_of_pmf (f x))) (set_pmf p)"
   using assms
   by (intro mapping_of_pmfI')
@@ -205,71 +239,127 @@
                  pmf_bind integral_measure_pmf lookup_default_0_map_values 
                  lookup_default_mapping_of_pmf mult_ac)
 
-lemma bind_pmf_code [code abstract]:
-  "mapping_of_pmf (bind_pmf p f) = 
-     (let A = set_pmf p in if is_list_set A then
-       fold_combine_plus (\<lambda>x. Mapping.map_values (op * (pmf p x)) (mapping_of_pmf (f x))) A
-     else
-       Code.abort (STR ''bind_pmf with infinite support.'') (\<lambda>_. mapping_of_pmf (bind_pmf p f)))"
-  using mapping_of_bind_pmf[of p f] by (auto simp: Let_def is_list_set_def)
+lift_definition bind_pmf_aux :: "'a pmf \<Rightarrow> ('a \<Rightarrow> 'b pmf) \<Rightarrow> 'a set \<Rightarrow> ('b, real) mapping" is
+  "\<lambda>(p :: 'a pmf) (f :: 'a \<Rightarrow> 'b pmf) (A::'a set) (x::'b). 
+     if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then 
+       Some (measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)) 
+     else None" .
+
+lemma keys_bind_pmf_aux [simp]:
+  "Mapping.keys (bind_pmf_aux p f A) = (\<Union>x\<in>A. set_pmf (f x))"
+  by transfer (auto split: if_splits)
+
+lemma lookup_default_bind_pmf_aux:
+  "Mapping.lookup_default 0 (bind_pmf_aux p f A) x = 
+     (if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then 
+        measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x) else 0)"
+  unfolding lookup_default_def by transfer' simp_all
+
+lemma lookup_default_bind_pmf_aux' [simp]:
+  "Mapping.lookup_default 0 (bind_pmf_aux p f (set_pmf p)) x = pmf (bind_pmf p f) x"
+  unfolding lookup_default_def
+  by transfer (auto simp: pmf_bind AE_measure_pmf_iff set_pmf_eq
+                    intro!: integral_cong_AE integral_eq_zero_AE)
+  
+lemma bind_pmf_aux_correct:
+  "mapping_of_pmf (bind_pmf p f) = bind_pmf_aux p f (set_pmf p)"
+  by (intro mapping_of_pmfI') simp_all
+
+lemma bind_pmf_aux_code_aux:
+  assumes "finite A"
+  shows   "bind_pmf_aux p f A = 
+             fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x))
+               (mapping_of_pmf (f x))) A" (is "?lhs = ?rhs")
+proof (intro mapping_eqI'[where d = 0])
+  fix x assume "x \<in> Mapping.keys ?lhs"
+  then obtain y where y: "y \<in> A" "x \<in> set_pmf (f y)" by auto
+  hence "Mapping.lookup_default 0 ?lhs x = 
+           measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)"
+    by (auto simp: lookup_default_bind_pmf_aux)
+  also from assms have "\<dots> = (\<Sum>y\<in>A. pmf p y * pmf (f y) x)"
+    by (subst integral_measure_pmf [of A])
+       (auto simp: set_pmf_eq indicator_def mult_ac split: if_splits)
+  also from assms have "\<dots> = Mapping.lookup_default 0 ?rhs x"
+    by (simp add: lookup_default_fold_combine_plus lookup_default_0_map_values
+          lookup_default_mapping_of_pmf)
+  finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" .
+qed (insert assms, simp_all add: keys_fold_combine_plus)
+
+lemma bind_pmf_aux_code [code]:
+  "bind_pmf_aux p f (set xs) = 
+     fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x))
+               (mapping_of_pmf (f x))) (set xs)"
+  by (rule bind_pmf_aux_code_aux) simp_all
+
+lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct
 
 end
 
-hide_const (open) is_list_set fold_combine_plus
+hide_const (open) fold_combine_plus
 
 
 lift_definition cond_pmf_impl :: "'a pmf \<Rightarrow> 'a set \<Rightarrow> ('a, real) mapping option" is
   "\<lambda>p A. if A \<inter> set_pmf p = {} then None else 
      Some (\<lambda>x. if x \<in> A \<inter> set_pmf p then Some (pmf p x / measure_pmf.prob p A) else None)" .
 
-lemma cond_pmf_impl_code [code]:
-  "cond_pmf_impl p (set xs) = (
-     let B = set_pmf p;
-         xs' = remdups (filter (\<lambda>x. x \<in> B) xs);
-         prob = listsum (map (pmf p) xs')
-     in  if prob = 0 then 
-           None
-         else
-           Some (Mapping.map_values (\<lambda>y. y / prob) 
-             (Mapping.filter (\<lambda>k _. k \<in> set xs') (mapping_of_pmf p))))"     
+lemma cond_pmf_impl_code_alt:
+  assumes "finite A"
+  shows   "cond_pmf_impl p A = (
+             let C = A \<inter> set_pmf p;
+                 prob = (\<Sum>x\<in>C. pmf p x)
+             in  if prob = 0 then 
+                   None
+                 else
+                   Some (Mapping.map_values (\<lambda>_ y. y / prob) 
+                     (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))"
 proof -
-  define xs' where "xs' = remdups (filter (\<lambda>x. x \<in> set_pmf p) xs)"
-  have xs': "set xs' = set xs \<inter> set_pmf p" "distinct xs'" by (auto simp: xs'_def)
-  define prob where "prob = listsum (map (pmf p) xs')"
-  have "prob = (\<Sum>x\<in>set xs'. pmf p x)"
-    unfolding prob_def by (rule listsum_distinct_conv_setsum_set) (simp_all add: xs'_def)
-  also note xs'(1)
-  also have "(\<Sum>x\<in>set xs \<inter> set_pmf p. pmf p x) = (\<Sum>x\<in>set xs. pmf p x)"
+  define C where "C = A \<inter> set_pmf p"
+  define prob where "prob = (\<Sum>x\<in>C. pmf p x)"
+  also note C_def
+  also from assms have "(\<Sum>x\<in>A \<inter> set_pmf p. pmf p x) = (\<Sum>x\<in>A. pmf p x)"
     by (intro setsum.mono_neutral_left) (auto simp: set_pmf_eq)
-  finally have prob1: "prob = (\<Sum>x\<in>set xs. pmf p x)" .
-  hence prob2: "prob = measure_pmf.prob p (set xs)"
-    by (subst measure_measure_pmf_finite) simp_all
-  have prob3: "prob = 0 \<longleftrightarrow> set xs \<inter> set_pmf p = {}"
-    by (subst prob1, subst setsum_nonneg_eq_0_iff) (auto simp: set_pmf_eq)
+  finally have prob1: "prob = (\<Sum>x\<in>A. pmf p x)" .
+  hence prob2: "prob = measure_pmf.prob p A"
+    using assms by (subst measure_measure_pmf_finite) simp_all
+  have prob3: "prob = 0 \<longleftrightarrow> A \<inter> set_pmf p = {}"
+    by (subst prob1, subst setsum_nonneg_eq_0_iff) (auto simp: set_pmf_eq assms)
+  from assms have prob4: "prob = measure_pmf.prob p C"
+    unfolding prob_def by (intro measure_measure_pmf_finite [symmetric]) (simp_all add: C_def)
   
   show ?thesis
   proof (cases "prob = 0")
     case True
-    hence "set xs \<inter> set_pmf p = {}" by (subst (asm) prob3)
-    with True show ?thesis by (simp add: Let_def prob_def xs'_def cond_pmf_impl.abs_eq)
+    hence "A \<inter> set_pmf p = {}" by (subst (asm) prob3)
+    with True show ?thesis by (simp add: Let_def prob_def C_def cond_pmf_impl.abs_eq)
   next
     case False
-    hence A: "set xs' \<noteq> {}" unfolding xs' by (subst (asm) prob3) auto
-    with xs' prob3 have prob_nz: "prob \<noteq> 0" by auto
+    hence A: "C \<noteq> {}" unfolding C_def by (subst (asm) prob3) auto
+    with prob3 have prob_nz: "prob \<noteq> 0" by (auto simp: C_def)
     fix x
-    have "cond_pmf_impl p (set xs) = 
-            Some (mapping.Mapping (\<lambda>x. if x \<in> set xs' then 
-              Some (pmf p x / measure_pmf.prob p (set xs)) else None))" 
+    have "cond_pmf_impl p A = 
+            Some (mapping.Mapping (\<lambda>x. if x \<in> C then 
+              Some (pmf p x / measure_pmf.prob p C) else None))" 
          (is "_ = Some ?m")
-      using A unfolding xs'_def by transfer auto
-    also have "?m = Mapping.map_values (\<lambda>y. y / prob) 
-                 (Mapping.filter (\<lambda>k _. k \<in> set xs') (mapping_of_pmf p))"
-      unfolding prob2 [symmetric] xs' using xs' prob_nz 
-      by transfer (rule ext, simp add: set_pmf_eq)
-    finally show ?thesis using False by (simp add: Let_def prob_def xs'_def)
+      using A prob2 prob4 unfolding C_def by transfer (auto simp: fun_eq_iff)
+    also have "?m = Mapping.map_values (\<lambda>_ y. y / prob) 
+                 (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))"
+      using prob_nz prob4 assms unfolding C_def
+      by transfer (auto simp: fun_eq_iff set_pmf_eq)
+    finally show ?thesis using False by (simp add: Let_def prob_def C_def)
   qed
 qed
 
+lemma cond_pmf_impl_code [code]:
+  "cond_pmf_impl p (set xs) = (
+     let C = set xs \<inter> set_pmf p;
+         prob = (\<Sum>x\<in>C. pmf p x)
+     in  if prob = 0 then 
+           None
+         else
+           Some (Mapping.map_values (\<lambda>_ y. y / prob) 
+             (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))"
+  by (rule cond_pmf_impl_code_alt) simp_all
+
 lemma cond_pmf_code [code abstract]:
   "mapping_of_pmf (cond_pmf p A) = 
      (case cond_pmf_impl p A of
@@ -290,7 +380,8 @@
 lemma binomial_pmf_code [code abstract]:
   "mapping_of_pmf (binomial_pmf n p) = (
      if p < 0 \<or> p > 1 then 
-       Code.abort (STR ''binomial_pmf with invalid probability'') (\<lambda>_. mapping_of_pmf (binomial_pmf n p))
+       Code.abort (STR ''binomial_pmf with invalid probability'') 
+         (\<lambda>_. mapping_of_pmf (binomial_pmf n p))
      else if p = 0 then Mapping.update 0 1 Mapping.empty
      else if p = 1 then Mapping.update n 1 Mapping.empty
      else Mapping.tabulate [0..<Suc n] (\<lambda>k. real (n choose k) * p ^ k * (1 - p) ^ (n - k)))"
@@ -298,53 +389,12 @@
      (simp, intro mapping_of_pmfI, 
       auto simp: lookup_update' lookup_empty set_pmf_binomial_eq lookup_tabulate split: if_splits)
 
+
 lemma pred_pmf_code [code]:
   "pred_pmf P p = (\<forall>x\<in>set_pmf p. P x)"
   by (auto simp: pred_pmf_def)
 
 
-definition pmf_integral where
-  "pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \<Rightarrow> real)"
-
-definition pmf_set_integral where
-  "pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\<lambda>x. indicator A x * f x :: real)"
-
-definition pmf_prob where
-  "pmf_prob p A = measure_pmf.prob p A"
-
-lemma pmf_integral_pmf_set_integral [code]:
-  "pmf_integral p f = pmf_set_integral p f (set_pmf p)"
-  unfolding pmf_integral_def pmf_set_integral_def
-  by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff)
-
-lemma pmf_set_integral_code [code]:
-  "pmf_set_integral p f (set xs) = listsum (map (\<lambda>x. pmf p x * f x) (remdups xs))"
-proof -
-  have "listsum (map (\<lambda>x. pmf p x * f x) (remdups xs)) = (\<Sum>x\<in>set xs. pmf p x * f x)"
-    by (subst listsum_distinct_conv_setsum_set) simp_all
-  also have "\<dots> = pmf_set_integral p f (set xs)" unfolding pmf_set_integral_def
-   by (subst integral_measure_pmf[of "set xs"])
-      (auto simp: indicator_def mult_ac split: if_splits)
-  finally show ?thesis ..
-qed
-
-lemma pmf_prob_code [code]:
-  "pmf_prob p (set xs) = listsum (map (pmf p) (remdups xs))"
-proof -
-  have "pmf_prob p (set xs) = pmf_set_integral p (\<lambda>_. 1) (set xs)"
-    unfolding pmf_prob_def pmf_set_integral_def by simp
-  also have "\<dots> = listsum (map (pmf p) (remdups xs))"
-    unfolding pmf_set_integral_code by simp
-  finally show ?thesis .
-qed
-
-lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p"
-  by (intro ext) (simp add: pmf_prob_def)
-
-(* Why does this not work without parameters? *)
-lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p"
-  by (intro ext) (simp add: pmf_integral_def)
-
 lemma mapping_of_pmf_pmf_of_list:
   assumes "\<And>x. x \<in> snd ` set xs \<Longrightarrow> x > 0" "listsum (map snd xs) = 1"
   shows   "mapping_of_pmf (pmf_of_list xs) = 
@@ -389,7 +439,6 @@
        Code.abort (STR ''Invalid list for pmf_of_list'') (\<lambda>_. mapping_of_pmf (pmf_of_list xs)))"
   using mapping_of_pmf_pmf_of_list'[of xs] by (simp add: Let_def)
 
-  
 lemma mapping_of_pmf_eq_iff [simp]:
   "mapping_of_pmf p = mapping_of_pmf q \<longleftrightarrow> p = (q :: 'a pmf)"
 proof (transfer, intro iffI pmf_eqI)
@@ -402,6 +451,66 @@
   from this[of x] show "pmf p x = pmf q x" by (auto split: if_splits)
 qed (simp_all cong: if_cong)
 
+
+subsection \<open>Code abbreviations for integrals and probabilities\<close>
+
+text \<open>
+  Integrals and probabilities are defined for general measures, so we cannot give any
+  code equations directly. We can, however, specialise these constants them to PMFs, 
+  give code equations for these specialised constants, and tell the code generator 
+  to unfold the original constants to the specialised ones whenever possible.
+\<close>
+
+definition pmf_integral where
+  "pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \<Rightarrow> real)"
+
+definition pmf_set_integral where
+  "pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\<lambda>x. indicator A x * f x :: real)"
+
+definition pmf_prob where
+  "pmf_prob p A = measure_pmf.prob p A"
+
+lemma pmf_prob_compl: "pmf_prob p (-A) = 1 - pmf_prob p A"
+  using measure_pmf.prob_compl[of A p] by (simp add: pmf_prob_def Compl_eq_Diff_UNIV)
+
+lemma pmf_integral_pmf_set_integral [code]:
+  "pmf_integral p f = pmf_set_integral p f (set_pmf p)"
+  unfolding pmf_integral_def pmf_set_integral_def
+  by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff)
+
+lemma pmf_prob_pmf_set_integral:
+  "pmf_prob p A = pmf_set_integral p (\<lambda>_. 1) A"
+  by (simp add: pmf_prob_def pmf_set_integral_def)
+  
+lemma pmf_set_integral_code_alt_finite:
+  "finite A \<Longrightarrow> pmf_set_integral p f A = (\<Sum>x\<in>A. pmf p x * f x)"
+  unfolding pmf_set_integral_def
+  by (subst integral_measure_pmf[of A]) (auto simp: indicator_def mult_ac split: if_splits)
+  
+lemma pmf_set_integral_code [code]:
+  "pmf_set_integral p f (set xs) = (\<Sum>x\<in>set xs. pmf p x * f x)"
+  by (rule pmf_set_integral_code_alt_finite) simp_all
+
+
+lemma pmf_prob_code_alt_finite:
+  "finite A \<Longrightarrow> pmf_prob p A = (\<Sum>x\<in>A. pmf p x)"
+  by (simp add: pmf_prob_pmf_set_integral pmf_set_integral_code_alt_finite)
+
+lemma pmf_prob_code [code]:
+  "pmf_prob p (set xs) = (\<Sum>x\<in>set xs. pmf p x)"
+  "pmf_prob p (List.coset xs) = 1 - (\<Sum>x\<in>set xs. pmf p x)"
+  by (simp_all add: pmf_prob_code_alt_finite pmf_prob_compl)
+  
+
+lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p"
+  by (intro ext) (simp add: pmf_prob_def)
+
+(* FIXME: Why does this not work without parameters? *)
+lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p"
+  by (intro ext) (simp add: pmf_integral_def)
+
+
+
 definition "pmf_of_alist xs = embed_pmf (\<lambda>x. case map_of xs x of Some p \<Rightarrow> p | None \<Rightarrow> 0)"
 
 lemma pmf_of_mapping_Mapping [code_post]:
@@ -447,21 +556,6 @@
 no_notation fcomp (infixl "\<circ>>" 60)
 no_notation scomp (infixl "\<circ>\<rightarrow>" 60)
 
-(*
-instantiation pmf :: (exhaustive) exhaustive
-begin
-
-definition exhaustive_pmf :: "('a pmf \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option"
-where
-  "exhaustive_pmf f i =
-     Quickcheck_Exhaustive.exhaustive (\<lambda>A. 
-       Quickcheck_Exhaustive.exhaustive (\<lambda>x. f (pmf_of_multiset (A + {#x#}))) i) i"
-
-instance ..
-
-end
-*)
-
 instantiation pmf :: (full_exhaustive) full_exhaustive
 begin