optimization for division by powers of 2
authorpaulson
Wed, 14 Jul 1999 10:40:51 +0200
changeset 6999 73f681047e5f
parent 6998 8a1a39b8fad8
child 7000 6920cf9b8623
optimization for division by powers of 2
src/HOL/Integ/IntDiv.ML
--- a/src/HOL/Integ/IntDiv.ML	Wed Jul 14 10:40:11 1999 +0200
+++ b/src/HOL/Integ/IntDiv.ML	Wed Jul 14 10:40:51 1999 +0200
@@ -280,24 +280,38 @@
 Goal "[| (#0::int) <= a;  a < b |] ==> a div b = #0";
 by (rtac quorem_div 1);
 by (auto_tac (claset(), simpset() addsimps [quorem_def]));
-qed "pos_div_trivial";
+qed "div_pos_pos_trivial";
 
 Goal "[| a <= (#0::int);  b < a |] ==> a div b = #0";
 by (rtac quorem_div 1);
 by (auto_tac (claset(), simpset() addsimps [quorem_def]));
-qed "neg_div_trivial";
+qed "div_neg_neg_trivial";
+
+Goal "[| (#0::int) < a;  a+b <= #0 |] ==> a div b = #-1";
+by (rtac quorem_div 1);
+by (auto_tac (claset(), simpset() addsimps [quorem_def]));
+qed "div_pos_neg_trivial";
+
+(*There is no div_neg_pos_trivial because  #0 div b = #0 would supersede it*)
 
 Goal "[| (#0::int) <= a;  a < b |] ==> a mod b = a";
 by (rtac quorem_mod 1);
 by (auto_tac (claset(), simpset() addsimps [quorem_def]));
 by (rtac zmult_0_right 1);
-qed "pos_mod_trivial";
+qed "mod_pos_pos_trivial";
 
 Goal "[| a <= (#0::int);  b < a |] ==> a mod b = a";
 by (rtac quorem_mod 1);
 by (auto_tac (claset(), simpset() addsimps [quorem_def]));
 by (rtac zmult_0_right 1);
-qed "neg_mod_trivial";
+qed "mod_neg_neg_trivial";
+
+Goal "[| (#0::int) < a;  a+b <= #0 |] ==> a mod b = a+b";
+by (res_inst_tac [("q","#-1")] quorem_mod 1);
+by (auto_tac (claset(), simpset() addsimps [quorem_def]));
+qed "mod_pos_neg_trivial";
+
+(*There is no mod_neg_pos_trivial...*)
 
 
 (*Simpler laws such as -a div b = -(a div b) FAIL*)
@@ -604,6 +618,7 @@
 by Auto_tac;
 val lemma = result();
 
+(*NOT suitable for rewriting: the RHS has an instance of the LHS*)
 Goal "(a+b) div (c::int) = a div c + b div c + ((a mod c + b mod c) div c)";
 by (undefined_case_tac "c = #0" 1);
 by (blast_tac (claset() addIs [[quorem_div_mod,quorem_div_mod]
@@ -620,18 +635,18 @@
 Goal "(a mod b) div b = (#0::int)";
 by (undefined_case_tac "b = #0" 1);
 by (auto_tac (claset(), 
-        simpset() addsimps [linorder_neq_iff, 
-			    pos_mod_sign, pos_mod_bound, pos_div_trivial, 
-			    neg_mod_sign, neg_mod_bound, neg_div_trivial]));
+       simpset() addsimps [linorder_neq_iff, 
+			   pos_mod_sign, pos_mod_bound, div_pos_pos_trivial, 
+			   neg_mod_sign, neg_mod_bound, div_neg_neg_trivial]));
 qed "mod_div_trivial";
 Addsimps [mod_div_trivial];
 
 Goal "(a mod b) mod b = a mod (b::int)";
 by (undefined_case_tac "b = #0" 1);
 by (auto_tac (claset(), 
-        simpset() addsimps [linorder_neq_iff, 
-			    pos_mod_sign, pos_mod_bound, pos_mod_trivial, 
-			    neg_mod_sign, neg_mod_bound, neg_mod_trivial]));
+       simpset() addsimps [linorder_neq_iff, 
+			   pos_mod_sign, pos_mod_bound, mod_pos_pos_trivial, 
+			   neg_mod_sign, neg_mod_bound, mod_neg_neg_trivial]));
 qed "mod_mod_trivial";
 Addsimps [mod_mod_trivial];
 
@@ -673,7 +688,7 @@
 by (rtac zmult_zle_mono2_neg 1);
 by (auto_tac
     (claset(),
-     simpset() addsimps zcompare_rls@[pos_mod_bound]));
+     simpset() addsimps zcompare_rls@[add1_zle_eq,pos_mod_bound]));
 val lemma1 = result();
 
 Goal "[| (#0::int) < c;   b < r;  r <= #0 |] ==> r + b * (q mod c) <= #0";
@@ -698,7 +713,7 @@
 by (rtac zmult_zle_mono2 2);
 by (auto_tac
     (claset(),
-     simpset() addsimps zcompare_rls@[pos_mod_bound]));
+     simpset() addsimps zcompare_rls@[add1_zle_eq,pos_mod_bound]));
 val lemma4 = result();
 
 
@@ -743,8 +758,7 @@
 Goal "[| b < (#0::int);  c ~= #0 |] ==> (c*a) div (c*b) = a div b";
 by (subgoal_tac "(c * -a) div (c * -b) = -a div -b" 1);
 by (rtac lemma1 2);
-by (auto_tac (claset(), 
-	      simpset() addsimps [zdiv_zminus_zminus]));
+by Auto_tac;
 val lemma2 = result();
 
 Goal "c ~= (#0::int) ==> (c*a) div (c*b) = a div b";
@@ -777,8 +791,9 @@
 	      simpset() addsimps [zmod_zminus_zminus]));
 val lemma2 = result();
 
-Goal "c ~= (#0::int) ==> (c*a) mod (c*b) = c * (a mod b)";
+Goal "(c*a) mod (c*b) = (c::int) * (a mod b)";
 by (undefined_case_tac "b = #0" 1);
+by (undefined_case_tac "c = #0" 1);
 by (auto_tac
     (claset(), 
      simpset() delsimps zmult_ac
@@ -786,10 +801,70 @@
 			 lemma1, lemma2]));
 qed "zmod_zmult_zmult1";
 
-Goal "c ~= (#0::int) ==> (a*c) mod (b*c) = (a mod b) * c";
-by (dtac zmod_zmult_zmult1 1);
+Goal "(a*c) mod (b*c) = (a mod b) * (c::int)";
+by (cut_inst_tac [("c","c")] zmod_zmult_zmult1 1);
 by Auto_tac;
 qed "zmod_zmult_zmult2";
 
 
+(*** Speeding up the division algorithm with shifting ***)
 
+(** NB Could do the same thing for "mod" **)
+
+Goal "(#0::int) <= a ==> (#1 + #2*b) div (#2*a) = b div a";
+by (undefined_case_tac "a = #0" 1);
+by (subgoal_tac "#1 <= a" 1);
+by (arith_tac 2);
+by (subgoal_tac "#1 < a * #2" 1);
+by (dres_inst_tac [("i","#1"), ("k", "#2")] zmult_zle_mono1 2);
+by (subgoal_tac "#2*(#1 + b mod a) <= #2*a" 1);
+br zmult_zle_mono2 2;
+by (auto_tac (claset(),
+	      simpset() addsimps [add1_zle_eq,pos_mod_bound]));
+by (stac zdiv_zadd1_eq 1);
+by (auto_tac (claset(),
+	      simpset() addsimps [zdiv_zmult_zmult2, zmod_zmult_zmult2, 
+				  div_pos_pos_trivial]));
+by (stac div_pos_pos_trivial 1);
+by (asm_simp_tac (simpset() addsimps [zmult_2_right, mod_pos_pos_trivial, 
+	   pos_mod_sign RS zadd_zle_mono1 RSN (2,order_trans)]) 1);
+by (auto_tac (claset(),
+	      simpset() addsimps [mod_pos_pos_trivial]));
+qed "pos_zdiv_times_2";
+
+
+Goal "a <= (#0::int) ==> (#1 + #2*b) div (#2*a) = (b+#1) div a";
+by (subgoal_tac "(#1 + #2*(-b-#1)) div (#2 * -a) = (-b-#1) div (-a)" 1);
+br pos_zdiv_times_2 2;
+by Auto_tac;
+by (subgoal_tac "(#-1 + - (b * #2)) = - (#1 + (b*#2))" 1);
+by (Simp_tac 2);
+by (asm_full_simp_tac (HOL_ss
+		       addsimps [zdiv_zminus_zminus, zdiff_def,
+				 zminus_zadd_distrib RS sym]) 1);
+qed "neg_zdiv_times_2";
+
+
+(*Not clear why this must be proved separately; probably number_of causes
+  simplification problems*)
+Goal "~ #0 <= x ==> x <= (#0::int)";
+auto();
+val lemma = result();
+
+Goal "number_of (v BIT b) div number_of (w BIT False) = \
+\         (if ~b | (#0::int) <= number_of w                   \
+\          then number_of v div (number_of w)    \
+\          else (number_of v + (#1::int)) div (number_of w))";
+by (simp_tac (simpset_of Int.thy
+			 addsimps [zadd_assoc, number_of_BIT]) 1);
+by (asm_simp_tac (simpset_of Int.thy
+		  addsimps [int_0, int_Suc, zadd_0_right,zmult_0_right, 
+			    zmult_2 RS sym, zdiv_zmult_zmult1,
+			    pos_zdiv_times_2,
+			    lemma, neg_zdiv_times_2]) 1);
+qed "zdiv_number_of_BIT";
+
+
+Addsimps [zdiv_number_of_BIT];
+
+