monomorphization of divmod wrt. code generation avoids costly dictionary unpacking at runtime
authorhaftmann
Sun, 27 Sep 2015 10:11:15 +0200
changeset 61275 053ec04ea866
parent 61274 0261eec37233
child 61276 8a4bd05c1735
monomorphization of divmod wrt. code generation avoids costly dictionary unpacking at runtime
src/HOL/Code_Numeral.thy
src/HOL/Divides.thy
src/HOL/Library/Code_Target_Int.thy
src/HOL/Library/Code_Target_Nat.thy
src/HOL/NSA/StarDef.thy
--- a/src/HOL/Code_Numeral.thy	Sun Sep 27 10:11:14 2015 +0200
+++ b/src/HOL/Code_Numeral.thy	Sun Sep 27 10:11:15 2015 +0200
@@ -147,6 +147,25 @@
   "int_of_integer (Num.sub k l) = Num.sub k l"
   by transfer rule
 
+lift_definition integer_of_num :: "num \<Rightarrow> integer"
+  is "numeral :: num \<Rightarrow> int"
+  .
+
+lemma integer_of_num [code]:
+  "integer_of_num num.One = 1"
+  "integer_of_num (num.Bit0 n) = (let k = integer_of_num n in k + k)"
+  "integer_of_num (num.Bit1 n) = (let k = integer_of_num n in k + k + 1)"
+  by (transfer, simp only: numeral.simps Let_def)+
+
+lemma numeral_unfold_integer_of_num:
+  "numeral = integer_of_num"
+  by (simp add: integer_of_num_def map_fun_def fun_eq_iff)
+
+lemma integer_of_num_triv:
+  "integer_of_num Num.One = 1"
+  "integer_of_num (Num.Bit0 Num.One) = 2"
+  by (transfer, simp)+
+
 instantiation integer :: "{ring_div, equal, linordered_idom}"
 begin
 
@@ -215,18 +234,43 @@
   "of_nat (nat_of_integer k) = max 0 k"
   by transfer auto
 
-instance integer :: semiring_numeral_div
-  by intro_classes (transfer,
-    fact le_add_diff_inverse2
-    semiring_numeral_div_class.div_less
-    semiring_numeral_div_class.mod_less
-    semiring_numeral_div_class.div_positive
-    semiring_numeral_div_class.mod_less_eq_dividend
-    semiring_numeral_div_class.pos_mod_bound
-    semiring_numeral_div_class.pos_mod_sign
-    semiring_numeral_div_class.mod_mult2_eq
-    semiring_numeral_div_class.div_mult2_eq
-    semiring_numeral_div_class.discrete)+
+instantiation integer :: semiring_numeral_div
+begin
+
+definition divmod_integer :: "num \<Rightarrow> num \<Rightarrow> integer \<times> integer"
+where
+  divmod_integer'_def: "divmod_integer m n = (numeral m div numeral n, numeral m mod numeral n)"
+
+definition divmod_step_integer :: "num \<Rightarrow> integer \<times> integer \<Rightarrow> integer \<times> integer"
+where
+  "divmod_step_integer l qr = (let (q, r) = qr
+    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
+    else (2 * q, r))"
+
+instance proof
+  show "divmod m n = (numeral m div numeral n :: integer, numeral m mod numeral n)"
+    for m n by (fact divmod_integer'_def)
+  show "divmod_step l qr = (let (q, r) = qr
+    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
+    else (2 * q, r))" for l and qr :: "integer \<times> integer"
+    by (fact divmod_step_integer_def)
+qed (transfer,
+  fact le_add_diff_inverse2
+  semiring_numeral_div_class.div_less
+  semiring_numeral_div_class.mod_less
+  semiring_numeral_div_class.div_positive
+  semiring_numeral_div_class.mod_less_eq_dividend
+  semiring_numeral_div_class.pos_mod_bound
+  semiring_numeral_div_class.pos_mod_sign
+  semiring_numeral_div_class.mod_mult2_eq
+  semiring_numeral_div_class.div_mult2_eq
+  semiring_numeral_div_class.discrete)+
+
+end
+
+declare divmod_algorithm_code [where ?'a = integer,
+  unfolded numeral_unfold_integer_of_num, unfolded integer_of_num_triv, 
+  code]
 
 lemma integer_of_nat_0: "integer_of_nat 0 = 0"
 by transfer simp
@@ -440,16 +484,6 @@
   "Neg k < Neg l \<longleftrightarrow> l < k"
   by simp_all
 
-lift_definition integer_of_num :: "num \<Rightarrow> integer"
-  is "numeral :: num \<Rightarrow> int"
-  .
-
-lemma integer_of_num [code]:
-  "integer_of_num num.One = 1"
-  "integer_of_num (num.Bit0 n) = (let k = integer_of_num n in k + k)"
-  "integer_of_num (num.Bit1 n) = (let k = integer_of_num n in k + k + 1)"
-  by (transfer, simp only: numeral.simps Let_def)+
-
 lift_definition num_of_integer :: "integer \<Rightarrow> num"
   is "num_of_nat \<circ> nat"
   .
--- a/src/HOL/Divides.thy	Sun Sep 27 10:11:14 2015 +0200
+++ b/src/HOL/Divides.thy	Sun Sep 27 10:11:15 2015 +0200
@@ -567,6 +567,16 @@
     and mod_mult2_eq: "0 \<le> c \<Longrightarrow> a mod (b * c) = b * (a div b mod c) + a mod b"
     and div_mult2_eq: "0 \<le> c \<Longrightarrow> a div (b * c) = a div b div c"
   assumes discrete: "a < b \<longleftrightarrow> a + 1 \<le> b"
+  fixes divmod :: "num \<Rightarrow> num \<Rightarrow> 'a \<times> 'a"
+    and divmod_step :: "num \<Rightarrow> 'a \<times> 'a \<Rightarrow> 'a \<times> 'a"
+  assumes divmod_def: "divmod m n = (numeral m div numeral n, numeral m mod numeral n)"
+    and divmod_step_def: "divmod_step l qr = (let (q, r) = qr
+    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
+    else (2 * q, r))"
+    -- \<open>These are conceptually definitions but force generated code
+    to be monomorphic wrt. particular instances of this class which
+    yields a significant speedup.\<close>
+
 begin
 
 lemma mult_div_cancel:
@@ -650,10 +660,6 @@
     by (simp_all add: div mod)
 qed
 
-definition divmod :: "num \<Rightarrow> num \<Rightarrow> 'a \<times> 'a"
-where
-  "divmod m n = (numeral m div numeral n, numeral m mod numeral n)"
-
 lemma fst_divmod:
   "fst (divmod m n) = numeral m div numeral n"
   by (simp add: divmod_def)
@@ -662,12 +668,6 @@
   "snd (divmod m n) = numeral m mod numeral n"
   by (simp add: divmod_def)
 
-definition divmod_step :: "num \<Rightarrow> 'a \<times> 'a \<Rightarrow> 'a \<times> 'a"
-where
-  "divmod_step l qr = (let (q, r) = qr
-    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
-    else (2 * q, r))"
-
 text \<open>
   This is a formulation of one step (referring to one digit position)
   in school-method division: compare the dividend at the current
@@ -675,7 +675,7 @@
   and evaluate accordingly.
 \<close>
 
-lemma divmod_step_eq [code, simp]:
+lemma divmod_step_eq [simp]:
   "divmod_step l (q, r) = (if numeral l \<le> r
     then (2 * q + 1, r - numeral l) else (2 * q, r))"
   by (simp add: divmod_step_def)
@@ -735,7 +735,7 @@
 
 text \<open>The division rewrite proper -- first, trivial results involving @{text 1}\<close>
 
-lemma divmod_trivial [simp, code]:
+lemma divmod_trivial [simp]:
   "divmod Num.One Num.One = (numeral Num.One, 0)"
   "divmod (Num.Bit0 m) Num.One = (numeral (Num.Bit0 m), 0)"
   "divmod (Num.Bit1 m) Num.One = (numeral (Num.Bit1 m), 0)"
@@ -745,7 +745,7 @@
 
 text \<open>Division by an even number is a right-shift\<close>
 
-lemma divmod_cancel [simp, code]:
+lemma divmod_cancel [simp]:
   "divmod (Num.Bit0 m) (Num.Bit0 n) = (case divmod m n of (q, r) \<Rightarrow> (q, 2 * r))" (is ?P)
   "divmod (Num.Bit1 m) (Num.Bit0 n) = (case divmod m n of (q, r) \<Rightarrow> (q, 2 * r + 1))" (is ?Q)
 proof -
@@ -761,7 +761,7 @@
 
 text \<open>The really hard work\<close>
 
-lemma divmod_steps [simp, code]:
+lemma divmod_steps [simp]:
   "divmod (num.Bit0 m) (num.Bit1 n) =
       (if m \<le> n then (0, numeral (num.Bit0 m))
        else divmod_step (num.Bit1 n)
@@ -774,6 +774,8 @@
                (num.Bit0 (num.Bit1 n))))"
   by (simp_all add: divmod_divmod_step)
 
+lemmas divmod_algorithm_code = divmod_step_eq divmod_trivial divmod_cancel divmod_steps  
+
 text \<open>Special case: divisibility\<close>
 
 definition divides_aux :: "'a \<times> 'a \<Rightarrow> bool"
@@ -1177,9 +1179,26 @@
 lemma mod_mult2_eq: "a mod (b * c) = b * (a div b mod c) + a mod (b::nat)"
 by (auto simp add: mult.commute divmod_nat_rel [THEN divmod_nat_rel_mult2_eq, THEN mod_nat_unique])
 
-instance nat :: semiring_numeral_div
-  by intro_classes (auto intro: div_positive simp add: mult_div_cancel mod_mult2_eq div_mult2_eq)
-
+instantiation nat :: semiring_numeral_div
+begin
+
+definition divmod_nat :: "num \<Rightarrow> num \<Rightarrow> nat \<times> nat"
+where
+  divmod'_nat_def: "divmod_nat m n = (numeral m div numeral n, numeral m mod numeral n)"
+
+definition divmod_step_nat :: "num \<Rightarrow> nat \<times> nat \<Rightarrow> nat \<times> nat"
+where
+  "divmod_step_nat l qr = (let (q, r) = qr
+    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
+    else (2 * q, r))"
+
+instance
+  by standard (auto intro: div_positive simp add: divmod'_nat_def divmod_step_nat_def mod_mult2_eq div_mult2_eq)
+
+end
+
+declare divmod_algorithm_code [where ?'a = nat, code]
+  
 
 subsubsection \<open>Further Facts about Quotient and Remainder\<close>
 
@@ -2304,12 +2323,26 @@
 
 subsubsection \<open>Computation of Division and Remainder\<close>
 
-instance int :: semiring_numeral_div
-  by intro_classes (auto intro: zmod_le_nonneg_dividend
-    simp add:
-    zmult_div_cancel
-    pos_imp_zdiv_pos_iff div_pos_pos_trivial mod_pos_pos_trivial
-    zmod_zmult2_eq zdiv_zmult2_eq)
+instantiation int :: semiring_numeral_div
+begin
+
+definition divmod_int :: "num \<Rightarrow> num \<Rightarrow> int \<times> int"
+where
+  "divmod_int m n = (numeral m div numeral n, numeral m mod numeral n)"
+
+definition divmod_step_int :: "num \<Rightarrow> int \<times> int \<Rightarrow> int \<times> int"
+where
+  "divmod_step_int l qr = (let (q, r) = qr
+    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
+    else (2 * q, r))"
+
+instance
+  by standard (auto intro: zmod_le_nonneg_dividend simp add: divmod_int_def divmod_step_int_def
+    pos_imp_zdiv_pos_iff div_pos_pos_trivial mod_pos_pos_trivial zmod_zmult2_eq zdiv_zmult2_eq)
+
+end
+
+declare divmod_algorithm_code [where ?'a = int, code]
 
 context
 begin
--- a/src/HOL/Library/Code_Target_Int.thy	Sun Sep 27 10:11:14 2015 +0200
+++ b/src/HOL/Library/Code_Target_Int.thy	Sun Sep 27 10:11:15 2015 +0200
@@ -79,6 +79,11 @@
   by simp
 
 lemma [code]:
+  "divmod m n = map_prod int_of_integer int_of_integer (divmod m n)"
+  unfolding prod_eq_iff divmod_def map_prod_def case_prod_beta fst_conv snd_conv
+  by transfer simp
+
+lemma [code]:
   "HOL.equal k l = HOL.equal (of_int k :: integer) (of_int l)"
   by transfer (simp add: equal)
 
--- a/src/HOL/Library/Code_Target_Nat.thy	Sun Sep 27 10:11:14 2015 +0200
+++ b/src/HOL/Library/Code_Target_Nat.thy	Sun Sep 27 10:11:15 2015 +0200
@@ -85,6 +85,12 @@
   by (fact divmod_nat_div_mod)
 
 lemma [code]:
+  "divmod m n = map_prod nat_of_integer nat_of_integer (divmod m n)"
+  by (simp only: prod_eq_iff divmod_def map_prod_def case_prod_beta fst_conv snd_conv)
+    (transfer, simp_all only: nat_div_distrib nat_mod_distrib
+        zero_le_numeral nat_numeral)
+  
+lemma [code]:
   "HOL.equal m n = HOL.equal (of_nat m :: integer) (of_nat n)"
   by transfer (simp add: equal)
 
--- a/src/HOL/NSA/StarDef.thy	Sun Sep 27 10:11:14 2015 +0200
+++ b/src/HOL/NSA/StarDef.thy	Sun Sep 27 10:11:15 2015 +0200
@@ -1009,18 +1009,42 @@
 apply(transfer, rule zero_not_eq_two)
 done
 
-instance star :: (semiring_numeral_div) semiring_numeral_div
-apply intro_classes
-apply(transfer, fact semiring_numeral_div_class.div_less)
-apply(transfer, fact semiring_numeral_div_class.mod_less)
-apply(transfer, fact semiring_numeral_div_class.div_positive)
-apply(transfer, fact semiring_numeral_div_class.mod_less_eq_dividend)
-apply(transfer, fact semiring_numeral_div_class.pos_mod_bound)
-apply(transfer, fact semiring_numeral_div_class.pos_mod_sign)
-apply(transfer, fact semiring_numeral_div_class.mod_mult2_eq)
-apply(transfer, fact semiring_numeral_div_class.div_mult2_eq)
-apply(transfer, fact discrete)
-done
+instantiation star :: (semiring_numeral_div) semiring_numeral_div
+begin
+
+definition divmod_star :: "num \<Rightarrow> num \<Rightarrow> 'a star \<times> 'a star"
+where
+  divmod_star_def: "divmod_star m n = (numeral m div numeral n, numeral m mod numeral n)"
+
+definition divmod_step_star :: "num \<Rightarrow> 'a star \<times> 'a star \<Rightarrow> 'a star \<times> 'a star"
+where
+  "divmod_step_star l qr = (let (q, r) = qr
+    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
+    else (2 * q, r))"
+
+instance proof
+  show "divmod m n = (numeral m div numeral n :: 'a star, numeral m mod numeral n)"
+    for m n by (fact divmod_star_def)
+  show "divmod_step l qr = (let (q, r) = qr
+    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
+    else (2 * q, r))" for l and qr :: "'a star \<times> 'a star"
+    by (fact divmod_step_star_def)
+qed (transfer,
+  fact
+  semiring_numeral_div_class.div_less
+  semiring_numeral_div_class.mod_less
+  semiring_numeral_div_class.div_positive
+  semiring_numeral_div_class.mod_less_eq_dividend
+  semiring_numeral_div_class.pos_mod_bound
+  semiring_numeral_div_class.pos_mod_sign
+  semiring_numeral_div_class.mod_mult2_eq
+  semiring_numeral_div_class.div_mult2_eq
+  semiring_numeral_div_class.discrete)+
+
+end
+
+declare divmod_algorithm_code [where ?'a = "'a::semiring_numeral_div star", code]
+
 
 subsection {* Finite class *}