src/ZF/int_arith.ML
author wenzelm
Fri Mar 20 15:24:18 2009 +0100 (2009-03-20)
changeset 30607 c3d1590debd8
parent 29269 5c25a2012975
child 32149 ef59550a55d3
permissions -rw-r--r--
eliminated global SIMPSET, CLASET etc. -- refer to explicit context;
     1 (*  Title:      ZF/int_arith.ML
     2     Author:     Larry Paulson
     3 
     4 Simprocs for linear arithmetic.
     5 *)
     6 
     7 structure Int_Numeral_Simprocs =
     8 struct
     9 
    10 (*Utilities*)
    11 
    12 fun mk_numeral n = @{const integ_of} $ NumeralSyntax.mk_bin n;
    13 
    14 (*Decodes a binary INTEGER*)
    15 fun dest_numeral (Const(@{const_name integ_of}, _) $ w) =
    16      (NumeralSyntax.dest_bin w
    17       handle Match => raise TERM("Int_Numeral_Simprocs.dest_numeral:1", [w]))
    18   | dest_numeral t =  raise TERM("Int_Numeral_Simprocs.dest_numeral:2", [t]);
    19 
    20 fun find_first_numeral past (t::terms) =
    21         ((dest_numeral t, rev past @ terms)
    22          handle TERM _ => find_first_numeral (t::past) terms)
    23   | find_first_numeral past [] = raise TERM("find_first_numeral", []);
    24 
    25 val zero = mk_numeral 0;
    26 val mk_plus = FOLogic.mk_binop @{const_name "zadd"};
    27 
    28 (*Thus mk_sum[t] yields t+#0; longer sums don't have a trailing zero*)
    29 fun mk_sum []        = zero
    30   | mk_sum [t,u]     = mk_plus (t, u)
    31   | mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
    32 
    33 (*this version ALWAYS includes a trailing zero*)
    34 fun long_mk_sum []        = zero
    35   | long_mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
    36 
    37 val dest_plus = FOLogic.dest_bin @{const_name "zadd"} @{typ i};
    38 
    39 (*decompose additions AND subtractions as a sum*)
    40 fun dest_summing (pos, Const (@{const_name "zadd"}, _) $ t $ u, ts) =
    41         dest_summing (pos, t, dest_summing (pos, u, ts))
    42   | dest_summing (pos, Const (@{const_name "zdiff"}, _) $ t $ u, ts) =
    43         dest_summing (pos, t, dest_summing (not pos, u, ts))
    44   | dest_summing (pos, t, ts) =
    45         if pos then t::ts else @{const zminus} $ t :: ts;
    46 
    47 fun dest_sum t = dest_summing (true, t, []);
    48 
    49 val mk_diff = FOLogic.mk_binop @{const_name "zdiff"};
    50 val dest_diff = FOLogic.dest_bin @{const_name "zdiff"} @{typ i};
    51 
    52 val one = mk_numeral 1;
    53 val mk_times = FOLogic.mk_binop @{const_name "zmult"};
    54 
    55 fun mk_prod [] = one
    56   | mk_prod [t] = t
    57   | mk_prod (t :: ts) = if t = one then mk_prod ts
    58                         else mk_times (t, mk_prod ts);
    59 
    60 val dest_times = FOLogic.dest_bin @{const_name "zmult"} @{typ i};
    61 
    62 fun dest_prod t =
    63       let val (t,u) = dest_times t
    64       in  dest_prod t @ dest_prod u  end
    65       handle TERM _ => [t];
    66 
    67 (*DON'T do the obvious simplifications; that would create special cases*)
    68 fun mk_coeff (k, t) = mk_times (mk_numeral k, t);
    69 
    70 (*Express t as a product of (possibly) a numeral with other sorted terms*)
    71 fun dest_coeff sign (Const (@{const_name "zminus"}, _) $ t) = dest_coeff (~sign) t
    72   | dest_coeff sign t =
    73     let val ts = sort TermOrd.term_ord (dest_prod t)
    74         val (n, ts') = find_first_numeral [] ts
    75                           handle TERM _ => (1, ts)
    76     in (sign*n, mk_prod ts') end;
    77 
    78 (*Find first coefficient-term THAT MATCHES u*)
    79 fun find_first_coeff past u [] = raise TERM("find_first_coeff", [])
    80   | find_first_coeff past u (t::terms) =
    81         let val (n,u') = dest_coeff 1 t
    82         in  if u aconv u' then (n, rev past @ terms)
    83                           else find_first_coeff (t::past) u terms
    84         end
    85         handle TERM _ => find_first_coeff (t::past) u terms;
    86 
    87 
    88 (*Simplify #1*n and n*#1 to n*)
    89 val add_0s = [@{thm zadd_0_intify}, @{thm zadd_0_right_intify}];
    90 
    91 val mult_1s = [@{thm zmult_1_intify}, @{thm zmult_1_right_intify},
    92                @{thm zmult_minus1}, @{thm zmult_minus1_right}];
    93 
    94 val tc_rules = [@{thm integ_of_type}, @{thm intify_in_int},
    95                 @{thm int_of_type}, @{thm zadd_type}, @{thm zdiff_type}, @{thm zmult_type}] @ 
    96                @{thms bin.intros};
    97 val intifys = [@{thm intify_ident}, @{thm zadd_intify1}, @{thm zadd_intify2},
    98                @{thm zdiff_intify1}, @{thm zdiff_intify2}, @{thm zmult_intify1}, @{thm zmult_intify2},
    99                @{thm zless_intify1}, @{thm zless_intify2}, @{thm zle_intify1}, @{thm zle_intify2}];
   100 
   101 (*To perform binary arithmetic*)
   102 val bin_simps = [@{thm add_integ_of_left}] @ @{thms bin_arith_simps} @ @{thms bin_rel_simps};
   103 
   104 (*To evaluate binary negations of coefficients*)
   105 val zminus_simps = @{thms NCons_simps} @
   106                    [@{thm integ_of_minus} RS sym,
   107                     @{thm bin_minus_1}, @{thm bin_minus_0}, @{thm bin_minus_Pls}, @{thm bin_minus_Min},
   108                     @{thm bin_pred_1}, @{thm bin_pred_0}, @{thm bin_pred_Pls}, @{thm bin_pred_Min}];
   109 
   110 (*To let us treat subtraction as addition*)
   111 val diff_simps = [@{thm zdiff_def}, @{thm zminus_zadd_distrib}, @{thm zminus_zminus}];
   112 
   113 (*push the unary minus down: - x * y = x * - y *)
   114 val int_minus_mult_eq_1_to_2 =
   115     [@{thm zmult_zminus}, @{thm zmult_zminus_right} RS sym] MRS trans |> standard;
   116 
   117 (*to extract again any uncancelled minuses*)
   118 val int_minus_from_mult_simps =
   119     [@{thm zminus_zminus}, @{thm zmult_zminus}, @{thm zmult_zminus_right}];
   120 
   121 (*combine unary minus with numeric literals, however nested within a product*)
   122 val int_mult_minus_simps =
   123     [@{thm zmult_assoc}, @{thm zmult_zminus} RS sym, int_minus_mult_eq_1_to_2];
   124 
   125 fun prep_simproc (name, pats, proc) =
   126   Simplifier.simproc (the_context ()) name pats proc;
   127 
   128 structure CancelNumeralsCommon =
   129   struct
   130   val mk_sum            = (fn T:typ => mk_sum)
   131   val dest_sum          = dest_sum
   132   val mk_coeff          = mk_coeff
   133   val dest_coeff        = dest_coeff 1
   134   val find_first_coeff  = find_first_coeff []
   135   fun trans_tac _       = ArithData.gen_trans_tac iff_trans
   136 
   137   val norm_ss1 = ZF_ss addsimps add_0s @ mult_1s @ diff_simps @ zminus_simps @ @{thms zadd_ac}
   138   val norm_ss2 = ZF_ss addsimps bin_simps @ int_mult_minus_simps @ intifys
   139   val norm_ss3 = ZF_ss addsimps int_minus_from_mult_simps @ @{thms zadd_ac} @ @{thms zmult_ac} @ tc_rules @ intifys
   140   fun norm_tac ss =
   141     ALLGOALS (asm_simp_tac (Simplifier.inherit_context ss norm_ss1))
   142     THEN ALLGOALS (asm_simp_tac (Simplifier.inherit_context ss norm_ss2))
   143     THEN ALLGOALS (asm_simp_tac (Simplifier.inherit_context ss norm_ss3))
   144 
   145   val numeral_simp_ss = ZF_ss addsimps add_0s @ bin_simps @ tc_rules @ intifys
   146   fun numeral_simp_tac ss =
   147     ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   148     THEN ALLGOALS (asm_simp_tac (local_simpset_of (Simplifier.the_context ss)))
   149   val simplify_meta_eq  = ArithData.simplify_meta_eq (add_0s @ mult_1s)
   150   end;
   151 
   152 
   153 structure EqCancelNumerals = CancelNumeralsFun
   154  (open CancelNumeralsCommon
   155   val prove_conv = ArithData.prove_conv "inteq_cancel_numerals"
   156   val mk_bal   = FOLogic.mk_eq
   157   val dest_bal = FOLogic.dest_eq
   158   val bal_add1 = @{thm eq_add_iff1} RS iff_trans
   159   val bal_add2 = @{thm eq_add_iff2} RS iff_trans
   160 );
   161 
   162 structure LessCancelNumerals = CancelNumeralsFun
   163  (open CancelNumeralsCommon
   164   val prove_conv = ArithData.prove_conv "intless_cancel_numerals"
   165   val mk_bal   = FOLogic.mk_binrel @{const_name "zless"}
   166   val dest_bal = FOLogic.dest_bin @{const_name "zless"} @{typ i}
   167   val bal_add1 = @{thm less_add_iff1} RS iff_trans
   168   val bal_add2 = @{thm less_add_iff2} RS iff_trans
   169 );
   170 
   171 structure LeCancelNumerals = CancelNumeralsFun
   172  (open CancelNumeralsCommon
   173   val prove_conv = ArithData.prove_conv "intle_cancel_numerals"
   174   val mk_bal   = FOLogic.mk_binrel @{const_name "zle"}
   175   val dest_bal = FOLogic.dest_bin @{const_name "zle"} @{typ i}
   176   val bal_add1 = @{thm le_add_iff1} RS iff_trans
   177   val bal_add2 = @{thm le_add_iff2} RS iff_trans
   178 );
   179 
   180 val cancel_numerals =
   181   map prep_simproc
   182    [("inteq_cancel_numerals",
   183      ["l $+ m = n", "l = m $+ n",
   184       "l $- m = n", "l = m $- n",
   185       "l $* m = n", "l = m $* n"],
   186      K EqCancelNumerals.proc),
   187     ("intless_cancel_numerals",
   188      ["l $+ m $< n", "l $< m $+ n",
   189       "l $- m $< n", "l $< m $- n",
   190       "l $* m $< n", "l $< m $* n"],
   191      K LessCancelNumerals.proc),
   192     ("intle_cancel_numerals",
   193      ["l $+ m $<= n", "l $<= m $+ n",
   194       "l $- m $<= n", "l $<= m $- n",
   195       "l $* m $<= n", "l $<= m $* n"],
   196      K LeCancelNumerals.proc)];
   197 
   198 
   199 (*version without the hyps argument*)
   200 fun prove_conv_nohyps name tacs sg = ArithData.prove_conv name tacs sg [];
   201 
   202 structure CombineNumeralsData =
   203   struct
   204   type coeff            = int
   205   val iszero            = (fn x => x = 0)
   206   val add               = op + 
   207   val mk_sum            = (fn T:typ => long_mk_sum) (*to work for #2*x $+ #3*x *)
   208   val dest_sum          = dest_sum
   209   val mk_coeff          = mk_coeff
   210   val dest_coeff        = dest_coeff 1
   211   val left_distrib      = @{thm left_zadd_zmult_distrib} RS trans
   212   val prove_conv        = prove_conv_nohyps "int_combine_numerals"
   213   fun trans_tac _       = ArithData.gen_trans_tac trans
   214 
   215   val norm_ss1 = ZF_ss addsimps add_0s @ mult_1s @ diff_simps @ zminus_simps @ @{thms zadd_ac} @ intifys
   216   val norm_ss2 = ZF_ss addsimps bin_simps @ int_mult_minus_simps @ intifys
   217   val norm_ss3 = ZF_ss addsimps int_minus_from_mult_simps @ @{thms zadd_ac} @ @{thms zmult_ac} @ tc_rules @ intifys
   218   fun norm_tac ss =
   219     ALLGOALS (asm_simp_tac (Simplifier.inherit_context ss norm_ss1))
   220     THEN ALLGOALS (asm_simp_tac (Simplifier.inherit_context ss norm_ss2))
   221     THEN ALLGOALS (asm_simp_tac (Simplifier.inherit_context ss norm_ss3))
   222 
   223   val numeral_simp_ss = ZF_ss addsimps add_0s @ bin_simps @ tc_rules @ intifys
   224   fun numeral_simp_tac ss =
   225     ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   226   val simplify_meta_eq  = ArithData.simplify_meta_eq (add_0s @ mult_1s)
   227   end;
   228 
   229 structure CombineNumerals = CombineNumeralsFun(CombineNumeralsData);
   230 
   231 val combine_numerals =
   232   prep_simproc ("int_combine_numerals", ["i $+ j", "i $- j"], K CombineNumerals.proc);
   233 
   234 
   235 
   236 (** Constant folding for integer multiplication **)
   237 
   238 (*The trick is to regard products as sums, e.g. #3 $* x $* #4 as
   239   the "sum" of #3, x, #4; the literals are then multiplied*)
   240 
   241 
   242 structure CombineNumeralsProdData =
   243   struct
   244   type coeff            = int
   245   val iszero            = (fn x => x = 0)
   246   val add               = op *
   247   val mk_sum            = (fn T:typ => mk_prod)
   248   val dest_sum          = dest_prod
   249   fun mk_coeff(k,t) = if t=one then mk_numeral k
   250                       else raise TERM("mk_coeff", [])
   251   fun dest_coeff t = (dest_numeral t, one)  (*We ONLY want pure numerals.*)
   252   val left_distrib      = @{thm zmult_assoc} RS sym RS trans
   253   val prove_conv        = prove_conv_nohyps "int_combine_numerals_prod"
   254   fun trans_tac _       = ArithData.gen_trans_tac trans
   255 
   256 
   257 
   258 val norm_ss1 = ZF_ss addsimps mult_1s @ diff_simps @ zminus_simps
   259   val norm_ss2 = ZF_ss addsimps [@{thm zmult_zminus_right} RS sym] @
   260     bin_simps @ @{thms zmult_ac} @ tc_rules @ intifys
   261   fun norm_tac ss =
   262     ALLGOALS (asm_simp_tac (Simplifier.inherit_context ss norm_ss1))
   263     THEN ALLGOALS (asm_simp_tac (Simplifier.inherit_context ss norm_ss2))
   264 
   265   val numeral_simp_ss = ZF_ss addsimps bin_simps @ tc_rules @ intifys
   266   fun numeral_simp_tac ss =
   267     ALLGOALS (simp_tac (Simplifier.inherit_context ss numeral_simp_ss))
   268   val simplify_meta_eq  = ArithData.simplify_meta_eq (mult_1s);
   269   end;
   270 
   271 
   272 structure CombineNumeralsProd = CombineNumeralsFun(CombineNumeralsProdData);
   273 
   274 val combine_numerals_prod =
   275   prep_simproc ("int_combine_numerals_prod", ["i $* j"], K CombineNumeralsProd.proc);
   276 
   277 end;
   278 
   279 
   280 Addsimprocs Int_Numeral_Simprocs.cancel_numerals;
   281 Addsimprocs [Int_Numeral_Simprocs.combine_numerals,
   282              Int_Numeral_Simprocs.combine_numerals_prod];
   283 
   284 
   285 (*examples:*)
   286 (*
   287 print_depth 22;
   288 set timing;
   289 set trace_simp;
   290 fun test s = (Goal s; by (Asm_simp_tac 1));
   291 val sg = #sign (rep_thm (topthm()));
   292 val t = FOLogic.dest_Trueprop (Logic.strip_assums_concl(getgoal 1));
   293 val (t,_) = FOLogic.dest_eq t;
   294 
   295 (*combine_numerals_prod (products of separate literals) *)
   296 test "#5 $* x $* #3 = y";
   297 
   298 test "y2 $+ ?x42 = y $+ y2";
   299 
   300 test "oo : int ==> l $+ (l $+ #2) $+ oo = oo";
   301 
   302 test "#9$*x $+ y = x$*#23 $+ z";
   303 test "y $+ x = x $+ z";
   304 
   305 test "x : int ==> x $+ y $+ z = x $+ z";
   306 test "x : int ==> y $+ (z $+ x) = z $+ x";
   307 test "z : int ==> x $+ y $+ z = (z $+ y) $+ (x $+ w)";
   308 test "z : int ==> x$*y $+ z = (z $+ y) $+ (y$*x $+ w)";
   309 
   310 test "#-3 $* x $+ y $<= x $* #2 $+ z";
   311 test "y $+ x $<= x $+ z";
   312 test "x $+ y $+ z $<= x $+ z";
   313 
   314 test "y $+ (z $+ x) $< z $+ x";
   315 test "x $+ y $+ z $< (z $+ y) $+ (x $+ w)";
   316 test "x$*y $+ z $< (z $+ y) $+ (y$*x $+ w)";
   317 
   318 test "l $+ #2 $+ #2 $+ #2 $+ (l $+ #2) $+ (oo $+ #2) = uu";
   319 test "u : int ==> #2 $* u = u";
   320 test "(i $+ j $+ #12 $+ k) $- #15 = y";
   321 test "(i $+ j $+ #12 $+ k) $- #5 = y";
   322 
   323 test "y $- b $< b";
   324 test "y $- (#3 $* b $+ c) $< b $- #2 $* c";
   325 
   326 test "(#2 $* x $- (u $* v) $+ y) $- v $* #3 $* u = w";
   327 test "(#2 $* x $* u $* v $+ (u $* v) $* #4 $+ y) $- v $* u $* #4 = w";
   328 test "(#2 $* x $* u $* v $+ (u $* v) $* #4 $+ y) $- v $* u = w";
   329 test "u $* v $- (x $* u $* v $+ (u $* v) $* #4 $+ y) = w";
   330 
   331 test "(i $+ j $+ #12 $+ k) = u $+ #15 $+ y";
   332 test "(i $+ j $* #2 $+ #12 $+ k) = j $+ #5 $+ y";
   333 
   334 test "#2 $* y $+ #3 $* z $+ #6 $* w $+ #2 $* y $+ #3 $* z $+ #2 $* u = #2 $* y' $+ #3 $* z' $+ #6 $* w' $+ #2 $* y' $+ #3 $* z' $+ u $+ vv";
   335 
   336 test "a $+ $-(b$+c) $+ b = d";
   337 test "a $+ $-(b$+c) $- b = d";
   338 
   339 (*negative numerals*)
   340 test "(i $+ j $+ #-2 $+ k) $- (u $+ #5 $+ y) = zz";
   341 test "(i $+ j $+ #-3 $+ k) $< u $+ #5 $+ y";
   342 test "(i $+ j $+ #3 $+ k) $< u $+ #-6 $+ y";
   343 test "(i $+ j $+ #-12 $+ k) $- #15 = y";
   344 test "(i $+ j $+ #12 $+ k) $- #-15 = y";
   345 test "(i $+ j $+ #-12 $+ k) $- #-15 = y";
   346 
   347 (*Multiplying separated numerals*)
   348 Goal "#6 $* ($# x $* #2) =  uu";
   349 Goal "#4 $* ($# x $* $# x) $* (#2 $* $# x) =  uu";
   350 *)
   351