replace Nat_Arith simprocs with simpler conversions that do less rearrangement of terms
authorhuffman
Fri, 27 Jul 2012 17:59:18 +0200
changeset 48560 e0875d956a6b
parent 48559 686cc7c47589
child 48561 12aa0cb2b447
replace Nat_Arith simprocs with simpler conversions that do less rearrangement of terms
src/HOL/IsaMakefile
src/HOL/Nat.thy
src/HOL/Tools/lin_arith.ML
src/HOL/Tools/nat_arith.ML
src/Provers/Arith/cancel_sums.ML
--- a/src/HOL/IsaMakefile	Fri Jul 27 17:57:31 2012 +0200
+++ b/src/HOL/IsaMakefile	Fri Jul 27 17:59:18 2012 +0200
@@ -171,7 +171,6 @@
 
 PLAIN_DEPENDENCIES = $(BASE_DEPENDENCIES) \
   $(SRC)/Provers/Arith/cancel_div_mod.ML \
-  $(SRC)/Provers/Arith/cancel_sums.ML \
   $(SRC)/Provers/Arith/fast_lin_arith.ML \
   $(SRC)/Provers/order.ML \
   $(SRC)/Provers/trancl.ML \
--- a/src/HOL/Nat.thy	Fri Jul 27 17:57:31 2012 +0200
+++ b/src/HOL/Nat.thy	Fri Jul 27 17:59:18 2012 +0200
@@ -11,7 +11,6 @@
 imports Inductive Typedef Fun Fields
 uses
   "~~/src/Tools/rat.ML"
-  "~~/src/Provers/Arith/cancel_sums.ML"
   "Tools/arith_data.ML"
   ("Tools/nat_arith.ML")
   "~~/src/Provers/Arith/fast_lin_arith.ML"
@@ -1497,19 +1496,19 @@
 
 simproc_setup nateq_cancel_sums
   ("(l::nat) + m = n" | "(l::nat) = m + n" | "Suc m = n" | "m = Suc n") =
-  {* fn phi => Nat_Arith.nateq_cancel_sums *}
+  {* fn phi => fn ss => try Nat_Arith.cancel_eq_conv *}
 
 simproc_setup natless_cancel_sums
   ("(l::nat) + m < n" | "(l::nat) < m + n" | "Suc m < n" | "m < Suc n") =
-  {* fn phi => Nat_Arith.natless_cancel_sums *}
+  {* fn phi => fn ss => try Nat_Arith.cancel_less_conv *}
 
 simproc_setup natle_cancel_sums
   ("(l::nat) + m \<le> n" | "(l::nat) \<le> m + n" | "Suc m \<le> n" | "m \<le> Suc n") =
-  {* fn phi => Nat_Arith.natle_cancel_sums *}
+  {* fn phi => fn ss => try Nat_Arith.cancel_le_conv *}
 
 simproc_setup natdiff_cancel_sums
   ("(l::nat) + m - n" | "(l::nat) - (m + n)" | "Suc m - n" | "m - Suc n") =
-  {* fn phi => Nat_Arith.natdiff_cancel_sums *}
+  {* fn phi => fn ss => try Nat_Arith.cancel_diff_conv *}
 
 use "Tools/lin_arith.ML"
 setup {* Lin_Arith.global_setup *}
--- a/src/HOL/Tools/lin_arith.ML	Fri Jul 27 17:57:31 2012 +0200
+++ b/src/HOL/Tools/lin_arith.ML	Fri Jul 27 17:59:18 2012 +0200
@@ -805,8 +805,9 @@
       addsimps @{thms ring_distribs}
       addsimps [@{thm if_True}, @{thm if_False}]
       addsimps
-       [@{thm add_0_left},
-        @{thm add_0_right},
+       [@{thm add_0_left}, @{thm add_0_right},
+        @{thm add_Suc}, @{thm add_Suc_right},
+        @{thm nat.inject}, @{thm Suc_le_mono}, @{thm Suc_less_eq},
         @{thm "Zero_not_Suc"}, @{thm "Suc_not_Zero"}, @{thm "le_0_eq"}, @{thm "One_nat_def"},
         @{thm "order_less_irrefl"}, @{thm "zero_neq_one"}, @{thm "zero_less_one"},
         @{thm "zero_le_one"}, @{thm "zero_neq_one"} RS not_sym, @{thm "not_one_le_zero"},
--- a/src/HOL/Tools/nat_arith.ML	Fri Jul 27 17:57:31 2012 +0200
+++ b/src/HOL/Tools/nat_arith.ML	Fri Jul 27 17:59:18 2012 +0200
@@ -1,17 +1,19 @@
 (* Author: Markus Wenzel, Stefan Berghofer, and Tobias Nipkow
+   Author: Brian Huffman
 
 Basic arithmetic for natural numbers.
 *)
 
 signature NAT_ARITH =
 sig
+  val cancel_diff_conv: conv
+  val cancel_eq_conv: conv
+  val cancel_le_conv: conv
+  val cancel_less_conv: conv
+  (* legacy functions: *)
   val mk_sum: term list -> term
   val mk_norm_sum: term list -> term
   val dest_sum: term -> term list
-  val nateq_cancel_sums: simpset -> cterm -> thm option
-  val natless_cancel_sums: simpset -> cterm -> thm option
-  val natle_cancel_sums: simpset -> cterm -> thm option
-  val natdiff_cancel_sums: simpset -> cterm -> thm option
 end;
 
 structure Nat_Arith: NAT_ARITH =
@@ -42,55 +44,58 @@
           SOME (t, u) => dest_sum t @ dest_sum u
         | NONE => [tm]));
 
+val add1 = @{lemma "(A::'a::comm_monoid_add) == k + a ==> A + b == k + (a + b)"
+      by (simp only: add_ac)}
+val add2 = @{lemma "(B::'a::comm_monoid_add) == k + b ==> a + B == k + (a + b)"
+      by (simp only: add_ac)}
+val suc1 = @{lemma "A == k + a ==> Suc A == k + Suc a"
+      by (simp only: add_Suc_right)}
+val rule0 = @{lemma "(a::'a::comm_monoid_add) == a + 0"
+      by (simp only: add_0_right)}
 
-(** cancel common summands **)
+val norm_rules = map mk_meta_eq @{thms add_0_left add_0_right}
 
-structure CommonCancelSums =
-struct
-  val mk_sum = mk_norm_sum;
-  val dest_sum = dest_sum;
-  val mk_plus = HOLogic.mk_binop @{const_name Groups.plus};
-  val norm_tac1 = Arith_Data.simp_all_tac [@{thm add_Suc}, @{thm add_Suc_right},
-    @{thm Nat.add_0}, @{thm Nat.add_0_right}];
-  val norm_tac2 = Arith_Data.simp_all_tac @{thms add_ac};
-  fun norm_tac ss = norm_tac1 ss THEN norm_tac2 ss;
-end;
+fun move_to_front path = Conv.every_conv
+    [Conv.rewr_conv (Library.foldl (op RS) (rule0, path)),
+     Conv.arg_conv (Raw_Simplifier.rewrite false norm_rules)]
 
-structure EqCancelSums = CancelSumsFun
-(struct
-  open CommonCancelSums;
-  val mk_bal = HOLogic.mk_eq;
-  val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} HOLogic.natT;
-  val cancel_rule = mk_meta_eq @{thm nat_add_left_cancel};
-end);
+fun add_atoms path (Const (@{const_name Groups.plus}, _) $ x $ y) =
+      add_atoms (add1::path) x #> add_atoms (add2::path) y
+  | add_atoms path (Const (@{const_name Nat.Suc}, _) $ x) =
+      add_atoms (suc1::path) x
+  | add_atoms _ (Const (@{const_name Groups.zero}, _)) = I
+  | add_atoms path x = cons (x, path)
+
+fun atoms t = add_atoms [] t []
+
+exception Cancel
 
-structure LessCancelSums = CancelSumsFun
-(struct
-  open CommonCancelSums;
-  val mk_bal = HOLogic.mk_binrel @{const_name Orderings.less};
-  val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} HOLogic.natT;
-  val cancel_rule = mk_meta_eq @{thm nat_add_left_cancel_less};
-end);
+fun find_common ord xs ys =
+  let
+    fun find (xs as (x, px)::xs') (ys as (y, py)::ys') =
+        (case ord (x, y) of
+          EQUAL => (px, py)
+        | LESS => find xs' ys
+        | GREATER => find xs ys')
+      | find _ _ = raise Cancel
+    fun ord' ((x, _), (y, _)) = ord (x, y)
+  in
+    find (sort ord' xs) (sort ord' ys)
+  end
 
-structure LeCancelSums = CancelSumsFun
-(struct
-  open CommonCancelSums;
-  val mk_bal = HOLogic.mk_binrel @{const_name Orderings.less_eq};
-  val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} HOLogic.natT;
-  val cancel_rule = mk_meta_eq @{thm nat_add_left_cancel_le};
-end);
+fun cancel_conv rule ct =
+  let
+    val ((_, lhs), rhs) = (apfst dest_comb o dest_comb) (Thm.term_of ct)
+    val (lpath, rpath) = find_common Term_Ord.term_ord (atoms lhs) (atoms rhs)
+    val lconv = move_to_front lpath
+    val rconv = move_to_front rpath
+    val conv1 = Conv.combination_conv (Conv.arg_conv lconv) rconv
+    val conv = conv1 then_conv Conv.rewr_conv rule
+  in conv ct handle Cancel => raise CTERM ("no_conversion", []) end
 
-structure DiffCancelSums = CancelSumsFun
-(struct
-  open CommonCancelSums;
-  val mk_bal = HOLogic.mk_binop @{const_name Groups.minus};
-  val dest_bal = HOLogic.dest_bin @{const_name Groups.minus} HOLogic.natT;
-  val cancel_rule = mk_meta_eq @{thm diff_cancel};
-end);
-
-fun nateq_cancel_sums ss = EqCancelSums.proc ss o Thm.term_of
-fun natless_cancel_sums ss = LessCancelSums.proc ss o Thm.term_of
-fun natle_cancel_sums ss = LeCancelSums.proc ss o Thm.term_of
-fun natdiff_cancel_sums ss = DiffCancelSums.proc ss o Thm.term_of
+val cancel_diff_conv = cancel_conv (mk_meta_eq @{thm diff_cancel})
+val cancel_eq_conv = cancel_conv (mk_meta_eq @{thm add_left_cancel})
+val cancel_le_conv = cancel_conv (mk_meta_eq @{thm add_le_cancel_left})
+val cancel_less_conv = cancel_conv (mk_meta_eq @{thm add_less_cancel_left})
 
 end;
--- a/src/Provers/Arith/cancel_sums.ML	Fri Jul 27 17:57:31 2012 +0200
+++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
@@ -1,78 +0,0 @@
-(*  Title:      Provers/Arith/cancel_sums.ML
-    Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
-
-Cancel common summands of balanced expressions:
-
-  A + x + B ~~ A' + x + B'  ==  A + B ~~ A' + B'
-
-where + is AC0 and ~~ an appropriate balancing operation (e.g. =, <=, <, -).
-*)
-
-signature CANCEL_SUMS_DATA =
-sig
-  (*abstract syntax*)
-  val mk_sum: term list -> term
-  val dest_sum: term -> term list
-  val mk_plus: term * term -> term
-  val mk_bal: term * term -> term
-  val dest_bal: term -> term * term
-  (*rules*)
-  val norm_tac: simpset -> tactic            (*AC0 etc.*)
-  val cancel_rule: thm                       (* x + A ~~ x + B == A ~~ B *)
-end;
-
-signature CANCEL_SUMS =
-sig
-  val proc: simpset -> term -> thm option
-end;
-
-
-functor CancelSumsFun(Data: CANCEL_SUMS_DATA): CANCEL_SUMS =
-struct
-
-
-(* cancel *)
-
-fun cons1 x (xs, y, z) = (x :: xs, y, z);
-fun cons2 y (x, ys, z) = (x, y :: ys, z);
-
-(*symmetric difference of multisets -- assumed to be sorted wrt. Term_Ord.term_ord*)
-fun cancel ts [] vs = (ts, [], vs)
-  | cancel [] us vs = ([], us, vs)
-  | cancel (t :: ts) (u :: us) vs =
-      (case Term_Ord.term_ord (t, u) of
-        EQUAL => cancel ts us (t :: vs)
-      | LESS => cons1 t (cancel ts (u :: us) vs)
-      | GREATER => cons2 u (cancel (t :: ts) us vs));
-
-
-(* the simplification procedure *)
-
-fun proc ss t =
-  (case try Data.dest_bal t of
-    NONE => NONE
-  | SOME bal =>
-      let
-        val thy = Proof_Context.theory_of (Simplifier.the_context ss);
-        val (ts, us) = pairself (sort Term_Ord.term_ord o Data.dest_sum) bal;
-        val (ts', us', vs) = cancel ts us [];
-      in
-        if null vs then NONE
-        else
-          let
-            val cert = Thm.cterm_of thy
-            val t' = Data.mk_sum ts'
-            val u' = Data.mk_sum us'
-            val v = Data.mk_sum vs
-            val t1 = Data.mk_bal (Data.mk_plus (v, t'), Data.mk_plus (v, u'))
-            val t2 = Data.mk_bal (t', u')
-            val goal1 = Thm.cterm_of thy (Logic.mk_equals (t, t1))
-            val goal2 = Thm.cterm_of thy (Logic.mk_equals (t1, t2))
-            val thm1 = Goal.prove_internal [] goal1 (K (Data.norm_tac ss))
-            val thm2 = Goal.prove_internal [] goal2 (K (rtac Data.cancel_rule 1))
-          in
-            SOME (Thm.transitive thm1 thm2)
-          end
-      end);
-
-end;