src/HOL/Integ/nat_simprocs.ML
author nipkow
Fri Dec 01 19:53:29 2000 +0100 (2000-12-01)
changeset 10574 8f98f0301d67
parent 10536 8f34ecae1446
child 10693 9e4a0e84d0d6
permissions -rw-r--r--
Linear arithmetic now copes with mixed nat/int formulae.
     1 (*  Title:      HOL/nat_simprocs.ML
     2     ID:         $Id$
     3     Author:     Lawrence C Paulson, Cambridge University Computer Laboratory
     4     Copyright   2000  University of Cambridge
     5 
     6 Simprocs for nat numerals.
     7 *)
     8 
     9 Goal "number_of v + (number_of v' + (k::nat)) = \
    10 \        (if neg (number_of v) then number_of v' + k \
    11 \         else if neg (number_of v') then number_of v + k \
    12 \         else number_of (bin_add v v') + k)";
    13 by (Simp_tac 1);
    14 qed "nat_number_of_add_left";
    15 
    16 
    17 (** For combine_numerals **)
    18 
    19 Goal "i*u + (j*u + k) = (i+j)*u + (k::nat)";
    20 by (asm_simp_tac (simpset() addsimps [add_mult_distrib]) 1);
    21 qed "left_add_mult_distrib";
    22 
    23 
    24 (** For cancel_numerals **)
    25 
    26 Goal "j <= (i::nat) ==> ((i*u + m) - (j*u + n)) = (((i-j)*u + m) - n)";
    27 by (asm_simp_tac (simpset() addsplits [nat_diff_split]
    28                             addsimps [add_mult_distrib]) 1);
    29 qed "nat_diff_add_eq1";
    30 
    31 Goal "i <= (j::nat) ==> ((i*u + m) - (j*u + n)) = (m - ((j-i)*u + n))";
    32 by (asm_simp_tac (simpset() addsplits [nat_diff_split]
    33                             addsimps [add_mult_distrib]) 1);
    34 qed "nat_diff_add_eq2";
    35 
    36 Goal "j <= (i::nat) ==> (i*u + m = j*u + n) = ((i-j)*u + m = n)";
    37 by (auto_tac (claset(), simpset() addsplits [nat_diff_split]
    38                                   addsimps [add_mult_distrib]));
    39 qed "nat_eq_add_iff1";
    40 
    41 Goal "i <= (j::nat) ==> (i*u + m = j*u + n) = (m = (j-i)*u + n)";
    42 by (auto_tac (claset(), simpset() addsplits [nat_diff_split]
    43                                   addsimps [add_mult_distrib]));
    44 qed "nat_eq_add_iff2";
    45 
    46 Goal "j <= (i::nat) ==> (i*u + m < j*u + n) = ((i-j)*u + m < n)";
    47 by (auto_tac (claset(), simpset() addsplits [nat_diff_split]
    48                                   addsimps [add_mult_distrib]));
    49 qed "nat_less_add_iff1";
    50 
    51 Goal "i <= (j::nat) ==> (i*u + m < j*u + n) = (m < (j-i)*u + n)";
    52 by (auto_tac (claset(), simpset() addsplits [nat_diff_split]
    53                                   addsimps [add_mult_distrib]));
    54 qed "nat_less_add_iff2";
    55 
    56 Goal "j <= (i::nat) ==> (i*u + m <= j*u + n) = ((i-j)*u + m <= n)";
    57 by (auto_tac (claset(), simpset() addsplits [nat_diff_split]
    58                                   addsimps [add_mult_distrib]));
    59 qed "nat_le_add_iff1";
    60 
    61 Goal "i <= (j::nat) ==> (i*u + m <= j*u + n) = (m <= (j-i)*u + n)";
    62 by (auto_tac (claset(), simpset() addsplits [nat_diff_split]
    63                                   addsimps [add_mult_distrib]));
    64 qed "nat_le_add_iff2";
    65 
    66 
    67 (** For cancel_numeral_factors **)
    68 
    69 Goal "(#0::nat) < k ==> (k*m <= k*n) = (m<=n)";
    70 by Auto_tac;  
    71 qed "nat_mult_le_cancel1";
    72 
    73 Goal "(#0::nat) < k ==> (k*m < k*n) = (m<n)";
    74 by Auto_tac;  
    75 qed "nat_mult_less_cancel1";
    76 
    77 Goal "(#0::nat) < k ==> (k*m = k*n) = (m=n)";
    78 by Auto_tac;  
    79 qed "nat_mult_eq_cancel1";
    80 
    81 Goal "(#0::nat) < k ==> (k*m) div (k*n) = (m div n)";
    82 by Auto_tac;  
    83 qed "nat_mult_div_cancel1";
    84 
    85 
    86 structure Nat_Numeral_Simprocs =
    87 struct
    88 
    89 (*Utilities*)
    90 
    91 fun mk_numeral n = HOLogic.number_of_const HOLogic.natT $
    92                    NumeralSyntax.mk_bin n;
    93 
    94 (*Decodes a unary or binary numeral to a NATURAL NUMBER*)
    95 fun dest_numeral (Const ("0", _)) = 0
    96   | dest_numeral (Const ("Suc", _) $ t) = 1 + dest_numeral t
    97   | dest_numeral (Const("Numeral.number_of", _) $ w) =
    98       (BasisLibrary.Int.max (0, NumeralSyntax.dest_bin w)
    99        handle Match => raise TERM("Nat_Numeral_Simprocs.dest_numeral:1", [w]))
   100   | dest_numeral t = raise TERM("Nat_Numeral_Simprocs.dest_numeral:2", [t]);
   101 
   102 fun find_first_numeral past (t::terms) =
   103         ((dest_numeral t, t, rev past @ terms)
   104          handle TERM _ => find_first_numeral (t::past) terms)
   105   | find_first_numeral past [] = raise TERM("find_first_numeral", []);
   106 
   107 val zero = mk_numeral 0;
   108 val mk_plus = HOLogic.mk_binop "op +";
   109 
   110 (*Thus mk_sum[t] yields t+#0; longer sums don't have a trailing zero*)
   111 fun mk_sum []        = zero
   112   | mk_sum [t,u]     = mk_plus (t, u)
   113   | mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
   114 
   115 (*this version ALWAYS includes a trailing zero*)
   116 fun long_mk_sum []        = zero
   117   | long_mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
   118 
   119 val dest_plus = HOLogic.dest_bin "op +" HOLogic.natT;
   120 
   121 (*extract the outer Sucs from a term and convert them to a binary numeral*)
   122 fun dest_Sucs (k, Const ("Suc", _) $ t) = dest_Sucs (k+1, t)
   123   | dest_Sucs (0, t) = t
   124   | dest_Sucs (k, t) = mk_plus (mk_numeral k, t);
   125 
   126 fun dest_sum t =
   127       let val (t,u) = dest_plus t
   128       in  dest_sum t @ dest_sum u  end
   129       handle TERM _ => [t];
   130 
   131 fun dest_Sucs_sum t = dest_sum (dest_Sucs (0,t));
   132 
   133 val trans_tac = Int_Numeral_Simprocs.trans_tac;
   134 
   135 val prove_conv = Int_Numeral_Simprocs.prove_conv;
   136 
   137 val bin_simps = [add_nat_number_of, nat_number_of_add_left,
   138                  diff_nat_number_of, le_nat_number_of_eq_not_less,
   139                  less_nat_number_of, mult_nat_number_of, 
   140                  Let_number_of, nat_number_of] @
   141                 bin_arith_simps @ bin_rel_simps;
   142 
   143 fun prep_simproc (name, pats, proc) = Simplifier.mk_simproc name pats proc;
   144 fun prep_pat s = Thm.read_cterm (Theory.sign_of (the_context ())) (s, HOLogic.termT);
   145 val prep_pats = map prep_pat;
   146 
   147 
   148 (*** CancelNumerals simprocs ***)
   149 
   150 val one = mk_numeral 1;
   151 val mk_times = HOLogic.mk_binop "op *";
   152 
   153 fun mk_prod [] = one
   154   | mk_prod [t] = t
   155   | mk_prod (t :: ts) = if t = one then mk_prod ts
   156                         else mk_times (t, mk_prod ts);
   157 
   158 val dest_times = HOLogic.dest_bin "op *" HOLogic.natT;
   159 
   160 fun dest_prod t =
   161       let val (t,u) = dest_times t
   162       in  dest_prod t @ dest_prod u  end
   163       handle TERM _ => [t];
   164 
   165 (*DON'T do the obvious simplifications; that would create special cases*)
   166 fun mk_coeff (k,t) = mk_times (mk_numeral k, t);
   167 
   168 (*Express t as a product of (possibly) a numeral with other factors, sorted*)
   169 fun dest_coeff t =
   170     let val ts = sort Term.term_ord (dest_prod t)
   171         val (n, _, ts') = find_first_numeral [] ts
   172                           handle TERM _ => (1, one, ts)
   173     in (n, mk_prod ts') end;
   174 
   175 (*Find first coefficient-term THAT MATCHES u*)
   176 fun find_first_coeff past u [] = raise TERM("find_first_coeff", [])
   177   | find_first_coeff past u (t::terms) =
   178         let val (n,u') = dest_coeff t
   179         in  if u aconv u' then (n, rev past @ terms)
   180                           else find_first_coeff (t::past) u terms
   181         end
   182         handle TERM _ => find_first_coeff (t::past) u terms;
   183 
   184 
   185 (*Simplify #1*n and n*#1 to n*)
   186 val add_0s = map rename_numerals [add_0, add_0_right];
   187 val mult_1s = map rename_numerals [mult_1, mult_1_right];
   188 
   189 (*Final simplification: cancel + and *; replace #0 by 0 and #1 by 1*)
   190 val simplify_meta_eq =
   191     Int_Numeral_Simprocs.simplify_meta_eq
   192          [numeral_0_eq_0, numeral_1_eq_1, add_0, add_0_right,
   193          mult_0, mult_0_right, mult_1, mult_1_right];
   194 
   195 
   196 (*** Instantiating CancelNumeralsFun ***)
   197 
   198 structure CancelNumeralsCommon =
   199   struct
   200   val mk_sum            = mk_sum
   201   val dest_sum          = dest_Sucs_sum
   202   val mk_coeff          = mk_coeff
   203   val dest_coeff        = dest_coeff
   204   val find_first_coeff  = find_first_coeff []
   205   val trans_tac          = trans_tac
   206   val norm_tac = ALLGOALS
   207                    (simp_tac (HOL_ss addsimps add_0s@mult_1s@
   208                                        [add_0, Suc_eq_add_numeral_1]@add_ac))
   209                  THEN ALLGOALS (simp_tac
   210                                 (HOL_ss addsimps bin_simps@add_ac@mult_ac))
   211   val numeral_simp_tac  = ALLGOALS
   212                 (simp_tac (HOL_ss addsimps [numeral_0_eq_0 RS sym]@add_0s@bin_simps))
   213   val simplify_meta_eq  = simplify_meta_eq
   214   end;
   215 
   216 
   217 structure EqCancelNumerals = CancelNumeralsFun
   218  (open CancelNumeralsCommon
   219   val prove_conv = prove_conv "nateq_cancel_numerals"
   220   val mk_bal   = HOLogic.mk_eq
   221   val dest_bal = HOLogic.dest_bin "op =" HOLogic.natT
   222   val bal_add1 = nat_eq_add_iff1 RS trans
   223   val bal_add2 = nat_eq_add_iff2 RS trans
   224 );
   225 
   226 structure LessCancelNumerals = CancelNumeralsFun
   227  (open CancelNumeralsCommon
   228   val prove_conv = prove_conv "natless_cancel_numerals"
   229   val mk_bal   = HOLogic.mk_binrel "op <"
   230   val dest_bal = HOLogic.dest_bin "op <" HOLogic.natT
   231   val bal_add1 = nat_less_add_iff1 RS trans
   232   val bal_add2 = nat_less_add_iff2 RS trans
   233 );
   234 
   235 structure LeCancelNumerals = CancelNumeralsFun
   236  (open CancelNumeralsCommon
   237   val prove_conv = prove_conv "natle_cancel_numerals"
   238   val mk_bal   = HOLogic.mk_binrel "op <="
   239   val dest_bal = HOLogic.dest_bin "op <=" HOLogic.natT
   240   val bal_add1 = nat_le_add_iff1 RS trans
   241   val bal_add2 = nat_le_add_iff2 RS trans
   242 );
   243 
   244 structure DiffCancelNumerals = CancelNumeralsFun
   245  (open CancelNumeralsCommon
   246   val prove_conv = prove_conv "natdiff_cancel_numerals"
   247   val mk_bal   = HOLogic.mk_binop "op -"
   248   val dest_bal = HOLogic.dest_bin "op -" HOLogic.natT
   249   val bal_add1 = nat_diff_add_eq1 RS trans
   250   val bal_add2 = nat_diff_add_eq2 RS trans
   251 );
   252 
   253 
   254 val cancel_numerals =
   255   map prep_simproc
   256    [("nateq_cancel_numerals",
   257      prep_pats ["(l::nat) + m = n", "(l::nat) = m + n",
   258                 "(l::nat) * m = n", "(l::nat) = m * n",
   259                 "Suc m = n", "m = Suc n"],
   260      EqCancelNumerals.proc),
   261     ("natless_cancel_numerals",
   262      prep_pats ["(l::nat) + m < n", "(l::nat) < m + n",
   263                 "(l::nat) * m < n", "(l::nat) < m * n",
   264                 "Suc m < n", "m < Suc n"],
   265      LessCancelNumerals.proc),
   266     ("natle_cancel_numerals",
   267      prep_pats ["(l::nat) + m <= n", "(l::nat) <= m + n",
   268                 "(l::nat) * m <= n", "(l::nat) <= m * n",
   269                 "Suc m <= n", "m <= Suc n"],
   270      LeCancelNumerals.proc),
   271     ("natdiff_cancel_numerals",
   272      prep_pats ["((l::nat) + m) - n", "(l::nat) - (m + n)",
   273                 "(l::nat) * m - n", "(l::nat) - m * n",
   274                 "Suc m - n", "m - Suc n"],
   275      DiffCancelNumerals.proc)];
   276 
   277 
   278 (*** Instantiating CombineNumeralsFun ***)
   279 
   280 structure CombineNumeralsData =
   281   struct
   282   val add		= op + : int*int -> int 
   283   val mk_sum            = long_mk_sum    (*to work for e.g. #2*x + #3*x *)
   284   val dest_sum          = dest_Sucs_sum
   285   val mk_coeff          = mk_coeff
   286   val dest_coeff        = dest_coeff
   287   val left_distrib      = left_add_mult_distrib RS trans
   288   val prove_conv = 
   289        Int_Numeral_Simprocs.prove_conv_nohyps "nat_combine_numerals"
   290   val trans_tac          = trans_tac
   291   val norm_tac = ALLGOALS
   292                    (simp_tac (HOL_ss addsimps add_0s@mult_1s@
   293                                        [add_0, Suc_eq_add_numeral_1]@add_ac))
   294                  THEN ALLGOALS (simp_tac
   295                                 (HOL_ss addsimps bin_simps@add_ac@mult_ac))
   296   val numeral_simp_tac  = ALLGOALS
   297                 (simp_tac (HOL_ss addsimps [numeral_0_eq_0 RS sym]@add_0s@bin_simps))
   298   val simplify_meta_eq  = simplify_meta_eq
   299   end;
   300 
   301 structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData);
   302 
   303 val combine_numerals =
   304     prep_simproc ("nat_combine_numerals",
   305                   prep_pats ["(i::nat) + j", "Suc (i + j)"],
   306                   CombineNumerals.proc);
   307 
   308 
   309 (*** Instantiating CancelNumeralFactorFun ***)
   310 
   311 structure CancelNumeralFactorCommon =
   312   struct
   313   val mk_coeff		= mk_coeff
   314   val dest_coeff	= dest_coeff
   315   val trans_tac         = trans_tac
   316   val norm_tac = ALLGOALS (simp_tac (HOL_ss addsimps 
   317                                              [Suc_eq_add_numeral_1]@mult_1s))
   318                  THEN ALLGOALS (simp_tac (HOL_ss addsimps bin_simps@mult_ac))
   319   val numeral_simp_tac	= ALLGOALS (simp_tac (HOL_ss addsimps bin_simps))
   320   val simplify_meta_eq  = simplify_meta_eq
   321   end
   322 
   323 structure DivCancelNumeralFactor = CancelNumeralFactorFun
   324  (open CancelNumeralFactorCommon
   325   val prove_conv = prove_conv "natdiv_cancel_numeral_factor"
   326   val mk_bal   = HOLogic.mk_binop "Divides.op div"
   327   val dest_bal = HOLogic.dest_bin "Divides.op div" HOLogic.natT
   328   val cancel = nat_mult_div_cancel1 RS trans
   329   val neg_exchanges = false
   330 )
   331 
   332 structure EqCancelNumeralFactor = CancelNumeralFactorFun
   333  (open CancelNumeralFactorCommon
   334   val prove_conv = prove_conv "nateq_cancel_numeral_factor"
   335   val mk_bal   = HOLogic.mk_eq
   336   val dest_bal = HOLogic.dest_bin "op =" HOLogic.natT
   337   val cancel = nat_mult_eq_cancel1 RS trans
   338   val neg_exchanges = false
   339 )
   340 
   341 structure LessCancelNumeralFactor = CancelNumeralFactorFun
   342  (open CancelNumeralFactorCommon
   343   val prove_conv = prove_conv "natless_cancel_numeral_factor"
   344   val mk_bal   = HOLogic.mk_binrel "op <"
   345   val dest_bal = HOLogic.dest_bin "op <" HOLogic.natT
   346   val cancel = nat_mult_less_cancel1 RS trans
   347   val neg_exchanges = true
   348 )
   349 
   350 structure LeCancelNumeralFactor = CancelNumeralFactorFun
   351  (open CancelNumeralFactorCommon
   352   val prove_conv = prove_conv "natle_cancel_numeral_factor"
   353   val mk_bal   = HOLogic.mk_binrel "op <="
   354   val dest_bal = HOLogic.dest_bin "op <=" HOLogic.natT
   355   val cancel = nat_mult_le_cancel1 RS trans
   356   val neg_exchanges = true
   357 )
   358 
   359 val cancel_numeral_factors = 
   360   map prep_simproc
   361    [("nateq_cancel_numeral_factors",
   362      prep_pats ["(l::nat) * m = n", "(l::nat) = m * n"], 
   363      EqCancelNumeralFactor.proc),
   364     ("natless_cancel_numeral_factors", 
   365      prep_pats ["(l::nat) * m < n", "(l::nat) < m * n"], 
   366      LessCancelNumeralFactor.proc),
   367     ("natle_cancel_numeral_factors", 
   368      prep_pats ["(l::nat) * m <= n", "(l::nat) <= m * n"], 
   369      LeCancelNumeralFactor.proc),
   370     ("natdiv_cancel_numeral_factors", 
   371      prep_pats ["((l::nat) * m) div n", "(l::nat) div (m * n)"], 
   372      DivCancelNumeralFactor.proc)];
   373 
   374 end;
   375 
   376 
   377 Addsimprocs Nat_Numeral_Simprocs.cancel_numerals;
   378 Addsimprocs [Nat_Numeral_Simprocs.combine_numerals];
   379 Addsimprocs Nat_Numeral_Simprocs.cancel_numeral_factors;
   380 
   381 
   382 (*examples:
   383 print_depth 22;
   384 set timing;
   385 set trace_simp;
   386 fun test s = (Goal s; by (Simp_tac 1));
   387 
   388 (*cancel_numerals*)
   389 test "l +( #2) + (#2) + #2 + (l + #2) + (oo  + #2) = (uu::nat)";
   390 test "(#2*length xs < #2*length xs + j)";
   391 test "(#2*length xs < length xs * #2 + j)";
   392 test "#2*u = (u::nat)";
   393 test "#2*u = Suc (u)";
   394 test "(i + j + #12 + (k::nat)) - #15 = y";
   395 test "(i + j + #12 + (k::nat)) - #5 = y";
   396 test "Suc u - #2 = y";
   397 test "Suc (Suc (Suc u)) - #2 = y";
   398 test "(i + j + #2 + (k::nat)) - 1 = y";
   399 test "(i + j + #1 + (k::nat)) - 2 = y";
   400 
   401 test "(#2*x + (u*v) + y) - v*#3*u = (w::nat)";
   402 test "(#2*x*u*v + #5 + (u*v)*#4 + y) - v*u*#4 = (w::nat)";
   403 test "(#2*x*u*v + (u*v)*#4 + y) - v*u = (w::nat)";
   404 test "Suc (Suc (#2*x*u*v + u*#4 + y)) - u = w";
   405 test "Suc ((u*v)*#4) - v*#3*u = w";
   406 test "Suc (Suc ((u*v)*#3)) - v*#3*u = w";
   407 
   408 test "(i + j + #12 + (k::nat)) = u + #15 + y";
   409 test "(i + j + #32 + (k::nat)) - (u + #15 + y) = zz";
   410 test "(i + j + #12 + (k::nat)) = u + #5 + y";
   411 (*Suc*)
   412 test "(i + j + #12 + k) = Suc (u + y)";
   413 test "Suc (Suc (Suc (Suc (Suc (u + y))))) <= ((i + j) + #41 + k)";
   414 test "(i + j + #5 + k) < Suc (Suc (Suc (Suc (Suc (u + y)))))";
   415 test "Suc (Suc (Suc (Suc (Suc (u + y))))) - #5 = v";
   416 test "(i + j + #5 + k) = Suc (Suc (Suc (Suc (Suc (Suc (Suc (u + y)))))))";
   417 test "#2*y + #3*z + #2*u = Suc (u)";
   418 test "#2*y + #3*z + #6*w + #2*y + #3*z + #2*u = Suc (u)";
   419 test "#2*y + #3*z + #6*w + #2*y + #3*z + #2*u = #2*y' + #3*z' + #6*w' + #2*y' + #3*z' + u + (vv::nat)";
   420 test "#6 + #2*y + #3*z + #4*u = Suc (vv + #2*u + z)";
   421 test "(#2*n*m) < (#3*(m*n)) + (u::nat)";
   422 
   423 (*negative numerals: FAIL*)
   424 test "(i + j + #-23 + (k::nat)) < u + #15 + y";
   425 test "(i + j + #3 + (k::nat)) < u + #-15 + y";
   426 test "(i + j + #-12 + (k::nat)) - #15 = y";
   427 test "(i + j + #12 + (k::nat)) - #-15 = y";
   428 test "(i + j + #-12 + (k::nat)) - #-15 = y";
   429 
   430 (*combine_numerals*)
   431 test "k + #3*k = (u::nat)";
   432 test "Suc (i + #3) = u";
   433 test "Suc (i + j + #3 + k) = u";
   434 test "k + j + #3*k + j = (u::nat)";
   435 test "Suc (j*i + i + k + #5 + #3*k + i*j*#4) = (u::nat)";
   436 test "(#2*n*m) + (#3*(m*n)) = (u::nat)";
   437 (*negative numerals: FAIL*)
   438 test "Suc (i + j + #-3 + k) = u";
   439 
   440 (*cancel_numeral_factor*)
   441 test "#9*x = #12 * (y::nat)";
   442 test "(#9*x) div (#12 * (y::nat)) = z";
   443 test "#9*x < #12 * (y::nat)";
   444 test "#9*x <= #12 * (y::nat)";
   445 *)
   446 
   447 
   448 (*** Prepare linear arithmetic for nat numerals ***)
   449 
   450 local
   451 
   452 (* reduce contradictory <= to False *)
   453 val add_rules =
   454   [add_nat_number_of, diff_nat_number_of, mult_nat_number_of,
   455    eq_nat_number_of, less_nat_number_of, le_nat_number_of_eq_not_less,
   456    le_Suc_number_of,le_number_of_Suc,
   457    less_Suc_number_of,less_number_of_Suc,
   458    Suc_eq_number_of,eq_number_of_Suc,
   459    eq_number_of_0, eq_0_number_of, less_0_number_of,
   460    nat_number_of, Let_number_of, if_True, if_False];
   461 
   462 val simprocs = [Nat_Times_Assoc.conv,
   463                 Nat_Numeral_Simprocs.combine_numerals]@
   464                 Nat_Numeral_Simprocs.cancel_numerals;
   465 
   466 in
   467 
   468 val nat_simprocs_setup =
   469  [Fast_Arith.map_data (fn {add_mono_thms, inj_thms, lessD, simpset} =>
   470    {add_mono_thms = add_mono_thms, inj_thms = inj_thms, lessD = lessD,
   471     simpset = simpset addsimps add_rules
   472                       addsimps basic_renamed_arith_simps
   473                       addsimprocs simprocs})];
   474 
   475 end;