option "sort_updates" for record update simproc. Make proper record simproc definitions.
authorNorbert Schirmer <nschirmer@apple.com>
Mon, 07 Mar 2022 13:37:17 +0100
changeset 76039 ca7737249aa4
parent 76038 46eea084f393
child 76040 5326abe1fff8
option "sort_updates" for record update simproc. Make proper record simproc definitions.
src/HOL/Examples/Records.thy
src/HOL/Tools/record.ML
--- a/src/HOL/Examples/Records.thy	Thu Sep 01 13:01:38 2022 +0100
+++ b/src/HOL/Examples/Records.thy	Mon Mar 07 13:37:17 2022 +0100
@@ -1,6 +1,7 @@
 (*  Title:      HOL/Examples/Records.thy
     Author:     Wolfgang Naraschewski, TU Muenchen
     Author:     Norbert Schirmer, TU Muenchen
+    Author:     Norbert Schirmer, Apple, 2022
     Author:     Markus Wenzel, TU Muenchen
 *)
 
@@ -294,8 +295,72 @@
   by the following lemma.\<close>
 
 lemma "\<exists>r. xpos r = x"
-  by (tactic \<open>simp_tac (put_simpset HOL_basic_ss \<^context>
-    addsimprocs [Record.ex_sel_eq_simproc]) 1\<close>)
+  supply [[simproc add: Record.ex_sel_eq]]
+  apply (simp)
+  done
+
+subsection \<open>Simprocs for update and equality\<close>
+
+record alph1 =
+a::nat
+b::nat
+
+record alph2 = alph1 +
+c::nat
+d::nat
+
+record alph3 = alph2 + 
+  e::nat
+  f::nat
+
+text \<open>The simprocs that are activated by default are:
+\<^item> @{ML [source] Record.simproc}: field selection of (nested) record updates.
+\<^item> @{ML [source] Record.upd_simproc}: nested record updates.
+\<^item> @{ML [source] Record.eq_simproc}: (componentwise) equality of records.
+\<close>
+
+
+text \<open>By default record updates are not ordered by simplification.\<close>
+schematic_goal "r\<lparr>b := x, a:= y\<rparr> = ?X"
+  by simp
+
+text \<open>Normalisation towards an update ordering (string ordering of update function names) can
+be configured as follows.\<close>
+schematic_goal "r\<lparr>b := y, a := x\<rparr> = ?X"
+  supply [[record_sort_updates = true]]
+  by simp
+
+text \<open>Note the interplay between update ordering and record equality. Without update ordering
+the following equality is handled by @{ML [source] Record.eq_simproc}. Record equality is thus
+solved by componentwise comparison of all the fields of the records which can be expensive 
+in the presence of many fields.\<close>
+
+lemma "r\<lparr>f := x1, a:= x2\<rparr> = r\<lparr>a := x2, f:= x1\<rparr>"
+  by simp
+
+lemma "r\<lparr>f := x1, a:= x2\<rparr> = r\<lparr>a := x2, f:= x1\<rparr>"
+  supply [[simproc del: Record.eq]]
+  apply (simp?)
+  oops
+
+text \<open>With update ordering the equality is already established after update normalisation. There
+is no need for componentwise comparison.\<close>
+
+lemma "r\<lparr>f := x1, a:= x2\<rparr> = r\<lparr>a := x2, f:= x1\<rparr>"
+  supply [[record_sort_updates = true, simproc del: Record.eq]]
+  apply simp
+  done
+
+schematic_goal "r\<lparr>f := x1, e := x2, d:= x3, c:= x4, b:=x5, a:= x6\<rparr> = ?X"
+  supply [[record_sort_updates = true]]
+  by simp
+
+schematic_goal "r\<lparr>f := x1, e := x2, d:= x3, c:= x4, e:=x5, a:= x6\<rparr> = ?X"
+  supply [[record_sort_updates = true]]
+  by simp
+
+schematic_goal "r\<lparr>f := x1, e := x2, d:= x3, c:= x4, e:=x5, a:= x6\<rparr> = ?X"
+  by simp
 
 
 subsection \<open>A more complex record expression\<close>
@@ -324,6 +389,24 @@
   bar520 :: nat
   bar521 :: "nat \<times> nat"
 
+
+setup \<open>
+let
+ val N = 300
+in
+  Record.add_record {overloaded=false} ([], \<^binding>\<open>large_record\<close>) NONE 
+    (map (fn i => (Binding.make ("fld_" ^ string_of_int i, \<^here>), @{typ nat}, Mixfix.NoSyn)) 
+      (1 upto N))
+end
+\<close>
+
 declare [[record_codegen = true]]
 
+schematic_goal \<open>fld_1 (r\<lparr>fld_300 := x300, fld_20 := x20, fld_200 := x200\<rparr>) = ?X\<close>
+  by simp
+
+schematic_goal \<open>r\<lparr>fld_300 := x300, fld_20 := x20, fld_200 := x200\<rparr> = ?X\<close>
+  supply [[record_sort_updates]]
+  by simp
+
 end
--- a/src/HOL/Tools/record.ML	Thu Sep 01 13:01:38 2022 +0100
+++ b/src/HOL/Tools/record.ML	Mon Mar 07 13:37:17 2022 +0100
@@ -2,6 +2,7 @@
     Author:     Wolfgang Naraschewski, TU Muenchen
     Author:     Markus Wenzel, TU Muenchen
     Author:     Norbert Schirmer, TU Muenchen
+    Author:     Norbert Schirmer, Apple, 2022
     Author:     Thomas Sewell, NICTA
 
 Extensible records with structural subtyping.
@@ -50,6 +51,7 @@
   val string_of_record: Proof.context -> string -> string
 
   val codegen: bool Config.T
+  val sort_updates: bool Config.T
   val updateN: string
   val ext_typeN: string
   val extN: string
@@ -196,6 +198,7 @@
 val updacc_cong_from_eq = @{thm iso_tuple_update_accessor_cong_from_eq};
 
 val codegen = Attrib.setup_config_bool \<^binding>\<open>record_codegen\<close> (K true);
+val sort_updates = Attrib.setup_config_bool \<^binding>\<open>record_sort_updates\<close> (K false);
 
 
 (** name components **)
@@ -984,9 +987,8 @@
     val dest = if comp then @{thm o_eq_dest_lhs} else @{thm o_eq_dest};
   in Drule.export_without_context (othm RS dest) end;
 
-fun get_updupd_simps ctxt term defset =
+fun gen_get_updupd_simps ctxt upd_funs defset =
   let
-    val upd_funs = get_upd_funs term;
     val cname = fst o dest_Const;
     fun getswap u u' = get_updupd_simp ctxt defset u u' (cname u = cname u');
     fun build_swaps_to_eq _ [] swaps = swaps
@@ -1007,7 +1009,9 @@
           else swaps_needed us (u :: prev) (Symtab.insert (K true) (cname u, ()) seen) swaps;
   in swaps_needed upd_funs [] Symtab.empty Symreltab.empty end;
 
-fun prove_unfold_defs thy ex_simps ex_simprs prop =
+fun get_updupd_simps ctxt term defset = gen_get_updupd_simps ctxt (get_upd_funs term) defset;
+
+fun prove_unfold_defs thy upd_funs ex_simps ex_simprs prop =
   let
     val ctxt = Proof_Context.init_global thy;
 
@@ -1015,7 +1019,10 @@
     val prop' = Envir.beta_eta_contract prop;
     val (lhs, _) = Logic.dest_equals (Logic.strip_assums_concl prop');
     val (_, args) = strip_comb lhs;
-    val simps = (if length args = 1 then get_accupd_simps else get_updupd_simps) ctxt lhs defset;
+    val simps = if null upd_funs then 
+                   (if length args = 1 then get_accupd_simps else get_updupd_simps) ctxt lhs defset
+                else
+                  gen_get_updupd_simps ctxt upd_funs defset
   in
     Goal.prove ctxt [] [] prop'
       (fn {context = ctxt', ...} =>
@@ -1053,8 +1060,7 @@
   - If X is a more-selector we have to make sure that S is not in the updated
     subrecord.
 *)
-val simproc =
-  Simplifier.make_simproc \<^context> "record"
+val  _ = Theory.setup (Named_Target.theory_map (Simplifier.define_simproc \<^binding>\<open>select_update\<close> 
    {lhss = [\<^term>\<open>x::'a::{}\<close>],
     proc = fn _ => fn ctxt => fn ct =>
       let
@@ -1102,13 +1108,15 @@
                 (case mk_eq_terms (upd $ k $ r) of
                   SOME (trm, trm', vars) =>
                     SOME
-                      (prove_unfold_defs thy [] []
+                      (prove_unfold_defs thy [] [] []
                         (Logic.list_all (vars, Logic.mk_equals (sel $ trm, trm'))))
                 | NONE => NONE)
               end
             else NONE
         | _ => NONE)
-      end};
+      end}));
+val simproc_name = Simplifier.check_simproc (Context.the_local_context ()) ("select_update", Position.none);
+val simproc = Simplifier.the_simproc (Context.the_local_context ()) simproc_name;
 
 fun get_upd_acc_cong_thm upd acc thy ss =
   let
@@ -1126,7 +1134,25 @@
         TRY (resolve_tac ctxt' [updacc_cong_idI] 1))
   end;
 
-
+fun sorted ord [] = true
+  | sorted ord [x] = true
+  | sorted ord (x::y::xs) = 
+      (case ord (x, y) of 
+         LESS => sorted ord (y::xs)
+       | EQUAL => sorted ord (y::xs)
+       | GREATER => false)
+
+fun insert_unique ord x [] = [x]
+  | insert_unique ord x (y::ys) = 
+      (case ord (x, y) of 
+         LESS => (x::y::ys)
+       | EQUAL => (x::ys)
+       | GREATER => y :: insert_unique ord x ys)
+
+fun insert_unique_hd ord (x::xs) = x :: insert_unique ord x xs
+  | insert_unique_hd ord xs = xs
+
+                                 
 (* upd_simproc *)
 
 (*Simplify multiple updates:
@@ -1137,8 +1163,7 @@
   In both cases "more" updates complicate matters: for this reason
   we omit considering further updates if doing so would introduce
   both a more update and an update to a field within it.*)
-val upd_simproc =
-  Simplifier.make_simproc \<^context> "record_upd"
+val  _ = Theory.setup (Named_Target.theory_map (Simplifier.define_simproc \<^binding>\<open>update\<close> 
    {lhss = [\<^term>\<open>x::'a::{}\<close>],
     proc = fn _ => fn ctxt => fn ct =>
       let
@@ -1169,6 +1194,13 @@
           | getupdseq term _ _ = ([], term, HOLogic.unitT);
 
         val (upds, base, baseT) = getupdseq t 0 ~1;
+        val orig_upds = map_index (fn (i, (x, y, z)) => (x, y, z, i)) upds
+        val upd_ord = rev_order o fast_string_ord o apply2 #2
+        val (upds, commuted) = 
+          if not (null orig_upds) andalso Config.get ctxt sort_updates andalso not (sorted upd_ord orig_upds) then
+             (sort upd_ord orig_upds, true)
+          else 
+             (orig_upds, false)
 
         fun is_upd_noop s (Abs (n, T, Const (s', T') $ tm')) tm =
               if s = s' andalso null (loose_bnos tm')
@@ -1193,7 +1225,7 @@
           | add_upd f fs = (f :: fs);
 
         (*mk_updterm returns
-          (orig-term-skeleton, simplified-skeleton,
+          (orig-term-skeleton-update list , simplified-skeleton,
             variables, duplicate-updates, simp-flag, noop-simps)
 
           where duplicate-updates is a table used to pass upward
@@ -1201,9 +1233,9 @@
           into an update above them, simp-flag indicates whether
           any simplification was achieved, and noop-simps are
           used for eliminating case (2) defined above*)
-        fun mk_updterm ((upd as Const (u, T), s, f) :: upds) above term =
+        fun mk_updterm ((upd as Const (u, T), s, f, i) :: upds) above term =
               let
-                val (lhs, rhs, vars, dups, simp, noops) =
+                val (lhs_upds, rhs, vars, dups, simp, noops) =
                   mk_updterm upds (Symtab.update (u, ()) above) term;
                 val (fvar, skelf) =
                   K_skeleton (Long_Name.base_name s) (domain_type T) (Bound (length vars)) f;
@@ -1213,35 +1245,45 @@
                   Const (\<^const_name>\<open>Fun.comp\<close>, funT --> funT --> funT) $ f $ f';
               in
                 if isnoop then
-                  (upd $ skelf' $ lhs, rhs, vars,
+                  ((upd $ skelf', i)::lhs_upds, rhs, vars,
                     Symtab.update (u, []) dups, true,
                     if Symtab.defined noops u then noops
                     else Symtab.update (u, get_noop_simps upd skelf') noops)
                 else if Symtab.defined above u then
-                  (upd $ skelf $ lhs, rhs, fvar :: vars,
+                  ((upd $ skelf, i)::lhs_upds, rhs, fvar :: vars,
                     Symtab.map_default (u, []) (add_upd skelf) dups,
                     true, noops)
                 else
                   (case Symtab.lookup dups u of
                     SOME fs =>
-                     (upd $ skelf $ lhs,
+                     ((upd $ skelf, i)::lhs_upds,
                       upd $ foldr1 mk_comp_local (add_upd skelf fs) $ rhs,
                       fvar :: vars, dups, true, noops)
-                  | NONE => (upd $ skelf $ lhs, upd $ skelf $ rhs, fvar :: vars, dups, simp, noops))
+                  | NONE => ((upd $ skelf, i)::lhs_upds, upd $ skelf $ rhs, fvar :: vars, dups, simp, noops))
               end
           | mk_updterm [] _ _ =
-              (Bound 0, Bound 0, [("r", baseT)], Symtab.empty, false, Symtab.empty)
-          | mk_updterm us _ _ = raise TERM ("mk_updterm match", map (fn (x, _, _) => x) us);
-
-        val (lhs, rhs, vars, _, simp, noops) = mk_updterm upds Symtab.empty base;
+              ([], Bound 0, [("r", baseT)], Symtab.empty, false, Symtab.empty)
+          | mk_updterm us _ _ = raise TERM ("mk_updterm match", map (fn (x, _, _, _) => x) us);
+
+        val (lhs_upds, rhs, vars, _, simp, noops) = mk_updterm upds Symtab.empty base;
+        val orig_order_lhs_upds = lhs_upds |> sort (rev_order o int_ord o apply2 snd)
+        val lhs = Bound 0 |> fold (fn (upd, _) => fn s => upd $ s) orig_order_lhs_upds
+        (* Note that the simplifier works bottom up. So all nested updates are already
+           normalised, e.g. sorted. 'commuted' thus means that the outermost update has to be 
+           inserted at its place inside the sorted nested updates. The necessary swaps can be 
+           expressed via 'upd_funs' by replicating the outer update at the designated position: *)
+        val upd_funs = (if commuted then insert_unique_hd upd_ord orig_upds else orig_upds) |> map #1 
         val noops' = maps snd (Symtab.dest noops);
       in
-        if simp then
+        if simp orelse commuted then
           SOME
-            (prove_unfold_defs thy noops' [simproc]
+            (prove_unfold_defs thy upd_funs noops' [simproc]
               (Logic.list_all (vars, Logic.mk_equals (lhs, rhs))))
         else NONE
-      end};
+      end}));
+val upd_simproc_name = Simplifier.check_simproc (Context.the_local_context ()) 
+      ("update", Position.none);
+val upd_simproc = Simplifier.the_simproc (Context.the_local_context ()) upd_simproc_name;
 
 end;
 
@@ -1260,8 +1302,8 @@
              eq_simproc          split_simp_tac
  Complexity: #components * #updates     #updates
 *)
-val eq_simproc =
-  Simplifier.make_simproc \<^context> "record_eq"
+
+val  _ = Theory.setup (Named_Target.theory_map (Simplifier.define_simproc \<^binding>\<open>eq\<close>  
    {lhss = [\<^term>\<open>r = s\<close>],
     proc = fn _ => fn ctxt => fn ct =>
       (case Thm.term_of ct of
@@ -1272,8 +1314,10 @@
               (case get_equalities (Proof_Context.theory_of ctxt) name of
                 NONE => NONE
               | SOME thm => SOME (thm RS @{thm Eq_TrueI})))
-      | _ => NONE)};
-
+      | _ => NONE)}))
+val eq_simproc_name = Simplifier.check_simproc (Context.the_local_context ()) 
+      ("eq", Position.none);
+val eq_simproc = Simplifier.the_simproc (Context.the_local_context ()) eq_simproc_name;
 
 (* split_simproc *)
 
@@ -1311,8 +1355,7 @@
           else NONE
       | _ => NONE)};
 
-val ex_sel_eq_simproc =
-  Simplifier.make_simproc \<^context> "ex_sel_eq"
+val  _ = Theory.setup (Named_Target.theory_map (Simplifier.define_simproc \<^binding>\<open>ex_sel_eq\<close>
    {lhss = [\<^term>\<open>Ex t\<close>],
     proc = fn _ => fn ctxt => fn ct =>
       let
@@ -1350,7 +1393,11 @@
                     addsimps @{thms simp_thms} addsimprocs [split_simproc (K ~1)]) 1))
             end handle TERM _ => NONE)
         | _ => NONE)
-      end};
+      end}));
+val ex_sel_eq_simproc_name = Simplifier.check_simproc (Context.the_local_context ()) 
+      ("ex_sel_eq", Position.none);
+val ex_sel_eq_simproc = Simplifier.the_simproc (Context.the_local_context ()) ex_sel_eq_simproc_name;
+val _ = Theory.setup (map_theory_simpset (fn ctxt => ctxt delsimprocs [ex_sel_eq_simproc]));
 
 
 (* split_simp_tac *)
@@ -1431,11 +1478,6 @@
     else no_tac
   end);
 
-val _ =
-  Theory.setup
-    (map_theory_simpset (fn ctxt => ctxt addsimprocs [simproc, upd_simproc, eq_simproc]));
-
-
 (* wrapper *)
 
 val split_name = "record_split_tac";