--- a/src/HOL/Tools/Quickcheck/exhaustive_generators.ML Tue May 29 11:41:37 2012 +0200
+++ b/src/HOL/Tools/Quickcheck/exhaustive_generators.ML Tue May 29 13:46:50 2012 +0200
@@ -17,6 +17,7 @@
val put_validator_batch: (unit -> (int -> bool) list) -> Proof.context -> Proof.context
exception Counterexample of term list
val smart_quantifier : bool Config.T
+ val optimise_equality : bool Config.T
val quickcheck_pretty : bool Config.T
val setup_exhaustive_datatype_interpretation : theory -> theory
val setup: theory -> theory
@@ -36,6 +37,8 @@
(** dynamic options **)
val smart_quantifier = Attrib.setup_config_bool @{binding quickcheck_smart_quantifier} (K true)
+val optimise_equality = Attrib.setup_config_bool @{binding quickcheck_optimise_equality} (K true)
+
val fast = Attrib.setup_config_bool @{binding quickcheck_fast} (K false)
val bounded_forall = Attrib.setup_config_bool @{binding quickcheck_bounded_forall} (K false)
val full_support = Attrib.setup_config_bool @{binding quickcheck_full_support} (K true)
@@ -288,26 +291,72 @@
(* building and compiling generator expressions *)
+fun mk_let_expr (x, t, e) genuine =
+ let
+ val (T1, T2) = (fastype_of x, fastype_of (e genuine))
+ in
+ Const (@{const_name Let}, T1 --> (T1 --> T2) --> T2) $ t $ lambda x (e genuine)
+ end
-fun mk_test_term lookup mk_closure mk_if none_t return ctxt =
+fun mk_test_term lookup mk_closure mk_if mk_let none_t return ctxt =
let
+ val cnstrs = flat (maps
+ (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
+ (Symtab.dest (Datatype.get_all (Proof_Context.theory_of ctxt))))
+ fun is_constrt (Const (s, T), ts) = (case (AList.lookup (op =) cnstrs s, body_type T) of
+ (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname'
+ | _ => false)
+ | is_constrt _ = false
fun mk_naive_test_term t =
fold_rev mk_closure (map lookup (Term.add_free_names t []))
(mk_if (t, none_t, return) true)
fun mk_smart_test_term' concl bound_vars assms genuine =
let
fun vars_of t = subtract (op =) bound_vars (Term.add_free_names t [])
+ fun mk_equality_term (lhs, f as Free (x, _)) c (assm, assms) =
+ if member (op =) (Term.add_free_names lhs bound_vars) x then
+ c (assm, assms)
+ else
+ (remove (op =) x (vars_of assm),
+ mk_let f (try lookup x) lhs
+ (mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms) genuine)
+ | mk_equality_term (lhs, t) c (assm, assms) =
+ if is_constrt (strip_comb t) then
+ let
+ val (constr, args) = strip_comb t
+ val T = fastype_of t
+ val vars = map Free (Variable.variant_frees ctxt (concl :: assms)
+ (map (fn t => ("x", fastype_of t)) args))
+ val varnames = map (fst o dest_Free) vars
+ val dummy_var = Free (singleton
+ (Variable.variant_frees ctxt (concl :: assms @ vars)) ("dummy", T))
+ val new_assms = map HOLogic.mk_eq (vars ~~ args)
+ val cont_t = mk_smart_test_term' concl (union (op =) varnames bound_vars)
+ (new_assms @ assms) genuine
+ in
+ (vars_of lhs, Datatype_Case.make_case ctxt Datatype_Case.Quiet [] lhs
+ [(list_comb (constr, vars), cont_t), (dummy_var, none_t)])
+ end
+ else c (assm, assms)
+ fun default (assm, assms) = (vars_of assm,
+ mk_if (HOLogic.mk_not assm, none_t,
+ mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms) genuine)
val (vars, check) =
- case assms of [] => (vars_of concl, (concl, none_t, return))
- | assm :: assms => (vars_of assm, (HOLogic.mk_not assm, none_t,
- mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms))
+ case assms of [] => (vars_of concl, mk_if (concl, none_t, return) genuine)
+ | assm :: assms =>
+ if Config.get ctxt optimise_equality then
+ (case try HOLogic.dest_eq assm of
+ SOME (lhs, rhs) =>
+ mk_equality_term (lhs, rhs) (mk_equality_term (rhs, lhs) default) (assm, assms)
+ | NONE => default (assm, assms))
+ else default (assm, assms)
in
- fold_rev mk_closure (map lookup vars) (mk_if check genuine)
+ fold_rev mk_closure (map lookup vars) check
end
val mk_smart_test_term =
Quickcheck_Common.strip_imp #> (fn (assms, concl) => mk_smart_test_term' concl [] assms true)
in
- if Config.get ctxt smart_quantifier then mk_smart_test_term else mk_naive_test_term
+ if Config.get ctxt smart_quantifier then mk_smart_test_term else mk_naive_test_term
end
fun mk_fast_generator_expr ctxt (t, eval_terms) =
@@ -327,7 +376,8 @@
$ lambda free t $ depth
val none_t = @{term "()"}
fun mk_safe_if (cond, then_t, else_t) genuine = mk_if (cond, then_t, else_t genuine)
- val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_safe_if none_t return ctxt
+ fun mk_let def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
+ val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_safe_if mk_let none_t return ctxt
in lambda depth (@{term "catch_Counterexample :: unit => term list option"} $ mk_test_term t) end
fun mk_unknown_term T = HOLogic.reflect_term (Const ("Quickcheck_Exhaustive.unknown", T))
@@ -355,7 +405,8 @@
$ lambda free t $ depth
val none_t = Const (@{const_name "None"}, resultT)
val mk_if = Quickcheck_Common.mk_safe_if genuine_only none_t
- val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_if none_t return ctxt
+ fun mk_let def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
+ val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_if mk_let none_t return ctxt
in lambda genuine_only (lambda depth (mk_test_term t)) end
fun mk_full_generator_expr ctxt (t, eval_terms) =
@@ -384,7 +435,11 @@
$ lambda free (lambda term_var t)) $ depth
val none_t = Const (@{const_name "None"}, resultT)
val mk_if = Quickcheck_Common.mk_safe_if genuine_only none_t
- val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_if none_t return ctxt
+ fun mk_let _ (SOME (v, term_var)) t e =
+ mk_let_expr (v, t,
+ e #> subst_free [(term_var, absdummy @{typ unit} (HOLogic.mk_term_of (fastype_of t) t))])
+ | mk_let v NONE t e = mk_let_expr (v, t, e)
+ val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_if mk_let none_t return ctxt
in lambda genuine_only (lambda depth (mk_test_term t)) end
fun mk_parametric_generator_expr mk_generator_expr =
@@ -406,8 +461,9 @@
Const (@{const_name "Quickcheck_Exhaustive.bounded_forall_class.bounded_forall"}, bounded_forallT T)
$ lambda (Free (s, T)) t $ depth
fun mk_safe_if (cond, then_t, else_t) genuine = mk_if (cond, then_t, else_t genuine)
+ fun mk_let def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
val mk_test_term =
- mk_test_term lookup mk_bounded_forall mk_safe_if @{term True} (K @{term False}) ctxt
+ mk_test_term lookup mk_bounded_forall mk_safe_if mk_let @{term True} (K @{term False}) ctxt
in lambda depth (mk_test_term t) end