haftmann@22528: (* ID: $Id$ haftmann@22528: Author: Florian Haftmann, TU Muenchen haftmann@22528: *) haftmann@22528: haftmann@22528: header {* A simple random engine *} haftmann@22528: haftmann@22528: theory Random haftmann@22528: imports State_Monad haftmann@22528: begin haftmann@22528: haftmann@22528: fun haftmann@22528: pick :: "(nat \ 'a) list \ nat \ 'a" haftmann@22528: where haftmann@22528: pick_undef: "pick [] n = undefined" haftmann@22528: | pick_simp: "pick ((k, v)#xs) n = (if n < k then v else pick xs (n - k))" haftmann@22528: lemmas [code nofunc] = pick_undef haftmann@22528: haftmann@22528: typedecl randseed haftmann@22528: haftmann@22528: axiomatization haftmann@22528: random_shift :: "randseed \ randseed" haftmann@22528: haftmann@22528: axiomatization haftmann@22528: random_seed :: "randseed \ nat" haftmann@22528: haftmann@22528: definition haftmann@22528: random :: "nat \ randseed \ nat \ randseed" where haftmann@22528: "random n s = (random_seed s mod n, random_shift s)" haftmann@22528: haftmann@22528: lemma random_bound: haftmann@22528: assumes "0 < n" haftmann@22528: shows "fst (random n s) < n" haftmann@22528: proof - haftmann@22528: from prems mod_less_divisor have "!!m .m mod n < n" by auto haftmann@22528: then show ?thesis unfolding random_def by simp haftmann@22528: qed haftmann@22528: haftmann@22528: lemma random_random_seed [simp]: haftmann@22528: "snd (random n s) = random_shift s" unfolding random_def by simp haftmann@22528: haftmann@22528: definition haftmann@22528: select :: "'a list \ randseed \ 'a \ randseed" where haftmann@22528: [simp]: "select xs = (do haftmann@22528: n \ random (length xs); haftmann@22528: return (nth xs n) haftmann@22528: done)" haftmann@22528: definition haftmann@22528: select_weight :: "(nat \ 'a) list \ randseed \ 'a \ randseed" where haftmann@22528: [simp]: "select_weight xs = (do haftmann@22528: n \ random (foldl (op +) 0 (map fst xs)); haftmann@22528: return (pick xs n) haftmann@22528: done)" haftmann@22528: haftmann@22528: lemma haftmann@22528: "select (x#xs) s = select_weight (map (Pair 1) (x#xs)) s" haftmann@22528: proof (induct xs) haftmann@22528: case Nil show ?case by (simp add: monad_collapse random_def) haftmann@22528: next haftmann@22528: have map_fst_Pair: "!!xs y. map fst (map (Pair y) xs) = replicate (length xs) y" haftmann@22528: proof - haftmann@22528: fix xs haftmann@22528: fix y haftmann@22528: show "map fst (map (Pair y) xs) = replicate (length xs) y" haftmann@22528: by (induct xs) simp_all haftmann@22528: qed haftmann@22528: have pick_nth: "!!xs n. n < length xs \ pick (map (Pair 1) xs) n = nth xs n" haftmann@22528: proof - haftmann@22528: fix xs haftmann@22528: fix n haftmann@22528: assume "n < length xs" haftmann@22528: then show "pick (map (Pair 1) xs) n = nth xs n" haftmann@22528: proof (induct xs arbitrary: n) haftmann@22528: case Nil then show ?case by simp haftmann@22528: next haftmann@22528: case (Cons x xs) show ?case haftmann@22528: proof (cases n) haftmann@22528: case 0 then show ?thesis by simp haftmann@22528: next haftmann@22528: case (Suc _) haftmann@22528: from Cons have "n < length (x # xs)" by auto haftmann@22528: then have "n < Suc (length xs)" by simp haftmann@22528: with Suc have "n - 1 < Suc (length xs) - 1" by auto haftmann@22528: with Cons have "pick (map (Pair (1\nat)) xs) (n - 1) = xs ! (n - 1)" by auto haftmann@22528: with Suc show ?thesis by auto haftmann@22528: qed haftmann@22528: qed haftmann@22528: qed haftmann@22528: have sum_length: "!!xs. foldl (op +) 0 (map fst (map (Pair 1) xs)) = length xs" haftmann@22528: proof - haftmann@22528: have replicate_append: haftmann@22528: "!!x xs y. replicate (length (x # xs)) y = replicate (length xs) y @ [y]" haftmann@22528: by (simp add: replicate_app_Cons_same) haftmann@22528: fix xs haftmann@22528: show "foldl (op +) 0 (map fst (map (Pair 1) xs)) = length xs" haftmann@22528: unfolding map_fst_Pair proof (induct xs) haftmann@22528: case Nil show ?case by simp haftmann@22528: next haftmann@22528: case (Cons x xs) then show ?case unfolding replicate_append by simp haftmann@22528: qed haftmann@22528: qed haftmann@22528: have pick_nth_random: haftmann@22528: "!!x xs s. pick (map (Pair 1) (x#xs)) (fst (random (length (x#xs)) s)) = nth (x#xs) (fst (random (length (x#xs)) s))" haftmann@22528: proof - haftmann@22528: fix s haftmann@22528: fix x haftmann@22528: fix xs haftmann@22528: have bound: "fst (random (length (x#xs)) s) < length (x#xs)" by (rule random_bound) simp haftmann@22528: from pick_nth [OF bound] show haftmann@22528: "pick (map (Pair 1) (x#xs)) (fst (random (length (x#xs)) s)) = nth (x#xs) (fst (random (length (x#xs)) s))" . haftmann@22528: qed haftmann@22528: have pick_nth_random_do: haftmann@22528: "!!x xs s. (do n \ random (length (x#xs)); return (pick (map (Pair 1) (x#xs)) n) done) s = haftmann@22528: (do n \ random (length (x#xs)); return (nth (x#xs) n) done) s" haftmann@22528: unfolding monad_collapse split_def unfolding pick_nth_random .. haftmann@22528: case (Cons x xs) then show ?case haftmann@22528: unfolding select_weight_def sum_length pick_nth_random_do haftmann@22528: by simp haftmann@22528: qed haftmann@22528: haftmann@22528: definition haftmann@22528: random_int :: "int \ randseed \ int * randseed" where haftmann@22528: "random_int k = (do n \ random (nat k); return (int n) done)" haftmann@22528: haftmann@22528: lemma random_nat [code]: haftmann@22528: "random n = (do k \ random_int (int n); return (nat k) done)" haftmann@22528: unfolding random_int_def by simp haftmann@22528: haftmann@22528: axiomatization haftmann@22528: run_random :: "(randseed \ 'a * randseed) \ 'a" haftmann@22528: haftmann@22528: ML {* haftmann@22528: signature RANDOM = haftmann@22528: sig haftmann@22528: type seed = IntInf.int; haftmann@22528: val seed: unit -> seed; haftmann@22528: val value: IntInf.int -> seed -> IntInf.int * seed; haftmann@22528: end; haftmann@22528: haftmann@22528: structure Random : RANDOM = haftmann@22528: struct haftmann@22528: haftmann@22528: open IntInf; haftmann@22528: haftmann@22528: exception RANDOM; haftmann@22528: haftmann@22528: type seed = int; haftmann@22528: haftmann@22528: local haftmann@22528: val a = fromInt 16807; haftmann@22528: (*greetings to SML/NJ*) haftmann@22528: val m = (the o fromString) "2147483647"; haftmann@22528: in haftmann@22528: fun next s = (a * s) mod m; haftmann@22528: end; haftmann@22528: haftmann@22528: local haftmann@22528: val seed_ref = ref (fromInt 1); haftmann@22528: in haftmann@22528: fun seed () = haftmann@22528: let haftmann@22528: val r = next (!seed_ref) haftmann@22528: in haftmann@22528: (seed_ref := r; r) haftmann@22528: end; haftmann@22528: end; haftmann@22528: haftmann@22528: fun value h s = haftmann@22528: if h < 1 then raise RANDOM haftmann@22528: else (s mod (h - 1), seed ()); haftmann@22528: haftmann@22528: end; haftmann@22528: *} haftmann@22528: haftmann@22528: code_type randseed haftmann@22528: (SML "Random.seed") haftmann@22528: haftmann@22528: code_const random_int haftmann@22528: (SML "Random.value") haftmann@22528: haftmann@22528: code_const run_random haftmann@22528: (SML "case _ (Random.seed ()) of (x, '_) => x") haftmann@22528: haftmann@22528: code_gen select select_weight haftmann@22528: (SML #) haftmann@22528: haftmann@22528: end