src/HOL/Random.thy
changeset 31203 5c8fb4fd67e0
parent 31196 82ff416d7d66
child 31205 98370b26c2ce
--- a/src/HOL/Random.thy	Tue May 19 13:57:31 2009 +0200
+++ b/src/HOL/Random.thy	Tue May 19 13:57:32 2009 +0200
@@ -3,7 +3,7 @@
 header {* A HOL random engine *}
 
 theory Random
-imports Code_Index
+imports Code_Index List
 begin
 
 notation fcomp (infixl "o>" 60)
@@ -42,9 +42,6 @@
 primrec seed_invariant :: "seed \<Rightarrow> bool" where
   "seed_invariant (v, w) \<longleftrightarrow> 0 < v \<and> v < 9438322952 \<and> 0 < w \<and> True"
 
-lemma if_same: "(if b then f x else f y) = f (if b then x else y)"
-  by (cases b) simp_all
-
 definition split_seed :: "seed \<Rightarrow> seed \<times> seed" where
   "split_seed s = (let
      (v, w) = s;
@@ -98,6 +95,14 @@
   "pick (filter (\<lambda>(k, _). k > 0) xs) = pick xs"
   by (induct xs) (auto simp add: expand_fun_eq)
 
+lemma pick_same:
+  "l < length xs \<Longrightarrow> Random.pick (map (Pair 1) xs) (Code_Index.of_nat l) = nth xs l"
+proof (induct xs arbitrary: l)
+  case Nil then show ?case by simp
+next
+  case (Cons x xs) then show ?case by (cases l) simp_all
+qed
+
 definition select_weight :: "(index \<times> 'a) list \<Rightarrow> seed \<Rightarrow> 'a \<times> seed" where
   "select_weight xs = range (listsum (map fst xs))
    o\<rightarrow> (\<lambda>k. Pair (pick xs k))"
@@ -113,6 +118,27 @@
   then show ?thesis by (simp add: select_weight_def scomp_def split_def) 
 qed
 
+lemma select_weigth_drop_zero:
+  "Random.select_weight (filter (\<lambda>(k, _). k > 0) xs) = Random.select_weight xs"
+proof -
+  have "listsum (map fst [(k, _)\<leftarrow>xs . 0 < k]) = listsum (map fst xs)"
+    by (induct xs) auto
+  then show ?thesis by (simp only: select_weight_def pick_drop_zero)
+qed
+
+lemma select_weigth_select:
+  assumes "xs \<noteq> []"
+  shows "Random.select_weight (map (Pair 1) xs) = Random.select xs"
+proof -
+  have less: "\<And>s. fst (Random.range (Code_Index.of_nat (length xs)) s) < Code_Index.of_nat (length xs)"
+    using assms by (intro range) simp
+  moreover have "listsum (map fst (map (Pair 1) xs)) = Code_Index.of_nat (length xs)"
+    by (induct xs) simp_all
+  ultimately show ?thesis
+    by (auto simp add: select_weight_def select_def scomp_def split_def
+      expand_fun_eq pick_same [symmetric])
+qed
+
 definition select_default :: "index \<Rightarrow> 'a \<Rightarrow> 'a \<Rightarrow> seed \<Rightarrow> 'a \<times> seed" where
   [code del]: "select_default k x y = range k
      o\<rightarrow> (\<lambda>l. Pair (if l + 1 < k then x else y))"
@@ -169,7 +195,6 @@
 hide (open) type seed
 hide (open) const inc_shift minus_shift log "next" seed_invariant split_seed
   iterate range select pick select_weight select_default
-hide (open) fact log_def
 
 no_notation fcomp (infixl "o>" 60)
 no_notation scomp (infixl "o\<rightarrow>" 60)