src/HOL/Random.thy
 author haftmann Mon Nov 29 13:44:54 2010 +0100 (2010-11-29) changeset 40815 6e2d17cc0d1d parent 39302 d7728f65b353 child 42163 392fd6c4669c permissions -rw-r--r--
equivI has replaced equiv.intro
 bulwahn@36020 ` 1` haftmann@29815 ` 2` ```(* Author: Florian Haftmann, TU Muenchen *) ``` haftmann@22528 ` 3` haftmann@26265 ` 4` ```header {* A HOL random engine *} ``` haftmann@22528 ` 5` haftmann@22528 ` 6` ```theory Random ``` haftmann@31205 ` 7` ```imports Code_Numeral List ``` haftmann@22528 ` 8` ```begin ``` haftmann@22528 ` 9` haftmann@37751 ` 10` ```notation fcomp (infixl "\>" 60) ``` haftmann@37751 ` 11` ```notation scomp (infixl "\\" 60) ``` haftmann@29823 ` 12` haftmann@29823 ` 13` haftmann@26265 ` 14` ```subsection {* Auxiliary functions *} ``` haftmann@26265 ` 15` haftmann@33236 ` 16` ```fun log :: "code_numeral \ code_numeral \ code_numeral" where ``` haftmann@33236 ` 17` ``` "log b i = (if b \ 1 \ i < b then 1 else 1 + log b (i div b))" ``` haftmann@33236 ` 18` haftmann@31205 ` 19` ```definition inc_shift :: "code_numeral \ code_numeral \ code_numeral" where ``` haftmann@26265 ` 20` ``` "inc_shift v k = (if v = k then 1 else k + 1)" ``` haftmann@26265 ` 21` haftmann@31205 ` 22` ```definition minus_shift :: "code_numeral \ code_numeral \ code_numeral \ code_numeral" where ``` haftmann@26265 ` 23` ``` "minus_shift r k l = (if k < l then r + k - l else k - l)" ``` haftmann@26265 ` 24` haftmann@30495 ` 25` haftmann@26265 ` 26` ```subsection {* Random seeds *} ``` haftmann@26038 ` 27` haftmann@31205 ` 28` ```types seed = "code_numeral \ code_numeral" ``` haftmann@22528 ` 29` haftmann@31205 ` 30` ```primrec "next" :: "seed \ code_numeral \ seed" where ``` haftmann@26265 ` 31` ``` "next (v, w) = (let ``` haftmann@26265 ` 32` ``` k = v div 53668; ``` haftmann@33236 ` 33` ``` v' = minus_shift 2147483563 ((v mod 53668) * 40014) (k * 12211); ``` haftmann@26265 ` 34` ``` l = w div 52774; ``` haftmann@33236 ` 35` ``` w' = minus_shift 2147483399 ((w mod 52774) * 40692) (l * 3791); ``` haftmann@26265 ` 36` ``` z = minus_shift 2147483562 v' (w' + 1) + 1 ``` haftmann@26265 ` 37` ``` in (z, (v', w')))" ``` haftmann@26265 ` 38` haftmann@29823 ` 39` ```definition split_seed :: "seed \ seed \ seed" where ``` haftmann@26038 ` 40` ``` "split_seed s = (let ``` haftmann@26038 ` 41` ``` (v, w) = s; ``` haftmann@26038 ` 42` ``` (v', w') = snd (next s); ``` haftmann@26265 ` 43` ``` v'' = inc_shift 2147483562 v; ``` haftmann@33236 ` 44` ``` w'' = inc_shift 2147483398 w ``` haftmann@33236 ` 45` ``` in ((v'', w'), (v', w'')))" ``` haftmann@26038 ` 46` haftmann@26038 ` 47` haftmann@26265 ` 48` ```subsection {* Base selectors *} ``` haftmann@22528 ` 49` haftmann@31205 ` 50` ```fun iterate :: "code_numeral \ ('b \ 'a \ 'b \ 'a) \ 'b \ 'a \ 'b \ 'a" where ``` haftmann@37751 ` 51` ``` "iterate k f x = (if k = 0 then Pair x else f x \\ iterate (k - 1) f)" ``` haftmann@22528 ` 52` haftmann@31205 ` 53` ```definition range :: "code_numeral \ seed \ code_numeral \ seed" where ``` haftmann@30495 ` 54` ``` "range k = iterate (log 2147483561 k) ``` haftmann@37751 ` 55` ``` (\l. next \\ (\v. Pair (v + l * 2147483561))) 1 ``` haftmann@37751 ` 56` ``` \\ (\v. Pair (v mod k))" ``` haftmann@26265 ` 57` haftmann@26265 ` 58` ```lemma range: ``` haftmann@30495 ` 59` ``` "k > 0 \ fst (range k s) < k" ``` haftmann@30495 ` 60` ``` by (simp add: range_def scomp_apply split_def del: log.simps iterate.simps) ``` haftmann@26038 ` 61` haftmann@29823 ` 62` ```definition select :: "'a list \ seed \ 'a \ seed" where ``` haftmann@31205 ` 63` ``` "select xs = range (Code_Numeral.of_nat (length xs)) ``` haftmann@37751 ` 64` ``` \\ (\k. Pair (nth xs (Code_Numeral.nat_of k)))" ``` haftmann@29823 ` 65` ``` ``` haftmann@26265 ` 66` ```lemma select: ``` haftmann@26265 ` 67` ``` assumes "xs \ []" ``` haftmann@26265 ` 68` ``` shows "fst (select xs s) \ set xs" ``` haftmann@26265 ` 69` ```proof - ``` haftmann@31205 ` 70` ``` from assms have "Code_Numeral.of_nat (length xs) > 0" by simp ``` haftmann@26265 ` 71` ``` with range have ``` haftmann@31205 ` 72` ``` "fst (range (Code_Numeral.of_nat (length xs)) s) < Code_Numeral.of_nat (length xs)" by best ``` haftmann@26265 ` 73` ``` then have ``` haftmann@31205 ` 74` ``` "Code_Numeral.nat_of (fst (range (Code_Numeral.of_nat (length xs)) s)) < length xs" by simp ``` haftmann@26265 ` 75` ``` then show ?thesis ``` haftmann@29823 ` 76` ``` by (simp add: scomp_apply split_beta select_def) ``` haftmann@26265 ` 77` ```qed ``` haftmann@22528 ` 78` haftmann@31205 ` 79` ```primrec pick :: "(code_numeral \ 'a) list \ code_numeral \ 'a" where ``` haftmann@31180 ` 80` ``` "pick (x # xs) i = (if i < fst x then snd x else pick xs (i - fst x))" ``` haftmann@31180 ` 81` haftmann@31180 ` 82` ```lemma pick_member: ``` haftmann@31180 ` 83` ``` "i < listsum (map fst xs) \ pick xs i \ set (map snd xs)" ``` haftmann@31180 ` 84` ``` by (induct xs arbitrary: i) simp_all ``` haftmann@31180 ` 85` haftmann@31180 ` 86` ```lemma pick_drop_zero: ``` haftmann@31180 ` 87` ``` "pick (filter (\(k, _). k > 0) xs) = pick xs" ``` nipkow@39302 ` 88` ``` by (induct xs) (auto simp add: fun_eq_iff) ``` haftmann@31180 ` 89` haftmann@31203 ` 90` ```lemma pick_same: ``` haftmann@31205 ` 91` ``` "l < length xs \ Random.pick (map (Pair 1) xs) (Code_Numeral.of_nat l) = nth xs l" ``` haftmann@31203 ` 92` ```proof (induct xs arbitrary: l) ``` haftmann@31203 ` 93` ``` case Nil then show ?case by simp ``` haftmann@31203 ` 94` ```next ``` haftmann@31203 ` 95` ``` case (Cons x xs) then show ?case by (cases l) simp_all ``` haftmann@31203 ` 96` ```qed ``` haftmann@31203 ` 97` haftmann@31205 ` 98` ```definition select_weight :: "(code_numeral \ 'a) list \ seed \ 'a \ seed" where ``` haftmann@31180 ` 99` ``` "select_weight xs = range (listsum (map fst xs)) ``` haftmann@37751 ` 100` ``` \\ (\k. Pair (pick xs k))" ``` haftmann@31180 ` 101` haftmann@31180 ` 102` ```lemma select_weight_member: ``` haftmann@31180 ` 103` ``` assumes "0 < listsum (map fst xs)" ``` haftmann@31180 ` 104` ``` shows "fst (select_weight xs s) \ set (map snd xs)" ``` haftmann@31180 ` 105` ```proof - ``` haftmann@31180 ` 106` ``` from range assms ``` haftmann@31180 ` 107` ``` have "fst (range (listsum (map fst xs)) s) < listsum (map fst xs)" . ``` haftmann@31180 ` 108` ``` with pick_member ``` haftmann@31180 ` 109` ``` have "pick xs (fst (range (listsum (map fst xs)) s)) \ set (map snd xs)" . ``` haftmann@31180 ` 110` ``` then show ?thesis by (simp add: select_weight_def scomp_def split_def) ``` haftmann@31180 ` 111` ```qed ``` haftmann@31180 ` 112` haftmann@31268 ` 113` ```lemma select_weight_cons_zero: ``` haftmann@31268 ` 114` ``` "select_weight ((0, x) # xs) = select_weight xs" ``` haftmann@31268 ` 115` ``` by (simp add: select_weight_def) ``` haftmann@31268 ` 116` haftmann@31203 ` 117` ```lemma select_weigth_drop_zero: ``` haftmann@31261 ` 118` ``` "select_weight (filter (\(k, _). k > 0) xs) = select_weight xs" ``` haftmann@31203 ` 119` ```proof - ``` haftmann@31203 ` 120` ``` have "listsum (map fst [(k, _)\xs . 0 < k]) = listsum (map fst xs)" ``` haftmann@31203 ` 121` ``` by (induct xs) auto ``` haftmann@31203 ` 122` ``` then show ?thesis by (simp only: select_weight_def pick_drop_zero) ``` haftmann@31203 ` 123` ```qed ``` haftmann@31203 ` 124` haftmann@31203 ` 125` ```lemma select_weigth_select: ``` haftmann@31203 ` 126` ``` assumes "xs \ []" ``` haftmann@31261 ` 127` ``` shows "select_weight (map (Pair 1) xs) = select xs" ``` haftmann@31203 ` 128` ```proof - ``` haftmann@31261 ` 129` ``` have less: "\s. fst (range (Code_Numeral.of_nat (length xs)) s) < Code_Numeral.of_nat (length xs)" ``` haftmann@31203 ` 130` ``` using assms by (intro range) simp ``` haftmann@31205 ` 131` ``` moreover have "listsum (map fst (map (Pair 1) xs)) = Code_Numeral.of_nat (length xs)" ``` haftmann@31203 ` 132` ``` by (induct xs) simp_all ``` haftmann@31203 ` 133` ``` ultimately show ?thesis ``` haftmann@31203 ` 134` ``` by (auto simp add: select_weight_def select_def scomp_def split_def ``` nipkow@39302 ` 135` ``` fun_eq_iff pick_same [symmetric]) ``` haftmann@31203 ` 136` ```qed ``` haftmann@31203 ` 137` haftmann@26265 ` 138` haftmann@26265 ` 139` ```subsection {* @{text ML} interface *} ``` haftmann@22528 ` 140` haftmann@36538 ` 141` ```code_reflect Random_Engine ``` haftmann@36538 ` 142` ``` functions range select select_weight ``` haftmann@36538 ` 143` haftmann@22528 ` 144` ```ML {* ``` haftmann@26265 ` 145` ```structure Random_Engine = ``` haftmann@22528 ` 146` ```struct ``` haftmann@22528 ` 147` haftmann@36538 ` 148` ```open Random_Engine; ``` haftmann@36538 ` 149` haftmann@26038 ` 150` ```type seed = int * int; ``` haftmann@22528 ` 151` haftmann@22528 ` 152` ```local ``` haftmann@26038 ` 153` wenzelm@32740 ` 154` ```val seed = Unsynchronized.ref ``` haftmann@26265 ` 155` ``` (let ``` haftmann@26265 ` 156` ``` val now = Time.toMilliseconds (Time.now ()); ``` haftmann@26038 ` 157` ``` val (q, s1) = IntInf.divMod (now, 2147483562); ``` haftmann@26038 ` 158` ``` val s2 = q mod 2147483398; ``` haftmann@26265 ` 159` ``` in (s1 + 1, s2 + 1) end); ``` haftmann@26265 ` 160` haftmann@22528 ` 161` ```in ``` haftmann@26038 ` 162` bulwahn@36020 ` 163` ```fun next_seed () = ``` bulwahn@36020 ` 164` ``` let ``` bulwahn@36020 ` 165` ``` val (seed1, seed') = @{code split_seed} (! seed) ``` bulwahn@36020 ` 166` ``` val _ = seed := seed' ``` bulwahn@36020 ` 167` ``` in ``` bulwahn@36020 ` 168` ``` seed1 ``` bulwahn@36020 ` 169` ``` end ``` bulwahn@36020 ` 170` haftmann@26038 ` 171` ```fun run f = ``` haftmann@26038 ` 172` ``` let ``` haftmann@26265 ` 173` ``` val (x, seed') = f (! seed); ``` haftmann@26038 ` 174` ``` val _ = seed := seed' ``` haftmann@26038 ` 175` ``` in x end; ``` haftmann@26038 ` 176` haftmann@22528 ` 177` ```end; ``` haftmann@22528 ` 178` haftmann@22528 ` 179` ```end; ``` haftmann@22528 ` 180` ```*} ``` haftmann@22528 ` 181` wenzelm@36176 ` 182` ```hide_type (open) seed ``` wenzelm@36176 ` 183` ```hide_const (open) inc_shift minus_shift log "next" split_seed ``` haftmann@31636 ` 184` ``` iterate range select pick select_weight ``` wenzelm@36176 ` 185` ```hide_fact (open) range_def ``` haftmann@31180 ` 186` haftmann@37751 ` 187` ```no_notation fcomp (infixl "\>" 60) ``` haftmann@37751 ` 188` ```no_notation scomp (infixl "\\" 60) ``` haftmann@29823 ` 189` haftmann@26038 ` 190` ```end ``` haftmann@28145 ` 191`