--- a/src/Pure/raw_simplifier.ML Wed Aug 14 18:59:49 2024 +0200
+++ b/src/Pure/raw_simplifier.ML Wed Aug 14 21:23:22 2024 +0200
@@ -25,9 +25,11 @@
type simpset
val empty_ss: simpset
val merge_ss: simpset * simpset -> simpset
+ datatype proc_kind = Simproc | Congproc of bool
type simproc
val cert_simproc: theory ->
- {name: string, lhss: term list, proc: proc Morphism.entity, identifier: thm list} -> simproc
+ {name: string, kind: proc_kind, lhss: term list,
+ proc: proc Morphism.entity, identifier: thm list} -> simproc
val transform_simproc: morphism -> simproc -> simproc
val trim_context_simproc: simproc -> simproc
val simpset_of: Proof.context -> simpset
@@ -68,6 +70,7 @@
val dest_ss: simpset ->
{simps: (Thm_Name.T * thm) list,
simprocs: (string * term list) list,
+ congprocs: (string * {lhss: term list, proc: proc Morphism.entity}) list,
congs: (cong_name * thm) list,
weak_congs: cong_name list,
loopers: string list,
@@ -210,15 +213,31 @@
(* simplification procedures *)
+datatype proc_kind = Simproc | Congproc of bool;
+
+val is_congproc = fn Congproc _ => true | _ => false;
+val is_weak_congproc = fn Congproc weak => weak | _ => false;
+
+fun map_procs kind f (simprocs, congprocs) =
+ if is_congproc kind then (simprocs, f congprocs) else (f simprocs, congprocs);
+
+fun print_proc_kind Simproc = "simplification procedure"
+ | print_proc_kind (Congproc false) = "simplification procedure (cong)"
+ | print_proc_kind (Congproc true) = "simplification procedure (weak cong)";
+
type proc = Proof.context -> cterm -> thm option;
datatype 'lhs procedure =
Procedure of
{name: string,
+ kind: proc_kind,
lhs: 'lhs,
proc: proc Morphism.entity,
id: stamp * thm list};
+fun procedure_kind (Procedure {kind, ...}) = kind;
+fun procedure_lhs (Procedure {lhs, ...}) = lhs;
+
fun eq_procedure_id (Procedure {id = (s1, ths1), ...}, Procedure {id = (s2, ths2), ...}) =
s1 = s2 andalso eq_list Thm.eq_thm_prop (ths1, ths2);
@@ -249,6 +268,7 @@
A congruence is `weak' if it avoids normalization of some argument.
procs: simplification procedures indexed via discrimination net
simprocs: functions that prove rewrite rules on the fly;
+ congprocs: functions that prove congruence rules on the fly;
mk_rews:
mk: turn simplification thms into rewrite rules;
mk_cong: prepare congruence rules;
@@ -262,7 +282,7 @@
prems: thm list,
depth: int * bool Unsynchronized.ref} *
{congs: thm Congtab.table * cong_name list,
- procs: term procedure Net.net,
+ procs: term procedure Net.net * term procedure Net.net,
mk_rews:
{mk: Proof.context -> thm -> thm list,
mk_cong: Proof.context -> thm -> thm,
@@ -298,11 +318,13 @@
fun dest_procs procs =
Net.entries procs
|> partition_eq eq_procedure_id
- |> map (fn ps as Procedure {name, ...} :: _ => (name, map (fn Procedure {lhs, ...} => lhs) ps));
+ |> map (fn ps as Procedure {name, proc, ...} :: _ =>
+ (name, {lhss = map (fn Procedure {lhs, ...} => lhs) ps, proc = proc}));
-fun dest_ss (ss as Simpset (_, {congs, procs = simprocs, loop_tacs, solvers, ...})) =
+fun dest_ss (ss as Simpset (_, {congs, procs = (simprocs, congprocs), loop_tacs, solvers, ...})) =
{simps = dest_simps ss,
- simprocs = dest_procs simprocs,
+ simprocs = map (apsnd #lhss) (dest_procs simprocs),
+ congprocs = dest_procs congprocs,
congs = dest_congs ss,
weak_congs = #2 congs,
loopers = map fst loop_tacs,
@@ -314,7 +336,7 @@
fun init_ss depth mk_rews term_ord subgoal_tac solvers =
make_simpset ((Net.empty, [], depth),
- ((Congtab.empty, []), Net.empty, mk_rews, term_ord, subgoal_tac, [], solvers));
+ ((Congtab.empty, []), (Net.empty, Net.empty), mk_rews, term_ord, subgoal_tac, [], solvers));
fun default_mk_sym _ th = SOME (th RS Drule.symmetric_thm);
@@ -335,11 +357,11 @@
else
let
val Simpset ({rules = rules1, prems = prems1, depth = depth1},
- {congs = (congs1, weak1), procs = simprocs1, mk_rews, term_ord, subgoal_tac,
+ {congs = (congs1, weak1), procs = (simprocs1, congprocs1), mk_rews, term_ord, subgoal_tac,
loop_tacs = loop_tacs1, solvers = (unsafe_solvers1, solvers1)}) = ss1;
val Simpset ({rules = rules2, prems = prems2, depth = depth2},
- {congs = (congs2, weak2), procs = simprocs2, mk_rews = _, term_ord = _, subgoal_tac = _,
- loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2)}) = ss2;
+ {congs = (congs2, weak2), procs = (simprocs2, congprocs2),
+ loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2), ...}) = ss2;
val rules' = Net.merge eq_rrule (rules1, rules2);
val prems' = Thm.merge_thms (prems1, prems2);
@@ -347,11 +369,12 @@
val congs' = Congtab.merge (K true) (congs1, congs2);
val weak' = merge (op =) (weak1, weak2);
val simprocs' = Net.merge eq_procedure_id (simprocs1, simprocs2);
+ val congprocs' = Net.merge eq_procedure_id (congprocs1, congprocs2);
val loop_tacs' = AList.merge (op =) (K true) (loop_tacs1, loop_tacs2);
val unsafe_solvers' = merge eq_solver (unsafe_solvers1, unsafe_solvers2);
val solvers' = merge eq_solver (solvers1, solvers2);
in
- make_simpset ((rules', prems', depth'), ((congs', weak'), simprocs',
+ make_simpset ((rules', prems', depth'), ((congs', weak'), (simprocs', congprocs'),
mk_rews, term_ord, subgoal_tac, loop_tacs', (unsafe_solvers', solvers')))
end;
@@ -728,51 +751,55 @@
type simproc = term list procedure;
-fun cert_simproc thy {name, lhss, proc, identifier} : simproc =
+fun cert_simproc thy {name, kind, lhss, proc, identifier} : simproc =
Procedure
{name = name,
+ kind = kind,
lhs = map (Sign.cert_term thy) lhss,
proc = proc,
id = (stamp (), map (Thm.transfer thy) identifier)};
-fun transform_simproc phi (Procedure {name, lhs, proc, id = (stamp, identifier)}) : simproc =
+fun transform_simproc phi (Procedure {name, kind, lhs, proc, id = (stamp, identifier)}) : simproc =
Procedure
{name = name,
+ kind = kind,
lhs = map (Morphism.term phi) lhs,
proc = Morphism.transform phi proc,
id = (stamp, Morphism.fact phi identifier)};
-fun trim_context_simproc (Procedure {name, lhs, proc, id = (stamp, identifier)}) : simproc =
+fun trim_context_simproc (Procedure {name, kind, lhs, proc, id = (stamp, identifier)}) : simproc =
Procedure
{name = name,
+ kind = kind,
lhs = lhs,
proc = Morphism.entity_reset_context proc,
id = (stamp, map Thm.trim_context identifier)};
local
-fun add_proc1 (proc as Procedure {name, lhs, ...}) ctxt =
+fun add_proc1 (proc as Procedure {name, kind, lhs, ...}) ctxt =
(cond_tracing ctxt (fn () =>
- print_term ctxt ("Adding simplification procedure " ^ quote name ^ " for") lhs);
+ print_term ctxt ("Adding " ^ print_proc_kind kind ^ " " ^ quote name ^ " for") lhs);
ctxt |> map_simpset2
(fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) =>
- (congs, Net.insert_term eq_procedure_id (lhs, proc) procs,
+ (congs, map_procs kind (Net.insert_term eq_procedure_id (lhs, proc)) procs,
mk_rews, term_ord, subgoal_tac, loop_tacs, solvers))
handle Net.INSERT =>
- (cond_warning ctxt (fn () => "Ignoring duplicate simplification procedure " ^ quote name);
+ (cond_warning ctxt (fn () =>
+ "Ignoring duplicate " ^ print_proc_kind kind ^ " " ^ quote name);
ctxt));
-fun del_proc1 (proc as Procedure {name, lhs, ...}) ctxt =
+fun del_proc1 (proc as Procedure {name, kind, lhs, ...}) ctxt =
ctxt |> map_simpset2
(fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) =>
- (congs, Net.delete_term eq_procedure_id (lhs, proc) procs,
+ (congs, map_procs kind (Net.delete_term eq_procedure_id (lhs, proc)) procs,
mk_rews, term_ord, subgoal_tac, loop_tacs, solvers))
handle Net.DELETE =>
- (cond_warning ctxt (fn () => "Simplification procedure " ^ quote name ^ " not in simpset");
+ (cond_warning ctxt (fn () => "No " ^ print_proc_kind kind ^ " " ^ quote name ^ " in simpset");
ctxt);
-fun split_proc (Procedure {name, lhs = lhss, proc, id} : simproc) =
- lhss |> map (fn lhs => Procedure {name = name, lhs = lhs, proc = proc, id = id});
+fun split_proc (Procedure {name, kind, lhs = lhss, proc, id} : simproc) =
+ lhss |> map (fn lhs => Procedure {name = name, kind = kind, lhs = lhs, proc = proc, id = id});
in
@@ -961,19 +988,24 @@
The latter may happen iff there are weak congruence rules for constants
in the lhs.*)
-fun uncond_skel ((_, weak), (lhs, rhs)) =
- if null weak then rhs (*optimization*)
- else if exists_subterm
+fun weak_cong weak lhs =
+ if null weak then false (*optimization*)
+ else exists_subterm
(fn Const (a, _) => member (op =) weak (true, a)
| Free (a, _) => member (op =) weak (false, a)
- | _ => false) lhs then skel0
+ | _ => false) lhs
+
+fun uncond_skel ((_, weak), congprocs, (lhs, rhs)) =
+ if weak_cong weak lhs then skel0
+ else if Net.is_empty congprocs then rhs (*optimization*)
+ else if exists (is_weak_congproc o procedure_kind) (Net.match_term congprocs lhs) then skel0
else rhs;
(*Behaves like unconditional rule if rhs does not contain vars not in the lhs.
Otherwise those vars may become instantiated with unnormalized terms
while the premises are solved.*)
-fun cond_skel (args as (_, (lhs, rhs))) =
+fun cond_skel (args as (_, _, (lhs, rhs))) =
if Vars.subset (vars_set rhs, vars_set lhs) then uncond_skel args
else skel0;
@@ -989,7 +1021,8 @@
fun rewritec (prover, maxt) ctxt t =
let
- val Simpset ({rules, ...}, {congs, procs = simprocs, term_ord, ...}) = simpset_of ctxt;
+ val Simpset ({rules, ...}, {congs, procs = (simprocs, congprocs), term_ord, ...}) =
+ simpset_of ctxt;
val eta_thm = Thm.eta_conversion t;
val eta_t' = Thm.rhs_of eta_thm;
val eta_t = Thm.term_of eta_t';
@@ -1028,7 +1061,7 @@
let
val lr = Logic.dest_equals prop;
val SOME thm'' = check_conv ctxt' false eta_thm thm';
- in SOME (thm'', uncond_skel (congs, lr)) end))
+ in SOME (thm'', uncond_skel (congs, congprocs, lr)) end))
else
(cond_tracing ctxt (fn () => print_thm0 ctxt "Trying to rewrite:" thm');
if simp_depth ctxt > Config.get ctxt simp_depth_limit
@@ -1044,7 +1077,7 @@
let
val concl = Logic.strip_imp_concl prop;
val lr = Logic.dest_equals concl;
- in SOME (thm2', cond_skel (congs, lr)) end)))))
+ in SOME (thm2', cond_skel (congs, congprocs, lr)) end)))))
end;
fun rews [] = NONE
@@ -1098,6 +1131,63 @@
end;
+(* apply congprocs *)
+
+(* pattern order:
+ p1 GREATER p2: p1 is more general than p2, p1 matches p2 but not vice versa
+ p1 LESS p2: p1 is more specific than p2, p2 matches p1 but not vice versa
+ p1 EQUAL p2: both match each other or neither match each other
+*)
+
+fun pattern_order thy =
+ let
+ fun matches arg = can (Pattern.match thy arg) (Vartab.empty, Vartab.empty);
+ in
+ fn (p1, p2) =>
+ if matches (p1, p2) then
+ if matches (p2, p1) then EQUAL
+ else GREATER
+ else
+ if matches (p2, p1) then LESS
+ else EQUAL
+ end;
+
+fun app_congprocs ctxt ct =
+ let
+ val thy = Proof_Context.theory_of ctxt;
+ val Simpset (_, {procs = (_, congprocs), ...}) = simpset_of ctxt;
+
+ val eta_ct = Thm.rhs_of (Thm.eta_conversion ct);
+
+ fun proc_congs [] = NONE
+ | proc_congs (Procedure {name, lhs, proc, ...} :: ps) =
+ if Pattern.matches thy (lhs, Thm.term_of ct) then
+ let
+ val _ =
+ cond_tracing' ctxt simp_debug (fn () =>
+ print_term ctxt ("Trying procedure " ^ quote name ^ " on:") (Thm.term_of eta_ct));
+
+ val ctxt' = Config.put simp_trace (Config.get ctxt simp_debug) ctxt;
+ val res =
+ trace_simproc {name = name, cterm = eta_ct} ctxt'
+ (fn ctxt'' => Morphism.form_context' ctxt'' proc eta_ct);
+ in
+ (case res of
+ NONE => (cond_tracing' ctxt simp_debug (fn () => "FAILED"); proc_congs ps)
+ | SOME raw_thm =>
+ (cond_tracing ctxt (fn () =>
+ print_thm0 ctxt ("Procedure " ^ quote name ^ " produced congruence rule:")
+ raw_thm);
+ SOME (raw_thm, skel0)))
+ end
+ else proc_congs ps;
+ in
+ Net.match_term congprocs (Thm.term_of eta_ct)
+ |> sort (pattern_order thy o apply2 procedure_lhs)
+ |> proc_congs
+ end;
+
+
(* conversion to apply a congruence rule to a term *)
fun congc prover ctxt maxt cong t =
@@ -1202,8 +1292,15 @@
val (h, ts) = strip_comb t;
+ (*Prefer congprocs over plain cong rules. In congprocs prefer most specific rules.
+ If there is a matching congproc, then look into the result:
+ 1. plain equality: consider normalisation complete (just as with a plain congruence rule),
+ 2. conditional rule: treat like congruence rules like SOME cong case below.*)
+
fun app_cong () =
- Option.mapPartial (Congtab.lookup (fst congs)) (cong_name h);
+ (case app_congprocs ctxt t0 of
+ SOME (thm, _) => SOME thm
+ | NONE => Option.mapPartial (Congtab.lookup (fst congs)) (cong_name h));
in
(case app_cong () of
NONE => appc ()
@@ -1213,7 +1310,9 @@
(let
val thm = congc (prover ctxt) ctxt maxidx cong t0;
val t = the_default t0 (Option.map Thm.rhs_of thm);
- val (cl, cr) = Thm.dest_comb t;
+ val (cl, cr) = Thm.dest_comb t
+ handle CTERM _ => Thm.dest_comb t0; (*e.g. congproc has
+ normalized such that head is removed from t*)
val dVar = Var (("", 0), dummyT);
val skel = list_comb (h, replicate (length ts) dVar);
in
--- a/src/Pure/simplifier.ML Wed Aug 14 18:59:49 2024 +0200
+++ b/src/Pure/simplifier.ML Wed Aug 14 21:23:22 2024 +0200
@@ -40,9 +40,10 @@
val check_simproc: Proof.context -> xstring * Position.T -> string * simproc
val the_simproc: Proof.context -> string -> simproc
val make_simproc: Proof.context ->
- {name: string, lhss: term list, proc: morphism -> proc, identifier: thm list} -> simproc
+ {name: string, kind: proc_kind, lhss: term list, proc: morphism -> proc, identifier: thm list} ->
+ simproc
type ('a, 'b, 'c) simproc_spec =
- {passive: bool, name: binding, lhss: 'a list, proc: 'b, identifier: 'c}
+ {passive: bool, name: binding, kind: proc_kind, lhss: 'a list, proc: 'b, identifier: 'c}
val read_simproc_spec: Proof.context ->
(string, 'b, 'c) simproc_spec -> (term, 'b, 'c) simproc_spec
val define_simproc: (term, morphism -> proc, thm list) simproc_spec -> local_theory ->
@@ -131,30 +132,33 @@
(* define simprocs *)
-fun make_simproc ctxt {name, lhss, proc, identifier} =
+fun make_simproc ctxt {name, lhss, kind, proc, identifier} =
let
val ctxt' = fold Proof_Context.augment lhss ctxt;
val lhss' = Variable.export_terms ctxt' ctxt lhss;
in
cert_simproc (Proof_Context.theory_of ctxt)
- {name = name, lhss = lhss', proc = Morphism.entity proc, identifier = identifier}
+ {name = name, kind = kind, lhss = lhss', proc = Morphism.entity proc, identifier = identifier}
end;
type ('a, 'b, 'c) simproc_spec =
- {passive: bool, name: binding, lhss: 'a list, proc: 'b, identifier: 'c};
+ {passive: bool, name: binding, kind: proc_kind, lhss: 'a list, proc: 'b, identifier: 'c};
-fun read_simproc_spec ctxt {passive, name, lhss, proc, identifier} =
+fun read_simproc_spec ctxt {passive, name, kind, lhss, proc, identifier} =
let
val lhss' =
Syntax.read_terms ctxt lhss handle ERROR msg =>
error (msg ^ Position.here_list (map Syntax.read_input_pos lhss));
- in {passive = passive, name = name, lhss = lhss', proc = proc, identifier = identifier} end;
+ in
+ {passive = passive, name = name, kind = kind, lhss = lhss', proc = proc, identifier = identifier}
+ end;
-fun define_simproc {passive, name, lhss, proc, identifier} lthy =
+fun define_simproc {passive, name, kind, lhss, proc, identifier} lthy =
let
val simproc0 =
make_simproc lthy
- {name = Local_Theory.full_name lthy name, lhss = lhss, proc = proc, identifier = identifier};
+ {name = Local_Theory.full_name lthy name,
+ kind = kind, lhss = lhss, proc = proc, identifier = identifier};
in
lthy |> Local_Theory.declaration {syntax = false, pervasive = false, pos = Binding.pos_of name}
(fn phi => fn context =>
@@ -179,14 +183,24 @@
Named_Target.setup_result Raw_Simplifier.transform_simproc
(fn lthy => lthy |> define_simproc (read_simproc_spec lthy args));
+val parse_proc_kind =
+ Parse.$$$ "congproc" >> K (Congproc false) ||
+ Parse.$$$ "weak_congproc" >> K (Congproc true) ||
+ Scan.succeed Simproc;
+
+fun print_proc_kind kind =
+ (case kind of
+ Simproc => "Simplifier.Simproc"
+ | Congproc weak => "Simplifier.Congproc " ^ Bool.toString weak);
val parse_simproc_spec =
- Scan.optional (Parse.$$$ "passive" >> K true) false --
+ Scan.optional (Parse.$$$ "passive" >> K true) false -- parse_proc_kind --
Parse.binding --
(Parse.$$$ "(" |-- Parse.enum1 "|" Parse.term --| Parse.$$$ ")") --
(Parse.$$$ "=" |-- Parse.ML_source) --
Scan.option ((Parse.position (Parse.$$$ "identifier") >> #2) -- Parse.thms1)
- >> (fn ((((a, b), c), d), e) => {passive = a, name = b, lhss = c, proc = d, identifier = e});
+ >> (fn (((((a, b), c), d), e), f) =>
+ {passive = a, kind = b, name = c, lhss = d, proc = e, identifier = f});
val _ = Theory.setup
(ML_Context.add_antiquotation_embedded \<^binding>\<open>simproc_setup\<close>
@@ -194,7 +208,7 @@
let
val ml = ML_Lex.tokenize_no_range;
- val {passive, name, lhss, proc, identifier} = input
+ val {passive, name, kind, lhss, proc, identifier} = input
|> Parse.read_embedded ctxt (Thy_Header.get_keywords' ctxt) parse_simproc_spec
|> read_simproc_spec ctxt;
@@ -211,6 +225,7 @@
val ml_body' =
ml "Simplifier.simproc_setup {passive = " @ ml (Bool.toString passive) @
ml ", name = " @ ml (ML_Syntax.make_binding (Binding.name_of name, Binding.pos_of name)) @
+ ml ", kind = " @ ml (print_proc_kind kind) @
ml ", lhss = " @ ml (ML_Syntax.print_list ML_Syntax.print_term lhss) @
ml ", proc = (" @ ml_body1 @ ml ")" @
ml ", identifier = (" @ ml_body2 @ ml ")}";
@@ -218,7 +233,7 @@
in (decl', ctxt2) end));
val simproc_setup_command =
- parse_simproc_spec >> (fn {passive, name, lhss, proc, identifier} =>
+ parse_simproc_spec >> (fn {passive, name, kind, lhss, proc, identifier} =>
(case identifier of
NONE =>
Context.proof_map
@@ -226,6 +241,7 @@
(ML_Lex.read
("Simplifier.simproc_setup_cmd {passive = " ^ Bool.toString passive ^
", name = " ^ ML_Syntax.make_binding (Binding.name_of name, Binding.pos_of name) ^
+ ", kind = " ^ print_proc_kind kind ^
", lhss = " ^ ML_Syntax.print_strings lhss ^
", proc = (") @ ML_Lex.read_source proc @ ML_Lex.read "), identifier = []}"))
| SOME (pos, _) =>
@@ -276,18 +292,31 @@
(Pretty.mark_str name :: Pretty.str ":" :: Pretty.fbrk ::
Pretty.fbreaks (map (Pretty.item o single o pretty_term) lhss));
+ fun pretty_congproc (name, {lhss, proc}) =
+ let
+ val prt_rule =
+ (case try (Morphism.form_context' ctxt proc) @{cterm dummy} of
+ SOME (SOME thm) => [Pretty.fbrk, Pretty.str "rule:", Pretty.fbrk, pretty_thm thm]
+ | NONE => []);
+ in
+ Pretty.block
+ (Pretty.mark_str name :: Pretty.str ":" :: Pretty.fbrk ::
+ Pretty.fbreaks (map (Pretty.item o single o pretty_term) lhss) @ prt_rule)
+ end;
+
fun pretty_cong_name (const, name) =
pretty_term ((if const then Const else Free) (name, dummyT));
fun pretty_cong (name, thm) =
Pretty.block [pretty_cong_name name, Pretty.str ":", Pretty.brk 1, pretty_thm thm];
val ss = dest_ss (simpset_of ctxt);
- val simprocs =
- Name_Space.markup_entries verbose ctxt
- (Name_Space.space_of_table (get_simprocs ctxt)) (#simprocs ss);
+ val simproc_space = Name_Space.space_of_table (get_simprocs ctxt);
+ val simprocs = Name_Space.markup_entries verbose ctxt simproc_space (#simprocs ss);
+ val congprocs = Name_Space.markup_entries verbose ctxt simproc_space (#congprocs ss);
in
[Pretty.big_list "simplification rules:" (map (pretty_thm_item o #2) (#simps ss)),
Pretty.big_list "simplification procedures:" (map pretty_simproc simprocs),
+ Pretty.big_list "congruence procedures:" (map pretty_congproc congprocs),
Pretty.big_list "congruences:" (map pretty_cong (#congs ss)),
Pretty.strs ("loopers:" :: map quote (#loopers ss)),
Pretty.strs ("unsafe solvers:" :: map quote (#unsafe_solvers ss)),