src/HOL/Library/Code_Binary_Nat.thy
1 (*  Title:      HOL/Library/Code_Binary_Nat.thy
2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
3 *)
5 header {* Implementation of natural numbers as binary numerals *}
7 theory Code_Binary_Nat
8 imports Main
9 begin
11 text {*
12   When generating code for functions on natural numbers, the
13   canonical representation using @{term "0::nat"} and
14   @{term Suc} is unsuitable for computations involving large
15   numbers.  This theory refines the representation of
16   natural numbers for code generation to use binary
17   numerals, which do not grow linear in size but logarithmic.
18 *}
20 subsection {* Representation *}
22 code_datatype "0::nat" nat_of_num
24 lemma [code_abbrev]:
25   "nat_of_num = numeral"
26   by (fact nat_of_num_numeral)
28 lemma [code]:
29   "num_of_nat 0 = Num.One"
30   "num_of_nat (nat_of_num k) = k"
33 lemma [code]:
34   "(1\<Colon>nat) = Numeral1"
35   by simp
37 lemma [code_abbrev]: "Numeral1 = (1\<Colon>nat)"
38   by simp
40 lemma [code]:
41   "Suc n = n + 1"
42   by simp
45 subsection {* Basic arithmetic *}
47 lemma [code, code del]:
48   "(plus :: nat \<Rightarrow> _) = plus" ..
50 lemma plus_nat_code [code]:
51   "nat_of_num k + nat_of_num l = nat_of_num (k + l)"
52   "m + 0 = (m::nat)"
53   "0 + n = (n::nat)"
56 text {* Bounded subtraction needs some auxiliary *}
58 definition dup :: "nat \<Rightarrow> nat" where
59   "dup n = n + n"
61 lemma dup_code [code]:
62   "dup 0 = 0"
63   "dup (nat_of_num k) = nat_of_num (Num.Bit0 k)"
64   by (simp_all add: dup_def numeral_Bit0)
66 definition sub :: "num \<Rightarrow> num \<Rightarrow> nat option" where
67   "sub k l = (if k \<ge> l then Some (numeral k - numeral l) else None)"
69 lemma sub_code [code]:
70   "sub Num.One Num.One = Some 0"
71   "sub (Num.Bit0 m) Num.One = Some (nat_of_num (Num.BitM m))"
72   "sub (Num.Bit1 m) Num.One = Some (nat_of_num (Num.Bit0 m))"
73   "sub Num.One (Num.Bit0 n) = None"
74   "sub Num.One (Num.Bit1 n) = None"
75   "sub (Num.Bit0 m) (Num.Bit0 n) = Option.map dup (sub m n)"
76   "sub (Num.Bit1 m) (Num.Bit1 n) = Option.map dup (sub m n)"
77   "sub (Num.Bit1 m) (Num.Bit0 n) = Option.map (\<lambda>q. dup q + 1) (sub m n)"
78   "sub (Num.Bit0 m) (Num.Bit1 n) = (case sub m n of None \<Rightarrow> None
79      | Some q \<Rightarrow> if q = 0 then None else Some (dup q - 1))"
80   apply (auto simp add: nat_of_num_numeral
81     Num.dbl_def Num.dbl_inc_def Num.dbl_dec_def
82     Let_def le_imp_diff_is_add BitM_plus_one sub_def dup_def)
84   apply (simp_all add: sub_non_negative [symmetric, where ?'a = int])
85   done
87 lemma [code, code del]:
88   "(minus :: nat \<Rightarrow> _) = minus" ..
90 lemma minus_nat_code [code]:
91   "nat_of_num k - nat_of_num l = (case sub k l of None \<Rightarrow> 0 | Some j \<Rightarrow> j)"
92   "m - 0 = (m::nat)"
93   "0 - n = (0::nat)"
94   by (simp_all add: nat_of_num_numeral sub_non_positive sub_def)
96 lemma [code, code del]:
97   "(times :: nat \<Rightarrow> _) = times" ..
99 lemma times_nat_code [code]:
100   "nat_of_num k * nat_of_num l = nat_of_num (k * l)"
101   "m * 0 = (0::nat)"
102   "0 * n = (0::nat)"
105 lemma [code, code del]:
106   "(HOL.equal :: nat \<Rightarrow> _) = HOL.equal" ..
108 lemma equal_nat_code [code]:
109   "HOL.equal 0 (0::nat) \<longleftrightarrow> True"
110   "HOL.equal 0 (nat_of_num l) \<longleftrightarrow> False"
111   "HOL.equal (nat_of_num k) 0 \<longleftrightarrow> False"
112   "HOL.equal (nat_of_num k) (nat_of_num l) \<longleftrightarrow> HOL.equal k l"
113   by (simp_all add: nat_of_num_numeral equal)
115 lemma equal_nat_refl [code nbe]:
116   "HOL.equal (n::nat) n \<longleftrightarrow> True"
117   by (rule equal_refl)
119 lemma [code, code del]:
120   "(less_eq :: nat \<Rightarrow> _) = less_eq" ..
122 lemma less_eq_nat_code [code]:
123   "0 \<le> (n::nat) \<longleftrightarrow> True"
124   "nat_of_num k \<le> 0 \<longleftrightarrow> False"
125   "nat_of_num k \<le> nat_of_num l \<longleftrightarrow> k \<le> l"
128 lemma [code, code del]:
129   "(less :: nat \<Rightarrow> _) = less" ..
131 lemma less_nat_code [code]:
132   "(m::nat) < 0 \<longleftrightarrow> False"
133   "0 < nat_of_num l \<longleftrightarrow> True"
134   "nat_of_num k < nat_of_num l \<longleftrightarrow> k < l"
138 subsection {* Conversions *}
140 lemma [code, code del]:
141   "of_nat = of_nat" ..
143 lemma of_nat_code [code]:
144   "of_nat 0 = 0"
145   "of_nat (nat_of_num k) = numeral k"
149 subsection {* Case analysis *}
151 text {*
152   Case analysis on natural numbers is rephrased using a conditional
153   expression:
154 *}
156 lemma [code, code_unfold]:
157   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
158   by (auto simp add: fun_eq_iff dest!: gr0_implies_Suc)
161 subsection {* Preprocessors *}
163 text {*
164   The term @{term "Suc n"} is no longer a valid pattern.
165   Therefore, all occurrences of this term in a position
166   where a pattern is expected (i.e.~on the left-hand side of a recursion
167   equation) must be eliminated.
168   This can be accomplished by applying the following transformation rules:
169 *}
171 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
172   f n \<equiv> if n = 0 then g else h (n - 1)"
173   by (rule eq_reflection) (cases n, simp_all)
175 text {*
176   The rules above are built into a preprocessor that is plugged into
177   the code generator. Since the preprocessor for introduction rules
178   does not know anything about modes, some of the modes that worked
179   for the canonical representation of natural numbers may no longer work.
180 *}
182 (*<*)
183 setup {*
184 let
186 fun remove_suc thy thms =
187   let
188     val vname = singleton (Name.variant_list (map fst
189       (fold (Term.add_var_names o Thm.full_prop_of) thms []))) "n";
190     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
191     fun lhs_of th = snd (Thm.dest_comb
192       (fst (Thm.dest_comb (cprop_of th))));
193     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
194     fun find_vars ct = (case term_of ct of
195         (Const (@{const_name Suc}, _) \$ Var _) => [(cv, snd (Thm.dest_comb ct))]
196       | _ \$ _ =>
197         let val (ct1, ct2) = Thm.dest_comb ct
198         in
199           map (apfst (fn ct => Thm.apply ct ct2)) (find_vars ct1) @
200           map (apfst (Thm.apply ct1)) (find_vars ct2)
201         end
202       | _ => []);
203     val eqs = maps
204       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
205     fun mk_thms (th, (ct, cv')) =
206       let
207         val th' =
208           Thm.implies_elim
209            (Conv.fconv_rule (Thm.beta_conversion true)
210              (Drule.instantiate'
211                [SOME (ctyp_of_term ct)] [SOME (Thm.lambda cv ct),
212                  SOME (Thm.lambda cv' (rhs_of th)), NONE, SOME cv']
213                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
214       in
215         case map_filter (fn th'' =>
216             SOME (th'', singleton
217               (Variable.trade (K (fn [th'''] => [th''' RS th']))
218                 (Variable.global_thm_context th'')) th'')
219           handle THM _ => NONE) thms of
220             [] => NONE
221           | thps =>
222               let val (ths1, ths2) = split_list thps
223               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
224       end
225   in get_first mk_thms eqs end;
227 fun eqn_suc_base_preproc thy thms =
228   let
229     val dest = fst o Logic.dest_equals o prop_of;
230     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
231   in
232     if forall (can dest) thms andalso exists (contains_suc o dest) thms
233       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
234        else NONE
235   end;
237 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
239 in