essential instance about bit structure
authorhaftmann
Thu, 18 Jun 2020 09:07:30 +0000
changeset 71951 ac6f9738c200
parent 71950 c9251bc7da4e
child 71952 2efc5b8c7456
essential instance about bit structure
src/HOL/Word/Word.thy
--- a/src/HOL/Word/Word.thy	Thu Jun 18 09:07:29 2020 +0000
+++ b/src/HOL/Word/Word.thy	Thu Jun 18 09:07:30 2020 +0000
@@ -286,6 +286,7 @@
 end
 
 
+
 text \<open>Legacy theorems:\<close>
 
 lemma word_arith_wis [code]:
@@ -406,6 +407,44 @@
 
 end
 
+instance word :: (len) semiring_modulo
+proof
+  show "a div b * b + a mod b = a" for a b :: "'a word"
+  proof transfer
+    fix k l :: int
+    define r :: int where "r = 2 ^ LENGTH('a)"
+    then have r: "take_bit LENGTH('a) k = k mod r" for k
+      by (simp add: take_bit_eq_mod)
+    have "k mod r = ((k mod r) div (l mod r) * (l mod r)
+      + (k mod r) mod (l mod r)) mod r"
+      by (simp add: div_mult_mod_eq)
+    also have "... = (((k mod r) div (l mod r) * (l mod r)) mod r
+      + (k mod r) mod (l mod r)) mod r"
+      by (simp add: mod_add_left_eq)
+    also have "... = (((k mod r) div (l mod r) * l) mod r
+      + (k mod r) mod (l mod r)) mod r"
+      by (simp add: mod_mult_right_eq)
+    finally have "k mod r = ((k mod r) div (l mod r) * l
+      + (k mod r) mod (l mod r)) mod r"
+      by (simp add: mod_simps)
+    with r show "take_bit LENGTH('a) (take_bit LENGTH('a) k div take_bit LENGTH('a) l * l
+      + take_bit LENGTH('a) k mod take_bit LENGTH('a) l) = take_bit LENGTH('a) k"
+      by simp
+  qed
+qed
+
+instance word :: (len) semiring_parity
+proof
+  show "\<not> 2 dvd (1::'a word)"
+    by transfer simp
+  show even_iff_mod_2_eq_0: "2 dvd a \<longleftrightarrow> a mod 2 = 0"
+    for a :: "'a word"
+    by transfer (simp_all add: mod_2_eq_odd take_bit_Suc)
+  show "\<not> 2 dvd a \<longleftrightarrow> a mod 2 = 1"
+    for a :: "'a word"
+    by transfer (simp_all add: mod_2_eq_odd take_bit_Suc)
+qed
+
 
 subsection \<open>Ordering\<close>
 
@@ -433,6 +472,42 @@
   "a < b \<longleftrightarrow> uint a < uint b"
   by transfer rule
 
+lemma word_greater_zero_iff:
+  \<open>a > 0 \<longleftrightarrow> a \<noteq> 0\<close> for a :: \<open>'a::len0 word\<close>
+  by transfer (simp add: less_le)
+
+lemma of_nat_word_eq_iff:
+  \<open>of_nat m = (of_nat n :: 'a::len word) \<longleftrightarrow> take_bit LENGTH('a) m = take_bit LENGTH('a) n\<close>
+  by transfer (simp add: take_bit_of_nat)
+
+lemma of_nat_word_less_eq_iff:
+  \<open>of_nat m \<le> (of_nat n :: 'a::len word) \<longleftrightarrow> take_bit LENGTH('a) m \<le> take_bit LENGTH('a) n\<close>
+  by transfer (simp add: take_bit_of_nat)
+
+lemma of_nat_word_less_iff:
+  \<open>of_nat m < (of_nat n :: 'a::len word) \<longleftrightarrow> take_bit LENGTH('a) m < take_bit LENGTH('a) n\<close>
+  by transfer (simp add: take_bit_of_nat)
+
+lemma of_nat_word_eq_0_iff:
+  \<open>of_nat n = (0 :: 'a::len word) \<longleftrightarrow> 2 ^ LENGTH('a) dvd n\<close>
+  using of_nat_word_eq_iff [where ?'a = 'a, of n 0] by (simp add: take_bit_eq_0_iff)
+
+lemma of_int_word_eq_iff:
+  \<open>of_int k = (of_int l :: 'a::len word) \<longleftrightarrow> take_bit LENGTH('a) k = take_bit LENGTH('a) l\<close>
+  by transfer rule
+
+lemma of_int_word_less_eq_iff:
+  \<open>of_int k \<le> (of_int l :: 'a::len word) \<longleftrightarrow> take_bit LENGTH('a) k \<le> take_bit LENGTH('a) l\<close>
+  by transfer rule
+
+lemma of_int_word_less_iff:
+  \<open>of_int k < (of_int l :: 'a::len word) \<longleftrightarrow> take_bit LENGTH('a) k < take_bit LENGTH('a) l\<close>
+  by transfer rule
+
+lemma of_int_word_eq_0_iff:
+  \<open>of_int k = (0 :: 'a::len word) \<longleftrightarrow> 2 ^ LENGTH('a) dvd k\<close>
+  using of_int_word_eq_iff [where ?'a = 'a, of k 0] by (simp add: take_bit_eq_0_iff)
+
 definition word_sle :: "'a::len word \<Rightarrow> 'a word \<Rightarrow> bool"  ("(_/ <=s _)" [50, 51] 50)
   where "a <=s b \<longleftrightarrow> sint a \<le> sint b"
 
@@ -442,6 +517,224 @@
 
 subsection \<open>Bit-wise operations\<close>
 
+lemma word_bit_induct [case_names zero even odd]:
+  \<open>P a\<close> if word_zero: \<open>P 0\<close>
+    and word_even: \<open>\<And>a. P a \<Longrightarrow> 0 < a \<Longrightarrow> a < 2 ^ (LENGTH('a) - 1) \<Longrightarrow> P (2 * a)\<close>
+    and word_odd: \<open>\<And>a. P a \<Longrightarrow> a < 2 ^ (LENGTH('a) - 1) \<Longrightarrow> P (1 + 2 * a)\<close>
+  for P and a :: \<open>'a::len word\<close>
+proof -
+  define m :: nat where \<open>m = LENGTH('a) - 1\<close>
+  then have l: \<open>LENGTH('a) = Suc m\<close>
+    by simp
+  define n :: nat where \<open>n = unat a\<close>
+  then have \<open>n < 2 ^ LENGTH('a)\<close>
+    by (unfold unat_def) (transfer, simp add: take_bit_eq_mod)
+  then have \<open>n < 2 * 2 ^ m\<close>
+    by (simp add: l)
+  then have \<open>P (of_nat n)\<close>
+  proof (induction n rule: nat_bit_induct)
+    case zero
+    show ?case
+      by simp (rule word_zero)
+  next
+    case (even n)
+    then have \<open>n < 2 ^ m\<close>
+      by simp
+    with even.IH have \<open>P (of_nat n)\<close>
+      by simp
+    moreover from \<open>n < 2 ^ m\<close> even.hyps have \<open>0 < (of_nat n :: 'a word)\<close>
+      by (auto simp add: word_greater_zero_iff of_nat_word_eq_0_iff l)
+    moreover from \<open>n < 2 ^ m\<close> have \<open>(of_nat n :: 'a word) < 2 ^ (LENGTH('a) - 1)\<close>
+      using of_nat_word_less_iff [where ?'a = 'a, of n \<open>2 ^ m\<close>]
+      by (cases \<open>m = 0\<close>) (simp_all add: not_less take_bit_eq_self ac_simps l)
+    ultimately have \<open>P (2 * of_nat n)\<close>
+      by (rule word_even)
+    then show ?case
+      by simp
+  next
+    case (odd n)
+    then have \<open>Suc n \<le> 2 ^ m\<close>
+      by simp
+    with odd.IH have \<open>P (of_nat n)\<close>
+      by simp
+    moreover from \<open>Suc n \<le> 2 ^ m\<close> have \<open>(of_nat n :: 'a word) < 2 ^ (LENGTH('a) - 1)\<close>
+      using of_nat_word_less_iff [where ?'a = 'a, of n \<open>2 ^ m\<close>]
+      by (cases \<open>m = 0\<close>) (simp_all add: not_less take_bit_eq_self ac_simps l)
+    ultimately have \<open>P (1 + 2 * of_nat n)\<close>
+      by (rule word_odd)
+    then show ?case
+      by simp
+  qed
+  moreover have \<open>of_nat (nat (uint a)) = a\<close>
+    by transfer simp
+  ultimately show ?thesis
+    by (simp add: n_def unat_def)
+qed
+
+lemma bit_word_half_eq:
+  \<open>(of_bool b + a * 2) div 2 = a\<close>
+    if \<open>a < 2 ^ (LENGTH('a) - Suc 0)\<close>
+    for a :: \<open>'a::len word\<close>
+proof (cases \<open>2 \<le> LENGTH('a::len)\<close>)
+  case False
+  have \<open>of_bool (odd k) < (1 :: int) \<longleftrightarrow> even k\<close> for k :: int
+    by auto
+  with False that show ?thesis
+    by transfer (simp add: eq_iff)
+next
+  case True
+  obtain n where length: \<open>LENGTH('a) = Suc n\<close>
+    by (cases \<open>LENGTH('a)\<close>) simp_all
+  show ?thesis proof (cases b)
+    case False
+    moreover have \<open>a * 2 div 2 = a\<close>
+    using that proof transfer
+      fix k :: int
+      from length have \<open>k * 2 mod 2 ^ LENGTH('a) = (k mod 2 ^ n) * 2\<close>
+        by simp
+      moreover assume \<open>take_bit LENGTH('a) k < take_bit LENGTH('a) (2 ^ (LENGTH('a) - Suc 0))\<close>
+      with \<open>LENGTH('a) = Suc n\<close>
+      have \<open>k mod 2 ^ LENGTH('a) = k mod 2 ^ n\<close>
+        by (simp add: take_bit_eq_mod divmod_digit_0)
+      ultimately have \<open>take_bit LENGTH('a) (k * 2) = take_bit LENGTH('a) k * 2\<close>
+        by (simp add: take_bit_eq_mod)
+      with True show \<open>take_bit LENGTH('a) (take_bit LENGTH('a) (k * 2) div take_bit LENGTH('a) 2)
+        = take_bit LENGTH('a) k\<close>
+        by simp
+    qed
+    ultimately show ?thesis
+      by simp
+  next
+    case True
+    moreover have \<open>(1 + a * 2) div 2 = a\<close>
+    using that proof transfer
+      fix k :: int
+      from length have \<open>(1 + k * 2) mod 2 ^ LENGTH('a) = 1 + (k mod 2 ^ n) * 2\<close>
+        using pos_zmod_mult_2 [of \<open>2 ^ n\<close> k] by (simp add: ac_simps)
+      moreover assume \<open>take_bit LENGTH('a) k < take_bit LENGTH('a) (2 ^ (LENGTH('a) - Suc 0))\<close>
+      with \<open>LENGTH('a) = Suc n\<close>
+      have \<open>k mod 2 ^ LENGTH('a) = k mod 2 ^ n\<close>
+        by (simp add: take_bit_eq_mod divmod_digit_0)
+      ultimately have \<open>take_bit LENGTH('a) (1 + k * 2) = 1 + take_bit LENGTH('a) k * 2\<close>
+        by (simp add: take_bit_eq_mod)
+      with True show \<open>take_bit LENGTH('a) (take_bit LENGTH('a) (1 + k * 2) div take_bit LENGTH('a) 2)
+        = take_bit LENGTH('a) k\<close>
+        by (auto simp add: take_bit_Suc)
+    qed
+    ultimately show ?thesis
+      by simp
+  qed
+qed
+
+lemma even_mult_exp_div_word_iff:
+  \<open>even (a * 2 ^ m div 2 ^ n) \<longleftrightarrow> \<not> (
+    m \<le> n \<and>
+    n < LENGTH('a) \<and> odd (a div 2 ^ (n - m)))\<close> for a :: \<open>'a::len word\<close>
+  by transfer
+    (auto simp flip: drop_bit_eq_div simp add: even_drop_bit_iff_not_bit bit_take_bit_iff,
+      simp_all flip: push_bit_eq_mult add: bit_push_bit_iff_int)
+
+instance word :: (len) semiring_bits
+proof
+  show \<open>P a\<close> if stable: \<open>\<And>a. a div 2 = a \<Longrightarrow> P a\<close>
+    and rec: \<open>\<And>a b. P a \<Longrightarrow> (of_bool b + 2 * a) div 2 = a \<Longrightarrow> P (of_bool b + 2 * a)\<close>
+  for P and a :: \<open>'a word\<close>
+  proof (induction a rule: word_bit_induct)
+    case zero
+    have \<open>0 div 2 = (0::'a word)\<close>
+      by transfer simp
+    with stable [of 0] show ?case
+      by simp
+  next
+    case (even a)
+    with rec [of a False] show ?case
+      using bit_word_half_eq [of a False] by (simp add: ac_simps)
+  next
+    case (odd a)
+    with rec [of a True] show ?case
+      using bit_word_half_eq [of a True] by (simp add: ac_simps)
+  qed
+  show \<open>0 div a = 0\<close>
+    for a :: \<open>'a word\<close>
+    by transfer simp
+  show \<open>a div 1 = a\<close>
+    for a :: \<open>'a word\<close>
+    by transfer simp
+  show \<open>a mod b div b = 0\<close>
+    for a b :: \<open>'a word\<close>
+    apply transfer
+    apply (simp add: take_bit_eq_mod)
+    apply (subst (3) mod_pos_pos_trivial [of _ \<open>2 ^ LENGTH('a)\<close>])
+      apply simp_all
+     apply (metis le_less mod_by_0 pos_mod_conj zero_less_numeral zero_less_power)
+    using pos_mod_bound [of \<open>2 ^ LENGTH('a)\<close>] apply simp
+  proof -
+    fix aa :: int and ba :: int
+    have f1: "\<And>i n. (i::int) mod 2 ^ n = 0 \<or> 0 < i mod 2 ^ n"
+      by (metis le_less take_bit_eq_mod take_bit_nonnegative)
+    have "(0::int) < 2 ^ len_of (TYPE('a)::'a itself) \<and> ba mod 2 ^ len_of (TYPE('a)::'a itself) \<noteq> 0 \<or> aa mod 2 ^ len_of (TYPE('a)::'a itself) mod (ba mod 2 ^ len_of (TYPE('a)::'a itself)) < 2 ^ len_of (TYPE('a)::'a itself)"
+      by (metis (no_types) mod_by_0 unique_euclidean_semiring_numeral_class.pos_mod_bound zero_less_numeral zero_less_power)
+    then show "aa mod 2 ^ len_of (TYPE('a)::'a itself) mod (ba mod 2 ^ len_of (TYPE('a)::'a itself)) < 2 ^ len_of (TYPE('a)::'a itself)"
+      using f1 by (meson le_less less_le_trans unique_euclidean_semiring_numeral_class.pos_mod_bound)
+  qed
+  show \<open>(1 + a) div 2 = a div 2\<close>
+    if \<open>even a\<close>
+    for a :: \<open>'a word\<close>
+    using that by transfer (auto dest: le_Suc_ex simp add: take_bit_Suc)
+  show \<open>(2 :: 'a word) ^ m div 2 ^ n = of_bool ((2 :: 'a word) ^ m \<noteq> 0 \<and> n \<le> m) * 2 ^ (m - n)\<close>
+    for m n :: nat
+    by transfer (simp, simp add: exp_div_exp_eq)
+  show "a div 2 ^ m div 2 ^ n = a div 2 ^ (m + n)"
+    for a :: "'a word" and m n :: nat
+    apply transfer
+    apply (auto simp add: not_less take_bit_drop_bit ac_simps simp flip: drop_bit_eq_div)
+    apply (simp add: drop_bit_take_bit)
+    done
+  show "a mod 2 ^ m mod 2 ^ n = a mod 2 ^ min m n"
+    for a :: "'a word" and m n :: nat
+    by transfer (auto simp flip: take_bit_eq_mod simp add: ac_simps)
+  show \<open>a * 2 ^ m mod 2 ^ n = a mod 2 ^ (n - m) * 2 ^ m\<close>
+    if \<open>m \<le> n\<close> for a :: "'a word" and m n :: nat
+    using that apply transfer
+    apply (auto simp flip: take_bit_eq_mod)
+           apply (auto simp flip: push_bit_eq_mult simp add: push_bit_take_bit split: split_min_lin)
+    done
+  show \<open>a div 2 ^ n mod 2 ^ m = a mod (2 ^ (n + m)) div 2 ^ n\<close>
+    for a :: "'a word" and m n :: nat
+    by transfer (auto simp add: not_less take_bit_drop_bit ac_simps simp flip: take_bit_eq_mod drop_bit_eq_div split: split_min_lin)
+  show \<open>even ((2 ^ m - 1) div (2::'a word) ^ n) \<longleftrightarrow> 2 ^ n = (0::'a word) \<or> m \<le> n\<close>
+    for m n :: nat
+    by transfer (auto simp add: take_bit_of_mask even_mask_div_iff)
+  show \<open>even (a * 2 ^ m div 2 ^ n) \<longleftrightarrow> n < m \<or> (2::'a word) ^ n = 0 \<or> m \<le> n \<and> even (a div 2 ^ (n - m))\<close>
+    for a :: \<open>'a word\<close> and m n :: nat
+  proof transfer
+    show \<open>even (take_bit LENGTH('a) (k * 2 ^ m) div take_bit LENGTH('a) (2 ^ n)) \<longleftrightarrow>
+      n < m
+      \<or> take_bit LENGTH('a) ((2::int) ^ n) = take_bit LENGTH('a) 0
+      \<or> (m \<le> n \<and> even (take_bit LENGTH('a) k div take_bit LENGTH('a) (2 ^ (n - m))))\<close>
+    for m n :: nat and k l :: int
+      by (auto simp flip: take_bit_eq_mod drop_bit_eq_div push_bit_eq_mult
+        simp add: div_push_bit_of_1_eq_drop_bit drop_bit_take_bit drop_bit_push_bit_int [of n m])
+  qed
+qed
+
+context
+  includes lifting_syntax
+begin
+
+lemma transfer_rule_bit_word [transfer_rule]:
+  \<open>((pcr_word :: int \<Rightarrow> 'a::len word \<Rightarrow> bool) ===> (=)) (\<lambda>k n. n < LENGTH('a) \<and> bit k n) bit\<close>
+proof -
+  let ?t = \<open>\<lambda>a n. odd (take_bit LENGTH('a) a div take_bit LENGTH('a) ((2::int) ^ n))\<close>
+  have \<open>((pcr_word :: int \<Rightarrow> 'a word \<Rightarrow> bool) ===> (=)) ?t bit\<close>
+    by (unfold bit_def) transfer_prover
+  also have \<open>?t = (\<lambda>k n. n < LENGTH('a) \<and> bit k n)\<close>
+    by (simp add: fun_eq_iff bit_take_bit_iff flip: bit_def)
+  finally show ?thesis .
+qed
+
+end
+
 definition shiftl1 :: "'a::len0 word \<Rightarrow> 'a word"
   where "shiftl1 w = word_of_int (uint w BIT False)"