merged
authorbulwahn
Sat, 16 May 2009 15:24:35 +0200
changeset 31173 bbe9e29b9672
parent 31172 74d72ba262fb (diff)
parent 31167 8741df04d1ae (current diff)
child 31174 f1f1e9b53c81
merged
src/HOL/HOL.thy
--- a/src/HOL/HOL.thy	Sat May 16 11:28:23 2009 +0200
+++ b/src/HOL/HOL.thy	Sat May 16 15:24:35 2009 +0200
@@ -1987,6 +1987,18 @@
 
 subsubsection {* Quickcheck *}
 
+ML {*
+structure Quickcheck_RecFun_Simp_Thms = NamedThmsFun
+(
+  val name = "quickcheck_recfun_simp"
+  val description = "simplification rules of recursive functions as needed by Quickcheck"
+)
+*}
+
+setup {*
+  Quickcheck_RecFun_Simp_Thms.setup
+*}
+
 setup {*
   Quickcheck.add_generator ("SML", Codegen.test_term)
 *}
--- a/src/HOL/Tools/function_package/fundef_package.ML	Sat May 16 11:28:23 2009 +0200
+++ b/src/HOL/Tools/function_package/fundef_package.ML	Sat May 16 15:24:35 2009 +0200
@@ -37,7 +37,8 @@
 val simp_attribs = map (Attrib.internal o K)
     [Simplifier.simp_add,
      Code.add_default_eqn_attribute,
-     Nitpick_Const_Simp_Thms.add]
+     Nitpick_Const_Simp_Thms.add,
+     Quickcheck_RecFun_Simp_Thms.add]
 
 val psimp_attribs = map (Attrib.internal o K)
     [Simplifier.simp_add,
--- a/src/HOL/Tools/primrec_package.ML	Sat May 16 11:28:23 2009 +0200
+++ b/src/HOL/Tools/primrec_package.ML	Sat May 16 15:24:35 2009 +0200
@@ -247,7 +247,7 @@
     val spec' = (map o apfst)
       (fn (b, attrs) => (qualify b, Code.add_default_eqn_attrib :: attrs)) spec;
     val simp_atts = map (Attrib.internal o K)
-      [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add];
+      [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add];
   in
     lthy
     |> set_group ? LocalTheory.set_group (serial_string ())
--- a/src/HOL/Tools/recdef_package.ML	Sat May 16 11:28:23 2009 +0200
+++ b/src/HOL/Tools/recdef_package.ML	Sat May 16 15:24:35 2009 +0200
@@ -208,7 +208,7 @@
                congs wfs name R eqs;
     val rules = (map o map) fst (partition_eq (eq_snd (op = : int * int -> bool)) rules_idx);
     val simp_att = if null tcs then [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add,
-      Code.add_default_eqn_attribute] else [];
+      Code.add_default_eqn_attribute, Quickcheck_RecFun_Simp_Thms.add] else [];
 
     val ((simps' :: rules', [induct']), thy) =
       thy
--- a/src/HOL/ex/predicate_compile.ML	Sat May 16 11:28:23 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML	Sat May 16 15:24:35 2009 +0200
@@ -18,8 +18,9 @@
   val code_pred_cmd: string -> Proof.context -> Proof.state
   val print_alternative_rules: theory -> theory (*FIXME diagnostic command?*)
   val do_proofs: bool ref
-  val pred_intros: theory -> string -> thm list
-  val get_nparams: theory -> string -> int
+  val pred_intros : theory -> string -> thm list
+  val get_nparams : theory -> string -> int
+  val pred_term_of : theory -> term -> term option
 end;
 
 structure Predicate_Compile : PREDICATE_COMPILE =
@@ -270,7 +271,7 @@
 datatype hmode = Mode of mode * int list * hmode option list; (*FIXME don't understand
   why there is another mode type!?*)
 
-fun modes_of modes t =
+fun modes_of_term modes t =
   let
     val ks = 1 upto length (binder_types (fastype_of t));
     val default = [Mode (([], ks), ks, [])];
@@ -288,7 +289,7 @@
           in map (fn x => Mode (m, is', x)) (cprods (map
             (fn (NONE, _) => [NONE]
               | (SOME js, arg) => map SOME (filter
-                  (fn Mode (_, js', _) => js=js') (modes_of modes arg)))
+                  (fn Mode (_, js', _) => js=js') (modes_of_term modes arg)))
                     (iss ~~ args1)))
           end
         end)) (AList.lookup op = modes name)
@@ -317,13 +318,13 @@
             term_vs t subset vs andalso
             forall is_eqT dupTs
           end)
-            (modes_of modes t handle Option =>
+            (modes_of_term modes t handle Option =>
                error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
       | Negprem (us, t) => find_first (fn Mode (_, is, _) =>
             length us = length is andalso
             terms_vs us subset vs andalso
             term_vs t subset vs)
-            (modes_of modes t handle Option =>
+            (modes_of_term modes t handle Option =>
                error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
       | Sidecond t => if term_vs t subset vs then SOME (Mode (([], []), [], []))
           else NONE
@@ -1426,4 +1427,27 @@
 - Naming of auxiliary rules necessary?
 *)
 
+(* transformation for code generation *)
+
+fun pred_term_of thy t = let
+   val (vars, body) = strip_abs t
+   val (pred, all_args) = strip_comb body
+   val (name, T) = dest_Const pred 
+   val (params, args) = chop (get_nparams thy name) all_args
+   val user_mode = flat (map_index
+      (fn (i, t) => case t of Bound j => if j < length vars then [] else [i+1] | _ => [i+1])
+        args)
+  val (inargs, _) = get_args user_mode args
+  val all_modes = Symtab.dest (#modes (IndCodegenData.get thy))
+  val modes = filter (fn Mode (_, is, _) => is = user_mode) (modes_of_term all_modes (list_comb (pred, params)))
+  fun compile m = list_comb (compile_expr thy all_modes (SOME m, list_comb (pred, params)), inargs)
+  in
+    case modes of
+      []  => (let val _ = error "No mode possible for this term" in NONE end)
+    | [m] => SOME (compile m)
+    | ms  => (let val _ = warning "Multiple modes possible for this term"
+        in SOME (compile (hd ms)) end)
+  end;
+
 end;
+