--- a/src/HOL/ex/predicate_compile.ML Tue Aug 04 01:01:23 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML Tue Aug 04 08:34:56 2009 +0200
@@ -17,6 +17,7 @@
val predfun_name_of: theory -> string -> mode -> string
val all_preds_of : theory -> string list
val modes_of: theory -> string -> mode list
+ val string_of_mode : mode -> string
val intros_of: theory -> string -> thm list
val nparams_of: theory -> string -> int
val add_intro: thm -> theory -> theory
@@ -25,12 +26,17 @@
val code_pred: string -> Proof.context -> Proof.state
val code_pred_cmd: string -> Proof.context -> Proof.state
val print_stored_rules: theory -> unit
+ val print_all_modes: theory -> unit
val do_proofs: bool ref
val mk_casesrule : Proof.context -> int -> thm list -> term
val analyze_compr: theory -> term -> term
val eval_ref: (unit -> term Predicate.pred) option ref
val add_equations : string -> theory -> theory
val code_pred_intros_attrib : attribute
+ (* used by Quickcheck_Generator *)
+ val funT_of : mode -> typ -> typ
+ val mk_if_pred : term -> term
+ val mk_Eval : term * term -> term
end;
structure Predicate_Compile : PREDICATE_COMPILE =
@@ -43,8 +49,7 @@
fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
-fun new_print_tac s = Tactical.print_tac s
-fun debug_tac msg = (fn st => (Output.tracing msg; Seq.single st));
+fun debug_tac msg = (fn st => (tracing msg; Seq.single st));
val do_proofs = ref true;
@@ -113,7 +118,7 @@
val mk_sup = HOLogic.mk_binop @{const_name sup};
-fun mk_if_predenum cond = Const (@{const_name Predicate.if_pred},
+fun mk_if_pred cond = Const (@{const_name Predicate.if_pred},
HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) $ cond;
fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
@@ -248,6 +253,18 @@
fold print preds ()
end;
+fun print_all_modes thy =
+ let
+ val _ = writeln ("Inferred modes:")
+ fun print (pred, modes) u =
+ let
+ val _ = writeln ("predicate: " ^ pred)
+ val _ = writeln ("modes: " ^ (commas (map string_of_mode modes)))
+ in u end
+ in
+ fold print (all_modes_of thy) ()
+ end
+
(** preprocessing rules **)
fun imp_prems_conv cv ct =
@@ -465,12 +482,13 @@
end
end)) (AList.lookup op = modes name)
- in (case strip_comb t of
+ in
+ case strip_comb (Envir.eta_contract t) of
(Const (name, _), args) => the_default default (mk_modes name args)
| (Var ((name, _), _), args) => the (mk_modes name args)
| (Free (name, _), args) => the (mk_modes name args)
- | (Abs _, []) => modes_of_param default modes t
- | _ => default)
+ | (Abs _, []) => error "Abs at param position" (* modes_of_param default modes t *)
+ | _ => default
end
datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term;
@@ -529,7 +547,7 @@
in (p, List.filter (fn m => case find_index
(not o check_mode_clause thy param_vs modes m) rs of
~1 => true
- | i => (tracing ("Clause " ^ string_of_int (i+1) ^ " of " ^
+ | i => (Output.tracing ("Clause " ^ string_of_int (i+1) ^ " of " ^
p ^ " violates mode " ^ string_of_mode m); false)) ms)
end;
@@ -547,18 +565,18 @@
(* term construction *)
-(* for simple modes (e.g. parameters) only: better call it param_funT *)
-(* or even better: remove it and only use funT'_of - some modifications to funT'_of necessary *)
-fun funT_of T NONE = T
- | funT_of T (SOME mode) = let
+(* Remark: types of param_funT_of and funT_of are swapped - which is the more
+canonical order? *)
+fun param_funT_of T NONE = T
+ | param_funT_of T (SOME mode) = let
val Ts = binder_types T;
val (Us1, Us2) = get_args mode Ts
in Us1 ---> (mk_pred_enumT (mk_tupleT Us2)) end;
-fun funT'_of (iss, is) T = let
+fun funT_of (iss, is) T = let
val Ts = binder_types T
val (paramTs, argTs) = chop (length iss) Ts
- val paramTs' = map2 (fn SOME is => funT'_of ([], is) | NONE => I) iss paramTs
+ val paramTs' = map2 (fn SOME is => funT_of ([], is) | NONE => I) iss paramTs
val (inargTs, outargTs) = get_args is argTs
in
(paramTs' @ inargTs) ---> (mk_pred_enumT (mk_tupleT outargTs))
@@ -636,9 +654,9 @@
val f' = case f of
Const (name, T) =>
if AList.defined op = modes name then
- Const (predfun_name_of thy name (iss, is'), funT'_of (iss, is') T)
+ Const (predfun_name_of thy name (iss, is'), funT_of (iss, is') T)
else error "compile param: Not an inductive predicate with correct mode"
- | Free (name, T) => Free (name, funT_of T (SOME is'))
+ | Free (name, T) => Free (name, param_funT_of T (SOME is'))
val outTs = dest_tupleT (dest_pred_enumT (body_type (fastype_of f')))
val out_vs = map Free (out_names ~~ outTs)
val params' = map (compile_param thy modes) (ms ~~ params)
@@ -662,9 +680,9 @@
val f' = case f of
Const (name, T) =>
if AList.defined op = modes name then
- Const (predfun_name_of thy name (iss, is'), funT'_of (iss, is') T)
+ Const (predfun_name_of thy name (iss, is'), funT_of (iss, is') T)
else error "compile param: Not an inductive predicate with correct mode"
- | Free (name, T) => Free (name, funT_of T (SOME is'))
+ | Free (name, T) => Free (name, param_funT_of T (SOME is'))
in list_comb (f', params' @ args') end
| compile_param _ _ _ = error "compile params"
@@ -684,7 +702,7 @@
| (Free (name, T), args) =>
(*if name mem param_vs then *)
(* Higher order mode call *)
- let val r = Free (name, funT_of T (SOME is))
+ let val r = Free (name, param_funT_of T (SOME is))
in list_comb (r, args) end)
| compile_expr _ _ _ = error "not a valid inductive expression"
@@ -746,7 +764,7 @@
let
val rest = compile_prems [] vs' names'' ps';
in
- (mk_if_predenum t, rest)
+ (mk_if_pred t, rest)
end
in
compile_match thy constr_vs' eqs out_ts''
@@ -761,7 +779,7 @@
let
val Ts = binder_types T;
val (Ts1, Ts2) = chop (length param_vs) Ts;
- val Ts1' = map2 funT_of Ts1 (fst mode)
+ val Ts1' = map2 param_funT_of Ts1 (fst mode)
val (Us1, Us2) = get_args (snd mode) Ts2;
val xnames = Name.variant_list param_vs
(map (fn i => "x" ^ string_of_int i) (snd mode));
@@ -817,7 +835,7 @@
val argnames = Name.variant_list []
(map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
val (Ts1, Ts2) = chop nparams Ts;
- val Ts1' = map2 funT_of Ts1 (fst mode)
+ val Ts1' = map2 param_funT_of Ts1 (fst mode)
val args = map Free (argnames ~~ (Ts1' @ Ts2))
val (params, io_args) = chop nparams args
val (inargs, outargs) = get_args (snd mode) io_args
@@ -834,7 +852,7 @@
val funpropI = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs),
mk_tuple outargs))
val introtrm = Logic.list_implies (predpropI :: param_eqs, funpropI)
- val _ = Output.tracing (Syntax.string_of_term_global thy introtrm)
+ val _ = tracing (Syntax.string_of_term_global thy introtrm)
val simprules = [defthm, @{thm eval_pred},
@{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}]
val unfolddef_tac = (Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps simprules) 1)
@@ -860,7 +878,7 @@
^ (string_of_mode (snd mode))
val Ts = binder_types T;
val (Ts1, Ts2) = chop nparams Ts;
- val Ts1' = map2 funT_of Ts1 (fst mode)
+ val Ts1' = map2 param_funT_of Ts1 (fst mode)
val (Us1, Us2) = get_args (snd mode) Ts2;
val names = Name.variant_list []
(map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
@@ -921,12 +939,12 @@
REPEAT_DETERM (etac @{thm thin_rl} 1)
THEN REPEAT_DETERM (rtac @{thm ext} 1)
THEN (rtac @{thm iffI} 1)
- THEN new_print_tac "prove_param"
+ THEN print_tac "prove_param"
(* proof in one direction *)
THEN (atac 1)
(* proof in the other direction *)
THEN (atac 1)
- THEN new_print_tac "after prove_param"
+ THEN print_tac "after prove_param"
(* let
val (f, args) = strip_comb t
val (params, _) = chop (length ms) args
@@ -964,10 +982,10 @@
(* for the right assumption in first position *)
THEN rotate_tac premposition 1
THEN rtac introrule 1
- THEN new_print_tac "after intro rule"
+ THEN print_tac "after intro rule"
(* work with parameter arguments *)
THEN (atac 1)
- THEN (new_print_tac "parameter goal")
+ THEN (print_tac "parameter goal")
THEN (EVERY (map (prove_param thy modes) (ms ~~ args1)))
THEN (REPEAT_DETERM (atac 1)) end)
else error "Prove expr if case not implemented"
@@ -1110,7 +1128,7 @@
(fn i => EVERY' (select_sup (length clauses) i) i)
(1 upto (length clauses))))
THEN (EVERY (map (prove_clause thy nargs all_vs param_vs modes mode) clauses))
- THEN new_print_tac "proved one direction"
+ THEN print_tac "proved one direction"
end;
(*******************************************************************************************************)
@@ -1168,15 +1186,15 @@
if AList.defined op = modes name then
etac @{thm bindE} 1
THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})))
- THEN new_print_tac "prove_expr2-before"
+ THEN print_tac "prove_expr2-before"
THEN (debug_tac (Syntax.string_of_term_global thy
(prop_of (predfun_elim_of thy name mode))))
THEN (etac (predfun_elim_of thy name mode) 1)
- THEN new_print_tac "prove_expr2"
+ THEN print_tac "prove_expr2"
(* TODO -- FIXME: replace remove_last_goal*)
(* THEN (EVERY (replicate (length args) (remove_last_goal thy))) *)
THEN (EVERY (map (prove_param thy modes) (ms ~~ args)))
- THEN new_print_tac "finished prove_expr2"
+ THEN print_tac "finished prove_expr2"
else error "Prove expr2 if case not implemented"
| _ => etac @{thm bindE} 1)
@@ -1273,7 +1291,7 @@
end;
val prems_tac = prove_prems2 in_ts' param_vs ps
in
- new_print_tac "starting prove_clause2"
+ print_tac "starting prove_clause2"
THEN etac @{thm bindE} 1
THEN (etac @{thm singleE'} 1)
THEN (TRY (etac @{thm Pair_inject} 1))
@@ -1401,10 +1419,10 @@
val clauses' = map (fn (s, cls) => (s, (the (AList.lookup (op =) preds s), cls))) clauses
val _ = tracing "Compiling equations..."
val ts = compile_preds thy' all_vs param_vs (extra_modes @ modes) clauses'
- val _ = map (Output.tracing o (Syntax.string_of_term_global thy')) (flat ts)
+(* val _ = map (tracing o (Syntax.string_of_term_global thy')) (flat ts) *)
val pred_mode =
maps (fn (s, (T, _)) => map (pair (s, T)) ((the o AList.lookup (op =) modes) s)) clauses'
- val _ = Output.tracing "Proving equations..."
+ val _ = tracing "Proving equations..."
val result_thms =
prove_preds thy' all_vs param_vs (extra_modes @ modes) clauses (pred_mode ~~ (flat ts))
val thy'' = fold (fn (name, result_thms) => fn thy => snd (PureThy.add_thmss
@@ -1486,7 +1504,6 @@
assumes = [("", Logic.strip_imp_prems case_rule)],
binds = [], cases = []}) cases_rules
val case_env = map2 (fn p => fn c => (Long_Name.base_name p, SOME c)) preds cases
- val _ = Output.tracing (commas (map fst case_env))
val lthy'' = ProofContext.add_cases true case_env lthy'
fun after_qed thms =