src/HOL/Tools/numeral_simprocs.ML
author wenzelm
Sat Oct 17 00:52:37 2009 +0200 (2009-10-17)
changeset 32957 675c0c7e6a37
parent 32155 e2bf2f73b0c8
child 33359 8b673ae1bf39
permissions -rw-r--r--
explicitly qualify Drule.standard;
     1 (* Author:     Lawrence C Paulson, Cambridge University Computer Laboratory
     2    Copyright   2000  University of Cambridge
     3 
     4 Simprocs for the integer numerals.
     5 *)
     6 
     7 (*To quote from Provers/Arith/cancel_numeral_factor.ML:
     8 
     9 Cancels common coefficients in balanced expressions:
    10 
    11      u*#m ~~ u'*#m'  ==  #n*u ~~ #n'*u'
    12 
    13 where ~~ is an appropriate balancing operation (e.g. =, <=, <, div, /)
    14 and d = gcd(m,m') and n=m/d and n'=m'/d.
    15 *)
    16 
    17 signature NUMERAL_SIMPROCS =
    18 sig
    19   val mk_sum: typ -> term list -> term
    20   val dest_sum: term -> term list
    21 
    22   val assoc_fold_simproc: simproc
    23   val combine_numerals: simproc
    24   val cancel_numerals: simproc list
    25   val cancel_factors: simproc list
    26   val cancel_numeral_factors: simproc list
    27   val field_combine_numerals: simproc
    28   val field_cancel_numeral_factors: simproc list
    29   val num_ss: simpset
    30 end;
    31 
    32 structure Numeral_Simprocs : NUMERAL_SIMPROCS =
    33 struct
    34 
    35 fun mk_number T n = HOLogic.number_of_const T $ HOLogic.mk_numeral n;
    36 
    37 fun find_first_numeral past (t::terms) =
    38         ((snd (HOLogic.dest_number t), rev past @ terms)
    39          handle TERM _ => find_first_numeral (t::past) terms)
    40   | find_first_numeral past [] = raise TERM("find_first_numeral", []);
    41 
    42 val mk_plus = HOLogic.mk_binop @{const_name HOL.plus};
    43 
    44 fun mk_minus t = 
    45   let val T = Term.fastype_of t
    46   in Const (@{const_name HOL.uminus}, T --> T) $ t end;
    47 
    48 (*Thus mk_sum[t] yields t+0; longer sums don't have a trailing zero*)
    49 fun mk_sum T []        = mk_number T 0
    50   | mk_sum T [t,u]     = mk_plus (t, u)
    51   | mk_sum T (t :: ts) = mk_plus (t, mk_sum T ts);
    52 
    53 (*this version ALWAYS includes a trailing zero*)
    54 fun long_mk_sum T []        = mk_number T 0
    55   | long_mk_sum T (t :: ts) = mk_plus (t, mk_sum T ts);
    56 
    57 val dest_plus = HOLogic.dest_bin @{const_name HOL.plus} Term.dummyT;
    58 
    59 (*decompose additions AND subtractions as a sum*)
    60 fun dest_summing (pos, Const (@{const_name HOL.plus}, _) $ t $ u, ts) =
    61         dest_summing (pos, t, dest_summing (pos, u, ts))
    62   | dest_summing (pos, Const (@{const_name HOL.minus}, _) $ t $ u, ts) =
    63         dest_summing (pos, t, dest_summing (not pos, u, ts))
    64   | dest_summing (pos, t, ts) =
    65         if pos then t::ts else mk_minus t :: ts;
    66 
    67 fun dest_sum t = dest_summing (true, t, []);
    68 
    69 val mk_diff = HOLogic.mk_binop @{const_name HOL.minus};
    70 val dest_diff = HOLogic.dest_bin @{const_name HOL.minus} Term.dummyT;
    71 
    72 val mk_times = HOLogic.mk_binop @{const_name HOL.times};
    73 
    74 fun one_of T = Const(@{const_name HOL.one},T);
    75 
    76 (* build product with trailing 1 rather than Numeral 1 in order to avoid the
    77    unnecessary restriction to type class number_ring
    78    which is not required for cancellation of common factors in divisions.
    79 *)
    80 fun mk_prod T = 
    81   let val one = one_of T
    82   fun mk [] = one
    83     | mk [t] = t
    84     | mk (t :: ts) = if t = one then mk ts else mk_times (t, mk ts)
    85   in mk end;
    86 
    87 (*This version ALWAYS includes a trailing one*)
    88 fun long_mk_prod T []        = one_of T
    89   | long_mk_prod T (t :: ts) = mk_times (t, mk_prod T ts);
    90 
    91 val dest_times = HOLogic.dest_bin @{const_name HOL.times} Term.dummyT;
    92 
    93 fun dest_prod t =
    94       let val (t,u) = dest_times t
    95       in dest_prod t @ dest_prod u end
    96       handle TERM _ => [t];
    97 
    98 (*DON'T do the obvious simplifications; that would create special cases*)
    99 fun mk_coeff (k, t) = mk_times (mk_number (Term.fastype_of t) k, t);
   100 
   101 (*Express t as a product of (possibly) a numeral with other sorted terms*)
   102 fun dest_coeff sign (Const (@{const_name HOL.uminus}, _) $ t) = dest_coeff (~sign) t
   103   | dest_coeff sign t =
   104     let val ts = sort TermOrd.term_ord (dest_prod t)
   105         val (n, ts') = find_first_numeral [] ts
   106                           handle TERM _ => (1, ts)
   107     in (sign*n, mk_prod (Term.fastype_of t) ts') end;
   108 
   109 (*Find first coefficient-term THAT MATCHES u*)
   110 fun find_first_coeff past u [] = raise TERM("find_first_coeff", [])
   111   | find_first_coeff past u (t::terms) =
   112         let val (n,u') = dest_coeff 1 t
   113         in if u aconv u' then (n, rev past @ terms)
   114                          else find_first_coeff (t::past) u terms
   115         end
   116         handle TERM _ => find_first_coeff (t::past) u terms;
   117 
   118 (*Fractions as pairs of ints. Can't use Rat.rat because the representation
   119   needs to preserve negative values in the denominator.*)
   120 fun mk_frac (p, q) = if q = 0 then raise Div else (p, q);
   121 
   122 (*Don't reduce fractions; sums must be proved by rule add_frac_eq.
   123   Fractions are reduced later by the cancel_numeral_factor simproc.*)
   124 fun add_frac ((p1, q1), (p2, q2)) = (p1 * q2 + p2 * q1, q1 * q2);
   125 
   126 val mk_divide = HOLogic.mk_binop @{const_name HOL.divide};
   127 
   128 (*Build term (p / q) * t*)
   129 fun mk_fcoeff ((p, q), t) =
   130   let val T = Term.fastype_of t
   131   in mk_times (mk_divide (mk_number T p, mk_number T q), t) end;
   132 
   133 (*Express t as a product of a fraction with other sorted terms*)
   134 fun dest_fcoeff sign (Const (@{const_name HOL.uminus}, _) $ t) = dest_fcoeff (~sign) t
   135   | dest_fcoeff sign (Const (@{const_name HOL.divide}, _) $ t $ u) =
   136     let val (p, t') = dest_coeff sign t
   137         val (q, u') = dest_coeff 1 u
   138     in (mk_frac (p, q), mk_divide (t', u')) end
   139   | dest_fcoeff sign t =
   140     let val (p, t') = dest_coeff sign t
   141         val T = Term.fastype_of t
   142     in (mk_frac (p, 1), mk_divide (t', one_of T)) end;
   143 
   144 
   145 (** New term ordering so that AC-rewriting brings numerals to the front **)
   146 
   147 (*Order integers by absolute value and then by sign. The standard integer
   148   ordering is not well-founded.*)
   149 fun num_ord (i,j) =
   150   (case int_ord (abs i, abs j) of
   151     EQUAL => int_ord (Int.sign i, Int.sign j) 
   152   | ord => ord);
   153 
   154 (*This resembles TermOrd.term_ord, but it puts binary numerals before other
   155   non-atomic terms.*)
   156 local open Term 
   157 in 
   158 fun numterm_ord (Abs (_, T, t), Abs(_, U, u)) =
   159       (case numterm_ord (t, u) of EQUAL => TermOrd.typ_ord (T, U) | ord => ord)
   160   | numterm_ord
   161      (Const(@{const_name Int.number_of}, _) $ v, Const(@{const_name Int.number_of}, _) $ w) =
   162      num_ord (HOLogic.dest_numeral v, HOLogic.dest_numeral w)
   163   | numterm_ord (Const(@{const_name Int.number_of}, _) $ _, _) = LESS
   164   | numterm_ord (_, Const(@{const_name Int.number_of}, _) $ _) = GREATER
   165   | numterm_ord (t, u) =
   166       (case int_ord (size_of_term t, size_of_term u) of
   167         EQUAL =>
   168           let val (f, ts) = strip_comb t and (g, us) = strip_comb u in
   169             (case TermOrd.hd_ord (f, g) of EQUAL => numterms_ord (ts, us) | ord => ord)
   170           end
   171       | ord => ord)
   172 and numterms_ord (ts, us) = list_ord numterm_ord (ts, us)
   173 end;
   174 
   175 fun numtermless tu = (numterm_ord tu = LESS);
   176 
   177 val num_ss = HOL_ss settermless numtermless;
   178 
   179 (*Maps 0 to Numeral0 and 1 to Numeral1 so that arithmetic isn't complicated by the abstract 0 and 1.*)
   180 val numeral_syms = [@{thm numeral_0_eq_0} RS sym, @{thm numeral_1_eq_1} RS sym];
   181 
   182 (*Simplify Numeral0+n, n+Numeral0, Numeral1*n, n*Numeral1, 1*x, x*1, x/1 *)
   183 val add_0s =  @{thms add_0s};
   184 val mult_1s = @{thms mult_1s mult_1_left mult_1_right divide_1};
   185 
   186 (*Simplify inverse Numeral1, a/Numeral1*)
   187 val inverse_1s = [@{thm inverse_numeral_1}];
   188 val divide_1s = [@{thm divide_numeral_1}];
   189 
   190 (*To perform binary arithmetic.  The "left" rewriting handles patterns
   191   created by the Numeral_Simprocs, such as 3 * (5 * x). *)
   192 val simps = [@{thm numeral_0_eq_0} RS sym, @{thm numeral_1_eq_1} RS sym,
   193                  @{thm add_number_of_left}, @{thm mult_number_of_left}] @
   194                 @{thms arith_simps} @ @{thms rel_simps};
   195 
   196 (*Binary arithmetic BUT NOT ADDITION since it may collapse adjacent terms
   197   during re-arrangement*)
   198 val non_add_simps =
   199   subtract Thm.eq_thm [@{thm add_number_of_left}, @{thm number_of_add} RS sym] simps;
   200 
   201 (*To evaluate binary negations of coefficients*)
   202 val minus_simps = [@{thm numeral_m1_eq_minus_1} RS sym, @{thm number_of_minus} RS sym] @
   203                    @{thms minus_bin_simps} @ @{thms pred_bin_simps};
   204 
   205 (*To let us treat subtraction as addition*)
   206 val diff_simps = [@{thm diff_minus}, @{thm minus_add_distrib}, @{thm minus_minus}];
   207 
   208 (*To let us treat division as multiplication*)
   209 val divide_simps = [@{thm divide_inverse}, @{thm inverse_mult_distrib}, @{thm inverse_inverse_eq}];
   210 
   211 (*push the unary minus down: - x * y = x * - y *)
   212 val minus_mult_eq_1_to_2 =
   213     [@{thm mult_minus_left}, @{thm minus_mult_right}] MRS trans |> Drule.standard;
   214 
   215 (*to extract again any uncancelled minuses*)
   216 val minus_from_mult_simps =
   217     [@{thm minus_minus}, @{thm mult_minus_left}, @{thm mult_minus_right}];
   218 
   219 (*combine unary minus with numeric literals, however nested within a product*)
   220 val mult_minus_simps =
   221     [@{thm mult_assoc}, @{thm minus_mult_left}, minus_mult_eq_1_to_2];
   222 
   223 val norm_ss1 = num_ss addsimps numeral_syms @ add_0s @ mult_1s @
   224   diff_simps @ minus_simps @ @{thms add_ac}
   225 val norm_ss2 = num_ss addsimps non_add_simps @ mult_minus_simps
   226 val norm_ss3 = num_ss addsimps minus_from_mult_simps @ @{thms add_ac} @ @{thms mult_ac}
   227 
   228 structure CancelNumeralsCommon =
   229   struct
   230   val mk_sum            = mk_sum
   231   val dest_sum          = dest_sum
   232   val mk_coeff          = mk_coeff
   233   val dest_coeff        = dest_coeff 1
   234   val find_first_coeff  = find_first_coeff []
   235   fun trans_tac _       = Arith_Data.trans_tac
   236 
   237   fun norm_tac ss =
   238     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   239     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   240     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   241 
   242   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps
   243   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   244   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s)
   245   end;
   246 
   247 
   248 structure EqCancelNumerals = CancelNumeralsFun
   249  (open CancelNumeralsCommon
   250   val prove_conv = Arith_Data.prove_conv
   251   val mk_bal   = HOLogic.mk_eq
   252   val dest_bal = HOLogic.dest_bin "op =" Term.dummyT
   253   val bal_add1 = @{thm eq_add_iff1} RS trans
   254   val bal_add2 = @{thm eq_add_iff2} RS trans
   255 );
   256 
   257 structure LessCancelNumerals = CancelNumeralsFun
   258  (open CancelNumeralsCommon
   259   val prove_conv = Arith_Data.prove_conv
   260   val mk_bal   = HOLogic.mk_binrel @{const_name HOL.less}
   261   val dest_bal = HOLogic.dest_bin @{const_name HOL.less} Term.dummyT
   262   val bal_add1 = @{thm less_add_iff1} RS trans
   263   val bal_add2 = @{thm less_add_iff2} RS trans
   264 );
   265 
   266 structure LeCancelNumerals = CancelNumeralsFun
   267  (open CancelNumeralsCommon
   268   val prove_conv = Arith_Data.prove_conv
   269   val mk_bal   = HOLogic.mk_binrel @{const_name HOL.less_eq}
   270   val dest_bal = HOLogic.dest_bin @{const_name HOL.less_eq} Term.dummyT
   271   val bal_add1 = @{thm le_add_iff1} RS trans
   272   val bal_add2 = @{thm le_add_iff2} RS trans
   273 );
   274 
   275 val cancel_numerals =
   276   map (Arith_Data.prep_simproc @{theory})
   277    [("inteq_cancel_numerals",
   278      ["(l::'a::number_ring) + m = n",
   279       "(l::'a::number_ring) = m + n",
   280       "(l::'a::number_ring) - m = n",
   281       "(l::'a::number_ring) = m - n",
   282       "(l::'a::number_ring) * m = n",
   283       "(l::'a::number_ring) = m * n"],
   284      K EqCancelNumerals.proc),
   285     ("intless_cancel_numerals",
   286      ["(l::'a::{ordered_idom,number_ring}) + m < n",
   287       "(l::'a::{ordered_idom,number_ring}) < m + n",
   288       "(l::'a::{ordered_idom,number_ring}) - m < n",
   289       "(l::'a::{ordered_idom,number_ring}) < m - n",
   290       "(l::'a::{ordered_idom,number_ring}) * m < n",
   291       "(l::'a::{ordered_idom,number_ring}) < m * n"],
   292      K LessCancelNumerals.proc),
   293     ("intle_cancel_numerals",
   294      ["(l::'a::{ordered_idom,number_ring}) + m <= n",
   295       "(l::'a::{ordered_idom,number_ring}) <= m + n",
   296       "(l::'a::{ordered_idom,number_ring}) - m <= n",
   297       "(l::'a::{ordered_idom,number_ring}) <= m - n",
   298       "(l::'a::{ordered_idom,number_ring}) * m <= n",
   299       "(l::'a::{ordered_idom,number_ring}) <= m * n"],
   300      K LeCancelNumerals.proc)];
   301 
   302 structure CombineNumeralsData =
   303   struct
   304   type coeff            = int
   305   val iszero            = (fn x => x = 0)
   306   val add               = op +
   307   val mk_sum            = long_mk_sum    (*to work for e.g. 2*x + 3*x *)
   308   val dest_sum          = dest_sum
   309   val mk_coeff          = mk_coeff
   310   val dest_coeff        = dest_coeff 1
   311   val left_distrib      = @{thm combine_common_factor} RS trans
   312   val prove_conv        = Arith_Data.prove_conv_nohyps
   313   fun trans_tac _       = Arith_Data.trans_tac
   314 
   315   fun norm_tac ss =
   316     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   317     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   318     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   319 
   320   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps
   321   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   322   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s)
   323   end;
   324 
   325 structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData);
   326 
   327 (*Version for fields, where coefficients can be fractions*)
   328 structure FieldCombineNumeralsData =
   329   struct
   330   type coeff            = int * int
   331   val iszero            = (fn (p, q) => p = 0)
   332   val add               = add_frac
   333   val mk_sum            = long_mk_sum
   334   val dest_sum          = dest_sum
   335   val mk_coeff          = mk_fcoeff
   336   val dest_coeff        = dest_fcoeff 1
   337   val left_distrib      = @{thm combine_common_factor} RS trans
   338   val prove_conv        = Arith_Data.prove_conv_nohyps
   339   fun trans_tac _       = Arith_Data.trans_tac
   340 
   341   val norm_ss1a = norm_ss1 addsimps inverse_1s @ divide_simps
   342   fun norm_tac ss =
   343     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1a))
   344     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   345     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   346 
   347   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps @ [@{thm add_frac_eq}]
   348   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   349   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s @ divide_1s)
   350   end;
   351 
   352 structure FieldCombineNumerals = CombineNumeralsFun(FieldCombineNumeralsData);
   353 
   354 val combine_numerals =
   355   Arith_Data.prep_simproc @{theory}
   356     ("int_combine_numerals", 
   357      ["(i::'a::number_ring) + j", "(i::'a::number_ring) - j"], 
   358      K CombineNumerals.proc);
   359 
   360 val field_combine_numerals =
   361   Arith_Data.prep_simproc @{theory}
   362     ("field_combine_numerals", 
   363      ["(i::'a::{number_ring,field,division_by_zero}) + j",
   364       "(i::'a::{number_ring,field,division_by_zero}) - j"], 
   365      K FieldCombineNumerals.proc);
   366 
   367 (** Constant folding for multiplication in semirings **)
   368 
   369 (*We do not need folding for addition: combine_numerals does the same thing*)
   370 
   371 structure Semiring_Times_Assoc_Data : ASSOC_FOLD_DATA =
   372 struct
   373   val assoc_ss = HOL_ss addsimps @{thms mult_ac}
   374   val eq_reflection = eq_reflection
   375   fun is_numeral (Const(@{const_name Int.number_of}, _) $ _) = true
   376     | is_numeral _ = false;
   377 end;
   378 
   379 structure Semiring_Times_Assoc = Assoc_Fold (Semiring_Times_Assoc_Data);
   380 
   381 val assoc_fold_simproc =
   382   Arith_Data.prep_simproc @{theory}
   383    ("semiring_assoc_fold", ["(a::'a::comm_semiring_1_cancel) * b"],
   384     K Semiring_Times_Assoc.proc);
   385 
   386 structure CancelNumeralFactorCommon =
   387   struct
   388   val mk_coeff          = mk_coeff
   389   val dest_coeff        = dest_coeff 1
   390   fun trans_tac _       = Arith_Data.trans_tac
   391 
   392   val norm_ss1 = HOL_ss addsimps minus_from_mult_simps @ mult_1s
   393   val norm_ss2 = HOL_ss addsimps simps @ mult_minus_simps
   394   val norm_ss3 = HOL_ss addsimps @{thms mult_ac}
   395   fun norm_tac ss =
   396     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   397     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   398     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   399 
   400   val numeral_simp_ss = HOL_ss addsimps
   401     [@{thm eq_number_of_eq}, @{thm less_number_of}, @{thm le_number_of}] @ simps
   402   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   403   val simplify_meta_eq = Arith_Data.simplify_meta_eq
   404     [@{thm add_0}, @{thm add_0_right}, @{thm mult_zero_left},
   405       @{thm mult_zero_right}, @{thm mult_Bit1}, @{thm mult_1_right}];
   406   end
   407 
   408 (*Version for semiring_div*)
   409 structure DivCancelNumeralFactor = CancelNumeralFactorFun
   410  (open CancelNumeralFactorCommon
   411   val prove_conv = Arith_Data.prove_conv
   412   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   413   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   414   val cancel = @{thm div_mult_mult1} RS trans
   415   val neg_exchanges = false
   416 )
   417 
   418 (*Version for fields*)
   419 structure DivideCancelNumeralFactor = CancelNumeralFactorFun
   420  (open CancelNumeralFactorCommon
   421   val prove_conv = Arith_Data.prove_conv
   422   val mk_bal   = HOLogic.mk_binop @{const_name HOL.divide}
   423   val dest_bal = HOLogic.dest_bin @{const_name HOL.divide} Term.dummyT
   424   val cancel = @{thm mult_divide_mult_cancel_left} RS trans
   425   val neg_exchanges = false
   426 )
   427 
   428 structure EqCancelNumeralFactor = CancelNumeralFactorFun
   429  (open CancelNumeralFactorCommon
   430   val prove_conv = Arith_Data.prove_conv
   431   val mk_bal   = HOLogic.mk_eq
   432   val dest_bal = HOLogic.dest_bin "op =" Term.dummyT
   433   val cancel = @{thm mult_cancel_left} RS trans
   434   val neg_exchanges = false
   435 )
   436 
   437 structure LessCancelNumeralFactor = CancelNumeralFactorFun
   438  (open CancelNumeralFactorCommon
   439   val prove_conv = Arith_Data.prove_conv
   440   val mk_bal   = HOLogic.mk_binrel @{const_name HOL.less}
   441   val dest_bal = HOLogic.dest_bin @{const_name HOL.less} Term.dummyT
   442   val cancel = @{thm mult_less_cancel_left} RS trans
   443   val neg_exchanges = true
   444 )
   445 
   446 structure LeCancelNumeralFactor = CancelNumeralFactorFun
   447  (open CancelNumeralFactorCommon
   448   val prove_conv = Arith_Data.prove_conv
   449   val mk_bal   = HOLogic.mk_binrel @{const_name HOL.less_eq}
   450   val dest_bal = HOLogic.dest_bin @{const_name HOL.less_eq} Term.dummyT
   451   val cancel = @{thm mult_le_cancel_left} RS trans
   452   val neg_exchanges = true
   453 )
   454 
   455 val cancel_numeral_factors =
   456   map (Arith_Data.prep_simproc @{theory})
   457    [("ring_eq_cancel_numeral_factor",
   458      ["(l::'a::{idom,number_ring}) * m = n",
   459       "(l::'a::{idom,number_ring}) = m * n"],
   460      K EqCancelNumeralFactor.proc),
   461     ("ring_less_cancel_numeral_factor",
   462      ["(l::'a::{ordered_idom,number_ring}) * m < n",
   463       "(l::'a::{ordered_idom,number_ring}) < m * n"],
   464      K LessCancelNumeralFactor.proc),
   465     ("ring_le_cancel_numeral_factor",
   466      ["(l::'a::{ordered_idom,number_ring}) * m <= n",
   467       "(l::'a::{ordered_idom,number_ring}) <= m * n"],
   468      K LeCancelNumeralFactor.proc),
   469     ("int_div_cancel_numeral_factors",
   470      ["((l::'a::{semiring_div,number_ring}) * m) div n",
   471       "(l::'a::{semiring_div,number_ring}) div (m * n)"],
   472      K DivCancelNumeralFactor.proc),
   473     ("divide_cancel_numeral_factor",
   474      ["((l::'a::{division_by_zero,field,number_ring}) * m) / n",
   475       "(l::'a::{division_by_zero,field,number_ring}) / (m * n)",
   476       "((number_of v)::'a::{division_by_zero,field,number_ring}) / (number_of w)"],
   477      K DivideCancelNumeralFactor.proc)];
   478 
   479 val field_cancel_numeral_factors =
   480   map (Arith_Data.prep_simproc @{theory})
   481    [("field_eq_cancel_numeral_factor",
   482      ["(l::'a::{field,number_ring}) * m = n",
   483       "(l::'a::{field,number_ring}) = m * n"],
   484      K EqCancelNumeralFactor.proc),
   485     ("field_cancel_numeral_factor",
   486      ["((l::'a::{division_by_zero,field,number_ring}) * m) / n",
   487       "(l::'a::{division_by_zero,field,number_ring}) / (m * n)",
   488       "((number_of v)::'a::{division_by_zero,field,number_ring}) / (number_of w)"],
   489      K DivideCancelNumeralFactor.proc)]
   490 
   491 
   492 (** Declarations for ExtractCommonTerm **)
   493 
   494 (*Find first term that matches u*)
   495 fun find_first_t past u []         = raise TERM ("find_first_t", [])
   496   | find_first_t past u (t::terms) =
   497         if u aconv t then (rev past @ terms)
   498         else find_first_t (t::past) u terms
   499         handle TERM _ => find_first_t (t::past) u terms;
   500 
   501 (** Final simplification for the CancelFactor simprocs **)
   502 val simplify_one = Arith_Data.simplify_meta_eq  
   503   [@{thm mult_1_left}, @{thm mult_1_right}, @{thm div_by_1}, @{thm numeral_1_eq_1}];
   504 
   505 fun cancel_simplify_meta_eq ss cancel_th th =
   506     simplify_one ss (([th, cancel_th]) MRS trans);
   507 
   508 local
   509   val Tp_Eq = Thm.reflexive (Thm.cterm_of @{theory HOL} HOLogic.Trueprop)
   510   fun Eq_True_elim Eq = 
   511     Thm.equal_elim (Thm.combination Tp_Eq (Thm.symmetric Eq)) @{thm TrueI}
   512 in
   513 fun sign_conv pos_th neg_th ss t =
   514   let val T = fastype_of t;
   515       val zero = Const(@{const_name HOL.zero}, T);
   516       val less = Const(@{const_name HOL.less}, [T,T] ---> HOLogic.boolT);
   517       val pos = less $ zero $ t and neg = less $ t $ zero
   518       fun prove p =
   519         Option.map Eq_True_elim (Lin_Arith.simproc ss p)
   520         handle THM _ => NONE
   521     in case prove pos of
   522          SOME th => SOME(th RS pos_th)
   523        | NONE => (case prove neg of
   524                     SOME th => SOME(th RS neg_th)
   525                   | NONE => NONE)
   526     end;
   527 end
   528 
   529 structure CancelFactorCommon =
   530   struct
   531   val mk_sum            = long_mk_prod
   532   val dest_sum          = dest_prod
   533   val mk_coeff          = mk_coeff
   534   val dest_coeff        = dest_coeff
   535   val find_first        = find_first_t []
   536   fun trans_tac _       = Arith_Data.trans_tac
   537   val norm_ss = HOL_ss addsimps mult_1s @ @{thms mult_ac}
   538   fun norm_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss))
   539   val simplify_meta_eq  = cancel_simplify_meta_eq 
   540   end;
   541 
   542 (*mult_cancel_left requires a ring with no zero divisors.*)
   543 structure EqCancelFactor = ExtractCommonTermFun
   544  (open CancelFactorCommon
   545   val prove_conv = Arith_Data.prove_conv
   546   val mk_bal   = HOLogic.mk_eq
   547   val dest_bal = HOLogic.dest_bin "op =" Term.dummyT
   548   fun simp_conv _ _ = SOME @{thm mult_cancel_left}
   549 );
   550 
   551 (*for ordered rings*)
   552 structure LeCancelFactor = ExtractCommonTermFun
   553  (open CancelFactorCommon
   554   val prove_conv = Arith_Data.prove_conv
   555   val mk_bal   = HOLogic.mk_binrel @{const_name HOL.less_eq}
   556   val dest_bal = HOLogic.dest_bin @{const_name HOL.less_eq} Term.dummyT
   557   val simp_conv = sign_conv
   558     @{thm mult_le_cancel_left_pos} @{thm mult_le_cancel_left_neg}
   559 );
   560 
   561 (*for ordered rings*)
   562 structure LessCancelFactor = ExtractCommonTermFun
   563  (open CancelFactorCommon
   564   val prove_conv = Arith_Data.prove_conv
   565   val mk_bal   = HOLogic.mk_binrel @{const_name HOL.less}
   566   val dest_bal = HOLogic.dest_bin @{const_name HOL.less} Term.dummyT
   567   val simp_conv = sign_conv
   568     @{thm mult_less_cancel_left_pos} @{thm mult_less_cancel_left_neg}
   569 );
   570 
   571 (*for semirings with division*)
   572 structure DivCancelFactor = ExtractCommonTermFun
   573  (open CancelFactorCommon
   574   val prove_conv = Arith_Data.prove_conv
   575   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   576   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   577   fun simp_conv _ _ = SOME @{thm div_mult_mult1_if}
   578 );
   579 
   580 structure ModCancelFactor = ExtractCommonTermFun
   581  (open CancelFactorCommon
   582   val prove_conv = Arith_Data.prove_conv
   583   val mk_bal   = HOLogic.mk_binop @{const_name Divides.mod}
   584   val dest_bal = HOLogic.dest_bin @{const_name Divides.mod} Term.dummyT
   585   fun simp_conv _ _ = SOME @{thm mod_mult_mult1}
   586 );
   587 
   588 (*for idoms*)
   589 structure DvdCancelFactor = ExtractCommonTermFun
   590  (open CancelFactorCommon
   591   val prove_conv = Arith_Data.prove_conv
   592   val mk_bal   = HOLogic.mk_binrel @{const_name Ring_and_Field.dvd}
   593   val dest_bal = HOLogic.dest_bin @{const_name Ring_and_Field.dvd} Term.dummyT
   594   fun simp_conv _ _ = SOME @{thm dvd_mult_cancel_left}
   595 );
   596 
   597 (*Version for all fields, including unordered ones (type complex).*)
   598 structure DivideCancelFactor = ExtractCommonTermFun
   599  (open CancelFactorCommon
   600   val prove_conv = Arith_Data.prove_conv
   601   val mk_bal   = HOLogic.mk_binop @{const_name HOL.divide}
   602   val dest_bal = HOLogic.dest_bin @{const_name HOL.divide} Term.dummyT
   603   fun simp_conv _ _ = SOME @{thm mult_divide_mult_cancel_left_if}
   604 );
   605 
   606 val cancel_factors =
   607   map (Arith_Data.prep_simproc @{theory})
   608    [("ring_eq_cancel_factor",
   609      ["(l::'a::idom) * m = n",
   610       "(l::'a::idom) = m * n"],
   611      K EqCancelFactor.proc),
   612     ("ordered_ring_le_cancel_factor",
   613      ["(l::'a::ordered_ring) * m <= n",
   614       "(l::'a::ordered_ring) <= m * n"],
   615      K LeCancelFactor.proc),
   616     ("ordered_ring_less_cancel_factor",
   617      ["(l::'a::ordered_ring) * m < n",
   618       "(l::'a::ordered_ring) < m * n"],
   619      K LessCancelFactor.proc),
   620     ("int_div_cancel_factor",
   621      ["((l::'a::semiring_div) * m) div n", "(l::'a::semiring_div) div (m * n)"],
   622      K DivCancelFactor.proc),
   623     ("int_mod_cancel_factor",
   624      ["((l::'a::semiring_div) * m) mod n", "(l::'a::semiring_div) mod (m * n)"],
   625      K ModCancelFactor.proc),
   626     ("dvd_cancel_factor",
   627      ["((l::'a::idom) * m) dvd n", "(l::'a::idom) dvd (m * n)"],
   628      K DvdCancelFactor.proc),
   629     ("divide_cancel_factor",
   630      ["((l::'a::{division_by_zero,field}) * m) / n",
   631       "(l::'a::{division_by_zero,field}) / (m * n)"],
   632      K DivideCancelFactor.proc)];
   633 
   634 end;
   635 
   636 Addsimprocs Numeral_Simprocs.cancel_numerals;
   637 Addsimprocs [Numeral_Simprocs.combine_numerals];
   638 Addsimprocs [Numeral_Simprocs.field_combine_numerals];
   639 Addsimprocs [Numeral_Simprocs.assoc_fold_simproc];
   640 
   641 (*examples:
   642 print_depth 22;
   643 set timing;
   644 set trace_simp;
   645 fun test s = (Goal s, by (Simp_tac 1));
   646 
   647 test "l + 2 + 2 + 2 + (l + 2) + (oo + 2) = (uu::int)";
   648 
   649 test "2*u = (u::int)";
   650 test "(i + j + 12 + (k::int)) - 15 = y";
   651 test "(i + j + 12 + (k::int)) - 5 = y";
   652 
   653 test "y - b < (b::int)";
   654 test "y - (3*b + c) < (b::int) - 2*c";
   655 
   656 test "(2*x - (u*v) + y) - v*3*u = (w::int)";
   657 test "(2*x*u*v + (u*v)*4 + y) - v*u*4 = (w::int)";
   658 test "(2*x*u*v + (u*v)*4 + y) - v*u = (w::int)";
   659 test "u*v - (x*u*v + (u*v)*4 + y) = (w::int)";
   660 
   661 test "(i + j + 12 + (k::int)) = u + 15 + y";
   662 test "(i + j*2 + 12 + (k::int)) = j + 5 + y";
   663 
   664 test "2*y + 3*z + 6*w + 2*y + 3*z + 2*u = 2*y' + 3*z' + 6*w' + 2*y' + 3*z' + u + (vv::int)";
   665 
   666 test "a + -(b+c) + b = (d::int)";
   667 test "a + -(b+c) - b = (d::int)";
   668 
   669 (*negative numerals*)
   670 test "(i + j + -2 + (k::int)) - (u + 5 + y) = zz";
   671 test "(i + j + -3 + (k::int)) < u + 5 + y";
   672 test "(i + j + 3 + (k::int)) < u + -6 + y";
   673 test "(i + j + -12 + (k::int)) - 15 = y";
   674 test "(i + j + 12 + (k::int)) - -15 = y";
   675 test "(i + j + -12 + (k::int)) - -15 = y";
   676 *)
   677 
   678 Addsimprocs Numeral_Simprocs.cancel_numeral_factors;
   679 
   680 (*examples:
   681 print_depth 22;
   682 set timing;
   683 set trace_simp;
   684 fun test s = (Goal s; by (Simp_tac 1));
   685 
   686 test "9*x = 12 * (y::int)";
   687 test "(9*x) div (12 * (y::int)) = z";
   688 test "9*x < 12 * (y::int)";
   689 test "9*x <= 12 * (y::int)";
   690 
   691 test "-99*x = 132 * (y::int)";
   692 test "(-99*x) div (132 * (y::int)) = z";
   693 test "-99*x < 132 * (y::int)";
   694 test "-99*x <= 132 * (y::int)";
   695 
   696 test "999*x = -396 * (y::int)";
   697 test "(999*x) div (-396 * (y::int)) = z";
   698 test "999*x < -396 * (y::int)";
   699 test "999*x <= -396 * (y::int)";
   700 
   701 test "-99*x = -81 * (y::int)";
   702 test "(-99*x) div (-81 * (y::int)) = z";
   703 test "-99*x <= -81 * (y::int)";
   704 test "-99*x < -81 * (y::int)";
   705 
   706 test "-2 * x = -1 * (y::int)";
   707 test "-2 * x = -(y::int)";
   708 test "(-2 * x) div (-1 * (y::int)) = z";
   709 test "-2 * x < -(y::int)";
   710 test "-2 * x <= -1 * (y::int)";
   711 test "-x < -23 * (y::int)";
   712 test "-x <= -23 * (y::int)";
   713 *)
   714 
   715 (*And the same examples for fields such as rat or real:
   716 test "0 <= (y::rat) * -2";
   717 test "9*x = 12 * (y::rat)";
   718 test "(9*x) / (12 * (y::rat)) = z";
   719 test "9*x < 12 * (y::rat)";
   720 test "9*x <= 12 * (y::rat)";
   721 
   722 test "-99*x = 132 * (y::rat)";
   723 test "(-99*x) / (132 * (y::rat)) = z";
   724 test "-99*x < 132 * (y::rat)";
   725 test "-99*x <= 132 * (y::rat)";
   726 
   727 test "999*x = -396 * (y::rat)";
   728 test "(999*x) / (-396 * (y::rat)) = z";
   729 test "999*x < -396 * (y::rat)";
   730 test "999*x <= -396 * (y::rat)";
   731 
   732 test  "(- ((2::rat) * x) <= 2 * y)";
   733 test "-99*x = -81 * (y::rat)";
   734 test "(-99*x) / (-81 * (y::rat)) = z";
   735 test "-99*x <= -81 * (y::rat)";
   736 test "-99*x < -81 * (y::rat)";
   737 
   738 test "-2 * x = -1 * (y::rat)";
   739 test "-2 * x = -(y::rat)";
   740 test "(-2 * x) / (-1 * (y::rat)) = z";
   741 test "-2 * x < -(y::rat)";
   742 test "-2 * x <= -1 * (y::rat)";
   743 test "-x < -23 * (y::rat)";
   744 test "-x <= -23 * (y::rat)";
   745 *)
   746 
   747 Addsimprocs Numeral_Simprocs.cancel_factors;
   748 
   749 
   750 (*examples:
   751 print_depth 22;
   752 set timing;
   753 set trace_simp;
   754 fun test s = (Goal s; by (Asm_simp_tac 1));
   755 
   756 test "x*k = k*(y::int)";
   757 test "k = k*(y::int)";
   758 test "a*(b*c) = (b::int)";
   759 test "a*(b*c) = d*(b::int)*(x*a)";
   760 
   761 test "(x*k) div (k*(y::int)) = (uu::int)";
   762 test "(k) div (k*(y::int)) = (uu::int)";
   763 test "(a*(b*c)) div ((b::int)) = (uu::int)";
   764 test "(a*(b*c)) div (d*(b::int)*(x*a)) = (uu::int)";
   765 *)
   766 
   767 (*And the same examples for fields such as rat or real:
   768 print_depth 22;
   769 set timing;
   770 set trace_simp;
   771 fun test s = (Goal s; by (Asm_simp_tac 1));
   772 
   773 test "x*k = k*(y::rat)";
   774 test "k = k*(y::rat)";
   775 test "a*(b*c) = (b::rat)";
   776 test "a*(b*c) = d*(b::rat)*(x*a)";
   777 
   778 
   779 test "(x*k) / (k*(y::rat)) = (uu::rat)";
   780 test "(k) / (k*(y::rat)) = (uu::rat)";
   781 test "(a*(b*c)) / ((b::rat)) = (uu::rat)";
   782 test "(a*(b*c)) / (d*(b::rat)*(x*a)) = (uu::rat)";
   783 
   784 (*FIXME: what do we do about this?*)
   785 test "a*(b*c)/(y*z) = d*(b::rat)*(x*a)/z";
   786 *)