--- a/src/Pure/raw_simplifier.ML Wed Aug 14 18:59:49 2024 +0200
+++ b/src/Pure/raw_simplifier.ML Wed Aug 14 21:23:22 2024 +0200
@@ -25,9 +25,11 @@
type simpset
val empty_ss: simpset
val merge_ss: simpset * simpset -> simpset
+ datatype proc_kind = Simproc | Congproc of bool
type simproc
val cert_simproc: theory ->
- {name: string, lhss: term list, proc: proc Morphism.entity, identifier: thm list} -> simproc
+ {name: string, kind: proc_kind, lhss: term list,
+ proc: proc Morphism.entity, identifier: thm list} -> simproc
val transform_simproc: morphism -> simproc -> simproc
val trim_context_simproc: simproc -> simproc
val simpset_of: Proof.context -> simpset
@@ -68,6 +70,7 @@
val dest_ss: simpset ->
{simps: (Thm_Name.T * thm) list,
simprocs: (string * term list) list,
+ congprocs: (string * {lhss: term list, proc: proc Morphism.entity}) list,
congs: (cong_name * thm) list,
weak_congs: cong_name list,
loopers: string list,
@@ -210,15 +213,31 @@
(* simplification procedures *)
+datatype proc_kind = Simproc | Congproc of bool;
+
+val is_congproc = fn Congproc _ => true | _ => false;
+val is_weak_congproc = fn Congproc weak => weak | _ => false;
+
+fun map_procs kind f (simprocs, congprocs) =
+ if is_congproc kind then (simprocs, f congprocs) else (f simprocs, congprocs);
+
+fun print_proc_kind Simproc = "simplification procedure"
+ | print_proc_kind (Congproc false) = "simplification procedure (cong)"
+ | print_proc_kind (Congproc true) = "simplification procedure (weak cong)";
+
type proc = Proof.context -> cterm -> thm option;
datatype 'lhs procedure =
Procedure of
{name: string,
+ kind: proc_kind,
lhs: 'lhs,
proc: proc Morphism.entity,
id: stamp * thm list};
+fun procedure_kind (Procedure {kind, ...}) = kind;
+fun procedure_lhs (Procedure {lhs, ...}) = lhs;
+
fun eq_procedure_id (Procedure {id = (s1, ths1), ...}, Procedure {id = (s2, ths2), ...}) =
s1 = s2 andalso eq_list Thm.eq_thm_prop (ths1, ths2);
@@ -249,6 +268,7 @@
A congruence is `weak' if it avoids normalization of some argument.
procs: simplification procedures indexed via discrimination net
simprocs: functions that prove rewrite rules on the fly;
+ congprocs: functions that prove congruence rules on the fly;
mk_rews:
mk: turn simplification thms into rewrite rules;
mk_cong: prepare congruence rules;
@@ -262,7 +282,7 @@
prems: thm list,
depth: int * bool Unsynchronized.ref} *
{congs: thm Congtab.table * cong_name list,
- procs: term procedure Net.net,
+ procs: term procedure Net.net * term procedure Net.net,
mk_rews:
{mk: Proof.context -> thm -> thm list,
mk_cong: Proof.context -> thm -> thm,
@@ -298,11 +318,13 @@
fun dest_procs procs =
Net.entries procs
|> partition_eq eq_procedure_id
- |> map (fn ps as Procedure {name, ...} :: _ => (name, map (fn Procedure {lhs, ...} => lhs) ps));
+ |> map (fn ps as Procedure {name, proc, ...} :: _ =>
+ (name, {lhss = map (fn Procedure {lhs, ...} => lhs) ps, proc = proc}));
-fun dest_ss (ss as Simpset (_, {congs, procs = simprocs, loop_tacs, solvers, ...})) =
+fun dest_ss (ss as Simpset (_, {congs, procs = (simprocs, congprocs), loop_tacs, solvers, ...})) =
{simps = dest_simps ss,
- simprocs = dest_procs simprocs,
+ simprocs = map (apsnd #lhss) (dest_procs simprocs),
+ congprocs = dest_procs congprocs,
congs = dest_congs ss,
weak_congs = #2 congs,
loopers = map fst loop_tacs,
@@ -314,7 +336,7 @@
fun init_ss depth mk_rews term_ord subgoal_tac solvers =
make_simpset ((Net.empty, [], depth),
- ((Congtab.empty, []), Net.empty, mk_rews, term_ord, subgoal_tac, [], solvers));
+ ((Congtab.empty, []), (Net.empty, Net.empty), mk_rews, term_ord, subgoal_tac, [], solvers));
fun default_mk_sym _ th = SOME (th RS Drule.symmetric_thm);
@@ -335,11 +357,11 @@
else
let
val Simpset ({rules = rules1, prems = prems1, depth = depth1},
- {congs = (congs1, weak1), procs = simprocs1, mk_rews, term_ord, subgoal_tac,
+ {congs = (congs1, weak1), procs = (simprocs1, congprocs1), mk_rews, term_ord, subgoal_tac,
loop_tacs = loop_tacs1, solvers = (unsafe_solvers1, solvers1)}) = ss1;
val Simpset ({rules = rules2, prems = prems2, depth = depth2},
- {congs = (congs2, weak2), procs = simprocs2, mk_rews = _, term_ord = _, subgoal_tac = _,
- loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2)}) = ss2;
+ {congs = (congs2, weak2), procs = (simprocs2, congprocs2),
+ loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2), ...}) = ss2;
val rules' = Net.merge eq_rrule (rules1, rules2);
val prems' = Thm.merge_thms (prems1, prems2);
@@ -347,11 +369,12 @@
val congs' = Congtab.merge (K true) (congs1, congs2);
val weak' = merge (op =) (weak1, weak2);
val simprocs' = Net.merge eq_procedure_id (simprocs1, simprocs2);
+ val congprocs' = Net.merge eq_procedure_id (congprocs1, congprocs2);
val loop_tacs' = AList.merge (op =) (K true) (loop_tacs1, loop_tacs2);
val unsafe_solvers' = merge eq_solver (unsafe_solvers1, unsafe_solvers2);
val solvers' = merge eq_solver (solvers1, solvers2);
in
- make_simpset ((rules', prems', depth'), ((congs', weak'), simprocs',
+ make_simpset ((rules', prems', depth'), ((congs', weak'), (simprocs', congprocs'),
mk_rews, term_ord, subgoal_tac, loop_tacs', (unsafe_solvers', solvers')))
end;
@@ -728,51 +751,55 @@
type simproc = term list procedure;
-fun cert_simproc thy {name, lhss, proc, identifier} : simproc =
+fun cert_simproc thy {name, kind, lhss, proc, identifier} : simproc =
Procedure
{name = name,
+ kind = kind,
lhs = map (Sign.cert_term thy) lhss,
proc = proc,
id = (stamp (), map (Thm.transfer thy) identifier)};
-fun transform_simproc phi (Procedure {name, lhs, proc, id = (stamp, identifier)}) : simproc =
+fun transform_simproc phi (Procedure {name, kind, lhs, proc, id = (stamp, identifier)}) : simproc =
Procedure
{name = name,
+ kind = kind,
lhs = map (Morphism.term phi) lhs,
proc = Morphism.transform phi proc,
id = (stamp, Morphism.fact phi identifier)};
-fun trim_context_simproc (Procedure {name, lhs, proc, id = (stamp, identifier)}) : simproc =
+fun trim_context_simproc (Procedure {name, kind, lhs, proc, id = (stamp, identifier)}) : simproc =
Procedure
{name = name,
+ kind = kind,
lhs = lhs,
proc = Morphism.entity_reset_context proc,
id = (stamp, map Thm.trim_context identifier)};
local
-fun add_proc1 (proc as Procedure {name, lhs, ...}) ctxt =
+fun add_proc1 (proc as Procedure {name, kind, lhs, ...}) ctxt =
(cond_tracing ctxt (fn () =>
- print_term ctxt ("Adding simplification procedure " ^ quote name ^ " for") lhs);
+ print_term ctxt ("Adding " ^ print_proc_kind kind ^ " " ^ quote name ^ " for") lhs);
ctxt |> map_simpset2
(fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) =>
- (congs, Net.insert_term eq_procedure_id (lhs, proc) procs,
+ (congs, map_procs kind (Net.insert_term eq_procedure_id (lhs, proc)) procs,
mk_rews, term_ord, subgoal_tac, loop_tacs, solvers))
handle Net.INSERT =>
- (cond_warning ctxt (fn () => "Ignoring duplicate simplification procedure " ^ quote name);
+ (cond_warning ctxt (fn () =>
+ "Ignoring duplicate " ^ print_proc_kind kind ^ " " ^ quote name);
ctxt));
-fun del_proc1 (proc as Procedure {name, lhs, ...}) ctxt =
+fun del_proc1 (proc as Procedure {name, kind, lhs, ...}) ctxt =
ctxt |> map_simpset2
(fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) =>
- (congs, Net.delete_term eq_procedure_id (lhs, proc) procs,
+ (congs, map_procs kind (Net.delete_term eq_procedure_id (lhs, proc)) procs,
mk_rews, term_ord, subgoal_tac, loop_tacs, solvers))
handle Net.DELETE =>
- (cond_warning ctxt (fn () => "Simplification procedure " ^ quote name ^ " not in simpset");
+ (cond_warning ctxt (fn () => "No " ^ print_proc_kind kind ^ " " ^ quote name ^ " in simpset");
ctxt);
-fun split_proc (Procedure {name, lhs = lhss, proc, id} : simproc) =
- lhss |> map (fn lhs => Procedure {name = name, lhs = lhs, proc = proc, id = id});
+fun split_proc (Procedure {name, kind, lhs = lhss, proc, id} : simproc) =
+ lhss |> map (fn lhs => Procedure {name = name, kind = kind, lhs = lhs, proc = proc, id = id});
in
@@ -961,19 +988,24 @@
The latter may happen iff there are weak congruence rules for constants
in the lhs.*)
-fun uncond_skel ((_, weak), (lhs, rhs)) =
- if null weak then rhs (*optimization*)
- else if exists_subterm
+fun weak_cong weak lhs =
+ if null weak then false (*optimization*)
+ else exists_subterm
(fn Const (a, _) => member (op =) weak (true, a)
| Free (a, _) => member (op =) weak (false, a)
- | _ => false) lhs then skel0
+ | _ => false) lhs
+
+fun uncond_skel ((_, weak), congprocs, (lhs, rhs)) =
+ if weak_cong weak lhs then skel0
+ else if Net.is_empty congprocs then rhs (*optimization*)
+ else if exists (is_weak_congproc o procedure_kind) (Net.match_term congprocs lhs) then skel0
else rhs;
(*Behaves like unconditional rule if rhs does not contain vars not in the lhs.
Otherwise those vars may become instantiated with unnormalized terms
while the premises are solved.*)
-fun cond_skel (args as (_, (lhs, rhs))) =
+fun cond_skel (args as (_, _, (lhs, rhs))) =
if Vars.subset (vars_set rhs, vars_set lhs) then uncond_skel args
else skel0;
@@ -989,7 +1021,8 @@
fun rewritec (prover, maxt) ctxt t =
let
- val Simpset ({rules, ...}, {congs, procs = simprocs, term_ord, ...}) = simpset_of ctxt;
+ val Simpset ({rules, ...}, {congs, procs = (simprocs, congprocs), term_ord, ...}) =
+ simpset_of ctxt;
val eta_thm = Thm.eta_conversion t;
val eta_t' = Thm.rhs_of eta_thm;
val eta_t = Thm.term_of eta_t';
@@ -1028,7 +1061,7 @@
let
val lr = Logic.dest_equals prop;
val SOME thm'' = check_conv ctxt' false eta_thm thm';
- in SOME (thm'', uncond_skel (congs, lr)) end))
+ in SOME (thm'', uncond_skel (congs, congprocs, lr)) end))
else
(cond_tracing ctxt (fn () => print_thm0 ctxt "Trying to rewrite:" thm');
if simp_depth ctxt > Config.get ctxt simp_depth_limit
@@ -1044,7 +1077,7 @@
let
val concl = Logic.strip_imp_concl prop;
val lr = Logic.dest_equals concl;
- in SOME (thm2', cond_skel (congs, lr)) end)))))
+ in SOME (thm2', cond_skel (congs, congprocs, lr)) end)))))
end;
fun rews [] = NONE
@@ -1098,6 +1131,63 @@
end;
+(* apply congprocs *)
+
+(* pattern order:
+ p1 GREATER p2: p1 is more general than p2, p1 matches p2 but not vice versa
+ p1 LESS p2: p1 is more specific than p2, p2 matches p1 but not vice versa
+ p1 EQUAL p2: both match each other or neither match each other
+*)
+
+fun pattern_order thy =
+ let
+ fun matches arg = can (Pattern.match thy arg) (Vartab.empty, Vartab.empty);
+ in
+ fn (p1, p2) =>
+ if matches (p1, p2) then
+ if matches (p2, p1) then EQUAL
+ else GREATER
+ else
+ if matches (p2, p1) then LESS
+ else EQUAL
+ end;
+
+fun app_congprocs ctxt ct =
+ let
+ val thy = Proof_Context.theory_of ctxt;
+ val Simpset (_, {procs = (_, congprocs), ...}) = simpset_of ctxt;
+
+ val eta_ct = Thm.rhs_of (Thm.eta_conversion ct);
+
+ fun proc_congs [] = NONE
+ | proc_congs (Procedure {name, lhs, proc, ...} :: ps) =
+ if Pattern.matches thy (lhs, Thm.term_of ct) then
+ let
+ val _ =
+ cond_tracing' ctxt simp_debug (fn () =>
+ print_term ctxt ("Trying procedure " ^ quote name ^ " on:") (Thm.term_of eta_ct));
+
+ val ctxt' = Config.put simp_trace (Config.get ctxt simp_debug) ctxt;
+ val res =
+ trace_simproc {name = name, cterm = eta_ct} ctxt'
+ (fn ctxt'' => Morphism.form_context' ctxt'' proc eta_ct);
+ in
+ (case res of
+ NONE => (cond_tracing' ctxt simp_debug (fn () => "FAILED"); proc_congs ps)
+ | SOME raw_thm =>
+ (cond_tracing ctxt (fn () =>
+ print_thm0 ctxt ("Procedure " ^ quote name ^ " produced congruence rule:")
+ raw_thm);
+ SOME (raw_thm, skel0)))
+ end
+ else proc_congs ps;
+ in
+ Net.match_term congprocs (Thm.term_of eta_ct)
+ |> sort (pattern_order thy o apply2 procedure_lhs)
+ |> proc_congs
+ end;
+
+
(* conversion to apply a congruence rule to a term *)
fun congc prover ctxt maxt cong t =
@@ -1202,8 +1292,15 @@
val (h, ts) = strip_comb t;
+ (*Prefer congprocs over plain cong rules. In congprocs prefer most specific rules.
+ If there is a matching congproc, then look into the result:
+ 1. plain equality: consider normalisation complete (just as with a plain congruence rule),
+ 2. conditional rule: treat like congruence rules like SOME cong case below.*)
+
fun app_cong () =
- Option.mapPartial (Congtab.lookup (fst congs)) (cong_name h);
+ (case app_congprocs ctxt t0 of
+ SOME (thm, _) => SOME thm
+ | NONE => Option.mapPartial (Congtab.lookup (fst congs)) (cong_name h));
in
(case app_cong () of
NONE => appc ()
@@ -1213,7 +1310,9 @@
(let
val thm = congc (prover ctxt) ctxt maxidx cong t0;
val t = the_default t0 (Option.map Thm.rhs_of thm);
- val (cl, cr) = Thm.dest_comb t;
+ val (cl, cr) = Thm.dest_comb t
+ handle CTERM _ => Thm.dest_comb t0; (*e.g. congproc has
+ normalized such that head is removed from t*)
val dVar = Var (("", 0), dummyT);
val skel = list_comb (h, replicate (length ts) dVar);
in