src/HOL/Tools/numeral_simprocs.ML
author haftmann
Sat Sep 17 00:37:21 2011 +0200 (2011-09-17)
changeset 44945 2625de88c994
parent 44064 5bce8ff0d9ae
child 44947 8ae418dfe561
permissions -rw-r--r--
tuned
     1 (* Author:     Lawrence C Paulson, Cambridge University Computer Laboratory
     2    Copyright   2000  University of Cambridge
     3 
     4 Simprocs for the (integer) numerals.
     5 *)
     6 
     7 (*To quote from Provers/Arith/cancel_numeral_factor.ML:
     8 
     9 Cancels common coefficients in balanced expressions:
    10 
    11      u*#m ~~ u'*#m'  ==  #n*u ~~ #n'*u'
    12 
    13 where ~~ is an appropriate balancing operation (e.g. =, <=, <, div, /)
    14 and d = gcd(m,m') and n=m/d and n'=m'/d.
    15 *)
    16 
    17 signature NUMERAL_SIMPROCS =
    18 sig
    19   val prep_simproc: theory -> string * string list * (theory -> simpset -> term -> thm option)
    20     -> simproc
    21   val trans_tac: thm option -> tactic
    22   val assoc_fold_simproc: simproc
    23   val combine_numerals: simproc
    24   val cancel_numerals: simproc list
    25   val cancel_factors: simproc list
    26   val cancel_numeral_factors: simproc list
    27   val field_combine_numerals: simproc
    28   val field_cancel_numeral_factors: simproc list
    29   val num_ss: simpset
    30   val field_comp_conv: conv
    31 end;
    32 
    33 structure Numeral_Simprocs : NUMERAL_SIMPROCS =
    34 struct
    35 
    36 fun prep_simproc thy (name, pats, proc) =
    37   Simplifier.simproc_global thy name pats proc;
    38 
    39 fun trans_tac NONE  = all_tac
    40   | trans_tac (SOME th) = ALLGOALS (rtac (th RS trans));
    41 
    42 val mk_number = Arith_Data.mk_number;
    43 val mk_sum = Arith_Data.mk_sum;
    44 val long_mk_sum = Arith_Data.long_mk_sum;
    45 val dest_sum = Arith_Data.dest_sum;
    46 
    47 val mk_diff = HOLogic.mk_binop @{const_name Groups.minus};
    48 val dest_diff = HOLogic.dest_bin @{const_name Groups.minus} Term.dummyT;
    49 
    50 val mk_times = HOLogic.mk_binop @{const_name Groups.times};
    51 
    52 fun one_of T = Const(@{const_name Groups.one}, T);
    53 
    54 (* build product with trailing 1 rather than Numeral 1 in order to avoid the
    55    unnecessary restriction to type class number_ring
    56    which is not required for cancellation of common factors in divisions.
    57 *)
    58 fun mk_prod T = 
    59   let val one = one_of T
    60   fun mk [] = one
    61     | mk [t] = t
    62     | mk (t :: ts) = if t = one then mk ts else mk_times (t, mk ts)
    63   in mk end;
    64 
    65 (*This version ALWAYS includes a trailing one*)
    66 fun long_mk_prod T []        = one_of T
    67   | long_mk_prod T (t :: ts) = mk_times (t, mk_prod T ts);
    68 
    69 val dest_times = HOLogic.dest_bin @{const_name Groups.times} Term.dummyT;
    70 
    71 fun dest_prod t =
    72       let val (t,u) = dest_times t
    73       in dest_prod t @ dest_prod u end
    74       handle TERM _ => [t];
    75 
    76 fun find_first_numeral past (t::terms) =
    77         ((snd (HOLogic.dest_number t), rev past @ terms)
    78          handle TERM _ => find_first_numeral (t::past) terms)
    79   | find_first_numeral past [] = raise TERM("find_first_numeral", []);
    80 
    81 (*DON'T do the obvious simplifications; that would create special cases*)
    82 fun mk_coeff (k, t) = mk_times (mk_number (Term.fastype_of t) k, t);
    83 
    84 (*Express t as a product of (possibly) a numeral with other sorted terms*)
    85 fun dest_coeff sign (Const (@{const_name Groups.uminus}, _) $ t) = dest_coeff (~sign) t
    86   | dest_coeff sign t =
    87     let val ts = sort Term_Ord.term_ord (dest_prod t)
    88         val (n, ts') = find_first_numeral [] ts
    89                           handle TERM _ => (1, ts)
    90     in (sign*n, mk_prod (Term.fastype_of t) ts') end;
    91 
    92 (*Find first coefficient-term THAT MATCHES u*)
    93 fun find_first_coeff past u [] = raise TERM("find_first_coeff", [])
    94   | find_first_coeff past u (t::terms) =
    95         let val (n,u') = dest_coeff 1 t
    96         in if u aconv u' then (n, rev past @ terms)
    97                          else find_first_coeff (t::past) u terms
    98         end
    99         handle TERM _ => find_first_coeff (t::past) u terms;
   100 
   101 (*Fractions as pairs of ints. Can't use Rat.rat because the representation
   102   needs to preserve negative values in the denominator.*)
   103 fun mk_frac (p, q) = if q = 0 then raise Div else (p, q);
   104 
   105 (*Don't reduce fractions; sums must be proved by rule add_frac_eq.
   106   Fractions are reduced later by the cancel_numeral_factor simproc.*)
   107 fun add_frac ((p1, q1), (p2, q2)) = (p1 * q2 + p2 * q1, q1 * q2);
   108 
   109 val mk_divide = HOLogic.mk_binop @{const_name Fields.divide};
   110 
   111 (*Build term (p / q) * t*)
   112 fun mk_fcoeff ((p, q), t) =
   113   let val T = Term.fastype_of t
   114   in mk_times (mk_divide (mk_number T p, mk_number T q), t) end;
   115 
   116 (*Express t as a product of a fraction with other sorted terms*)
   117 fun dest_fcoeff sign (Const (@{const_name Groups.uminus}, _) $ t) = dest_fcoeff (~sign) t
   118   | dest_fcoeff sign (Const (@{const_name Fields.divide}, _) $ t $ u) =
   119     let val (p, t') = dest_coeff sign t
   120         val (q, u') = dest_coeff 1 u
   121     in (mk_frac (p, q), mk_divide (t', u')) end
   122   | dest_fcoeff sign t =
   123     let val (p, t') = dest_coeff sign t
   124         val T = Term.fastype_of t
   125     in (mk_frac (p, 1), mk_divide (t', one_of T)) end;
   126 
   127 
   128 (** New term ordering so that AC-rewriting brings numerals to the front **)
   129 
   130 (*Order integers by absolute value and then by sign. The standard integer
   131   ordering is not well-founded.*)
   132 fun num_ord (i,j) =
   133   (case int_ord (abs i, abs j) of
   134     EQUAL => int_ord (Int.sign i, Int.sign j) 
   135   | ord => ord);
   136 
   137 (*This resembles Term_Ord.term_ord, but it puts binary numerals before other
   138   non-atomic terms.*)
   139 local open Term 
   140 in 
   141 fun numterm_ord (Abs (_, T, t), Abs(_, U, u)) =
   142       (case numterm_ord (t, u) of EQUAL => Term_Ord.typ_ord (T, U) | ord => ord)
   143   | numterm_ord
   144      (Const(@{const_name Int.number_of}, _) $ v, Const(@{const_name Int.number_of}, _) $ w) =
   145      num_ord (HOLogic.dest_numeral v, HOLogic.dest_numeral w)
   146   | numterm_ord (Const(@{const_name Int.number_of}, _) $ _, _) = LESS
   147   | numterm_ord (_, Const(@{const_name Int.number_of}, _) $ _) = GREATER
   148   | numterm_ord (t, u) =
   149       (case int_ord (size_of_term t, size_of_term u) of
   150         EQUAL =>
   151           let val (f, ts) = strip_comb t and (g, us) = strip_comb u in
   152             (case Term_Ord.hd_ord (f, g) of EQUAL => numterms_ord (ts, us) | ord => ord)
   153           end
   154       | ord => ord)
   155 and numterms_ord (ts, us) = list_ord numterm_ord (ts, us)
   156 end;
   157 
   158 fun numtermless tu = (numterm_ord tu = LESS);
   159 
   160 val num_ss = HOL_ss settermless numtermless;
   161 
   162 (*Maps 0 to Numeral0 and 1 to Numeral1 so that arithmetic isn't complicated by the abstract 0 and 1.*)
   163 val numeral_syms = [@{thm numeral_0_eq_0} RS sym, @{thm numeral_1_eq_1} RS sym];
   164 
   165 (*Simplify Numeral0+n, n+Numeral0, Numeral1*n, n*Numeral1, 1*x, x*1, x/1 *)
   166 val add_0s =  @{thms add_0s};
   167 val mult_1s = @{thms mult_1s mult_1_left mult_1_right divide_1};
   168 
   169 (*Simplify inverse Numeral1, a/Numeral1*)
   170 val inverse_1s = [@{thm inverse_numeral_1}];
   171 val divide_1s = [@{thm divide_numeral_1}];
   172 
   173 (*To perform binary arithmetic.  The "left" rewriting handles patterns
   174   created by the Numeral_Simprocs, such as 3 * (5 * x). *)
   175 val simps = [@{thm numeral_0_eq_0} RS sym, @{thm numeral_1_eq_1} RS sym,
   176                  @{thm add_number_of_left}, @{thm mult_number_of_left}] @
   177                 @{thms arith_simps} @ @{thms rel_simps};
   178 
   179 (*Binary arithmetic BUT NOT ADDITION since it may collapse adjacent terms
   180   during re-arrangement*)
   181 val non_add_simps =
   182   subtract Thm.eq_thm [@{thm add_number_of_left}, @{thm number_of_add} RS sym] simps;
   183 
   184 (*To evaluate binary negations of coefficients*)
   185 val minus_simps = [@{thm numeral_m1_eq_minus_1} RS sym, @{thm number_of_minus} RS sym] @
   186                    @{thms minus_bin_simps} @ @{thms pred_bin_simps};
   187 
   188 (*To let us treat subtraction as addition*)
   189 val diff_simps = [@{thm diff_minus}, @{thm minus_add_distrib}, @{thm minus_minus}];
   190 
   191 (*To let us treat division as multiplication*)
   192 val divide_simps = [@{thm divide_inverse}, @{thm inverse_mult_distrib}, @{thm inverse_inverse_eq}];
   193 
   194 (*push the unary minus down*)
   195 val minus_mult_eq_1_to_2 = @{lemma "- (a::'a::ring) * b = a * - b" by simp};
   196 
   197 (*to extract again any uncancelled minuses*)
   198 val minus_from_mult_simps =
   199     [@{thm minus_minus}, @{thm mult_minus_left}, @{thm mult_minus_right}];
   200 
   201 (*combine unary minus with numeric literals, however nested within a product*)
   202 val mult_minus_simps =
   203     [@{thm mult_assoc}, @{thm minus_mult_left}, minus_mult_eq_1_to_2];
   204 
   205 val norm_ss1 = num_ss addsimps numeral_syms @ add_0s @ mult_1s @
   206   diff_simps @ minus_simps @ @{thms add_ac}
   207 val norm_ss2 = num_ss addsimps non_add_simps @ mult_minus_simps
   208 val norm_ss3 = num_ss addsimps minus_from_mult_simps @ @{thms add_ac} @ @{thms mult_ac}
   209 
   210 structure CancelNumeralsCommon =
   211 struct
   212   val mk_sum = mk_sum
   213   val dest_sum = dest_sum
   214   val mk_coeff = mk_coeff
   215   val dest_coeff = dest_coeff 1
   216   val find_first_coeff = find_first_coeff []
   217   val trans_tac = K trans_tac
   218 
   219   fun norm_tac ss =
   220     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   221     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   222     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   223 
   224   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps
   225   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   226   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s)
   227   val prove_conv = Arith_Data.prove_conv
   228 end;
   229 
   230 structure EqCancelNumerals = CancelNumeralsFun
   231  (open CancelNumeralsCommon
   232   val mk_bal   = HOLogic.mk_eq
   233   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   234   val bal_add1 = @{thm eq_add_iff1} RS trans
   235   val bal_add2 = @{thm eq_add_iff2} RS trans
   236 );
   237 
   238 structure LessCancelNumerals = CancelNumeralsFun
   239  (open CancelNumeralsCommon
   240   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   241   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   242   val bal_add1 = @{thm less_add_iff1} RS trans
   243   val bal_add2 = @{thm less_add_iff2} RS trans
   244 );
   245 
   246 structure LeCancelNumerals = CancelNumeralsFun
   247  (open CancelNumeralsCommon
   248   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   249   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   250   val bal_add1 = @{thm le_add_iff1} RS trans
   251   val bal_add2 = @{thm le_add_iff2} RS trans
   252 );
   253 
   254 val cancel_numerals =
   255   map (prep_simproc @{theory})
   256    [("inteq_cancel_numerals",
   257      ["(l::'a::number_ring) + m = n",
   258       "(l::'a::number_ring) = m + n",
   259       "(l::'a::number_ring) - m = n",
   260       "(l::'a::number_ring) = m - n",
   261       "(l::'a::number_ring) * m = n",
   262       "(l::'a::number_ring) = m * n"],
   263      K EqCancelNumerals.proc),
   264     ("intless_cancel_numerals",
   265      ["(l::'a::{linordered_idom,number_ring}) + m < n",
   266       "(l::'a::{linordered_idom,number_ring}) < m + n",
   267       "(l::'a::{linordered_idom,number_ring}) - m < n",
   268       "(l::'a::{linordered_idom,number_ring}) < m - n",
   269       "(l::'a::{linordered_idom,number_ring}) * m < n",
   270       "(l::'a::{linordered_idom,number_ring}) < m * n"],
   271      K LessCancelNumerals.proc),
   272     ("intle_cancel_numerals",
   273      ["(l::'a::{linordered_idom,number_ring}) + m <= n",
   274       "(l::'a::{linordered_idom,number_ring}) <= m + n",
   275       "(l::'a::{linordered_idom,number_ring}) - m <= n",
   276       "(l::'a::{linordered_idom,number_ring}) <= m - n",
   277       "(l::'a::{linordered_idom,number_ring}) * m <= n",
   278       "(l::'a::{linordered_idom,number_ring}) <= m * n"],
   279      K LeCancelNumerals.proc)];
   280 
   281 structure CombineNumeralsData =
   282 struct
   283   type coeff = int
   284   val iszero = (fn x => x = 0)
   285   val add  = op +
   286   val mk_sum = long_mk_sum    (*to work for e.g. 2*x + 3*x *)
   287   val dest_sum = dest_sum
   288   val mk_coeff = mk_coeff
   289   val dest_coeff = dest_coeff 1
   290   val left_distrib = @{thm combine_common_factor} RS trans
   291   val prove_conv = Arith_Data.prove_conv_nohyps
   292   val trans_tac = K trans_tac
   293 
   294   fun norm_tac ss =
   295     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   296     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   297     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   298 
   299   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps
   300   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   301   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s)
   302 end;
   303 
   304 structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData);
   305 
   306 (*Version for fields, where coefficients can be fractions*)
   307 structure FieldCombineNumeralsData =
   308 struct
   309   type coeff = int * int
   310   val iszero = (fn (p, q) => p = 0)
   311   val add = add_frac
   312   val mk_sum = long_mk_sum
   313   val dest_sum = dest_sum
   314   val mk_coeff = mk_fcoeff
   315   val dest_coeff = dest_fcoeff 1
   316   val left_distrib = @{thm combine_common_factor} RS trans
   317   val prove_conv = Arith_Data.prove_conv_nohyps
   318   val trans_tac = K trans_tac
   319 
   320   val norm_ss1a = norm_ss1 addsimps inverse_1s @ divide_simps
   321   fun norm_tac ss =
   322     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1a))
   323     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   324     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   325 
   326   val numeral_simp_ss = HOL_ss addsimps add_0s @ simps @ [@{thm add_frac_eq}]
   327   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   328   val simplify_meta_eq = Arith_Data.simplify_meta_eq (add_0s @ mult_1s @ divide_1s)
   329 end;
   330 
   331 structure FieldCombineNumerals = CombineNumeralsFun(FieldCombineNumeralsData);
   332 
   333 val combine_numerals =
   334   prep_simproc @{theory}
   335     ("int_combine_numerals", 
   336      ["(i::'a::number_ring) + j", "(i::'a::number_ring) - j"], 
   337      K CombineNumerals.proc);
   338 
   339 val field_combine_numerals =
   340   prep_simproc @{theory}
   341     ("field_combine_numerals", 
   342      ["(i::'a::{field_inverse_zero, number_ring}) + j",
   343       "(i::'a::{field_inverse_zero, number_ring}) - j"], 
   344      K FieldCombineNumerals.proc);
   345 
   346 (** Constant folding for multiplication in semirings **)
   347 
   348 (*We do not need folding for addition: combine_numerals does the same thing*)
   349 
   350 structure Semiring_Times_Assoc_Data : ASSOC_FOLD_DATA =
   351 struct
   352   val assoc_ss = HOL_ss addsimps @{thms mult_ac}
   353   val eq_reflection = eq_reflection
   354   val is_numeral = can HOLogic.dest_number
   355 end;
   356 
   357 structure Semiring_Times_Assoc = Assoc_Fold (Semiring_Times_Assoc_Data);
   358 
   359 val assoc_fold_simproc =
   360   prep_simproc @{theory}
   361    ("semiring_assoc_fold", ["(a::'a::comm_semiring_1_cancel) * b"],
   362     K Semiring_Times_Assoc.proc);
   363 
   364 structure CancelNumeralFactorCommon =
   365 struct
   366   val mk_coeff = mk_coeff
   367   val dest_coeff = dest_coeff 1
   368   val trans_tac = K trans_tac
   369 
   370   val norm_ss1 = HOL_ss addsimps minus_from_mult_simps @ mult_1s
   371   val norm_ss2 = HOL_ss addsimps simps @ mult_minus_simps
   372   val norm_ss3 = HOL_ss addsimps @{thms mult_ac}
   373   fun norm_tac ss =
   374     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   375     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   376     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   377 
   378   val numeral_simp_ss = HOL_ss addsimps
   379     [@{thm eq_number_of_eq}, @{thm less_number_of}, @{thm le_number_of}] @ simps
   380   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   381   val simplify_meta_eq = Arith_Data.simplify_meta_eq
   382     [@{thm Nat.add_0}, @{thm Nat.add_0_right}, @{thm mult_zero_left},
   383       @{thm mult_zero_right}, @{thm mult_Bit1}, @{thm mult_1_right}];
   384   val prove_conv = Arith_Data.prove_conv
   385 end
   386 
   387 (*Version for semiring_div*)
   388 structure DivCancelNumeralFactor = CancelNumeralFactorFun
   389  (open CancelNumeralFactorCommon
   390   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   391   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   392   val cancel = @{thm div_mult_mult1} RS trans
   393   val neg_exchanges = false
   394 )
   395 
   396 (*Version for fields*)
   397 structure DivideCancelNumeralFactor = CancelNumeralFactorFun
   398  (open CancelNumeralFactorCommon
   399   val mk_bal   = HOLogic.mk_binop @{const_name Fields.divide}
   400   val dest_bal = HOLogic.dest_bin @{const_name Fields.divide} Term.dummyT
   401   val cancel = @{thm mult_divide_mult_cancel_left} RS trans
   402   val neg_exchanges = false
   403 )
   404 
   405 structure EqCancelNumeralFactor = CancelNumeralFactorFun
   406  (open CancelNumeralFactorCommon
   407   val mk_bal   = HOLogic.mk_eq
   408   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   409   val cancel = @{thm mult_cancel_left} RS trans
   410   val neg_exchanges = false
   411 )
   412 
   413 structure LessCancelNumeralFactor = CancelNumeralFactorFun
   414  (open CancelNumeralFactorCommon
   415   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   416   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   417   val cancel = @{thm mult_less_cancel_left} RS trans
   418   val neg_exchanges = true
   419 )
   420 
   421 structure LeCancelNumeralFactor = CancelNumeralFactorFun
   422  (open CancelNumeralFactorCommon
   423   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   424   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   425   val cancel = @{thm mult_le_cancel_left} RS trans
   426   val neg_exchanges = true
   427 )
   428 
   429 val cancel_numeral_factors =
   430   map (prep_simproc @{theory})
   431    [("ring_eq_cancel_numeral_factor",
   432      ["(l::'a::{idom,number_ring}) * m = n",
   433       "(l::'a::{idom,number_ring}) = m * n"],
   434      K EqCancelNumeralFactor.proc),
   435     ("ring_less_cancel_numeral_factor",
   436      ["(l::'a::{linordered_idom,number_ring}) * m < n",
   437       "(l::'a::{linordered_idom,number_ring}) < m * n"],
   438      K LessCancelNumeralFactor.proc),
   439     ("ring_le_cancel_numeral_factor",
   440      ["(l::'a::{linordered_idom,number_ring}) * m <= n",
   441       "(l::'a::{linordered_idom,number_ring}) <= m * n"],
   442      K LeCancelNumeralFactor.proc),
   443     ("int_div_cancel_numeral_factors",
   444      ["((l::'a::{semiring_div,number_ring}) * m) div n",
   445       "(l::'a::{semiring_div,number_ring}) div (m * n)"],
   446      K DivCancelNumeralFactor.proc),
   447     ("divide_cancel_numeral_factor",
   448      ["((l::'a::{field_inverse_zero,number_ring}) * m) / n",
   449       "(l::'a::{field_inverse_zero,number_ring}) / (m * n)",
   450       "((number_of v)::'a::{field_inverse_zero,number_ring}) / (number_of w)"],
   451      K DivideCancelNumeralFactor.proc)];
   452 
   453 val field_cancel_numeral_factors =
   454   map (prep_simproc @{theory})
   455    [("field_eq_cancel_numeral_factor",
   456      ["(l::'a::{field,number_ring}) * m = n",
   457       "(l::'a::{field,number_ring}) = m * n"],
   458      K EqCancelNumeralFactor.proc),
   459     ("field_cancel_numeral_factor",
   460      ["((l::'a::{field_inverse_zero,number_ring}) * m) / n",
   461       "(l::'a::{field_inverse_zero,number_ring}) / (m * n)",
   462       "((number_of v)::'a::{field_inverse_zero,number_ring}) / (number_of w)"],
   463      K DivideCancelNumeralFactor.proc)]
   464 
   465 
   466 (** Declarations for ExtractCommonTerm **)
   467 
   468 (*Find first term that matches u*)
   469 fun find_first_t past u []         = raise TERM ("find_first_t", [])
   470   | find_first_t past u (t::terms) =
   471         if u aconv t then (rev past @ terms)
   472         else find_first_t (t::past) u terms
   473         handle TERM _ => find_first_t (t::past) u terms;
   474 
   475 (** Final simplification for the CancelFactor simprocs **)
   476 val simplify_one = Arith_Data.simplify_meta_eq  
   477   [@{thm mult_1_left}, @{thm mult_1_right}, @{thm div_by_1}, @{thm numeral_1_eq_1}];
   478 
   479 fun cancel_simplify_meta_eq ss cancel_th th =
   480     simplify_one ss (([th, cancel_th]) MRS trans);
   481 
   482 local
   483   val Tp_Eq = Thm.reflexive (Thm.cterm_of @{theory HOL} HOLogic.Trueprop)
   484   fun Eq_True_elim Eq = 
   485     Thm.equal_elim (Thm.combination Tp_Eq (Thm.symmetric Eq)) @{thm TrueI}
   486 in
   487 fun sign_conv pos_th neg_th ss t =
   488   let val T = fastype_of t;
   489       val zero = Const(@{const_name Groups.zero}, T);
   490       val less = Const(@{const_name Orderings.less}, [T,T] ---> HOLogic.boolT);
   491       val pos = less $ zero $ t and neg = less $ t $ zero
   492       fun prove p =
   493         Option.map Eq_True_elim (Lin_Arith.simproc ss p)
   494         handle THM _ => NONE
   495     in case prove pos of
   496          SOME th => SOME(th RS pos_th)
   497        | NONE => (case prove neg of
   498                     SOME th => SOME(th RS neg_th)
   499                   | NONE => NONE)
   500     end;
   501 end
   502 
   503 structure CancelFactorCommon =
   504 struct
   505   val mk_sum = long_mk_prod
   506   val dest_sum = dest_prod
   507   val mk_coeff = mk_coeff
   508   val dest_coeff = dest_coeff
   509   val find_first = find_first_t []
   510   val trans_tac = K trans_tac
   511   val norm_ss = HOL_ss addsimps mult_1s @ @{thms mult_ac}
   512   fun norm_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss))
   513   val simplify_meta_eq  = cancel_simplify_meta_eq 
   514   val prove_conv = Arith_Data.prove_conv
   515 end;
   516 
   517 (*mult_cancel_left requires a ring with no zero divisors.*)
   518 structure EqCancelFactor = ExtractCommonTermFun
   519  (open CancelFactorCommon
   520   val mk_bal   = HOLogic.mk_eq
   521   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   522   fun simp_conv _ _ = SOME @{thm mult_cancel_left}
   523 );
   524 
   525 (*for ordered rings*)
   526 structure LeCancelFactor = ExtractCommonTermFun
   527  (open CancelFactorCommon
   528   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   529   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   530   val simp_conv = sign_conv
   531     @{thm mult_le_cancel_left_pos} @{thm mult_le_cancel_left_neg}
   532 );
   533 
   534 (*for ordered rings*)
   535 structure LessCancelFactor = ExtractCommonTermFun
   536  (open CancelFactorCommon
   537   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   538   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   539   val simp_conv = sign_conv
   540     @{thm mult_less_cancel_left_pos} @{thm mult_less_cancel_left_neg}
   541 );
   542 
   543 (*for semirings with division*)
   544 structure DivCancelFactor = ExtractCommonTermFun
   545  (open CancelFactorCommon
   546   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   547   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   548   fun simp_conv _ _ = SOME @{thm div_mult_mult1_if}
   549 );
   550 
   551 structure ModCancelFactor = ExtractCommonTermFun
   552  (open CancelFactorCommon
   553   val mk_bal   = HOLogic.mk_binop @{const_name Divides.mod}
   554   val dest_bal = HOLogic.dest_bin @{const_name Divides.mod} Term.dummyT
   555   fun simp_conv _ _ = SOME @{thm mod_mult_mult1}
   556 );
   557 
   558 (*for idoms*)
   559 structure DvdCancelFactor = ExtractCommonTermFun
   560  (open CancelFactorCommon
   561   val mk_bal   = HOLogic.mk_binrel @{const_name Rings.dvd}
   562   val dest_bal = HOLogic.dest_bin @{const_name Rings.dvd} Term.dummyT
   563   fun simp_conv _ _ = SOME @{thm dvd_mult_cancel_left}
   564 );
   565 
   566 (*Version for all fields, including unordered ones (type complex).*)
   567 structure DivideCancelFactor = ExtractCommonTermFun
   568  (open CancelFactorCommon
   569   val mk_bal   = HOLogic.mk_binop @{const_name Fields.divide}
   570   val dest_bal = HOLogic.dest_bin @{const_name Fields.divide} Term.dummyT
   571   fun simp_conv _ _ = SOME @{thm mult_divide_mult_cancel_left_if}
   572 );
   573 
   574 val cancel_factors =
   575   map (prep_simproc @{theory})
   576    [("ring_eq_cancel_factor",
   577      ["(l::'a::idom) * m = n",
   578       "(l::'a::idom) = m * n"],
   579      K EqCancelFactor.proc),
   580     ("linordered_ring_le_cancel_factor",
   581      ["(l::'a::linordered_ring) * m <= n",
   582       "(l::'a::linordered_ring) <= m * n"],
   583      K LeCancelFactor.proc),
   584     ("linordered_ring_less_cancel_factor",
   585      ["(l::'a::linordered_ring) * m < n",
   586       "(l::'a::linordered_ring) < m * n"],
   587      K LessCancelFactor.proc),
   588     ("int_div_cancel_factor",
   589      ["((l::'a::semiring_div) * m) div n", "(l::'a::semiring_div) div (m * n)"],
   590      K DivCancelFactor.proc),
   591     ("int_mod_cancel_factor",
   592      ["((l::'a::semiring_div) * m) mod n", "(l::'a::semiring_div) mod (m * n)"],
   593      K ModCancelFactor.proc),
   594     ("dvd_cancel_factor",
   595      ["((l::'a::idom) * m) dvd n", "(l::'a::idom) dvd (m * n)"],
   596      K DvdCancelFactor.proc),
   597     ("divide_cancel_factor",
   598      ["((l::'a::field_inverse_zero) * m) / n",
   599       "(l::'a::field_inverse_zero) / (m * n)"],
   600      K DivideCancelFactor.proc)];
   601 
   602 local
   603  val zr = @{cpat "0"}
   604  val zT = ctyp_of_term zr
   605  val geq = @{cpat HOL.eq}
   606  val eqT = Thm.dest_ctyp (ctyp_of_term geq) |> hd
   607  val add_frac_eq = mk_meta_eq @{thm "add_frac_eq"}
   608  val add_frac_num = mk_meta_eq @{thm "add_frac_num"}
   609  val add_num_frac = mk_meta_eq @{thm "add_num_frac"}
   610 
   611  fun prove_nz ss T t =
   612     let
   613       val z = Thm.instantiate_cterm ([(zT,T)],[]) zr
   614       val eq = Thm.instantiate_cterm ([(eqT,T)],[]) geq
   615       val th = Simplifier.rewrite (ss addsimps @{thms simp_thms})
   616            (Thm.capply @{cterm "Trueprop"} (Thm.capply @{cterm "Not"}
   617                   (Thm.capply (Thm.capply eq t) z)))
   618     in Thm.equal_elim (Thm.symmetric th) TrueI
   619     end
   620 
   621  fun proc phi ss ct =
   622   let
   623     val ((x,y),(w,z)) =
   624          (Thm.dest_binop #> (fn (a,b) => (Thm.dest_binop a, Thm.dest_binop b))) ct
   625     val _ = map (HOLogic.dest_number o term_of) [x,y,z,w]
   626     val T = ctyp_of_term x
   627     val [y_nz, z_nz] = map (prove_nz ss T) [y, z]
   628     val th = instantiate' [SOME T] (map SOME [y,z,x,w]) add_frac_eq
   629   in SOME (Thm.implies_elim (Thm.implies_elim th y_nz) z_nz)
   630   end
   631   handle CTERM _ => NONE | TERM _ => NONE | THM _ => NONE
   632 
   633  fun proc2 phi ss ct =
   634   let
   635     val (l,r) = Thm.dest_binop ct
   636     val T = ctyp_of_term l
   637   in (case (term_of l, term_of r) of
   638       (Const(@{const_name Fields.divide},_)$_$_, _) =>
   639         let val (x,y) = Thm.dest_binop l val z = r
   640             val _ = map (HOLogic.dest_number o term_of) [x,y,z]
   641             val ynz = prove_nz ss T y
   642         in SOME (Thm.implies_elim (instantiate' [SOME T] (map SOME [y,x,z]) add_frac_num) ynz)
   643         end
   644      | (_, Const (@{const_name Fields.divide},_)$_$_) =>
   645         let val (x,y) = Thm.dest_binop r val z = l
   646             val _ = map (HOLogic.dest_number o term_of) [x,y,z]
   647             val ynz = prove_nz ss T y
   648         in SOME (Thm.implies_elim (instantiate' [SOME T] (map SOME [y,z,x]) add_num_frac) ynz)
   649         end
   650      | _ => NONE)
   651   end
   652   handle CTERM _ => NONE | TERM _ => NONE | THM _ => NONE
   653 
   654  fun is_number (Const(@{const_name Fields.divide},_)$a$b) = is_number a andalso is_number b
   655    | is_number t = can HOLogic.dest_number t
   656 
   657  val is_number = is_number o term_of
   658 
   659  fun proc3 phi ss ct =
   660   (case term_of ct of
   661     Const(@{const_name Orderings.less},_)$(Const(@{const_name Fields.divide},_)$_$_)$_ =>
   662       let
   663         val ((a,b),c) = Thm.dest_binop ct |>> Thm.dest_binop
   664         val _ = map is_number [a,b,c]
   665         val T = ctyp_of_term c
   666         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "divide_less_eq"}
   667       in SOME (mk_meta_eq th) end
   668   | Const(@{const_name Orderings.less_eq},_)$(Const(@{const_name Fields.divide},_)$_$_)$_ =>
   669       let
   670         val ((a,b),c) = Thm.dest_binop ct |>> Thm.dest_binop
   671         val _ = map is_number [a,b,c]
   672         val T = ctyp_of_term c
   673         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "divide_le_eq"}
   674       in SOME (mk_meta_eq th) end
   675   | Const(@{const_name HOL.eq},_)$(Const(@{const_name Fields.divide},_)$_$_)$_ =>
   676       let
   677         val ((a,b),c) = Thm.dest_binop ct |>> Thm.dest_binop
   678         val _ = map is_number [a,b,c]
   679         val T = ctyp_of_term c
   680         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "divide_eq_eq"}
   681       in SOME (mk_meta_eq th) end
   682   | Const(@{const_name Orderings.less},_)$_$(Const(@{const_name Fields.divide},_)$_$_) =>
   683     let
   684       val (a,(b,c)) = Thm.dest_binop ct ||> Thm.dest_binop
   685         val _ = map is_number [a,b,c]
   686         val T = ctyp_of_term c
   687         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "less_divide_eq"}
   688       in SOME (mk_meta_eq th) end
   689   | Const(@{const_name Orderings.less_eq},_)$_$(Const(@{const_name Fields.divide},_)$_$_) =>
   690     let
   691       val (a,(b,c)) = Thm.dest_binop ct ||> Thm.dest_binop
   692         val _ = map is_number [a,b,c]
   693         val T = ctyp_of_term c
   694         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "le_divide_eq"}
   695       in SOME (mk_meta_eq th) end
   696   | Const(@{const_name HOL.eq},_)$_$(Const(@{const_name Fields.divide},_)$_$_) =>
   697     let
   698       val (a,(b,c)) = Thm.dest_binop ct ||> Thm.dest_binop
   699         val _ = map is_number [a,b,c]
   700         val T = ctyp_of_term c
   701         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "eq_divide_eq"}
   702       in SOME (mk_meta_eq th) end
   703   | _ => NONE)
   704   handle TERM _ => NONE | CTERM _ => NONE | THM _ => NONE
   705 
   706 val add_frac_frac_simproc =
   707        make_simproc {lhss = [@{cpat "(?x::?'a::field)/?y + (?w::?'a::field)/?z"}],
   708                      name = "add_frac_frac_simproc",
   709                      proc = proc, identifier = []}
   710 
   711 val add_frac_num_simproc =
   712        make_simproc {lhss = [@{cpat "(?x::?'a::field)/?y + ?z"}, @{cpat "?z + (?x::?'a::field)/?y"}],
   713                      name = "add_frac_num_simproc",
   714                      proc = proc2, identifier = []}
   715 
   716 val ord_frac_simproc =
   717   make_simproc
   718     {lhss = [@{cpat "(?a::(?'a::{field, ord}))/?b < ?c"},
   719              @{cpat "(?a::(?'a::{field, ord}))/?b <= ?c"},
   720              @{cpat "?c < (?a::(?'a::{field, ord}))/?b"},
   721              @{cpat "?c <= (?a::(?'a::{field, ord}))/?b"},
   722              @{cpat "?c = ((?a::(?'a::{field, ord}))/?b)"},
   723              @{cpat "((?a::(?'a::{field, ord}))/ ?b) = ?c"}],
   724              name = "ord_frac_simproc", proc = proc3, identifier = []}
   725 
   726 val ths = [@{thm "mult_numeral_1"}, @{thm "mult_numeral_1_right"},
   727            @{thm "divide_Numeral1"},
   728            @{thm "divide_zero"}, @{thm "divide_Numeral0"},
   729            @{thm "divide_divide_eq_left"}, 
   730            @{thm "times_divide_eq_left"}, @{thm "times_divide_eq_right"},
   731            @{thm "times_divide_times_eq"},
   732            @{thm "divide_divide_eq_right"},
   733            @{thm "diff_minus"}, @{thm "minus_divide_left"},
   734            @{thm "Numeral1_eq1_nat"}, @{thm "add_divide_distrib"} RS sym,
   735            @{thm field_divide_inverse} RS sym, @{thm inverse_divide}, 
   736            Conv.fconv_rule (Conv.arg_conv (Conv.arg1_conv (Conv.rewr_conv (mk_meta_eq @{thm mult_commute}))))   
   737            (@{thm field_divide_inverse} RS sym)]
   738 
   739 in
   740 
   741 val field_comp_conv = (Simplifier.rewrite
   742 (HOL_basic_ss addsimps @{thms "semiring_norm"}
   743               addsimps ths addsimps @{thms simp_thms}
   744               addsimprocs field_cancel_numeral_factors
   745                addsimprocs [add_frac_frac_simproc, add_frac_num_simproc,
   746                             ord_frac_simproc]
   747                 addcongs [@{thm "if_weak_cong"}]))
   748 then_conv (Simplifier.rewrite (HOL_basic_ss addsimps
   749   [@{thm numeral_1_eq_1},@{thm numeral_0_eq_0}] @ @{thms numerals(1-2)}))
   750 
   751 end
   752 
   753 end;
   754 
   755 Addsimprocs Numeral_Simprocs.cancel_numerals;
   756 Addsimprocs [Numeral_Simprocs.combine_numerals];
   757 Addsimprocs [Numeral_Simprocs.field_combine_numerals];
   758 Addsimprocs [Numeral_Simprocs.assoc_fold_simproc];
   759 
   760 (*examples:
   761 print_depth 22;
   762 set timing;
   763 set simp_trace;
   764 fun test s = (Goal s, by (Simp_tac 1));
   765 
   766 test "l + 2 + 2 + 2 + (l + 2) + (oo + 2) = (uu::int)";
   767 
   768 test "2*u = (u::int)";
   769 test "(i + j + 12 + (k::int)) - 15 = y";
   770 test "(i + j + 12 + (k::int)) - 5 = y";
   771 
   772 test "y - b < (b::int)";
   773 test "y - (3*b + c) < (b::int) - 2*c";
   774 
   775 test "(2*x - (u*v) + y) - v*3*u = (w::int)";
   776 test "(2*x*u*v + (u*v)*4 + y) - v*u*4 = (w::int)";
   777 test "(2*x*u*v + (u*v)*4 + y) - v*u = (w::int)";
   778 test "u*v - (x*u*v + (u*v)*4 + y) = (w::int)";
   779 
   780 test "(i + j + 12 + (k::int)) = u + 15 + y";
   781 test "(i + j*2 + 12 + (k::int)) = j + 5 + y";
   782 
   783 test "2*y + 3*z + 6*w + 2*y + 3*z + 2*u = 2*y' + 3*z' + 6*w' + 2*y' + 3*z' + u + (vv::int)";
   784 
   785 test "a + -(b+c) + b = (d::int)";
   786 test "a + -(b+c) - b = (d::int)";
   787 
   788 (*negative numerals*)
   789 test "(i + j + -2 + (k::int)) - (u + 5 + y) = zz";
   790 test "(i + j + -3 + (k::int)) < u + 5 + y";
   791 test "(i + j + 3 + (k::int)) < u + -6 + y";
   792 test "(i + j + -12 + (k::int)) - 15 = y";
   793 test "(i + j + 12 + (k::int)) - -15 = y";
   794 test "(i + j + -12 + (k::int)) - -15 = y";
   795 *)
   796 
   797 Addsimprocs Numeral_Simprocs.cancel_numeral_factors;
   798 
   799 (*examples:
   800 print_depth 22;
   801 set timing;
   802 set simp_trace;
   803 fun test s = (Goal s; by (Simp_tac 1));
   804 
   805 test "9*x = 12 * (y::int)";
   806 test "(9*x) div (12 * (y::int)) = z";
   807 test "9*x < 12 * (y::int)";
   808 test "9*x <= 12 * (y::int)";
   809 
   810 test "-99*x = 132 * (y::int)";
   811 test "(-99*x) div (132 * (y::int)) = z";
   812 test "-99*x < 132 * (y::int)";
   813 test "-99*x <= 132 * (y::int)";
   814 
   815 test "999*x = -396 * (y::int)";
   816 test "(999*x) div (-396 * (y::int)) = z";
   817 test "999*x < -396 * (y::int)";
   818 test "999*x <= -396 * (y::int)";
   819 
   820 test "-99*x = -81 * (y::int)";
   821 test "(-99*x) div (-81 * (y::int)) = z";
   822 test "-99*x <= -81 * (y::int)";
   823 test "-99*x < -81 * (y::int)";
   824 
   825 test "-2 * x = -1 * (y::int)";
   826 test "-2 * x = -(y::int)";
   827 test "(-2 * x) div (-1 * (y::int)) = z";
   828 test "-2 * x < -(y::int)";
   829 test "-2 * x <= -1 * (y::int)";
   830 test "-x < -23 * (y::int)";
   831 test "-x <= -23 * (y::int)";
   832 *)
   833 
   834 (*And the same examples for fields such as rat or real:
   835 test "0 <= (y::rat) * -2";
   836 test "9*x = 12 * (y::rat)";
   837 test "(9*x) / (12 * (y::rat)) = z";
   838 test "9*x < 12 * (y::rat)";
   839 test "9*x <= 12 * (y::rat)";
   840 
   841 test "-99*x = 132 * (y::rat)";
   842 test "(-99*x) / (132 * (y::rat)) = z";
   843 test "-99*x < 132 * (y::rat)";
   844 test "-99*x <= 132 * (y::rat)";
   845 
   846 test "999*x = -396 * (y::rat)";
   847 test "(999*x) / (-396 * (y::rat)) = z";
   848 test "999*x < -396 * (y::rat)";
   849 test "999*x <= -396 * (y::rat)";
   850 
   851 test  "(- ((2::rat) * x) <= 2 * y)";
   852 test "-99*x = -81 * (y::rat)";
   853 test "(-99*x) / (-81 * (y::rat)) = z";
   854 test "-99*x <= -81 * (y::rat)";
   855 test "-99*x < -81 * (y::rat)";
   856 
   857 test "-2 * x = -1 * (y::rat)";
   858 test "-2 * x = -(y::rat)";
   859 test "(-2 * x) / (-1 * (y::rat)) = z";
   860 test "-2 * x < -(y::rat)";
   861 test "-2 * x <= -1 * (y::rat)";
   862 test "-x < -23 * (y::rat)";
   863 test "-x <= -23 * (y::rat)";
   864 *)
   865 
   866 Addsimprocs Numeral_Simprocs.cancel_factors;
   867 
   868 
   869 (*examples:
   870 print_depth 22;
   871 set timing;
   872 set simp_trace;
   873 fun test s = (Goal s; by (Asm_simp_tac 1));
   874 
   875 test "x*k = k*(y::int)";
   876 test "k = k*(y::int)";
   877 test "a*(b*c) = (b::int)";
   878 test "a*(b*c) = d*(b::int)*(x*a)";
   879 
   880 test "(x*k) div (k*(y::int)) = (uu::int)";
   881 test "(k) div (k*(y::int)) = (uu::int)";
   882 test "(a*(b*c)) div ((b::int)) = (uu::int)";
   883 test "(a*(b*c)) div (d*(b::int)*(x*a)) = (uu::int)";
   884 *)
   885 
   886 (*And the same examples for fields such as rat or real:
   887 print_depth 22;
   888 set timing;
   889 set simp_trace;
   890 fun test s = (Goal s; by (Asm_simp_tac 1));
   891 
   892 test "x*k = k*(y::rat)";
   893 test "k = k*(y::rat)";
   894 test "a*(b*c) = (b::rat)";
   895 test "a*(b*c) = d*(b::rat)*(x*a)";
   896 
   897 
   898 test "(x*k) / (k*(y::rat)) = (uu::rat)";
   899 test "(k) / (k*(y::rat)) = (uu::rat)";
   900 test "(a*(b*c)) / ((b::rat)) = (uu::rat)";
   901 test "(a*(b*c)) / (d*(b::rat)*(x*a)) = (uu::rat)";
   902 
   903 (*FIXME: what do we do about this?*)
   904 test "a*(b*c)/(y*z) = d*(b::rat)*(x*a)/z";
   905 *)