src/HOL/ex/Random.thy
changeset 22528 8501c4a62a3c
child 22799 ed7d53db2170
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/ex/Random.thy	Tue Mar 27 12:28:42 2007 +0200
     1.3 @@ -0,0 +1,186 @@
     1.4 +(*  ID:         $Id$
     1.5 +    Author:     Florian Haftmann, TU Muenchen
     1.6 +*)
     1.7 +
     1.8 +header {* A simple random engine *}
     1.9 +
    1.10 +theory Random
    1.11 +imports State_Monad
    1.12 +begin
    1.13 +
    1.14 +fun
    1.15 +  pick :: "(nat \<times> 'a) list \<Rightarrow> nat \<Rightarrow> 'a"
    1.16 +where
    1.17 +  pick_undef: "pick [] n = undefined"
    1.18 +  | pick_simp: "pick ((k, v)#xs) n = (if n < k then v else pick xs (n - k))"
    1.19 +lemmas [code nofunc] = pick_undef
    1.20 +
    1.21 +typedecl randseed
    1.22 +
    1.23 +axiomatization
    1.24 +  random_shift :: "randseed \<Rightarrow> randseed"
    1.25 +
    1.26 +axiomatization
    1.27 +  random_seed :: "randseed \<Rightarrow> nat"
    1.28 +
    1.29 +definition
    1.30 +  random :: "nat \<Rightarrow> randseed \<Rightarrow> nat \<times> randseed" where
    1.31 +  "random n s = (random_seed s mod n, random_shift s)"
    1.32 +
    1.33 +lemma random_bound:
    1.34 +  assumes "0 < n"
    1.35 +  shows "fst (random n s) < n"
    1.36 +proof -
    1.37 +  from prems mod_less_divisor have "!!m .m mod n < n" by auto
    1.38 +  then show ?thesis unfolding random_def by simp 
    1.39 +qed
    1.40 +
    1.41 +lemma random_random_seed [simp]:
    1.42 +  "snd (random n s) = random_shift s" unfolding random_def by simp
    1.43 +
    1.44 +definition
    1.45 +  select :: "'a list \<Rightarrow> randseed \<Rightarrow> 'a \<times> randseed" where
    1.46 +  [simp]: "select xs = (do
    1.47 +      n \<leftarrow> random (length xs);
    1.48 +      return (nth xs n)
    1.49 +    done)"
    1.50 +definition
    1.51 +  select_weight :: "(nat \<times> 'a) list \<Rightarrow> randseed \<Rightarrow> 'a \<times> randseed" where
    1.52 +  [simp]: "select_weight xs = (do
    1.53 +      n \<leftarrow> random (foldl (op +) 0 (map fst xs));
    1.54 +      return (pick xs n)
    1.55 +    done)"
    1.56 +
    1.57 +lemma
    1.58 +  "select (x#xs) s = select_weight (map (Pair 1) (x#xs)) s"
    1.59 +proof (induct xs)
    1.60 +  case Nil show ?case by (simp add: monad_collapse random_def)
    1.61 +next
    1.62 +  have map_fst_Pair: "!!xs y. map fst (map (Pair y) xs) = replicate (length xs) y"
    1.63 +  proof -
    1.64 +    fix xs
    1.65 +    fix y
    1.66 +    show "map fst (map (Pair y) xs) = replicate (length xs) y"
    1.67 +      by (induct xs) simp_all
    1.68 +  qed
    1.69 +  have pick_nth: "!!xs n. n < length xs \<Longrightarrow> pick (map (Pair 1) xs) n = nth xs n"
    1.70 +  proof -
    1.71 +    fix xs
    1.72 +    fix n
    1.73 +    assume "n < length xs"
    1.74 +    then show "pick (map (Pair 1) xs) n = nth xs n"
    1.75 +    proof (induct xs arbitrary: n)
    1.76 +      case Nil then show ?case by simp
    1.77 +    next
    1.78 +      case (Cons x xs) show ?case
    1.79 +      proof (cases n)
    1.80 +        case 0 then show ?thesis by simp
    1.81 +      next
    1.82 +        case (Suc _)
    1.83 +    from Cons have "n < length (x # xs)" by auto
    1.84 +        then have "n < Suc (length xs)" by simp
    1.85 +        with Suc have "n - 1 < Suc (length xs) - 1" by auto
    1.86 +        with Cons have "pick (map (Pair (1\<Colon>nat)) xs) (n - 1) = xs ! (n - 1)" by auto
    1.87 +        with Suc show ?thesis by auto
    1.88 +      qed
    1.89 +    qed
    1.90 +  qed
    1.91 +  have sum_length: "!!xs. foldl (op +) 0 (map fst (map (Pair 1) xs)) = length xs"
    1.92 +  proof -
    1.93 +    have replicate_append:
    1.94 +      "!!x xs y. replicate (length (x # xs)) y = replicate (length xs) y @ [y]"
    1.95 +      by (simp add: replicate_app_Cons_same)
    1.96 +    fix xs
    1.97 +    show "foldl (op +) 0 (map fst (map (Pair 1) xs)) = length xs"
    1.98 +    unfolding map_fst_Pair proof (induct xs)
    1.99 +      case Nil show ?case by simp
   1.100 +    next
   1.101 +      case (Cons x xs) then show ?case unfolding replicate_append by simp
   1.102 +    qed
   1.103 +  qed
   1.104 +  have pick_nth_random:
   1.105 +    "!!x xs s. pick (map (Pair 1) (x#xs)) (fst (random (length (x#xs)) s)) = nth (x#xs) (fst (random (length (x#xs)) s))"
   1.106 +  proof -
   1.107 +    fix s
   1.108 +    fix x
   1.109 +    fix xs
   1.110 +    have bound: "fst (random (length (x#xs)) s) < length (x#xs)" by (rule random_bound) simp
   1.111 +    from pick_nth [OF bound] show
   1.112 +      "pick (map (Pair 1) (x#xs)) (fst (random (length (x#xs)) s)) = nth (x#xs) (fst (random (length (x#xs)) s))" .
   1.113 +  qed
   1.114 +  have pick_nth_random_do:
   1.115 +    "!!x xs s. (do n \<leftarrow> random (length (x#xs)); return (pick (map (Pair 1) (x#xs)) n) done) s =
   1.116 +      (do n \<leftarrow> random (length (x#xs)); return (nth (x#xs) n) done) s"
   1.117 +  unfolding monad_collapse split_def unfolding pick_nth_random ..
   1.118 +  case (Cons x xs) then show ?case
   1.119 +    unfolding select_weight_def sum_length pick_nth_random_do
   1.120 +    by simp
   1.121 +qed
   1.122 +
   1.123 +definition
   1.124 +  random_int :: "int \<Rightarrow> randseed \<Rightarrow> int * randseed" where
   1.125 +  "random_int k = (do n \<leftarrow> random (nat k); return (int n) done)"
   1.126 +
   1.127 +lemma random_nat [code]:
   1.128 +  "random n = (do k \<leftarrow> random_int (int n); return (nat k) done)"
   1.129 +unfolding random_int_def by simp
   1.130 +
   1.131 +axiomatization
   1.132 +  run_random :: "(randseed \<Rightarrow> 'a * randseed) \<Rightarrow> 'a"
   1.133 +
   1.134 +ML {*
   1.135 +signature RANDOM =
   1.136 +sig
   1.137 +  type seed = IntInf.int;
   1.138 +  val seed: unit -> seed;
   1.139 +  val value: IntInf.int -> seed -> IntInf.int * seed;
   1.140 +end;
   1.141 +
   1.142 +structure Random : RANDOM =
   1.143 +struct
   1.144 +
   1.145 +open IntInf;
   1.146 +
   1.147 +exception RANDOM;
   1.148 +
   1.149 +type seed = int;
   1.150 +
   1.151 +local
   1.152 +  val a = fromInt 16807;
   1.153 +    (*greetings to SML/NJ*)
   1.154 +  val m = (the o fromString) "2147483647";
   1.155 +in
   1.156 +  fun next s = (a * s) mod m;
   1.157 +end;
   1.158 +
   1.159 +local
   1.160 +  val seed_ref = ref (fromInt 1);
   1.161 +in
   1.162 +  fun seed () =
   1.163 +    let
   1.164 +      val r = next (!seed_ref)
   1.165 +    in
   1.166 +      (seed_ref := r; r)
   1.167 +    end;
   1.168 +end;
   1.169 +
   1.170 +fun value h s =
   1.171 +  if h < 1 then raise RANDOM
   1.172 +  else (s mod (h - 1), seed ());
   1.173 +
   1.174 +end;
   1.175 +*}
   1.176 +
   1.177 +code_type randseed
   1.178 +  (SML "Random.seed")
   1.179 +
   1.180 +code_const random_int
   1.181 +  (SML "Random.value")
   1.182 +
   1.183 +code_const run_random
   1.184 +  (SML "case _ (Random.seed ()) of (x, '_) => x")
   1.185 +
   1.186 +code_gen select select_weight
   1.187 +  (SML #)
   1.188 +
   1.189 +end