adding an even faster compilation scheme
authorbulwahn
Fri, 08 Apr 2011 16:31:14 +0200
changeset 42306 51a08b2699d5
parent 42305 494c31fdec95
child 42307 72e2fabb4bc2
adding an even faster compilation scheme
src/HOL/Tools/Quickcheck/exhaustive_generators.ML
--- a/src/HOL/Tools/Quickcheck/exhaustive_generators.ML	Fri Apr 08 16:31:14 2011 +0200
+++ b/src/HOL/Tools/Quickcheck/exhaustive_generators.ML	Fri Apr 08 16:31:14 2011 +0200
@@ -15,6 +15,7 @@
   val put_counterexample_batch: (unit -> (int -> term list option) list)
     -> Proof.context -> Proof.context
   val put_validator_batch: (unit -> (int -> bool) list) -> Proof.context -> Proof.context
+  exception Counterexample of term list
   val smart_quantifier : bool Config.T
   val quickcheck_pretty : bool Config.T
   val setup: theory -> theory
@@ -28,6 +29,9 @@
 val (smart_quantifier, setup_smart_quantifier) =
   Attrib.config_bool "quickcheck_smart_quantifier" (K true)
 
+val (fast, setup_fast) =
+  Attrib.config_bool "quickcheck_fast" (K true)
+  
 val (full_support, setup_full_support) =
   Attrib.config_bool "quickcheck_full_support" (K true)
 
@@ -67,6 +71,7 @@
 
 fun test_function T = Free ("f", T --> @{typ "term list option"})
 fun full_test_function T = Free ("f", termifyT T --> @{typ "term list option"})
+fun fast_test_function T = Free ("f", T --> @{typ "unit"})
 
 fun mk_none_continuation (x, y) =
   let
@@ -75,13 +80,24 @@
     Const (@{const_name "Quickcheck_Exhaustive.orelse"}, T --> T --> T) $ x $ y
   end
 
+fun mk_unit_let (x, y) =
+  Const (@{const_name "Let"}, @{typ "unit => (unit => unit) => unit"}) $ x $ (absdummy (@{typ unit}, y))
+  
 (** datatypes **)
 
 (* constructing exhaustive generator instances on datatypes *)
 
 exception FUNCTION_TYPE;
+
+exception Counterexample of term list
+
 val exhaustiveN = "exhaustive";
 val full_exhaustiveN = "full_exhaustive";
+val fast_exhaustiveN = "fast_exhaustive";
+
+fun fast_exhaustiveT T = (T --> @{typ unit})
+  --> @{typ code_numeral} --> @{typ unit}
+
 
 fun exhaustiveT T = (T --> @{typ "Code_Evaluation.term list option"})
   --> @{typ code_numeral} --> @{typ "Code_Evaluation.term list option"}
@@ -92,6 +108,46 @@
 fun check_allT T = (termifyT T --> @{typ "Code_Evaluation.term list option"})
   --> @{typ "Code_Evaluation.term list option"}
 
+
+fun mk_fast_equations descr vs tycos fast_exhaustives (Ts, Us) =
+  let
+    fun mk_call T =
+      let
+        val fast_exhaustive =
+          Const (@{const_name "Quickcheck_Exhaustive.fast_exhaustive_class.fast_exhaustive"}, fast_exhaustiveT T)
+      in
+        (T, fn t => fast_exhaustive $ absdummy (T, t) $ size_pred)
+      end
+    fun mk_aux_call fTs (k, _) (tyco, Ts) =
+      let
+        val T = Type (tyco, Ts)
+        val _ = if not (null fTs) then raise FUNCTION_TYPE else ()
+      in
+       (T, fn t => nth fast_exhaustives k $ absdummy (T, t) $ size_pred)
+      end
+    fun mk_consexpr simpleT (c, xs) =
+      let
+        val (Ts, fns) = split_list xs
+        val constr = Const (c, Ts ---> simpleT)
+        val bounds = map Bound (((length xs) - 1) downto 0)
+        val term_bounds = map (fn x => Bound (2 * x)) (((length xs) - 1) downto 0)
+        val start_term = fast_test_function simpleT $ list_comb (constr, bounds)
+      in fold_rev (fn f => fn t => f t) fns start_term end
+    fun mk_rhs exprs =
+        @{term "If :: bool => unit => unit => unit"}
+            $ size_ge_zero $ (foldr1 mk_unit_let exprs) $ @{term "()"}
+    val rhss =
+      Datatype_Aux.interpret_construction descr vs
+        { atyp = mk_call, dtyp = mk_aux_call }
+      |> (map o apfst) Type
+      |> map (fn (T, cs) => map (mk_consexpr T) cs)
+      |> map mk_rhs
+    val lhss = map2 (fn t => fn T => t $ fast_test_function T $ size) fast_exhaustives (Ts @ Us)
+    val eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhss ~~ rhss)
+  in
+    eqs
+  end
+  
 fun mk_equations descr vs tycos exhaustives (Ts, Us) =
   let
     fun mk_call T =
@@ -222,6 +278,22 @@
       "Creation of exhaustive generators failed because the datatype contains a function type";
     thy)
 
+fun instantiate_fast_exhaustive_datatype config descr vs tycos prfx (names, auxnames) (Ts, Us) thy =
+  let
+    val _ = Datatype_Aux.message config "Creating fast exhaustive generators...";
+    val fast_exhaustivesN = map (prefix (fast_exhaustiveN ^ "_")) (names @ auxnames)
+  in
+    thy
+    |> Class.instantiation (tycos, vs, @{sort fast_exhaustive})
+    |> Quickcheck_Common.define_functions
+        (fn exhaustives => mk_fast_equations descr vs tycos exhaustives (Ts, Us), SOME termination_tac)
+        prfx ["f", "i"] fast_exhaustivesN (map fast_exhaustiveT (Ts @ Us))
+    |> Class.prove_instantiation_exit (K (Class.intro_classes_tac []))
+  end handle FUNCTION_TYPE =>
+    (Datatype_Aux.message config
+      "Creation of exhaustive generators failed because the datatype contains a function type";
+    thy)
+       
 (* constructing bounded_forall instances on datatypes *)
 
 val bounded_forallN = "bounded_forall";
@@ -289,6 +361,46 @@
     
 (** building and compiling generator expressions **)
 
+fun mk_fast_generator_expr ctxt (t, eval_terms) =
+  let
+    val thy = ProofContext.theory_of ctxt
+    val ctxt' = Variable.auto_fixes t ctxt
+    val names = Term.add_free_names t []
+    val frees = map Free (Term.add_frees t [])
+    val ([depth_name], ctxt'') = Variable.variant_fixes ["depth"] ctxt'
+    val depth = Free (depth_name, @{typ code_numeral})
+    val return = @{term "throw_Counterexample :: term list => unit"} $
+      (HOLogic.mk_list @{typ "term"}
+        (map (fn t => HOLogic.mk_term_of (fastype_of t) t) (frees @ eval_terms)))
+    fun mk_exhaustive_closure (free as Free (_, T)) t =
+      Const (@{const_name "Quickcheck_Exhaustive.fast_exhaustive_class.fast_exhaustive"}, fast_exhaustiveT T)
+        $ lambda free t $ depth
+    val none_t = @{term "()"}
+    fun mk_safe_if (cond, then_t, else_t) =
+      @{term "If :: bool => unit => unit => unit"} $ cond $ then_t $ else_t
+    fun lookup v = the (AList.lookup (op =) (names ~~ frees) v)
+    fun mk_naive_test_term t =
+      fold_rev mk_exhaustive_closure frees (mk_safe_if (t, none_t, return)) 
+    fun mk_smart_test_term' concl bound_vars assms =
+      let
+        fun vars_of t = subtract (op =) bound_vars (Term.add_free_names t [])
+        val (vars, check) =
+          case assms of [] => (vars_of concl, (concl, none_t, return))
+            | assm :: assms => (vars_of assm, (assm,
+                mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms, none_t))
+      in
+        fold_rev mk_exhaustive_closure (map lookup vars) (mk_safe_if check)
+      end
+    fun mk_smart_test_term t =
+      let
+        val (assms, concl) = Quickcheck_Common.strip_imp t
+      in
+        mk_smart_test_term' concl [] assms
+      end
+    val mk_test_term =
+      if Config.get ctxt smart_quantifier then mk_smart_test_term else mk_naive_test_term 
+  in lambda depth (@{term "catch_Counterexample :: unit => term list option"} $ mk_test_term t) end
+
 fun mk_generator_expr ctxt (t, eval_terms) =
   let
     val thy = ProofContext.theory_of ctxt
@@ -453,8 +565,9 @@
 fun compile_generator_expr ctxt ts =
   let
     val thy = ProofContext.theory_of ctxt
-    val mk_generator_expr =
-      if Config.get ctxt full_support then mk_full_generator_expr else mk_generator_expr
+    val mk_generator_expr = 
+      if Config.get ctxt fast then mk_fast_generator_expr
+      else if Config.get ctxt full_support then mk_full_generator_expr else mk_generator_expr
     val t' = mk_parametric_generator_expr mk_generator_expr ctxt ts;
     val compile = Code_Runtime.dynamic_value_strict
       (Counterexample.get, put_counterexample, "Exhaustive_Generators.put_counterexample")
@@ -497,8 +610,11 @@
       (((@{sort typerep}, @{sort term_of}), @{sort exhaustive}), instantiate_exhaustive_datatype))
   #> Datatype.interpretation (Quickcheck_Common.ensure_sort_datatype
       (((@{sort type}, @{sort type}), @{sort bounded_forall}), instantiate_bounded_forall_datatype))
+  #> Datatype.interpretation (Quickcheck_Common.ensure_sort_datatype
+      (((@{sort typerep}, @{sort term_of}), @{sort fast_exhaustive}), instantiate_fast_exhaustive_datatype))
   #> setup_smart_quantifier
   #> setup_full_support
+  #> setup_fast
   #> setup_quickcheck_pretty
   #> Context.theory_map (Quickcheck.add_generator ("exhaustive", compile_generator_expr))
   #> Context.theory_map (Quickcheck.add_batch_generator ("exhaustive", compile_generator_exprs))