--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Sat Oct 24 16:55:35 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Sat Oct 24 16:55:37 2009 +0200
@@ -2112,10 +2112,24 @@
fun prove_by_skip thy _ _ _ _ compiled_terms =
map_preds_modes (fn pred => fn mode => fn t => Drule.standard (setmp quick_and_dirty true (SkipProof.make_thm thy) t))
compiled_terms
+
+fun dest_prem thy params t =
+ (case strip_comb t of
+ (v as Free _, ts) => if v mem params then Prem (ts, v) else Sidecond t
+ | (c as Const (@{const_name Not}, _), [t]) => (case dest_prem thy params t of
+ Prem (ts, t) => Negprem (ts, t)
+ | Negprem _ => error ("Double negation not allowed in premise: " ^ (Syntax.string_of_term_global thy (c $ t)))
+ | Sidecond t => Sidecond (c $ t))
+ | (c as Const (s, _), ts) =>
+ if is_registered thy s then
+ let val (ts1, ts2) = chop (nparams_of thy s) ts
+ in Prem (ts2, list_comb (c, ts1)) end
+ else Sidecond t
+ | _ => Sidecond t)
-fun prepare_intrs thy prednames =
+fun prepare_intrs thy prednames intros =
let
- val intrs = maps (intros_of thy) prednames
+ val intrs = intros
|> map (Logic.unvarify o prop_of)
val nparams = nparams_of thy (hd prednames)
val extra_modes = all_modes_of thy |> filter_out (fn (name, _) => member (op =) prednames name)
@@ -2124,25 +2138,12 @@
val params = List.take (snd (strip_comb u), nparams);
val param_vs = maps term_vs params
val all_vs = terms_vs intrs
- fun dest_prem t =
- (case strip_comb t of
- (v as Free _, ts) => if v mem params then Prem (ts, v) else Sidecond t
- | (c as Const (@{const_name Not}, _), [t]) => (case dest_prem t of
- Prem (ts, t) => Negprem (ts, t)
- | Negprem _ => error ("Double negation not allowed in premise: " ^ (Syntax.string_of_term_global thy (c $ t)))
- | Sidecond t => Sidecond (c $ t))
- | (c as Const (s, _), ts) =>
- if is_registered thy s then
- let val (ts1, ts2) = chop (nparams_of thy s) ts
- in Prem (ts2, list_comb (c, ts1)) end
- else Sidecond t
- | _ => Sidecond t)
fun add_clause intr (clauses, arities) =
let
val _ $ t = Logic.strip_imp_concl intr;
val (Const (name, T), ts) = strip_comb t;
val (ts1, ts2) = chop nparams ts;
- val prems = map (dest_prem o HOLogic.dest_Trueprop) (Logic.strip_imp_prems intr);
+ val prems = map (dest_prem thy params o HOLogic.dest_Trueprop) (Logic.strip_imp_prems intr);
val (Ts, Us) = chop nparams (binder_types T)
in
(AList.update op = (name, these (AList.lookup op = clauses name) @
@@ -2177,14 +2178,84 @@
val all_modes = map (fn (s, T) => (s, modes_of_typ T)) preds
in (preds, nparams, all_vs, param_vs, extra_modes, clauses, all_modes) end;
+fun check_format_of_intro_rule thy intro =
+ let
+ val concl = Logic.strip_imp_concl (prop_of intro)
+ val (p, args) = strip_comb (HOLogic.dest_Trueprop concl)
+ val params = List.take (args, nparams_of thy (fst (dest_Const p)))
+ fun check_arg arg = case HOLogic.strip_tupleT (fastype_of arg) of
+ (Ts as _ :: _ :: _) =>
+ if (length (HOLogic.strip_tuple arg) = length Ts) then true
+ else error ("Format of introduction rule is invalid: tuples must be expanded:"
+ ^ (Syntax.string_of_term_global thy arg) ^ " in " ^
+ (Display.string_of_thm_global thy intro))
+ | _ => true
+ val prems = Logic.strip_imp_prems (prop_of intro)
+ fun check_prem (Prem (args, _)) = forall check_arg args
+ | check_prem (Negprem (args, _)) = forall check_arg args
+ | check_prem _ = true
+ in
+ forall check_arg args andalso
+ forall (check_prem o dest_prem thy params o HOLogic.dest_Trueprop) prems
+ end
+
+fun expand_tuples thy intro =
+ let
+ fun rewrite_args [] (intro_t, names) = (intro_t, names)
+ | rewrite_args (arg::args) (intro_t, names) =
+ (case HOLogic.strip_tupleT (fastype_of arg) of
+ (Ts as _ :: _ :: _) =>
+ let
+ fun rewrite_arg' (Const ("Pair", _) $ _ $ t2, Type ("*", [_, T2]))
+ (args, (intro_t, names)) = rewrite_arg' (t2, T2) (args, (intro_t, names))
+ | rewrite_arg' (t, Type ("*", [T1, T2])) (args, (intro_t, names)) =
+ let
+ val [x, y] = Name.variant_list names ["x", "y"]
+ val pat = (t, HOLogic.mk_prod (Free (x, T1), Free (y, T2)))
+ val _ = tracing ("Rewriting term " ^
+ (Syntax.string_of_term_global thy (fst pat)) ^ " to " ^
+ (Syntax.string_of_term_global thy (snd pat)) ^ " in " ^
+ (Syntax.string_of_term_global thy intro_t))
+ val intro_t' = Pattern.rewrite_term thy [pat] [] intro_t
+ val args' = map (Pattern.rewrite_term thy [pat] []) args
+ in
+ rewrite_arg' (Free (y, T2), T2) (args', (intro_t', x::y::names))
+ end
+ | rewrite_arg' _ (args, (intro_t, names)) = (args, (intro_t, names))
+ val (args', (intro_t', names')) = rewrite_arg' (arg, fastype_of arg)
+ (args, (intro_t, names))
+ in
+ rewrite_args args' (intro_t', names')
+ end
+ | _ => rewrite_args args (intro_t, names))
+ fun rewrite_prem (Prem (args, _)) = rewrite_args args
+ | rewrite_prem (Negprem (args, _)) = rewrite_args args
+ | rewrite_prem _ = I
+ val intro_t = Logic.unvarify (prop_of intro)
+ val frees = Term.add_free_names intro_t []
+ val concl = Logic.strip_imp_concl intro_t
+ val (p, args) = strip_comb (HOLogic.dest_Trueprop concl)
+ val params = List.take (args, nparams_of thy (fst (dest_Const p)))
+ val (intro_t', frees') = rewrite_args args (intro_t, frees)
+ val (intro_t', _) =
+ fold (rewrite_prem o dest_prem thy params o HOLogic.dest_Trueprop)
+ (Logic.strip_imp_prems intro_t') (intro_t', frees')
+ val _ = tracing ("intro_t': " ^ (Syntax.string_of_term_global thy intro_t'))
+ in
+ Goal.prove (ProofContext.init thy) (Term.add_free_names intro_t' []) []
+ intro_t' (fn _ => setmp quick_and_dirty true SkipProof.cheat_tac thy)
+ end
+
(** main function of predicate compiler **)
fun add_equations_of steps prednames thy =
let
val _ = Output.tracing ("Starting predicate compiler for predicates " ^ commas prednames ^ "...")
val _ = Output.tracing (commas (map (Display.string_of_thm_global thy) (maps (intros_of thy) prednames)))
+ val intros' = map (expand_tuples thy) (maps (intros_of thy) prednames)
+ val _ = map (check_format_of_intro_rule thy) intros'
val (preds, nparams, all_vs, param_vs, extra_modes, clauses, all_modes) =
- prepare_intrs thy prednames
+ prepare_intrs thy prednames intros'
val _ = Output.tracing "Infering modes..."
val moded_clauses = #infer_modes steps thy extra_modes all_modes param_vs clauses
val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses