src/HOL/Tools/Quickcheck/random_generators.ML
changeset 45763 3bb2bdf654f7
parent 45762 daf57640d4df
child 45923 473b744c23f2
--- a/src/HOL/Tools/Quickcheck/random_generators.ML	Mon Dec 05 12:36:19 2011 +0100
+++ b/src/HOL/Tools/Quickcheck/random_generators.ML	Mon Dec 05 12:36:20 2011 +0100
@@ -14,7 +14,7 @@
     Proof.context -> (term * term list) list -> bool -> int list -> (bool * term list) option * Quickcheck.report option
   val put_counterexample: (unit -> int -> bool -> int -> seed -> (bool * term list) option * seed)
     -> Proof.context -> Proof.context
-  val put_counterexample_report: (unit -> int -> int -> seed -> ((bool * term list) option * (bool list * bool)) * seed)
+  val put_counterexample_report: (unit -> int -> bool -> int -> seed -> ((bool * term list) option * (bool list * bool)) * seed)
     -> Proof.context -> Proof.context
   val setup: theory -> theory
 end;
@@ -284,7 +284,7 @@
 
 structure Counterexample_Report = Proof_Data
 (
-  type T = unit -> int -> int -> seed -> ((bool * term list) option * (bool list * bool)) * seed
+  type T = unit -> int -> bool -> int -> seed -> ((bool * term list) option * (bool list * bool)) * seed
   (* FIXME avoid user error with non-user text *)
   fun init _ () = error "Counterexample_Report"
 );
@@ -304,7 +304,8 @@
     val terms = HOLogic.mk_list @{typ term} (map (fn (_, i, _, _) => Bound i $ @{term "()"}) bounds);
     val ([genuine_only_name], ctxt') = Variable.variant_fixes ["genuine_only"] ctxt
     val genuine_only = Free (genuine_only_name, @{typ bool})
-    val check = Quickcheck_Common.mk_safe_if genuine_only (result, Const (@{const_name "None"}, resultT),
+    val none_t = Const (@{const_name "None"}, resultT)
+    val check = Quickcheck_Common.mk_safe_if genuine_only none_t (result, none_t,
       fn genuine => @{term "Some :: bool * term list => (bool * term list) option"} $
         HOLogic.mk_prod (Quickcheck_Common.reflect_bool genuine, terms))
     val return = HOLogic.pair_const resultT @{typ Random.seed};
@@ -344,19 +345,16 @@
         @{term False}))
     fun mk_concl_report b =
       HOLogic.mk_prod (HOLogic.mk_list HOLogic.boolT (replicate (length assms) @{term True}),
-        if b then @{term True} else @{term False})
-    val concl_check = mk_if (concl,
-      HOLogic.mk_prod (@{term "None :: (bool * term list) option"}, mk_concl_report true),
-      HOLogic.mk_prod (@{term "Some :: bool * term list => (bool * term list) option"} $
-        HOLogic.mk_prod (@{term True}, terms), mk_concl_report false))
-    val check = fold_rev (fn (i, assm) => fn t => mk_if (assm, t, mk_assms_report i))
+        Quickcheck_Common.reflect_bool b)
+    val ([genuine_only_name], ctxt') = Variable.variant_fixes ["genuine_only"] ctxt
+    val genuine_only = Free (genuine_only_name, @{typ bool})
+    val none_t = HOLogic.mk_prod (@{term "None :: (bool * term list) option"}, mk_concl_report true)
+    val concl_check = Quickcheck_Common.mk_safe_if genuine_only none_t (concl, none_t,
+      fn genuine => HOLogic.mk_prod (@{term "Some :: bool * term list => (bool * term list) option"} $
+        HOLogic.mk_prod (Quickcheck_Common.reflect_bool genuine, terms), mk_concl_report false))
+    val check = fold_rev (fn (i, assm) => fn t => Quickcheck_Common.mk_safe_if genuine_only
+      (mk_assms_report i) (HOLogic.mk_not assm, mk_assms_report i, t))
       (map_index I assms) concl_check
-    val check' = Const (@{const_name Quickcheck.catch_match}, resultT --> resultT --> resultT) $
-      check $ (if not (Config.get ctxt Quickcheck.genuine_only) then
-        HOLogic.mk_prod (@{term "Some :: bool * term list  => (bool * term list) option"} $
-          HOLogic.mk_prod (@{term False}, terms), mk_concl_report false)
-      else
-        HOLogic.mk_prod (@{term "None :: (bool * term list) option"}, mk_concl_report true))
     fun liftT T sT = sT --> HOLogic.mk_prodT (T, sT);
     fun mk_termtyp T = HOLogic.mk_prodT (T, @{typ "unit => term"});
     fun mk_scomp T1 T2 sT f g = Const (@{const_name scomp},
@@ -369,7 +367,8 @@
     fun mk_bindclause (_, _, i, T) = mk_scomp_split T
       (Sign.mk_const thy (@{const_name Quickcheck.random}, [T]) $ Bound i);
   in
-    Abs ("n", @{typ code_numeral}, fold_rev mk_bindclause bounds (return $ check'))
+    lambda genuine_only
+      (Abs ("n", @{typ code_numeral}, fold_rev mk_bindclause bounds (return $ check true)))
   end
 
 val mk_parametric_generator_expr = Quickcheck_Common.gen_mk_parametric_generator_expr 
@@ -380,9 +379,10 @@
 
 val mk_parametric_reporting_generator_expr = Quickcheck_Common.gen_mk_parametric_generator_expr 
   ((mk_reporting_generator_expr,
-    absdummy @{typ code_numeral}
-      @{term "Pair (None, ([], False)) :: Random.seed => ((bool * term list) option * (bool list * bool)) * Random.seed"}),
-    @{typ "code_numeral => Random.seed => ((bool * term list) option * (bool list * bool)) * Random.seed"})
+    absdummy @{typ bool} (absdummy @{typ code_numeral}
+      @{term "Pair (None, ([], False)) :: Random.seed =>
+        ((bool * term list) option * (bool list * bool)) * Random.seed"})),
+    @{typ "bool => code_numeral => Random.seed => ((bool * term list) option * (bool list * bool)) * Random.seed"})
     
     
 (* single quickcheck report *)
@@ -417,19 +417,21 @@
         val compile = Code_Runtime.dynamic_value_strict
           (Counterexample_Report.get, put_counterexample_report, "Random_Generators.put_counterexample_report")
           thy (SOME target)
-          (fn proc => fn g => fn c => fn s => g c s #>> (apfst o Option.map o apsnd o map) proc) t' [];
-        fun single_tester c s = compile c s |> Random_Engine.run
-        fun iterate_and_collect (card, size) 0 report = (NONE, report)
-          | iterate_and_collect (card, size) j report =
+          (fn proc => fn g => fn c => fn b => fn s => g c b s
+            #>> (apfst o Option.map o apsnd o map) proc) t' [];
+        fun single_tester c b s = compile c b s |> Random_Engine.run
+        fun iterate_and_collect _ (card, size) 0 report = (NONE, report)
+          | iterate_and_collect genuine_only (card, size) j report =
             let
-              val (test_result, single_report) = apsnd Run (single_tester card size)
+              val (test_result, single_report) = apsnd Run (single_tester card genuine_only size)
               val report = collect_single_report single_report report
             in
-              case test_result of NONE => iterate_and_collect (card, size) (j - 1) report
+              case test_result of NONE => iterate_and_collect genuine_only (card, size) (j - 1) report
                 | SOME q => (SOME q, report)
             end
       in
-        fn _ => fn [card, size] => apsnd SOME (iterate_and_collect (card, size) iterations empty_report)
+        fn genuine_only => fn [card, size] =>
+          apsnd SOME (iterate_and_collect genuine_only (card, size) iterations empty_report)
       end
     else
       let