--- a/src/Tools/induct.ML Sat Jan 09 23:22:56 2010 +0100
+++ b/src/Tools/induct.ML Sun Jan 10 18:01:04 2010 +0100
@@ -10,6 +10,8 @@
val atomize: thm list
val rulify: thm list
val rulify_fallback: thm list
+ val dest_def: term -> (term * term) option
+ val trivial_tac: int -> tactic
end;
signature INDUCT =
@@ -42,6 +44,9 @@
val coinduct_type: string -> attribute
val coinduct_pred: string -> attribute
val coinduct_del: attribute
+ val map_simpset: (simpset -> simpset) -> Context.generic -> Context.generic
+ val add_simp_rule: attribute
+ val no_simpN: string
val casesN: string
val inductN: string
val coinductN: string
@@ -50,19 +55,24 @@
val setN: string
(*proof methods*)
val fix_tac: Proof.context -> int -> (string * typ) list -> int -> tactic
- val add_defs: (binding option * term) option list -> Proof.context ->
+ val add_defs: (binding option * (term * bool)) option list -> Proof.context ->
(term option list * thm list) * Proof.context
val atomize_term: theory -> term -> term
+ val atomize_cterm: conv
val atomize_tac: int -> tactic
val inner_atomize_tac: int -> tactic
val rulified_term: thm -> theory * term
val rulify_tac: int -> tactic
+ val simplified_rule: Proof.context -> thm -> thm
+ val simplify_tac: Proof.context -> int -> tactic
+ val trivial_tac: int -> tactic
+ val rotate_tac: int -> int -> int -> tactic
val internalize: int -> thm -> thm
val guess_instance: Proof.context -> thm -> int -> thm -> thm Seq.seq
val cases_tac: Proof.context -> term option list list -> thm option ->
thm list -> int -> cases_tactic
val get_inductT: Proof.context -> term option list list -> thm list list
- val induct_tac: Proof.context -> (binding option * term) option list list ->
+ val induct_tac: Proof.context -> bool -> (binding option * (term * bool)) option list list ->
(string * typ) list list -> term option list -> thm list option ->
thm list -> int -> cases_tactic
val coinduct_tac: Proof.context -> term option list -> term option list -> thm option ->
@@ -107,6 +117,77 @@
+(** constraint simplification **)
+
+(* rearrange parameters and premises to allow application of one-point-rules *)
+
+fun swap_params_conv ctxt i j cv =
+ let
+ fun conv1 0 ctxt = Conv.forall_conv (cv o snd) ctxt
+ | conv1 k ctxt =
+ Conv.rewr_conv @{thm swap_params} then_conv
+ Conv.forall_conv (conv1 (k-1) o snd) ctxt
+ fun conv2 0 ctxt = conv1 j ctxt
+ | conv2 k ctxt = Conv.forall_conv (conv2 (k-1) o snd) ctxt
+ in conv2 i ctxt end;
+
+fun swap_prems_conv 0 = Conv.all_conv
+ | swap_prems_conv i =
+ Conv.implies_concl_conv (swap_prems_conv (i-1)) then_conv
+ Conv.rewr_conv Drule.swap_prems_eq
+
+fun drop_judgment ctxt = ObjectLogic.drop_judgment (ProofContext.theory_of ctxt);
+
+fun find_eq ctxt t =
+ let
+ val l = length (Logic.strip_params t);
+ val Hs = Logic.strip_assums_hyp t;
+ fun find (i, t) =
+ case Data.dest_def (drop_judgment ctxt t) of
+ SOME (Bound j, _) => SOME (i, j)
+ | SOME (_, Bound j) => SOME (i, j)
+ | _ => NONE
+ in
+ case get_first find (map_index I Hs) of
+ NONE => NONE
+ | SOME (0, 0) => NONE
+ | SOME (i, j) => SOME (i, l-j-1, j)
+ end;
+
+fun mk_swap_rrule ctxt ct = case find_eq ctxt (term_of ct) of
+ NONE => NONE
+ | SOME (i, k, j) => SOME (swap_params_conv ctxt k j (K (swap_prems_conv i)) ct);
+
+val rearrange_eqs_simproc = Simplifier.simproc
+ (Thm.theory_of_thm Drule.swap_prems_eq) "rearrange_eqs" ["all t"]
+ (fn thy => fn ss => fn t =>
+ mk_swap_rrule (Simplifier.the_context ss) (cterm_of thy t))
+
+(* rotate k premises to the left by j, skipping over first j premises *)
+
+fun rotate_conv 0 j 0 = Conv.all_conv
+ | rotate_conv 0 j k = swap_prems_conv j then_conv rotate_conv 1 j (k-1)
+ | rotate_conv i j k = Conv.implies_concl_conv (rotate_conv (i-1) j k);
+
+fun rotate_tac j 0 = K all_tac
+ | rotate_tac j k = SUBGOAL (fn (goal, i) => CONVERSION (rotate_conv
+ j (length (Logic.strip_assums_hyp goal) - j - k) k) i);
+
+(* rulify operators around definition *)
+
+fun rulify_defs_conv ctxt ct =
+ if exists_subterm (is_some o Data.dest_def) (term_of ct) andalso
+ not (is_some (Data.dest_def (drop_judgment ctxt (term_of ct))))
+ then
+ (Conv.forall_conv (rulify_defs_conv o snd) ctxt else_conv
+ Conv.implies_conv (Conv.try_conv (rulify_defs_conv ctxt))
+ (Conv.try_conv (rulify_defs_conv ctxt)) else_conv
+ Conv.first_conv (map Conv.rewr_conv Data.rulify) then_conv
+ Conv.try_conv (rulify_defs_conv ctxt)) ct
+ else Conv.no_conv ct;
+
+
+
(** induct data **)
(* rules *)
@@ -132,23 +213,25 @@
structure InductData = Generic_Data
(
- type T = (rules * rules) * (rules * rules) * (rules * rules);
+ type T = (rules * rules) * (rules * rules) * (rules * rules) * simpset;
val empty =
((init_rules (left_var_prem o #2), init_rules (Thm.major_prem_of o #2)),
(init_rules (right_var_concl o #2), init_rules (Thm.major_prem_of o #2)),
- (init_rules (left_var_concl o #2), init_rules (Thm.concl_of o #2)));
+ (init_rules (left_var_concl o #2), init_rules (Thm.concl_of o #2)),
+ empty_ss addsimprocs [rearrange_eqs_simproc] addsimps [Drule.norm_hhf_eq]);
val extend = I;
- fun merge (((casesT1, casesP1), (inductT1, inductP1), (coinductT1, coinductP1)),
- ((casesT2, casesP2), (inductT2, inductP2), (coinductT2, coinductP2))) =
+ fun merge (((casesT1, casesP1), (inductT1, inductP1), (coinductT1, coinductP1), simpset1),
+ ((casesT2, casesP2), (inductT2, inductP2), (coinductT2, coinductP2), simpset2)) =
((Item_Net.merge (casesT1, casesT2), Item_Net.merge (casesP1, casesP2)),
- (Item_Net.merge (inductT1, inductT2), Item_Net.merge (inductP1, inductP2)),
- (Item_Net.merge (coinductT1, coinductT2), Item_Net.merge (coinductP1, coinductP2)));
+ (Item_Net.merge (inductT1, inductT2), Item_Net.merge (inductP1, inductP2)),
+ (Item_Net.merge (coinductT1, coinductT2), Item_Net.merge (coinductP1, coinductP2)),
+ merge_ss (simpset1, simpset2));
);
val get_local = InductData.get o Context.Proof;
fun dest_rules ctxt =
- let val ((casesT, casesP), (inductT, inductP), (coinductT, coinductP)) = get_local ctxt in
+ let val ((casesT, casesP), (inductT, inductP), (coinductT, coinductP), _) = get_local ctxt in
{type_cases = Item_Net.content casesT,
pred_cases = Item_Net.content casesP,
type_induct = Item_Net.content inductT,
@@ -158,7 +241,7 @@
end;
fun print_rules ctxt =
- let val ((casesT, casesP), (inductT, inductP), (coinductT, coinductP)) = get_local ctxt in
+ let val ((casesT, casesP), (inductT, inductP), (coinductT, coinductP), _) = get_local ctxt in
[pretty_rules ctxt "coinduct type:" coinductT,
pretty_rules ctxt "coinduct pred:" coinductP,
pretty_rules ctxt "induct type:" inductT,
@@ -206,9 +289,10 @@
fun del_att which = Thm.declaration_attribute (fn th => InductData.map (which (pairself (fn rs =>
fold Item_Net.remove (filter_rules rs th) rs))));
-fun map1 f (x, y, z) = (f x, y, z);
-fun map2 f (x, y, z) = (x, f y, z);
-fun map3 f (x, y, z) = (x, y, f z);
+fun map1 f (x, y, z, s) = (f x, y, z, s);
+fun map2 f (x, y, z, s) = (x, f y, z, s);
+fun map3 f (x, y, z, s) = (x, y, f z, s);
+fun map4 f (x, y, z, s) = (x, y, z, f s);
fun add_casesT rule x = map1 (apfst (Item_Net.update rule)) x;
fun add_casesP rule x = map1 (apsnd (Item_Net.update rule)) x;
@@ -234,12 +318,17 @@
val coinduct_pred = mk_att add_coinductP consumes1;
val coinduct_del = del_att map3;
+fun map_simpset f = InductData.map (map4 f);
+fun add_simp_rule (ctxt, thm) =
+ (map_simpset (fn ss => ss addsimps [thm]) ctxt, thm);
+
end;
(** attribute syntax **)
+val no_simpN = "no_simp";
val casesN = "cases";
val inductN = "induct";
val coinductN = "coinduct";
@@ -268,7 +357,9 @@
Attrib.setup @{binding induct} (attrib induct_type induct_pred induct_del)
"declaration of induction rule" #>
Attrib.setup @{binding coinduct} (attrib coinduct_type coinduct_pred coinduct_del)
- "declaration of coinduction rule";
+ "declaration of coinduction rule" #>
+ Attrib.setup @{binding induct_simp} (Scan.succeed add_simp_rule)
+ "declaration of rules for simplifying induction or cases rules";
end;
@@ -362,7 +453,8 @@
ruleq
|> Seq.maps (Rule_Cases.consume [] facts)
|> Seq.maps (fn ((cases, (_, more_facts)), rule) =>
- CASES (Rule_Cases.make_common false (thy, Thm.prop_of rule) cases)
+ CASES (Rule_Cases.make_common (thy,
+ Thm.prop_of (Rule_Cases.internalize_params rule)) cases)
(Method.insert_tac more_facts i THEN Tactic.rtac rule i) st)
end;
@@ -409,6 +501,22 @@
(Simplifier.rewrite_goal_tac [@{thm Pure.conjunction_imp}] THEN' Goal.norm_hhf_tac);
+(* simplify *)
+
+fun simplify_conv ctxt ct =
+ if exists_subterm (is_some o Data.dest_def) (term_of ct) then
+ (Conv.try_conv (rulify_defs_conv ctxt) then_conv
+ Simplifier.full_rewrite (Simplifier.context ctxt (#4 (get_local ctxt)))) ct
+ else Conv.all_conv ct;
+
+fun simplified_rule ctxt thm =
+ Conv.fconv_rule (Conv.prems_conv ~1 (simplify_conv ctxt)) thm;
+
+fun simplify_tac ctxt = CONVERSION (simplify_conv ctxt);
+
+val trivial_tac = Data.trivial_tac;
+
+
(* prepare rule *)
fun rule_instance ctxt inst rule =
@@ -548,11 +656,19 @@
fun add_defs def_insts =
let
- fun add (SOME (SOME x, t)) ctxt =
+ fun add (SOME (_, (t, true))) ctxt = ((SOME t, []), ctxt)
+ | add (SOME (SOME x, (t, _))) ctxt =
let val ([(lhs, (_, th))], ctxt') =
LocalDefs.add_defs [((x, NoSyn), (Thm.empty_binding, t))] ctxt
in ((SOME lhs, [th]), ctxt') end
- | add (SOME (NONE, t)) ctxt = ((SOME t, []), ctxt)
+ | add (SOME (NONE, (t as Free _, _))) ctxt = ((SOME t, []), ctxt)
+ | add (SOME (NONE, (t, _))) ctxt =
+ let
+ val ([s], _) = Name.variants ["x"] (Variable.names_of ctxt);
+ val ([(lhs, (_, th))], ctxt') =
+ LocalDefs.add_defs [((Binding.name s, NoSyn),
+ (Thm.empty_binding, t))] ctxt
+ in ((SOME lhs, [th]), ctxt') end
| add NONE ctxt = ((NONE, []), ctxt);
in fold_map add def_insts #> apfst (split_list #> apsnd flat) end;
@@ -576,12 +692,12 @@
fun get_inductP ctxt (fact :: _) = map single (find_inductP ctxt (Thm.concl_of fact))
| get_inductP _ _ = [];
-fun induct_tac ctxt def_insts arbitrary taking opt_rule facts =
+fun induct_tac ctxt simp def_insts arbitrary taking opt_rule facts =
let
val thy = ProofContext.theory_of ctxt;
val ((insts, defs), defs_ctxt) = fold_map add_defs def_insts ctxt |>> split_list;
- val atomized_defs = map (map (Conv.fconv_rule ObjectLogic.atomize)) defs;
+ val atomized_defs = map (map (Conv.fconv_rule atomize_cterm)) defs;
fun inst_rule (concls, r) =
(if null insts then `Rule_Cases.get r
@@ -601,8 +717,10 @@
|> tap (trace_rules ctxt inductN o map #2)
|> Seq.of_list |> Seq.maps (Seq.try inst_rule));
- fun rule_cases rule =
- Rule_Cases.make_nested false (Thm.prop_of rule) (rulified_term rule);
+ fun rule_cases ctxt rule =
+ let val rule' = (if simp then simplified_rule ctxt else I)
+ (Rule_Cases.internalize_params rule);
+ in Rule_Cases.make_nested (Thm.prop_of rule') (rulified_term rule') end;
in
(fn i => fn st =>
ruleq
@@ -610,19 +728,32 @@
|> Seq.maps (fn (((cases, concls), (more_consumes, more_facts)), rule) =>
(PRECISE_CONJUNCTS (length concls) (ALLGOALS (fn j =>
(CONJUNCTS (ALLGOALS
- (Method.insert_tac (more_facts @ nth_list atomized_defs (j - 1))
- THEN' fix_tac defs_ctxt
- (nth concls (j - 1) + more_consumes)
- (nth_list arbitrary (j - 1))))
+ let
+ val adefs = nth_list atomized_defs (j - 1);
+ val frees = fold (Term.add_frees o prop_of) adefs [];
+ val xs = nth_list arbitrary (j - 1);
+ val k = nth concls (j - 1) + more_consumes
+ in
+ Method.insert_tac (more_facts @ adefs) THEN'
+ (if simp then
+ rotate_tac k (length adefs) THEN'
+ fix_tac defs_ctxt k
+ (List.partition (member op = frees) xs |> op @)
+ else
+ fix_tac defs_ctxt k xs)
+ end)
THEN' inner_atomize_tac) j))
THEN' atomize_tac) i st |> Seq.maps (fn st' =>
guess_instance ctxt (internalize more_consumes rule) i st'
|> Seq.map (rule_instance ctxt (burrow_options (Variable.polymorphic ctxt) taking))
|> Seq.maps (fn rule' =>
- CASES (rule_cases rule' cases)
+ CASES (rule_cases ctxt rule' cases)
(Tactic.rtac rule' i THEN
PRIMITIVE (singleton (ProofContext.export defs_ctxt ctxt))) st'))))
- THEN_ALL_NEW_CASES rulify_tac
+ THEN_ALL_NEW_CASES
+ ((if simp then simplify_tac ctxt THEN' (TRY o trivial_tac)
+ else K all_tac)
+ THEN_ALL_NEW rulify_tac)
end;
@@ -672,7 +803,8 @@
guess_instance ctxt rule i st
|> Seq.map (rule_instance ctxt (burrow_options (Variable.polymorphic ctxt) taking))
|> Seq.maps (fn rule' =>
- CASES (Rule_Cases.make_common false (thy, Thm.prop_of rule') cases)
+ CASES (Rule_Cases.make_common (thy,
+ Thm.prop_of (Rule_Cases.internalize_params rule')) cases)
(Method.insert_tac more_facts i THEN Tactic.rtac rule' i) st)))
end;
@@ -711,10 +843,15 @@
val inst = Scan.lift (Args.$$$ "_") >> K NONE || Args.term >> SOME;
+val inst' = Scan.lift (Args.$$$ "_") >> K NONE ||
+ Args.term >> (SOME o rpair false) ||
+ Scan.lift (Args.$$$ "(") |-- (Args.term >> (SOME o rpair true)) --|
+ Scan.lift (Args.$$$ ")");
+
val def_inst =
((Scan.lift (Args.binding --| (Args.$$$ "\<equiv>" || Args.$$$ "==")) >> SOME)
- -- Args.term) >> SOME ||
- inst >> Option.map (pair NONE);
+ -- (Args.term >> rpair false)) >> SOME ||
+ inst' >> Option.map (pair NONE);
val free = Args.context -- Args.term >> (fn (_, Free v) => v | (ctxt, t) =>
error ("Bad free variable: " ^ Syntax.string_of_term ctxt t));
@@ -740,11 +877,11 @@
val induct_setup =
Method.setup @{binding induct}
- (P.and_list' (Scan.repeat (unless_more_args def_inst)) --
- (arbitrary -- taking -- Scan.option induct_rule) >>
- (fn (insts, ((arbitrary, taking), opt_rule)) => fn ctxt =>
+ (Args.mode no_simpN -- (P.and_list' (Scan.repeat (unless_more_args def_inst)) --
+ (arbitrary -- taking -- Scan.option induct_rule)) >>
+ (fn (no_simp, (insts, ((arbitrary, taking), opt_rule))) => fn ctxt =>
RAW_METHOD_CASES (fn facts =>
- Seq.DETERM (HEADGOAL (induct_tac ctxt insts arbitrary taking opt_rule facts)))))
+ Seq.DETERM (HEADGOAL (induct_tac ctxt (not no_simp) insts arbitrary taking opt_rule facts)))))
"induction on types or predicates/sets";
val coinduct_setup =