adding smart quantifiers to exhaustive testing
authorbulwahn
Fri, 03 Dec 2010 08:40:47 +0100
changeset 40907 45ba9f05583a
parent 40906 b5a319668955
child 40908 e8806880819e
adding smart quantifiers to exhaustive testing
src/HOL/Tools/smallvalue_generators.ML
--- a/src/HOL/Tools/smallvalue_generators.ML	Fri Dec 03 08:40:47 2010 +0100
+++ b/src/HOL/Tools/smallvalue_generators.ML	Fri Dec 03 08:40:47 2010 +0100
@@ -10,6 +10,7 @@
     Proof.context -> term -> int -> term list option * (bool list * bool)
   val put_counterexample: (unit -> int -> term list option)
     -> Proof.context -> Proof.context
+  val smart_quantifier : bool Config.T;
   val setup: theory -> theory
 end;
 
@@ -20,6 +21,11 @@
 
 val define_foundationally = false
 
+(* dynamic options *)
+
+val (smart_quantifier, setup_smart_quantifier) =
+  Attrib.config_bool "quickcheck_smart_quantifier" (K true)
+
 (** general term functions **)
 
 fun mk_measure f =
@@ -209,12 +215,50 @@
 
 val target = "Quickcheck";
 
-fun mk_generator_expr thy prop Ts =
+fun mk_smart_generator_expr ctxt t =
   let
+    val ((vnames, Ts), t') = apfst split_list (strip_abs t)
+    val ([depth_name], ctxt') = Variable.variant_fixes ["depth"] ctxt
+    val (names, ctxt'') = Variable.variant_fixes vnames ctxt'
+    val (term_names, ctxt''') = Variable.variant_fixes (map (prefix "t_") vnames) ctxt''
+    val depth = Free (depth_name, @{typ code_numeral})
+    val frees = map2 (curry Free) names Ts
+    val term_vars = map (fn n => Free (n, @{typ "unit => term"})) term_names 
+    fun strip_imp (Const(@{const_name HOL.implies},_) $ A $ B) = apfst (cons A) (strip_imp B)
+      | strip_imp A = ([], A)
+    val (assms, concl) = strip_imp (subst_bounds (rev frees, t'))
+    val terms = HOLogic.mk_list @{typ term} (map (fn v => v $ @{term "()"}) term_vars)
+    fun mk_small_closure (free as Free (_, T), term_var) t =
+      Const (@{const_name "Smallcheck.full_small_class.full_small"}, full_smallT T)
+        $ (HOLogic.split_const (T, @{typ "unit => term"}, @{typ "term list option"}) 
+          $ lambda free (lambda term_var t)) $ depth
+    fun lookup v = the (AList.lookup (op =) (names ~~ (frees ~~ term_vars)) v)
+    val none_t = @{term "None :: term list option"}
+    fun mk_safe_if (cond, then_t, else_t) =
+      @{term "Smallcheck.catch_match :: term list option => term list option => term list option"} $
+        (@{term "If :: bool => term list option => term list option => term list option"}
+        $ cond $ then_t $ else_t) $ none_t;
+    fun mk_test_term bound_vars assms =
+      let
+        fun vars_of t = subtract (op =) bound_vars (Term.add_free_names t [])
+        val (vars, check) =
+          case assms of [] =>
+            (vars_of concl, (concl, none_t, @{term "Some :: term list => term list option"} $ terms))
+          | assm :: assms =>
+            (vars_of assm, (assm, mk_test_term (union (op =) (vars_of assm) bound_vars) assms, none_t))
+      in
+        fold_rev mk_small_closure (map lookup vars) (mk_safe_if check)
+      end
+  in lambda depth (mk_test_term [] assms) end
+
+fun mk_generator_expr ctxt t =
+  let
+    val Ts = (map snd o fst o strip_abs) t;
+    val thy = ProofContext.theory_of ctxt
     val bound_max = length Ts - 1;
     val bounds = map_index (fn (i, ty) =>
       (2 * (bound_max - i) + 1, 2 * (bound_max - i), 2 * i, ty)) Ts;
-    val result = list_comb (prop, map (fn (i, _, _, _) => Bound i) bounds);
+    val result = list_comb (t, map (fn (i, _, _, _) => Bound i) bounds);
     val terms = HOLogic.mk_list @{typ term} (map (fn (_, i, _, _) => Bound i $ @{term "()"}) bounds);
     val check =
       @{term "Smallcheck.catch_match :: term list option => term list option => term list option"} $
@@ -228,26 +272,26 @@
   in Abs ("d", @{typ code_numeral}, fold_rev mk_small_closure bounds check) end
 
 fun compile_generator_expr ctxt t =
-  let
-    val Ts = (map snd o fst o strip_abs) t;
-    val thy = ProofContext.theory_of ctxt
-  in if Config.get ctxt Quickcheck.report then
+  if Config.get ctxt Quickcheck.report then
     error "Compilation with reporting facility is not supported"
   else
     let
-      val t' = mk_generator_expr thy t Ts;
+      val thy = ProofContext.theory_of ctxt
+      val t' =
+        (if Config.get ctxt smart_quantifier then mk_smart_generator_expr else mk_generator_expr)
+          ctxt t;
       val compile = Code_Runtime.dynamic_value_strict
         (Counterexample.get, put_counterexample, "Smallvalue_Generators.put_counterexample")
         thy (SOME target) (fn proc => fn g => g #> (Option.map o map) proc) t' [];
       val dummy_report = ([], false)
-    in compile #> rpair dummy_report end
-  end;
+    in compile #> rpair dummy_report end;
 
 (** setup **)
 
 val setup =
   Datatype.interpretation
     (Quickcheck_Generators.ensure_sort_datatype (@{sort full_small}, instantiate_smallvalue_datatype))
+  #> setup_smart_quantifier
   #> Context.theory_map
     (Quickcheck.add_generator ("small", compile_generator_expr));