src/HOL/Tools/Quickcheck/quickcheck_common.ML
changeset 46331 f5598b604a54
parent 46327 ecda23528833
child 46478 cf1bcfb34c82
--- a/src/HOL/Tools/Quickcheck/quickcheck_common.ML	Wed Jan 25 09:50:34 2012 +0100
+++ b/src/HOL/Tools/Quickcheck/quickcheck_common.ML	Wed Jan 25 15:19:04 2012 +0100
@@ -16,11 +16,12 @@
     -> (typ option * (term * term list)) list list
   val mk_safe_if : term -> term -> term * term * (bool -> term) -> bool -> term
   val collect_results : ('a -> Quickcheck.result) -> 'a list -> Quickcheck.result list -> Quickcheck.result list
-  type compile_generator =
-    Proof.context -> (term * term list) list -> bool -> int list -> (bool * term list) option * Quickcheck.report option
+  type result = (bool * term list) option * Quickcheck.report option
+  type generator = string * ((theory -> typ list -> bool) * 
+      (Proof.context -> (term * term list) list -> bool -> int list -> result))
   val generator_test_goal_terms :
-    string * compile_generator -> Proof.context -> bool -> (string * typ) list
-    -> (term * term list) list -> Quickcheck.result list
+    generator -> Proof.context -> bool -> (string * typ) list
+      -> (term * term list) list -> Quickcheck.result list
   type instantiation = Datatype.config -> Datatype.descr -> (string * sort) list
      -> string list -> string -> string list * string list -> typ list * typ list -> theory -> theory
   val ensure_sort :
@@ -36,8 +37,7 @@
      -> Proof.context -> (term * term list) list -> term
   val mk_fun_upd : typ -> typ -> term * term -> term -> term
   val post_process_term : term -> term
-  val test_term : string * compile_generator
-    -> Proof.context -> bool -> term * term list -> Quickcheck.result
+  val test_term : generator -> Proof.context -> bool -> term * term list -> Quickcheck.result
 end;
 
 structure Quickcheck_Common : QUICKCHECK_COMMON =
@@ -58,8 +58,9 @@
 
 (* testing functions: testing with increasing sizes (and cardinalities) *)
 
-type compile_generator =
-  Proof.context -> (term * term list) list -> bool -> int list -> (bool * term list) option * Quickcheck.report option
+type result = (bool * term list) option * Quickcheck.report option
+type generator = string * ((theory -> typ list -> bool) * 
+      (Proof.context -> (term * term list) list -> bool -> int list -> result))
 
 fun check_test_term t =
   let
@@ -73,7 +74,7 @@
   let val ({cpu, ...}, result) = Timing.timing e ()
   in (result, (description, Time.toMilliseconds cpu)) end
 
-fun test_term (name, compile) ctxt catch_code_errors (t, eval_terms) =
+fun test_term (name, (_, compile)) ctxt catch_code_errors (t, eval_terms) =
   let
     val genuine_only = Config.get ctxt Quickcheck.genuine_only
     val _ = check_test_term t
@@ -165,7 +166,7 @@
       [comp_time, exec_time])
   end
 
-fun test_term_with_cardinality (name, compile) ctxt catch_code_errors ts =
+fun test_term_with_cardinality (name, (size_matters_for, compile)) ctxt catch_code_errors ts =
   let
     val genuine_only = Config.get ctxt Quickcheck.genuine_only
     val thy = Proof_Context.theory_of ctxt
@@ -189,13 +190,11 @@
         Option.map (pair (card, size)) ts
       end
     val enumeration_card_size =
-      if forall (fn T => Sign.of_sort thy (T,  ["Enum.enum"])) Ts then
-        (* size does not matter *)
-        map (rpair 0) (1 upto (length ts))
-      else
-        (* size does matter *)
+      if size_matters_for thy Ts then
         map_product pair (1 upto (length ts)) (1 upto (Config.get ctxt Quickcheck.size))
         |> sort (fn ((c1, s1), (c2, s2)) => int_ord ((c1 + s1), (c2 + s2)))
+      else
+        map (rpair 0) (1 upto (length ts))
     val act = if catch_code_errors then try else (fn f => SOME o f)
     val (test_fun, comp_time) = cpu_time "quickcheck compilation" (fn () => act (compile ctxt) ts)
     val _ = Quickcheck.add_timing comp_time current_result
@@ -325,7 +324,7 @@
         collect_results f ts (result :: results)
     end  
 
-fun generator_test_goal_terms (name, compile) ctxt catch_code_errors insts goals =
+fun generator_test_goal_terms generator ctxt catch_code_errors insts goals =
   let
     fun add_eval_term t ts = if is_Free t then ts else ts @ [t]
     fun add_equation_eval_terms (t, eval_terms) =
@@ -334,15 +333,15 @@
       | NONE => (t, eval_terms)
     fun test_term' goal =
       case goal of
-        [(NONE, t)] => test_term (name, compile) ctxt catch_code_errors t
-      | ts => test_term_with_cardinality (name, compile) ctxt catch_code_errors (map snd ts)
+        [(NONE, t)] => test_term generator ctxt catch_code_errors t
+      | ts => test_term_with_cardinality generator ctxt catch_code_errors (map snd ts)
     val goals' = instantiate_goals ctxt insts goals
       |> map (map (apsnd add_equation_eval_terms))
   in
     if Config.get ctxt Quickcheck.finite_types then
       collect_results test_term' goals' []
     else
-      collect_results (test_term (name, compile) ctxt catch_code_errors)
+      collect_results (test_term generator ctxt catch_code_errors)
         (maps (map snd) goals') []
   end;