merged
authorhaftmann
Sat, 13 Jun 2009 16:32:38 +0200
changeset 31626 fe35b72b9ef0
parent 31614 ef6d67b1ad10 (current diff)
parent 31625 9e4d7d60c3e7 (diff)
child 31627 bc2de3795756
child 31628 28699098b5f3
merged
NEWS
--- a/NEWS	Sat Jun 13 16:29:15 2009 +0200
+++ b/NEWS	Sat Jun 13 16:32:38 2009 +0200
@@ -37,6 +37,9 @@
 * New method "linarith" invokes existing linear arithmetic decision
 procedure only.
 
+* Implementation of quickcheck using generic code generator; default generators
+are provided for all suitable HOL types, records and datatypes.
+
 
 *** ML ***
 
--- a/src/HOL/Quickcheck.thy	Sat Jun 13 16:29:15 2009 +0200
+++ b/src/HOL/Quickcheck.thy	Sat Jun 13 16:32:38 2009 +0200
@@ -93,15 +93,15 @@
   \<Rightarrow> (Random.seed \<Rightarrow> ('b \<times> (unit \<Rightarrow> term)) \<times> Random.seed) \<Rightarrow> (Random.seed \<Rightarrow> Random.seed \<times> Random.seed)
   \<Rightarrow> Random.seed \<Rightarrow> (('a \<Rightarrow> 'b) \<times> (unit \<Rightarrow> term)) \<times> Random.seed"
 
-definition random_fun_lift :: "(code_numeral \<Rightarrow> Random.seed \<Rightarrow> ('b \<times> (unit \<Rightarrow> term)) \<times> Random.seed)
-  \<Rightarrow> code_numeral \<Rightarrow> Random.seed \<Rightarrow> (('a\<Colon>term_of \<Rightarrow> 'b\<Colon>typerep) \<times> (unit \<Rightarrow> term)) \<times> Random.seed" where
-  "random_fun_lift f i = random_fun_aux TYPEREP('a) TYPEREP('b) (op =) Code_Eval.term_of (f i) Random.split_seed"
+definition random_fun_lift :: "(Random.seed \<Rightarrow> ('b \<times> (unit \<Rightarrow> term)) \<times> Random.seed)
+  \<Rightarrow> Random.seed \<Rightarrow> (('a\<Colon>term_of \<Rightarrow> 'b\<Colon>typerep) \<times> (unit \<Rightarrow> term)) \<times> Random.seed" where
+  "random_fun_lift f = random_fun_aux TYPEREP('a) TYPEREP('b) (op =) Code_Eval.term_of f Random.split_seed"
 
 instantiation "fun" :: ("{eq, term_of}", "{type, random}") random
 begin
 
 definition random_fun :: "code_numeral \<Rightarrow> Random.seed \<Rightarrow> (('a \<Rightarrow> 'b) \<times> (unit \<Rightarrow> term)) \<times> Random.seed" where
-  "random = random_fun_lift random"
+  "random i = random_fun_lift (random i)"
 
 instance ..
 
--- a/src/HOL/Tools/datatype_package/datatype_codegen.ML	Sat Jun 13 16:29:15 2009 +0200
+++ b/src/HOL/Tools/datatype_package/datatype_codegen.ML	Sat Jun 13 16:32:38 2009 +0200
@@ -15,13 +15,7 @@
 structure DatatypeCodegen : DATATYPE_CODEGEN =
 struct
 
-(** SML code generator **)
-
-open Codegen;
-
-(**** datatype definition ****)
-
-(* find shortest path to constructor with no recursive arguments *)
+(** find shortest path to constructor with no recursive arguments **)
 
 fun find_nonempty (descr: DatatypeAux.descr) is i =
   let
@@ -41,6 +35,13 @@
 
 fun find_shortest_path descr i = find_nonempty descr [i] i;
 
+
+(** SML code generator **)
+
+open Codegen;
+
+(* datatype definition *)
+
 fun add_dt_defs thy defs dep module (descr: DatatypeAux.descr) sorts gr =
   let
     val descr' = List.filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr;
@@ -210,7 +211,7 @@
   end;
 
 
-(**** case expressions ****)
+(* case expressions *)
 
 fun pretty_case thy defs dep module brack constrs (c as Const (_, T)) ts gr =
   let val i = length constrs
@@ -252,7 +253,7 @@
   end;
 
 
-(**** constructors ****)
+(* constructors *)
 
 fun pretty_constr thy defs dep module brack args (c as Const (s, T)) ts gr =
   let val i = length args
@@ -271,7 +272,7 @@
   end;
 
 
-(**** code generators for terms and types ****)
+(* code generators for terms and types *)
 
 fun datatype_codegen thy defs dep module brack t gr = (case strip_comb t of
    (c as Const (s, T), ts) =>
@@ -313,6 +314,18 @@
 
 (** generic code generator **)
 
+(* liberal addition of code data for datatypes *)
+
+fun mk_constr_consts thy vs dtco cos =
+  let
+    val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
+    val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
+  in if is_some (try (Code.constrset_of_consts thy) cs')
+    then SOME cs
+    else NONE
+  end;
+
+
 (* case certificates *)
 
 fun mk_case_cert thy tyco =
@@ -371,7 +384,7 @@
     val simpset = Simplifier.context (ProofContext.init thy) (HOL_basic_ss
       addsimps (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms))
       addsimprocs [DatatypePackage.distinct_simproc]);
-    fun prove prop = Goal.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
+    fun prove prop = SkipProof.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
       |> Simpdata.mk_eq;
   in map (rpair true o prove) (triv_injects @ injects @ distincts) @ [(prove refl, false)] end;
 
@@ -411,16 +424,7 @@
   end;
 
 
-(* liberal addition of code data for datatypes *)
-
-fun mk_constr_consts thy vs dtco cos =
-  let
-    val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
-    val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
-  in if is_some (try (Code.constrset_of_consts thy) cs')
-    then SOME cs
-    else NONE
-  end;
+(* register a datatype etc. *)
 
 fun add_all_code dtcos thy =
   let
@@ -433,6 +437,7 @@
   in
     if null css then thy
     else thy
+      |> tap (fn _ => DatatypeAux.message "Registering datatype for code generator ...")
       |> fold Code.add_datatype css
       |> fold_rev Code.add_default_eqn case_rewrites
       |> fold Code.add_case certs
@@ -440,7 +445,6 @@
    end;
 
 
-
 (** theory setup **)
 
 val setup = 
--- a/src/HOL/Tools/quickcheck_generators.ML	Sat Jun 13 16:29:15 2009 +0200
+++ b/src/HOL/Tools/quickcheck_generators.ML	Sat Jun 13 16:32:38 2009 +0200
@@ -170,9 +170,9 @@
     val eqs1 = map (Pattern.rewrite_term thy rew_ts []) eqs0;
     val ((_, eqs2), lthy') = PrimrecPackage.add_primrec_simple
       [((Binding.name random_aux, T), NoSyn)] eqs1 lthy;
-    val eq_tac = ALLGOALS Goal.conjunction_tac THEN ALLGOALS (simp_tac rew_ss)
+    val eq_tac = ALLGOALS (simp_tac rew_ss)
       THEN (ALLGOALS (ProofContext.fact_tac (flat eqs2)));
-    val eqs3 = Goal.prove_multi lthy' [v] [] eqs0 (K eq_tac);
+    val eqs3 = map (fn prop => SkipProof.prove lthy' [v] [] prop (K eq_tac)) eqs0;
     val cT_random_aux = inst pt_random_aux;
     val cT_rhs = inst pt_rhs;
     val rule = @{thm random_aux_rec}
@@ -180,7 +180,7 @@
            [(cT_random_aux, cert t_random_aux), (cT_rhs, cert t_rhs)])
       |> (fn thm => thm OF eqs3);
     val tac = ALLGOALS (rtac rule);
-    val simp = Goal.prove lthy' [v] [] eq (K tac);
+    val simp = SkipProof.prove lthy' [v] [] eq (K tac);
   in (simp, lthy') end;
 
 end;
@@ -212,11 +212,11 @@
         fun prove_eqs aux_simp proj_defs lthy = 
           let
             val proj_simps = map (snd o snd) proj_defs;
-            fun tac { context = ctxt, ... } = ALLGOALS Goal.conjunction_tac
-              THEN ALLGOALS (simp_tac (HOL_ss addsimps proj_simps))
+            fun tac { context = ctxt, ... } =
+              ALLGOALS (simp_tac (HOL_ss addsimps proj_simps))
               THEN ALLGOALS (EqSubst.eqsubst_tac ctxt [0] [aux_simp])
               THEN ALLGOALS (simp_tac (HOL_ss addsimps [fst_conv, snd_conv]));
-          in (Goal.prove_multi lthy [v] [] eqs tac, lthy) end;
+          in (map (fn prop => SkipProof.prove lthy [v] [] prop tac) eqs, lthy) end;
       in
         lthy
         |> random_aux_primrec aux_eq'
@@ -235,10 +235,9 @@
     val proto_eqs = map mk_proto_eq eqs;
     fun prove_simps proto_simps lthy =
       let
-        val ext_simps = map (fn thm => fun_cong OF [fun_cong OF  [thm]]) proto_simps;
-        val tac = ALLGOALS Goal.conjunction_tac
-          THEN ALLGOALS (ProofContext.fact_tac ext_simps);
-      in (Goal.prove_multi lthy vs [] eqs (K tac), lthy) end;
+        val ext_simps = map (fn thm => fun_cong OF [fun_cong OF [thm]]) proto_simps;
+        val tac = ALLGOALS (ProofContext.fact_tac ext_simps);
+      in (map (fn prop => SkipProof.prove lthy vs [] prop (K tac)) eqs, lthy) end;
     val b = Binding.qualify true prefix (Binding.name "simps");
   in
     lthy
@@ -254,8 +253,6 @@
 
 (* constructing random instances on datatypes *)
 
-exception Datatype_Fun; (*FIXME*)
-
 fun mk_random_aux_eqs thy descr vs tycos (names, auxnames) (Ts, Us) =
   let
     val mk_const = curry (Sign.mk_const thy);
@@ -274,19 +271,20 @@
     val random_auxT = sizeT o random_resultT;
     val random_auxs = map2 (fn s => fn rT => Free (s, random_auxT rT))
       random_auxsN rTs;
-
     fun mk_random_call T = (NONE, (HOLogic.mk_random T j, T));
     fun mk_random_aux_call fTs (k, _) (tyco, Ts) =
       let
-        val _ = if null fTs then () else raise Datatype_Fun; (*FIXME*)
-        val random = nth random_auxs k;
+        val T = Type (tyco, Ts);
+        fun mk_random_fun_lift [] t = t
+          | mk_random_fun_lift (fT :: fTs) t =
+              mk_const @{const_name random_fun_lift} [fTs ---> T, fT] $
+                mk_random_fun_lift fTs t;
+        val t = mk_random_fun_lift fTs (nth random_auxs k $ i1 $ j);
         val size = Option.map snd (DatatypeCodegen.find_shortest_path descr k)
           |> the_default 0;
-      in (SOME size, (random $ i1 $ j, Type (tyco, Ts))) end;
-
+      in (SOME size, (t, fTs ---> T)) end;
     val tss = DatatypeAux.interpret_construction descr vs
       { atyp = mk_random_call, dtyp = mk_random_aux_call };
-
     fun mk_consexpr simpleT (c, xs) =
       let
         val (ks, simple_tTs) = split_list xs;
@@ -324,6 +322,7 @@
 
 fun mk_random_datatype descr vs tycos (names, auxnames) (Ts, Us) thy =
   let
+    val _ = DatatypeAux.message "Creating quickcheck generators ...";
     val i = @{term "i\<Colon>code_numeral"};
     val mk_prop_eq = HOLogic.mk_Trueprop o HOLogic.mk_eq;
     fun mk_size_arg k = case DatatypeCodegen.find_shortest_path descr k
@@ -347,35 +346,39 @@
     |> Class.prove_instantiation_exit (K (Class.intro_classes_tac []))
   end;
 
+fun perhaps_constrain thy insts raw_vs =
+  let
+    fun meet_random (T, sort) = Sorts.meet_sort (Sign.classes_of thy) 
+      (Logic.varifyT T, sort);
+    val vtab = Vartab.empty
+      |> fold (fn (v, sort) => Vartab.update ((v, 0), sort)) raw_vs
+      |> fold meet_random insts;
+  in SOME (fn (v, _) => (v, (the o Vartab.lookup vtab) (v, 0)))
+  end handle CLASS_ERROR => NONE;
+
 fun ensure_random_datatype raw_tycos thy =
   let
     val pp = Syntax.pp_global thy;
     val algebra = Sign.classes_of thy;
     val (descr, raw_vs, tycos, (names, auxnames), raw_TUs) =
       DatatypePackage.the_datatype_descr thy raw_tycos;
-    val aTs = (flat o maps snd o maps snd) (DatatypeAux.interpret_construction descr raw_vs
-      { atyp = single, dtyp = K o K });
-    fun meet_random T = Sorts.meet_sort (Sign.classes_of thy) (Logic.varifyT T, @{sort random});
-    val vtab = (Vartab.empty
-      |> fold (fn (v, sort) => Vartab.update ((v, 0), sort)) raw_vs
-      |> fold meet_random aTs
-      |> SOME) handle CLASS_ERROR => NONE;
-    val vconstrain = case vtab of SOME vtab => (fn (v, _) =>
-          (v, (the o Vartab.lookup vtab) (v, 0)))
-      | NONE => I;
-    val vs = map vconstrain raw_vs;
-    val constrain = map_atyps (fn TFree v => TFree (vconstrain v));
+    val random_insts = (map (rpair @{sort random}) o flat o maps snd o maps snd)
+      (DatatypeAux.interpret_construction descr raw_vs { atyp = single, dtyp = (K o K o K) [] });
+    val term_of_insts = (map (rpair @{sort term_of}) o flat o maps snd o maps snd)
+      (DatatypeAux.interpret_construction descr raw_vs { atyp = K [], dtyp = K o K });
     val has_inst = exists (fn tyco =>
       can (Sorts.mg_domain algebra tyco) @{sort random}) tycos;
-  in if is_some vtab andalso not has_inst then
-    (mk_random_datatype descr vs tycos (names, auxnames)
-      ((pairself o map) constrain raw_TUs) thy
-    (*FIXME ephemeral handles*)
-    handle Datatype_Fun => thy
-         | e as TERM (msg, ts) => (tracing (cat_lines (msg :: map (Syntax.string_of_term_global thy) ts)); raise e)
-         | e as TYPE (msg, _, _) =>  (tracing msg; raise e)
-         | e as ERROR msg =>  (tracing msg; raise e))
-  else thy end;
+  in if has_inst then thy
+    else case perhaps_constrain thy (random_insts @ term_of_insts) raw_vs
+     of SOME constrain => (mk_random_datatype descr
+          (map constrain raw_vs) tycos (names, auxnames)
+            ((pairself o map o map_atyps) (fn TFree v => TFree (constrain v)) raw_TUs) thy
+            (*FIXME ephemeral handles*)
+          handle e as TERM (msg, ts) => (tracing (cat_lines (msg :: map (Syntax.string_of_term_global thy) ts)); raise e)
+               | e as TYPE (msg, _, _) => (tracing msg; raise e)
+               | e as ERROR msg => (tracing msg; raise e))
+      | NONE => thy
+  end;
 
 
 (** setup **)