quickcheck with term reconstruction
authorhaftmann
Wed, 19 Mar 2008 07:20:29 +0100
changeset 26325 6ecae5c8175b
parent 26324 456f726a11e4
child 26326 a68045977f60
quickcheck with term reconstruction
src/HOL/ex/Quickcheck.thy
--- a/src/HOL/ex/Quickcheck.thy	Wed Mar 19 07:20:28 2008 +0100
+++ b/src/HOL/ex/Quickcheck.thy	Wed Mar 19 07:20:29 2008 +0100
@@ -10,16 +10,16 @@
 
 subsection {* The @{text random} class *}
 
-class random = type +
-  fixes random :: "index \<Rightarrow> seed \<Rightarrow> 'a \<times> seed"
+class random = rtype +
+  fixes random :: "index \<Rightarrow> seed \<Rightarrow> ('a \<times> (unit \<Rightarrow> term)) \<times> seed"
 
 text {* Type @{typ "'a itself"} *}
 
-instantiation itself :: (type) random
+instantiation itself :: ("{type, rtype}") random
 begin
 
 definition
-  "random _ = return TYPE('a)"
+  "random _ = return (TYPE('a), \<lambda>u. Eval.Const (STR ''TYPE'') RTYPE('a))"
 
 instance ..
 
@@ -28,7 +28,7 @@
 text {* Datatypes *}
 
 lemma random'_if:
-  fixes random' :: "index \<Rightarrow> index \<Rightarrow> seed \<Rightarrow> 'a \<times> seed"
+  fixes random' :: "index \<Rightarrow> index \<Rightarrow> seed \<Rightarrow> ('a \<times> (unit \<Rightarrow> term)) \<times> seed"
   assumes "random' 0 j = undefined"
     and "\<And>i. random' (Suc_index i) j = rhs2 i"
   shows "random' i j s = (if i = 0 then undefined else rhs2 (i - 1) s)"
@@ -39,38 +39,48 @@
   exception REC of string;
   fun mk_collapse thy ty = Sign.mk_const thy
     (@{const_name collapse}, [@{typ seed}, ty]);
-  fun mk_cons this_ty (c, args) =
+  fun term_ty ty = HOLogic.mk_prodT (ty, @{typ "unit \<Rightarrow> term"});
+  fun mk_split thy ty ty' = Sign.mk_const thy
+    (@{const_name split}, [ty, @{typ "unit \<Rightarrow> term"}, StateMonad.liftT (term_ty ty') @{typ seed}]);
+  fun mk_mbind_split thy ty ty' t t' =
+    StateMonad.mbind (term_ty ty) (term_ty ty') @{typ seed} t
+      (mk_split thy ty ty' $ Abs ("", ty, Abs ("", @{typ "unit \<Rightarrow> term"}, t')))
+  fun mk_cons thy this_ty (c, args) =
     let
       val tys = map (fst o fst) args;
-      val return = StateMonad.return this_ty @{typ seed}
-        (list_comb (Const (c, tys ---> this_ty),
-           map Bound (length tys - 1 downto 0)));
-      val t = fold_rev (fn ((ty, _), random) => fn t =>
-        StateMonad.mbind ty this_ty @{typ seed} random (Abs ("", ty, t)))
+      val c_ty = tys ---> this_ty;
+      val c = Const (c, tys ---> this_ty);
+      val t_indices = map (curry ( op * ) 2) (length tys - 1 downto 0);
+      val c_indices = map (curry ( op + ) 1) t_indices;
+      val c_t = list_comb (c, map Bound c_indices);
+      val t_t = Abs ("", @{typ unit}, Eval.mk_term Free RType.rtype
+        (list_comb (c, map (fn k => Bound (k + 1)) t_indices))
+        |> map_aterms (fn t as Bound _ => t $ @{term "()"} | t => t));
+      val return = StateMonad.return (term_ty this_ty) @{typ seed}
+        (HOLogic.mk_prod (c_t, t_t));
+      val t = fold_rev (fn ((ty, _), random) =>
+        mk_mbind_split thy ty this_ty random)
           args return;
       val is_rec = exists (snd o fst) args;
-    in (is_rec, StateMonad.run this_ty @{typ seed} t) end;
+    in (is_rec, StateMonad.run (term_ty this_ty) @{typ seed} t) end;
   fun mk_conss thy ty [] = NONE
     | mk_conss thy ty [(_, t)] = SOME t
-    | mk_conss thy ty ts = SOME (mk_collapse thy ty $
-          (Sign.mk_const thy (@{const_name select}, [StateMonad.liftT ty @{typ seed}]) $
-            HOLogic.mk_list (StateMonad.liftT ty @{typ seed}) (map snd ts)));
+    | mk_conss thy ty ts = SOME (mk_collapse thy (term_ty ty) $
+          (Sign.mk_const thy (@{const_name select}, [StateMonad.liftT (term_ty ty) @{typ seed}]) $
+            HOLogic.mk_list (StateMonad.liftT (term_ty ty) @{typ seed}) (map snd ts)));
   fun mk_clauses thy ty (tyco, (ts_rec, ts_atom)) = 
     let
       val SOME t_atom = mk_conss thy ty ts_atom;
     in case mk_conss thy ty ts_rec
-     of SOME t_rec => mk_collapse thy ty $
-          (Sign.mk_const thy (@{const_name select_default}, [StateMonad.liftT ty @{typ seed}]) $
+     of SOME t_rec => mk_collapse thy (term_ty ty) $
+          (Sign.mk_const thy (@{const_name select_default}, [StateMonad.liftT (term_ty ty) @{typ seed}]) $
              @{term "i\<Colon>index"} $ t_rec $ t_atom)
       | NONE => t_atom
     end;
-  fun mk_random_eqs thy tycos =
+  fun mk_random_eqs thy vs tycos =
     let
-      val (raw_vs, _) = DatatypePackage.the_datatype_spec thy (hd tycos);
-      val vs = (map o apsnd)
-        (curry (Sorts.inter_sort (Sign.classes_of thy)) @{sort random}) raw_vs;
       val this_ty = Type (hd tycos, map TFree vs);
-      val this_ty' = StateMonad.liftT this_ty @{typ seed};
+      val this_ty' = StateMonad.liftT (term_ty this_ty) @{typ seed};
       val random_name = NameSpace.base @{const_name random};
       val random'_name = random_name ^ "_" ^ Class.type_name (hd tycos) ^ "'";
       fun random ty = Sign.mk_const thy (@{const_name random}, [ty]);
@@ -82,7 +92,7 @@
         ("Will not generate random elements for mutual recursive type " ^ quote (hd tycos));
       val rhss = DatatypePackage.construction_interpretation thy
             { atom = atom, dtyp = dtyp, rtyp = rtyp } vs tycos
-        |> (map o apsnd o map) (mk_cons this_ty) 
+        |> (map o apsnd o map) (mk_cons thy this_ty) 
         |> (map o apsnd) (List.partition fst)
         |> map (mk_clauses thy this_ty)
       val eqss = map ((apsnd o map) (HOLogic.mk_Trueprop o HOLogic.mk_eq) o (fn rhs => ((this_ty, random'), [
@@ -96,7 +106,7 @@
           val vs = (map o apsnd)
             (curry (Sorts.inter_sort (Sign.classes_of thy)) @{sort random}) raw_vs;
           val { descr, index, ... } = DatatypePackage.the_datatype thy tyco;
-          val ((this_ty, random'), eqs') = singleton (mk_random_eqs thy) tyco;
+          val ((this_ty, random'), eqs') = singleton (mk_random_eqs thy vs) tyco;
           val eq = (HOLogic.mk_Trueprop o HOLogic.mk_eq)
             (Sign.mk_const thy (@{const_name random}, [this_ty]) $ @{term "i\<Colon>index"},
                random' $ @{term "i\<Colon>index"} $ @{term "i\<Colon>index"})
@@ -142,8 +152,11 @@
 
 definition
   "random n = (do
-     (b, m) \<leftarrow> random n;
-     return (if b then int m else - int m)
+     (b, _) \<leftarrow> random n;
+     (m, t) \<leftarrow> random n;
+     return (if b then (int m, \<lambda>u. Eval.App (Eval.Const (STR ''Int.int'') RTYPE(nat \<Rightarrow> int)) (t ()))
+       else (- int m, \<lambda>u. Eval.App (Eval.Const (STR ''HOL.uminus_class.uminus'') RTYPE(int \<Rightarrow> int))
+         (Eval.App (Eval.Const (STR ''Int.int'') RTYPE(nat \<Rightarrow> int)) (t ()))))
    done)"
 
 instance ..
@@ -152,19 +165,23 @@
 
 text {* Type @{typ "'a set"} *}
 
-instantiation set :: (random) random
+instantiation set :: ("{random, type}") random
 begin
 
-primrec random_set' :: "index \<Rightarrow> index \<Rightarrow> seed \<Rightarrow> 'a set \<times> seed" where
+primrec random_set' :: "index \<Rightarrow> index \<Rightarrow> seed \<Rightarrow> ('a\<Colon>{random, type} set \<times> (unit \<Rightarrow> term)) \<times> seed" where
   "random_set' 0 j = undefined"
   | "random_set' (Suc_index i) j = collapse (select_default i
-       (do x \<leftarrow> random i; xs \<leftarrow> random_set' i j; return (insert x xs) done)
-       (return {}))"
+       (do (x, t) \<leftarrow> random i;
+           (xs, ts) \<leftarrow> random_set' i j;
+           return (insert x xs, \<lambda>u. Eval.App (Eval.App (Eval.Const (STR ''insert'') RTYPE('a \<Rightarrow> 'a set \<Rightarrow> 'a set)) (t ())) (ts ())) done)
+       (return ({}, \<lambda>u. Eval.Const (STR ''{}'') RTYPE('a set))))"
 
 lemma random_set'_code [code func]:
   "random_set' i j s = (if i = 0 then undefined else collapse (select_default (i - 1)
-       (do x \<leftarrow> random (i - 1); xs \<leftarrow> random_set' (i - 1) j; return (insert x xs) done)
-       (return {})) s)"
+       (do (x \<Colon> 'a\<Colon>{random, type}, t) \<leftarrow> random (i - 1);
+           (xs, ts) \<leftarrow> random_set' (i - 1) j;
+           return (insert x xs, \<lambda>u. Eval.App (Eval.App (Eval.Const (STR ''insert'') RTYPE('a \<Rightarrow> 'a set \<Rightarrow> 'a set)) (t ())) (ts ())) done)
+       (return ({}, \<lambda>u. Eval.Const (STR ''{}'') RTYPE('a set)))) s)"
   by (rule random'_if random_set'.simps)+
 
 definition
@@ -182,49 +199,46 @@
 
 open Random_Engine;
 
-fun random_fun (eq : 'a -> 'a -> bool)
-    (random : Random_Engine.seed -> 'b * Random_Engine.seed)
+fun random_fun (T1 : typ) (T2 : typ) (eq : 'a -> 'a -> bool) (term_of : 'a -> term)
+    (random : Random_Engine.seed -> ('b * (unit -> term)) * Random_Engine.seed)
     (random_split : Random_Engine.seed -> Random_Engine.seed * Random_Engine.seed)
     (seed : Random_Engine.seed) =
   let
     val (seed', seed'') = random_split seed;
-    val state = ref (seed', []);
+    val state = ref (seed', [], Const (@{const_name arbitrary}, T1 --> T2));
+    val fun_upd = Const (@{const_name fun_upd},
+      (T1 --> T2) --> T1 --> T2 --> T1 --> T2);
     fun random_fun' x =
       let
-        val (seed, fun_map) = ! state;
+        val (seed, fun_map, f_t) = ! state;
       in case AList.lookup (uncurry eq) fun_map x
        of SOME y => y
         | NONE => let
-              val (y, seed') = random seed;
-              val _ = state := (seed', (x, y) :: fun_map);
+              val t1 = term_of x;
+              val ((y, t2), seed') = random seed;
+              val fun_map' = (x, y) :: fun_map;
+              val f_t' = fun_upd $ f_t $ t1 $ t2 ();
+              val _ = state := (seed', fun_map', f_t');
             in y end
       end;
-  in (random_fun', seed'') end;
+    fun term_fun' () = #3 (! state);
+  in ((random_fun', term_fun'), seed'') end;
 
 end
 *}
 
 axiomatization
-  random_fun_aux :: "('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> (seed \<Rightarrow> 'b \<times> seed)
-    \<Rightarrow> (seed \<Rightarrow> seed \<times> seed) \<Rightarrow> seed \<Rightarrow> ('a \<Rightarrow> 'b) \<times> seed"
+  random_fun_aux :: "rtype \<Rightarrow> rtype \<Rightarrow> ('a \<Rightarrow> 'a \<Rightarrow> bool) \<Rightarrow> ('a \<Rightarrow> term)
+    \<Rightarrow> (seed \<Rightarrow> ('b \<times> (unit \<Rightarrow> term)) \<times> seed) \<Rightarrow> (seed \<Rightarrow> seed \<times> seed)
+    \<Rightarrow> seed \<Rightarrow> (('a \<Rightarrow> 'b) \<times> (unit \<Rightarrow> term)) \<times> seed"
 
 code_const random_fun_aux (SML "Random'_Engine.random'_fun")
 
-instantiation "fun" :: (term_of, term_of) term_of
+instantiation "fun" :: ("{eq, term_of}", "{type, random}") random
 begin
 
-instance ..
-
-end
-
-code_const "Eval.term_of :: ('a\<Colon>term_of \<Rightarrow> 'b\<Colon>term_of) \<Rightarrow> _"
-  (SML "(fn '_ => Const (\"arbitrary\", dummyT))")
-
-instantiation "fun" :: (eq, "{type, random}") random
-begin
-
-definition
-  "random n = random_fun_aux (op =) (random n) split_seed"
+definition random_fun :: "index \<Rightarrow> seed \<Rightarrow> (('a \<Rightarrow> 'b) \<times> (unit \<Rightarrow> term)) \<times> seed" where
+  "random n = random_fun_aux RTYPE('a) RTYPE('b) (op =) Eval.term_of (random n) split_seed"
 
 instance ..
 
@@ -241,31 +255,36 @@
 
 val eval_ref : (unit -> int -> int * int -> term list option * (int * int)) option ref = ref NONE;
 
-fun mk_generator_expr prop tys =
+fun mk_generator_expr thy prop tys =
   let
-    val bounds = map_index (fn (i, ty) => (i, ty)) tys;
-    val result = list_comb (prop, map (fn (i, _) => Bound (length tys - i - 1)) bounds);
-    val terms = map (fn (i, ty) => Const (@{const_name Eval.term_of}, ty --> @{typ term}) $ Bound (length tys - i - 1)) bounds;
+    val bound_max = length tys - 1;
+    val bounds = map_index (fn (i, ty) =>
+      (2 * (bound_max - i) + 1, 2 * (bound_max - i), 2 * i, ty)) tys;
+    val result = list_comb (prop, map (fn (i, _, _, _) => Bound i) bounds);
+    val terms = HOLogic.mk_list @{typ term} (map (fn (_, i, _, _) => Bound i $ @{term "()"}) bounds);
     val check = @{term "If \<Colon> bool \<Rightarrow> term list option \<Rightarrow> term list option \<Rightarrow> term list option"}
-      $ result $ @{term "None \<Colon> term list option"} $ (@{term "Some \<Colon> term list \<Rightarrow> term list option "} $ HOLogic.mk_list @{typ term} terms);
+      $ result $ @{term "None \<Colon> term list option"} $ (@{term "Some \<Colon> term list \<Rightarrow> term list option "} $ terms);
     val return = @{term "Pair \<Colon> term list option \<Rightarrow> seed \<Rightarrow> term list option \<times> seed"};
-    fun mk_bindtyp ty = @{typ seed} --> HOLogic.mk_prodT (ty, @{typ seed});
-    fun mk_bindclause (i, ty) t = Const (@{const_name mbind}, mk_bindtyp ty
-      --> (ty --> mk_bindtyp @{typ "term list option"}) --> mk_bindtyp @{typ "term list option"})
-      $ (Const (@{const_name random}, @{typ index} --> mk_bindtyp ty)
-        $ Bound i) $ Abs ("a", ty, t);
+    fun mk_termtyp ty = HOLogic.mk_prodT (ty, @{typ "unit \<Rightarrow> term"});
+    fun mk_split ty = Sign.mk_const thy
+      (@{const_name split}, [ty, @{typ "unit \<Rightarrow> term"}, StateMonad.liftT @{typ "term list option"} @{typ seed}]);
+    fun mk_mbind_split ty t t' =
+      StateMonad.mbind (mk_termtyp ty) @{typ "term list option"} @{typ seed} t (*FIXME*)
+        (mk_split ty $ Abs ("", ty, Abs ("", @{typ "unit \<Rightarrow> term"}, t')));
+    fun mk_bindclause (_, _, i, ty) = mk_mbind_split ty
+      (Sign.mk_const thy (@{const_name random}, [ty]) $ Bound i)
     val t = fold_rev mk_bindclause bounds (return $ check);
   in Abs ("n", @{typ index}, t) end;
 
 fun compile_generator_expr thy prop tys =
   let
     val f = CodePackage.eval_term ("Quickcheck.eval_ref", eval_ref) thy
-      (mk_generator_expr prop tys) [];
+      (mk_generator_expr thy prop tys) [];
   in f #> Random_Engine.run #> (Option.map o map) (Code.postprocess_term thy) end;
 
 fun VALUE prop tys thy =
   let
-    val t = mk_generator_expr prop tys;
+    val t = mk_generator_expr thy prop tys;
     val eq = Logic.mk_equals (Free ("VALUE", fastype_of t), t)
   in
     thy
@@ -279,18 +298,24 @@
 end
 *}
 
-
 subsection {* Examples *}
 
-ML {* Quickcheck.mk_generator_expr
-  @{term "\<lambda>(n::nat) (m::nat) (q::nat). n = m + q + 1"} [@{typ nat}, @{typ nat}, @{typ nat}]
-|> Sign.string_of_term @{theory} *}
+(*export_code "random :: index \<Rightarrow> seed \<Rightarrow> ((_ \<Rightarrow> _) \<times> (unit \<Rightarrow> term)) \<times> seed"
+  in SML file -*)
 
-(*setup {* Quickcheck.VALUE @{term "\<lambda>(n::nat) (m::nat) (q::nat). n = m + q + 1"} [@{typ nat}, @{typ nat}, @{typ nat}] *}
+(*setup {* Quickcheck.VALUE
+  @{term "\<lambda>f k. int (f k) = k"} [@{typ "int \<Rightarrow> nat"}, @{typ int}] *}
+
 export_code VALUE in SML module_name QuickcheckExample file "~~/../../gen_code/quickcheck.ML"
 use "~~/../../gen_code/quickcheck.ML"
 ML {* Random_Engine.run (QuickcheckExample.range 1) *}*)
 
+(*definition "FOO = (True, Suc 0)"
+
+code_module (test) QuickcheckExample
+  file "~~/../../gen_code/quickcheck'.ML"
+  contains FOO*)
+
 ML {* val f = Quickcheck.compile_generator_expr @{theory}
   @{term "\<lambda>(n::nat) (m::nat) (q::nat). n = m + q + 1"} [@{typ nat}, @{typ nat}, @{typ nat}] *}
 
@@ -309,12 +334,6 @@
 ML {* val f = Quickcheck.compile_generator_expr @{theory}
   @{term "\<lambda>(n::int) (m::int) (q::int). n = m + q + 1"} [@{typ int}, @{typ int}, @{typ int}] *}
 
-(*definition "FOO = (True, Suc 0)"
-
-code_module (test) QuickcheckExample
-  file "~~/../../gen_code/quickcheck'.ML"
-  contains FOO*)
-
 ML {* f 5 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
 ML {* f 5 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
 ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
@@ -353,17 +372,6 @@
 ML {* f 8 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
 ML {* f 88 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
 
-
-ML {* val f = Quickcheck.compile_generator_expr @{theory}
-  @{term "\<lambda>f k. int (f k) = k"} [@{typ "int \<Rightarrow> nat"}, @{typ int}] *}
-
-ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
-ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
-ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
-ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
-ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
-ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
-
 ML {* val f = Quickcheck.compile_generator_expr @{theory}
   @{term "\<lambda>(A \<Colon> nat set) B. card (A \<union> B) = card A + card B"} [@{typ "nat set"}, @{typ "nat set"}] *}
 
@@ -391,4 +399,14 @@
 ML {* f 8 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
 ML {* f 8 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
 
+ML {* val f = Quickcheck.compile_generator_expr @{theory}
+  @{term "\<lambda>f k. int (f k) = k"} [@{typ "int \<Rightarrow> nat"}, @{typ int}] *}
+
+ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
+ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
+ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
+ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
+ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
+ML {* f 20 |> (Option.map o map) (Sign.string_of_term @{theory}) *}
+
 end