src/HOL/Tools/numeral_simprocs.ML
author wenzelm
Sun Feb 07 19:31:55 2010 +0100 (2010-02-07)
changeset 35020 862a20ffa8e2
parent 34974 18b41bba42b5
child 35030 f2f1e50bf65d
permissions -rw-r--r--
prefer explicit @{lemma} over adhoc forward reasoning;
     1 (* Author:     Lawrence C Paulson, Cambridge University Computer Laboratory
     2    Copyright   2000  University of Cambridge
     3 
     4 Simprocs for the integer numerals.
     5 *)
     6 
     7 (*To quote from Provers/Arith/cancel_numeral_factor.ML:
     8 
     9 Cancels common coefficients in balanced expressions:
    10 
    11      u*#m ~~ u'*#m'  ==  #n*u ~~ #n'*u'
    12 
    13 where ~~ is an appropriate balancing operation (e.g. =, <=, <, div, /)
    14 and d = gcd(m,m') and n=m/d and n'=m'/d.
    15 *)
    16 
    17 signature NUMERAL_SIMPROCS =
    18 sig
    19   val assoc_fold_simproc: simproc
    20   val combine_numerals: simproc
    21   val cancel_numerals: simproc list
    22   val cancel_factors: simproc list
    23   val cancel_numeral_factors: simproc list
    24   val field_combine_numerals: simproc
    25   val field_cancel_numeral_factors: simproc list
    26   val num_ss: simpset
    27 end;
    28 
    29 structure Numeral_Simprocs : NUMERAL_SIMPROCS =
    30 struct
    31 
    32 val mk_number = Arith_Data.mk_number;
    33 val mk_sum = Arith_Data.mk_sum;
    34 val long_mk_sum = Arith_Data.long_mk_sum;
    35 val dest_sum = Arith_Data.dest_sum;
    36 
    37 val mk_diff = HOLogic.mk_binop @{const_name Algebras.minus};
    38 val dest_diff = HOLogic.dest_bin @{const_name Algebras.minus} Term.dummyT;
    39 
    40 val mk_times = HOLogic.mk_binop @{const_name Algebras.times};
    41 
    42 fun one_of T = Const(@{const_name Algebras.one}, T);
    43 
    44 (* build product with trailing 1 rather than Numeral 1 in order to avoid the
    45    unnecessary restriction to type class number_ring
    46    which is not required for cancellation of common factors in divisions.
    47 *)
    48 fun mk_prod T = 
    49   let val one = one_of T
    50   fun mk [] = one
    51     | mk [t] = t
    52     | mk (t :: ts) = if t = one then mk ts else mk_times (t, mk ts)
    53   in mk end;
    54 
    55 (*This version ALWAYS includes a trailing one*)
    56 fun long_mk_prod T []        = one_of T
    57   | long_mk_prod T (t :: ts) = mk_times (t, mk_prod T ts);
    58 
    59 val dest_times = HOLogic.dest_bin @{const_name Algebras.times} Term.dummyT;
    60 
    61 fun dest_prod t =
    62       let val (t,u) = dest_times t
    63       in dest_prod t @ dest_prod u end
    64       handle TERM _ => [t];
    65 
    66 fun find_first_numeral past (t::terms) =
    67         ((snd (HOLogic.dest_number t), rev past @ terms)
    68          handle TERM _ => find_first_numeral (t::past) terms)
    69   | find_first_numeral past [] = raise TERM("find_first_numeral", []);
    70 
    71 (*DON'T do the obvious simplifications; that would create special cases*)
    72 fun mk_coeff (k, t) = mk_times (mk_number (Term.fastype_of t) k, t);
    73 
    74 (*Express t as a product of (possibly) a numeral with other sorted terms*)
    75 fun dest_coeff sign (Const (@{const_name Algebras.uminus}, _) $ t) = dest_coeff (~sign) t
    76   | dest_coeff sign t =
    77     let val ts = sort TermOrd.term_ord (dest_prod t)
    78         val (n, ts') = find_first_numeral [] ts
    79                           handle TERM _ => (1, ts)
    80     in (sign*n, mk_prod (Term.fastype_of t) ts') end;
    81 
    82 (*Find first coefficient-term THAT MATCHES u*)
    83 fun find_first_coeff past u [] = raise TERM("find_first_coeff", [])
    84   | find_first_coeff past u (t::terms) =
    85         let val (n,u') = dest_coeff 1 t
    86         in if u aconv u' then (n, rev past @ terms)
    87                          else find_first_coeff (t::past) u terms
    88         end
    89         handle TERM _ => find_first_coeff (t::past) u terms;
    90 
    91 (*Fractions as pairs of ints. Can't use Rat.rat because the representation
    92   needs to preserve negative values in the denominator.*)
    93 fun mk_frac (p, q) = if q = 0 then raise Div else (p, q);
    94 
    95 (*Don't reduce fractions; sums must be proved by rule add_frac_eq.
    96   Fractions are reduced later by the cancel_numeral_factor simproc.*)
    97 fun add_frac ((p1, q1), (p2, q2)) = (p1 * q2 + p2 * q1, q1 * q2);
    98 
    99 val mk_divide = HOLogic.mk_binop @{const_name Algebras.divide};
   100 
   101 (*Build term (p / q) * t*)
   102 fun mk_fcoeff ((p, q), t) =
   103   let val T = Term.fastype_of t
   104   in mk_times (mk_divide (mk_number T p, mk_number T q), t) end;
   105 
   106 (*Express t as a product of a fraction with other sorted terms*)
   107 fun dest_fcoeff sign (Const (@{const_name Algebras.uminus}, _) $ t) = dest_fcoeff (~sign) t
   108   | dest_fcoeff sign (Const (@{const_name Algebras.divide}, _) $ t $ u) =
   109     let val (p, t') = dest_coeff sign t
   110         val (q, u') = dest_coeff 1 u
   111     in (mk_frac (p, q), mk_divide (t', u')) end
   112   | dest_fcoeff sign t =
   113     let val (p, t') = dest_coeff sign t
   114         val T = Term.fastype_of t
   115     in (mk_frac (p, 1), mk_divide (t', one_of T)) end;
   116 
   117 
   118 (** New term ordering so that AC-rewriting brings numerals to the front **)
   119 
   120 (*Order integers by absolute value and then by sign. The standard integer
   121   ordering is not well-founded.*)
   122 fun num_ord (i,j) =
   123   (case int_ord (abs i, abs j) of
   124     EQUAL => int_ord (Int.sign i, Int.sign j) 
   125   | ord => ord);
   126 
   127 (*This resembles TermOrd.term_ord, but it puts binary numerals before other
   128   non-atomic terms.*)
   129 local open Term 
   130 in 
   131 fun numterm_ord (Abs (_, T, t), Abs(_, U, u)) =
   132       (case numterm_ord (t, u) of EQUAL => TermOrd.typ_ord (T, U) | ord => ord)
   133   | numterm_ord
   134      (Const(@{const_name Int.number_of}, _) $ v, Const(@{const_name Int.number_of}, _) $ w) =
   135      num_ord (HOLogic.dest_numeral v, HOLogic.dest_numeral w)
   136   | numterm_ord (Const(@{const_name Int.number_of}, _) $ _, _) = LESS
   137   | numterm_ord (_, Const(@{const_name Int.number_of}, _) $ _) = GREATER
   138   | numterm_ord (t, u) =
   139       (case int_ord (size_of_term t, size_of_term u) of
   140         EQUAL =>
   141           let val (f, ts) = strip_comb t and (g, us) = strip_comb u in
   142             (case TermOrd.hd_ord (f, g) of EQUAL => numterms_ord (ts, us) | ord => ord)
   143           end
   144       | ord => ord)
   145 and numterms_ord (ts, us) = list_ord numterm_ord (ts, us)
   146 end;
   147 
   148 fun numtermless tu = (numterm_ord tu = LESS);
   149 
   150 val num_ss = HOL_ss settermless numtermless;
   151 
   152 (*Maps 0 to Numeral0 and 1 to Numeral1 so that arithmetic isn't complicated by the abstract 0 and 1.*)
   153 val numeral_syms = [@{thm numeral_0_eq_0} RS sym, @{thm numeral_1_eq_1} RS sym];
   154 
   155 (*Simplify Numeral0+n, n+Numeral0, Numeral1*n, n*Numeral1, 1*x, x*1, x/1 *)
   156 val add_0s =  @{thms add_0s};
   157 val mult_1s = @{thms mult_1s mult_1_left mult_1_right divide_1};
   158 
   159 (*Simplify inverse Numeral1, a/Numeral1*)
   160 val inverse_1s = [@{thm inverse_numeral_1}];
   161 val divide_1s = [@{thm divide_numeral_1}];
   162 
   163 (*To perform binary arithmetic.  The "left" rewriting handles patterns
   164   created by the Numeral_Simprocs, such as 3 * (5 * x). *)
   165 val simps = [@{thm numeral_0_eq_0} RS sym, @{thm numeral_1_eq_1} RS sym,
   166                  @{thm add_number_of_left}, @{thm mult_number_of_left}] @
   167                 @{thms arith_simps} @ @{thms rel_simps};
   168 
   169 (*Binary arithmetic BUT NOT ADDITION since it may collapse adjacent terms
   170   during re-arrangement*)
   171 val non_add_simps =
   172   subtract Thm.eq_thm [@{thm add_number_of_left}, @{thm number_of_add} RS sym] simps;
   173 
   174 (*To evaluate binary negations of coefficients*)
   175 val minus_simps = [@{thm numeral_m1_eq_minus_1} RS sym, @{thm number_of_minus} RS sym] @
   176                    @{thms minus_bin_simps} @ @{thms pred_bin_simps};
   177 
   178 (*To let us treat subtraction as addition*)
   179 val diff_simps = [@{thm diff_minus}, @{thm minus_add_distrib}, @{thm minus_minus}];
   180 
   181 (*To let us treat division as multiplication*)
   182 val divide_simps = [@{thm divide_inverse}, @{thm inverse_mult_distrib}, @{thm inverse_inverse_eq}];
   183 
   184 (*push the unary minus down*)
   185 val minus_mult_eq_1_to_2 = @{lemma "- (a::'a::ring) * b = a * - b" by simp};
   186 
   187 (*to extract again any uncancelled minuses*)
   188 val minus_from_mult_simps =
   189     [@{thm minus_minus}, @{thm mult_minus_left}, @{thm mult_minus_right}];
   190 
   191 (*combine unary minus with numeric literals, however nested within a product*)
   192 val mult_minus_simps =
   193     [@{thm mult_assoc}, @{thm minus_mult_left}, minus_mult_eq_1_to_2];
   194 
   195 val norm_ss1 = num_ss addsimps numeral_syms @ add_0s @ mult_1s @
   196   diff_simps @ minus_simps @ @{thms add_ac}
   197 val norm_ss2 = num_ss addsimps non_add_simps @ mult_minus_simps
   198 val norm_ss3 = num_ss addsimps minus_from_mult_simps @ @{thms add_ac} @ @{thms mult_ac}
   199 
   200 structure CancelNumeralsCommon =
   201   struct
   202   val mk_sum            = mk_sum
   203   val dest_sum          = dest_sum
   204   val mk_coeff          = mk_coeff
   205   val dest_coeff        = dest_coeff 1
   206   val find_first_coeff  = find_first_coeff []
   207   fun trans_tac _       = Arith_Data.trans_tac
   208 
   209   fun norm_tac ss =
   210     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   211     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   212     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   213 
   214   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps
   215   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   216   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s)
   217   end;
   218 
   219 
   220 structure EqCancelNumerals = CancelNumeralsFun
   221  (open CancelNumeralsCommon
   222   val prove_conv = Arith_Data.prove_conv
   223   val mk_bal   = HOLogic.mk_eq
   224   val dest_bal = HOLogic.dest_bin "op =" Term.dummyT
   225   val bal_add1 = @{thm eq_add_iff1} RS trans
   226   val bal_add2 = @{thm eq_add_iff2} RS trans
   227 );
   228 
   229 structure LessCancelNumerals = CancelNumeralsFun
   230  (open CancelNumeralsCommon
   231   val prove_conv = Arith_Data.prove_conv
   232   val mk_bal   = HOLogic.mk_binrel @{const_name Algebras.less}
   233   val dest_bal = HOLogic.dest_bin @{const_name Algebras.less} Term.dummyT
   234   val bal_add1 = @{thm less_add_iff1} RS trans
   235   val bal_add2 = @{thm less_add_iff2} RS trans
   236 );
   237 
   238 structure LeCancelNumerals = CancelNumeralsFun
   239  (open CancelNumeralsCommon
   240   val prove_conv = Arith_Data.prove_conv
   241   val mk_bal   = HOLogic.mk_binrel @{const_name Algebras.less_eq}
   242   val dest_bal = HOLogic.dest_bin @{const_name Algebras.less_eq} Term.dummyT
   243   val bal_add1 = @{thm le_add_iff1} RS trans
   244   val bal_add2 = @{thm le_add_iff2} RS trans
   245 );
   246 
   247 val cancel_numerals =
   248   map (Arith_Data.prep_simproc @{theory})
   249    [("inteq_cancel_numerals",
   250      ["(l::'a::number_ring) + m = n",
   251       "(l::'a::number_ring) = m + n",
   252       "(l::'a::number_ring) - m = n",
   253       "(l::'a::number_ring) = m - n",
   254       "(l::'a::number_ring) * m = n",
   255       "(l::'a::number_ring) = m * n"],
   256      K EqCancelNumerals.proc),
   257     ("intless_cancel_numerals",
   258      ["(l::'a::{ordered_idom,number_ring}) + m < n",
   259       "(l::'a::{ordered_idom,number_ring}) < m + n",
   260       "(l::'a::{ordered_idom,number_ring}) - m < n",
   261       "(l::'a::{ordered_idom,number_ring}) < m - n",
   262       "(l::'a::{ordered_idom,number_ring}) * m < n",
   263       "(l::'a::{ordered_idom,number_ring}) < m * n"],
   264      K LessCancelNumerals.proc),
   265     ("intle_cancel_numerals",
   266      ["(l::'a::{ordered_idom,number_ring}) + m <= n",
   267       "(l::'a::{ordered_idom,number_ring}) <= m + n",
   268       "(l::'a::{ordered_idom,number_ring}) - m <= n",
   269       "(l::'a::{ordered_idom,number_ring}) <= m - n",
   270       "(l::'a::{ordered_idom,number_ring}) * m <= n",
   271       "(l::'a::{ordered_idom,number_ring}) <= m * n"],
   272      K LeCancelNumerals.proc)];
   273 
   274 structure CombineNumeralsData =
   275   struct
   276   type coeff            = int
   277   val iszero            = (fn x => x = 0)
   278   val add               = op +
   279   val mk_sum            = long_mk_sum    (*to work for e.g. 2*x + 3*x *)
   280   val dest_sum          = dest_sum
   281   val mk_coeff          = mk_coeff
   282   val dest_coeff        = dest_coeff 1
   283   val left_distrib      = @{thm combine_common_factor} RS trans
   284   val prove_conv        = Arith_Data.prove_conv_nohyps
   285   fun trans_tac _       = Arith_Data.trans_tac
   286 
   287   fun norm_tac ss =
   288     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   289     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   290     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   291 
   292   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps
   293   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   294   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s)
   295   end;
   296 
   297 structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData);
   298 
   299 (*Version for fields, where coefficients can be fractions*)
   300 structure FieldCombineNumeralsData =
   301   struct
   302   type coeff            = int * int
   303   val iszero            = (fn (p, q) => p = 0)
   304   val add               = add_frac
   305   val mk_sum            = long_mk_sum
   306   val dest_sum          = dest_sum
   307   val mk_coeff          = mk_fcoeff
   308   val dest_coeff        = dest_fcoeff 1
   309   val left_distrib      = @{thm combine_common_factor} RS trans
   310   val prove_conv        = Arith_Data.prove_conv_nohyps
   311   fun trans_tac _       = Arith_Data.trans_tac
   312 
   313   val norm_ss1a = norm_ss1 addsimps inverse_1s @ divide_simps
   314   fun norm_tac ss =
   315     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1a))
   316     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   317     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   318 
   319   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps @ [@{thm add_frac_eq}]
   320   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   321   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s @ divide_1s)
   322   end;
   323 
   324 structure FieldCombineNumerals = CombineNumeralsFun(FieldCombineNumeralsData);
   325 
   326 val combine_numerals =
   327   Arith_Data.prep_simproc @{theory}
   328     ("int_combine_numerals", 
   329      ["(i::'a::number_ring) + j", "(i::'a::number_ring) - j"], 
   330      K CombineNumerals.proc);
   331 
   332 val field_combine_numerals =
   333   Arith_Data.prep_simproc @{theory}
   334     ("field_combine_numerals", 
   335      ["(i::'a::{number_ring,field,division_by_zero}) + j",
   336       "(i::'a::{number_ring,field,division_by_zero}) - j"], 
   337      K FieldCombineNumerals.proc);
   338 
   339 (** Constant folding for multiplication in semirings **)
   340 
   341 (*We do not need folding for addition: combine_numerals does the same thing*)
   342 
   343 structure Semiring_Times_Assoc_Data : ASSOC_FOLD_DATA =
   344 struct
   345   val assoc_ss = HOL_ss addsimps @{thms mult_ac}
   346   val eq_reflection = eq_reflection
   347   fun is_numeral (Const(@{const_name Int.number_of}, _) $ _) = true
   348     | is_numeral _ = false;
   349 end;
   350 
   351 structure Semiring_Times_Assoc = Assoc_Fold (Semiring_Times_Assoc_Data);
   352 
   353 val assoc_fold_simproc =
   354   Arith_Data.prep_simproc @{theory}
   355    ("semiring_assoc_fold", ["(a::'a::comm_semiring_1_cancel) * b"],
   356     K Semiring_Times_Assoc.proc);
   357 
   358 structure CancelNumeralFactorCommon =
   359   struct
   360   val mk_coeff          = mk_coeff
   361   val dest_coeff        = dest_coeff 1
   362   fun trans_tac _       = Arith_Data.trans_tac
   363 
   364   val norm_ss1 = HOL_ss addsimps minus_from_mult_simps @ mult_1s
   365   val norm_ss2 = HOL_ss addsimps simps @ mult_minus_simps
   366   val norm_ss3 = HOL_ss addsimps @{thms mult_ac}
   367   fun norm_tac ss =
   368     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   369     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   370     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   371 
   372   val numeral_simp_ss = HOL_ss addsimps
   373     [@{thm eq_number_of_eq}, @{thm less_number_of}, @{thm le_number_of}] @ simps
   374   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   375   val simplify_meta_eq = Arith_Data.simplify_meta_eq
   376     [@{thm add_0}, @{thm add_0_right}, @{thm mult_zero_left},
   377       @{thm mult_zero_right}, @{thm mult_Bit1}, @{thm mult_1_right}];
   378   end
   379 
   380 (*Version for semiring_div*)
   381 structure DivCancelNumeralFactor = CancelNumeralFactorFun
   382  (open CancelNumeralFactorCommon
   383   val prove_conv = Arith_Data.prove_conv
   384   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   385   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   386   val cancel = @{thm div_mult_mult1} RS trans
   387   val neg_exchanges = false
   388 )
   389 
   390 (*Version for fields*)
   391 structure DivideCancelNumeralFactor = CancelNumeralFactorFun
   392  (open CancelNumeralFactorCommon
   393   val prove_conv = Arith_Data.prove_conv
   394   val mk_bal   = HOLogic.mk_binop @{const_name Algebras.divide}
   395   val dest_bal = HOLogic.dest_bin @{const_name Algebras.divide} Term.dummyT
   396   val cancel = @{thm mult_divide_mult_cancel_left} RS trans
   397   val neg_exchanges = false
   398 )
   399 
   400 structure EqCancelNumeralFactor = CancelNumeralFactorFun
   401  (open CancelNumeralFactorCommon
   402   val prove_conv = Arith_Data.prove_conv
   403   val mk_bal   = HOLogic.mk_eq
   404   val dest_bal = HOLogic.dest_bin "op =" Term.dummyT
   405   val cancel = @{thm mult_cancel_left} RS trans
   406   val neg_exchanges = false
   407 )
   408 
   409 structure LessCancelNumeralFactor = CancelNumeralFactorFun
   410  (open CancelNumeralFactorCommon
   411   val prove_conv = Arith_Data.prove_conv
   412   val mk_bal   = HOLogic.mk_binrel @{const_name Algebras.less}
   413   val dest_bal = HOLogic.dest_bin @{const_name Algebras.less} Term.dummyT
   414   val cancel = @{thm mult_less_cancel_left} RS trans
   415   val neg_exchanges = true
   416 )
   417 
   418 structure LeCancelNumeralFactor = CancelNumeralFactorFun
   419  (open CancelNumeralFactorCommon
   420   val prove_conv = Arith_Data.prove_conv
   421   val mk_bal   = HOLogic.mk_binrel @{const_name Algebras.less_eq}
   422   val dest_bal = HOLogic.dest_bin @{const_name Algebras.less_eq} Term.dummyT
   423   val cancel = @{thm mult_le_cancel_left} RS trans
   424   val neg_exchanges = true
   425 )
   426 
   427 val cancel_numeral_factors =
   428   map (Arith_Data.prep_simproc @{theory})
   429    [("ring_eq_cancel_numeral_factor",
   430      ["(l::'a::{idom,number_ring}) * m = n",
   431       "(l::'a::{idom,number_ring}) = m * n"],
   432      K EqCancelNumeralFactor.proc),
   433     ("ring_less_cancel_numeral_factor",
   434      ["(l::'a::{ordered_idom,number_ring}) * m < n",
   435       "(l::'a::{ordered_idom,number_ring}) < m * n"],
   436      K LessCancelNumeralFactor.proc),
   437     ("ring_le_cancel_numeral_factor",
   438      ["(l::'a::{ordered_idom,number_ring}) * m <= n",
   439       "(l::'a::{ordered_idom,number_ring}) <= m * n"],
   440      K LeCancelNumeralFactor.proc),
   441     ("int_div_cancel_numeral_factors",
   442      ["((l::'a::{semiring_div,number_ring}) * m) div n",
   443       "(l::'a::{semiring_div,number_ring}) div (m * n)"],
   444      K DivCancelNumeralFactor.proc),
   445     ("divide_cancel_numeral_factor",
   446      ["((l::'a::{division_by_zero,field,number_ring}) * m) / n",
   447       "(l::'a::{division_by_zero,field,number_ring}) / (m * n)",
   448       "((number_of v)::'a::{division_by_zero,field,number_ring}) / (number_of w)"],
   449      K DivideCancelNumeralFactor.proc)];
   450 
   451 val field_cancel_numeral_factors =
   452   map (Arith_Data.prep_simproc @{theory})
   453    [("field_eq_cancel_numeral_factor",
   454      ["(l::'a::{field,number_ring}) * m = n",
   455       "(l::'a::{field,number_ring}) = m * n"],
   456      K EqCancelNumeralFactor.proc),
   457     ("field_cancel_numeral_factor",
   458      ["((l::'a::{division_by_zero,field,number_ring}) * m) / n",
   459       "(l::'a::{division_by_zero,field,number_ring}) / (m * n)",
   460       "((number_of v)::'a::{division_by_zero,field,number_ring}) / (number_of w)"],
   461      K DivideCancelNumeralFactor.proc)]
   462 
   463 
   464 (** Declarations for ExtractCommonTerm **)
   465 
   466 (*Find first term that matches u*)
   467 fun find_first_t past u []         = raise TERM ("find_first_t", [])
   468   | find_first_t past u (t::terms) =
   469         if u aconv t then (rev past @ terms)
   470         else find_first_t (t::past) u terms
   471         handle TERM _ => find_first_t (t::past) u terms;
   472 
   473 (** Final simplification for the CancelFactor simprocs **)
   474 val simplify_one = Arith_Data.simplify_meta_eq  
   475   [@{thm mult_1_left}, @{thm mult_1_right}, @{thm div_by_1}, @{thm numeral_1_eq_1}];
   476 
   477 fun cancel_simplify_meta_eq ss cancel_th th =
   478     simplify_one ss (([th, cancel_th]) MRS trans);
   479 
   480 local
   481   val Tp_Eq = Thm.reflexive (Thm.cterm_of @{theory HOL} HOLogic.Trueprop)
   482   fun Eq_True_elim Eq = 
   483     Thm.equal_elim (Thm.combination Tp_Eq (Thm.symmetric Eq)) @{thm TrueI}
   484 in
   485 fun sign_conv pos_th neg_th ss t =
   486   let val T = fastype_of t;
   487       val zero = Const(@{const_name Algebras.zero}, T);
   488       val less = Const(@{const_name Algebras.less}, [T,T] ---> HOLogic.boolT);
   489       val pos = less $ zero $ t and neg = less $ t $ zero
   490       fun prove p =
   491         Option.map Eq_True_elim (Lin_Arith.simproc ss p)
   492         handle THM _ => NONE
   493     in case prove pos of
   494          SOME th => SOME(th RS pos_th)
   495        | NONE => (case prove neg of
   496                     SOME th => SOME(th RS neg_th)
   497                   | NONE => NONE)
   498     end;
   499 end
   500 
   501 structure CancelFactorCommon =
   502   struct
   503   val mk_sum            = long_mk_prod
   504   val dest_sum          = dest_prod
   505   val mk_coeff          = mk_coeff
   506   val dest_coeff        = dest_coeff
   507   val find_first        = find_first_t []
   508   fun trans_tac _       = Arith_Data.trans_tac
   509   val norm_ss = HOL_ss addsimps mult_1s @ @{thms mult_ac}
   510   fun norm_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss))
   511   val simplify_meta_eq  = cancel_simplify_meta_eq 
   512   end;
   513 
   514 (*mult_cancel_left requires a ring with no zero divisors.*)
   515 structure EqCancelFactor = ExtractCommonTermFun
   516  (open CancelFactorCommon
   517   val prove_conv = Arith_Data.prove_conv
   518   val mk_bal   = HOLogic.mk_eq
   519   val dest_bal = HOLogic.dest_bin "op =" Term.dummyT
   520   fun simp_conv _ _ = SOME @{thm mult_cancel_left}
   521 );
   522 
   523 (*for ordered rings*)
   524 structure LeCancelFactor = ExtractCommonTermFun
   525  (open CancelFactorCommon
   526   val prove_conv = Arith_Data.prove_conv
   527   val mk_bal   = HOLogic.mk_binrel @{const_name Algebras.less_eq}
   528   val dest_bal = HOLogic.dest_bin @{const_name Algebras.less_eq} Term.dummyT
   529   val simp_conv = sign_conv
   530     @{thm mult_le_cancel_left_pos} @{thm mult_le_cancel_left_neg}
   531 );
   532 
   533 (*for ordered rings*)
   534 structure LessCancelFactor = ExtractCommonTermFun
   535  (open CancelFactorCommon
   536   val prove_conv = Arith_Data.prove_conv
   537   val mk_bal   = HOLogic.mk_binrel @{const_name Algebras.less}
   538   val dest_bal = HOLogic.dest_bin @{const_name Algebras.less} Term.dummyT
   539   val simp_conv = sign_conv
   540     @{thm mult_less_cancel_left_pos} @{thm mult_less_cancel_left_neg}
   541 );
   542 
   543 (*for semirings with division*)
   544 structure DivCancelFactor = ExtractCommonTermFun
   545  (open CancelFactorCommon
   546   val prove_conv = Arith_Data.prove_conv
   547   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   548   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   549   fun simp_conv _ _ = SOME @{thm div_mult_mult1_if}
   550 );
   551 
   552 structure ModCancelFactor = ExtractCommonTermFun
   553  (open CancelFactorCommon
   554   val prove_conv = Arith_Data.prove_conv
   555   val mk_bal   = HOLogic.mk_binop @{const_name Divides.mod}
   556   val dest_bal = HOLogic.dest_bin @{const_name Divides.mod} Term.dummyT
   557   fun simp_conv _ _ = SOME @{thm mod_mult_mult1}
   558 );
   559 
   560 (*for idoms*)
   561 structure DvdCancelFactor = ExtractCommonTermFun
   562  (open CancelFactorCommon
   563   val prove_conv = Arith_Data.prove_conv
   564   val mk_bal   = HOLogic.mk_binrel @{const_name Ring_and_Field.dvd}
   565   val dest_bal = HOLogic.dest_bin @{const_name Ring_and_Field.dvd} Term.dummyT
   566   fun simp_conv _ _ = SOME @{thm dvd_mult_cancel_left}
   567 );
   568 
   569 (*Version for all fields, including unordered ones (type complex).*)
   570 structure DivideCancelFactor = ExtractCommonTermFun
   571  (open CancelFactorCommon
   572   val prove_conv = Arith_Data.prove_conv
   573   val mk_bal   = HOLogic.mk_binop @{const_name Algebras.divide}
   574   val dest_bal = HOLogic.dest_bin @{const_name Algebras.divide} Term.dummyT
   575   fun simp_conv _ _ = SOME @{thm mult_divide_mult_cancel_left_if}
   576 );
   577 
   578 val cancel_factors =
   579   map (Arith_Data.prep_simproc @{theory})
   580    [("ring_eq_cancel_factor",
   581      ["(l::'a::idom) * m = n",
   582       "(l::'a::idom) = m * n"],
   583      K EqCancelFactor.proc),
   584     ("ordered_ring_le_cancel_factor",
   585      ["(l::'a::ordered_ring) * m <= n",
   586       "(l::'a::ordered_ring) <= m * n"],
   587      K LeCancelFactor.proc),
   588     ("ordered_ring_less_cancel_factor",
   589      ["(l::'a::ordered_ring) * m < n",
   590       "(l::'a::ordered_ring) < m * n"],
   591      K LessCancelFactor.proc),
   592     ("int_div_cancel_factor",
   593      ["((l::'a::semiring_div) * m) div n", "(l::'a::semiring_div) div (m * n)"],
   594      K DivCancelFactor.proc),
   595     ("int_mod_cancel_factor",
   596      ["((l::'a::semiring_div) * m) mod n", "(l::'a::semiring_div) mod (m * n)"],
   597      K ModCancelFactor.proc),
   598     ("dvd_cancel_factor",
   599      ["((l::'a::idom) * m) dvd n", "(l::'a::idom) dvd (m * n)"],
   600      K DvdCancelFactor.proc),
   601     ("divide_cancel_factor",
   602      ["((l::'a::{division_by_zero,field}) * m) / n",
   603       "(l::'a::{division_by_zero,field}) / (m * n)"],
   604      K DivideCancelFactor.proc)];
   605 
   606 end;
   607 
   608 Addsimprocs Numeral_Simprocs.cancel_numerals;
   609 Addsimprocs [Numeral_Simprocs.combine_numerals];
   610 Addsimprocs [Numeral_Simprocs.field_combine_numerals];
   611 Addsimprocs [Numeral_Simprocs.assoc_fold_simproc];
   612 
   613 (*examples:
   614 print_depth 22;
   615 set timing;
   616 set trace_simp;
   617 fun test s = (Goal s, by (Simp_tac 1));
   618 
   619 test "l + 2 + 2 + 2 + (l + 2) + (oo + 2) = (uu::int)";
   620 
   621 test "2*u = (u::int)";
   622 test "(i + j + 12 + (k::int)) - 15 = y";
   623 test "(i + j + 12 + (k::int)) - 5 = y";
   624 
   625 test "y - b < (b::int)";
   626 test "y - (3*b + c) < (b::int) - 2*c";
   627 
   628 test "(2*x - (u*v) + y) - v*3*u = (w::int)";
   629 test "(2*x*u*v + (u*v)*4 + y) - v*u*4 = (w::int)";
   630 test "(2*x*u*v + (u*v)*4 + y) - v*u = (w::int)";
   631 test "u*v - (x*u*v + (u*v)*4 + y) = (w::int)";
   632 
   633 test "(i + j + 12 + (k::int)) = u + 15 + y";
   634 test "(i + j*2 + 12 + (k::int)) = j + 5 + y";
   635 
   636 test "2*y + 3*z + 6*w + 2*y + 3*z + 2*u = 2*y' + 3*z' + 6*w' + 2*y' + 3*z' + u + (vv::int)";
   637 
   638 test "a + -(b+c) + b = (d::int)";
   639 test "a + -(b+c) - b = (d::int)";
   640 
   641 (*negative numerals*)
   642 test "(i + j + -2 + (k::int)) - (u + 5 + y) = zz";
   643 test "(i + j + -3 + (k::int)) < u + 5 + y";
   644 test "(i + j + 3 + (k::int)) < u + -6 + y";
   645 test "(i + j + -12 + (k::int)) - 15 = y";
   646 test "(i + j + 12 + (k::int)) - -15 = y";
   647 test "(i + j + -12 + (k::int)) - -15 = y";
   648 *)
   649 
   650 Addsimprocs Numeral_Simprocs.cancel_numeral_factors;
   651 
   652 (*examples:
   653 print_depth 22;
   654 set timing;
   655 set trace_simp;
   656 fun test s = (Goal s; by (Simp_tac 1));
   657 
   658 test "9*x = 12 * (y::int)";
   659 test "(9*x) div (12 * (y::int)) = z";
   660 test "9*x < 12 * (y::int)";
   661 test "9*x <= 12 * (y::int)";
   662 
   663 test "-99*x = 132 * (y::int)";
   664 test "(-99*x) div (132 * (y::int)) = z";
   665 test "-99*x < 132 * (y::int)";
   666 test "-99*x <= 132 * (y::int)";
   667 
   668 test "999*x = -396 * (y::int)";
   669 test "(999*x) div (-396 * (y::int)) = z";
   670 test "999*x < -396 * (y::int)";
   671 test "999*x <= -396 * (y::int)";
   672 
   673 test "-99*x = -81 * (y::int)";
   674 test "(-99*x) div (-81 * (y::int)) = z";
   675 test "-99*x <= -81 * (y::int)";
   676 test "-99*x < -81 * (y::int)";
   677 
   678 test "-2 * x = -1 * (y::int)";
   679 test "-2 * x = -(y::int)";
   680 test "(-2 * x) div (-1 * (y::int)) = z";
   681 test "-2 * x < -(y::int)";
   682 test "-2 * x <= -1 * (y::int)";
   683 test "-x < -23 * (y::int)";
   684 test "-x <= -23 * (y::int)";
   685 *)
   686 
   687 (*And the same examples for fields such as rat or real:
   688 test "0 <= (y::rat) * -2";
   689 test "9*x = 12 * (y::rat)";
   690 test "(9*x) / (12 * (y::rat)) = z";
   691 test "9*x < 12 * (y::rat)";
   692 test "9*x <= 12 * (y::rat)";
   693 
   694 test "-99*x = 132 * (y::rat)";
   695 test "(-99*x) / (132 * (y::rat)) = z";
   696 test "-99*x < 132 * (y::rat)";
   697 test "-99*x <= 132 * (y::rat)";
   698 
   699 test "999*x = -396 * (y::rat)";
   700 test "(999*x) / (-396 * (y::rat)) = z";
   701 test "999*x < -396 * (y::rat)";
   702 test "999*x <= -396 * (y::rat)";
   703 
   704 test  "(- ((2::rat) * x) <= 2 * y)";
   705 test "-99*x = -81 * (y::rat)";
   706 test "(-99*x) / (-81 * (y::rat)) = z";
   707 test "-99*x <= -81 * (y::rat)";
   708 test "-99*x < -81 * (y::rat)";
   709 
   710 test "-2 * x = -1 * (y::rat)";
   711 test "-2 * x = -(y::rat)";
   712 test "(-2 * x) / (-1 * (y::rat)) = z";
   713 test "-2 * x < -(y::rat)";
   714 test "-2 * x <= -1 * (y::rat)";
   715 test "-x < -23 * (y::rat)";
   716 test "-x <= -23 * (y::rat)";
   717 *)
   718 
   719 Addsimprocs Numeral_Simprocs.cancel_factors;
   720 
   721 
   722 (*examples:
   723 print_depth 22;
   724 set timing;
   725 set trace_simp;
   726 fun test s = (Goal s; by (Asm_simp_tac 1));
   727 
   728 test "x*k = k*(y::int)";
   729 test "k = k*(y::int)";
   730 test "a*(b*c) = (b::int)";
   731 test "a*(b*c) = d*(b::int)*(x*a)";
   732 
   733 test "(x*k) div (k*(y::int)) = (uu::int)";
   734 test "(k) div (k*(y::int)) = (uu::int)";
   735 test "(a*(b*c)) div ((b::int)) = (uu::int)";
   736 test "(a*(b*c)) div (d*(b::int)*(x*a)) = (uu::int)";
   737 *)
   738 
   739 (*And the same examples for fields such as rat or real:
   740 print_depth 22;
   741 set timing;
   742 set trace_simp;
   743 fun test s = (Goal s; by (Asm_simp_tac 1));
   744 
   745 test "x*k = k*(y::rat)";
   746 test "k = k*(y::rat)";
   747 test "a*(b*c) = (b::rat)";
   748 test "a*(b*c) = d*(b::rat)*(x*a)";
   749 
   750 
   751 test "(x*k) / (k*(y::rat)) = (uu::rat)";
   752 test "(k) / (k*(y::rat)) = (uu::rat)";
   753 test "(a*(b*c)) / ((b::rat)) = (uu::rat)";
   754 test "(a*(b*c)) / (d*(b::rat)*(x*a)) = (uu::rat)";
   755 
   756 (*FIXME: what do we do about this?*)
   757 test "a*(b*c)/(y*z) = d*(b::rat)*(x*a)/z";
   758 *)