--- a/src/HOL/HOL.thy Sat May 16 11:28:23 2009 +0200
+++ b/src/HOL/HOL.thy Sat May 16 15:24:35 2009 +0200
@@ -1987,6 +1987,18 @@
subsubsection {* Quickcheck *}
+ML {*
+structure Quickcheck_RecFun_Simp_Thms = NamedThmsFun
+(
+ val name = "quickcheck_recfun_simp"
+ val description = "simplification rules of recursive functions as needed by Quickcheck"
+)
+*}
+
+setup {*
+ Quickcheck_RecFun_Simp_Thms.setup
+*}
+
setup {*
Quickcheck.add_generator ("SML", Codegen.test_term)
*}
--- a/src/HOL/Tools/function_package/fundef_package.ML Sat May 16 11:28:23 2009 +0200
+++ b/src/HOL/Tools/function_package/fundef_package.ML Sat May 16 15:24:35 2009 +0200
@@ -37,7 +37,8 @@
val simp_attribs = map (Attrib.internal o K)
[Simplifier.simp_add,
Code.add_default_eqn_attribute,
- Nitpick_Const_Simp_Thms.add]
+ Nitpick_Const_Simp_Thms.add,
+ Quickcheck_RecFun_Simp_Thms.add]
val psimp_attribs = map (Attrib.internal o K)
[Simplifier.simp_add,
--- a/src/HOL/Tools/primrec_package.ML Sat May 16 11:28:23 2009 +0200
+++ b/src/HOL/Tools/primrec_package.ML Sat May 16 15:24:35 2009 +0200
@@ -247,7 +247,7 @@
val spec' = (map o apfst)
(fn (b, attrs) => (qualify b, Code.add_default_eqn_attrib :: attrs)) spec;
val simp_atts = map (Attrib.internal o K)
- [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add];
+ [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add];
in
lthy
|> set_group ? LocalTheory.set_group (serial_string ())
--- a/src/HOL/Tools/recdef_package.ML Sat May 16 11:28:23 2009 +0200
+++ b/src/HOL/Tools/recdef_package.ML Sat May 16 15:24:35 2009 +0200
@@ -208,7 +208,7 @@
congs wfs name R eqs;
val rules = (map o map) fst (partition_eq (eq_snd (op = : int * int -> bool)) rules_idx);
val simp_att = if null tcs then [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add,
- Code.add_default_eqn_attribute] else [];
+ Code.add_default_eqn_attribute, Quickcheck_RecFun_Simp_Thms.add] else [];
val ((simps' :: rules', [induct']), thy) =
thy
--- a/src/HOL/ex/predicate_compile.ML Sat May 16 11:28:23 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML Sat May 16 15:24:35 2009 +0200
@@ -18,8 +18,9 @@
val code_pred_cmd: string -> Proof.context -> Proof.state
val print_alternative_rules: theory -> theory (*FIXME diagnostic command?*)
val do_proofs: bool ref
- val pred_intros: theory -> string -> thm list
- val get_nparams: theory -> string -> int
+ val pred_intros : theory -> string -> thm list
+ val get_nparams : theory -> string -> int
+ val pred_term_of : theory -> term -> term option
end;
structure Predicate_Compile : PREDICATE_COMPILE =
@@ -270,7 +271,7 @@
datatype hmode = Mode of mode * int list * hmode option list; (*FIXME don't understand
why there is another mode type!?*)
-fun modes_of modes t =
+fun modes_of_term modes t =
let
val ks = 1 upto length (binder_types (fastype_of t));
val default = [Mode (([], ks), ks, [])];
@@ -288,7 +289,7 @@
in map (fn x => Mode (m, is', x)) (cprods (map
(fn (NONE, _) => [NONE]
| (SOME js, arg) => map SOME (filter
- (fn Mode (_, js', _) => js=js') (modes_of modes arg)))
+ (fn Mode (_, js', _) => js=js') (modes_of_term modes arg)))
(iss ~~ args1)))
end
end)) (AList.lookup op = modes name)
@@ -317,13 +318,13 @@
term_vs t subset vs andalso
forall is_eqT dupTs
end)
- (modes_of modes t handle Option =>
+ (modes_of_term modes t handle Option =>
error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
| Negprem (us, t) => find_first (fn Mode (_, is, _) =>
length us = length is andalso
terms_vs us subset vs andalso
term_vs t subset vs)
- (modes_of modes t handle Option =>
+ (modes_of_term modes t handle Option =>
error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
| Sidecond t => if term_vs t subset vs then SOME (Mode (([], []), [], []))
else NONE
@@ -1426,4 +1427,27 @@
- Naming of auxiliary rules necessary?
*)
+(* transformation for code generation *)
+
+fun pred_term_of thy t = let
+ val (vars, body) = strip_abs t
+ val (pred, all_args) = strip_comb body
+ val (name, T) = dest_Const pred
+ val (params, args) = chop (get_nparams thy name) all_args
+ val user_mode = flat (map_index
+ (fn (i, t) => case t of Bound j => if j < length vars then [] else [i+1] | _ => [i+1])
+ args)
+ val (inargs, _) = get_args user_mode args
+ val all_modes = Symtab.dest (#modes (IndCodegenData.get thy))
+ val modes = filter (fn Mode (_, is, _) => is = user_mode) (modes_of_term all_modes (list_comb (pred, params)))
+ fun compile m = list_comb (compile_expr thy all_modes (SOME m, list_comb (pred, params)), inargs)
+ in
+ case modes of
+ [] => (let val _ = error "No mode possible for this term" in NONE end)
+ | [m] => SOME (compile m)
+ | ms => (let val _ = warning "Multiple modes possible for this term"
+ in SOME (compile (hd ms)) end)
+ end;
+
end;
+