src/HOL/Tools/nat_numeral_simprocs.ML
author huffman
Fri Mar 30 12:32:35 2012 +0200 (2012-03-30)
changeset 47220 52426c62b5d0
parent 47108 2a1953f0d20d
child 51717 9e7d1c139569
permissions -rw-r--r--
replace lemmas eval_nat_numeral with a simpler reformulation
     1 (* Author: Lawrence C Paulson, Cambridge University Computer Laboratory
     2 
     3 Simprocs for nat numerals.
     4 *)
     5 
     6 signature NAT_NUMERAL_SIMPROCS =
     7 sig
     8   val combine_numerals: simpset -> cterm -> thm option
     9   val eq_cancel_numerals: simpset -> cterm -> thm option
    10   val less_cancel_numerals: simpset -> cterm -> thm option
    11   val le_cancel_numerals: simpset -> cterm -> thm option
    12   val diff_cancel_numerals: simpset -> cterm -> thm option
    13   val eq_cancel_numeral_factor: simpset -> cterm -> thm option
    14   val less_cancel_numeral_factor: simpset -> cterm -> thm option
    15   val le_cancel_numeral_factor: simpset -> cterm -> thm option
    16   val div_cancel_numeral_factor: simpset -> cterm -> thm option
    17   val dvd_cancel_numeral_factor: simpset -> cterm -> thm option
    18   val eq_cancel_factor: simpset -> cterm -> thm option
    19   val less_cancel_factor: simpset -> cterm -> thm option
    20   val le_cancel_factor: simpset -> cterm -> thm option
    21   val div_cancel_factor: simpset -> cterm -> thm option
    22   val dvd_cancel_factor: simpset -> cterm -> thm option
    23 end;
    24 
    25 structure Nat_Numeral_Simprocs : NAT_NUMERAL_SIMPROCS =
    26 struct
    27 
    28 (*Maps n to #n for n = 1, 2*)
    29 val numeral_syms = [@{thm numeral_1_eq_1} RS sym, @{thm numeral_2_eq_2} RS sym];
    30 val numeral_sym_ss = HOL_basic_ss addsimps numeral_syms;
    31 
    32 val rename_numerals = simplify numeral_sym_ss o Thm.transfer @{theory};
    33 
    34 (*Utilities*)
    35 
    36 fun mk_number 1 = HOLogic.numeral_const HOLogic.natT $ HOLogic.one_const
    37   | mk_number n = HOLogic.mk_number HOLogic.natT n;
    38 fun dest_number t = Int.max (0, snd (HOLogic.dest_number t));
    39 
    40 fun find_first_numeral past (t::terms) =
    41         ((dest_number t, t, rev past @ terms)
    42          handle TERM _ => find_first_numeral (t::past) terms)
    43   | find_first_numeral past [] = raise TERM("find_first_numeral", []);
    44 
    45 val zero = mk_number 0;
    46 val mk_plus = HOLogic.mk_binop @{const_name Groups.plus};
    47 
    48 (*Thus mk_sum[t] yields t+0; longer sums don't have a trailing zero*)
    49 fun mk_sum []        = zero
    50   | mk_sum [t,u]     = mk_plus (t, u)
    51   | mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
    52 
    53 (*this version ALWAYS includes a trailing zero*)
    54 fun long_mk_sum []        = HOLogic.zero
    55   | long_mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
    56 
    57 val dest_plus = HOLogic.dest_bin @{const_name Groups.plus} HOLogic.natT;
    58 
    59 
    60 (** Other simproc items **)
    61 
    62 val bin_simps =
    63      [@{thm numeral_1_eq_1} RS sym,
    64       @{thm numeral_plus_numeral}, @{thm add_numeral_left},
    65       @{thm diff_nat_numeral}, @{thm diff_0_eq_0}, @{thm diff_0},
    66       @{thm numeral_times_numeral}, @{thm mult_numeral_left(1)},
    67       @{thm if_True}, @{thm if_False}, @{thm not_False_eq_True},
    68       @{thm nat_0}, @{thm nat_numeral}, @{thm nat_neg_numeral}] @
    69      @{thms arith_simps} @ @{thms rel_simps};
    70 
    71 
    72 (*** CancelNumerals simprocs ***)
    73 
    74 val one = mk_number 1;
    75 val mk_times = HOLogic.mk_binop @{const_name Groups.times};
    76 
    77 fun mk_prod [] = one
    78   | mk_prod [t] = t
    79   | mk_prod (t :: ts) = if t = one then mk_prod ts
    80                         else mk_times (t, mk_prod ts);
    81 
    82 val dest_times = HOLogic.dest_bin @{const_name Groups.times} HOLogic.natT;
    83 
    84 fun dest_prod t =
    85       let val (t,u) = dest_times t
    86       in  dest_prod t @ dest_prod u  end
    87       handle TERM _ => [t];
    88 
    89 (*DON'T do the obvious simplifications; that would create special cases*)
    90 fun mk_coeff (k,t) = mk_times (mk_number k, t);
    91 
    92 (*Express t as a product of (possibly) a numeral with other factors, sorted*)
    93 fun dest_coeff t =
    94     let val ts = sort Term_Ord.term_ord (dest_prod t)
    95         val (n, _, ts') = find_first_numeral [] ts
    96                           handle TERM _ => (1, one, ts)
    97     in (n, mk_prod ts') end;
    98 
    99 (*Find first coefficient-term THAT MATCHES u*)
   100 fun find_first_coeff past u [] = raise TERM("find_first_coeff", [])
   101   | find_first_coeff past u (t::terms) =
   102         let val (n,u') = dest_coeff t
   103         in  if u aconv u' then (n, rev past @ terms)
   104                           else find_first_coeff (t::past) u terms
   105         end
   106         handle TERM _ => find_first_coeff (t::past) u terms;
   107 
   108 
   109 (*Split up a sum into the list of its constituent terms, on the way removing any
   110   Sucs and counting them.*)
   111 fun dest_Suc_sum (Const (@{const_name Suc}, _) $ t, (k,ts)) = dest_Suc_sum (t, (k+1,ts))
   112   | dest_Suc_sum (t, (k,ts)) = 
   113       let val (t1,t2) = dest_plus t
   114       in  dest_Suc_sum (t1, dest_Suc_sum (t2, (k,ts)))  end
   115       handle TERM _ => (k, t::ts);
   116 
   117 (*Code for testing whether numerals are already used in the goal*)
   118 fun is_numeral (Const(@{const_name Num.numeral}, _) $ w) = true
   119   | is_numeral _ = false;
   120 
   121 fun prod_has_numeral t = exists is_numeral (dest_prod t);
   122 
   123 (*The Sucs found in the term are converted to a binary numeral. If relaxed is false,
   124   an exception is raised unless the original expression contains at least one
   125   numeral in a coefficient position.  This prevents nat_combine_numerals from 
   126   introducing numerals to goals.*)
   127 fun dest_Sucs_sum relaxed t = 
   128   let val (k,ts) = dest_Suc_sum (t,(0,[]))
   129   in
   130      if relaxed orelse exists prod_has_numeral ts then 
   131        if k=0 then ts
   132        else mk_number k :: ts
   133      else raise TERM("Nat_Numeral_Simprocs.dest_Sucs_sum", [t])
   134   end;
   135 
   136 
   137 (*Simplify 1*n and n*1 to n*)
   138 val add_0s  = map rename_numerals [@{thm Nat.add_0}, @{thm Nat.add_0_right}];
   139 val mult_1s = map rename_numerals [@{thm nat_mult_1}, @{thm nat_mult_1_right}];
   140 
   141 (*Final simplification: cancel + and *; replace Numeral0 by 0 and Numeral1 by 1*)
   142 
   143 (*And these help the simproc return False when appropriate, which helps
   144   the arith prover.*)
   145 val contra_rules = [@{thm add_Suc}, @{thm add_Suc_right}, @{thm Zero_not_Suc},
   146   @{thm Suc_not_Zero}, @{thm le_0_eq}];
   147 
   148 val simplify_meta_eq =
   149     Arith_Data.simplify_meta_eq
   150         ([@{thm numeral_1_eq_Suc_0}, @{thm Nat.add_0}, @{thm Nat.add_0_right},
   151           @{thm mult_0}, @{thm mult_0_right}, @{thm mult_1}, @{thm mult_1_right}] @ contra_rules);
   152 
   153 
   154 (*** Applying CancelNumeralsFun ***)
   155 
   156 structure CancelNumeralsCommon =
   157 struct
   158   val mk_sum = (fn T : typ => mk_sum)
   159   val dest_sum = dest_Sucs_sum true
   160   val mk_coeff = mk_coeff
   161   val dest_coeff = dest_coeff
   162   val find_first_coeff = find_first_coeff []
   163   val trans_tac = Numeral_Simprocs.trans_tac
   164 
   165   val norm_ss1 = Numeral_Simprocs.num_ss addsimps numeral_syms @ add_0s @ mult_1s @
   166     [@{thm Suc_eq_plus1_left}] @ @{thms add_ac}
   167   val norm_ss2 = Numeral_Simprocs.num_ss addsimps bin_simps @ @{thms add_ac} @ @{thms mult_ac}
   168   fun norm_tac ss = 
   169     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   170     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   171 
   172   val numeral_simp_ss = HOL_basic_ss addsimps add_0s @ bin_simps;
   173   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss));
   174   val simplify_meta_eq  = simplify_meta_eq
   175   val prove_conv = Arith_Data.prove_conv
   176 end;
   177 
   178 structure EqCancelNumerals = CancelNumeralsFun
   179  (open CancelNumeralsCommon
   180   val mk_bal   = HOLogic.mk_eq
   181   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} HOLogic.natT
   182   val bal_add1 = @{thm nat_eq_add_iff1} RS trans
   183   val bal_add2 = @{thm nat_eq_add_iff2} RS trans
   184 );
   185 
   186 structure LessCancelNumerals = CancelNumeralsFun
   187  (open CancelNumeralsCommon
   188   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   189   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} HOLogic.natT
   190   val bal_add1 = @{thm nat_less_add_iff1} RS trans
   191   val bal_add2 = @{thm nat_less_add_iff2} RS trans
   192 );
   193 
   194 structure LeCancelNumerals = CancelNumeralsFun
   195  (open CancelNumeralsCommon
   196   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   197   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} HOLogic.natT
   198   val bal_add1 = @{thm nat_le_add_iff1} RS trans
   199   val bal_add2 = @{thm nat_le_add_iff2} RS trans
   200 );
   201 
   202 structure DiffCancelNumerals = CancelNumeralsFun
   203  (open CancelNumeralsCommon
   204   val mk_bal   = HOLogic.mk_binop @{const_name Groups.minus}
   205   val dest_bal = HOLogic.dest_bin @{const_name Groups.minus} HOLogic.natT
   206   val bal_add1 = @{thm nat_diff_add_eq1} RS trans
   207   val bal_add2 = @{thm nat_diff_add_eq2} RS trans
   208 );
   209 
   210 fun eq_cancel_numerals ss ct = EqCancelNumerals.proc ss (term_of ct)
   211 fun less_cancel_numerals ss ct = LessCancelNumerals.proc ss (term_of ct)
   212 fun le_cancel_numerals ss ct = LeCancelNumerals.proc ss (term_of ct)
   213 fun diff_cancel_numerals ss ct = DiffCancelNumerals.proc ss (term_of ct)
   214 
   215 
   216 (*** Applying CombineNumeralsFun ***)
   217 
   218 structure CombineNumeralsData =
   219 struct
   220   type coeff = int
   221   val iszero = (fn x => x = 0)
   222   val add = op +
   223   val mk_sum = (fn T : typ => long_mk_sum)  (*to work for 2*x + 3*x *)
   224   val dest_sum = dest_Sucs_sum false
   225   val mk_coeff = mk_coeff
   226   val dest_coeff = dest_coeff
   227   val left_distrib = @{thm left_add_mult_distrib} RS trans
   228   val prove_conv = Arith_Data.prove_conv_nohyps
   229   val trans_tac = Numeral_Simprocs.trans_tac
   230 
   231   val norm_ss1 = Numeral_Simprocs.num_ss addsimps numeral_syms @ add_0s @ mult_1s @ [@{thm Suc_eq_plus1}] @ @{thms add_ac}
   232   val norm_ss2 = Numeral_Simprocs.num_ss addsimps bin_simps @ @{thms add_ac} @ @{thms mult_ac}
   233   fun norm_tac ss =
   234     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   235     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   236 
   237   val numeral_simp_ss = HOL_basic_ss addsimps add_0s @ bin_simps;
   238   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   239   val simplify_meta_eq = simplify_meta_eq
   240 end;
   241 
   242 structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData);
   243 
   244 fun combine_numerals ss ct = CombineNumerals.proc ss (term_of ct)
   245 
   246 
   247 (*** Applying CancelNumeralFactorFun ***)
   248 
   249 structure CancelNumeralFactorCommon =
   250 struct
   251   val mk_coeff = mk_coeff
   252   val dest_coeff = dest_coeff
   253   val trans_tac = Numeral_Simprocs.trans_tac
   254 
   255   val norm_ss1 = Numeral_Simprocs.num_ss addsimps
   256     numeral_syms @ add_0s @ mult_1s @ [@{thm Suc_eq_plus1_left}] @ @{thms add_ac}
   257   val norm_ss2 = Numeral_Simprocs.num_ss addsimps bin_simps @ @{thms add_ac} @ @{thms mult_ac}
   258   fun norm_tac ss =
   259     ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss1))
   260     THEN ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss2))
   261 
   262   val numeral_simp_ss = HOL_basic_ss addsimps bin_simps
   263   fun numeral_simp_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   264   val simplify_meta_eq = simplify_meta_eq
   265   val prove_conv = Arith_Data.prove_conv
   266 end;
   267 
   268 structure DivCancelNumeralFactor = CancelNumeralFactorFun
   269  (open CancelNumeralFactorCommon
   270   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   271   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} HOLogic.natT
   272   val cancel = @{thm nat_mult_div_cancel1} RS trans
   273   val neg_exchanges = false
   274 );
   275 
   276 structure DvdCancelNumeralFactor = CancelNumeralFactorFun
   277  (open CancelNumeralFactorCommon
   278   val mk_bal   = HOLogic.mk_binrel @{const_name Rings.dvd}
   279   val dest_bal = HOLogic.dest_bin @{const_name Rings.dvd} HOLogic.natT
   280   val cancel = @{thm nat_mult_dvd_cancel1} RS trans
   281   val neg_exchanges = false
   282 );
   283 
   284 structure EqCancelNumeralFactor = CancelNumeralFactorFun
   285  (open CancelNumeralFactorCommon
   286   val mk_bal   = HOLogic.mk_eq
   287   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} HOLogic.natT
   288   val cancel = @{thm nat_mult_eq_cancel1} RS trans
   289   val neg_exchanges = false
   290 );
   291 
   292 structure LessCancelNumeralFactor = CancelNumeralFactorFun
   293  (open CancelNumeralFactorCommon
   294   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   295   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} HOLogic.natT
   296   val cancel = @{thm nat_mult_less_cancel1} RS trans
   297   val neg_exchanges = true
   298 );
   299 
   300 structure LeCancelNumeralFactor = CancelNumeralFactorFun
   301  (open CancelNumeralFactorCommon
   302   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   303   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} HOLogic.natT
   304   val cancel = @{thm nat_mult_le_cancel1} RS trans
   305   val neg_exchanges = true
   306 )
   307 
   308 fun eq_cancel_numeral_factor ss ct = EqCancelNumeralFactor.proc ss (term_of ct)
   309 fun less_cancel_numeral_factor ss ct = LessCancelNumeralFactor.proc ss (term_of ct)
   310 fun le_cancel_numeral_factor ss ct = LeCancelNumeralFactor.proc ss (term_of ct)
   311 fun div_cancel_numeral_factor ss ct = DivCancelNumeralFactor.proc ss (term_of ct)
   312 fun dvd_cancel_numeral_factor ss ct = DvdCancelNumeralFactor.proc ss (term_of ct)
   313 
   314 
   315 (*** Applying ExtractCommonTermFun ***)
   316 
   317 (*this version ALWAYS includes a trailing one*)
   318 fun long_mk_prod []        = one
   319   | long_mk_prod (t :: ts) = mk_times (t, mk_prod ts);
   320 
   321 (*Find first term that matches u*)
   322 fun find_first_t past u []         = raise TERM("find_first_t", [])
   323   | find_first_t past u (t::terms) =
   324         if u aconv t then (rev past @ terms)
   325         else find_first_t (t::past) u terms
   326         handle TERM _ => find_first_t (t::past) u terms;
   327 
   328 (** Final simplification for the CancelFactor simprocs **)
   329 val simplify_one = Arith_Data.simplify_meta_eq  
   330   [@{thm mult_1_left}, @{thm mult_1_right}, @{thm div_1}, @{thm numeral_1_eq_Suc_0}];
   331 
   332 fun cancel_simplify_meta_eq ss cancel_th th =
   333     simplify_one ss (([th, cancel_th]) MRS trans);
   334 
   335 structure CancelFactorCommon =
   336 struct
   337   val mk_sum = (fn T : typ => long_mk_prod)
   338   val dest_sum = dest_prod
   339   val mk_coeff = mk_coeff
   340   val dest_coeff = dest_coeff
   341   val find_first = find_first_t []
   342   val trans_tac = Numeral_Simprocs.trans_tac
   343   val norm_ss = HOL_basic_ss addsimps mult_1s @ @{thms mult_ac}
   344   fun norm_tac ss = ALLGOALS (simp_tac (Simplifier.inherit_context ss norm_ss))
   345   val simplify_meta_eq  = cancel_simplify_meta_eq
   346   fun mk_eq (a, b) = HOLogic.mk_Trueprop (HOLogic.mk_eq (a, b))
   347 end;
   348 
   349 structure EqCancelFactor = ExtractCommonTermFun
   350  (open CancelFactorCommon
   351   val mk_bal   = HOLogic.mk_eq
   352   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} HOLogic.natT
   353   fun simp_conv _ _ = SOME @{thm nat_mult_eq_cancel_disj}
   354 );
   355 
   356 structure LeCancelFactor = ExtractCommonTermFun
   357  (open CancelFactorCommon
   358   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   359   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} HOLogic.natT
   360   fun simp_conv _ _ = SOME @{thm nat_mult_le_cancel_disj}
   361 );
   362 
   363 structure LessCancelFactor = ExtractCommonTermFun
   364  (open CancelFactorCommon
   365   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   366   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} HOLogic.natT
   367   fun simp_conv _ _ = SOME @{thm nat_mult_less_cancel_disj}
   368 );
   369 
   370 structure DivideCancelFactor = ExtractCommonTermFun
   371  (open CancelFactorCommon
   372   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   373   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} HOLogic.natT
   374   fun simp_conv _ _ = SOME @{thm nat_mult_div_cancel_disj}
   375 );
   376 
   377 structure DvdCancelFactor = ExtractCommonTermFun
   378  (open CancelFactorCommon
   379   val mk_bal   = HOLogic.mk_binrel @{const_name Rings.dvd}
   380   val dest_bal = HOLogic.dest_bin @{const_name Rings.dvd} HOLogic.natT
   381   fun simp_conv _ _ = SOME @{thm nat_mult_dvd_cancel_disj}
   382 );
   383 
   384 fun eq_cancel_factor ss ct = EqCancelFactor.proc ss (term_of ct)
   385 fun less_cancel_factor ss ct = LessCancelFactor.proc ss (term_of ct)
   386 fun le_cancel_factor ss ct = LeCancelFactor.proc ss (term_of ct)
   387 fun div_cancel_factor ss ct = DivideCancelFactor.proc ss (term_of ct)
   388 fun dvd_cancel_factor ss ct = DvdCancelFactor.proc ss (term_of ct)
   389 
   390 end;