src/HOL/ex/CodeRandom.thy
changeset 22528 8501c4a62a3c
parent 22527 84690fcd3db9
child 22529 902ed60d53a7
equal deleted inserted replaced
22527:84690fcd3db9 22528:8501c4a62a3c
     1 (*  ID:         $Id$
       
     2     Author:     Florian Haftmann, TU Muenchen
       
     3 *)
       
     4 
       
     5 header {* A simple random engine *}
       
     6 
       
     7 theory CodeRandom
       
     8 imports State_Monad
       
     9 begin
       
    10 
       
    11 consts
       
    12   pick :: "(nat \<times> 'a) list \<Rightarrow> nat \<Rightarrow> 'a"
       
    13 
       
    14 primrec
       
    15   "pick (x#xs) n = (let (k, v) = x in
       
    16     if n < k then v else pick xs (n - k))"
       
    17 
       
    18 lemma pick_def [code, simp]:
       
    19   "pick ((k, v)#xs) n = (if n < k then v else pick xs (n - k))" by simp
       
    20 declare pick.simps [simp del, code del]
       
    21 
       
    22 typedecl randseed
       
    23 
       
    24 axiomatization
       
    25   random_shift :: "randseed \<Rightarrow> randseed"
       
    26 
       
    27 axiomatization
       
    28   random_seed :: "randseed \<Rightarrow> nat"
       
    29 
       
    30 definition
       
    31   random :: "nat \<Rightarrow> randseed \<Rightarrow> nat \<times> randseed" where
       
    32   "random n s = (random_seed s mod n, random_shift s)"
       
    33 
       
    34 lemma random_bound:
       
    35   assumes "0 < n"
       
    36   shows "fst (random n s) < n"
       
    37 proof -
       
    38   from prems mod_less_divisor have "!!m .m mod n < n" by auto
       
    39   then show ?thesis unfolding random_def by simp 
       
    40 qed
       
    41 
       
    42 lemma random_random_seed [simp]:
       
    43   "snd (random n s) = random_shift s" unfolding random_def by simp
       
    44 
       
    45 definition
       
    46   select :: "'a list \<Rightarrow> randseed \<Rightarrow> 'a \<times> randseed" where
       
    47   [simp]: "select xs = (do
       
    48       n \<leftarrow> random (length xs);
       
    49       return (nth xs n)
       
    50     done)"
       
    51 definition
       
    52   select_weight :: "(nat \<times> 'a) list \<Rightarrow> randseed \<Rightarrow> 'a \<times> randseed" where
       
    53   [simp]: "select_weight xs = (do
       
    54       n \<leftarrow> random (foldl (op +) 0 (map fst xs));
       
    55       return (pick xs n)
       
    56     done)"
       
    57 
       
    58 lemma
       
    59   "select (x#xs) s = select_weight (map (Pair 1) (x#xs)) s"
       
    60 proof (induct xs)
       
    61   case Nil show ?case by (simp add: monad_collapse random_def)
       
    62 next
       
    63   have map_fst_Pair: "!!xs y. map fst (map (Pair y) xs) = replicate (length xs) y"
       
    64   proof -
       
    65     fix xs
       
    66     fix y
       
    67     show "map fst (map (Pair y) xs) = replicate (length xs) y"
       
    68       by (induct xs) simp_all
       
    69   qed
       
    70   have pick_nth: "!!xs n. n < length xs \<Longrightarrow> pick (map (Pair 1) xs) n = nth xs n"
       
    71   proof -
       
    72     fix xs
       
    73     fix n
       
    74     assume "n < length xs"
       
    75     then show "pick (map (Pair 1) xs) n = nth xs n"
       
    76     proof (induct xs arbitrary: n)
       
    77       case Nil then show ?case by simp
       
    78     next
       
    79       case (Cons x xs) show ?case
       
    80       proof (cases n)
       
    81         case 0 then show ?thesis by simp
       
    82       next
       
    83         case (Suc _)
       
    84     from Cons have "n < length (x # xs)" by auto
       
    85         then have "n < Suc (length xs)" by simp
       
    86         with Suc have "n - 1 < Suc (length xs) - 1" by auto
       
    87         with Cons have "pick (map (Pair (1\<Colon>nat)) xs) (n - 1) = xs ! (n - 1)" by auto
       
    88         with Suc show ?thesis by auto
       
    89       qed
       
    90     qed
       
    91   qed
       
    92   have sum_length: "!!xs. foldl (op +) 0 (map fst (map (Pair 1) xs)) = length xs"
       
    93   proof -
       
    94     have replicate_append:
       
    95       "!!x xs y. replicate (length (x # xs)) y = replicate (length xs) y @ [y]"
       
    96       by (simp add: replicate_app_Cons_same)
       
    97     fix xs
       
    98     show "foldl (op +) 0 (map fst (map (Pair 1) xs)) = length xs"
       
    99     unfolding map_fst_Pair proof (induct xs)
       
   100       case Nil show ?case by simp
       
   101     next
       
   102       case (Cons x xs) then show ?case unfolding replicate_append by simp
       
   103     qed
       
   104   qed
       
   105   have pick_nth_random:
       
   106     "!!x xs s. pick (map (Pair 1) (x#xs)) (fst (random (length (x#xs)) s)) = nth (x#xs) (fst (random (length (x#xs)) s))"
       
   107   proof -
       
   108     fix s
       
   109     fix x
       
   110     fix xs
       
   111     have bound: "fst (random (length (x#xs)) s) < length (x#xs)" by (rule random_bound) simp
       
   112     from pick_nth [OF bound] show
       
   113       "pick (map (Pair 1) (x#xs)) (fst (random (length (x#xs)) s)) = nth (x#xs) (fst (random (length (x#xs)) s))" .
       
   114   qed
       
   115   have pick_nth_random_do:
       
   116     "!!x xs s. (do n \<leftarrow> random (length (x#xs)); return (pick (map (Pair 1) (x#xs)) n) done) s =
       
   117       (do n \<leftarrow> random (length (x#xs)); return (nth (x#xs) n) done) s"
       
   118   unfolding monad_collapse split_def unfolding pick_nth_random ..
       
   119   case (Cons x xs) then show ?case
       
   120     unfolding select_weight_def sum_length pick_nth_random_do
       
   121     by simp
       
   122 qed
       
   123 
       
   124 definition
       
   125   random_int :: "int \<Rightarrow> randseed \<Rightarrow> int * randseed" where
       
   126   "random_int k = (do n \<leftarrow> random (nat k); return (int n) done)"
       
   127 
       
   128 lemma random_nat [code]:
       
   129   "random n = (do k \<leftarrow> random_int (int n); return (nat k) done)"
       
   130 unfolding random_int_def by simp
       
   131 
       
   132 axiomatization
       
   133   run_random :: "(randseed \<Rightarrow> 'a * randseed) \<Rightarrow> 'a"
       
   134 
       
   135 ML {*
       
   136 signature RANDOM =
       
   137 sig
       
   138   type seed = IntInf.int;
       
   139   val seed: unit -> seed;
       
   140   val value: IntInf.int -> seed -> IntInf.int * seed;
       
   141 end;
       
   142 
       
   143 structure Random : RANDOM =
       
   144 struct
       
   145 
       
   146 open IntInf;
       
   147 
       
   148 exception RANDOM;
       
   149 
       
   150 type seed = int;
       
   151 
       
   152 local
       
   153   val a = fromInt 16807;
       
   154     (*greetings to SML/NJ*)
       
   155   val m = (the o fromString) "2147483647";
       
   156 in
       
   157   fun next s = (a * s) mod m;
       
   158 end;
       
   159 
       
   160 local
       
   161   val seed_ref = ref (fromInt 1);
       
   162 in
       
   163   fun seed () =
       
   164     let
       
   165       val r = next (!seed_ref)
       
   166     in
       
   167       (seed_ref := r; r)
       
   168     end;
       
   169 end;
       
   170 
       
   171 fun value h s =
       
   172   if h < 1 then raise RANDOM
       
   173   else (s mod (h - 1), seed ());
       
   174 
       
   175 end;
       
   176 *}
       
   177 
       
   178 code_type randseed
       
   179   (SML "Random.seed")
       
   180 
       
   181 code_const random_int
       
   182   (SML "Random.value")
       
   183 
       
   184 code_const run_random
       
   185   (SML "case _ (Random.seed ()) of (x, '_) => x")
       
   186 
       
   187 code_gen select select_weight
       
   188   (SML #)
       
   189 
       
   190 end