src/Pure/raw_simplifier.ML
changeset 80709 e6f026505c5b
parent 80707 897c993293c5
child 80710 82c0bfbaaa86
--- a/src/Pure/raw_simplifier.ML	Wed Aug 14 18:59:49 2024 +0200
+++ b/src/Pure/raw_simplifier.ML	Wed Aug 14 21:23:22 2024 +0200
@@ -25,9 +25,11 @@
   type simpset
   val empty_ss: simpset
   val merge_ss: simpset * simpset -> simpset
+  datatype proc_kind = Simproc | Congproc of bool
   type simproc
   val cert_simproc: theory ->
-    {name: string, lhss: term list, proc: proc Morphism.entity, identifier: thm list} -> simproc
+    {name: string, kind: proc_kind, lhss: term list,
+      proc: proc Morphism.entity, identifier: thm list} -> simproc
   val transform_simproc: morphism -> simproc -> simproc
   val trim_context_simproc: simproc -> simproc
   val simpset_of: Proof.context -> simpset
@@ -68,6 +70,7 @@
   val dest_ss: simpset ->
    {simps: (Thm_Name.T * thm) list,
     simprocs: (string * term list) list,
+    congprocs: (string * {lhss: term list, proc: proc Morphism.entity}) list,
     congs: (cong_name * thm) list,
     weak_congs: cong_name list,
     loopers: string list,
@@ -210,15 +213,31 @@
 
 (* simplification procedures *)
 
+datatype proc_kind = Simproc | Congproc of bool;
+
+val is_congproc = fn Congproc _ => true | _ => false;
+val is_weak_congproc = fn Congproc weak => weak | _ => false;
+
+fun map_procs kind f (simprocs, congprocs) =
+  if is_congproc kind then (simprocs, f congprocs) else (f simprocs, congprocs);
+
+fun print_proc_kind Simproc = "simplification procedure"
+  | print_proc_kind (Congproc false) = "simplification procedure (cong)"
+  | print_proc_kind (Congproc true) = "simplification procedure (weak cong)";
+
 type proc = Proof.context -> cterm -> thm option;
 
 datatype 'lhs procedure =
   Procedure of
    {name: string,
+    kind: proc_kind,
     lhs: 'lhs,
     proc: proc Morphism.entity,
     id: stamp * thm list};
 
+fun procedure_kind (Procedure {kind, ...}) = kind;
+fun procedure_lhs (Procedure {lhs, ...}) = lhs;
+
 fun eq_procedure_id (Procedure {id = (s1, ths1), ...}, Procedure {id = (s2, ths2), ...}) =
   s1 = s2 andalso eq_list Thm.eq_thm_prop (ths1, ths2);
 
@@ -249,6 +268,7 @@
            A congruence is `weak' if it avoids normalization of some argument.
     procs: simplification procedures indexed via discrimination net
       simprocs: functions that prove rewrite rules on the fly;
+      congprocs: functions that prove congruence rules on the fly;
     mk_rews:
       mk: turn simplification thms into rewrite rules;
       mk_cong: prepare congruence rules;
@@ -262,7 +282,7 @@
     prems: thm list,
     depth: int * bool Unsynchronized.ref} *
    {congs: thm Congtab.table * cong_name list,
-    procs: term procedure Net.net,
+    procs: term procedure Net.net * term procedure Net.net,
     mk_rews:
      {mk: Proof.context -> thm -> thm list,
       mk_cong: Proof.context -> thm -> thm,
@@ -298,11 +318,13 @@
 fun dest_procs procs =
   Net.entries procs
   |> partition_eq eq_procedure_id
-  |> map (fn ps as Procedure {name, ...} :: _ => (name, map (fn Procedure {lhs, ...} => lhs) ps));
+  |> map (fn ps as Procedure {name, proc, ...} :: _ =>
+      (name, {lhss = map (fn Procedure {lhs, ...} => lhs) ps, proc = proc}));
 
-fun dest_ss (ss as Simpset (_, {congs, procs = simprocs, loop_tacs, solvers, ...})) =
+fun dest_ss (ss as Simpset (_, {congs, procs = (simprocs, congprocs), loop_tacs, solvers, ...})) =
  {simps = dest_simps ss,
-  simprocs = dest_procs simprocs,
+  simprocs = map (apsnd #lhss) (dest_procs simprocs),
+  congprocs = dest_procs congprocs,
   congs = dest_congs ss,
   weak_congs = #2 congs,
   loopers = map fst loop_tacs,
@@ -314,7 +336,7 @@
 
 fun init_ss depth mk_rews term_ord subgoal_tac solvers =
   make_simpset ((Net.empty, [], depth),
-    ((Congtab.empty, []), Net.empty, mk_rews, term_ord, subgoal_tac, [], solvers));
+    ((Congtab.empty, []), (Net.empty, Net.empty), mk_rews, term_ord, subgoal_tac, [], solvers));
 
 fun default_mk_sym _ th = SOME (th RS Drule.symmetric_thm);
 
@@ -335,11 +357,11 @@
   else
     let
       val Simpset ({rules = rules1, prems = prems1, depth = depth1},
-       {congs = (congs1, weak1), procs = simprocs1, mk_rews, term_ord, subgoal_tac,
+       {congs = (congs1, weak1), procs = (simprocs1, congprocs1), mk_rews, term_ord, subgoal_tac,
         loop_tacs = loop_tacs1, solvers = (unsafe_solvers1, solvers1)}) = ss1;
       val Simpset ({rules = rules2, prems = prems2, depth = depth2},
-       {congs = (congs2, weak2), procs = simprocs2, mk_rews = _, term_ord = _, subgoal_tac = _,
-        loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2)}) = ss2;
+       {congs = (congs2, weak2), procs = (simprocs2, congprocs2),
+        loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2), ...}) = ss2;
 
       val rules' = Net.merge eq_rrule (rules1, rules2);
       val prems' = Thm.merge_thms (prems1, prems2);
@@ -347,11 +369,12 @@
       val congs' = Congtab.merge (K true) (congs1, congs2);
       val weak' = merge (op =) (weak1, weak2);
       val simprocs' = Net.merge eq_procedure_id (simprocs1, simprocs2);
+      val congprocs' = Net.merge eq_procedure_id (congprocs1, congprocs2);
       val loop_tacs' = AList.merge (op =) (K true) (loop_tacs1, loop_tacs2);
       val unsafe_solvers' = merge eq_solver (unsafe_solvers1, unsafe_solvers2);
       val solvers' = merge eq_solver (solvers1, solvers2);
     in
-      make_simpset ((rules', prems', depth'), ((congs', weak'), simprocs',
+      make_simpset ((rules', prems', depth'), ((congs', weak'), (simprocs', congprocs'),
         mk_rews, term_ord, subgoal_tac, loop_tacs', (unsafe_solvers', solvers')))
     end;
 
@@ -728,51 +751,55 @@
 
 type simproc = term list procedure;
 
-fun cert_simproc thy {name, lhss, proc, identifier} : simproc =
+fun cert_simproc thy {name, kind, lhss, proc, identifier} : simproc =
   Procedure
    {name = name,
+    kind = kind,
     lhs = map (Sign.cert_term thy) lhss,
     proc = proc,
     id = (stamp (), map (Thm.transfer thy) identifier)};
 
-fun transform_simproc phi (Procedure {name, lhs, proc, id = (stamp, identifier)}) : simproc =
+fun transform_simproc phi (Procedure {name, kind, lhs, proc, id = (stamp, identifier)}) : simproc =
   Procedure
    {name = name,
+    kind = kind,
     lhs = map (Morphism.term phi) lhs,
     proc = Morphism.transform phi proc,
     id = (stamp, Morphism.fact phi identifier)};
 
-fun trim_context_simproc (Procedure {name, lhs, proc, id = (stamp, identifier)}) : simproc =
+fun trim_context_simproc (Procedure {name, kind, lhs, proc, id = (stamp, identifier)}) : simproc =
   Procedure
    {name = name,
+    kind = kind,
     lhs = lhs,
     proc = Morphism.entity_reset_context proc,
     id = (stamp, map Thm.trim_context identifier)};
 
 local
 
-fun add_proc1 (proc as Procedure {name, lhs, ...}) ctxt =
+fun add_proc1 (proc as Procedure {name, kind, lhs, ...}) ctxt =
  (cond_tracing ctxt (fn () =>
-    print_term ctxt ("Adding simplification procedure " ^ quote name ^ " for") lhs);
+    print_term ctxt ("Adding " ^ print_proc_kind kind ^ " " ^ quote name ^ " for") lhs);
   ctxt |> map_simpset2
     (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) =>
-      (congs, Net.insert_term eq_procedure_id (lhs, proc) procs,
+      (congs, map_procs kind (Net.insert_term eq_procedure_id (lhs, proc)) procs,
         mk_rews, term_ord, subgoal_tac, loop_tacs, solvers))
   handle Net.INSERT =>
-    (cond_warning ctxt (fn () => "Ignoring duplicate simplification procedure " ^ quote name);
+    (cond_warning ctxt (fn () =>
+      "Ignoring duplicate " ^ print_proc_kind kind ^ " " ^ quote name);
       ctxt));
 
-fun del_proc1 (proc as Procedure {name, lhs, ...}) ctxt =
+fun del_proc1 (proc as Procedure {name, kind, lhs, ...}) ctxt =
   ctxt |> map_simpset2
     (fn (congs, procs, mk_rews, term_ord, subgoal_tac, loop_tacs, solvers) =>
-      (congs, Net.delete_term eq_procedure_id (lhs, proc) procs,
+      (congs, map_procs kind (Net.delete_term eq_procedure_id (lhs, proc)) procs,
         mk_rews, term_ord, subgoal_tac, loop_tacs, solvers))
   handle Net.DELETE =>
-    (cond_warning ctxt (fn () => "Simplification procedure " ^ quote name ^ " not in simpset");
+    (cond_warning ctxt (fn () => "No " ^ print_proc_kind kind ^ " " ^ quote name ^ " in simpset");
       ctxt);
 
-fun split_proc (Procedure {name, lhs = lhss, proc, id} : simproc) =
-  lhss |> map (fn lhs => Procedure {name = name, lhs = lhs, proc = proc, id = id});
+fun split_proc (Procedure {name, kind, lhs = lhss, proc, id} : simproc) =
+  lhss |> map (fn lhs => Procedure {name = name, kind = kind, lhs = lhs, proc = proc, id = id});
 
 in
 
@@ -961,19 +988,24 @@
   The latter may happen iff there are weak congruence rules for constants
   in the lhs.*)
 
-fun uncond_skel ((_, weak), (lhs, rhs)) =
-  if null weak then rhs  (*optimization*)
-  else if exists_subterm
+fun weak_cong weak lhs =
+  if null weak then false  (*optimization*)
+  else exists_subterm
     (fn Const (a, _) => member (op =) weak (true, a)
       | Free (a, _) => member (op =) weak (false, a)
-      | _ => false) lhs then skel0
+      | _ => false) lhs
+
+fun uncond_skel ((_, weak), congprocs, (lhs, rhs)) =
+  if weak_cong weak lhs then skel0
+  else if Net.is_empty congprocs then rhs  (*optimization*)
+  else if exists (is_weak_congproc o procedure_kind) (Net.match_term congprocs lhs) then skel0
   else rhs;
 
 (*Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   Otherwise those vars may become instantiated with unnormalized terms
   while the premises are solved.*)
 
-fun cond_skel (args as (_, (lhs, rhs))) =
+fun cond_skel (args as (_, _, (lhs, rhs))) =
   if Vars.subset (vars_set rhs, vars_set lhs) then uncond_skel args
   else skel0;
 
@@ -989,7 +1021,8 @@
 
 fun rewritec (prover, maxt) ctxt t =
   let
-    val Simpset ({rules, ...}, {congs, procs = simprocs, term_ord, ...}) = simpset_of ctxt;
+    val Simpset ({rules, ...}, {congs, procs = (simprocs, congprocs), term_ord, ...}) =
+      simpset_of ctxt;
     val eta_thm = Thm.eta_conversion t;
     val eta_t' = Thm.rhs_of eta_thm;
     val eta_t = Thm.term_of eta_t';
@@ -1028,7 +1061,7 @@
               let
                 val lr = Logic.dest_equals prop;
                 val SOME thm'' = check_conv ctxt' false eta_thm thm';
-              in SOME (thm'', uncond_skel (congs, lr)) end))
+              in SOME (thm'', uncond_skel (congs, congprocs, lr)) end))
           else
            (cond_tracing ctxt (fn () => print_thm0 ctxt "Trying to rewrite:" thm');
             if simp_depth ctxt > Config.get ctxt simp_depth_limit
@@ -1044,7 +1077,7 @@
                         let
                           val concl = Logic.strip_imp_concl prop;
                           val lr = Logic.dest_equals concl;
-                        in SOME (thm2', cond_skel (congs, lr)) end)))))
+                        in SOME (thm2', cond_skel (congs, congprocs, lr)) end)))))
       end;
 
     fun rews [] = NONE
@@ -1098,6 +1131,63 @@
   end;
 
 
+(* apply congprocs *)
+
+(* pattern order:
+   p1 GREATER p2: p1 is more general than p2, p1 matches p2 but not vice versa
+   p1 LESS    p2: p1 is more specific than p2, p2 matches p1 but not vice versa
+   p1 EQUAL   p2: both match each other or neither match each other
+*)
+
+fun pattern_order thy =
+  let
+    fun matches arg = can (Pattern.match thy arg) (Vartab.empty, Vartab.empty);
+  in
+    fn (p1, p2) =>
+      if matches (p1, p2) then
+        if matches (p2, p1) then EQUAL
+        else GREATER
+      else
+        if matches (p2, p1) then LESS
+        else EQUAL
+  end;
+
+fun app_congprocs ctxt ct =
+  let
+    val thy = Proof_Context.theory_of ctxt;
+    val Simpset (_, {procs = (_, congprocs), ...}) = simpset_of ctxt;
+
+    val eta_ct = Thm.rhs_of (Thm.eta_conversion ct);
+
+    fun proc_congs [] = NONE
+      | proc_congs (Procedure {name, lhs, proc, ...} :: ps) =
+          if Pattern.matches thy (lhs, Thm.term_of ct) then
+            let
+              val _ =
+                cond_tracing' ctxt simp_debug (fn () =>
+                  print_term ctxt ("Trying procedure " ^ quote name ^ " on:") (Thm.term_of eta_ct));
+
+              val ctxt' = Config.put simp_trace (Config.get ctxt simp_debug) ctxt;
+              val res =
+                trace_simproc {name = name, cterm = eta_ct} ctxt'
+                  (fn ctxt'' => Morphism.form_context' ctxt'' proc eta_ct);
+            in
+              (case res of
+                NONE => (cond_tracing' ctxt simp_debug (fn () => "FAILED"); proc_congs ps)
+              | SOME raw_thm =>
+                  (cond_tracing ctxt (fn () =>
+                     print_thm0 ctxt ("Procedure " ^ quote name ^ " produced congruence rule:")
+                       raw_thm);
+                   SOME (raw_thm, skel0)))
+            end
+          else proc_congs ps;
+  in
+    Net.match_term congprocs (Thm.term_of eta_ct)
+    |> sort (pattern_order thy o apply2 procedure_lhs)
+    |> proc_congs
+  end;
+
+
 (* conversion to apply a congruence rule to a term *)
 
 fun congc prover ctxt maxt cong t =
@@ -1202,8 +1292,15 @@
 
                     val (h, ts) = strip_comb t;
 
+     (*Prefer congprocs over plain cong rules. In congprocs prefer most specific rules.
+       If there is a matching congproc, then look into the result:
+         1. plain equality: consider normalisation complete (just as with a plain congruence rule),
+         2. conditional rule: treat like congruence rules like SOME cong case below.*)
+
                     fun app_cong () =
-                      Option.mapPartial (Congtab.lookup (fst congs)) (cong_name h);
+                      (case app_congprocs ctxt t0 of
+                        SOME (thm, _) => SOME thm
+                      | NONE => Option.mapPartial (Congtab.lookup (fst congs)) (cong_name h));
                   in
                     (case app_cong () of
                       NONE => appc ()
@@ -1213,7 +1310,9 @@
                        (let
                           val thm = congc (prover ctxt) ctxt maxidx cong t0;
                           val t = the_default t0 (Option.map Thm.rhs_of thm);
-                          val (cl, cr) = Thm.dest_comb t;
+                          val (cl, cr) = Thm.dest_comb t
+                            handle CTERM _ => Thm.dest_comb t0;  (*e.g. congproc has
+                              normalized such that head is removed from t*)
                           val dVar = Var (("", 0), dummyT);
                           val skel = list_comb (h, replicate (length ts) dVar);
                         in