src/HOL/Random.thy
 author wenzelm Sun Sep 18 20:33:48 2016 +0200 (2016-09-18) changeset 63915 bab633745c7f parent 63882 018998c00003 child 68249 949d93804740 permissions -rw-r--r--
tuned proofs;
 haftmann@29815 ` 1` ```(* Author: Florian Haftmann, TU Muenchen *) ``` haftmann@22528 ` 2` wenzelm@60758 ` 3` ```section \A HOL random engine\ ``` haftmann@22528 ` 4` haftmann@22528 ` 5` ```theory Random ``` haftmann@58101 ` 6` ```imports List Groups_List ``` haftmann@22528 ` 7` ```begin ``` haftmann@22528 ` 8` haftmann@37751 ` 9` ```notation fcomp (infixl "\>" 60) ``` haftmann@37751 ` 10` ```notation scomp (infixl "\\" 60) ``` haftmann@29823 ` 11` haftmann@29823 ` 12` wenzelm@60758 ` 13` ```subsection \Auxiliary functions\ ``` haftmann@26265 ` 14` haftmann@51143 ` 15` ```fun log :: "natural \ natural \ natural" where ``` haftmann@33236 ` 16` ``` "log b i = (if b \ 1 \ i < b then 1 else 1 + log b (i div b))" ``` haftmann@33236 ` 17` haftmann@51143 ` 18` ```definition inc_shift :: "natural \ natural \ natural" where ``` haftmann@26265 ` 19` ``` "inc_shift v k = (if v = k then 1 else k + 1)" ``` haftmann@26265 ` 20` haftmann@51143 ` 21` ```definition minus_shift :: "natural \ natural \ natural \ natural" where ``` haftmann@26265 ` 22` ``` "minus_shift r k l = (if k < l then r + k - l else k - l)" ``` haftmann@26265 ` 23` haftmann@30495 ` 24` wenzelm@60758 ` 25` ```subsection \Random seeds\ ``` haftmann@26038 ` 26` haftmann@51143 ` 27` ```type_synonym seed = "natural \ natural" ``` haftmann@22528 ` 28` haftmann@51143 ` 29` ```primrec "next" :: "seed \ natural \ seed" where ``` haftmann@26265 ` 30` ``` "next (v, w) = (let ``` haftmann@26265 ` 31` ``` k = v div 53668; ``` haftmann@33236 ` 32` ``` v' = minus_shift 2147483563 ((v mod 53668) * 40014) (k * 12211); ``` haftmann@26265 ` 33` ``` l = w div 52774; ``` haftmann@33236 ` 34` ``` w' = minus_shift 2147483399 ((w mod 52774) * 40692) (l * 3791); ``` haftmann@26265 ` 35` ``` z = minus_shift 2147483562 v' (w' + 1) + 1 ``` haftmann@26265 ` 36` ``` in (z, (v', w')))" ``` haftmann@26265 ` 37` haftmann@29823 ` 38` ```definition split_seed :: "seed \ seed \ seed" where ``` haftmann@26038 ` 39` ``` "split_seed s = (let ``` haftmann@26038 ` 40` ``` (v, w) = s; ``` haftmann@26038 ` 41` ``` (v', w') = snd (next s); ``` haftmann@26265 ` 42` ``` v'' = inc_shift 2147483562 v; ``` haftmann@33236 ` 43` ``` w'' = inc_shift 2147483398 w ``` haftmann@33236 ` 44` ``` in ((v'', w'), (v', w'')))" ``` haftmann@26038 ` 45` haftmann@26038 ` 46` wenzelm@60758 ` 47` ```subsection \Base selectors\ ``` haftmann@22528 ` 48` haftmann@51143 ` 49` ```fun iterate :: "natural \ ('b \ 'a \ 'b \ 'a) \ 'b \ 'a \ 'b \ 'a" where ``` haftmann@37751 ` 50` ``` "iterate k f x = (if k = 0 then Pair x else f x \\ iterate (k - 1) f)" ``` haftmann@22528 ` 51` haftmann@51143 ` 52` ```definition range :: "natural \ seed \ natural \ seed" where ``` haftmann@30495 ` 53` ``` "range k = iterate (log 2147483561 k) ``` haftmann@37751 ` 54` ``` (\l. next \\ (\v. Pair (v + l * 2147483561))) 1 ``` haftmann@37751 ` 55` ``` \\ (\v. Pair (v mod k))" ``` haftmann@26265 ` 56` haftmann@26265 ` 57` ```lemma range: ``` haftmann@30495 ` 58` ``` "k > 0 \ fst (range k s) < k" ``` haftmann@51143 ` 59` ``` by (simp add: range_def split_def less_natural_def del: log.simps iterate.simps) ``` haftmann@26038 ` 60` haftmann@29823 ` 61` ```definition select :: "'a list \ seed \ 'a \ seed" where ``` haftmann@51143 ` 62` ``` "select xs = range (natural_of_nat (length xs)) ``` haftmann@51143 ` 63` ``` \\ (\k. Pair (nth xs (nat_of_natural k)))" ``` haftmann@29823 ` 64` ``` ``` haftmann@26265 ` 65` ```lemma select: ``` haftmann@26265 ` 66` ``` assumes "xs \ []" ``` haftmann@26265 ` 67` ``` shows "fst (select xs s) \ set xs" ``` haftmann@26265 ` 68` ```proof - ``` haftmann@51143 ` 69` ``` from assms have "natural_of_nat (length xs) > 0" by (simp add: less_natural_def) ``` haftmann@26265 ` 70` ``` with range have ``` haftmann@51143 ` 71` ``` "fst (range (natural_of_nat (length xs)) s) < natural_of_nat (length xs)" by best ``` haftmann@26265 ` 72` ``` then have ``` haftmann@51143 ` 73` ``` "nat_of_natural (fst (range (natural_of_nat (length xs)) s)) < length xs" by (simp add: less_natural_def) ``` haftmann@26265 ` 74` ``` then show ?thesis ``` huffman@44921 ` 75` ``` by (simp add: split_beta select_def) ``` haftmann@26265 ` 76` ```qed ``` haftmann@22528 ` 77` haftmann@51143 ` 78` ```primrec pick :: "(natural \ 'a) list \ natural \ 'a" where ``` haftmann@31180 ` 79` ``` "pick (x # xs) i = (if i < fst x then snd x else pick xs (i - fst x))" ``` haftmann@31180 ` 80` haftmann@31180 ` 81` ```lemma pick_member: ``` nipkow@63882 ` 82` ``` "i < sum_list (map fst xs) \ pick xs i \ set (map snd xs)" ``` haftmann@51143 ` 83` ``` by (induct xs arbitrary: i) (simp_all add: less_natural_def) ``` haftmann@31180 ` 84` haftmann@31180 ` 85` ```lemma pick_drop_zero: ``` haftmann@31180 ` 86` ``` "pick (filter (\(k, _). k > 0) xs) = pick xs" ``` haftmann@51143 ` 87` ``` by (induct xs) (auto simp add: fun_eq_iff less_natural_def minus_natural_def) ``` haftmann@31180 ` 88` haftmann@31203 ` 89` ```lemma pick_same: ``` haftmann@51143 ` 90` ``` "l < length xs \ Random.pick (map (Pair 1) xs) (natural_of_nat l) = nth xs l" ``` haftmann@31203 ` 91` ```proof (induct xs arbitrary: l) ``` haftmann@31203 ` 92` ``` case Nil then show ?case by simp ``` haftmann@31203 ` 93` ```next ``` haftmann@51143 ` 94` ``` case (Cons x xs) then show ?case by (cases l) (simp_all add: less_natural_def) ``` haftmann@31203 ` 95` ```qed ``` haftmann@31203 ` 96` haftmann@51143 ` 97` ```definition select_weight :: "(natural \ 'a) list \ seed \ 'a \ seed" where ``` nipkow@63882 ` 98` ``` "select_weight xs = range (sum_list (map fst xs)) ``` haftmann@37751 ` 99` ``` \\ (\k. Pair (pick xs k))" ``` haftmann@31180 ` 100` haftmann@31180 ` 101` ```lemma select_weight_member: ``` nipkow@63882 ` 102` ``` assumes "0 < sum_list (map fst xs)" ``` haftmann@31180 ` 103` ``` shows "fst (select_weight xs s) \ set (map snd xs)" ``` haftmann@31180 ` 104` ```proof - ``` haftmann@31180 ` 105` ``` from range assms ``` nipkow@63882 ` 106` ``` have "fst (range (sum_list (map fst xs)) s) < sum_list (map fst xs)" . ``` haftmann@31180 ` 107` ``` with pick_member ``` nipkow@63882 ` 108` ``` have "pick xs (fst (range (sum_list (map fst xs)) s)) \ set (map snd xs)" . ``` haftmann@31180 ` 109` ``` then show ?thesis by (simp add: select_weight_def scomp_def split_def) ``` haftmann@31180 ` 110` ```qed ``` haftmann@31180 ` 111` haftmann@31268 ` 112` ```lemma select_weight_cons_zero: ``` haftmann@31268 ` 113` ``` "select_weight ((0, x) # xs) = select_weight xs" ``` haftmann@51143 ` 114` ``` by (simp add: select_weight_def less_natural_def) ``` haftmann@31268 ` 115` bulwahn@46311 ` 116` ```lemma select_weight_drop_zero: ``` haftmann@31261 ` 117` ``` "select_weight (filter (\(k, _). k > 0) xs) = select_weight xs" ``` haftmann@31203 ` 118` ```proof - ``` nipkow@63882 ` 119` ``` have "sum_list (map fst [(k, _)\xs . 0 < k]) = sum_list (map fst xs)" ``` haftmann@62608 ` 120` ``` by (induct xs) (auto simp add: less_natural_def natural_eq_iff) ``` haftmann@31203 ` 121` ``` then show ?thesis by (simp only: select_weight_def pick_drop_zero) ``` haftmann@31203 ` 122` ```qed ``` haftmann@31203 ` 123` bulwahn@46311 ` 124` ```lemma select_weight_select: ``` haftmann@31203 ` 125` ``` assumes "xs \ []" ``` haftmann@31261 ` 126` ``` shows "select_weight (map (Pair 1) xs) = select xs" ``` haftmann@31203 ` 127` ```proof - ``` haftmann@51143 ` 128` ``` have less: "\s. fst (range (natural_of_nat (length xs)) s) < natural_of_nat (length xs)" ``` haftmann@51143 ` 129` ``` using assms by (intro range) (simp add: less_natural_def) ``` nipkow@63882 ` 130` ``` moreover have "sum_list (map fst (map (Pair 1) xs)) = natural_of_nat (length xs)" ``` haftmann@31203 ` 131` ``` by (induct xs) simp_all ``` haftmann@31203 ` 132` ``` ultimately show ?thesis ``` haftmann@31203 ` 133` ``` by (auto simp add: select_weight_def select_def scomp_def split_def ``` haftmann@51143 ` 134` ``` fun_eq_iff pick_same [symmetric] less_natural_def) ``` haftmann@31203 ` 135` ```qed ``` haftmann@31203 ` 136` haftmann@26265 ` 137` wenzelm@61799 ` 138` ```subsection \\ML\ interface\ ``` haftmann@22528 ` 139` haftmann@36538 ` 140` ```code_reflect Random_Engine ``` haftmann@36538 ` 141` ``` functions range select select_weight ``` haftmann@36538 ` 142` wenzelm@60758 ` 143` ```ML \ ``` haftmann@26265 ` 144` ```structure Random_Engine = ``` haftmann@22528 ` 145` ```struct ``` haftmann@22528 ` 146` haftmann@36538 ` 147` ```open Random_Engine; ``` haftmann@36538 ` 148` haftmann@51143 ` 149` ```type seed = Code_Numeral.natural * Code_Numeral.natural; ``` haftmann@22528 ` 150` haftmann@22528 ` 151` ```local ``` haftmann@26038 ` 152` wenzelm@32740 ` 153` ```val seed = Unsynchronized.ref ``` haftmann@26265 ` 154` ``` (let ``` haftmann@26265 ` 155` ``` val now = Time.toMilliseconds (Time.now ()); ``` haftmann@26038 ` 156` ``` val (q, s1) = IntInf.divMod (now, 2147483562); ``` haftmann@26038 ` 157` ``` val s2 = q mod 2147483398; ``` wenzelm@59058 ` 158` ``` in apply2 Code_Numeral.natural_of_integer (s1 + 1, s2 + 1) end); ``` haftmann@26265 ` 159` haftmann@22528 ` 160` ```in ``` haftmann@26038 ` 161` bulwahn@36020 ` 162` ```fun next_seed () = ``` bulwahn@36020 ` 163` ``` let ``` bulwahn@36020 ` 164` ``` val (seed1, seed') = @{code split_seed} (! seed) ``` bulwahn@36020 ` 165` ``` val _ = seed := seed' ``` bulwahn@36020 ` 166` ``` in ``` bulwahn@36020 ` 167` ``` seed1 ``` bulwahn@36020 ` 168` ``` end ``` bulwahn@36020 ` 169` haftmann@26038 ` 170` ```fun run f = ``` haftmann@26038 ` 171` ``` let ``` haftmann@26265 ` 172` ``` val (x, seed') = f (! seed); ``` haftmann@26038 ` 173` ``` val _ = seed := seed' ``` haftmann@26038 ` 174` ``` in x end; ``` haftmann@26038 ` 175` haftmann@22528 ` 176` ```end; ``` haftmann@22528 ` 177` haftmann@22528 ` 178` ```end; ``` wenzelm@60758 ` 179` ```\ ``` haftmann@22528 ` 180` wenzelm@36176 ` 181` ```hide_type (open) seed ``` wenzelm@36176 ` 182` ```hide_const (open) inc_shift minus_shift log "next" split_seed ``` haftmann@31636 ` 183` ``` iterate range select pick select_weight ``` wenzelm@36176 ` 184` ```hide_fact (open) range_def ``` haftmann@31180 ` 185` haftmann@37751 ` 186` ```no_notation fcomp (infixl "\>" 60) ``` haftmann@37751 ` 187` ```no_notation scomp (infixl "\\" 60) ``` haftmann@29823 ` 188` haftmann@26038 ` 189` ```end ```