new compilation for exhaustive quickcheck
authorbulwahn
Fri, 08 Apr 2011 16:31:14 +0200
changeset 42304 34366f39d32d
parent 42291 682b35dc1926
child 42305 494c31fdec95
new compilation for exhaustive quickcheck
src/HOL/Quickcheck_Exhaustive.thy
src/HOL/Tools/Quickcheck/exhaustive_generators.ML
--- a/src/HOL/Quickcheck_Exhaustive.thy	Thu Apr 07 21:49:24 2011 +0200
+++ b/src/HOL/Quickcheck_Exhaustive.thy	Fri Apr 08 16:31:14 2011 +0200
@@ -16,22 +16,35 @@
 subsection {* exhaustive generator type classes *}
 
 class exhaustive = term_of +
-fixes exhaustive :: "('a * (unit => term) \<Rightarrow> term list option) \<Rightarrow> code_numeral \<Rightarrow> term list option"
+  fixes exhaustive :: "('a \<Rightarrow> term list option) \<Rightarrow> code_numeral \<Rightarrow> term list option"
+  fixes full_exhaustive :: "('a * (unit => term) \<Rightarrow> term list option) \<Rightarrow> code_numeral \<Rightarrow> term list option"
 
 instantiation code_numeral :: exhaustive
 begin
 
-function exhaustive_code_numeral' :: "(code_numeral * (unit => term) => term list option) => code_numeral => code_numeral => term list option"
+function full_exhaustive_code_numeral' :: "(code_numeral * (unit => term) => term list option) => code_numeral => code_numeral => term list option"
+  where "full_exhaustive_code_numeral' f d i =
+    (if d < i then None
+    else (f (i, %_. Code_Evaluation.term_of i)) orelse (full_exhaustive_code_numeral' f d (i + 1)))"
+by pat_completeness auto
+
+termination
+  by (relation "measure (%(_, d, i). Code_Numeral.nat_of (d + 1 - i))") auto
+
+definition "full_exhaustive f d = full_exhaustive_code_numeral' f d 0"
+
+function exhaustive_code_numeral' :: "(code_numeral => term list option) => code_numeral => code_numeral => term list option"
   where "exhaustive_code_numeral' f d i =
     (if d < i then None
-    else (f (i, %_. Code_Evaluation.term_of i)) orelse (exhaustive_code_numeral' f d (i + 1)))"
+    else (f i orelse exhaustive_code_numeral' f d (i + 1)))"
 by pat_completeness auto
 
-termination 
+termination
   by (relation "measure (%(_, d, i). Code_Numeral.nat_of (d + 1 - i))") auto
 
 definition "exhaustive f d = exhaustive_code_numeral' f d 0"
 
+
 instance ..
 
 end
@@ -39,7 +52,9 @@
 instantiation nat :: exhaustive
 begin
 
-definition "exhaustive f d = exhaustive (%(x, xt). f (Code_Numeral.nat_of x, %_. Code_Evaluation.term_of (Code_Numeral.nat_of x))) d"
+definition "exhaustive f d = exhaustive (%x. f (Code_Numeral.nat_of x)) d"
+
+definition "full_exhaustive f d = full_exhaustive (%(x, xt). f (Code_Numeral.nat_of x, %_. Code_Evaluation.term_of (Code_Numeral.nat_of x))) d"
 
 instance ..
 
@@ -48,8 +63,8 @@
 instantiation int :: exhaustive
 begin
 
-function exhaustive' :: "(int * (unit => term) => term list option) => int => int => term list option"
-  where "exhaustive' f d i = (if d < i then None else (case f (i, %_. Code_Evaluation.term_of i) of Some t => Some t | None => exhaustive' f d (i + 1)))"
+function exhaustive' :: "(int => term list option) => int => int => term list option"
+  where "exhaustive' f d i = (if d < i then None else (f i orelse exhaustive' f d (i + 1)))"
 by pat_completeness auto
 
 termination 
@@ -57,6 +72,15 @@
 
 definition "exhaustive f d = exhaustive' f (Code_Numeral.int_of d) (- (Code_Numeral.int_of d))"
 
+function full_exhaustive' :: "(int * (unit => term) => term list option) => int => int => term list option"
+  where "full_exhaustive' f d i = (if d < i then None else (case f (i, %_. Code_Evaluation.term_of i) of Some t => Some t | None => full_exhaustive' f d (i + 1)))"
+by pat_completeness auto
+
+termination 
+  by (relation "measure (%(_, d, i). nat (d + 1 - i))") auto
+
+definition "full_exhaustive f d = full_exhaustive' f (Code_Numeral.int_of d) (- (Code_Numeral.int_of d))"
+
 instance ..
 
 end
@@ -65,7 +89,10 @@
 begin
 
 definition
-  "exhaustive f d = exhaustive (%(x, t1). exhaustive (%(y, t2). f ((x, y),
+  "exhaustive f d = exhaustive (%x. exhaustive (%y. f ((x, y))) d) d"
+
+definition
+  "full_exhaustive f d = full_exhaustive (%(x, t1). full_exhaustive (%(y, t2). f ((x, y),
     %u. let T1 = (Typerep.typerep (TYPE('a)));
             T2 = (Typerep.typerep (TYPE('b)))
     in Code_Evaluation.App (Code_Evaluation.App (
@@ -80,11 +107,23 @@
 instantiation "fun" :: ("{equal, exhaustive}", exhaustive) exhaustive
 begin
 
-fun exhaustive_fun' :: "(('a => 'b) * (unit => term) => term list option) => code_numeral => code_numeral => term list option"
+fun exhaustive_fun' :: "(('a => 'b) => term list option) => code_numeral => code_numeral => term list option"
+where
+  "exhaustive_fun' f i d = (exhaustive (%b. f (%_. b)) d)
+   orelse (if i > 1 then
+     exhaustive_fun' (%g. exhaustive (%a. exhaustive (%b.
+       f (g(a := b))) d) d) (i - 1) d else None)"
+
+definition exhaustive_fun :: "(('a => 'b) => term list option) => code_numeral => term list option"
 where
-  "exhaustive_fun' f i d = (exhaustive (%(b, t). f (%_. b, %_. Code_Evaluation.Abs (STR ''x'') (Typerep.typerep TYPE('a)) (t ()))) d)
+  "exhaustive_fun f d = exhaustive_fun' f d d" 
+
+
+fun full_exhaustive_fun' :: "(('a => 'b) * (unit => term) => term list option) => code_numeral => code_numeral => term list option"
+where
+  "full_exhaustive_fun' f i d = (full_exhaustive (%(b, t). f (%_. b, %_. Code_Evaluation.Abs (STR ''x'') (Typerep.typerep TYPE('a)) (t ()))) d)
    orelse (if i > 1 then
-     exhaustive_fun' (%(g, gt). exhaustive (%(a, at). exhaustive (%(b, bt).
+     full_exhaustive_fun' (%(g, gt). full_exhaustive (%(a, at). full_exhaustive (%(b, bt).
        f (g(a := b),
          (%_. let A = (Typerep.typerep (TYPE('a)));
                   B = (Typerep.typerep (TYPE('b)));
@@ -94,9 +133,9 @@
                   (Code_Evaluation.Const (STR ''Fun.fun_upd'') (fun (fun A B) (fun A (fun B (fun A B)))))
                 (gt ())) (at ())) (bt ())))) d) d) (i - 1) d else None)"
 
-definition exhaustive_fun :: "(('a => 'b) * (unit => term) => term list option) => code_numeral => term list option"
+definition full_exhaustive_fun :: "(('a => 'b) * (unit => term) => term list option) => code_numeral => term list option"
 where
-  "exhaustive_fun f d = exhaustive_fun' f d d" 
+  "full_exhaustive_fun f d = full_exhaustive_fun' f d d" 
 
 instance ..
 
--- a/src/HOL/Tools/Quickcheck/exhaustive_generators.ML	Thu Apr 07 21:49:24 2011 +0200
+++ b/src/HOL/Tools/Quickcheck/exhaustive_generators.ML	Fri Apr 08 16:31:14 2011 +0200
@@ -28,6 +28,9 @@
 val (smart_quantifier, setup_smart_quantifier) =
   Attrib.config_bool "quickcheck_smart_quantifier" (K true)
 
+val (full_support, setup_full_support) =
+  Attrib.config_bool "quickcheck_full_support" (K true)
+
 val (quickcheck_pretty, setup_quickcheck_pretty) =
   Attrib.config_bool "quickcheck_pretty" (K true)
  
@@ -61,7 +64,9 @@
 val size = @{term "i :: code_numeral"}
 val size_pred = @{term "(i :: code_numeral) - 1"}
 val size_ge_zero = @{term "(i :: code_numeral) > 0"}
-fun test_function T = Free ("f", termifyT T --> @{typ "term list option"})
+
+fun test_function T = Free ("f", T --> @{typ "term list option"})
+fun full_test_function T = Free ("f", termifyT T --> @{typ "term list option"})
 
 fun mk_none_continuation (x, y) =
   let
@@ -76,8 +81,12 @@
 
 exception FUNCTION_TYPE;
 val exhaustiveN = "exhaustive";
+val full_exhaustiveN = "full_exhaustive";
 
-fun exhaustiveT T = (termifyT T --> @{typ "Code_Evaluation.term list option"})
+fun exhaustiveT T = (T --> @{typ "Code_Evaluation.term list option"})
+  --> @{typ code_numeral} --> @{typ "Code_Evaluation.term list option"}
+
+fun full_exhaustiveT T = (termifyT T --> @{typ "Code_Evaluation.term list option"})
   --> @{typ code_numeral} --> @{typ "Code_Evaluation.term list option"}
 
 fun check_allT T = (termifyT T --> @{typ "Code_Evaluation.term list option"})
@@ -89,7 +98,45 @@
       let
         val exhaustive = Const (@{const_name "Quickcheck_Exhaustive.exhaustive_class.exhaustive"}, exhaustiveT T)
       in
-        (T, (fn t => exhaustive $
+        (T, fn t => exhaustive $ absdummy (T, t) $ size_pred)
+      end
+    fun mk_aux_call fTs (k, _) (tyco, Ts) =
+      let
+        val T = Type (tyco, Ts)
+        val _ = if not (null fTs) then raise FUNCTION_TYPE else ()
+      in
+       (T, fn t => nth exhaustives k $ absdummy (T, t) $ size_pred)
+      end
+    fun mk_consexpr simpleT (c, xs) =
+      let
+        val (Ts, fns) = split_list xs
+        val constr = Const (c, Ts ---> simpleT)
+        val bounds = map Bound (((length xs) - 1) downto 0)
+        val term_bounds = map (fn x => Bound (2 * x)) (((length xs) - 1) downto 0)
+        val start_term = test_function simpleT $ list_comb (constr, bounds)
+      in fold_rev (fn f => fn t => f t) fns start_term end
+    fun mk_rhs exprs =
+        @{term "If :: bool => term list option => term list option => term list option"}
+            $ size_ge_zero $ (foldr1 mk_none_continuation exprs) $ @{term "None :: term list option"}
+    val rhss =
+      Datatype_Aux.interpret_construction descr vs
+        { atyp = mk_call, dtyp = mk_aux_call }
+      |> (map o apfst) Type
+      |> map (fn (T, cs) => map (mk_consexpr T) cs)
+      |> map mk_rhs
+    val lhss = map2 (fn t => fn T => t $ test_function T $ size) exhaustives (Ts @ Us)
+    val eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhss ~~ rhss)
+  in
+    eqs
+  end
+    
+fun mk_full_equations descr vs tycos full_exhaustives (Ts, Us) =
+  let
+    fun mk_call T =
+      let
+        val full_exhaustive = Const (@{const_name "Quickcheck_Exhaustive.exhaustive_class.full_exhaustive"}, full_exhaustiveT T)
+      in
+        (T, (fn t => full_exhaustive $
           (HOLogic.split_const (T, @{typ "unit => Code_Evaluation.term"}, @{typ "Code_Evaluation.term list option"})
           $ absdummy (T, absdummy (@{typ "unit => Code_Evaluation.term"}, t))) $ size_pred))
       end
@@ -98,7 +145,7 @@
         val T = Type (tyco, Ts)
         val _ = if not (null fTs) then raise FUNCTION_TYPE else ()
       in
-       (T, (fn t => nth exhaustives k $
+       (T, (fn t => nth full_exhaustives k $
           (HOLogic.split_const (T, @{typ "unit => Code_Evaluation.term"}, @{typ "Code_Evaluation.term list option"})
             $ absdummy (T, absdummy (@{typ "unit => Code_Evaluation.term"}, t))) $ size_pred))
       end
@@ -112,7 +159,7 @@
         val Eval_Const = Const ("Code_Evaluation.Const", HOLogic.literalT --> @{typ typerep} --> HOLogic.termT)
         val term = fold (fn u => fn t => Eval_App $ t $ (u $ @{term "()"}))
           bounds (Eval_Const $ HOLogic.mk_literal c $ HOLogic.mk_typerep (Ts ---> simpleT))
-        val start_term = test_function simpleT $ 
+        val start_term = full_test_function simpleT $ 
         (HOLogic.pair_const simpleT @{typ "unit => Code_Evaluation.term"}
           $ (list_comb (constr, bounds)) $ absdummy (@{typ unit}, term))
       in fold_rev (fn f => fn t => f t) fns start_term end
@@ -125,7 +172,7 @@
       |> (map o apfst) Type
       |> map (fn (T, cs) => map (mk_consexpr T) cs)
       |> map mk_rhs
-    val lhss = map2 (fn t => fn T => t $ test_function T $ size) exhaustives (Ts @ Us);
+    val lhss = map2 (fn t => fn T => t $ full_test_function T $ size) full_exhaustives (Ts @ Us);
     val eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhss ~~ rhss)
   in
     eqs
@@ -157,13 +204,18 @@
 fun instantiate_exhaustive_datatype config descr vs tycos prfx (names, auxnames) (Ts, Us) thy =
   let
     val _ = Datatype_Aux.message config "Creating exhaustive generators...";
-    val exhaustivesN = map (prefix (exhaustiveN ^ "_")) (names @ auxnames);
+    val exhaustivesN = map (prefix (exhaustiveN ^ "_")) (names @ auxnames)
+    val full_exhaustivesN = map (prefix (full_exhaustiveN ^ "_")) (names @ auxnames)
   in
     thy
     |> Class.instantiation (tycos, vs, @{sort exhaustive})
     |> Quickcheck_Common.define_functions
         (fn exhaustives => mk_equations descr vs tycos exhaustives (Ts, Us), SOME termination_tac)
         prfx ["f", "i"] exhaustivesN (map exhaustiveT (Ts @ Us))
+    |> Quickcheck_Common.define_functions
+        (fn full_exhaustives => mk_full_equations descr vs tycos full_exhaustives (Ts, Us),
+        SOME termination_tac)
+        prfx ["f", "i"] full_exhaustivesN (map full_exhaustiveT (Ts @ Us))
     |> Class.prove_instantiation_exit (K (Class.intro_classes_tac []))
   end handle FUNCTION_TYPE =>
     (Datatype_Aux.message config
@@ -244,6 +296,48 @@
     val names = Term.add_free_names t []
     val frees = map Free (Term.add_frees t [])
     val ([depth_name], ctxt'') = Variable.variant_fixes ["depth"] ctxt'
+    val depth = Free (depth_name, @{typ code_numeral})
+    val return = @{term "Some :: term list => term list option"} $
+      (HOLogic.mk_list @{typ "term"}
+        (map (fn t => HOLogic.mk_term_of (fastype_of t) t) (frees @ eval_terms)))
+    fun mk_exhaustive_closure (free as Free (_, T)) t =
+      Const (@{const_name "Quickcheck_Exhaustive.exhaustive_class.exhaustive"}, exhaustiveT T)
+        $ lambda free t $ depth
+    val none_t = @{term "None :: term list option"}
+    fun mk_safe_if (cond, then_t, else_t) =
+      @{term "Quickcheck_Exhaustive.catch_match :: term list option => term list option => term list option"} $
+        (@{term "If :: bool => term list option => term list option => term list option"}
+        $ cond $ then_t $ else_t) $ none_t;
+    fun lookup v = the (AList.lookup (op =) (names ~~ frees) v)
+    fun mk_naive_test_term t =
+      fold_rev mk_exhaustive_closure frees (mk_safe_if (t, none_t, return)) 
+    fun mk_smart_test_term' concl bound_vars assms =
+      let
+        fun vars_of t = subtract (op =) bound_vars (Term.add_free_names t [])
+        val (vars, check) =
+          case assms of [] => (vars_of concl, (concl, none_t, return))
+            | assm :: assms => (vars_of assm, (assm,
+                mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms, none_t))
+      in
+        fold_rev mk_exhaustive_closure (map lookup vars) (mk_safe_if check)
+      end
+    fun mk_smart_test_term t =
+      let
+        val (assms, concl) = Quickcheck_Common.strip_imp t
+      in
+        mk_smart_test_term' concl [] assms
+      end
+    val mk_test_term =
+      if Config.get ctxt smart_quantifier then mk_smart_test_term else mk_naive_test_term 
+  in lambda depth (mk_test_term t) end
+
+fun mk_full_generator_expr ctxt (t, eval_terms) =
+  let
+    val thy = ProofContext.theory_of ctxt
+    val ctxt' = Variable.auto_fixes t ctxt
+    val names = Term.add_free_names t []
+    val frees = map Free (Term.add_frees t [])
+    val ([depth_name], ctxt'') = Variable.variant_fixes ["depth"] ctxt'
     val (term_names, ctxt''') = Variable.variant_fixes (map (prefix "t_") names) ctxt''
     val depth = Free (depth_name, @{typ code_numeral})
     val term_vars = map (fn n => Free (n, @{typ "unit => term"})) term_names
@@ -288,7 +382,7 @@
       if Config.get ctxt smart_quantifier then mk_smart_test_term else mk_naive_test_term 
   in lambda depth (mk_test_term t) end
 
-val mk_parametric_generator_expr =
+fun mk_parametric_generator_expr mk_generator_expr =
   Quickcheck_Common.gen_mk_parametric_generator_expr 
     ((mk_generator_expr, absdummy (@{typ "code_numeral"}, @{term "None :: term list option"})),
       @{typ "code_numeral => term list option"})
@@ -359,7 +453,9 @@
 fun compile_generator_expr ctxt ts =
   let
     val thy = ProofContext.theory_of ctxt
-    val t' = mk_parametric_generator_expr ctxt ts;
+    val mk_generator_expr =
+      if Config.get ctxt full_support then mk_full_generator_expr else mk_generator_expr
+    val t' = mk_parametric_generator_expr mk_generator_expr ctxt ts;
     val compile = Code_Runtime.dynamic_value_strict
       (Counterexample.get, put_counterexample, "Exhaustive_Generators.put_counterexample")
       thy (SOME target) (fn proc => fn g =>
@@ -402,6 +498,7 @@
   #> Datatype.interpretation (Quickcheck_Common.ensure_sort_datatype
       (((@{sort type}, @{sort type}), @{sort bounded_forall}), instantiate_bounded_forall_datatype))
   #> setup_smart_quantifier
+  #> setup_full_support
   #> setup_quickcheck_pretty
   #> Context.theory_map (Quickcheck.add_generator ("exhaustive", compile_generator_expr))
   #> Context.theory_map (Quickcheck.add_batch_generator ("exhaustive", compile_generator_exprs))