now works for coefficients, not just for numerals
authorpaulson
Fri, 21 Apr 2000 11:31:03 +0200
changeset 8760 9139453d7033
parent 8759 49154c960140
child 8761 8043130d3dcf
now works for coefficients, not just for numerals no longer works by subtraction, so no need for inverse_fold
src/Provers/Arith/cancel_numerals.ML
--- a/src/Provers/Arith/cancel_numerals.ML	Fri Apr 21 11:29:57 2000 +0200
+++ b/src/Provers/Arith/cancel_numerals.ML	Fri Apr 21 11:31:03 2000 +0200
@@ -1,61 +1,103 @@
-(*  Title:      Provers/Arith/cancel_sums.ML
+(*  Title:      Provers/Arith/cancel_numerals.ML
     ID:         $Id$
     Author:     Lawrence C Paulson, Cambridge University Computer Laboratory
     Copyright   2000  University of Cambridge
 
-Cancel common literals in balanced expressions:
+Cancel common coefficients in balanced expressions:
 
-     i + #m + j ~~ i' + #m' + j'  ==  #(m-m') + i + j ~~ i' + j'
+     i + #m*u + j ~~ i' + #m'*u + j'  ==  #(m-m')*u + i + j ~~ i' + j'
 
 where ~~ is an appropriate balancing operation (e.g. =, <=, <, -).
+
+It works by (a) massaging both sides to bring the selected term to the front:
+
+     #m*u + (i + j) ~~ #m'*u + (i' + j') 
+
+(b) then using bal_add1 or bal_add2 to reach
+
+     #(m-m')*u + i + j ~~ i' + j'       (if m'<=m)
+
+or
+
+     i + j ~~ #(m'-m)*u + i' + j'       (otherwise)
 *)
 
 signature CANCEL_NUMERALS_DATA =
 sig
   (*abstract syntax*)
-  val mk_numeral: int -> term
-  val find_first_numeral: term list -> int * term * term list
   val mk_sum: term list -> term
   val dest_sum: term -> term list
   val mk_bal: term * term -> term
   val dest_bal: term -> term * term
+  val mk_coeff: int * term -> term
+  val dest_coeff: term -> int * term
+  val find_first_coeff: term -> term list -> int * term list
+  (*rules*)
+  val bal_add1: thm
+  val bal_add2: thm
   (*proof tools*)
   val prove_conv: tactic list -> Sign.sg -> term * term -> thm option
-  val subst_tac: term -> tactic
-  val all_simp_tac: tactic
-end;
-
-signature CANCEL_NUMERALS =
-sig
-  val proc: Sign.sg -> thm list -> term -> thm option
+  val norm_tac: tactic
+  val numeral_simp_tac: tactic
 end;
 
 
-functor CancelNumeralsFun(Data: CANCEL_NUMERALS_DATA): CANCEL_NUMERALS =
+functor CancelNumeralsFun(Data: CANCEL_NUMERALS_DATA):
+  sig
+  val proc: Sign.sg -> thm list -> term -> thm option
+  end 
+=
 struct
 
-(*predicting the outputs of other simprocs given a term of the form
-   (i + ... #m + ... j) - #n   *)
-fun cancelled m n terms =
-    if m = n then (*cancel_sums: sort the terms*)
-	sort Term.term_ord terms 
-    else          (*inverse_fold: subtract, keeping original term order*)
-	Data.mk_numeral (m - n) :: terms;
+fun listof None = []
+  | listof (Some x) = [x];
+
+(*If t = #n*u then put u in the table*)
+fun update_by_coeff (tab, t) =
+  Termtab.update ((#2 (Data.dest_coeff t), ()), tab)
+  handle TERM _ => tab;
+
+(*a left-to-right scan of terms1, seeking a term of the form #n*u, where
+  #m*u is in terms2 for some m*)
+fun find_common (terms1,terms2) =
+  let val tab2 = foldl update_by_coeff (Termtab.empty, terms2)
+      fun seek [] = raise TERM("find_common", []) 
+	| seek (t::terms) =
+	      let val (_,u) = Data.dest_coeff t 
+	      in  if is_some (Termtab.lookup (tab2, u)) then u
+		  else seek terms
+	      end
+	      handle TERM _ => seek terms
+  in  seek terms1 end;
 
 (*the simplification procedure*)
 fun proc sg _ t =
   let val (t1,t2) = Data.dest_bal t 
-      val (n1, lit1, terms1) = Data.find_first_numeral (Data.dest_sum t1)
-      and (n2, lit2, terms2) = Data.find_first_numeral (Data.dest_sum t2)
-      val lit_n = if n1<n2 then lit1 else lit2
-      and n     = BasisLibrary.Int.min (n1,n2)
-          (*having both the literals and their integer values makes it
-            more robust against negative natural number literals*)
+      val terms1 = Data.dest_sum t1
+      and terms2 = Data.dest_sum t2
+      val u = find_common (terms1,terms2)
+      val (n1, terms1') = Data.find_first_coeff u terms1
+      and (n2, terms2') = Data.find_first_coeff u terms2
+      fun newshape (i,terms) = Data.mk_sum (Data.mk_coeff(i,u)::terms)
+      val reshapes =  (*Move i*u to the front and put j*u into standard form
+		      i + #m + j + k == #m + i + (j + k) *)
+	    listof (Data.prove_conv [Data.norm_tac] sg 
+		    (t, 
+		     Data.mk_bal (newshape(n1,terms1'), 
+				  newshape(n2,terms2'))))
   in
-      Data.prove_conv [Data.subst_tac lit_n, Data.all_simp_tac] sg
-	 (t, Data.mk_bal (Data.mk_sum (cancelled n1 n terms1),
-			  Data.mk_sum (cancelled n2 n terms2)))
+
+      if n2<=n1 then 
+	  Data.prove_conv 
+	     [rewrite_goals_tac reshapes, rtac Data.bal_add1 1,
+	      Data.numeral_simp_tac] sg
+	     (t, Data.mk_bal (newshape(n1-n2,terms1'), Data.mk_sum terms2'))
+      else
+	  Data.prove_conv 
+	     [rewrite_goals_tac reshapes, rtac Data.bal_add2 1,
+	      Data.numeral_simp_tac] sg
+	     (t, Data.mk_bal (Data.mk_sum terms1', newshape(n2-n1,terms2')))
   end
-  handle _ => None;
+  handle TERM _ => None;
 
 end;