monomorphization of divmod wrt. code generation avoids costly dictionary unpacking at runtime
authorhaftmann
Sun Sep 27 10:11:15 2015 +0200 (2015-09-27)
changeset 61275053ec04ea866
parent 61274 0261eec37233
child 61276 8a4bd05c1735
monomorphization of divmod wrt. code generation avoids costly dictionary unpacking at runtime
src/HOL/Code_Numeral.thy
src/HOL/Divides.thy
src/HOL/Library/Code_Target_Int.thy
src/HOL/Library/Code_Target_Nat.thy
src/HOL/NSA/StarDef.thy
     1.1 --- a/src/HOL/Code_Numeral.thy	Sun Sep 27 10:11:14 2015 +0200
     1.2 +++ b/src/HOL/Code_Numeral.thy	Sun Sep 27 10:11:15 2015 +0200
     1.3 @@ -147,6 +147,25 @@
     1.4    "int_of_integer (Num.sub k l) = Num.sub k l"
     1.5    by transfer rule
     1.6  
     1.7 +lift_definition integer_of_num :: "num \<Rightarrow> integer"
     1.8 +  is "numeral :: num \<Rightarrow> int"
     1.9 +  .
    1.10 +
    1.11 +lemma integer_of_num [code]:
    1.12 +  "integer_of_num num.One = 1"
    1.13 +  "integer_of_num (num.Bit0 n) = (let k = integer_of_num n in k + k)"
    1.14 +  "integer_of_num (num.Bit1 n) = (let k = integer_of_num n in k + k + 1)"
    1.15 +  by (transfer, simp only: numeral.simps Let_def)+
    1.16 +
    1.17 +lemma numeral_unfold_integer_of_num:
    1.18 +  "numeral = integer_of_num"
    1.19 +  by (simp add: integer_of_num_def map_fun_def fun_eq_iff)
    1.20 +
    1.21 +lemma integer_of_num_triv:
    1.22 +  "integer_of_num Num.One = 1"
    1.23 +  "integer_of_num (Num.Bit0 Num.One) = 2"
    1.24 +  by (transfer, simp)+
    1.25 +
    1.26  instantiation integer :: "{ring_div, equal, linordered_idom}"
    1.27  begin
    1.28  
    1.29 @@ -215,18 +234,43 @@
    1.30    "of_nat (nat_of_integer k) = max 0 k"
    1.31    by transfer auto
    1.32  
    1.33 -instance integer :: semiring_numeral_div
    1.34 -  by intro_classes (transfer,
    1.35 -    fact le_add_diff_inverse2
    1.36 -    semiring_numeral_div_class.div_less
    1.37 -    semiring_numeral_div_class.mod_less
    1.38 -    semiring_numeral_div_class.div_positive
    1.39 -    semiring_numeral_div_class.mod_less_eq_dividend
    1.40 -    semiring_numeral_div_class.pos_mod_bound
    1.41 -    semiring_numeral_div_class.pos_mod_sign
    1.42 -    semiring_numeral_div_class.mod_mult2_eq
    1.43 -    semiring_numeral_div_class.div_mult2_eq
    1.44 -    semiring_numeral_div_class.discrete)+
    1.45 +instantiation integer :: semiring_numeral_div
    1.46 +begin
    1.47 +
    1.48 +definition divmod_integer :: "num \<Rightarrow> num \<Rightarrow> integer \<times> integer"
    1.49 +where
    1.50 +  divmod_integer'_def: "divmod_integer m n = (numeral m div numeral n, numeral m mod numeral n)"
    1.51 +
    1.52 +definition divmod_step_integer :: "num \<Rightarrow> integer \<times> integer \<Rightarrow> integer \<times> integer"
    1.53 +where
    1.54 +  "divmod_step_integer l qr = (let (q, r) = qr
    1.55 +    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
    1.56 +    else (2 * q, r))"
    1.57 +
    1.58 +instance proof
    1.59 +  show "divmod m n = (numeral m div numeral n :: integer, numeral m mod numeral n)"
    1.60 +    for m n by (fact divmod_integer'_def)
    1.61 +  show "divmod_step l qr = (let (q, r) = qr
    1.62 +    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
    1.63 +    else (2 * q, r))" for l and qr :: "integer \<times> integer"
    1.64 +    by (fact divmod_step_integer_def)
    1.65 +qed (transfer,
    1.66 +  fact le_add_diff_inverse2
    1.67 +  semiring_numeral_div_class.div_less
    1.68 +  semiring_numeral_div_class.mod_less
    1.69 +  semiring_numeral_div_class.div_positive
    1.70 +  semiring_numeral_div_class.mod_less_eq_dividend
    1.71 +  semiring_numeral_div_class.pos_mod_bound
    1.72 +  semiring_numeral_div_class.pos_mod_sign
    1.73 +  semiring_numeral_div_class.mod_mult2_eq
    1.74 +  semiring_numeral_div_class.div_mult2_eq
    1.75 +  semiring_numeral_div_class.discrete)+
    1.76 +
    1.77 +end
    1.78 +
    1.79 +declare divmod_algorithm_code [where ?'a = integer,
    1.80 +  unfolded numeral_unfold_integer_of_num, unfolded integer_of_num_triv, 
    1.81 +  code]
    1.82  
    1.83  lemma integer_of_nat_0: "integer_of_nat 0 = 0"
    1.84  by transfer simp
    1.85 @@ -440,16 +484,6 @@
    1.86    "Neg k < Neg l \<longleftrightarrow> l < k"
    1.87    by simp_all
    1.88  
    1.89 -lift_definition integer_of_num :: "num \<Rightarrow> integer"
    1.90 -  is "numeral :: num \<Rightarrow> int"
    1.91 -  .
    1.92 -
    1.93 -lemma integer_of_num [code]:
    1.94 -  "integer_of_num num.One = 1"
    1.95 -  "integer_of_num (num.Bit0 n) = (let k = integer_of_num n in k + k)"
    1.96 -  "integer_of_num (num.Bit1 n) = (let k = integer_of_num n in k + k + 1)"
    1.97 -  by (transfer, simp only: numeral.simps Let_def)+
    1.98 -
    1.99  lift_definition num_of_integer :: "integer \<Rightarrow> num"
   1.100    is "num_of_nat \<circ> nat"
   1.101    .
     2.1 --- a/src/HOL/Divides.thy	Sun Sep 27 10:11:14 2015 +0200
     2.2 +++ b/src/HOL/Divides.thy	Sun Sep 27 10:11:15 2015 +0200
     2.3 @@ -567,6 +567,16 @@
     2.4      and mod_mult2_eq: "0 \<le> c \<Longrightarrow> a mod (b * c) = b * (a div b mod c) + a mod b"
     2.5      and div_mult2_eq: "0 \<le> c \<Longrightarrow> a div (b * c) = a div b div c"
     2.6    assumes discrete: "a < b \<longleftrightarrow> a + 1 \<le> b"
     2.7 +  fixes divmod :: "num \<Rightarrow> num \<Rightarrow> 'a \<times> 'a"
     2.8 +    and divmod_step :: "num \<Rightarrow> 'a \<times> 'a \<Rightarrow> 'a \<times> 'a"
     2.9 +  assumes divmod_def: "divmod m n = (numeral m div numeral n, numeral m mod numeral n)"
    2.10 +    and divmod_step_def: "divmod_step l qr = (let (q, r) = qr
    2.11 +    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
    2.12 +    else (2 * q, r))"
    2.13 +    -- \<open>These are conceptually definitions but force generated code
    2.14 +    to be monomorphic wrt. particular instances of this class which
    2.15 +    yields a significant speedup.\<close>
    2.16 +
    2.17  begin
    2.18  
    2.19  lemma mult_div_cancel:
    2.20 @@ -650,10 +660,6 @@
    2.21      by (simp_all add: div mod)
    2.22  qed
    2.23  
    2.24 -definition divmod :: "num \<Rightarrow> num \<Rightarrow> 'a \<times> 'a"
    2.25 -where
    2.26 -  "divmod m n = (numeral m div numeral n, numeral m mod numeral n)"
    2.27 -
    2.28  lemma fst_divmod:
    2.29    "fst (divmod m n) = numeral m div numeral n"
    2.30    by (simp add: divmod_def)
    2.31 @@ -662,12 +668,6 @@
    2.32    "snd (divmod m n) = numeral m mod numeral n"
    2.33    by (simp add: divmod_def)
    2.34  
    2.35 -definition divmod_step :: "num \<Rightarrow> 'a \<times> 'a \<Rightarrow> 'a \<times> 'a"
    2.36 -where
    2.37 -  "divmod_step l qr = (let (q, r) = qr
    2.38 -    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
    2.39 -    else (2 * q, r))"
    2.40 -
    2.41  text \<open>
    2.42    This is a formulation of one step (referring to one digit position)
    2.43    in school-method division: compare the dividend at the current
    2.44 @@ -675,7 +675,7 @@
    2.45    and evaluate accordingly.
    2.46  \<close>
    2.47  
    2.48 -lemma divmod_step_eq [code, simp]:
    2.49 +lemma divmod_step_eq [simp]:
    2.50    "divmod_step l (q, r) = (if numeral l \<le> r
    2.51      then (2 * q + 1, r - numeral l) else (2 * q, r))"
    2.52    by (simp add: divmod_step_def)
    2.53 @@ -735,7 +735,7 @@
    2.54  
    2.55  text \<open>The division rewrite proper -- first, trivial results involving @{text 1}\<close>
    2.56  
    2.57 -lemma divmod_trivial [simp, code]:
    2.58 +lemma divmod_trivial [simp]:
    2.59    "divmod Num.One Num.One = (numeral Num.One, 0)"
    2.60    "divmod (Num.Bit0 m) Num.One = (numeral (Num.Bit0 m), 0)"
    2.61    "divmod (Num.Bit1 m) Num.One = (numeral (Num.Bit1 m), 0)"
    2.62 @@ -745,7 +745,7 @@
    2.63  
    2.64  text \<open>Division by an even number is a right-shift\<close>
    2.65  
    2.66 -lemma divmod_cancel [simp, code]:
    2.67 +lemma divmod_cancel [simp]:
    2.68    "divmod (Num.Bit0 m) (Num.Bit0 n) = (case divmod m n of (q, r) \<Rightarrow> (q, 2 * r))" (is ?P)
    2.69    "divmod (Num.Bit1 m) (Num.Bit0 n) = (case divmod m n of (q, r) \<Rightarrow> (q, 2 * r + 1))" (is ?Q)
    2.70  proof -
    2.71 @@ -761,7 +761,7 @@
    2.72  
    2.73  text \<open>The really hard work\<close>
    2.74  
    2.75 -lemma divmod_steps [simp, code]:
    2.76 +lemma divmod_steps [simp]:
    2.77    "divmod (num.Bit0 m) (num.Bit1 n) =
    2.78        (if m \<le> n then (0, numeral (num.Bit0 m))
    2.79         else divmod_step (num.Bit1 n)
    2.80 @@ -774,6 +774,8 @@
    2.81                 (num.Bit0 (num.Bit1 n))))"
    2.82    by (simp_all add: divmod_divmod_step)
    2.83  
    2.84 +lemmas divmod_algorithm_code = divmod_step_eq divmod_trivial divmod_cancel divmod_steps  
    2.85 +
    2.86  text \<open>Special case: divisibility\<close>
    2.87  
    2.88  definition divides_aux :: "'a \<times> 'a \<Rightarrow> bool"
    2.89 @@ -1177,9 +1179,26 @@
    2.90  lemma mod_mult2_eq: "a mod (b * c) = b * (a div b mod c) + a mod (b::nat)"
    2.91  by (auto simp add: mult.commute divmod_nat_rel [THEN divmod_nat_rel_mult2_eq, THEN mod_nat_unique])
    2.92  
    2.93 -instance nat :: semiring_numeral_div
    2.94 -  by intro_classes (auto intro: div_positive simp add: mult_div_cancel mod_mult2_eq div_mult2_eq)
    2.95 -
    2.96 +instantiation nat :: semiring_numeral_div
    2.97 +begin
    2.98 +
    2.99 +definition divmod_nat :: "num \<Rightarrow> num \<Rightarrow> nat \<times> nat"
   2.100 +where
   2.101 +  divmod'_nat_def: "divmod_nat m n = (numeral m div numeral n, numeral m mod numeral n)"
   2.102 +
   2.103 +definition divmod_step_nat :: "num \<Rightarrow> nat \<times> nat \<Rightarrow> nat \<times> nat"
   2.104 +where
   2.105 +  "divmod_step_nat l qr = (let (q, r) = qr
   2.106 +    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
   2.107 +    else (2 * q, r))"
   2.108 +
   2.109 +instance
   2.110 +  by standard (auto intro: div_positive simp add: divmod'_nat_def divmod_step_nat_def mod_mult2_eq div_mult2_eq)
   2.111 +
   2.112 +end
   2.113 +
   2.114 +declare divmod_algorithm_code [where ?'a = nat, code]
   2.115 +  
   2.116  
   2.117  subsubsection \<open>Further Facts about Quotient and Remainder\<close>
   2.118  
   2.119 @@ -2304,12 +2323,26 @@
   2.120  
   2.121  subsubsection \<open>Computation of Division and Remainder\<close>
   2.122  
   2.123 -instance int :: semiring_numeral_div
   2.124 -  by intro_classes (auto intro: zmod_le_nonneg_dividend
   2.125 -    simp add:
   2.126 -    zmult_div_cancel
   2.127 -    pos_imp_zdiv_pos_iff div_pos_pos_trivial mod_pos_pos_trivial
   2.128 -    zmod_zmult2_eq zdiv_zmult2_eq)
   2.129 +instantiation int :: semiring_numeral_div
   2.130 +begin
   2.131 +
   2.132 +definition divmod_int :: "num \<Rightarrow> num \<Rightarrow> int \<times> int"
   2.133 +where
   2.134 +  "divmod_int m n = (numeral m div numeral n, numeral m mod numeral n)"
   2.135 +
   2.136 +definition divmod_step_int :: "num \<Rightarrow> int \<times> int \<Rightarrow> int \<times> int"
   2.137 +where
   2.138 +  "divmod_step_int l qr = (let (q, r) = qr
   2.139 +    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
   2.140 +    else (2 * q, r))"
   2.141 +
   2.142 +instance
   2.143 +  by standard (auto intro: zmod_le_nonneg_dividend simp add: divmod_int_def divmod_step_int_def
   2.144 +    pos_imp_zdiv_pos_iff div_pos_pos_trivial mod_pos_pos_trivial zmod_zmult2_eq zdiv_zmult2_eq)
   2.145 +
   2.146 +end
   2.147 +
   2.148 +declare divmod_algorithm_code [where ?'a = int, code]
   2.149  
   2.150  context
   2.151  begin
     3.1 --- a/src/HOL/Library/Code_Target_Int.thy	Sun Sep 27 10:11:14 2015 +0200
     3.2 +++ b/src/HOL/Library/Code_Target_Int.thy	Sun Sep 27 10:11:15 2015 +0200
     3.3 @@ -79,6 +79,11 @@
     3.4    by simp
     3.5  
     3.6  lemma [code]:
     3.7 +  "divmod m n = map_prod int_of_integer int_of_integer (divmod m n)"
     3.8 +  unfolding prod_eq_iff divmod_def map_prod_def case_prod_beta fst_conv snd_conv
     3.9 +  by transfer simp
    3.10 +
    3.11 +lemma [code]:
    3.12    "HOL.equal k l = HOL.equal (of_int k :: integer) (of_int l)"
    3.13    by transfer (simp add: equal)
    3.14  
     4.1 --- a/src/HOL/Library/Code_Target_Nat.thy	Sun Sep 27 10:11:14 2015 +0200
     4.2 +++ b/src/HOL/Library/Code_Target_Nat.thy	Sun Sep 27 10:11:15 2015 +0200
     4.3 @@ -85,6 +85,12 @@
     4.4    by (fact divmod_nat_div_mod)
     4.5  
     4.6  lemma [code]:
     4.7 +  "divmod m n = map_prod nat_of_integer nat_of_integer (divmod m n)"
     4.8 +  by (simp only: prod_eq_iff divmod_def map_prod_def case_prod_beta fst_conv snd_conv)
     4.9 +    (transfer, simp_all only: nat_div_distrib nat_mod_distrib
    4.10 +        zero_le_numeral nat_numeral)
    4.11 +  
    4.12 +lemma [code]:
    4.13    "HOL.equal m n = HOL.equal (of_nat m :: integer) (of_nat n)"
    4.14    by transfer (simp add: equal)
    4.15  
     5.1 --- a/src/HOL/NSA/StarDef.thy	Sun Sep 27 10:11:14 2015 +0200
     5.2 +++ b/src/HOL/NSA/StarDef.thy	Sun Sep 27 10:11:15 2015 +0200
     5.3 @@ -1009,18 +1009,42 @@
     5.4  apply(transfer, rule zero_not_eq_two)
     5.5  done
     5.6  
     5.7 -instance star :: (semiring_numeral_div) semiring_numeral_div
     5.8 -apply intro_classes
     5.9 -apply(transfer, fact semiring_numeral_div_class.div_less)
    5.10 -apply(transfer, fact semiring_numeral_div_class.mod_less)
    5.11 -apply(transfer, fact semiring_numeral_div_class.div_positive)
    5.12 -apply(transfer, fact semiring_numeral_div_class.mod_less_eq_dividend)
    5.13 -apply(transfer, fact semiring_numeral_div_class.pos_mod_bound)
    5.14 -apply(transfer, fact semiring_numeral_div_class.pos_mod_sign)
    5.15 -apply(transfer, fact semiring_numeral_div_class.mod_mult2_eq)
    5.16 -apply(transfer, fact semiring_numeral_div_class.div_mult2_eq)
    5.17 -apply(transfer, fact discrete)
    5.18 -done
    5.19 +instantiation star :: (semiring_numeral_div) semiring_numeral_div
    5.20 +begin
    5.21 +
    5.22 +definition divmod_star :: "num \<Rightarrow> num \<Rightarrow> 'a star \<times> 'a star"
    5.23 +where
    5.24 +  divmod_star_def: "divmod_star m n = (numeral m div numeral n, numeral m mod numeral n)"
    5.25 +
    5.26 +definition divmod_step_star :: "num \<Rightarrow> 'a star \<times> 'a star \<Rightarrow> 'a star \<times> 'a star"
    5.27 +where
    5.28 +  "divmod_step_star l qr = (let (q, r) = qr
    5.29 +    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
    5.30 +    else (2 * q, r))"
    5.31 +
    5.32 +instance proof
    5.33 +  show "divmod m n = (numeral m div numeral n :: 'a star, numeral m mod numeral n)"
    5.34 +    for m n by (fact divmod_star_def)
    5.35 +  show "divmod_step l qr = (let (q, r) = qr
    5.36 +    in if r \<ge> numeral l then (2 * q + 1, r - numeral l)
    5.37 +    else (2 * q, r))" for l and qr :: "'a star \<times> 'a star"
    5.38 +    by (fact divmod_step_star_def)
    5.39 +qed (transfer,
    5.40 +  fact
    5.41 +  semiring_numeral_div_class.div_less
    5.42 +  semiring_numeral_div_class.mod_less
    5.43 +  semiring_numeral_div_class.div_positive
    5.44 +  semiring_numeral_div_class.mod_less_eq_dividend
    5.45 +  semiring_numeral_div_class.pos_mod_bound
    5.46 +  semiring_numeral_div_class.pos_mod_sign
    5.47 +  semiring_numeral_div_class.mod_mult2_eq
    5.48 +  semiring_numeral_div_class.div_mult2_eq
    5.49 +  semiring_numeral_div_class.discrete)+
    5.50 +
    5.51 +end
    5.52 +
    5.53 +declare divmod_algorithm_code [where ?'a = "'a::semiring_numeral_div star", code]
    5.54 +
    5.55  
    5.56  subsection {* Finite class *}
    5.57