# HG changeset patch # User eberlm # Date 1464781714 -7200 # Node ID f3f08c0d4aafef6b6fb96867a3f73e17fe2dbeb7 # Parent 0b7bdb75f451b52fcfd4e1fd3217c654af60c995 Tuned code equations for mappings and PMFs diff -r 0b7bdb75f451 -r f3f08c0d4aaf src/HOL/Library/AList_Mapping.thy --- a/src/HOL/Library/AList_Mapping.thy Tue May 31 13:02:44 2016 +0200 +++ b/src/HOL/Library/AList_Mapping.thy Wed Jun 01 13:48:34 2016 +0200 @@ -64,13 +64,25 @@ qed lemma map_values_Mapping [code]: - fixes f :: "'a \ 'b" and xs :: "('c \ 'a) list" - shows "Mapping.map_values f (Mapping xs) = Mapping (map (\(x,y). (x, f y)) xs)" + fixes f :: "'c \ 'a \ 'b" and xs :: "('c \ 'a) list" + shows "Mapping.map_values f (Mapping xs) = Mapping (map (\(x,y). (x, f x y)) xs)" proof (transfer, rule ext, goal_cases) case (1 f xs x) thus ?case by (induction xs) auto qed +lemma combine_with_key_code [code]: + "Mapping.combine_with_key f (Mapping xs) (Mapping ys) = + Mapping.tabulate (remdups (map fst xs @ map fst ys)) + (\x. the (combine_options (f x) (map_of xs x) (map_of ys x)))" +proof (transfer, rule ext, rule sym, goal_cases) + case (1 f xs ys x) + show ?case + by (cases "map_of xs x"; cases "map_of ys x"; simp) + (force simp: map_of_eq_None_iff combine_options_def option.the_def o_def image_iff + dest: map_of_SomeD split: option.splits)+ +qed + lemma combine_code [code]: "Mapping.combine f (Mapping xs) (Mapping ys) = Mapping.tabulate (remdups (map fst xs @ map fst ys)) @@ -79,7 +91,7 @@ case (1 f xs ys x) show ?case by (cases "map_of xs x"; cases "map_of ys x"; simp) - (force simp: map_of_eq_None_iff combine_options_altdef option.the_def o_def image_iff + (force simp: map_of_eq_None_iff combine_options_def option.the_def o_def image_iff dest: map_of_SomeD split: option.splits)+ qed diff -r 0b7bdb75f451 -r f3f08c0d4aaf src/HOL/Library/DAList_Multiset.thy --- a/src/HOL/Library/DAList_Multiset.thy Tue May 31 13:02:44 2016 +0200 +++ b/src/HOL/Library/DAList_Multiset.thy Wed Jun 01 13:48:34 2016 +0200 @@ -12,6 +12,8 @@ lemma [code, code del]: "{#} = {#}" .. +lemma [code, code del]: "Multiset.is_empty = Multiset.is_empty" .. + lemma [code, code del]: "single = single" .. lemma [code, code del]: "plus = (plus :: 'a multiset \ _)" .. @@ -187,6 +189,27 @@ lemma Mempty_Bag [code]: "{#} = Bag (DAList.empty)" by (simp add: multiset_eq_iff alist.Alist_inverse DAList.empty_def) +lift_definition is_empty_Bag_impl :: "('a, nat) alist \ bool" is + "\xs. list_all (\x. snd x = 0) xs" . + +lemma is_empty_Bag [code]: "Multiset.is_empty (Bag xs) \ is_empty_Bag_impl xs" +proof - + have "Multiset.is_empty (Bag xs) \ (\x. count (Bag xs) x = 0)" + unfolding Multiset.is_empty_def multiset_eq_iff by simp + also have "\ \ (\x\fst ` set (alist.impl_of xs). count (Bag xs) x = 0)" + proof (intro iffI allI ballI) + fix x assume A: "\x\fst ` set (alist.impl_of xs). count (Bag xs) x = 0" + thus "count (Bag xs) x = 0" + proof (cases "x \ fst ` set (alist.impl_of xs)") + case False + thus ?thesis by (force simp: count_of_def split: option.splits) + qed (insert A, auto) + qed simp_all + also have "\ \ list_all (\x. snd x = 0) (alist.impl_of xs)" + by (auto simp: count_of_def list_all_def) + finally show ?thesis by (simp add: is_empty_Bag_impl.rep_eq) +qed + lemma single_Bag [code]: "{#x#} = Bag (DAList.update x 1 DAList.empty)" by (simp add: multiset_eq_iff alist.Alist_inverse update.rep_eq empty.rep_eq) diff -r 0b7bdb75f451 -r f3f08c0d4aaf src/HOL/Library/Mapping.thy --- a/src/HOL/Library/Mapping.thy Tue May 31 13:02:44 2016 +0200 +++ b/src/HOL/Library/Mapping.thy Wed Jun 01 13:48:34 2016 +0200 @@ -240,6 +240,32 @@ subsection \Properties\ +lemma mapping_eqI: + "(\x. lookup m x = lookup m' x) \ m = m'" + by transfer (simp add: fun_eq_iff) + +lemma mapping_eqI': + assumes "\x. x \ Mapping.keys m \ Mapping.lookup_default d m x = Mapping.lookup_default d m' x" + and "Mapping.keys m = Mapping.keys m'" + shows "m = m'" +proof (intro mapping_eqI) + fix x + show "Mapping.lookup m x = Mapping.lookup m' x" + proof (cases "Mapping.lookup m x") + case None + hence "x \ Mapping.keys m" by transfer (simp add: dom_def) + hence "x \ Mapping.keys m'" by (simp add: assms) + hence "Mapping.lookup m' x = None" by transfer (simp add: dom_def) + with None show ?thesis by simp + next + case (Some y) + hence A: "x \ Mapping.keys m" by transfer (simp add: dom_def) + hence "x \ Mapping.keys m'" by (simp add: assms) + hence "\y'. Mapping.lookup m' x = Some y'" by transfer (simp add: dom_def) + with Some assms(1)[OF A] show ?thesis by (auto simp add: lookup_default_def) + qed +qed + lemma lookup_update: "lookup (update k v m) k = Some v" by transfer simp @@ -314,6 +340,51 @@ f (Mapping.lookup_default d m1 x) (Mapping.lookup_default d m2 x)" by (auto simp: lookup_default_def lookup_combine assms split: option.splits) +lemma lookup_map_entry: + "lookup (map_entry x f m) x = map_option f (lookup m x)" + by transfer (auto split: option.splits) + +lemma lookup_map_entry_neq: + "x \ y \ lookup (map_entry x f m) y = lookup m y" + by transfer (auto split: option.splits) + +lemma lookup_map_entry': + "lookup (map_entry x f m) y = + (if x = y then map_option f (lookup m y) else lookup m y)" + by transfer (auto split: option.splits) + +lemma lookup_default: + "lookup (default x d m) x = Some (lookup_default d m x)" + unfolding lookup_default_def default_def + by transfer (auto split: option.splits) + +lemma lookup_default_neq: + "x \ y \ lookup (default x d m) y = lookup m y" + unfolding lookup_default_def default_def + by transfer (auto split: option.splits) + +lemma lookup_default': + "lookup (default x d m) y = + (if x = y then Some (lookup_default d m x) else lookup m y)" + unfolding lookup_default_def default_def + by transfer (auto split: option.splits) + +lemma lookup_map_default: + "lookup (map_default x d f m) x = Some (f (lookup_default d m x))" + unfolding lookup_default_def default_def + by (simp add: map_default_def lookup_map_entry lookup_default lookup_default_def) + +lemma lookup_map_default_neq: + "x \ y \ lookup (map_default x d f m) y = lookup m y" + unfolding lookup_default_def default_def + by (simp add: map_default_def lookup_map_entry_neq lookup_default_neq) + +lemma lookup_map_default': + "lookup (map_default x d f m) y = + (if x = y then Some (f (lookup_default d m x)) else lookup m y)" + unfolding lookup_default_def default_def + by (simp add: map_default_def lookup_map_entry' lookup_default' lookup_default_def) + lemma lookup_tabulate: assumes "distinct xs" shows "Mapping.lookup (Mapping.tabulate xs f) x = (if x \ set xs then Some (f x) else None)" diff -r 0b7bdb75f451 -r f3f08c0d4aaf src/HOL/Library/Multiset.thy --- a/src/HOL/Library/Multiset.thy Tue May 31 13:02:44 2016 +0200 +++ b/src/HOL/Library/Multiset.thy Wed Jun 01 13:48:34 2016 +0200 @@ -90,6 +90,15 @@ end +context +begin + +qualified definition is_empty :: "'a multiset \ bool" where + [code_abbrev]: "is_empty A \ A = {#}" + +end + + lift_definition single :: "'a \ 'a multiset" is "\a b. if b = a then 1 else 0" by (rule only1_in_multiset) @@ -2583,6 +2592,9 @@ lemma [code]: "{#} = mset []" by simp +lemma [code]: "Multiset.is_empty (mset xs) \ List.null xs" + by (simp add: Multiset.is_empty_def List.null_def) + lemma [code]: "{#x#} = mset [x]" by simp diff -r 0b7bdb75f451 -r f3f08c0d4aaf src/HOL/Probability/PMF_Impl.thy --- a/src/HOL/Probability/PMF_Impl.thy Tue May 31 13:02:44 2016 +0200 +++ b/src/HOL/Probability/PMF_Impl.thy Wed Jun 01 13:48:34 2016 +0200 @@ -5,10 +5,14 @@ by default. Also includes Quickcheck setup for PMFs. *) +section \Code generation for PMFs\ + theory PMF_Impl imports Probability_Mass_Function "~~/src/HOL/Library/AList_Mapping" begin +subsection \General code generation setup\ + definition pmf_of_mapping :: "('a, real) mapping \ 'a pmf" where "pmf_of_mapping m = embed_pmf (Mapping.lookup_default 0 m)" @@ -95,7 +99,26 @@ definition pmf_of_set_impl where "pmf_of_set_impl A = mapping_of_pmf (pmf_of_set A)" - + +(* This equation can be used to easily implement pmf_of_set for other set implementations *) +lemma pmf_of_set_impl_code_alt: + assumes "A \ {}" "finite A" + shows "pmf_of_set_impl A = + (let p = 1 / real (card A) + in Finite_Set.fold (\x. Mapping.update x p) Mapping.empty A)" +proof - + define p where "p = 1 / real (card A)" + let ?m = "Finite_Set.fold (\x. Mapping.update x p) Mapping.empty A" + interpret comp_fun_idem "\x. Mapping.update x p" + by standard (transfer, force simp: fun_eq_iff)+ + have keys: "Mapping.keys ?m = A" + using assms(2) by (induction A rule: finite_induct) simp_all + have lookup: "Mapping.lookup ?m x = Some p" if "x \ A" for x + using assms(2) that by (induction A rule: finite_induct) (auto simp: lookup_update') + from keys lookup assms show ?thesis unfolding pmf_of_set_impl_def + by (intro mapping_of_pmfI) (simp_all add: Let_def p_def) +qed + lemma pmf_of_set_impl_code [code]: "pmf_of_set_impl (set xs) = (if xs = [] then @@ -116,7 +139,27 @@ using assms by (intro mapping_of_pmfI) (auto simp: lookup_tabulate) definition pmf_of_multiset_impl where - "pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)" + "pmf_of_multiset_impl A = mapping_of_pmf (pmf_of_multiset A)" + +lemma pmf_of_multiset_impl_code_alt: + assumes "A \ {#}" + shows "pmf_of_multiset_impl A = + (let p = 1 / real (size A) + in fold_mset (\x. Mapping.map_default x 0 (op + p)) Mapping.empty A)" +proof - + define p where "p = 1 / real (size A)" + interpret comp_fun_commute "\x. Mapping.map_default x 0 (op + p)" + unfolding Mapping.map_default_def [abs_def] + by (standard, intro mapping_eqI ext) + (simp_all add: o_def lookup_map_entry' lookup_default' lookup_default_def) + let ?m = "fold_mset (\x. Mapping.map_default x 0 (op + p)) Mapping.empty A" + have keys: "Mapping.keys ?m = set_mset A" by (induction A) simp_all + have lookup: "Mapping.lookup_default 0 ?m x = real (count A x) * p" for x + by (induction A) + (simp_all add: lookup_map_default' lookup_default_def lookup_empty ring_distribs) + from keys lookup assms show ?thesis unfolding pmf_of_multiset_impl_def + by (intro mapping_of_pmfI') (simp_all add: Let_def p_def) +qed lemma pmf_of_multiset_impl_code [code]: "pmf_of_multiset_impl (mset xs) = @@ -126,12 +169,13 @@ else let xs' = remdups xs; p = 1 / real (length xs) in Mapping.tabulate xs' (\x. real (count (mset xs) x) * p))" using pmf_of_multiset_pmf_of_mapping[of "mset xs" "remdups xs"] - by (simp add: pmf_of_multiset_impl_def) + by (simp add: pmf_of_multiset_impl_def) lemma pmf_of_multiset_code [code abstract]: "mapping_of_pmf (pmf_of_multiset A) = pmf_of_multiset_impl A" by (simp add: pmf_of_multiset_impl_def) + lemma bernoulli_pmf_code [code abstract]: "mapping_of_pmf (bernoulli_pmf p) = (if p \ 0 then Mapping.update False 1 Mapping.empty @@ -140,8 +184,6 @@ by (intro mapping_of_pmfI) (auto simp: bernoulli_pmf.rep_eq lookup_update' set_pmf_eq) - - lemma pmf_code [code]: "pmf p x = Mapping.lookup_default 0 (mapping_of_pmf p) x" unfolding mapping_of_pmf_def Mapping.lookup_default_def by (auto split: option.splits simp: id_def Mapping.lookup.abs_eq) @@ -154,14 +196,6 @@ -(* This is necessary since we want something the guarantees finiteness, but simply using - "finite" restricts the code equations to types where finiteness of the universe can - be decided. This simply fails when finiteness is not clear *) -definition is_list_set where "is_list_set A = finite A" - -lemma is_list_set_code [code]: "is_list_set (set xs) = True" - by (simp add: is_list_set_def) - definition fold_combine_plus where "fold_combine_plus = comm_monoid_set.F (Mapping.combine (op + :: real \ _)) Mapping.empty" @@ -189,15 +223,15 @@ by (simp add: fold_combine_plus_def fold_combine_plus.fold_combine_code) private lemma lookup_default_0_map_values: - assumes "f 0 = 0" - shows "Mapping.lookup_default 0 (Mapping.map_values f m) x = f (Mapping.lookup_default 0 m x)" + assumes "f x 0 = 0" + shows "Mapping.lookup_default 0 (Mapping.map_values f m) x = f x (Mapping.lookup_default 0 m x)" unfolding Mapping.lookup_default_def - using assms by transfer (auto split: option.splits) + using assms by transfer (auto split: option.splits) qualified lemma mapping_of_bind_pmf: assumes "finite (set_pmf p)" shows "mapping_of_pmf (bind_pmf p f) = - fold_combine_plus (\x. Mapping.map_values (op * (pmf p x)) + fold_combine_plus (\x. Mapping.map_values (\_. op * (pmf p x)) (mapping_of_pmf (f x))) (set_pmf p)" using assms by (intro mapping_of_pmfI') @@ -205,71 +239,127 @@ pmf_bind integral_measure_pmf lookup_default_0_map_values lookup_default_mapping_of_pmf mult_ac) -lemma bind_pmf_code [code abstract]: - "mapping_of_pmf (bind_pmf p f) = - (let A = set_pmf p in if is_list_set A then - fold_combine_plus (\x. Mapping.map_values (op * (pmf p x)) (mapping_of_pmf (f x))) A - else - Code.abort (STR ''bind_pmf with infinite support.'') (\_. mapping_of_pmf (bind_pmf p f)))" - using mapping_of_bind_pmf[of p f] by (auto simp: Let_def is_list_set_def) +lift_definition bind_pmf_aux :: "'a pmf \ ('a \ 'b pmf) \ 'a set \ ('b, real) mapping" is + "\(p :: 'a pmf) (f :: 'a \ 'b pmf) (A::'a set) (x::'b). + if x \ (\y\A. set_pmf (f y)) then + Some (measure_pmf.expectation p (\y. indicator A y * pmf (f y) x)) + else None" . + +lemma keys_bind_pmf_aux [simp]: + "Mapping.keys (bind_pmf_aux p f A) = (\x\A. set_pmf (f x))" + by transfer (auto split: if_splits) + +lemma lookup_default_bind_pmf_aux: + "Mapping.lookup_default 0 (bind_pmf_aux p f A) x = + (if x \ (\y\A. set_pmf (f y)) then + measure_pmf.expectation p (\y. indicator A y * pmf (f y) x) else 0)" + unfolding lookup_default_def by transfer' simp_all + +lemma lookup_default_bind_pmf_aux' [simp]: + "Mapping.lookup_default 0 (bind_pmf_aux p f (set_pmf p)) x = pmf (bind_pmf p f) x" + unfolding lookup_default_def + by transfer (auto simp: pmf_bind AE_measure_pmf_iff set_pmf_eq + intro!: integral_cong_AE integral_eq_zero_AE) + +lemma bind_pmf_aux_correct: + "mapping_of_pmf (bind_pmf p f) = bind_pmf_aux p f (set_pmf p)" + by (intro mapping_of_pmfI') simp_all + +lemma bind_pmf_aux_code_aux: + assumes "finite A" + shows "bind_pmf_aux p f A = + fold_combine_plus (\x. Mapping.map_values (\_. op * (pmf p x)) + (mapping_of_pmf (f x))) A" (is "?lhs = ?rhs") +proof (intro mapping_eqI'[where d = 0]) + fix x assume "x \ Mapping.keys ?lhs" + then obtain y where y: "y \ A" "x \ set_pmf (f y)" by auto + hence "Mapping.lookup_default 0 ?lhs x = + measure_pmf.expectation p (\y. indicator A y * pmf (f y) x)" + by (auto simp: lookup_default_bind_pmf_aux) + also from assms have "\ = (\y\A. pmf p y * pmf (f y) x)" + by (subst integral_measure_pmf [of A]) + (auto simp: set_pmf_eq indicator_def mult_ac split: if_splits) + also from assms have "\ = Mapping.lookup_default 0 ?rhs x" + by (simp add: lookup_default_fold_combine_plus lookup_default_0_map_values + lookup_default_mapping_of_pmf) + finally show "Mapping.lookup_default 0 ?lhs x = Mapping.lookup_default 0 ?rhs x" . +qed (insert assms, simp_all add: keys_fold_combine_plus) + +lemma bind_pmf_aux_code [code]: + "bind_pmf_aux p f (set xs) = + fold_combine_plus (\x. Mapping.map_values (\_. op * (pmf p x)) + (mapping_of_pmf (f x))) (set xs)" + by (rule bind_pmf_aux_code_aux) simp_all + +lemmas bind_pmf_code [code abstract] = bind_pmf_aux_correct end -hide_const (open) is_list_set fold_combine_plus +hide_const (open) fold_combine_plus lift_definition cond_pmf_impl :: "'a pmf \ 'a set \ ('a, real) mapping option" is "\p A. if A \ set_pmf p = {} then None else Some (\x. if x \ A \ set_pmf p then Some (pmf p x / measure_pmf.prob p A) else None)" . -lemma cond_pmf_impl_code [code]: - "cond_pmf_impl p (set xs) = ( - let B = set_pmf p; - xs' = remdups (filter (\x. x \ B) xs); - prob = listsum (map (pmf p) xs') - in if prob = 0 then - None - else - Some (Mapping.map_values (\y. y / prob) - (Mapping.filter (\k _. k \ set xs') (mapping_of_pmf p))))" +lemma cond_pmf_impl_code_alt: + assumes "finite A" + shows "cond_pmf_impl p A = ( + let C = A \ set_pmf p; + prob = (\x\C. pmf p x) + in if prob = 0 then + None + else + Some (Mapping.map_values (\_ y. y / prob) + (Mapping.filter (\k _. k \ C) (mapping_of_pmf p))))" proof - - define xs' where "xs' = remdups (filter (\x. x \ set_pmf p) xs)" - have xs': "set xs' = set xs \ set_pmf p" "distinct xs'" by (auto simp: xs'_def) - define prob where "prob = listsum (map (pmf p) xs')" - have "prob = (\x\set xs'. pmf p x)" - unfolding prob_def by (rule listsum_distinct_conv_setsum_set) (simp_all add: xs'_def) - also note xs'(1) - also have "(\x\set xs \ set_pmf p. pmf p x) = (\x\set xs. pmf p x)" + define C where "C = A \ set_pmf p" + define prob where "prob = (\x\C. pmf p x)" + also note C_def + also from assms have "(\x\A \ set_pmf p. pmf p x) = (\x\A. pmf p x)" by (intro setsum.mono_neutral_left) (auto simp: set_pmf_eq) - finally have prob1: "prob = (\x\set xs. pmf p x)" . - hence prob2: "prob = measure_pmf.prob p (set xs)" - by (subst measure_measure_pmf_finite) simp_all - have prob3: "prob = 0 \ set xs \ set_pmf p = {}" - by (subst prob1, subst setsum_nonneg_eq_0_iff) (auto simp: set_pmf_eq) + finally have prob1: "prob = (\x\A. pmf p x)" . + hence prob2: "prob = measure_pmf.prob p A" + using assms by (subst measure_measure_pmf_finite) simp_all + have prob3: "prob = 0 \ A \ set_pmf p = {}" + by (subst prob1, subst setsum_nonneg_eq_0_iff) (auto simp: set_pmf_eq assms) + from assms have prob4: "prob = measure_pmf.prob p C" + unfolding prob_def by (intro measure_measure_pmf_finite [symmetric]) (simp_all add: C_def) show ?thesis proof (cases "prob = 0") case True - hence "set xs \ set_pmf p = {}" by (subst (asm) prob3) - with True show ?thesis by (simp add: Let_def prob_def xs'_def cond_pmf_impl.abs_eq) + hence "A \ set_pmf p = {}" by (subst (asm) prob3) + with True show ?thesis by (simp add: Let_def prob_def C_def cond_pmf_impl.abs_eq) next case False - hence A: "set xs' \ {}" unfolding xs' by (subst (asm) prob3) auto - with xs' prob3 have prob_nz: "prob \ 0" by auto + hence A: "C \ {}" unfolding C_def by (subst (asm) prob3) auto + with prob3 have prob_nz: "prob \ 0" by (auto simp: C_def) fix x - have "cond_pmf_impl p (set xs) = - Some (mapping.Mapping (\x. if x \ set xs' then - Some (pmf p x / measure_pmf.prob p (set xs)) else None))" + have "cond_pmf_impl p A = + Some (mapping.Mapping (\x. if x \ C then + Some (pmf p x / measure_pmf.prob p C) else None))" (is "_ = Some ?m") - using A unfolding xs'_def by transfer auto - also have "?m = Mapping.map_values (\y. y / prob) - (Mapping.filter (\k _. k \ set xs') (mapping_of_pmf p))" - unfolding prob2 [symmetric] xs' using xs' prob_nz - by transfer (rule ext, simp add: set_pmf_eq) - finally show ?thesis using False by (simp add: Let_def prob_def xs'_def) + using A prob2 prob4 unfolding C_def by transfer (auto simp: fun_eq_iff) + also have "?m = Mapping.map_values (\_ y. y / prob) + (Mapping.filter (\k _. k \ C) (mapping_of_pmf p))" + using prob_nz prob4 assms unfolding C_def + by transfer (auto simp: fun_eq_iff set_pmf_eq) + finally show ?thesis using False by (simp add: Let_def prob_def C_def) qed qed +lemma cond_pmf_impl_code [code]: + "cond_pmf_impl p (set xs) = ( + let C = set xs \ set_pmf p; + prob = (\x\C. pmf p x) + in if prob = 0 then + None + else + Some (Mapping.map_values (\_ y. y / prob) + (Mapping.filter (\k _. k \ C) (mapping_of_pmf p))))" + by (rule cond_pmf_impl_code_alt) simp_all + lemma cond_pmf_code [code abstract]: "mapping_of_pmf (cond_pmf p A) = (case cond_pmf_impl p A of @@ -290,7 +380,8 @@ lemma binomial_pmf_code [code abstract]: "mapping_of_pmf (binomial_pmf n p) = ( if p < 0 \ p > 1 then - Code.abort (STR ''binomial_pmf with invalid probability'') (\_. mapping_of_pmf (binomial_pmf n p)) + Code.abort (STR ''binomial_pmf with invalid probability'') + (\_. mapping_of_pmf (binomial_pmf n p)) else if p = 0 then Mapping.update 0 1 Mapping.empty else if p = 1 then Mapping.update n 1 Mapping.empty else Mapping.tabulate [0..k. real (n choose k) * p ^ k * (1 - p) ^ (n - k)))" @@ -298,53 +389,12 @@ (simp, intro mapping_of_pmfI, auto simp: lookup_update' lookup_empty set_pmf_binomial_eq lookup_tabulate split: if_splits) + lemma pred_pmf_code [code]: "pred_pmf P p = (\x\set_pmf p. P x)" by (auto simp: pred_pmf_def) -definition pmf_integral where - "pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \ real)" - -definition pmf_set_integral where - "pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\x. indicator A x * f x :: real)" - -definition pmf_prob where - "pmf_prob p A = measure_pmf.prob p A" - -lemma pmf_integral_pmf_set_integral [code]: - "pmf_integral p f = pmf_set_integral p f (set_pmf p)" - unfolding pmf_integral_def pmf_set_integral_def - by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff) - -lemma pmf_set_integral_code [code]: - "pmf_set_integral p f (set xs) = listsum (map (\x. pmf p x * f x) (remdups xs))" -proof - - have "listsum (map (\x. pmf p x * f x) (remdups xs)) = (\x\set xs. pmf p x * f x)" - by (subst listsum_distinct_conv_setsum_set) simp_all - also have "\ = pmf_set_integral p f (set xs)" unfolding pmf_set_integral_def - by (subst integral_measure_pmf[of "set xs"]) - (auto simp: indicator_def mult_ac split: if_splits) - finally show ?thesis .. -qed - -lemma pmf_prob_code [code]: - "pmf_prob p (set xs) = listsum (map (pmf p) (remdups xs))" -proof - - have "pmf_prob p (set xs) = pmf_set_integral p (\_. 1) (set xs)" - unfolding pmf_prob_def pmf_set_integral_def by simp - also have "\ = listsum (map (pmf p) (remdups xs))" - unfolding pmf_set_integral_code by simp - finally show ?thesis . -qed - -lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p" - by (intro ext) (simp add: pmf_prob_def) - -(* Why does this not work without parameters? *) -lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p" - by (intro ext) (simp add: pmf_integral_def) - lemma mapping_of_pmf_pmf_of_list: assumes "\x. x \ snd ` set xs \ x > 0" "listsum (map snd xs) = 1" shows "mapping_of_pmf (pmf_of_list xs) = @@ -389,7 +439,6 @@ Code.abort (STR ''Invalid list for pmf_of_list'') (\_. mapping_of_pmf (pmf_of_list xs)))" using mapping_of_pmf_pmf_of_list'[of xs] by (simp add: Let_def) - lemma mapping_of_pmf_eq_iff [simp]: "mapping_of_pmf p = mapping_of_pmf q \ p = (q :: 'a pmf)" proof (transfer, intro iffI pmf_eqI) @@ -402,6 +451,66 @@ from this[of x] show "pmf p x = pmf q x" by (auto split: if_splits) qed (simp_all cong: if_cong) + +subsection \Code abbreviations for integrals and probabilities\ + +text \ + Integrals and probabilities are defined for general measures, so we cannot give any + code equations directly. We can, however, specialise these constants them to PMFs, + give code equations for these specialised constants, and tell the code generator + to unfold the original constants to the specialised ones whenever possible. +\ + +definition pmf_integral where + "pmf_integral p f = lebesgue_integral (measure_pmf p) (f :: _ \ real)" + +definition pmf_set_integral where + "pmf_set_integral p f A = lebesgue_integral (measure_pmf p) (\x. indicator A x * f x :: real)" + +definition pmf_prob where + "pmf_prob p A = measure_pmf.prob p A" + +lemma pmf_prob_compl: "pmf_prob p (-A) = 1 - pmf_prob p A" + using measure_pmf.prob_compl[of A p] by (simp add: pmf_prob_def Compl_eq_Diff_UNIV) + +lemma pmf_integral_pmf_set_integral [code]: + "pmf_integral p f = pmf_set_integral p f (set_pmf p)" + unfolding pmf_integral_def pmf_set_integral_def + by (intro integral_cong_AE) (simp_all add: AE_measure_pmf_iff) + +lemma pmf_prob_pmf_set_integral: + "pmf_prob p A = pmf_set_integral p (\_. 1) A" + by (simp add: pmf_prob_def pmf_set_integral_def) + +lemma pmf_set_integral_code_alt_finite: + "finite A \ pmf_set_integral p f A = (\x\A. pmf p x * f x)" + unfolding pmf_set_integral_def + by (subst integral_measure_pmf[of A]) (auto simp: indicator_def mult_ac split: if_splits) + +lemma pmf_set_integral_code [code]: + "pmf_set_integral p f (set xs) = (\x\set xs. pmf p x * f x)" + by (rule pmf_set_integral_code_alt_finite) simp_all + + +lemma pmf_prob_code_alt_finite: + "finite A \ pmf_prob p A = (\x\A. pmf p x)" + by (simp add: pmf_prob_pmf_set_integral pmf_set_integral_code_alt_finite) + +lemma pmf_prob_code [code]: + "pmf_prob p (set xs) = (\x\set xs. pmf p x)" + "pmf_prob p (List.coset xs) = 1 - (\x\set xs. pmf p x)" + by (simp_all add: pmf_prob_code_alt_finite pmf_prob_compl) + + +lemma pmf_prob_code_unfold [code_abbrev]: "pmf_prob p = measure_pmf.prob p" + by (intro ext) (simp add: pmf_prob_def) + +(* FIXME: Why does this not work without parameters? *) +lemma pmf_integral_code_unfold [code_abbrev]: "pmf_integral p = measure_pmf.expectation p" + by (intro ext) (simp add: pmf_integral_def) + + + definition "pmf_of_alist xs = embed_pmf (\x. case map_of xs x of Some p \ p | None \ 0)" lemma pmf_of_mapping_Mapping [code_post]: @@ -447,21 +556,6 @@ no_notation fcomp (infixl "\>" 60) no_notation scomp (infixl "\\" 60) -(* -instantiation pmf :: (exhaustive) exhaustive -begin - -definition exhaustive_pmf :: "('a pmf \ (bool \ term list) option) \ natural \ (bool \ term list) option" -where - "exhaustive_pmf f i = - Quickcheck_Exhaustive.exhaustive (\A. - Quickcheck_Exhaustive.exhaustive (\x. f (pmf_of_multiset (A + {#x#}))) i) i" - -instance .. - -end -*) - instantiation pmf :: (full_exhaustive) full_exhaustive begin