| 
22528
 | 
     1  | 
(*  ID:         $Id$
  | 
| 
 | 
     2  | 
    Author:     Florian Haftmann, TU Muenchen
  | 
| 
 | 
     3  | 
*)
  | 
| 
 | 
     4  | 
  | 
| 
 | 
     5  | 
header {* A simple random engine *}
 | 
| 
 | 
     6  | 
  | 
| 
 | 
     7  | 
theory Random
  | 
| 
 | 
     8  | 
imports State_Monad
  | 
| 
 | 
     9  | 
begin
  | 
| 
 | 
    10  | 
  | 
| 
 | 
    11  | 
fun
  | 
| 
 | 
    12  | 
  pick :: "(nat \<times> 'a) list \<Rightarrow> nat \<Rightarrow> 'a"
  | 
| 
 | 
    13  | 
where
  | 
| 
 | 
    14  | 
  pick_undef: "pick [] n = undefined"
  | 
| 
 | 
    15  | 
  | pick_simp: "pick ((k, v)#xs) n = (if n < k then v else pick xs (n - k))"
  | 
| 
 | 
    16  | 
lemmas [code nofunc] = pick_undef
  | 
| 
 | 
    17  | 
  | 
| 
 | 
    18  | 
typedecl randseed
  | 
| 
 | 
    19  | 
  | 
| 
 | 
    20  | 
axiomatization
  | 
| 
 | 
    21  | 
  random_shift :: "randseed \<Rightarrow> randseed"
  | 
| 
 | 
    22  | 
  | 
| 
 | 
    23  | 
axiomatization
  | 
| 
 | 
    24  | 
  random_seed :: "randseed \<Rightarrow> nat"
  | 
| 
 | 
    25  | 
  | 
| 
 | 
    26  | 
definition
  | 
| 
 | 
    27  | 
  random :: "nat \<Rightarrow> randseed \<Rightarrow> nat \<times> randseed" where
  | 
| 
 | 
    28  | 
  "random n s = (random_seed s mod n, random_shift s)"
  | 
| 
 | 
    29  | 
  | 
| 
 | 
    30  | 
lemma random_bound:
  | 
| 
 | 
    31  | 
  assumes "0 < n"
  | 
| 
 | 
    32  | 
  shows "fst (random n s) < n"
  | 
| 
 | 
    33  | 
proof -
  | 
| 
 | 
    34  | 
  from prems mod_less_divisor have "!!m .m mod n < n" by auto
  | 
| 
 | 
    35  | 
  then show ?thesis unfolding random_def by simp 
  | 
| 
 | 
    36  | 
qed
  | 
| 
 | 
    37  | 
  | 
| 
 | 
    38  | 
lemma random_random_seed [simp]:
  | 
| 
 | 
    39  | 
  "snd (random n s) = random_shift s" unfolding random_def by simp
  | 
| 
 | 
    40  | 
  | 
| 
 | 
    41  | 
definition
  | 
| 
 | 
    42  | 
  select :: "'a list \<Rightarrow> randseed \<Rightarrow> 'a \<times> randseed" where
  | 
| 
 | 
    43  | 
  [simp]: "select xs = (do
  | 
| 
 | 
    44  | 
      n \<leftarrow> random (length xs);
  | 
| 
 | 
    45  | 
      return (nth xs n)
  | 
| 
 | 
    46  | 
    done)"
  | 
| 
 | 
    47  | 
definition
  | 
| 
 | 
    48  | 
  select_weight :: "(nat \<times> 'a) list \<Rightarrow> randseed \<Rightarrow> 'a \<times> randseed" where
  | 
| 
 | 
    49  | 
  [simp]: "select_weight xs = (do
  | 
| 
 | 
    50  | 
      n \<leftarrow> random (foldl (op +) 0 (map fst xs));
  | 
| 
 | 
    51  | 
      return (pick xs n)
  | 
| 
 | 
    52  | 
    done)"
  | 
| 
 | 
    53  | 
  | 
| 
 | 
    54  | 
lemma
  | 
| 
 | 
    55  | 
  "select (x#xs) s = select_weight (map (Pair 1) (x#xs)) s"
  | 
| 
 | 
    56  | 
proof (induct xs)
  | 
| 
 | 
    57  | 
  case Nil show ?case by (simp add: monad_collapse random_def)
  | 
| 
 | 
    58  | 
next
  | 
| 
 | 
    59  | 
  have map_fst_Pair: "!!xs y. map fst (map (Pair y) xs) = replicate (length xs) y"
  | 
| 
 | 
    60  | 
  proof -
  | 
| 
 | 
    61  | 
    fix xs
  | 
| 
 | 
    62  | 
    fix y
  | 
| 
 | 
    63  | 
    show "map fst (map (Pair y) xs) = replicate (length xs) y"
  | 
| 
 | 
    64  | 
      by (induct xs) simp_all
  | 
| 
 | 
    65  | 
  qed
  | 
| 
 | 
    66  | 
  have pick_nth: "!!xs n. n < length xs \<Longrightarrow> pick (map (Pair 1) xs) n = nth xs n"
  | 
| 
 | 
    67  | 
  proof -
  | 
| 
 | 
    68  | 
    fix xs
  | 
| 
 | 
    69  | 
    fix n
  | 
| 
 | 
    70  | 
    assume "n < length xs"
  | 
| 
 | 
    71  | 
    then show "pick (map (Pair 1) xs) n = nth xs n"
  | 
| 
 | 
    72  | 
    proof (induct xs arbitrary: n)
  | 
| 
 | 
    73  | 
      case Nil then show ?case by simp
  | 
| 
 | 
    74  | 
    next
  | 
| 
 | 
    75  | 
      case (Cons x xs) show ?case
  | 
| 
 | 
    76  | 
      proof (cases n)
  | 
| 
 | 
    77  | 
        case 0 then show ?thesis by simp
  | 
| 
 | 
    78  | 
      next
  | 
| 
 | 
    79  | 
        case (Suc _)
  | 
| 
 | 
    80  | 
    from Cons have "n < length (x # xs)" by auto
  | 
| 
 | 
    81  | 
        then have "n < Suc (length xs)" by simp
  | 
| 
 | 
    82  | 
        with Suc have "n - 1 < Suc (length xs) - 1" by auto
  | 
| 
 | 
    83  | 
        with Cons have "pick (map (Pair (1\<Colon>nat)) xs) (n - 1) = xs ! (n - 1)" by auto
  | 
| 
 | 
    84  | 
        with Suc show ?thesis by auto
  | 
| 
 | 
    85  | 
      qed
  | 
| 
 | 
    86  | 
    qed
  | 
| 
 | 
    87  | 
  qed
  | 
| 
 | 
    88  | 
  have sum_length: "!!xs. foldl (op +) 0 (map fst (map (Pair 1) xs)) = length xs"
  | 
| 
 | 
    89  | 
  proof -
  | 
| 
 | 
    90  | 
    have replicate_append:
  | 
| 
 | 
    91  | 
      "!!x xs y. replicate (length (x # xs)) y = replicate (length xs) y @ [y]"
  | 
| 
 | 
    92  | 
      by (simp add: replicate_app_Cons_same)
  | 
| 
 | 
    93  | 
    fix xs
  | 
| 
 | 
    94  | 
    show "foldl (op +) 0 (map fst (map (Pair 1) xs)) = length xs"
  | 
| 
 | 
    95  | 
    unfolding map_fst_Pair proof (induct xs)
  | 
| 
 | 
    96  | 
      case Nil show ?case by simp
  | 
| 
 | 
    97  | 
    next
  | 
| 
 | 
    98  | 
      case (Cons x xs) then show ?case unfolding replicate_append by simp
  | 
| 
 | 
    99  | 
    qed
  | 
| 
 | 
   100  | 
  qed
  | 
| 
 | 
   101  | 
  have pick_nth_random:
  | 
| 
 | 
   102  | 
    "!!x xs s. pick (map (Pair 1) (x#xs)) (fst (random (length (x#xs)) s)) = nth (x#xs) (fst (random (length (x#xs)) s))"
  | 
| 
 | 
   103  | 
  proof -
  | 
| 
 | 
   104  | 
    fix s
  | 
| 
 | 
   105  | 
    fix x
  | 
| 
 | 
   106  | 
    fix xs
  | 
| 
 | 
   107  | 
    have bound: "fst (random (length (x#xs)) s) < length (x#xs)" by (rule random_bound) simp
  | 
| 
 | 
   108  | 
    from pick_nth [OF bound] show
  | 
| 
 | 
   109  | 
      "pick (map (Pair 1) (x#xs)) (fst (random (length (x#xs)) s)) = nth (x#xs) (fst (random (length (x#xs)) s))" .
  | 
| 
 | 
   110  | 
  qed
  | 
| 
 | 
   111  | 
  have pick_nth_random_do:
  | 
| 
 | 
   112  | 
    "!!x xs s. (do n \<leftarrow> random (length (x#xs)); return (pick (map (Pair 1) (x#xs)) n) done) s =
  | 
| 
 | 
   113  | 
      (do n \<leftarrow> random (length (x#xs)); return (nth (x#xs) n) done) s"
  | 
| 
 | 
   114  | 
  unfolding monad_collapse split_def unfolding pick_nth_random ..
  | 
| 
 | 
   115  | 
  case (Cons x xs) then show ?case
  | 
| 
 | 
   116  | 
    unfolding select_weight_def sum_length pick_nth_random_do
  | 
| 
 | 
   117  | 
    by simp
  | 
| 
 | 
   118  | 
qed
  | 
| 
 | 
   119  | 
  | 
| 
 | 
   120  | 
definition
  | 
| 
 | 
   121  | 
  random_int :: "int \<Rightarrow> randseed \<Rightarrow> int * randseed" where
  | 
| 
 | 
   122  | 
  "random_int k = (do n \<leftarrow> random (nat k); return (int n) done)"
  | 
| 
 | 
   123  | 
  | 
| 
 | 
   124  | 
lemma random_nat [code]:
  | 
| 
 | 
   125  | 
  "random n = (do k \<leftarrow> random_int (int n); return (nat k) done)"
  | 
| 
 | 
   126  | 
unfolding random_int_def by simp
  | 
| 
 | 
   127  | 
  | 
| 
 | 
   128  | 
axiomatization
  | 
| 
 | 
   129  | 
  run_random :: "(randseed \<Rightarrow> 'a * randseed) \<Rightarrow> 'a"
  | 
| 
 | 
   130  | 
  | 
| 
 | 
   131  | 
ML {*
 | 
| 
 | 
   132  | 
signature RANDOM =
  | 
| 
 | 
   133  | 
sig
  | 
| 
 | 
   134  | 
  type seed = IntInf.int;
  | 
| 
 | 
   135  | 
  val seed: unit -> seed;
  | 
| 
 | 
   136  | 
  val value: IntInf.int -> seed -> IntInf.int * seed;
  | 
| 
 | 
   137  | 
end;
  | 
| 
 | 
   138  | 
  | 
| 
 | 
   139  | 
structure Random : RANDOM =
  | 
| 
 | 
   140  | 
struct
  | 
| 
 | 
   141  | 
  | 
| 
 | 
   142  | 
open IntInf;
  | 
| 
 | 
   143  | 
  | 
| 
 | 
   144  | 
exception RANDOM;
  | 
| 
 | 
   145  | 
  | 
| 
 | 
   146  | 
type seed = int;
  | 
| 
 | 
   147  | 
  | 
| 
 | 
   148  | 
local
  | 
| 
 | 
   149  | 
  val a = fromInt 16807;
  | 
| 
 | 
   150  | 
    (*greetings to SML/NJ*)
  | 
| 
 | 
   151  | 
  val m = (the o fromString) "2147483647";
  | 
| 
 | 
   152  | 
in
  | 
| 
 | 
   153  | 
  fun next s = (a * s) mod m;
  | 
| 
 | 
   154  | 
end;
  | 
| 
 | 
   155  | 
  | 
| 
 | 
   156  | 
local
  | 
| 
 | 
   157  | 
  val seed_ref = ref (fromInt 1);
  | 
| 
 | 
   158  | 
in
  | 
| 
 | 
   159  | 
  fun seed () =
  | 
| 
 | 
   160  | 
    let
  | 
| 
 | 
   161  | 
      val r = next (!seed_ref)
  | 
| 
 | 
   162  | 
    in
  | 
| 
 | 
   163  | 
      (seed_ref := r; r)
  | 
| 
 | 
   164  | 
    end;
  | 
| 
 | 
   165  | 
end;
  | 
| 
 | 
   166  | 
  | 
| 
 | 
   167  | 
fun value h s =
  | 
| 
 | 
   168  | 
  if h < 1 then raise RANDOM
  | 
| 
 | 
   169  | 
  else (s mod (h - 1), seed ());
  | 
| 
 | 
   170  | 
  | 
| 
 | 
   171  | 
end;
  | 
| 
 | 
   172  | 
*}
  | 
| 
 | 
   173  | 
  | 
| 
 | 
   174  | 
code_type randseed
  | 
| 
 | 
   175  | 
  (SML "Random.seed")
  | 
| 
 | 
   176  | 
  | 
| 
 | 
   177  | 
code_const random_int
  | 
| 
 | 
   178  | 
  (SML "Random.value")
  | 
| 
 | 
   179  | 
  | 
| 
 | 
   180  | 
code_const run_random
  | 
| 
 | 
   181  | 
  (SML "case _ (Random.seed ()) of (x, '_) => x")
  | 
| 
 | 
   182  | 
  | 
| 
 | 
   183  | 
code_gen select select_weight
  | 
| 
 | 
   184  | 
  (SML #)
  | 
| 
 | 
   185  | 
  | 
| 
 | 
   186  | 
end
  |