| 22528 |      1 | (*  ID:         $Id$
 | 
|  |      2 |     Author:     Florian Haftmann, TU Muenchen
 | 
|  |      3 | *)
 | 
|  |      4 | 
 | 
| 26265 |      5 | header {* A HOL random engine *}
 | 
| 22528 |      6 | 
 | 
|  |      7 | theory Random
 | 
| 26265 |      8 | imports State_Monad Code_Index
 | 
| 22528 |      9 | begin
 | 
|  |     10 | 
 | 
| 26265 |     11 | subsection {* Auxiliary functions *}
 | 
|  |     12 | 
 | 
|  |     13 | definition
 | 
|  |     14 |   inc_shift :: "index \<Rightarrow> index \<Rightarrow> index"
 | 
|  |     15 | where
 | 
|  |     16 |   "inc_shift v k = (if v = k then 1 else k + 1)"
 | 
|  |     17 | 
 | 
|  |     18 | definition
 | 
|  |     19 |   minus_shift :: "index \<Rightarrow> index \<Rightarrow> index \<Rightarrow> index"
 | 
|  |     20 | where
 | 
|  |     21 |   "minus_shift r k l = (if k < l then r + k - l else k - l)"
 | 
|  |     22 | 
 | 
| 28042 |     23 | fun
 | 
| 26265 |     24 |   log :: "index \<Rightarrow> index \<Rightarrow> index"
 | 
|  |     25 | where
 | 
|  |     26 |   "log b i = (if b \<le> 1 \<or> i < b then 1 else 1 + log b (i div b))"
 | 
|  |     27 | 
 | 
|  |     28 | subsection {* Random seeds *}
 | 
| 26038 |     29 | 
 | 
|  |     30 | types seed = "index \<times> index"
 | 
| 22528 |     31 | 
 | 
| 26265 |     32 | primrec
 | 
| 26038 |     33 |   "next" :: "seed \<Rightarrow> index \<times> seed"
 | 
|  |     34 | where
 | 
| 26265 |     35 |   "next (v, w) = (let
 | 
|  |     36 |      k =  v div 53668;
 | 
|  |     37 |      v' = minus_shift 2147483563 (40014 * (v mod 53668)) (k * 12211);
 | 
|  |     38 |      l =  w div 52774;
 | 
|  |     39 |      w' = minus_shift 2147483399 (40692 * (w mod 52774)) (l * 3791);
 | 
|  |     40 |      z =  minus_shift 2147483562 v' (w' + 1) + 1
 | 
|  |     41 |    in (z, (v', w')))"
 | 
|  |     42 | 
 | 
|  |     43 | lemma next_not_0:
 | 
|  |     44 |   "fst (next s) \<noteq> 0"
 | 
|  |     45 | apply (cases s)
 | 
|  |     46 | apply (auto simp add: minus_shift_def Let_def)
 | 
|  |     47 | done
 | 
|  |     48 | 
 | 
|  |     49 | primrec
 | 
|  |     50 |   seed_invariant :: "seed \<Rightarrow> bool"
 | 
|  |     51 | where
 | 
|  |     52 |   "seed_invariant (v, w) \<longleftrightarrow> 0 < v \<and> v < 9438322952 \<and> 0 < w \<and> True"
 | 
|  |     53 | 
 | 
|  |     54 | lemma if_same:
 | 
|  |     55 |   "(if b then f x else f y) = f (if b then x else y)"
 | 
|  |     56 |   by (cases b) simp_all
 | 
|  |     57 | 
 | 
| 22528 |     58 | definition
 | 
| 26038 |     59 |   split_seed :: "seed \<Rightarrow> seed \<times> seed"
 | 
|  |     60 | where
 | 
|  |     61 |   "split_seed s = (let
 | 
|  |     62 |      (v, w) = s;
 | 
|  |     63 |      (v', w') = snd (next s);
 | 
| 26265 |     64 |      v'' = inc_shift 2147483562 v;
 | 
| 26038 |     65 |      s'' = (v'', w');
 | 
| 26265 |     66 |      w'' = inc_shift 2147483398 w;
 | 
| 26038 |     67 |      s''' = (v', w'')
 | 
|  |     68 |    in (s'', s'''))"
 | 
|  |     69 | 
 | 
|  |     70 | 
 | 
| 26265 |     71 | subsection {* Base selectors *}
 | 
| 22528 |     72 | 
 | 
| 26038 |     73 | function
 | 
|  |     74 |   range_aux :: "index \<Rightarrow> index \<Rightarrow> seed \<Rightarrow> index \<times> seed"
 | 
|  |     75 | where
 | 
|  |     76 |   "range_aux k l s = (if k = 0 then (l, s) else
 | 
|  |     77 |     let (v, s') = next s
 | 
|  |     78 |   in range_aux (k - 1) (v + l * 2147483561) s')"
 | 
|  |     79 | by pat_completeness auto
 | 
|  |     80 | termination
 | 
|  |     81 |   by (relation "measure (nat_of_index o fst)")
 | 
|  |     82 |     (auto simp add: index)
 | 
| 22528 |     83 | 
 | 
|  |     84 | definition
 | 
| 26038 |     85 |   range :: "index \<Rightarrow> seed \<Rightarrow> index \<times> seed"
 | 
|  |     86 | where
 | 
| 26265 |     87 |   "range k = (do
 | 
|  |     88 |      v \<leftarrow> range_aux (log 2147483561 k) 1;
 | 
|  |     89 |      return (v mod k)
 | 
|  |     90 |    done)"
 | 
|  |     91 | 
 | 
|  |     92 | lemma range:
 | 
|  |     93 |   assumes "k > 0"
 | 
|  |     94 |   shows "fst (range k s) < k"
 | 
|  |     95 | proof -
 | 
|  |     96 |   obtain v w where range_aux:
 | 
|  |     97 |     "range_aux (log 2147483561 k) 1 s = (v, w)"
 | 
|  |     98 |     by (cases "range_aux (log 2147483561 k) 1 s")
 | 
|  |     99 |   with assms show ?thesis
 | 
| 28145 |    100 |     by (simp add: monad_collapse range_def del: range_aux.simps log.simps)
 | 
| 26265 |    101 | qed
 | 
| 26038 |    102 | 
 | 
| 22528 |    103 | definition
 | 
| 26038 |    104 |   select :: "'a list \<Rightarrow> seed \<Rightarrow> 'a \<times> seed"
 | 
|  |    105 | where
 | 
| 26265 |    106 |   "select xs = (do
 | 
|  |    107 |      k \<leftarrow> range (index_of_nat (length xs));
 | 
|  |    108 |      return (nth xs (nat_of_index k))
 | 
|  |    109 |    done)"
 | 
|  |    110 | 
 | 
|  |    111 | lemma select:
 | 
|  |    112 |   assumes "xs \<noteq> []"
 | 
|  |    113 |   shows "fst (select xs s) \<in> set xs"
 | 
|  |    114 | proof -
 | 
|  |    115 |   from assms have "index_of_nat (length xs) > 0" by simp
 | 
|  |    116 |   with range have
 | 
|  |    117 |     "fst (range (index_of_nat (length xs)) s) < index_of_nat (length xs)" by best
 | 
|  |    118 |   then have
 | 
|  |    119 |     "nat_of_index (fst (range (index_of_nat (length xs)) s)) < length xs" by simp
 | 
|  |    120 |   then show ?thesis
 | 
| 28145 |    121 |     by (auto simp add: monad_collapse select_def)
 | 
| 26265 |    122 | qed
 | 
| 22528 |    123 | 
 | 
| 26038 |    124 | definition
 | 
| 26265 |    125 |   select_default :: "index \<Rightarrow> 'a \<Rightarrow> 'a \<Rightarrow> seed \<Rightarrow> 'a \<times> seed"
 | 
| 26038 |    126 | where
 | 
| 28562 |    127 |   [code del]: "select_default k x y = (do
 | 
| 26265 |    128 |      l \<leftarrow> range k;
 | 
|  |    129 |      return (if l + 1 < k then x else y)
 | 
|  |    130 |    done)"
 | 
|  |    131 | 
 | 
|  |    132 | lemma select_default_zero:
 | 
|  |    133 |   "fst (select_default 0 x y s) = y"
 | 
| 28145 |    134 |   by (simp add: monad_collapse select_default_def)
 | 
| 26038 |    135 | 
 | 
| 26265 |    136 | lemma select_default_code [code]:
 | 
|  |    137 |   "select_default k x y = (if k = 0 then do
 | 
|  |    138 |      _ \<leftarrow> range 1;
 | 
|  |    139 |      return y
 | 
|  |    140 |    done else do
 | 
|  |    141 |      l \<leftarrow> range k;
 | 
|  |    142 |      return (if l + 1 < k then x else y)
 | 
|  |    143 |    done)"
 | 
|  |    144 | proof (cases "k = 0")
 | 
|  |    145 |   case False then show ?thesis by (simp add: select_default_def)
 | 
| 22528 |    146 | next
 | 
| 26265 |    147 |   case True then show ?thesis
 | 
| 28145 |    148 |     by (simp add: monad_collapse select_default_def range_def)
 | 
| 26265 |    149 | qed
 | 
| 22528 |    150 | 
 | 
| 26265 |    151 | 
 | 
|  |    152 | subsection {* @{text ML} interface *}
 | 
| 22528 |    153 | 
 | 
|  |    154 | ML {*
 | 
| 26265 |    155 | structure Random_Engine =
 | 
| 22528 |    156 | struct
 | 
|  |    157 | 
 | 
| 26038 |    158 | type seed = int * int;
 | 
| 22528 |    159 | 
 | 
|  |    160 | local
 | 
| 26038 |    161 | 
 | 
| 26265 |    162 | val seed = ref 
 | 
|  |    163 |   (let
 | 
|  |    164 |     val now = Time.toMilliseconds (Time.now ());
 | 
| 26038 |    165 |     val (q, s1) = IntInf.divMod (now, 2147483562);
 | 
|  |    166 |     val s2 = q mod 2147483398;
 | 
| 26265 |    167 |   in (s1 + 1, s2 + 1) end);
 | 
|  |    168 | 
 | 
| 22528 |    169 | in
 | 
| 26038 |    170 | 
 | 
|  |    171 | fun run f =
 | 
|  |    172 |   let
 | 
| 26265 |    173 |     val (x, seed') = f (! seed);
 | 
| 26038 |    174 |     val _ = seed := seed'
 | 
|  |    175 |   in x end;
 | 
|  |    176 | 
 | 
| 22528 |    177 | end;
 | 
|  |    178 | 
 | 
|  |    179 | end;
 | 
|  |    180 | *}
 | 
|  |    181 | 
 | 
| 26038 |    182 | end
 | 
| 28145 |    183 | 
 |