--- a/src/HOL/ex/predicate_compile.ML Fri May 15 10:01:57 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML Fri May 15 15:56:28 2009 +0200
@@ -18,8 +18,9 @@
val code_pred_cmd: string -> Proof.context -> Proof.state
val print_alternative_rules: theory -> theory (*FIXME diagnostic command?*)
val do_proofs: bool ref
- val pred_intros: theory -> string -> thm list
- val get_nparams: theory -> string -> int
+ val pred_intros : theory -> string -> thm list
+ val get_nparams : theory -> string -> int
+ val pred_term_of : theory -> term -> term option
end;
structure Predicate_Compile : PREDICATE_COMPILE =
@@ -270,7 +271,7 @@
datatype hmode = Mode of mode * int list * hmode option list; (*FIXME don't understand
why there is another mode type!?*)
-fun modes_of modes t =
+fun modes_of_term modes t =
let
val ks = 1 upto length (binder_types (fastype_of t));
val default = [Mode (([], ks), ks, [])];
@@ -288,7 +289,7 @@
in map (fn x => Mode (m, is', x)) (cprods (map
(fn (NONE, _) => [NONE]
| (SOME js, arg) => map SOME (filter
- (fn Mode (_, js', _) => js=js') (modes_of modes arg)))
+ (fn Mode (_, js', _) => js=js') (modes_of_term modes arg)))
(iss ~~ args1)))
end
end)) (AList.lookup op = modes name)
@@ -317,13 +318,13 @@
term_vs t subset vs andalso
forall is_eqT dupTs
end)
- (modes_of modes t handle Option =>
+ (modes_of_term modes t handle Option =>
error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
| Negprem (us, t) => find_first (fn Mode (_, is, _) =>
length us = length is andalso
terms_vs us subset vs andalso
term_vs t subset vs)
- (modes_of modes t handle Option =>
+ (modes_of_term modes t handle Option =>
error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
| Sidecond t => if term_vs t subset vs then SOME (Mode (([], []), [], []))
else NONE
@@ -1426,4 +1427,27 @@
- Naming of auxiliary rules necessary?
*)
+(* transformation for code generation *)
+
+fun pred_term_of thy t = let
+ val (vars, body) = strip_abs t
+ val (pred, all_args) = strip_comb body
+ val (name, T) = dest_Const pred
+ val (params, args) = chop (get_nparams thy name) all_args
+ val user_mode = flat (map_index
+ (fn (i, t) => case t of Bound j => if j < length vars then [] else [i+1] | _ => [i+1])
+ args)
+ val (inargs, _) = get_args user_mode args
+ val all_modes = Symtab.dest (#modes (IndCodegenData.get thy))
+ val modes = filter (fn Mode (_, is, _) => is = user_mode) (modes_of_term all_modes (list_comb (pred, params)))
+ fun compile m = list_comb (compile_expr thy all_modes (SOME m, list_comb (pred, params)), inargs)
+ in
+ case modes of
+ [] => (let val _ = error "No mode possible for this term" in NONE end)
+ | [m] => SOME (compile m)
+ | ms => (let val _ = warning "Multiple modes possible for this term"
+ in SOME (compile (hd ms)) end)
+ end;
+
end;
+