src/HOL/Tools/numeral_simprocs.ML
author huffman
Fri Mar 30 12:32:35 2012 +0200 (2012-03-30)
changeset 47220 52426c62b5d0
parent 47108 2a1953f0d20d
child 49323 6dff6b1f5417
permissions -rw-r--r--
replace lemmas eval_nat_numeral with a simpler reformulation
     1 (* Author:     Lawrence C Paulson, Cambridge University Computer Laboratory
     2    Copyright   2000  University of Cambridge
     3 
     4 Simprocs for the (integer) numerals.
     5 *)
     6 
     7 (*To quote from Provers/Arith/cancel_numeral_factor.ML:
     8 
     9 Cancels common coefficients in balanced expressions:
    10 
    11      u*#m ~~ u'*#m'  ==  #n*u ~~ #n'*u'
    12 
    13 where ~~ is an appropriate balancing operation (e.g. =, <=, <, div, /)
    14 and d = gcd(m,m') and n=m/d and n'=m'/d.
    15 *)
    16 
    17 signature NUMERAL_SIMPROCS =
    18 sig
    19   val prep_simproc: theory -> string * string list * (theory -> simpset -> term -> thm option)
    20     -> simproc
    21   val trans_tac: thm option -> tactic
    22   val assoc_fold: simpset -> cterm -> thm option
    23   val combine_numerals: simpset -> cterm -> thm option
    24   val eq_cancel_numerals: simpset -> cterm -> thm option
    25   val less_cancel_numerals: simpset -> cterm -> thm option
    26   val le_cancel_numerals: simpset -> cterm -> thm option
    27   val eq_cancel_factor: simpset -> cterm -> thm option
    28   val le_cancel_factor: simpset -> cterm -> thm option
    29   val less_cancel_factor: simpset -> cterm -> thm option
    30   val div_cancel_factor: simpset -> cterm -> thm option
    31   val mod_cancel_factor: simpset -> cterm -> thm option
    32   val dvd_cancel_factor: simpset -> cterm -> thm option
    33   val divide_cancel_factor: simpset -> cterm -> thm option
    34   val eq_cancel_numeral_factor: simpset -> cterm -> thm option
    35   val less_cancel_numeral_factor: simpset -> cterm -> thm option
    36   val le_cancel_numeral_factor: simpset -> cterm -> thm option
    37   val div_cancel_numeral_factor: simpset -> cterm -> thm option
    38   val divide_cancel_numeral_factor: simpset -> cterm -> thm option
    39   val field_combine_numerals: simpset -> cterm -> thm option
    40   val field_cancel_numeral_factors: simproc list
    41   val num_ss: simpset
    42   val field_comp_conv: conv
    43 end;
    44 
    45 structure Numeral_Simprocs : NUMERAL_SIMPROCS =
    46 struct
    47 
    48 fun prep_simproc thy (name, pats, proc) =
    49   Simplifier.simproc_global thy name pats proc;
    50 
    51 fun trans_tac NONE  = all_tac
    52   | trans_tac (SOME th) = ALLGOALS (rtac (th RS trans));
    53 
    54 val mk_number = Arith_Data.mk_number;
    55 val mk_sum = Arith_Data.mk_sum;
    56 val long_mk_sum = Arith_Data.long_mk_sum;
    57 val dest_sum = Arith_Data.dest_sum;
    58 
    59 val mk_diff = HOLogic.mk_binop @{const_name Groups.minus};
    60 val dest_diff = HOLogic.dest_bin @{const_name Groups.minus} Term.dummyT;
    61 
    62 val mk_times = HOLogic.mk_binop @{const_name Groups.times};
    63 
    64 fun one_of T = Const(@{const_name Groups.one}, T);
    65 
    66 (* build product with trailing 1 rather than Numeral 1 in order to avoid the
    67    unnecessary restriction to type class number_ring
    68    which is not required for cancellation of common factors in divisions.
    69    UPDATE: this reasoning no longer applies (number_ring is gone)
    70 *)
    71 fun mk_prod T = 
    72   let val one = one_of T
    73   fun mk [] = one
    74     | mk [t] = t
    75     | mk (t :: ts) = if t = one then mk ts else mk_times (t, mk ts)
    76   in mk end;
    77 
    78 (*This version ALWAYS includes a trailing one*)
    79 fun long_mk_prod T []        = one_of T
    80   | long_mk_prod T (t :: ts) = mk_times (t, mk_prod T ts);
    81 
    82 val dest_times = HOLogic.dest_bin @{const_name Groups.times} Term.dummyT;
    83 
    84 fun dest_prod t =
    85       let val (t,u) = dest_times t
    86       in dest_prod t @ dest_prod u end
    87       handle TERM _ => [t];
    88 
    89 fun find_first_numeral past (t::terms) =
    90         ((snd (HOLogic.dest_number t), rev past @ terms)
    91          handle TERM _ => find_first_numeral (t::past) terms)
    92   | find_first_numeral past [] = raise TERM("find_first_numeral", []);
    93 
    94 (*DON'T do the obvious simplifications; that would create special cases*)
    95 fun mk_coeff (k, t) = mk_times (mk_number (Term.fastype_of t) k, t);
    96 
    97 (*Express t as a product of (possibly) a numeral with other sorted terms*)
    98 fun dest_coeff sign (Const (@{const_name Groups.uminus}, _) $ t) = dest_coeff (~sign) t
    99   | dest_coeff sign t =
   100     let val ts = sort Term_Ord.term_ord (dest_prod t)
   101         val (n, ts') = find_first_numeral [] ts
   102                           handle TERM _ => (1, ts)
   103     in (sign*n, mk_prod (Term.fastype_of t) ts') end;
   104 
   105 (*Find first coefficient-term THAT MATCHES u*)
   106 fun find_first_coeff past u [] = raise TERM("find_first_coeff", [])
   107   | find_first_coeff past u (t::terms) =
   108         let val (n,u') = dest_coeff 1 t
   109         in if u aconv u' then (n, rev past @ terms)
   110                          else find_first_coeff (t::past) u terms
   111         end
   112         handle TERM _ => find_first_coeff (t::past) u terms;
   113 
   114 (*Fractions as pairs of ints. Can't use Rat.rat because the representation
   115   needs to preserve negative values in the denominator.*)
   116 fun mk_frac (p, q) = if q = 0 then raise Div else (p, q);
   117 
   118 (*Don't reduce fractions; sums must be proved by rule add_frac_eq.
   119   Fractions are reduced later by the cancel_numeral_factor simproc.*)
   120 fun add_frac ((p1, q1), (p2, q2)) = (p1 * q2 + p2 * q1, q1 * q2);
   121 
   122 val mk_divide = HOLogic.mk_binop @{const_name Fields.divide};
   123 
   124 (*Build term (p / q) * t*)
   125 fun mk_fcoeff ((p, q), t) =
   126   let val T = Term.fastype_of t
   127   in mk_times (mk_divide (mk_number T p, mk_number T q), t) end;
   128 
   129 (*Express t as a product of a fraction with other sorted terms*)
   130 fun dest_fcoeff sign (Const (@{const_name Groups.uminus}, _) $ t) = dest_fcoeff (~sign) t
   131   | dest_fcoeff sign (Const (@{const_name Fields.divide}, _) $ t $ u) =
   132     let val (p, t') = dest_coeff sign t
   133         val (q, u') = dest_coeff 1 u
   134     in (mk_frac (p, q), mk_divide (t', u')) end
   135   | dest_fcoeff sign t =
   136     let val (p, t') = dest_coeff sign t
   137         val T = Term.fastype_of t
   138     in (mk_frac (p, 1), mk_divide (t', one_of T)) end;
   139 
   140 
   141 (** New term ordering so that AC-rewriting brings numerals to the front **)
   142 
   143 (*Order integers by absolute value and then by sign. The standard integer
   144   ordering is not well-founded.*)
   145 fun num_ord (i,j) =
   146   (case int_ord (abs i, abs j) of
   147     EQUAL => int_ord (Int.sign i, Int.sign j) 
   148   | ord => ord);
   149 
   150 (*This resembles Term_Ord.term_ord, but it puts binary numerals before other
   151   non-atomic terms.*)
   152 local open Term
   153 in
   154 fun numterm_ord (t, u) =
   155     case (try HOLogic.dest_number t, try HOLogic.dest_number u) of
   156       (SOME (_, i), SOME (_, j)) => num_ord (i, j)
   157     | (SOME _, NONE) => LESS
   158     | (NONE, SOME _) => GREATER
   159     | _ => (
   160       case (t, u) of
   161         (Abs (_, T, t), Abs(_, U, u)) =>
   162         (prod_ord numterm_ord Term_Ord.typ_ord ((t, T), (u, U)))
   163       | _ => (
   164         case int_ord (size_of_term t, size_of_term u) of
   165           EQUAL =>
   166           let val (f, ts) = strip_comb t and (g, us) = strip_comb u in
   167             (prod_ord Term_Ord.hd_ord numterms_ord ((f, ts), (g, us)))
   168           end
   169         | ord => ord))
   170 and numterms_ord (ts, us) = list_ord numterm_ord (ts, us)
   171 end;
   172 
   173 fun numtermless tu = (numterm_ord tu = LESS);
   174 
   175 val num_ss = HOL_basic_ss |> Simplifier.set_termless numtermless;
   176 
   177 (*Maps 1 to Numeral1 so that arithmetic isn't complicated by the abstract 1.*)
   178 val numeral_syms = [@{thm numeral_1_eq_1} RS sym];
   179 
   180 (*Simplify 0+n, n+0, Numeral1*n, n*Numeral1, 1*x, x*1, x/1 *)
   181 val add_0s =  @{thms add_0s};
   182 val mult_1s = @{thms mult_1s mult_1_left mult_1_right divide_1};
   183 
   184 (* For post-simplification of the rhs of simproc-generated rules *)
   185 val post_simps =
   186     [@{thm numeral_1_eq_1},
   187      @{thm add_0_left}, @{thm add_0_right},
   188      @{thm mult_zero_left}, @{thm mult_zero_right},
   189      @{thm mult_1_left}, @{thm mult_1_right},
   190      @{thm mult_minus1}, @{thm mult_minus1_right}]
   191 
   192 val field_post_simps =
   193     post_simps @ [@{thm divide_zero_left}, @{thm divide_1}]
   194                       
   195 (*Simplify inverse Numeral1, a/Numeral1*)
   196 val inverse_1s = [@{thm inverse_numeral_1}];
   197 val divide_1s = [@{thm divide_numeral_1}];
   198 
   199 (*To perform binary arithmetic.  The "left" rewriting handles patterns
   200   created by the Numeral_Simprocs, such as 3 * (5 * x). *)
   201 val simps =
   202     [@{thm numeral_1_eq_1} RS sym] @
   203     @{thms add_numeral_left} @
   204     @{thms add_neg_numeral_left} @
   205     @{thms mult_numeral_left} @
   206     @{thms arith_simps} @ @{thms rel_simps};
   207     
   208 (*Binary arithmetic BUT NOT ADDITION since it may collapse adjacent terms
   209   during re-arrangement*)
   210 val non_add_simps =
   211   subtract Thm.eq_thm
   212     (@{thms add_numeral_left} @
   213      @{thms add_neg_numeral_left} @
   214      @{thms numeral_plus_numeral} @
   215      @{thms add_neg_numeral_simps}) simps;
   216 
   217 (*To evaluate binary negations of coefficients*)
   218 val minus_simps = [@{thm minus_zero}, @{thm minus_one}, @{thm minus_numeral}, @{thm minus_neg_numeral}];
   219 
   220 (*To let us treat subtraction as addition*)
   221 val diff_simps = [@{thm diff_minus}, @{thm minus_add_distrib}, @{thm minus_minus}];
   222 
   223 (*To let us treat division as multiplication*)
   224 val divide_simps = [@{thm divide_inverse}, @{thm inverse_mult_distrib}, @{thm inverse_inverse_eq}];
   225 
   226 (*push the unary minus down*)
   227 val minus_mult_eq_1_to_2 = @{lemma "- (a::'a::ring) * b = a * - b" by simp};
   228 
   229 (*to extract again any uncancelled minuses*)
   230 val minus_from_mult_simps =
   231     [@{thm minus_minus}, @{thm mult_minus_left}, @{thm mult_minus_right}];
   232 
   233 (*combine unary minus with numeric literals, however nested within a product*)
   234 val mult_minus_simps =
   235     [@{thm mult_assoc}, @{thm minus_mult_left}, minus_mult_eq_1_to_2];
   236 
   237 val norm_ss1 = num_ss addsimps numeral_syms @ add_0s @ mult_1s @
   238   diff_simps @ minus_simps @ @{thms add_ac}
   239 val norm_ss2 = num_ss addsimps non_add_simps @ mult_minus_simps
   240 val norm_ss3 = num_ss addsimps minus_from_mult_simps @ @{thms add_ac} @ @{thms mult_ac}
   241 
   242 structure CancelNumeralsCommon =
   243 struct
   244   val mk_sum = mk_sum
   245   val dest_sum = dest_sum
   246   val mk_coeff = mk_coeff
   247   val dest_coeff = dest_coeff 1
   248   val find_first_coeff = find_first_coeff []
   249   val trans_tac = trans_tac
   250 
   251   fun norm_tac ss =
   252     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   253     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   254     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   255 
   256   val numeral_simp_ss = HOL_basic_ss addsimps add_0s @ simps
   257   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   258   val simplify_meta_eq = Arith_Data.simplify_meta_eq post_simps
   259   val prove_conv = Arith_Data.prove_conv
   260 end;
   261 
   262 structure EqCancelNumerals = CancelNumeralsFun
   263  (open CancelNumeralsCommon
   264   val mk_bal   = HOLogic.mk_eq
   265   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   266   val bal_add1 = @{thm eq_add_iff1} RS trans
   267   val bal_add2 = @{thm eq_add_iff2} RS trans
   268 );
   269 
   270 structure LessCancelNumerals = CancelNumeralsFun
   271  (open CancelNumeralsCommon
   272   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   273   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   274   val bal_add1 = @{thm less_add_iff1} RS trans
   275   val bal_add2 = @{thm less_add_iff2} RS trans
   276 );
   277 
   278 structure LeCancelNumerals = CancelNumeralsFun
   279  (open CancelNumeralsCommon
   280   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   281   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   282   val bal_add1 = @{thm le_add_iff1} RS trans
   283   val bal_add2 = @{thm le_add_iff2} RS trans
   284 );
   285 
   286 fun eq_cancel_numerals ss ct = EqCancelNumerals.proc ss (term_of ct)
   287 fun less_cancel_numerals ss ct = LessCancelNumerals.proc ss (term_of ct)
   288 fun le_cancel_numerals ss ct = LeCancelNumerals.proc ss (term_of ct)
   289 
   290 structure CombineNumeralsData =
   291 struct
   292   type coeff = int
   293   val iszero = (fn x => x = 0)
   294   val add  = op +
   295   val mk_sum = long_mk_sum    (*to work for e.g. 2*x + 3*x *)
   296   val dest_sum = dest_sum
   297   val mk_coeff = mk_coeff
   298   val dest_coeff = dest_coeff 1
   299   val left_distrib = @{thm combine_common_factor} RS trans
   300   val prove_conv = Arith_Data.prove_conv_nohyps
   301   val trans_tac = trans_tac
   302 
   303   fun norm_tac ss =
   304     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   305     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   306     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   307 
   308   val numeral_simp_ss = HOL_basic_ss addsimps add_0s @ simps
   309   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   310   val simplify_meta_eq = Arith_Data.simplify_meta_eq post_simps
   311 end;
   312 
   313 structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData);
   314 
   315 (*Version for fields, where coefficients can be fractions*)
   316 structure FieldCombineNumeralsData =
   317 struct
   318   type coeff = int * int
   319   val iszero = (fn (p, q) => p = 0)
   320   val add = add_frac
   321   val mk_sum = long_mk_sum
   322   val dest_sum = dest_sum
   323   val mk_coeff = mk_fcoeff
   324   val dest_coeff = dest_fcoeff 1
   325   val left_distrib = @{thm combine_common_factor} RS trans
   326   val prove_conv = Arith_Data.prove_conv_nohyps
   327   val trans_tac = trans_tac
   328 
   329   val norm_ss1a = norm_ss1 addsimps inverse_1s @ divide_simps
   330   fun norm_tac ss =
   331     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1a))
   332     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   333     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   334 
   335   val numeral_simp_ss = HOL_basic_ss addsimps add_0s @ simps @ [@{thm add_frac_eq}, @{thm not_False_eq_True}]
   336   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   337   val simplify_meta_eq = Arith_Data.simplify_meta_eq field_post_simps
   338 end;
   339 
   340 structure FieldCombineNumerals = CombineNumeralsFun(FieldCombineNumeralsData);
   341 
   342 fun combine_numerals ss ct = CombineNumerals.proc ss (term_of ct)
   343 
   344 fun field_combine_numerals ss ct = FieldCombineNumerals.proc ss (term_of ct)
   345 
   346 (** Constant folding for multiplication in semirings **)
   347 
   348 (*We do not need folding for addition: combine_numerals does the same thing*)
   349 
   350 structure Semiring_Times_Assoc_Data : ASSOC_FOLD_DATA =
   351 struct
   352   val assoc_ss = HOL_basic_ss addsimps @{thms mult_ac}
   353   val eq_reflection = eq_reflection
   354   val is_numeral = can HOLogic.dest_number
   355 end;
   356 
   357 structure Semiring_Times_Assoc = Assoc_Fold (Semiring_Times_Assoc_Data);
   358 
   359 fun assoc_fold ss ct = Semiring_Times_Assoc.proc ss (term_of ct)
   360 
   361 structure CancelNumeralFactorCommon =
   362 struct
   363   val mk_coeff = mk_coeff
   364   val dest_coeff = dest_coeff 1
   365   val trans_tac = trans_tac
   366 
   367   val norm_ss1 = HOL_basic_ss addsimps minus_from_mult_simps @ mult_1s
   368   val norm_ss2 = HOL_basic_ss addsimps simps @ mult_minus_simps
   369   val norm_ss3 = HOL_basic_ss addsimps @{thms mult_ac}
   370   fun norm_tac ss =
   371     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   372     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   373     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss3))
   374 
   375   (* simp_thms are necessary because some of the cancellation rules below
   376   (e.g. mult_less_cancel_left) introduce various logical connectives *)
   377   val numeral_simp_ss = HOL_basic_ss addsimps simps @ @{thms simp_thms}
   378   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   379   val simplify_meta_eq = Arith_Data.simplify_meta_eq
   380     ([@{thm Nat.add_0}, @{thm Nat.add_0_right}] @ post_simps)
   381   val prove_conv = Arith_Data.prove_conv
   382 end
   383 
   384 (*Version for semiring_div*)
   385 structure DivCancelNumeralFactor = CancelNumeralFactorFun
   386  (open CancelNumeralFactorCommon
   387   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   388   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   389   val cancel = @{thm div_mult_mult1} RS trans
   390   val neg_exchanges = false
   391 )
   392 
   393 (*Version for fields*)
   394 structure DivideCancelNumeralFactor = CancelNumeralFactorFun
   395  (open CancelNumeralFactorCommon
   396   val mk_bal   = HOLogic.mk_binop @{const_name Fields.divide}
   397   val dest_bal = HOLogic.dest_bin @{const_name Fields.divide} Term.dummyT
   398   val cancel = @{thm mult_divide_mult_cancel_left} RS trans
   399   val neg_exchanges = false
   400 )
   401 
   402 structure EqCancelNumeralFactor = CancelNumeralFactorFun
   403  (open CancelNumeralFactorCommon
   404   val mk_bal   = HOLogic.mk_eq
   405   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   406   val cancel = @{thm mult_cancel_left} RS trans
   407   val neg_exchanges = false
   408 )
   409 
   410 structure LessCancelNumeralFactor = CancelNumeralFactorFun
   411  (open CancelNumeralFactorCommon
   412   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   413   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   414   val cancel = @{thm mult_less_cancel_left} RS trans
   415   val neg_exchanges = true
   416 )
   417 
   418 structure LeCancelNumeralFactor = CancelNumeralFactorFun
   419  (open CancelNumeralFactorCommon
   420   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   421   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   422   val cancel = @{thm mult_le_cancel_left} RS trans
   423   val neg_exchanges = true
   424 )
   425 
   426 fun eq_cancel_numeral_factor ss ct = EqCancelNumeralFactor.proc ss (term_of ct)
   427 fun less_cancel_numeral_factor ss ct = LessCancelNumeralFactor.proc ss (term_of ct)
   428 fun le_cancel_numeral_factor ss ct = LeCancelNumeralFactor.proc ss (term_of ct)
   429 fun div_cancel_numeral_factor ss ct = DivCancelNumeralFactor.proc ss (term_of ct)
   430 fun divide_cancel_numeral_factor ss ct = DivideCancelNumeralFactor.proc ss (term_of ct)
   431 
   432 val field_cancel_numeral_factors =
   433   map (prep_simproc @{theory})
   434    [("field_eq_cancel_numeral_factor",
   435      ["(l::'a::field) * m = n",
   436       "(l::'a::field) = m * n"],
   437      K EqCancelNumeralFactor.proc),
   438     ("field_cancel_numeral_factor",
   439      ["((l::'a::field_inverse_zero) * m) / n",
   440       "(l::'a::field_inverse_zero) / (m * n)",
   441       "((numeral v)::'a::field_inverse_zero) / (numeral w)",
   442       "((numeral v)::'a::field_inverse_zero) / (neg_numeral w)",
   443       "((neg_numeral v)::'a::field_inverse_zero) / (numeral w)",
   444       "((neg_numeral v)::'a::field_inverse_zero) / (neg_numeral w)"],
   445      K DivideCancelNumeralFactor.proc)]
   446 
   447 
   448 (** Declarations for ExtractCommonTerm **)
   449 
   450 (*Find first term that matches u*)
   451 fun find_first_t past u []         = raise TERM ("find_first_t", [])
   452   | find_first_t past u (t::terms) =
   453         if u aconv t then (rev past @ terms)
   454         else find_first_t (t::past) u terms
   455         handle TERM _ => find_first_t (t::past) u terms;
   456 
   457 (** Final simplification for the CancelFactor simprocs **)
   458 val simplify_one = Arith_Data.simplify_meta_eq  
   459   [@{thm mult_1_left}, @{thm mult_1_right}, @{thm div_by_1}, @{thm numeral_1_eq_1}];
   460 
   461 fun cancel_simplify_meta_eq ss cancel_th th =
   462     simplify_one ss (([th, cancel_th]) MRS trans);
   463 
   464 local
   465   val Tp_Eq = Thm.reflexive (Thm.cterm_of @{theory HOL} HOLogic.Trueprop)
   466   fun Eq_True_elim Eq = 
   467     Thm.equal_elim (Thm.combination Tp_Eq (Thm.symmetric Eq)) @{thm TrueI}
   468 in
   469 fun sign_conv pos_th neg_th ss t =
   470   let val T = fastype_of t;
   471       val zero = Const(@{const_name Groups.zero}, T);
   472       val less = Const(@{const_name Orderings.less}, [T,T] ---> HOLogic.boolT);
   473       val pos = less $ zero $ t and neg = less $ t $ zero
   474       val thy = Proof_Context.theory_of (Simplifier.the_context ss)
   475       fun prove p =
   476         SOME (Eq_True_elim (Simplifier.asm_rewrite ss (Thm.cterm_of thy p)))
   477         handle THM _ => NONE
   478     in case prove pos of
   479          SOME th => SOME(th RS pos_th)
   480        | NONE => (case prove neg of
   481                     SOME th => SOME(th RS neg_th)
   482                   | NONE => NONE)
   483     end;
   484 end
   485 
   486 structure CancelFactorCommon =
   487 struct
   488   val mk_sum = long_mk_prod
   489   val dest_sum = dest_prod
   490   val mk_coeff = mk_coeff
   491   val dest_coeff = dest_coeff
   492   val find_first = find_first_t []
   493   val trans_tac = trans_tac
   494   val norm_ss = HOL_basic_ss addsimps mult_1s @ @{thms mult_ac}
   495   fun norm_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss))
   496   val simplify_meta_eq  = cancel_simplify_meta_eq 
   497   fun mk_eq (a, b) = HOLogic.mk_Trueprop (HOLogic.mk_eq (a, b))
   498 end;
   499 
   500 (*mult_cancel_left requires a ring with no zero divisors.*)
   501 structure EqCancelFactor = ExtractCommonTermFun
   502  (open CancelFactorCommon
   503   val mk_bal   = HOLogic.mk_eq
   504   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} Term.dummyT
   505   fun simp_conv _ _ = SOME @{thm mult_cancel_left}
   506 );
   507 
   508 (*for ordered rings*)
   509 structure LeCancelFactor = ExtractCommonTermFun
   510  (open CancelFactorCommon
   511   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   512   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} Term.dummyT
   513   val simp_conv = sign_conv
   514     @{thm mult_le_cancel_left_pos} @{thm mult_le_cancel_left_neg}
   515 );
   516 
   517 (*for ordered rings*)
   518 structure LessCancelFactor = ExtractCommonTermFun
   519  (open CancelFactorCommon
   520   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   521   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} Term.dummyT
   522   val simp_conv = sign_conv
   523     @{thm mult_less_cancel_left_pos} @{thm mult_less_cancel_left_neg}
   524 );
   525 
   526 (*for semirings with division*)
   527 structure DivCancelFactor = ExtractCommonTermFun
   528  (open CancelFactorCommon
   529   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   530   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} Term.dummyT
   531   fun simp_conv _ _ = SOME @{thm div_mult_mult1_if}
   532 );
   533 
   534 structure ModCancelFactor = ExtractCommonTermFun
   535  (open CancelFactorCommon
   536   val mk_bal   = HOLogic.mk_binop @{const_name Divides.mod}
   537   val dest_bal = HOLogic.dest_bin @{const_name Divides.mod} Term.dummyT
   538   fun simp_conv _ _ = SOME @{thm mod_mult_mult1}
   539 );
   540 
   541 (*for idoms*)
   542 structure DvdCancelFactor = ExtractCommonTermFun
   543  (open CancelFactorCommon
   544   val mk_bal   = HOLogic.mk_binrel @{const_name Rings.dvd}
   545   val dest_bal = HOLogic.dest_bin @{const_name Rings.dvd} Term.dummyT
   546   fun simp_conv _ _ = SOME @{thm dvd_mult_cancel_left}
   547 );
   548 
   549 (*Version for all fields, including unordered ones (type complex).*)
   550 structure DivideCancelFactor = ExtractCommonTermFun
   551  (open CancelFactorCommon
   552   val mk_bal   = HOLogic.mk_binop @{const_name Fields.divide}
   553   val dest_bal = HOLogic.dest_bin @{const_name Fields.divide} Term.dummyT
   554   fun simp_conv _ _ = SOME @{thm mult_divide_mult_cancel_left_if}
   555 );
   556 
   557 fun eq_cancel_factor ss ct = EqCancelFactor.proc ss (term_of ct)
   558 fun le_cancel_factor ss ct = LeCancelFactor.proc ss (term_of ct)
   559 fun less_cancel_factor ss ct = LessCancelFactor.proc ss (term_of ct)
   560 fun div_cancel_factor ss ct = DivCancelFactor.proc ss (term_of ct)
   561 fun mod_cancel_factor ss ct = ModCancelFactor.proc ss (term_of ct)
   562 fun dvd_cancel_factor ss ct = DvdCancelFactor.proc ss (term_of ct)
   563 fun divide_cancel_factor ss ct = DivideCancelFactor.proc ss (term_of ct)
   564 
   565 local
   566  val zr = @{cpat "0"}
   567  val zT = ctyp_of_term zr
   568  val geq = @{cpat HOL.eq}
   569  val eqT = Thm.dest_ctyp (ctyp_of_term geq) |> hd
   570  val add_frac_eq = mk_meta_eq @{thm "add_frac_eq"}
   571  val add_frac_num = mk_meta_eq @{thm "add_frac_num"}
   572  val add_num_frac = mk_meta_eq @{thm "add_num_frac"}
   573 
   574  fun prove_nz ss T t =
   575     let
   576       val z = Thm.instantiate_cterm ([(zT,T)],[]) zr
   577       val eq = Thm.instantiate_cterm ([(eqT,T)],[]) geq
   578       val th = Simplifier.rewrite (ss addsimps @{thms simp_thms})
   579            (Thm.apply @{cterm "Trueprop"} (Thm.apply @{cterm "Not"}
   580                   (Thm.apply (Thm.apply eq t) z)))
   581     in Thm.equal_elim (Thm.symmetric th) TrueI
   582     end
   583 
   584  fun proc phi ss ct =
   585   let
   586     val ((x,y),(w,z)) =
   587          (Thm.dest_binop #> (fn (a,b) => (Thm.dest_binop a, Thm.dest_binop b))) ct
   588     val _ = map (HOLogic.dest_number o term_of) [x,y,z,w]
   589     val T = ctyp_of_term x
   590     val [y_nz, z_nz] = map (prove_nz ss T) [y, z]
   591     val th = instantiate' [SOME T] (map SOME [y,z,x,w]) add_frac_eq
   592   in SOME (Thm.implies_elim (Thm.implies_elim th y_nz) z_nz)
   593   end
   594   handle CTERM _ => NONE | TERM _ => NONE | THM _ => NONE
   595 
   596  fun proc2 phi ss ct =
   597   let
   598     val (l,r) = Thm.dest_binop ct
   599     val T = ctyp_of_term l
   600   in (case (term_of l, term_of r) of
   601       (Const(@{const_name Fields.divide},_)$_$_, _) =>
   602         let val (x,y) = Thm.dest_binop l val z = r
   603             val _ = map (HOLogic.dest_number o term_of) [x,y,z]
   604             val ynz = prove_nz ss T y
   605         in SOME (Thm.implies_elim (instantiate' [SOME T] (map SOME [y,x,z]) add_frac_num) ynz)
   606         end
   607      | (_, Const (@{const_name Fields.divide},_)$_$_) =>
   608         let val (x,y) = Thm.dest_binop r val z = l
   609             val _ = map (HOLogic.dest_number o term_of) [x,y,z]
   610             val ynz = prove_nz ss T y
   611         in SOME (Thm.implies_elim (instantiate' [SOME T] (map SOME [y,z,x]) add_num_frac) ynz)
   612         end
   613      | _ => NONE)
   614   end
   615   handle CTERM _ => NONE | TERM _ => NONE | THM _ => NONE
   616 
   617  fun is_number (Const(@{const_name Fields.divide},_)$a$b) = is_number a andalso is_number b
   618    | is_number t = can HOLogic.dest_number t
   619 
   620  val is_number = is_number o term_of
   621 
   622  fun proc3 phi ss ct =
   623   (case term_of ct of
   624     Const(@{const_name Orderings.less},_)$(Const(@{const_name Fields.divide},_)$_$_)$_ =>
   625       let
   626         val ((a,b),c) = Thm.dest_binop ct |>> Thm.dest_binop
   627         val _ = map is_number [a,b,c]
   628         val T = ctyp_of_term c
   629         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "divide_less_eq"}
   630       in SOME (mk_meta_eq th) end
   631   | Const(@{const_name Orderings.less_eq},_)$(Const(@{const_name Fields.divide},_)$_$_)$_ =>
   632       let
   633         val ((a,b),c) = Thm.dest_binop ct |>> Thm.dest_binop
   634         val _ = map is_number [a,b,c]
   635         val T = ctyp_of_term c
   636         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "divide_le_eq"}
   637       in SOME (mk_meta_eq th) end
   638   | Const(@{const_name HOL.eq},_)$(Const(@{const_name Fields.divide},_)$_$_)$_ =>
   639       let
   640         val ((a,b),c) = Thm.dest_binop ct |>> Thm.dest_binop
   641         val _ = map is_number [a,b,c]
   642         val T = ctyp_of_term c
   643         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "divide_eq_eq"}
   644       in SOME (mk_meta_eq th) end
   645   | Const(@{const_name Orderings.less},_)$_$(Const(@{const_name Fields.divide},_)$_$_) =>
   646     let
   647       val (a,(b,c)) = Thm.dest_binop ct ||> Thm.dest_binop
   648         val _ = map is_number [a,b,c]
   649         val T = ctyp_of_term c
   650         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "less_divide_eq"}
   651       in SOME (mk_meta_eq th) end
   652   | Const(@{const_name Orderings.less_eq},_)$_$(Const(@{const_name Fields.divide},_)$_$_) =>
   653     let
   654       val (a,(b,c)) = Thm.dest_binop ct ||> Thm.dest_binop
   655         val _ = map is_number [a,b,c]
   656         val T = ctyp_of_term c
   657         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "le_divide_eq"}
   658       in SOME (mk_meta_eq th) end
   659   | Const(@{const_name HOL.eq},_)$_$(Const(@{const_name Fields.divide},_)$_$_) =>
   660     let
   661       val (a,(b,c)) = Thm.dest_binop ct ||> Thm.dest_binop
   662         val _ = map is_number [a,b,c]
   663         val T = ctyp_of_term c
   664         val th = instantiate' [SOME T] (map SOME [a,b,c]) @{thm "eq_divide_eq"}
   665       in SOME (mk_meta_eq th) end
   666   | _ => NONE)
   667   handle TERM _ => NONE | CTERM _ => NONE | THM _ => NONE
   668 
   669 val add_frac_frac_simproc =
   670        make_simproc {lhss = [@{cpat "(?x::?'a::field)/?y + (?w::?'a::field)/?z"}],
   671                      name = "add_frac_frac_simproc",
   672                      proc = proc, identifier = []}
   673 
   674 val add_frac_num_simproc =
   675        make_simproc {lhss = [@{cpat "(?x::?'a::field)/?y + ?z"}, @{cpat "?z + (?x::?'a::field)/?y"}],
   676                      name = "add_frac_num_simproc",
   677                      proc = proc2, identifier = []}
   678 
   679 val ord_frac_simproc =
   680   make_simproc
   681     {lhss = [@{cpat "(?a::(?'a::{field, ord}))/?b < ?c"},
   682              @{cpat "(?a::(?'a::{field, ord}))/?b <= ?c"},
   683              @{cpat "?c < (?a::(?'a::{field, ord}))/?b"},
   684              @{cpat "?c <= (?a::(?'a::{field, ord}))/?b"},
   685              @{cpat "?c = ((?a::(?'a::{field, ord}))/?b)"},
   686              @{cpat "((?a::(?'a::{field, ord}))/ ?b) = ?c"}],
   687              name = "ord_frac_simproc", proc = proc3, identifier = []}
   688 
   689 val ths = [@{thm "mult_numeral_1"}, @{thm "mult_numeral_1_right"},
   690            @{thm "divide_Numeral1"},
   691            @{thm "divide_zero"}, @{thm divide_zero_left},
   692            @{thm "divide_divide_eq_left"}, 
   693            @{thm "times_divide_eq_left"}, @{thm "times_divide_eq_right"},
   694            @{thm "times_divide_times_eq"},
   695            @{thm "divide_divide_eq_right"},
   696            @{thm "diff_minus"}, @{thm "minus_divide_left"},
   697            @{thm "add_divide_distrib"} RS sym,
   698            @{thm field_divide_inverse} RS sym, @{thm inverse_divide}, 
   699            Conv.fconv_rule (Conv.arg_conv (Conv.arg1_conv (Conv.rewr_conv (mk_meta_eq @{thm mult_commute}))))   
   700            (@{thm field_divide_inverse} RS sym)]
   701 
   702 in
   703 
   704 val field_comp_conv =
   705   Simplifier.rewrite
   706     (HOL_basic_ss addsimps @{thms "semiring_norm"}
   707       addsimps ths addsimps @{thms simp_thms}
   708       addsimprocs field_cancel_numeral_factors
   709       addsimprocs [add_frac_frac_simproc, add_frac_num_simproc, ord_frac_simproc]
   710       |> Simplifier.add_cong @{thm "if_weak_cong"})
   711   then_conv
   712   Simplifier.rewrite (HOL_basic_ss addsimps [@{thm numeral_1_eq_1}])
   713 
   714 end
   715 
   716 end;
   717 
   718 (*examples:
   719 print_depth 22;
   720 set timing;
   721 set simp_trace;
   722 fun test s = (Goal s, by (Simp_tac 1));
   723 
   724 test "l + 2 + 2 + 2 + (l + 2) + (oo + 2) = (uu::int)";
   725 
   726 test "2*u = (u::int)";
   727 test "(i + j + 12 + (k::int)) - 15 = y";
   728 test "(i + j + 12 + (k::int)) - 5 = y";
   729 
   730 test "y - b < (b::int)";
   731 test "y - (3*b + c) < (b::int) - 2*c";
   732 
   733 test "(2*x - (u*v) + y) - v*3*u = (w::int)";
   734 test "(2*x*u*v + (u*v)*4 + y) - v*u*4 = (w::int)";
   735 test "(2*x*u*v + (u*v)*4 + y) - v*u = (w::int)";
   736 test "u*v - (x*u*v + (u*v)*4 + y) = (w::int)";
   737 
   738 test "(i + j + 12 + (k::int)) = u + 15 + y";
   739 test "(i + j*2 + 12 + (k::int)) = j + 5 + y";
   740 
   741 test "2*y + 3*z + 6*w + 2*y + 3*z + 2*u = 2*y' + 3*z' + 6*w' + 2*y' + 3*z' + u + (vv::int)";
   742 
   743 test "a + -(b+c) + b = (d::int)";
   744 test "a + -(b+c) - b = (d::int)";
   745 
   746 (*negative numerals*)
   747 test "(i + j + -2 + (k::int)) - (u + 5 + y) = zz";
   748 test "(i + j + -3 + (k::int)) < u + 5 + y";
   749 test "(i + j + 3 + (k::int)) < u + -6 + y";
   750 test "(i + j + -12 + (k::int)) - 15 = y";
   751 test "(i + j + 12 + (k::int)) - -15 = y";
   752 test "(i + j + -12 + (k::int)) - -15 = y";
   753 *)
   754 
   755 (*examples:
   756 print_depth 22;
   757 set timing;
   758 set simp_trace;
   759 fun test s = (Goal s; by (Simp_tac 1));
   760 
   761 test "9*x = 12 * (y::int)";
   762 test "(9*x) div (12 * (y::int)) = z";
   763 test "9*x < 12 * (y::int)";
   764 test "9*x <= 12 * (y::int)";
   765 
   766 test "-99*x = 132 * (y::int)";
   767 test "(-99*x) div (132 * (y::int)) = z";
   768 test "-99*x < 132 * (y::int)";
   769 test "-99*x <= 132 * (y::int)";
   770 
   771 test "999*x = -396 * (y::int)";
   772 test "(999*x) div (-396 * (y::int)) = z";
   773 test "999*x < -396 * (y::int)";
   774 test "999*x <= -396 * (y::int)";
   775 
   776 test "-99*x = -81 * (y::int)";
   777 test "(-99*x) div (-81 * (y::int)) = z";
   778 test "-99*x <= -81 * (y::int)";
   779 test "-99*x < -81 * (y::int)";
   780 
   781 test "-2 * x = -1 * (y::int)";
   782 test "-2 * x = -(y::int)";
   783 test "(-2 * x) div (-1 * (y::int)) = z";
   784 test "-2 * x < -(y::int)";
   785 test "-2 * x <= -1 * (y::int)";
   786 test "-x < -23 * (y::int)";
   787 test "-x <= -23 * (y::int)";
   788 *)
   789 
   790 (*And the same examples for fields such as rat or real:
   791 test "0 <= (y::rat) * -2";
   792 test "9*x = 12 * (y::rat)";
   793 test "(9*x) / (12 * (y::rat)) = z";
   794 test "9*x < 12 * (y::rat)";
   795 test "9*x <= 12 * (y::rat)";
   796 
   797 test "-99*x = 132 * (y::rat)";
   798 test "(-99*x) / (132 * (y::rat)) = z";
   799 test "-99*x < 132 * (y::rat)";
   800 test "-99*x <= 132 * (y::rat)";
   801 
   802 test "999*x = -396 * (y::rat)";
   803 test "(999*x) / (-396 * (y::rat)) = z";
   804 test "999*x < -396 * (y::rat)";
   805 test "999*x <= -396 * (y::rat)";
   806 
   807 test  "(- ((2::rat) * x) <= 2 * y)";
   808 test "-99*x = -81 * (y::rat)";
   809 test "(-99*x) / (-81 * (y::rat)) = z";
   810 test "-99*x <= -81 * (y::rat)";
   811 test "-99*x < -81 * (y::rat)";
   812 
   813 test "-2 * x = -1 * (y::rat)";
   814 test "-2 * x = -(y::rat)";
   815 test "(-2 * x) / (-1 * (y::rat)) = z";
   816 test "-2 * x < -(y::rat)";
   817 test "-2 * x <= -1 * (y::rat)";
   818 test "-x < -23 * (y::rat)";
   819 test "-x <= -23 * (y::rat)";
   820 *)
   821 
   822 (*examples:
   823 print_depth 22;
   824 set timing;
   825 set simp_trace;
   826 fun test s = (Goal s; by (Asm_simp_tac 1));
   827 
   828 test "x*k = k*(y::int)";
   829 test "k = k*(y::int)";
   830 test "a*(b*c) = (b::int)";
   831 test "a*(b*c) = d*(b::int)*(x*a)";
   832 
   833 test "(x*k) div (k*(y::int)) = (uu::int)";
   834 test "(k) div (k*(y::int)) = (uu::int)";
   835 test "(a*(b*c)) div ((b::int)) = (uu::int)";
   836 test "(a*(b*c)) div (d*(b::int)*(x*a)) = (uu::int)";
   837 *)
   838 
   839 (*And the same examples for fields such as rat or real:
   840 print_depth 22;
   841 set timing;
   842 set simp_trace;
   843 fun test s = (Goal s; by (Asm_simp_tac 1));
   844 
   845 test "x*k = k*(y::rat)";
   846 test "k = k*(y::rat)";
   847 test "a*(b*c) = (b::rat)";
   848 test "a*(b*c) = d*(b::rat)*(x*a)";
   849 
   850 
   851 test "(x*k) / (k*(y::rat)) = (uu::rat)";
   852 test "(k) / (k*(y::rat)) = (uu::rat)";
   853 test "(a*(b*c)) / ((b::rat)) = (uu::rat)";
   854 test "(a*(b*c)) / (d*(b::rat)*(x*a)) = (uu::rat)";
   855 
   856 (*FIXME: what do we do about this?*)
   857 test "a*(b*c)/(y*z) = d*(b::rat)*(x*a)/z";
   858 *)