src/HOL/Tools/nat_numeral_simprocs.ML
author haftmann
Fri Oct 10 19:55:32 2014 +0200 (2014-10-10)
changeset 58646 cd63a4b12a33
parent 57514 bdc2c6b40bf2
child 59582 0fbed69ff081
permissions -rw-r--r--
specialized specification: avoid trivial instances
     1 (* Author: Lawrence C Paulson, Cambridge University Computer Laboratory
     2 
     3 Simprocs for nat numerals.
     4 *)
     5 
     6 signature NAT_NUMERAL_SIMPROCS =
     7 sig
     8   val combine_numerals: Proof.context -> cterm -> thm option
     9   val eq_cancel_numerals: Proof.context -> cterm -> thm option
    10   val less_cancel_numerals: Proof.context -> cterm -> thm option
    11   val le_cancel_numerals: Proof.context -> cterm -> thm option
    12   val diff_cancel_numerals: Proof.context -> cterm -> thm option
    13   val eq_cancel_numeral_factor: Proof.context -> cterm -> thm option
    14   val less_cancel_numeral_factor: Proof.context -> cterm -> thm option
    15   val le_cancel_numeral_factor: Proof.context -> cterm -> thm option
    16   val div_cancel_numeral_factor: Proof.context -> cterm -> thm option
    17   val dvd_cancel_numeral_factor: Proof.context -> cterm -> thm option
    18   val eq_cancel_factor: Proof.context -> cterm -> thm option
    19   val less_cancel_factor: Proof.context -> cterm -> thm option
    20   val le_cancel_factor: Proof.context -> cterm -> thm option
    21   val div_cancel_factor: Proof.context -> cterm -> thm option
    22   val dvd_cancel_factor: Proof.context -> cterm -> thm option
    23 end;
    24 
    25 structure Nat_Numeral_Simprocs : NAT_NUMERAL_SIMPROCS =
    26 struct
    27 
    28 (*Maps n to #n for n = 1, 2*)
    29 val numeral_syms = [@{thm numeral_1_eq_1} RS sym, @{thm numeral_2_eq_2} RS sym];
    30 val numeral_sym_ss =
    31   simpset_of (put_simpset HOL_basic_ss @{context} addsimps numeral_syms);
    32 
    33 (*Utilities*)
    34 
    35 fun mk_number 1 = HOLogic.numeral_const HOLogic.natT $ HOLogic.one_const
    36   | mk_number n = HOLogic.mk_number HOLogic.natT n;
    37 fun dest_number t = Int.max (0, snd (HOLogic.dest_number t));
    38 
    39 fun find_first_numeral past (t::terms) =
    40         ((dest_number t, t, rev past @ terms)
    41          handle TERM _ => find_first_numeral (t::past) terms)
    42   | find_first_numeral past [] = raise TERM("find_first_numeral", []);
    43 
    44 val zero = mk_number 0;
    45 val mk_plus = HOLogic.mk_binop @{const_name Groups.plus};
    46 
    47 (*Thus mk_sum[t] yields t+0; longer sums don't have a trailing zero*)
    48 fun mk_sum []        = zero
    49   | mk_sum [t,u]     = mk_plus (t, u)
    50   | mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
    51 
    52 (*this version ALWAYS includes a trailing zero*)
    53 fun long_mk_sum []        = HOLogic.zero
    54   | long_mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
    55 
    56 val dest_plus = HOLogic.dest_bin @{const_name Groups.plus} HOLogic.natT;
    57 
    58 
    59 (** Other simproc items **)
    60 
    61 val bin_simps =
    62      [@{thm numeral_1_eq_1} RS sym,
    63       @{thm numeral_plus_numeral}, @{thm add_numeral_left},
    64       @{thm diff_nat_numeral}, @{thm diff_0_eq_0}, @{thm diff_0},
    65       @{thm numeral_times_numeral}, @{thm mult_numeral_left(1)},
    66       @{thm if_True}, @{thm if_False}, @{thm not_False_eq_True},
    67       @{thm nat_0}, @{thm nat_numeral}, @{thm nat_neg_numeral}] @
    68      @{thms arith_simps} @ @{thms rel_simps};
    69 
    70 
    71 (*** CancelNumerals simprocs ***)
    72 
    73 val one = mk_number 1;
    74 val mk_times = HOLogic.mk_binop @{const_name Groups.times};
    75 
    76 fun mk_prod [] = one
    77   | mk_prod [t] = t
    78   | mk_prod (t :: ts) = if t = one then mk_prod ts
    79                         else mk_times (t, mk_prod ts);
    80 
    81 val dest_times = HOLogic.dest_bin @{const_name Groups.times} HOLogic.natT;
    82 
    83 fun dest_prod t =
    84       let val (t,u) = dest_times t
    85       in  dest_prod t @ dest_prod u  end
    86       handle TERM _ => [t];
    87 
    88 (*DON'T do the obvious simplifications; that would create special cases*)
    89 fun mk_coeff (k,t) = mk_times (mk_number k, t);
    90 
    91 (*Express t as a product of (possibly) a numeral with other factors, sorted*)
    92 fun dest_coeff t =
    93     let val ts = sort Term_Ord.term_ord (dest_prod t)
    94         val (n, _, ts') = find_first_numeral [] ts
    95                           handle TERM _ => (1, one, ts)
    96     in (n, mk_prod ts') end;
    97 
    98 (*Find first coefficient-term THAT MATCHES u*)
    99 fun find_first_coeff past u [] = raise TERM("find_first_coeff", [])
   100   | find_first_coeff past u (t::terms) =
   101         let val (n,u') = dest_coeff t
   102         in  if u aconv u' then (n, rev past @ terms)
   103                           else find_first_coeff (t::past) u terms
   104         end
   105         handle TERM _ => find_first_coeff (t::past) u terms;
   106 
   107 
   108 (*Split up a sum into the list of its constituent terms, on the way removing any
   109   Sucs and counting them.*)
   110 fun dest_Suc_sum (Const (@{const_name Suc}, _) $ t, (k,ts)) = dest_Suc_sum (t, (k+1,ts))
   111   | dest_Suc_sum (t, (k,ts)) = 
   112       let val (t1,t2) = dest_plus t
   113       in  dest_Suc_sum (t1, dest_Suc_sum (t2, (k,ts)))  end
   114       handle TERM _ => (k, t::ts);
   115 
   116 (*Code for testing whether numerals are already used in the goal*)
   117 fun is_numeral (Const(@{const_name Num.numeral}, _) $ w) = true
   118   | is_numeral _ = false;
   119 
   120 fun prod_has_numeral t = exists is_numeral (dest_prod t);
   121 
   122 (*The Sucs found in the term are converted to a binary numeral. If relaxed is false,
   123   an exception is raised unless the original expression contains at least one
   124   numeral in a coefficient position.  This prevents nat_combine_numerals from 
   125   introducing numerals to goals.*)
   126 fun dest_Sucs_sum relaxed t = 
   127   let val (k,ts) = dest_Suc_sum (t,(0,[]))
   128   in
   129      if relaxed orelse exists prod_has_numeral ts then 
   130        if k=0 then ts
   131        else mk_number k :: ts
   132      else raise TERM("Nat_Numeral_Simprocs.dest_Sucs_sum", [t])
   133   end;
   134 
   135 
   136 (* FIXME !? *)
   137 val rename_numerals = simplify (put_simpset numeral_sym_ss @{context}) o Thm.transfer @{theory};
   138 
   139 (*Simplify 1*n and n*1 to n*)
   140 val add_0s  = map rename_numerals [@{thm Nat.add_0}, @{thm Nat.add_0_right}];
   141 val mult_1s = map rename_numerals [@{thm nat_mult_1}, @{thm nat_mult_1_right}];
   142 
   143 (*Final simplification: cancel + and *; replace Numeral0 by 0 and Numeral1 by 1*)
   144 
   145 (*And these help the simproc return False when appropriate, which helps
   146   the arith prover.*)
   147 val contra_rules = [@{thm add_Suc}, @{thm add_Suc_right}, @{thm Zero_not_Suc},
   148   @{thm Suc_not_Zero}, @{thm le_0_eq}];
   149 
   150 val simplify_meta_eq =
   151     Arith_Data.simplify_meta_eq
   152         ([@{thm numeral_1_eq_Suc_0}, @{thm Nat.add_0}, @{thm Nat.add_0_right},
   153           @{thm mult_0}, @{thm mult_0_right}, @{thm mult_1}, @{thm mult_1_right}] @ contra_rules);
   154 
   155 
   156 (*** Applying CancelNumeralsFun ***)
   157 
   158 structure CancelNumeralsCommon =
   159 struct
   160   val mk_sum = (fn T : typ => mk_sum)
   161   val dest_sum = dest_Sucs_sum true
   162   val mk_coeff = mk_coeff
   163   val dest_coeff = dest_coeff
   164   val find_first_coeff = find_first_coeff []
   165   val trans_tac = Numeral_Simprocs.trans_tac
   166 
   167   val norm_ss1 =
   168     simpset_of (put_simpset Numeral_Simprocs.num_ss @{context}
   169       addsimps numeral_syms @ add_0s @ mult_1s @
   170         [@{thm Suc_eq_plus1_left}] @ @{thms ac_simps})
   171   val norm_ss2 =
   172     simpset_of (put_simpset Numeral_Simprocs.num_ss @{context}
   173       addsimps bin_simps @ @{thms ac_simps} @ @{thms ac_simps})
   174   fun norm_tac ctxt = 
   175     ALLGOALS (simp_tac (put_simpset norm_ss1 ctxt))
   176     THEN ALLGOALS (simp_tac (put_simpset norm_ss2 ctxt))
   177 
   178   val numeral_simp_ss =
   179     simpset_of (put_simpset HOL_basic_ss @{context}
   180       addsimps add_0s @ bin_simps);
   181   fun numeral_simp_tac ctxt =
   182     ALLGOALS (simp_tac (put_simpset numeral_simp_ss ctxt));
   183   val simplify_meta_eq  = simplify_meta_eq
   184   val prove_conv = Arith_Data.prove_conv
   185 end;
   186 
   187 structure EqCancelNumerals = CancelNumeralsFun
   188  (open CancelNumeralsCommon
   189   val mk_bal   = HOLogic.mk_eq
   190   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} HOLogic.natT
   191   val bal_add1 = @{thm nat_eq_add_iff1} RS trans
   192   val bal_add2 = @{thm nat_eq_add_iff2} RS trans
   193 );
   194 
   195 structure LessCancelNumerals = CancelNumeralsFun
   196  (open CancelNumeralsCommon
   197   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   198   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} HOLogic.natT
   199   val bal_add1 = @{thm nat_less_add_iff1} RS trans
   200   val bal_add2 = @{thm nat_less_add_iff2} RS trans
   201 );
   202 
   203 structure LeCancelNumerals = CancelNumeralsFun
   204  (open CancelNumeralsCommon
   205   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   206   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} HOLogic.natT
   207   val bal_add1 = @{thm nat_le_add_iff1} RS trans
   208   val bal_add2 = @{thm nat_le_add_iff2} RS trans
   209 );
   210 
   211 structure DiffCancelNumerals = CancelNumeralsFun
   212  (open CancelNumeralsCommon
   213   val mk_bal   = HOLogic.mk_binop @{const_name Groups.minus}
   214   val dest_bal = HOLogic.dest_bin @{const_name Groups.minus} HOLogic.natT
   215   val bal_add1 = @{thm nat_diff_add_eq1} RS trans
   216   val bal_add2 = @{thm nat_diff_add_eq2} RS trans
   217 );
   218 
   219 fun eq_cancel_numerals ctxt ct = EqCancelNumerals.proc ctxt (term_of ct)
   220 fun less_cancel_numerals ctxt ct = LessCancelNumerals.proc ctxt (term_of ct)
   221 fun le_cancel_numerals ctxt ct = LeCancelNumerals.proc ctxt (term_of ct)
   222 fun diff_cancel_numerals ctxt ct = DiffCancelNumerals.proc ctxt (term_of ct)
   223 
   224 
   225 (*** Applying CombineNumeralsFun ***)
   226 
   227 structure CombineNumeralsData =
   228 struct
   229   type coeff = int
   230   val iszero = (fn x => x = 0)
   231   val add = op +
   232   val mk_sum = (fn T : typ => long_mk_sum)  (*to work for 2*x + 3*x *)
   233   val dest_sum = dest_Sucs_sum false
   234   val mk_coeff = mk_coeff
   235   val dest_coeff = dest_coeff
   236   val left_distrib = @{thm left_add_mult_distrib} RS trans
   237   val prove_conv = Arith_Data.prove_conv_nohyps
   238   val trans_tac = Numeral_Simprocs.trans_tac
   239 
   240   val norm_ss1 =
   241     simpset_of (put_simpset Numeral_Simprocs.num_ss @{context}
   242       addsimps numeral_syms @ add_0s @ mult_1s @ [@{thm Suc_eq_plus1}] @ @{thms ac_simps})
   243   val norm_ss2 =
   244     simpset_of (put_simpset Numeral_Simprocs.num_ss @{context}
   245       addsimps bin_simps @ @{thms ac_simps} @ @{thms ac_simps})
   246   fun norm_tac ctxt =
   247     ALLGOALS (simp_tac (put_simpset norm_ss1 ctxt))
   248     THEN ALLGOALS (simp_tac (put_simpset norm_ss2 ctxt))
   249 
   250   val numeral_simp_ss =
   251     simpset_of (put_simpset HOL_basic_ss @{context}
   252       addsimps add_0s @ bin_simps);
   253   fun numeral_simp_tac ctxt =
   254     ALLGOALS (simp_tac (put_simpset numeral_simp_ss ctxt))
   255   val simplify_meta_eq = simplify_meta_eq
   256 end;
   257 
   258 structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData);
   259 
   260 fun combine_numerals ctxt ct = CombineNumerals.proc ctxt (term_of ct)
   261 
   262 
   263 (*** Applying CancelNumeralFactorFun ***)
   264 
   265 structure CancelNumeralFactorCommon =
   266 struct
   267   val mk_coeff = mk_coeff
   268   val dest_coeff = dest_coeff
   269   val trans_tac = Numeral_Simprocs.trans_tac
   270 
   271   val norm_ss1 =
   272     simpset_of (put_simpset Numeral_Simprocs.num_ss @{context}
   273       addsimps numeral_syms @ add_0s @ mult_1s @ [@{thm Suc_eq_plus1_left}] @ @{thms ac_simps})
   274   val norm_ss2 =
   275     simpset_of (put_simpset Numeral_Simprocs.num_ss @{context}
   276       addsimps bin_simps @ @{thms ac_simps} @ @{thms ac_simps})
   277   fun norm_tac ctxt =
   278     ALLGOALS (simp_tac (put_simpset norm_ss1 ctxt))
   279     THEN ALLGOALS (simp_tac (put_simpset norm_ss2 ctxt))
   280 
   281   val numeral_simp_ss =
   282     simpset_of (put_simpset HOL_basic_ss @{context} addsimps bin_simps)
   283   fun numeral_simp_tac ctxt =
   284     ALLGOALS (simp_tac (put_simpset numeral_simp_ss ctxt))
   285   val simplify_meta_eq = simplify_meta_eq
   286   val prove_conv = Arith_Data.prove_conv
   287 end;
   288 
   289 structure DivCancelNumeralFactor = CancelNumeralFactorFun
   290  (open CancelNumeralFactorCommon
   291   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   292   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} HOLogic.natT
   293   val cancel = @{thm nat_mult_div_cancel1} RS trans
   294   val neg_exchanges = false
   295 );
   296 
   297 structure DvdCancelNumeralFactor = CancelNumeralFactorFun
   298  (open CancelNumeralFactorCommon
   299   val mk_bal   = HOLogic.mk_binrel @{const_name Rings.dvd}
   300   val dest_bal = HOLogic.dest_bin @{const_name Rings.dvd} HOLogic.natT
   301   val cancel = @{thm nat_mult_dvd_cancel1} RS trans
   302   val neg_exchanges = false
   303 );
   304 
   305 structure EqCancelNumeralFactor = CancelNumeralFactorFun
   306  (open CancelNumeralFactorCommon
   307   val mk_bal   = HOLogic.mk_eq
   308   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} HOLogic.natT
   309   val cancel = @{thm nat_mult_eq_cancel1} RS trans
   310   val neg_exchanges = false
   311 );
   312 
   313 structure LessCancelNumeralFactor = CancelNumeralFactorFun
   314  (open CancelNumeralFactorCommon
   315   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   316   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} HOLogic.natT
   317   val cancel = @{thm nat_mult_less_cancel1} RS trans
   318   val neg_exchanges = true
   319 );
   320 
   321 structure LeCancelNumeralFactor = CancelNumeralFactorFun
   322  (open CancelNumeralFactorCommon
   323   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   324   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} HOLogic.natT
   325   val cancel = @{thm nat_mult_le_cancel1} RS trans
   326   val neg_exchanges = true
   327 )
   328 
   329 fun eq_cancel_numeral_factor ctxt ct = EqCancelNumeralFactor.proc ctxt (term_of ct)
   330 fun less_cancel_numeral_factor ctxt ct = LessCancelNumeralFactor.proc ctxt (term_of ct)
   331 fun le_cancel_numeral_factor ctxt ct = LeCancelNumeralFactor.proc ctxt (term_of ct)
   332 fun div_cancel_numeral_factor ctxt ct = DivCancelNumeralFactor.proc ctxt (term_of ct)
   333 fun dvd_cancel_numeral_factor ctxt ct = DvdCancelNumeralFactor.proc ctxt (term_of ct)
   334 
   335 
   336 (*** Applying ExtractCommonTermFun ***)
   337 
   338 (*this version ALWAYS includes a trailing one*)
   339 fun long_mk_prod []        = one
   340   | long_mk_prod (t :: ts) = mk_times (t, mk_prod ts);
   341 
   342 (*Find first term that matches u*)
   343 fun find_first_t past u []         = raise TERM("find_first_t", [])
   344   | find_first_t past u (t::terms) =
   345         if u aconv t then (rev past @ terms)
   346         else find_first_t (t::past) u terms
   347         handle TERM _ => find_first_t (t::past) u terms;
   348 
   349 (** Final simplification for the CancelFactor simprocs **)
   350 val simplify_one = Arith_Data.simplify_meta_eq  
   351   [@{thm mult_1_left}, @{thm mult_1_right}, @{thm div_1}, @{thm numeral_1_eq_Suc_0}];
   352 
   353 fun cancel_simplify_meta_eq ctxt cancel_th th =
   354     simplify_one ctxt (([th, cancel_th]) MRS trans);
   355 
   356 structure CancelFactorCommon =
   357 struct
   358   val mk_sum = (fn T : typ => long_mk_prod)
   359   val dest_sum = dest_prod
   360   val mk_coeff = mk_coeff
   361   val dest_coeff = dest_coeff
   362   val find_first = find_first_t []
   363   val trans_tac = Numeral_Simprocs.trans_tac
   364   val norm_ss =
   365     simpset_of (put_simpset HOL_basic_ss @{context}
   366       addsimps mult_1s @ @{thms ac_simps})
   367   fun norm_tac ctxt =
   368     ALLGOALS (simp_tac (put_simpset norm_ss ctxt))
   369   val simplify_meta_eq  = cancel_simplify_meta_eq
   370   fun mk_eq (a, b) = HOLogic.mk_Trueprop (HOLogic.mk_eq (a, b))
   371 end;
   372 
   373 structure EqCancelFactor = ExtractCommonTermFun
   374  (open CancelFactorCommon
   375   val mk_bal   = HOLogic.mk_eq
   376   val dest_bal = HOLogic.dest_bin @{const_name HOL.eq} HOLogic.natT
   377   fun simp_conv _ _ = SOME @{thm nat_mult_eq_cancel_disj}
   378 );
   379 
   380 structure LeCancelFactor = ExtractCommonTermFun
   381  (open CancelFactorCommon
   382   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less_eq}
   383   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less_eq} HOLogic.natT
   384   fun simp_conv _ _ = SOME @{thm nat_mult_le_cancel_disj}
   385 );
   386 
   387 structure LessCancelFactor = ExtractCommonTermFun
   388  (open CancelFactorCommon
   389   val mk_bal   = HOLogic.mk_binrel @{const_name Orderings.less}
   390   val dest_bal = HOLogic.dest_bin @{const_name Orderings.less} HOLogic.natT
   391   fun simp_conv _ _ = SOME @{thm nat_mult_less_cancel_disj}
   392 );
   393 
   394 structure DivideCancelFactor = ExtractCommonTermFun
   395  (open CancelFactorCommon
   396   val mk_bal   = HOLogic.mk_binop @{const_name Divides.div}
   397   val dest_bal = HOLogic.dest_bin @{const_name Divides.div} HOLogic.natT
   398   fun simp_conv _ _ = SOME @{thm nat_mult_div_cancel_disj}
   399 );
   400 
   401 structure DvdCancelFactor = ExtractCommonTermFun
   402  (open CancelFactorCommon
   403   val mk_bal   = HOLogic.mk_binrel @{const_name Rings.dvd}
   404   val dest_bal = HOLogic.dest_bin @{const_name Rings.dvd} HOLogic.natT
   405   fun simp_conv _ _ = SOME @{thm nat_mult_dvd_cancel_disj}
   406 );
   407 
   408 fun eq_cancel_factor ctxt ct = EqCancelFactor.proc ctxt (term_of ct)
   409 fun less_cancel_factor ctxt ct = LessCancelFactor.proc ctxt (term_of ct)
   410 fun le_cancel_factor ctxt ct = LeCancelFactor.proc ctxt (term_of ct)
   411 fun div_cancel_factor ctxt ct = DivideCancelFactor.proc ctxt (term_of ct)
   412 fun dvd_cancel_factor ctxt ct = DvdCancelFactor.proc ctxt (term_of ct)
   413 
   414 end;