--- a/src/HOL/Library/AList_Mapping.thy Tue May 31 13:02:44 2016 +0200
+++ b/src/HOL/Library/AList_Mapping.thy Wed Jun 01 13:48:34 2016 +0200
@@ -64,13 +64,25 @@
qed
lemma map_values_Mapping [code]:
- fixes f :: "'a \<Rightarrow> 'b" and xs :: "('c \<times> 'a) list"
- shows "Mapping.map_values f (Mapping xs) = Mapping (map (\<lambda>(x,y). (x, f y)) xs)"
+ fixes f :: "'c \<Rightarrow> 'a \<Rightarrow> 'b" and xs :: "('c \<times> 'a) list"
+ shows "Mapping.map_values f (Mapping xs) = Mapping (map (\<lambda>(x,y). (x, f x y)) xs)"
proof (transfer, rule ext, goal_cases)
case (1 f xs x)
thus ?case by (induction xs) auto
qed
+lemma combine_with_key_code [code]:
+ "Mapping.combine_with_key f (Mapping xs) (Mapping ys) =
+ Mapping.tabulate (remdups (map fst xs @ map fst ys))
+ (\<lambda>x. the (combine_options (f x) (map_of xs x) (map_of ys x)))"
+proof (transfer, rule ext, rule sym, goal_cases)
+ case (1 f xs ys x)
+ show ?case
+ by (cases "map_of xs x"; cases "map_of ys x"; simp)
+ (force simp: map_of_eq_None_iff combine_options_def option.the_def o_def image_iff
+ dest: map_of_SomeD split: option.splits)+
+qed
+
lemma combine_code [code]:
"Mapping.combine f (Mapping xs) (Mapping ys) =
Mapping.tabulate (remdups (map fst xs @ map fst ys))
@@ -79,7 +91,7 @@
case (1 f xs ys x)
show ?case
by (cases "map_of xs x"; cases "map_of ys x"; simp)
- (force simp: map_of_eq_None_iff combine_options_altdef option.the_def o_def image_iff
+ (force simp: map_of_eq_None_iff combine_options_def option.the_def o_def image_iff
dest: map_of_SomeD split: option.splits)+
qed
--- a/src/HOL/Library/DAList_Multiset.thy Tue May 31 13:02:44 2016 +0200
+++ b/src/HOL/Library/DAList_Multiset.thy Wed Jun 01 13:48:34 2016 +0200
@@ -12,6 +12,8 @@
lemma [code, code del]: "{#} = {#}" ..
+lemma [code, code del]: "Multiset.is_empty = Multiset.is_empty" ..
+
lemma [code, code del]: "single = single" ..
lemma [code, code del]: "plus = (plus :: 'a multiset \<Rightarrow> _)" ..
@@ -187,6 +189,27 @@
lemma Mempty_Bag [code]: "{#} = Bag (DAList.empty)"
by (simp add: multiset_eq_iff alist.Alist_inverse DAList.empty_def)
+lift_definition is_empty_Bag_impl :: "('a, nat) alist \<Rightarrow> bool" is
+ "\<lambda>xs. list_all (\<lambda>x. snd x = 0) xs" .
+
+lemma is_empty_Bag [code]: "Multiset.is_empty (Bag xs) \<longleftrightarrow> is_empty_Bag_impl xs"
+proof -
+ have "Multiset.is_empty (Bag xs) \<longleftrightarrow> (\<forall>x. count (Bag xs) x = 0)"
+ unfolding Multiset.is_empty_def multiset_eq_iff by simp
+ also have "\<dots> \<longleftrightarrow> (\<forall>x\<in>fst ` set (alist.impl_of xs). count (Bag xs) x = 0)"
+ proof (intro iffI allI ballI)
+ fix x assume A: "\<forall>x\<in>fst ` set (alist.impl_of xs). count (Bag xs) x = 0"
+ thus "count (Bag xs) x = 0"
+ proof (cases "x \<in> fst ` set (alist.impl_of xs)")
+ case False
+ thus ?thesis by (force simp: count_of_def split: option.splits)
+ qed (insert A, auto)
+ qed simp_all
+ also have "\<dots> \<longleftrightarrow> list_all (\<lambda>x. snd x = 0) (alist.impl_of xs)"
+ by (auto simp: count_of_def list_all_def)
+ finally show ?thesis by (simp add: is_empty_Bag_impl.rep_eq)
+qed
+
lemma single_Bag [code]: "{#x#} = Bag (DAList.update x 1 DAList.empty)"
by (simp add: multiset_eq_iff alist.Alist_inverse update.rep_eq empty.rep_eq)
--- a/src/HOL/Library/Mapping.thy Tue May 31 13:02:44 2016 +0200
+++ b/src/HOL/Library/Mapping.thy Wed Jun 01 13:48:34 2016 +0200
@@ -240,6 +240,32 @@
subsection \<open>Properties\<close>
+lemma mapping_eqI:
+ "(\<And>x. lookup m x = lookup m' x) \<Longrightarrow> m = m'"
+ by transfer (simp add: fun_eq_iff)
+
+lemma mapping_eqI':
+ assumes "\<And>x. x \<in> Mapping.keys m \<Longrightarrow> Mapping.lookup_default d m x = Mapping.lookup_default d m' x"
+ and "Mapping.keys m = Mapping.keys m'"
+ shows "m = m'"
+proof (intro mapping_eqI)
+ fix x
+ show "Mapping.lookup m x = Mapping.lookup m' x"
+ proof (cases "Mapping.lookup m x")
+ case None
+ hence "x \<notin> Mapping.keys m" by transfer (simp add: dom_def)
+ hence "x \<notin> Mapping.keys m'" by (simp add: assms)
+ hence "Mapping.lookup m' x = None" by transfer (simp add: dom_def)
+ with None show ?thesis by simp
+ next
+ case (Some y)
+ hence A: "x \<in> Mapping.keys m" by transfer (simp add: dom_def)
+ hence "x \<in> Mapping.keys m'" by (simp add: assms)
+ hence "\<exists>y'. Mapping.lookup m' x = Some y'" by transfer (simp add: dom_def)
+ with Some assms(1)[OF A] show ?thesis by (auto simp add: lookup_default_def)
+ qed
+qed
+
lemma lookup_update:
"lookup (update k v m) k = Some v"
by transfer simp
@@ -314,6 +340,51 @@
f (Mapping.lookup_default d m1 x) (Mapping.lookup_default d m2 x)"
by (auto simp: lookup_default_def lookup_combine assms split: option.splits)
+lemma lookup_map_entry:
+ "lookup (map_entry x f m) x = map_option f (lookup m x)"
+ by transfer (auto split: option.splits)
+
+lemma lookup_map_entry_neq:
+ "x \<noteq> y \<Longrightarrow> lookup (map_entry x f m) y = lookup m y"
+ by transfer (auto split: option.splits)
+
+lemma lookup_map_entry':
+ "lookup (map_entry x f m) y =
+ (if x = y then map_option f (lookup m y) else lookup m y)"
+ by transfer (auto split: option.splits)
+
+lemma lookup_default:
+ "lookup (default x d m) x = Some (lookup_default d m x)"
+ unfolding lookup_default_def default_def
+ by transfer (auto split: option.splits)
+
+lemma lookup_default_neq:
+ "x \<noteq> y \<Longrightarrow> lookup (default x d m) y = lookup m y"
+ unfolding lookup_default_def default_def
+ by transfer (auto split: option.splits)
+
+lemma lookup_default':
+ "lookup (default x d m) y =
+ (if x = y then Some (lookup_default d m x) else lookup m y)"
+ unfolding lookup_default_def default_def
+ by transfer (auto split: option.splits)
+
+lemma lookup_map_default:
+ "lookup (map_default x d f m) x = Some (f (lookup_default d m x))"
+ unfolding lookup_default_def default_def
+ by (simp add: map_default_def lookup_map_entry lookup_default lookup_default_def)
+
+lemma lookup_map_default_neq:
+ "x \<noteq> y \<Longrightarrow> lookup (map_default x d f m) y = lookup m y"
+ unfolding lookup_default_def default_def
+ by (simp add: map_default_def lookup_map_entry_neq lookup_default_neq)
+
+lemma lookup_map_default':
+ "lookup (map_default x d f m) y =
+ (if x = y then Some (f (lookup_default d m x)) else lookup m y)"
+ unfolding lookup_default_def default_def
+ by (simp add: map_default_def lookup_map_entry' lookup_default' lookup_default_def)
+
lemma lookup_tabulate:
assumes "distinct xs"
shows "Mapping.lookup (Mapping.tabulate xs f) x = (if x \<in> set xs then Some (f x) else None)"
--- a/src/HOL/Library/Multiset.thy Tue May 31 13:02:44 2016 +0200
+++ b/src/HOL/Library/Multiset.thy Wed Jun 01 13:48:34 2016 +0200
@@ -90,6 +90,15 @@
end
+context
+begin
+
+qualified definition is_empty :: "'a multiset \<Rightarrow> bool" where
+ [code_abbrev]: "is_empty A \<longleftrightarrow> A = {#}"
+
+end
+
+
lift_definition single :: "'a \<Rightarrow> 'a multiset" is "\<lambda>a b. if b = a then 1 else 0"
by (rule only1_in_multiset)
@@ -2583,6 +2592,9 @@
lemma [code]: "{#} = mset []"
by simp
+lemma [code]: "Multiset.is_empty (mset xs) \<longleftrightarrow> List.null xs"
+ by (simp add: Multiset.is_empty_def List.null_def)
+
lemma [code]: "{#x#} = mset [x]"
by simp
--- a/src/HOL/Probability/PMF_Impl.thy Tue May 31 13:02:44 2016 +0200
+++ b/src/HOL/Probability/PMF_Impl.thy Wed Jun 01 13:48:34 2016 +0200
@@ -5,10 +5,14 @@
by default. Also includes Quickcheck setup for PMFs.
*)
+section \<open>Code generation for PMFs\<close>
+
theory PMF_Impl
imports Probability_Mass_Function "~~/src/HOL/Library/AList_Mapping"
begin
+subsection \<open>General code generation setup\<close>
+
definition pmf_of_mapping :: "('a, real) mapping \<Rightarrow> 'a pmf" where
"pmf_of_mapping m = embed_pmf (Mapping.lookup_default 0 m)"
@@ -95,7 +99,26 @@
definition pmf_of_set_impl where
"pmf_of_set_impl A = mapping_of_pmf (pmf_of_set A)"
-
+
+(* This equation can be used to easily implement pmf_of_set for other set implementations *)
+lemma pmf_of_set_impl_code_alt:
+ assumes "A \<noteq> {}" "finite A"
+ shows "pmf_of_set_impl A =
+ (let p = 1 / real (card A)
+ in Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A)"
+proof -
+ define p where "p = 1 / real (card A)"
+ let ?m = "Finite_Set.fold (\<lambda>x. Mapping.update x p) Mapping.empty A"
+ interpret comp_fun_idem "\<lambda>x. Mapping.update x p"
+ by standard (transfer, force simp: fun_eq_iff)+
+ have keys: "Mapping.keys ?m = A"
+ using assms(2) by (induction A rule: finite_induct) simp_all
+ have lookup: "Mapping.lookup ?m x = Some p" if "x \<in> A" for x
+ using assms(2) that by (induction A rule: finite_induct) (auto simp: lookup_update')
+ from keys lookup assms show ?thesis unfolding pmf_of_set_impl_def
+ by (intro mapping_of_pmfI) (simp_all add: Let_def p_def)
+qed
+
lemma pmf_of_set_impl_code [code]:
"pmf_of_set_impl (set xs) =
(if xs = [] then
@@ -116,7 +139,27 @@
using assms by (intro mapping_of_pmfI) (auto simp: lookup_tabulate)
definition pmf_of_multiset_impl where
- "pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)"
+ "pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)"
+
+lemma pmf_of_multiset_impl_code_alt:
+ assumes "A \<noteq> {#}"
+ shows "pmf_of_multiset_impl A =
+ (let p = 1 / real (size A)
+ in fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A)"
+proof -
+ define p where "p = 1 / real (size A)"
+ interpret comp_fun_commute "\<lambda>x. Mapping.map_default x 0 (op + p)"
+ unfolding Mapping.map_default_def [abs_def]
+ by (standard, intro mapping_eqI ext)
+ (simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def)
+ let ?m = "fold_mset (\<lambda>x. Mapping.map_default x 0 (op + p)) Mapping.empty A"
+ have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all
+ have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x
+ by (induction A)
+ (simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs)
+ from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def
+ by (intro mapping_of_pmfI') (simp_all add: Let_def p_def)
+qed
lemma pmf_of_multiset_impl_code [code]:
"pmf_of_multiset_impl (mset xs) =
@@ -126,12 +169,13 @@
else let xs' = remdups xs; p = 1 / real (length xs) in
Mapping.tabulate xs' (\<lambda>x. real (count (mset xs) x) * p))"
using pmf_of_multiset_pmf_of_mapping[of "mset xs" "remdups xs"]
- by (simp add: pmf_of_multiset_impl_def)
+ by (simp add: pmf_of_multiset_impl_def)
lemma pmf_of_multiset_code [code abstract]:
"mapping_of_pmf (pmf_of_multiset A) = pmf_of_multiset_impl A"
by (simp add: pmf_of_multiset_impl_def)
+
lemma bernoulli_pmf_code [code abstract]:
"mapping_of_pmf (bernoulli_pmf p) =
(if p \<le> 0 then Mapping.update False 1 Mapping.empty
@@ -140,8 +184,6 @@
by (intro mapping_of_pmfI) (auto simp: bernoulli_pmf.rep_eq lookup_update' set_pmf_eq)
-
-
lemma pmf_code [code]: "pmf p x = Mapping.lookup_default 0 (mapping_of_pmf p) x"
unfolding mapping_of_pmf_def Mapping.lookup_default_def
by (auto split: option.splits simp: id_def Mapping.lookup.abs_eq)
@@ -154,14 +196,6 @@
-(* This is necessary since we want something the guarantees finiteness, but simply using
- "finite" restricts the code equations to types where finiteness of the universe can
- be decided. This simply fails when finiteness is not clear *)
-definition is_list_set where "is_list_set A = finite A"
-
-lemma is_list_set_code [code]: "is_list_set (set xs) = True"
- by (simp add: is_list_set_def)
-
definition fold_combine_plus where
"fold_combine_plus = comm_monoid_set.F (Mapping.combine (op + :: real \<Rightarrow> _)) Mapping.empty"
@@ -189,15 +223,15 @@
by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code)
private lemma lookup_default_0_map_values:
- assumes "f 0 = 0"
- shows "Mapping.lookup_default 0 (Mapping.map_values f m) x = f (Mapping.lookup_default 0 m x)"
+ assumes "f x 0 = 0"
+ shows "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)"
unfolding Mapping.lookup_default_def
- using assms by transfer (auto split: option.splits)
+ using assms by transfer (auto split: option.splits)
qualified lemma mapping_of_bind_pmf:
assumes "finite (set_pmf p)"
shows "mapping_of_pmf (bind_pmf p f) =
- fold_combine_plus (\<lambda>x. Mapping.map_values (op * (pmf p x))
+ fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x))
(mapping_of_pmf (f x))) (set_pmf p)"
using assms
by (intro mapping_of_pmfI')
@@ -205,71 +239,127 @@
pmf_bind integral_measure_pmf lookup_default_0_map_values
lookup_default_mapping_of_pmf mult_ac)
-lemma bind_pmf_code [code abstract]:
- "mapping_of_pmf (bind_pmf p f) =
- (let A = set_pmf p in if is_list_set A then
- fold_combine_plus (\<lambda>x. Mapping.map_values (op * (pmf p x)) (mapping_of_pmf (f x))) A
- else
- Code.abort (STR ''bind_pmf with infinite support.'') (\<lambda>_. mapping_of_pmf (bind_pmf p f)))"
- using mapping_of_bind_pmf[of p f] by (auto simp: Let_def is_list_set_def)
+lift_definition bind_pmf_aux :: "'a pmf \<Rightarrow> ('a \<Rightarrow> 'b pmf) \<Rightarrow> 'a set \<Rightarrow> ('b, real) mapping" is
+ "\<lambda>(p :: 'a pmf) (f :: 'a \<Rightarrow> 'b pmf) (A::'a set) (x::'b).
+ if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then
+ Some (measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x))
+ else None" .
+
+lemma keys_bind_pmf_aux [simp]:
+ "Mapping.keys (bind_pmf_aux p f A) = (\<Union>x\<in>A. set_pmf (f x))"
+ by transfer (auto split: if_splits)
+
+lemma lookup_default_bind_pmf_aux:
+ "Mapping.lookup_default 0 (bind_pmf_aux p f A) x =
+ (if x \<in> (\<Union>y\<in>A. set_pmf (f y)) then
+ measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x) else 0)"
+ unfolding lookup_default_def by transfer' simp_all
+
+lemma lookup_default_bind_pmf_aux' [simp]:
+ "Mapping.lookup_default 0 (bind_pmf_aux p f (set_pmf p)) x = pmf (bind_pmf p f) x"
+ unfolding lookup_default_def
+ by transfer (auto simp: pmf_bind AE_measure_pmf_iff set_pmf_eq
+ intro!: integral_cong_AE integral_eq_zero_AE)
+
+lemma bind_pmf_aux_correct:
+ "mapping_of_pmf (bind_pmf p f) = bind_pmf_aux p f (set_pmf p)"
+ by (intro mapping_of_pmfI') simp_all
+
+lemma bind_pmf_aux_code_aux:
+ assumes "finite A"
+ shows "bind_pmf_aux p f A =
+ fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x))
+ (mapping_of_pmf (f x))) A" (is "?lhs = ?rhs")
+proof (intro mapping_eqI'[where d = 0])
+ fix x assume "x \<in> Mapping.keys ?lhs"
+ then obtain y where y: "y \<in> A" "x \<in> set_pmf (f y)" by auto
+ hence "Mapping.lookup_default 0 ?lhs x =
+ measure_pmf.expectation p (\<lambda>y. indicator A y * pmf (f y) x)"
+ by (auto simp: lookup_default_bind_pmf_aux)
+ also from assms have "\<dots> = (\<Sum>y\<in>A. pmf p y * pmf (f y) x)"
+ by (subst integral_measure_pmf [of A])
+ (auto simp: set_pmf_eq indicator_def mult_ac split: if_splits)
+ also from assms have "\<dots> = Mapping.lookup_default 0 ?rhs x"
+ by (simp add: lookup_default_fold_combine_plus lookup_default_0_map_values
+ lookup_default_mapping_of_pmf)
+ finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" .
+qed (insert assms, simp_all add: keys_fold_combine_plus)
+
+lemma bind_pmf_aux_code [code]:
+ "bind_pmf_aux p f (set xs) =
+ fold_combine_plus (\<lambda>x. Mapping.map_values (\<lambda>_. op * (pmf p x))
+ (mapping_of_pmf (f x))) (set xs)"
+ by (rule bind_pmf_aux_code_aux) simp_all
+
+lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct
end
-hide_const (open) is_list_set fold_combine_plus
+hide_const (open) fold_combine_plus
lift_definition cond_pmf_impl :: "'a pmf \<Rightarrow> 'a set \<Rightarrow> ('a, real) mapping option" is
"\<lambda>p A. if A \<inter> set_pmf p = {} then None else
Some (\<lambda>x. if x \<in> A \<inter> set_pmf p then Some (pmf p x / measure_pmf.prob p A) else None)" .
-lemma cond_pmf_impl_code [code]:
- "cond_pmf_impl p (set xs) = (
- let B = set_pmf p;
- xs' = remdups (filter (\<lambda>x. x \<in> B) xs);
- prob = listsum (map (pmf p) xs')
- in if prob = 0 then
- None
- else
- Some (Mapping.map_values (\<lambda>y. y / prob)
- (Mapping.filter (\<lambda>k _. k \<in> set xs') (mapping_of_pmf p))))"
+lemma cond_pmf_impl_code_alt:
+ assumes "finite A"
+ shows "cond_pmf_impl p A = (
+ let C = A \<inter> set_pmf p;
+ prob = (\<Sum>x\<in>C. pmf p x)
+ in if prob = 0 then
+ None
+ else
+ Some (Mapping.map_values (\<lambda>_ y. y / prob)
+ (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))"
proof -
- define xs' where "xs' = remdups (filter (\<lambda>x. x \<in> set_pmf p) xs)"
- have xs': "set xs' = set xs \<inter> set_pmf p" "distinct xs'" by (auto simp: xs'_def)
- define prob where "prob = listsum (map (pmf p) xs')"
- have "prob = (\<Sum>x\<in>set xs'. pmf p x)"
- unfolding prob_def by (rule listsum_distinct_conv_setsum_set) (simp_all add: xs'_def)
- also note xs'(1)
- also have "(\<Sum>x\<in>set xs \<inter> set_pmf p. pmf p x) = (\<Sum>x\<in>set xs. pmf p x)"
+ define C where "C = A \<inter> set_pmf p"
+ define prob where "prob = (\<Sum>x\<in>C. pmf p x)"
+ also note C_def
+ also from assms have "(\<Sum>x\<in>A \<inter> set_pmf p. pmf p x) = (\<Sum>x\<in>A. pmf p x)"
by (intro setsum.mono_neutral_left) (auto simp: set_pmf_eq)
- finally have prob1: "prob = (\<Sum>x\<in>set xs. pmf p x)" .
- hence prob2: "prob = measure_pmf.prob p (set xs)"
- by (subst measure_measure_pmf_finite) simp_all
- have prob3: "prob = 0 \<longleftrightarrow> set xs \<inter> set_pmf p = {}"
- by (subst prob1, subst setsum_nonneg_eq_0_iff) (auto simp: set_pmf_eq)
+ finally have prob1: "prob = (\<Sum>x\<in>A. pmf p x)" .
+ hence prob2: "prob = measure_pmf.prob p A"
+ using assms by (subst measure_measure_pmf_finite) simp_all
+ have prob3: "prob = 0 \<longleftrightarrow> A \<inter> set_pmf p = {}"
+ by (subst prob1, subst setsum_nonneg_eq_0_iff) (auto simp: set_pmf_eq assms)
+ from assms have prob4: "prob = measure_pmf.prob p C"
+ unfolding prob_def by (intro measure_measure_pmf_finite [symmetric]) (simp_all add: C_def)
show ?thesis
proof (cases "prob = 0")
case True
- hence "set xs \<inter> set_pmf p = {}" by (subst (asm) prob3)
- with True show ?thesis by (simp add: Let_def prob_def xs'_def cond_pmf_impl.abs_eq)
+ hence "A \<inter> set_pmf p = {}" by (subst (asm) prob3)
+ with True show ?thesis by (simp add: Let_def prob_def C_def cond_pmf_impl.abs_eq)
next
case False
- hence A: "set xs' \<noteq> {}" unfolding xs' by (subst (asm) prob3) auto
- with xs' prob3 have prob_nz: "prob \<noteq> 0" by auto
+ hence A: "C \<noteq> {}" unfolding C_def by (subst (asm) prob3) auto
+ with prob3 have prob_nz: "prob \<noteq> 0" by (auto simp: C_def)
fix x
- have "cond_pmf_impl p (set xs) =
- Some (mapping.Mapping (\<lambda>x. if x \<in> set xs' then
- Some (pmf p x / measure_pmf.prob p (set xs)) else None))"
+ have "cond_pmf_impl p A =
+ Some (mapping.Mapping (\<lambda>x. if x \<in> C then
+ Some (pmf p x / measure_pmf.prob p C) else None))"
(is "_ = Some ?m")
- using A unfolding xs'_def by transfer auto
- also have "?m = Mapping.map_values (\<lambda>y. y / prob)
- (Mapping.filter (\<lambda>k _. k \<in> set xs') (mapping_of_pmf p))"
- unfolding prob2 [symmetric] xs' using xs' prob_nz
- by transfer (rule ext, simp add: set_pmf_eq)
- finally show ?thesis using False by (simp add: Let_def prob_def xs'_def)
+ using A prob2 prob4 unfolding C_def by transfer (auto simp: fun_eq_iff)
+ also have "?m = Mapping.map_values (\<lambda>_ y. y / prob)
+ (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))"
+ using prob_nz prob4 assms unfolding C_def
+ by transfer (auto simp: fun_eq_iff set_pmf_eq)
+ finally show ?thesis using False by (simp add: Let_def prob_def C_def)
qed
qed
+lemma cond_pmf_impl_code [code]:
+ "cond_pmf_impl p (set xs) = (
+ let C = set xs \<inter> set_pmf p;
+ prob = (\<Sum>x\<in>C. pmf p x)
+ in if prob = 0 then
+ None
+ else
+ Some (Mapping.map_values (\<lambda>_ y. y / prob)
+ (Mapping.filter (\<lambda>k _. k \<in> C) (mapping_of_pmf p))))"
+ by (rule cond_pmf_impl_code_alt) simp_all
+
lemma cond_pmf_code [code abstract]:
"mapping_of_pmf (cond_pmf p A) =
(case cond_pmf_impl p A of
@@ -290,7 +380,8 @@
lemma binomial_pmf_code [code abstract]:
"mapping_of_pmf (binomial_pmf n p) = (
if p < 0 \<or> p > 1 then
- Code.abort (STR ''binomial_pmf with invalid probability'') (\<lambda>_. mapping_of_pmf (binomial_pmf n p))
+ Code.abort (STR ''binomial_pmf with invalid probability'')
+ (\<lambda>_. mapping_of_pmf (binomial_pmf n p))
else if p = 0 then Mapping.update 0 1 Mapping.empty
else if p = 1 then Mapping.update n 1 Mapping.empty
else Mapping.tabulate [0..<Suc n] (\<lambda>k. real (n choose k) * p ^ k * (1 - p) ^ (n - k)))"
@@ -298,53 +389,12 @@
(simp, intro mapping_of_pmfI,
auto simp: lookup_update' lookup_empty set_pmf_binomial_eq lookup_tabulate split: if_splits)
+
lemma pred_pmf_code [code]:
"pred_pmf P p = (\<forall>x\<in>set_pmf p. P x)"
by (auto simp: pred_pmf_def)
-definition pmf_integral where
- "pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \<Rightarrow> real)"
-
-definition pmf_set_integral where
- "pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\<lambda>x. indicator A x * f x :: real)"
-
-definition pmf_prob where
- "pmf_prob p A = measure_pmf.prob p A"
-
-lemma pmf_integral_pmf_set_integral [code]:
- "pmf_integral p f = pmf_set_integral p f (set_pmf p)"
- unfolding pmf_integral_def pmf_set_integral_def
- by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff)
-
-lemma pmf_set_integral_code [code]:
- "pmf_set_integral p f (set xs) = listsum (map (\<lambda>x. pmf p x * f x) (remdups xs))"
-proof -
- have "listsum (map (\<lambda>x. pmf p x * f x) (remdups xs)) = (\<Sum>x\<in>set xs. pmf p x * f x)"
- by (subst listsum_distinct_conv_setsum_set) simp_all
- also have "\<dots> = pmf_set_integral p f (set xs)" unfolding pmf_set_integral_def
- by (subst integral_measure_pmf[of "set xs"])
- (auto simp: indicator_def mult_ac split: if_splits)
- finally show ?thesis ..
-qed
-
-lemma pmf_prob_code [code]:
- "pmf_prob p (set xs) = listsum (map (pmf p) (remdups xs))"
-proof -
- have "pmf_prob p (set xs) = pmf_set_integral p (\<lambda>_. 1) (set xs)"
- unfolding pmf_prob_def pmf_set_integral_def by simp
- also have "\<dots> = listsum (map (pmf p) (remdups xs))"
- unfolding pmf_set_integral_code by simp
- finally show ?thesis .
-qed
-
-lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p"
- by (intro ext) (simp add: pmf_prob_def)
-
-(* Why does this not work without parameters? *)
-lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p"
- by (intro ext) (simp add: pmf_integral_def)
-
lemma mapping_of_pmf_pmf_of_list:
assumes "\<And>x. x \<in> snd ` set xs \<Longrightarrow> x > 0" "listsum (map snd xs) = 1"
shows "mapping_of_pmf (pmf_of_list xs) =
@@ -389,7 +439,6 @@
Code.abort (STR ''Invalid list for pmf_of_list'') (\<lambda>_. mapping_of_pmf (pmf_of_list xs)))"
using mapping_of_pmf_pmf_of_list'[of xs] by (simp add: Let_def)
-
lemma mapping_of_pmf_eq_iff [simp]:
"mapping_of_pmf p = mapping_of_pmf q \<longleftrightarrow> p = (q :: 'a pmf)"
proof (transfer, intro iffI pmf_eqI)
@@ -402,6 +451,66 @@
from this[of x] show "pmf p x = pmf q x" by (auto split: if_splits)
qed (simp_all cong: if_cong)
+
+subsection \<open>Code abbreviations for integrals and probabilities\<close>
+
+text \<open>
+ Integrals and probabilities are defined for general measures, so we cannot give any
+ code equations directly. We can, however, specialise these constants them to PMFs,
+ give code equations for these specialised constants, and tell the code generator
+ to unfold the original constants to the specialised ones whenever possible.
+\<close>
+
+definition pmf_integral where
+ "pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \<Rightarrow> real)"
+
+definition pmf_set_integral where
+ "pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\<lambda>x. indicator A x * f x :: real)"
+
+definition pmf_prob where
+ "pmf_prob p A = measure_pmf.prob p A"
+
+lemma pmf_prob_compl: "pmf_prob p (-A) = 1 - pmf_prob p A"
+ using measure_pmf.prob_compl[of A p] by (simp add: pmf_prob_def Compl_eq_Diff_UNIV)
+
+lemma pmf_integral_pmf_set_integral [code]:
+ "pmf_integral p f = pmf_set_integral p f (set_pmf p)"
+ unfolding pmf_integral_def pmf_set_integral_def
+ by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff)
+
+lemma pmf_prob_pmf_set_integral:
+ "pmf_prob p A = pmf_set_integral p (\<lambda>_. 1) A"
+ by (simp add: pmf_prob_def pmf_set_integral_def)
+
+lemma pmf_set_integral_code_alt_finite:
+ "finite A \<Longrightarrow> pmf_set_integral p f A = (\<Sum>x\<in>A. pmf p x * f x)"
+ unfolding pmf_set_integral_def
+ by (subst integral_measure_pmf[of A]) (auto simp: indicator_def mult_ac split: if_splits)
+
+lemma pmf_set_integral_code [code]:
+ "pmf_set_integral p f (set xs) = (\<Sum>x\<in>set xs. pmf p x * f x)"
+ by (rule pmf_set_integral_code_alt_finite) simp_all
+
+
+lemma pmf_prob_code_alt_finite:
+ "finite A \<Longrightarrow> pmf_prob p A = (\<Sum>x\<in>A. pmf p x)"
+ by (simp add: pmf_prob_pmf_set_integral pmf_set_integral_code_alt_finite)
+
+lemma pmf_prob_code [code]:
+ "pmf_prob p (set xs) = (\<Sum>x\<in>set xs. pmf p x)"
+ "pmf_prob p (List.coset xs) = 1 - (\<Sum>x\<in>set xs. pmf p x)"
+ by (simp_all add: pmf_prob_code_alt_finite pmf_prob_compl)
+
+
+lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p"
+ by (intro ext) (simp add: pmf_prob_def)
+
+(* FIXME: Why does this not work without parameters? *)
+lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p"
+ by (intro ext) (simp add: pmf_integral_def)
+
+
+
definition "pmf_of_alist xs = embed_pmf (\<lambda>x. case map_of xs x of Some p \<Rightarrow> p | None \<Rightarrow> 0)"
lemma pmf_of_mapping_Mapping [code_post]:
@@ -447,21 +556,6 @@
no_notation fcomp (infixl "\<circ>>" 60)
no_notation scomp (infixl "\<circ>\<rightarrow>" 60)
-(*
-instantiation pmf :: (exhaustive) exhaustive
-begin
-
-definition exhaustive_pmf :: "('a pmf \<Rightarrow> (bool \<times> term list) option) \<Rightarrow> natural \<Rightarrow> (bool \<times> term list) option"
-where
- "exhaustive_pmf f i =
- Quickcheck_Exhaustive.exhaustive (\<lambda>A.
- Quickcheck_Exhaustive.exhaustive (\<lambda>x. f (pmf_of_multiset (A + {#x#}))) i) i"
-
-instance ..
-
-end
-*)
-
instantiation pmf :: (full_exhaustive) full_exhaustive
begin