src/HOL/Tools/Quickcheck/exhaustive_generators.ML
changeset 48414 43875bab3a4c
parent 48273 65233084e9d7
child 50046 0051dc4f301f
--- a/src/HOL/Tools/Quickcheck/exhaustive_generators.ML	Fri Jul 20 23:38:15 2012 +0200
+++ b/src/HOL/Tools/Quickcheck/exhaustive_generators.ML	Sat Jul 21 10:53:26 2012 +0200
@@ -299,6 +299,16 @@
     Const (@{const_name Let}, T1 --> (T1 --> T2) --> T2) $ t $ lambda x (e genuine)
   end
 
+fun mk_safe_let_expr genuine_only none safe (x, t, e) genuine =
+  let
+    val (T1, T2) = (fastype_of x, fastype_of (e genuine))
+    val if_t = Const (@{const_name "If"}, @{typ bool} --> T2 --> T2 --> T2)
+  in
+    Const (@{const_name "Quickcheck.catch_match"}, T2 --> T2 --> T2) $ 
+      (Const (@{const_name Let}, T1 --> (T1 --> T2) --> T2) $ t $ lambda x (e genuine)) $
+      (if_t $ genuine_only $ none $ safe false)
+  end
+
 fun mk_test_term lookup mk_closure mk_if mk_let none_t return ctxt =
   let
     val cnstrs = flat (maps
@@ -311,6 +321,7 @@
     fun mk_naive_test_term t =
       fold_rev mk_closure (map lookup (Term.add_free_names t []))
         (mk_if (t, none_t, return) true)
+    fun mk_test (vars, check) = fold_rev mk_closure (map lookup vars) check
     fun mk_smart_test_term' concl bound_vars assms genuine =
       let
         fun vars_of t = subtract (op =) bound_vars (Term.add_free_names t [])
@@ -318,9 +329,16 @@
           if member (op =) (Term.add_free_names lhs bound_vars) x then
             c (assm, assms)
           else
-            (remove (op =) x (vars_of assm),
-              mk_let f (try lookup x) lhs 
-                (mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms) genuine)
+            (let
+               val rec_call = mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms
+               fun safe genuine =
+                 the_default I (Option.map mk_closure (try lookup x)) (rec_call genuine)
+            in
+              mk_test (remove (op =) x (vars_of assm),
+                mk_let safe f (try lookup x) lhs 
+                  (mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms) genuine)
+            
+            end)
           | mk_equality_term (lhs, t) c (assm, assms) =
             if is_constrt (strip_comb t) then
               let
@@ -335,24 +353,23 @@
                 val bound_vars' = union (op =) (vars_of lhs) (union (op =) varnames bound_vars)
                 val cont_t = mk_smart_test_term' concl bound_vars' (new_assms @ assms) genuine
               in
-                (vars_of lhs, Datatype_Case.make_case ctxt Datatype_Case.Quiet [] lhs
+                mk_test (vars_of lhs, Datatype_Case.make_case ctxt Datatype_Case.Quiet [] lhs
                   [(list_comb (constr, vars), cont_t), (dummy_var, none_t)])
               end
             else c (assm, assms)
-        fun default (assm, assms) = (vars_of assm,
-          mk_if (HOLogic.mk_not assm, none_t, 
-          mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms) genuine)
-        val (vars, check) =
-          case assms of [] => (vars_of concl, mk_if (concl, none_t, return) genuine)
-            | assm :: assms =>
-              if Config.get ctxt optimise_equality then
-                (case try HOLogic.dest_eq assm of
-                  SOME (lhs, rhs) =>
-                    mk_equality_term (lhs, rhs) (mk_equality_term (rhs, lhs) default) (assm, assms)
-                | NONE => default (assm, assms))
-              else default (assm, assms)
+        fun default (assm, assms) =
+          mk_test (vars_of assm,
+            mk_if (HOLogic.mk_not assm, none_t, 
+            mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms) genuine)
       in
-        fold_rev mk_closure (map lookup vars) check
+        case assms of [] => mk_test (vars_of concl, mk_if (concl, none_t, return) genuine)
+          | assm :: assms =>
+            if Config.get ctxt optimise_equality then
+              (case try HOLogic.dest_eq assm of
+                SOME (lhs, rhs) =>
+                  mk_equality_term (lhs, rhs) (mk_equality_term (rhs, lhs) default) (assm, assms)
+              | NONE => default (assm, assms))
+            else default (assm, assms)
       end
     val mk_smart_test_term =
       Quickcheck_Common.strip_imp #> (fn (assms, concl) => mk_smart_test_term' concl [] assms true)
@@ -377,7 +394,7 @@
         $ lambda free t $ depth
     val none_t = @{term "()"}
     fun mk_safe_if (cond, then_t, else_t) genuine = mk_if (cond, then_t, else_t genuine)
-    fun mk_let def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
+    fun mk_let _ def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
     val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_safe_if mk_let none_t return ctxt 
   in lambda depth (@{term "catch_Counterexample :: unit => term list option"} $ mk_test_term t) end
 
@@ -406,7 +423,7 @@
         $ lambda free t $ depth
     val none_t = Const (@{const_name "None"}, resultT)
     val mk_if = Quickcheck_Common.mk_safe_if genuine_only none_t
-    fun mk_let def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
+    fun mk_let safe def v_opt t e = mk_safe_let_expr genuine_only none_t safe (the_default def v_opt, t, e)
     val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_if mk_let none_t return ctxt
   in lambda genuine_only (lambda depth (mk_test_term t)) end
 
@@ -436,10 +453,10 @@
             $ lambda free (lambda term_var t)) $ depth
     val none_t = Const (@{const_name "None"}, resultT)
     val mk_if = Quickcheck_Common.mk_safe_if genuine_only none_t
-    fun mk_let _ (SOME (v, term_var)) t e =
-      mk_let_expr (v, t, 
-        e #> subst_free [(term_var, absdummy @{typ unit} (HOLogic.mk_term_of (fastype_of t) t))])
-      | mk_let v NONE t e = mk_let_expr (v, t, e)
+    fun mk_let safe _ (SOME (v, term_var)) t e =
+        mk_safe_let_expr genuine_only none_t safe (v, t, 
+          e #> subst_free [(term_var, absdummy @{typ unit} (mk_safe_term t))])
+      | mk_let safe v NONE t e = mk_safe_let_expr genuine_only none_t safe (v, t, e)
     val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_if mk_let none_t return ctxt
   in lambda genuine_only (lambda depth (mk_test_term t)) end
 
@@ -462,7 +479,7 @@
       Const (@{const_name "Quickcheck_Exhaustive.bounded_forall_class.bounded_forall"}, bounded_forallT T)
         $ lambda (Free (s, T)) t $ depth
     fun mk_safe_if (cond, then_t, else_t) genuine = mk_if (cond, then_t, else_t genuine)
-    fun mk_let def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
+    fun mk_let safe def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
     val mk_test_term =
       mk_test_term lookup mk_bounded_forall mk_safe_if mk_let @{term True} (K @{term False}) ctxt
   in lambda depth (mk_test_term t) end