src/HOL/Tools/numeral_simprocs.ML
author haftmann
Sat Aug 28 16:14:32 2010 +0200 (2010-08-28)
changeset 38864 4abe644fcea5
parent 38549 d0385f2764d8
child 40878 7695e4de4d86
permissions -rw-r--r--
formerly unnamed infix equality now named HOL.eq
     1 (* Author:     Lawrence C Paulson, Cambridge University Computer Laboratory
     2    Copyright   2000  University of Cambridge
     3 
     4 Simprocs for the (integer) numerals.
     5 *)
     6 
     7 (*To quote from Provers/Arith/cancel_numeral_factor.ML:
     8 
     9 Cancels common coefficients in balanced expressions:
    10 
    11      u*#m ~~ u'*#m'  ==  #n*u ~~ #n'*u'
    12 
    13 where ~~ is an appropriate balancing operation (e.g. =, <=, <, div, /)
    14 and d = gcd(m,m') and n=m/d and n'=m'/d.
    15 *)
    16 
    17 signature NUMERAL_SIMPROCS =
    18 sig
    19   val assoc_fold_simproc: simproc
    20   val combine_numerals: simproc
    21   val cancel_numerals: simproc list
    22   val cancel_factors: simproc list
    23   val cancel_numeral_factors: simproc list
    24   val field_combine_numerals: simproc
    25   val field_cancel_numeral_factors: simproc list
    26   val num_ss: simpset
    27   val field_comp_conv: conv
    28 end;
    29 
    30 structure Numeral_Simprocs : NUMERAL_SIMPROCS =
    31 struct
    32 
    33 val mk_number = Arith_Data.mk_number;
    34 val mk_sum = Arith_Data.mk_sum;
    35 val long_mk_sum = Arith_Data.long_mk_sum;
    36 val dest_sum = Arith_Data.dest_sum;
    37 
    38 val mk_diff = HOLogic.mk_binop @{const_name Groups.minus};
    39 val dest_diff = HOLogic.dest_bin @{const_name Groups.minus} Term.dummyT;
    40 
    41 val mk_times = HOLogic.mk_binop @{const_name Groups.times};
    42 
    43 fun one_of T = Const(@{const_name Groups.one}, T);
    44 
    45 (* build product with trailing 1 rather than Numeral 1 in order to avoid the
    46    unnecessary restriction to type class number_ring
    47    which is not required for cancellation of common factors in divisions.
    48 *)
    49 fun mk_prod T = 
    50   let val one = one_of T
    51   fun mk [] = one
    52     | mk [t] = t
    53     | mk (t :: ts) = if t = one then mk ts else mk_times (t, mk ts)
    54   in mk end;
    55 
    56 (*This version ALWAYS includes a trailing one*)
    57 fun long_mk_prod T []        = one_of T
    58   | long_mk_prod T (t :: ts) = mk_times (t, mk_prod T ts);
    59 
    60 val dest_times = HOLogic.dest_bin @{const_name Groups.times} Term.dummyT;
    61 
    62 fun dest_prod t =
    63       let val (t,u) = dest_times t
    64       in dest_prod t @ dest_prod u end
    65       handle TERM _ => [t];
    66 
    67 fun find_first_numeral past (t::terms) =
    68         ((snd (HOLogic.dest_number t), rev past @ terms)
    69          handle TERM _ => find_first_numeral (t::past) terms)
    70   | find_first_numeral past [] = raise TERM("find_first_numeral", []);
    71 
    72 (*DON'T do the obvious simplifications; that would create special cases*)
    73 fun mk_coeff (k, t) = mk_times (mk_number (Term.fastype_of t) k, t);
    74 
    75 (*Express t as a product of (possibly) a numeral with other sorted terms*)
    76 fun dest_coeff sign (Const (@{const_name Groups.uminus}, _) $ t) = dest_coeff (~sign) t
    77   | dest_coeff sign t =
    78     let val ts = sort Term_Ord.term_ord (dest_prod t)
    79         val (n, ts') = find_first_numeral [] ts
    80                           handle TERM _ => (1, ts)
    81     in (sign*n, mk_prod (Term.fastype_of t) ts') end;
    82 
    83 (*Find first coefficient-term THAT MATCHES u*)
    84 fun find_first_coeff past u [] = raise TERM("find_first_coeff", [])
    85   | find_first_coeff past u (t::terms) =
    86         let val (n,u') = dest_coeff 1 t
    87         in if u aconv u' then (n, rev past @ terms)
    88                          else find_first_coeff (t::past) u terms
    89         end
    90         handle TERM _ => find_first_coeff (t::past) u terms;
    91 
    92 (*Fractions as pairs of ints. Can't use Rat.rat because the representation
    93   needs to preserve negative values in the denominator.*)
    94 fun mk_frac (p, q) = if q = 0 then raise Div else (p, q);
    95 
    96 (*Don't reduce fractions; sums must be proved by rule add_frac_eq.
    97   Fractions are reduced later by the cancel_numeral_factor simproc.*)
    98 fun add_frac ((p1, q1), (p2, q2)) = (p1 * q2 + p2 * q1, q1 * q2);
    99 
   100 val mk_divide = HOLogic.mk_binop @{const_name Rings.divide};
   101 
   102 (*Build term (p / q) * t*)
   103 fun mk_fcoeff ((p, q), t) =
   104   let val T = Term.fastype_of t
   105   in mk_times (mk_divide (mk_number T p, mk_number T q), t) end;
   106 
   107 (*Express t as a product of a fraction with other sorted terms*)
   108 fun dest_fcoeff sign (Const (@{const_name Groups.uminus}, _) $ t) = dest_fcoeff (~sign) t
   109   | dest_fcoeff sign (Const (@{const_name Rings.divide}, _) $ t $ u) =
   110     let val (p, t') = dest_coeff sign t
   111         val (q, u') = dest_coeff 1 u
   112     in (mk_frac (p, q), mk_divide (t', u')) end
   113   | dest_fcoeff sign t =
   114     let val (p, t') = dest_coeff sign t
   115         val T = Term.fastype_of t
   116     in (mk_frac (p, 1), mk_divide (t', one_of T)) end;
   117 
   118 
   119 (** New term ordering so that AC-rewriting brings numerals to the front **)
   120 
   121 (*Order integers by absolute value and then by sign. The standard integer
   122   ordering is not well-founded.*)
   123 fun num_ord (i,j) =
   124   (case int_ord (abs i, abs j) of
   125     EQUAL => int_ord (Int.sign i, Int.sign j) 
   126   | ord => ord);
   127 
   128 (*This resembles Term_Ord.term_ord, but it puts binary numerals before other
   129   non-atomic terms.*)
   130 local open Term 
   131 in 
   132 fun numterm_ord (Abs (_, T, t), Abs(_, U, u)) =
   133       (case numterm_ord (t, u) of EQUAL => Term_Ord.typ_ord (T, U) | ord => ord)
   134   | numterm_ord
   135      (Const(@{const_name Int.number_of}, _) $ v, Const(@{const_name Int.number_of}, _) $ w) =
   136      num_ord (HOLogic.dest_numeral v, HOLogic.dest_numeral w)
   137   | numterm_ord (Const(@{const_name Int.number_of}, _) $ _, _) = LESS
   138   | numterm_ord (_, Const(@{const_name Int.number_of}, _) $ _) = GREATER
   139   | numterm_ord (t, u) =
   140       (case int_ord (size_of_term t, size_of_term u) of
   141         EQUAL =>
   142           let val (f, ts) = strip_comb t and (g, us) = strip_comb u in
   143             (case Term_Ord.hd_ord (f, g) of EQUAL => numterms_ord (ts, us) | ord => ord)
   144           end
   145       | ord => ord)
   146 and numterms_ord (ts, us) = list_ord numterm_ord (ts, us)
   147 end;
   148 
   149 fun numtermless tu = (numterm_ord tu = LESS);
   150 
   151 val num_ss = HOL_ss settermless numtermless;
   152 
   153 (*Maps 0 to Numeral0 and 1 to Numeral1 so that arithmetic isn't complicated by the abstract 0 and 1.*)
   154 val numeral_syms = [@{thm numeral_0_eq_0} RS sym, @{thm numeral_1_eq_1} RS sym];
   155 
   156 (*Simplify Numeral0+n, n+Numeral0, Numeral1*n, n*Numeral1, 1*x, x*1, x/1 *)
   157 val add_0s =  @{thms add_0s};
   158 val mult_1s = @{thms mult_1s mult_1_left mult_1_right divide_1};
   159 
   160 (*Simplify inverse Numeral1, a/Numeral1*)
   161 val inverse_1s = [@{thm inverse_numeral_1}];
   162 val divide_1s = [@{thm divide_numeral_1}];
   163 
   164 (*To perform binary arithmetic.  The "left" rewriting handles patterns
   165   created by the Numeral_Simprocs, such as 3 * (5 * x). *)
   166 val simps = [@{thm numeral_0_eq_0} RS sym, @{thm numeral_1_eq_1} RS sym,
   167                  @{thm add_number_of_left}, @{thm mult_number_of_left}] @
   168                 @{thms arith_simps} @ @{thms rel_simps};
   169 
   170 (*Binary arithmetic BUT NOT ADDITION since it may collapse adjacent terms
   171   during re-arrangement*)
   172 val non_add_simps =
   173   subtract Thm.eq_thm [@{thm add_number_of_left}, @{thm number_of_add} RS sym] simps;
   174 
   175 (*To evaluate binary negations of coefficients*)
   176 val minus_simps = [@{thm numeral_m1_eq_minus_1} RS sym, @{thm number_of_minus} RS sym] @
   177                    @{thms minus_bin_simps} @ @{thms pred_bin_simps};
   178 
   179 (*To let us treat subtraction as addition*)
   180 val diff_simps = [@{thm diff_minus}, @{thm minus_add_distrib}, @{thm minus_minus}];
   181 
   182 (*To let us treat division as multiplication*)
   183 val divide_simps = [@{thm divide_inverse}, @{thm inverse_mult_distrib}, @{thm inverse_inverse_eq}];
   184 
   185 (*push the unary minus down*)
   186 val minus_mult_eq_1_to_2 = @{lemma "- (a::'a::ring) * b = a * - b" by simp};
   187 
   188 (*to extract again any uncancelled minuses*)
   189 val minus_from_mult_simps =
   190     [@{thm minus_minus}, @{thm mult_minus_left}, @{thm mult_minus_right}];
   191 
   192 (*combine unary minus with numeric literals, however nested within a product*)
   193 val mult_minus_simps =
   194     [@{thm mult_assoc}, @{thm minus_mult_left}, minus_mult_eq_1_to_2];
   195 
   196 val norm_ss1 = num_ss addsimps numeral_syms @ add_0s @ mult_1s @
   197   diff_simps @ minus_simps @ @{thms add_ac}
   198 val norm_ss2 = num_ss addsimps non_add_simps @ mult_minus_simps
   199 val norm_ss3 = num_ss addsimps minus_from_mult_simps @ @{thms add_ac} @ @{thms mult_ac}
   200 
   201 structure CancelNumeralsCommon =
   202   struct
   203   val mk_sum            = mk_sum
   204   val dest_sum          = dest_sum
   205   val mk_coeff          = mk_coeff
   206   val dest_coeff        = dest_coeff 1
   207   val find_first_coeff  = find_first_coeff []
   208   fun trans_tac _       = Arith_Data.trans_tac
   209 
   210   fun norm_tac ss =
   211     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   212     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   213     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   214 
   215   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps
   216   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   217   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s)
   218   end;
   219 
   220 
   221 structure EqCancelNumerals = CancelNumeralsFun
   222  (open CancelNumeralsCommon
   223   val prove_conv = Arith_Data.prove_conv
   224   val mk_bal   = HOLogic.mk_eq
   225   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   226   val bal_add1 = @{thm eq_add_iff1} RS trans
   227   val bal_add2 = @{thm eq_add_iff2} RS trans
   228 );
   229 
   230 structure LessCancelNumerals = CancelNumeralsFun
   231  (open CancelNumeralsCommon
   232   val prove_conv = Arith_Data.prove_conv
   233   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   234   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   235   val bal_add1 = @{thm less_add_iff1} RS trans
   236   val bal_add2 = @{thm less_add_iff2} RS trans
   237 );
   238 
   239 structure LeCancelNumerals = CancelNumeralsFun
   240  (open CancelNumeralsCommon
   241   val prove_conv = Arith_Data.prove_conv
   242   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   243   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   244   val bal_add1 = @{thm le_add_iff1} RS trans
   245   val bal_add2 = @{thm le_add_iff2} RS trans
   246 );
   247 
   248 val cancel_numerals =
   249   map (Arith_Data.prep_simproc @{theory})
   250    [("inteq_cancel_numerals",
   251      ["(l::'a::number_ring) + m = n",
   252       "(l::'a::number_ring) = m + n",
   253       "(l::'a::number_ring) - m = n",
   254       "(l::'a::number_ring) = m - n",
   255       "(l::'a::number_ring) * m = n",
   256       "(l::'a::number_ring) = m * n"],
   257      K EqCancelNumerals.proc),
   258     ("intless_cancel_numerals",
   259      ["(l::'a::{linordered_idom,number_ring}) + m < n",
   260       "(l::'a::{linordered_idom,number_ring}) < m + n",
   261       "(l::'a::{linordered_idom,number_ring}) - m < n",
   262       "(l::'a::{linordered_idom,number_ring}) < m - n",
   263       "(l::'a::{linordered_idom,number_ring}) * m < n",
   264       "(l::'a::{linordered_idom,number_ring}) < m * n"],
   265      K LessCancelNumerals.proc),
   266     ("intle_cancel_numerals",
   267      ["(l::'a::{linordered_idom,number_ring}) + m <= n",
   268       "(l::'a::{linordered_idom,number_ring}) <= m + n",
   269       "(l::'a::{linordered_idom,number_ring}) - m <= n",
   270       "(l::'a::{linordered_idom,number_ring}) <= m - n",
   271       "(l::'a::{linordered_idom,number_ring}) * m <= n",
   272       "(l::'a::{linordered_idom,number_ring}) <= m * n"],
   273      K LeCancelNumerals.proc)];
   274 
   275 structure CombineNumeralsData =
   276   struct
   277   type coeff            = int
   278   val iszero            = (fn x => x = 0)
   279   val add               = op +
   280   val mk_sum            = long_mk_sum    (*to work for e.g. 2*x + 3*x *)
   281   val dest_sum          = dest_sum
   282   val mk_coeff          = mk_coeff
   283   val dest_coeff        = dest_coeff 1
   284   val left_distrib      = @{thm combine_common_factor} RS trans
   285   val prove_conv        = Arith_Data.prove_conv_nohyps
   286   fun trans_tac _       = Arith_Data.trans_tac
   287 
   288   fun norm_tac ss =
   289     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   290     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   291     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   292 
   293   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps
   294   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   295   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s)
   296   end;
   297 
   298 structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData);
   299 
   300 (*Version for fields, where coefficients can be fractions*)
   301 structure FieldCombineNumeralsData =
   302   struct
   303   type coeff            = int * int
   304   val iszero            = (fn (p, q) => p = 0)
   305   val add               = add_frac
   306   val mk_sum            = long_mk_sum
   307   val dest_sum          = dest_sum
   308   val mk_coeff          = mk_fcoeff
   309   val dest_coeff        = dest_fcoeff 1
   310   val left_distrib      = @{thm combine_common_factor} RS trans
   311   val prove_conv        = Arith_Data.prove_conv_nohyps
   312   fun trans_tac _       = Arith_Data.trans_tac
   313 
   314   val norm_ss1a = norm_ss1 addsimps inverse_1s @ divide_simps
   315   fun norm_tac ss =
   316     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1a))
   317     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   318     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   319 
   320   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps @ [@{thm add_frac_eq}]
   321   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   322   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s @ divide_1s)
   323   end;
   324 
   325 structure FieldCombineNumerals = CombineNumeralsFun(FieldCombineNumeralsData);
   326 
   327 val combine_numerals =
   328   Arith_Data.prep_simproc @{theory}
   329     ("int_combine_numerals", 
   330      ["(i::'a::number_ring) + j", "(i::'a::number_ring) - j"], 
   331      K CombineNumerals.proc);
   332 
   333 val field_combine_numerals =
   334   Arith_Data.prep_simproc @{theory}
   335     ("field_combine_numerals", 
   336      ["(i::'a::{field_inverse_zero, number_ring}) + j",
   337       "(i::'a::{field_inverse_zero, number_ring}) - j"], 
   338      K FieldCombineNumerals.proc);
   339 
   340 (** Constant folding for multiplication in semirings **)
   341 
   342 (*We do not need folding for addition: combine_numerals does the same thing*)
   343 
   344 structure Semiring_Times_Assoc_Data : ASSOC_FOLD_DATA =
   345 struct
   346   val assoc_ss = HOL_ss addsimps @{thms mult_ac}
   347   val eq_reflection = eq_reflection
   348   val is_numeral = can HOLogic.dest_number
   349 end;
   350 
   351 structure Semiring_Times_Assoc = Assoc_Fold (Semiring_Times_Assoc_Data);
   352 
   353 val assoc_fold_simproc =
   354   Arith_Data.prep_simproc @{theory}
   355    ("semiring_assoc_fold", ["(a::'a::comm_semiring_1_cancel) * b"],
   356     K Semiring_Times_Assoc.proc);
   357 
   358 structure CancelNumeralFactorCommon =
   359   struct
   360   val mk_coeff          = mk_coeff
   361   val dest_coeff        = dest_coeff 1
   362   fun trans_tac _       = Arith_Data.trans_tac
   363 
   364   val norm_ss1 = HOL_ss addsimps minus_from_mult_simps @ mult_1s
   365   val norm_ss2 = HOL_ss addsimps simps @ mult_minus_simps
   366   val norm_ss3 = HOL_ss addsimps @{thms mult_ac}
   367   fun norm_tac ss =
   368     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   369     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   370     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   371 
   372   val numeral_simp_ss = HOL_ss addsimps
   373     [@{thm eq_number_of_eq}, @{thm less_number_of}, @{thm le_number_of}] @ simps
   374   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   375   val simplify_meta_eq = Arith_Data.simplify_meta_eq
   376     [@{thm Nat.add_0}, @{thm Nat.add_0_right}, @{thm mult_zero_left},
   377       @{thm mult_zero_right}, @{thm mult_Bit1}, @{thm mult_1_right}];
   378   end
   379 
   380 (*Version for semiring_div*)
   381 structure DivCancelNumeralFactor = CancelNumeralFactorFun
   382  (open CancelNumeralFactorCommon
   383   val prove_conv = Arith_Data.prove_conv
   384   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   385   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   386   val cancel = @{thm div_mult_mult1} RS trans
   387   val neg_exchanges = false
   388 )
   389 
   390 (*Version for fields*)
   391 structure DivideCancelNumeralFactor = CancelNumeralFactorFun
   392  (open CancelNumeralFactorCommon
   393   val prove_conv = Arith_Data.prove_conv
   394   val mk_bal   = HOLogic.mk_binop @{const_name Rings.divide}
   395   val dest_bal = HOLogic.dest_bin @{const_name Rings.divide} Term.dummyT
   396   val cancel = @{thm mult_divide_mult_cancel_left} RS trans
   397   val neg_exchanges = false
   398 )
   399 
   400 structure EqCancelNumeralFactor = CancelNumeralFactorFun
   401  (open CancelNumeralFactorCommon
   402   val prove_conv = Arith_Data.prove_conv
   403   val mk_bal   = HOLogic.mk_eq
   404   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   405   val cancel = @{thm mult_cancel_left} RS trans
   406   val neg_exchanges = false
   407 )
   408 
   409 structure LessCancelNumeralFactor = CancelNumeralFactorFun
   410  (open CancelNumeralFactorCommon
   411   val prove_conv = Arith_Data.prove_conv
   412   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   413   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   414   val cancel = @{thm mult_less_cancel_left} RS trans
   415   val neg_exchanges = true
   416 )
   417 
   418 structure LeCancelNumeralFactor = CancelNumeralFactorFun
   419  (open CancelNumeralFactorCommon
   420   val prove_conv = Arith_Data.prove_conv
   421   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   422   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   423   val cancel = @{thm mult_le_cancel_left} RS trans
   424   val neg_exchanges = true
   425 )
   426 
   427 val cancel_numeral_factors =
   428   map (Arith_Data.prep_simproc @{theory})
   429    [("ring_eq_cancel_numeral_factor",
   430      ["(l::'a::{idom,number_ring}) * m = n",
   431       "(l::'a::{idom,number_ring}) = m * n"],
   432      K EqCancelNumeralFactor.proc),
   433     ("ring_less_cancel_numeral_factor",
   434      ["(l::'a::{linordered_idom,number_ring}) * m < n",
   435       "(l::'a::{linordered_idom,number_ring}) < m * n"],
   436      K LessCancelNumeralFactor.proc),
   437     ("ring_le_cancel_numeral_factor",
   438      ["(l::'a::{linordered_idom,number_ring}) * m <= n",
   439       "(l::'a::{linordered_idom,number_ring}) <= m * n"],
   440      K LeCancelNumeralFactor.proc),
   441     ("int_div_cancel_numeral_factors",
   442      ["((l::'a::{semiring_div,number_ring}) * m) div n",
   443       "(l::'a::{semiring_div,number_ring}) div (m * n)"],
   444      K DivCancelNumeralFactor.proc),
   445     ("divide_cancel_numeral_factor",
   446      ["((l::'a::{field_inverse_zero,number_ring}) * m) / n",
   447       "(l::'a::{field_inverse_zero,number_ring}) / (m * n)",
   448       "((number_of v)::'a::{field_inverse_zero,number_ring}) / (number_of w)"],
   449      K DivideCancelNumeralFactor.proc)];
   450 
   451 val field_cancel_numeral_factors =
   452   map (Arith_Data.prep_simproc @{theory})
   453    [("field_eq_cancel_numeral_factor",
   454      ["(l::'a::{field,number_ring}) * m = n",
   455       "(l::'a::{field,number_ring}) = m * n"],
   456      K EqCancelNumeralFactor.proc),
   457     ("field_cancel_numeral_factor",
   458      ["((l::'a::{field_inverse_zero,number_ring}) * m) / n",
   459       "(l::'a::{field_inverse_zero,number_ring}) / (m * n)",
   460       "((number_of v)::'a::{field_inverse_zero,number_ring}) / (number_of w)"],
   461      K DivideCancelNumeralFactor.proc)]
   462 
   463 
   464 (** Declarations for ExtractCommonTerm **)
   465 
   466 (*Find first term that matches u*)
   467 fun find_first_t past u []         = raise TERM ("find_first_t", [])
   468   | find_first_t past u (t::terms) =
   469         if u aconv t then (rev past @ terms)
   470         else find_first_t (t::past) u terms
   471         handle TERM _ => find_first_t (t::past) u terms;
   472 
   473 (** Final simplification for the CancelFactor simprocs **)
   474 val simplify_one = Arith_Data.simplify_meta_eq  
   475   [@{thm mult_1_left}, @{thm mult_1_right}, @{thm div_by_1}, @{thm numeral_1_eq_1}];
   476 
   477 fun cancel_simplify_meta_eq ss cancel_th th =
   478     simplify_one ss (([th, cancel_th]) MRS trans);
   479 
   480 local
   481   val Tp_Eq = Thm.reflexive (Thm.cterm_of @{theory HOL} HOLogic.Trueprop)
   482   fun Eq_True_elim Eq = 
   483     Thm.equal_elim (Thm.combination Tp_Eq (Thm.symmetric Eq)) @{thm TrueI}
   484 in
   485 fun sign_conv pos_th neg_th ss t =
   486   let val T = fastype_of t;
   487       val zero = Const(@{const_name Groups.zero}, T);
   488       val less = Const(@{const_name Orderings.less}, [T,T] ---> HOLogic.boolT);
   489       val pos = less $ zero $ t and neg = less $ t $ zero
   490       fun prove p =
   491         Option.map Eq_True_elim (Lin_Arith.simproc ss p)
   492         handle THM _ => NONE
   493     in case prove pos of
   494          SOME th => SOME(th RS pos_th)
   495        | NONE => (case prove neg of
   496                     SOME th => SOME(th RS neg_th)
   497                   | NONE => NONE)
   498     end;
   499 end
   500 
   501 structure CancelFactorCommon =
   502   struct
   503   val mk_sum            = long_mk_prod
   504   val dest_sum          = dest_prod
   505   val mk_coeff          = mk_coeff
   506   val dest_coeff        = dest_coeff
   507   val find_first        = find_first_t []
   508   fun trans_tac _       = Arith_Data.trans_tac
   509   val norm_ss = HOL_ss addsimps mult_1s @ @{thms mult_ac}
   510   fun norm_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss))
   511   val simplify_meta_eq  = cancel_simplify_meta_eq 
   512   end;
   513 
   514 (*mult_cancel_left requires a ring with no zero divisors.*)
   515 structure EqCancelFactor = ExtractCommonTermFun
   516  (open CancelFactorCommon
   517   val prove_conv = Arith_Data.prove_conv
   518   val mk_bal   = HOLogic.mk_eq
   519   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   520   fun simp_conv _ _ = SOME @{thm mult_cancel_left}
   521 );
   522 
   523 (*for ordered rings*)
   524 structure LeCancelFactor = ExtractCommonTermFun
   525  (open CancelFactorCommon
   526   val prove_conv = Arith_Data.prove_conv
   527   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   528   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   529   val simp_conv = sign_conv
   530     @{thm mult_le_cancel_left_pos} @{thm mult_le_cancel_left_neg}
   531 );
   532 
   533 (*for ordered rings*)
   534 structure LessCancelFactor = ExtractCommonTermFun
   535  (open CancelFactorCommon
   536   val prove_conv = Arith_Data.prove_conv
   537   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   538   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   539   val simp_conv = sign_conv
   540     @{thm mult_less_cancel_left_pos} @{thm mult_less_cancel_left_neg}
   541 );
   542 
   543 (*for semirings with division*)
   544 structure DivCancelFactor = ExtractCommonTermFun
   545  (open CancelFactorCommon
   546   val prove_conv = Arith_Data.prove_conv
   547   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   548   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   549   fun simp_conv _ _ = SOME @{thm div_mult_mult1_if}
   550 );
   551 
   552 structure ModCancelFactor = ExtractCommonTermFun
   553  (open CancelFactorCommon
   554   val prove_conv = Arith_Data.prove_conv
   555   val mk_bal   = HOLogic.mk_binop @{const_name Divides.mod}
   556   val dest_bal = HOLogic.dest_bin @{const_name Divides.mod} Term.dummyT
   557   fun simp_conv _ _ = SOME @{thm mod_mult_mult1}
   558 );
   559 
   560 (*for idoms*)
   561 structure DvdCancelFactor = ExtractCommonTermFun
   562  (open CancelFactorCommon
   563   val prove_conv = Arith_Data.prove_conv
   564   val mk_bal   = HOLogic.mk_binrel @{const_name Rings.dvd}
   565   val dest_bal = HOLogic.dest_bin @{const_name Rings.dvd} Term.dummyT
   566   fun simp_conv _ _ = SOME @{thm dvd_mult_cancel_left}
   567 );
   568 
   569 (*Version for all fields, including unordered ones (type complex).*)
   570 structure DivideCancelFactor = ExtractCommonTermFun
   571  (open CancelFactorCommon
   572   val prove_conv = Arith_Data.prove_conv
   573   val mk_bal   = HOLogic.mk_binop @{const_name Rings.divide}
   574   val dest_bal = HOLogic.dest_bin @{const_name Rings.divide} Term.dummyT
   575   fun simp_conv _ _ = SOME @{thm mult_divide_mult_cancel_left_if}
   576 );
   577 
   578 val cancel_factors =
   579   map (Arith_Data.prep_simproc @{theory})
   580    [("ring_eq_cancel_factor",
   581      ["(l::'a::idom) * m = n",
   582       "(l::'a::idom) = m * n"],
   583      K EqCancelFactor.proc),
   584     ("linordered_ring_le_cancel_factor",
   585      ["(l::'a::linordered_ring) * m <= n",
   586       "(l::'a::linordered_ring) <= m * n"],
   587      K LeCancelFactor.proc),
   588     ("linordered_ring_less_cancel_factor",
   589      ["(l::'a::linordered_ring) * m < n",
   590       "(l::'a::linordered_ring) < m * n"],
   591      K LessCancelFactor.proc),
   592     ("int_div_cancel_factor",
   593      ["((l::'a::semiring_div) * m) div n", "(l::'a::semiring_div) div (m * n)"],
   594      K DivCancelFactor.proc),
   595     ("int_mod_cancel_factor",
   596      ["((l::'a::semiring_div) * m) mod n", "(l::'a::semiring_div) mod (m * n)"],
   597      K ModCancelFactor.proc),
   598     ("dvd_cancel_factor",
   599      ["((l::'a::idom) * m) dvd n", "(l::'a::idom) dvd (m * n)"],
   600      K DvdCancelFactor.proc),
   601     ("divide_cancel_factor",
   602      ["((l::'a::field_inverse_zero) * m) / n",
   603       "(l::'a::field_inverse_zero) / (m * n)"],
   604      K DivideCancelFactor.proc)];
   605 
   606 local
   607  val zr = @{cpat "0"}
   608  val zT = ctyp_of_term zr
   609  val geq = @{cpat HOL.eq}
   610  val eqT = Thm.dest_ctyp (ctyp_of_term geq) |> hd
   611  val add_frac_eq = mk_meta_eq @{thm "add_frac_eq"}
   612  val add_frac_num = mk_meta_eq @{thm "add_frac_num"}
   613  val add_num_frac = mk_meta_eq @{thm "add_num_frac"}
   614 
   615  fun prove_nz ss T t =
   616     let
   617       val z = Thm.instantiate_cterm ([(zT,T)],[]) zr
   618       val eq = Thm.instantiate_cterm ([(eqT,T)],[]) geq
   619       val th = Simplifier.rewrite (ss addsimps @{thms simp_thms})
   620            (Thm.capply @{cterm "Trueprop"} (Thm.capply @{cterm "Not"}
   621                   (Thm.capply (Thm.capply eq t) z)))
   622     in Thm.equal_elim (Thm.symmetric th) TrueI
   623     end
   624 
   625  fun proc phi ss ct =
   626   let
   627     val ((x,y),(w,z)) =
   628          (Thm.dest_binop #> (fn (a,b) => (Thm.dest_binop a, Thm.dest_binop b))) ct
   629     val _ = map (HOLogic.dest_number o term_of) [x,y,z,w]
   630     val T = ctyp_of_term x
   631     val [y_nz, z_nz] = map (prove_nz ss T) [y, z]
   632     val th = instantiate' [SOME T] (map SOME [y,z,x,w]) add_frac_eq
   633   in SOME (Thm.implies_elim (Thm.implies_elim th y_nz) z_nz)
   634   end
   635   handle CTERM _ => NONE | TERM _ => NONE | THM _ => NONE
   636 
   637  fun proc2 phi ss ct =
   638   let
   639     val (l,r) = Thm.dest_binop ct
   640     val T = ctyp_of_term l
   641   in (case (term_of l, term_of r) of
   642       (Const(@{const_name Rings.divide},_)$_$_, _) =>
   643         let val (x,y) = Thm.dest_binop l val z = r
   644             val _ = map (HOLogic.dest_number o term_of) [x,y,z]
   645             val ynz = prove_nz ss T y
   646         in SOME (Thm.implies_elim (instantiate' [SOME T] (map SOME [y,x,z]) add_frac_num) ynz)
   647         end
   648      | (_, Const (@{const_name Rings.divide},_)$_$_) =>
   649         let val (x,y) = Thm.dest_binop r val z = l
   650             val _ = map (HOLogic.dest_number o term_of) [x,y,z]
   651             val ynz = prove_nz ss T y
   652         in SOME (Thm.implies_elim (instantiate' [SOME T] (map SOME [y,z,x]) add_num_frac) ynz)
   653         end
   654      | _ => NONE)
   655   end
   656   handle CTERM _ => NONE | TERM _ => NONE | THM _ => NONE
   657 
   658  fun is_number (Const(@{const_name Rings.divide},_)$a$b) = is_number a andalso is_number b
   659    | is_number t = can HOLogic.dest_number t
   660 
   661  val is_number = is_number o term_of
   662 
   663  fun proc3 phi ss ct =
   664   (case term_of ct of
   665     Const(@{const_name Orderings.less},_)$(Const(@{const_name Rings.divide},_)$_$_)$_ =>
   666       let
   667         val ((a,b),c) = Thm.dest_binop ct |>> Thm.dest_binop
   668         val _ = map is_number [a,b,c]
   669         val T = ctyp_of_term c
   670         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "divide_less_eq"}
   671       in SOME (mk_meta_eq th) end
   672   | Const(@{const_name Orderings.less_eq},_)$(Const(@{const_name Rings.divide},_)$_$_)$_ =>
   673       let
   674         val ((a,b),c) = Thm.dest_binop ct |>> Thm.dest_binop
   675         val _ = map is_number [a,b,c]
   676         val T = ctyp_of_term c
   677         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "divide_le_eq"}
   678       in SOME (mk_meta_eq th) end
   679   | Const(@{const_name HOL.eq},_)$(Const(@{const_name Rings.divide},_)$_$_)$_ =>
   680       let
   681         val ((a,b),c) = Thm.dest_binop ct |>> Thm.dest_binop
   682         val _ = map is_number [a,b,c]
   683         val T = ctyp_of_term c
   684         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "divide_eq_eq"}
   685       in SOME (mk_meta_eq th) end
   686   | Const(@{const_name Orderings.less},_)$_$(Const(@{const_name Rings.divide},_)$_$_) =>
   687     let
   688       val (a,(b,c)) = Thm.dest_binop ct ||> Thm.dest_binop
   689         val _ = map is_number [a,b,c]
   690         val T = ctyp_of_term c
   691         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "less_divide_eq"}
   692       in SOME (mk_meta_eq th) end
   693   | Const(@{const_name Orderings.less_eq},_)$_$(Const(@{const_name Rings.divide},_)$_$_) =>
   694     let
   695       val (a,(b,c)) = Thm.dest_binop ct ||> Thm.dest_binop
   696         val _ = map is_number [a,b,c]
   697         val T = ctyp_of_term c
   698         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "le_divide_eq"}
   699       in SOME (mk_meta_eq th) end
   700   | Const(@{const_name HOL.eq},_)$_$(Const(@{const_name Rings.divide},_)$_$_) =>
   701     let
   702       val (a,(b,c)) = Thm.dest_binop ct ||> Thm.dest_binop
   703         val _ = map is_number [a,b,c]
   704         val T = ctyp_of_term c
   705         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "eq_divide_eq"}
   706       in SOME (mk_meta_eq th) end
   707   | _ => NONE)
   708   handle TERM _ => NONE | CTERM _ => NONE | THM _ => NONE
   709 
   710 val add_frac_frac_simproc =
   711        make_simproc {lhss = [@{cpat "(?x::?'a::field)/?y + (?w::?'a::field)/?z"}],
   712                      name = "add_frac_frac_simproc",
   713                      proc = proc, identifier = []}
   714 
   715 val add_frac_num_simproc =
   716        make_simproc {lhss = [@{cpat "(?x::?'a::field)/?y + ?z"}, @{cpat "?z + (?x::?'a::field)/?y"}],
   717                      name = "add_frac_num_simproc",
   718                      proc = proc2, identifier = []}
   719 
   720 val ord_frac_simproc =
   721   make_simproc
   722     {lhss = [@{cpat "(?a::(?'a::{field, ord}))/?b < ?c"},
   723              @{cpat "(?a::(?'a::{field, ord}))/?b <= ?c"},
   724              @{cpat "?c < (?a::(?'a::{field, ord}))/?b"},
   725              @{cpat "?c <= (?a::(?'a::{field, ord}))/?b"},
   726              @{cpat "?c = ((?a::(?'a::{field, ord}))/?b)"},
   727              @{cpat "((?a::(?'a::{field, ord}))/ ?b) = ?c"}],
   728              name = "ord_frac_simproc", proc = proc3, identifier = []}
   729 
   730 val ths = [@{thm "mult_numeral_1"}, @{thm "mult_numeral_1_right"},
   731            @{thm "divide_Numeral1"},
   732            @{thm "divide_zero"}, @{thm "divide_Numeral0"},
   733            @{thm "divide_divide_eq_left"}, 
   734            @{thm "times_divide_eq_left"}, @{thm "times_divide_eq_right"},
   735            @{thm "times_divide_times_eq"},
   736            @{thm "divide_divide_eq_right"},
   737            @{thm "diff_minus"}, @{thm "minus_divide_left"},
   738            @{thm "Numeral1_eq1_nat"}, @{thm "add_divide_distrib"} RS sym,
   739            @{thm field_divide_inverse} RS sym, @{thm inverse_divide}, 
   740            Conv.fconv_rule (Conv.arg_conv (Conv.arg1_conv (Conv.rewr_conv (mk_meta_eq @{thm mult_commute}))))   
   741            (@{thm field_divide_inverse} RS sym)]
   742 
   743 in
   744 
   745 val field_comp_conv = (Simplifier.rewrite
   746 (HOL_basic_ss addsimps @{thms "semiring_norm"}
   747               addsimps ths addsimps @{thms simp_thms}
   748               addsimprocs field_cancel_numeral_factors
   749                addsimprocs [add_frac_frac_simproc, add_frac_num_simproc,
   750                             ord_frac_simproc]
   751                 addcongs [@{thm "if_weak_cong"}]))
   752 then_conv (Simplifier.rewrite (HOL_basic_ss addsimps
   753   [@{thm numeral_1_eq_1},@{thm numeral_0_eq_0}] @ @{thms numerals(1-2)}))
   754 
   755 end
   756 
   757 end;
   758 
   759 Addsimprocs Numeral_Simprocs.cancel_numerals;
   760 Addsimprocs [Numeral_Simprocs.combine_numerals];
   761 Addsimprocs [Numeral_Simprocs.field_combine_numerals];
   762 Addsimprocs [Numeral_Simprocs.assoc_fold_simproc];
   763 
   764 (*examples:
   765 print_depth 22;
   766 set timing;
   767 set trace_simp;
   768 fun test s = (Goal s, by (Simp_tac 1));
   769 
   770 test "l + 2 + 2 + 2 + (l + 2) + (oo + 2) = (uu::int)";
   771 
   772 test "2*u = (u::int)";
   773 test "(i + j + 12 + (k::int)) - 15 = y";
   774 test "(i + j + 12 + (k::int)) - 5 = y";
   775 
   776 test "y - b < (b::int)";
   777 test "y - (3*b + c) < (b::int) - 2*c";
   778 
   779 test "(2*x - (u*v) + y) - v*3*u = (w::int)";
   780 test "(2*x*u*v + (u*v)*4 + y) - v*u*4 = (w::int)";
   781 test "(2*x*u*v + (u*v)*4 + y) - v*u = (w::int)";
   782 test "u*v - (x*u*v + (u*v)*4 + y) = (w::int)";
   783 
   784 test "(i + j + 12 + (k::int)) = u + 15 + y";
   785 test "(i + j*2 + 12 + (k::int)) = j + 5 + y";
   786 
   787 test "2*y + 3*z + 6*w + 2*y + 3*z + 2*u = 2*y' + 3*z' + 6*w' + 2*y' + 3*z' + u + (vv::int)";
   788 
   789 test "a + -(b+c) + b = (d::int)";
   790 test "a + -(b+c) - b = (d::int)";
   791 
   792 (*negative numerals*)
   793 test "(i + j + -2 + (k::int)) - (u + 5 + y) = zz";
   794 test "(i + j + -3 + (k::int)) < u + 5 + y";
   795 test "(i + j + 3 + (k::int)) < u + -6 + y";
   796 test "(i + j + -12 + (k::int)) - 15 = y";
   797 test "(i + j + 12 + (k::int)) - -15 = y";
   798 test "(i + j + -12 + (k::int)) - -15 = y";
   799 *)
   800 
   801 Addsimprocs Numeral_Simprocs.cancel_numeral_factors;
   802 
   803 (*examples:
   804 print_depth 22;
   805 set timing;
   806 set trace_simp;
   807 fun test s = (Goal s; by (Simp_tac 1));
   808 
   809 test "9*x = 12 * (y::int)";
   810 test "(9*x) div (12 * (y::int)) = z";
   811 test "9*x < 12 * (y::int)";
   812 test "9*x <= 12 * (y::int)";
   813 
   814 test "-99*x = 132 * (y::int)";
   815 test "(-99*x) div (132 * (y::int)) = z";
   816 test "-99*x < 132 * (y::int)";
   817 test "-99*x <= 132 * (y::int)";
   818 
   819 test "999*x = -396 * (y::int)";
   820 test "(999*x) div (-396 * (y::int)) = z";
   821 test "999*x < -396 * (y::int)";
   822 test "999*x <= -396 * (y::int)";
   823 
   824 test "-99*x = -81 * (y::int)";
   825 test "(-99*x) div (-81 * (y::int)) = z";
   826 test "-99*x <= -81 * (y::int)";
   827 test "-99*x < -81 * (y::int)";
   828 
   829 test "-2 * x = -1 * (y::int)";
   830 test "-2 * x = -(y::int)";
   831 test "(-2 * x) div (-1 * (y::int)) = z";
   832 test "-2 * x < -(y::int)";
   833 test "-2 * x <= -1 * (y::int)";
   834 test "-x < -23 * (y::int)";
   835 test "-x <= -23 * (y::int)";
   836 *)
   837 
   838 (*And the same examples for fields such as rat or real:
   839 test "0 <= (y::rat) * -2";
   840 test "9*x = 12 * (y::rat)";
   841 test "(9*x) / (12 * (y::rat)) = z";
   842 test "9*x < 12 * (y::rat)";
   843 test "9*x <= 12 * (y::rat)";
   844 
   845 test "-99*x = 132 * (y::rat)";
   846 test "(-99*x) / (132 * (y::rat)) = z";
   847 test "-99*x < 132 * (y::rat)";
   848 test "-99*x <= 132 * (y::rat)";
   849 
   850 test "999*x = -396 * (y::rat)";
   851 test "(999*x) / (-396 * (y::rat)) = z";
   852 test "999*x < -396 * (y::rat)";
   853 test "999*x <= -396 * (y::rat)";
   854 
   855 test  "(- ((2::rat) * x) <= 2 * y)";
   856 test "-99*x = -81 * (y::rat)";
   857 test "(-99*x) / (-81 * (y::rat)) = z";
   858 test "-99*x <= -81 * (y::rat)";
   859 test "-99*x < -81 * (y::rat)";
   860 
   861 test "-2 * x = -1 * (y::rat)";
   862 test "-2 * x = -(y::rat)";
   863 test "(-2 * x) / (-1 * (y::rat)) = z";
   864 test "-2 * x < -(y::rat)";
   865 test "-2 * x <= -1 * (y::rat)";
   866 test "-x < -23 * (y::rat)";
   867 test "-x <= -23 * (y::rat)";
   868 *)
   869 
   870 Addsimprocs Numeral_Simprocs.cancel_factors;
   871 
   872 
   873 (*examples:
   874 print_depth 22;
   875 set timing;
   876 set trace_simp;
   877 fun test s = (Goal s; by (Asm_simp_tac 1));
   878 
   879 test "x*k = k*(y::int)";
   880 test "k = k*(y::int)";
   881 test "a*(b*c) = (b::int)";
   882 test "a*(b*c) = d*(b::int)*(x*a)";
   883 
   884 test "(x*k) div (k*(y::int)) = (uu::int)";
   885 test "(k) div (k*(y::int)) = (uu::int)";
   886 test "(a*(b*c)) div ((b::int)) = (uu::int)";
   887 test "(a*(b*c)) div (d*(b::int)*(x*a)) = (uu::int)";
   888 *)
   889 
   890 (*And the same examples for fields such as rat or real:
   891 print_depth 22;
   892 set timing;
   893 set trace_simp;
   894 fun test s = (Goal s; by (Asm_simp_tac 1));
   895 
   896 test "x*k = k*(y::rat)";
   897 test "k = k*(y::rat)";
   898 test "a*(b*c) = (b::rat)";
   899 test "a*(b*c) = d*(b::rat)*(x*a)";
   900 
   901 
   902 test "(x*k) / (k*(y::rat)) = (uu::rat)";
   903 test "(k) / (k*(y::rat)) = (uu::rat)";
   904 test "(a*(b*c)) / ((b::rat)) = (uu::rat)";
   905 test "(a*(b*c)) / (d*(b::rat)*(x*a)) = (uu::rat)";
   906 
   907 (*FIXME: what do we do about this?*)
   908 test "a*(b*c)/(y*z) = d*(b::rat)*(x*a)/z";
   909 *)