|
1 (* Title: HOL/Library/Code_Binary_Nat.thy |
|
2 Author: Stefan Berghofer, Florian Haftmann, TU Muenchen |
|
3 *) |
|
4 |
|
5 header {* Implementation of natural numbers as binary numerals *} |
|
6 |
|
7 theory Code_Binary_Nat |
|
8 imports Main |
|
9 begin |
|
10 |
|
11 text {* |
|
12 When generating code for functions on natural numbers, the |
|
13 canonical representation using @{term "0::nat"} and |
|
14 @{term Suc} is unsuitable for computations involving large |
|
15 numbers. This theory refines the representation of |
|
16 natural numbers for code generation to use binary |
|
17 numerals, which do not grow linear in size but logarithmic. |
|
18 *} |
|
19 |
|
20 subsection {* Representation *} |
|
21 |
|
22 code_datatype "0::nat" nat_of_num |
|
23 |
|
24 lemma [code_abbrev]: |
|
25 "nat_of_num = numeral" |
|
26 by (fact nat_of_num_numeral) |
|
27 |
|
28 lemma [code]: |
|
29 "num_of_nat 0 = Num.One" |
|
30 "num_of_nat (nat_of_num k) = k" |
|
31 by (simp_all add: nat_of_num_inverse) |
|
32 |
|
33 lemma [code]: |
|
34 "(1\<Colon>nat) = Numeral1" |
|
35 by simp |
|
36 |
|
37 lemma [code_abbrev]: "Numeral1 = (1\<Colon>nat)" |
|
38 by simp |
|
39 |
|
40 lemma [code]: |
|
41 "Suc n = n + 1" |
|
42 by simp |
|
43 |
|
44 |
|
45 subsection {* Basic arithmetic *} |
|
46 |
|
47 lemma [code, code del]: |
|
48 "(plus :: nat \<Rightarrow> _) = plus" .. |
|
49 |
|
50 lemma plus_nat_code [code]: |
|
51 "nat_of_num k + nat_of_num l = nat_of_num (k + l)" |
|
52 "m + 0 = (m::nat)" |
|
53 "0 + n = (n::nat)" |
|
54 by (simp_all add: nat_of_num_numeral) |
|
55 |
|
56 text {* Bounded subtraction needs some auxiliary *} |
|
57 |
|
58 definition dup :: "nat \<Rightarrow> nat" where |
|
59 "dup n = n + n" |
|
60 |
|
61 lemma dup_code [code]: |
|
62 "dup 0 = 0" |
|
63 "dup (nat_of_num k) = nat_of_num (Num.Bit0 k)" |
|
64 by (simp_all add: dup_def numeral_Bit0) |
|
65 |
|
66 definition sub :: "num \<Rightarrow> num \<Rightarrow> nat option" where |
|
67 "sub k l = (if k \<ge> l then Some (numeral k - numeral l) else None)" |
|
68 |
|
69 lemma sub_code [code]: |
|
70 "sub Num.One Num.One = Some 0" |
|
71 "sub (Num.Bit0 m) Num.One = Some (nat_of_num (Num.BitM m))" |
|
72 "sub (Num.Bit1 m) Num.One = Some (nat_of_num (Num.Bit0 m))" |
|
73 "sub Num.One (Num.Bit0 n) = None" |
|
74 "sub Num.One (Num.Bit1 n) = None" |
|
75 "sub (Num.Bit0 m) (Num.Bit0 n) = Option.map dup (sub m n)" |
|
76 "sub (Num.Bit1 m) (Num.Bit1 n) = Option.map dup (sub m n)" |
|
77 "sub (Num.Bit1 m) (Num.Bit0 n) = Option.map (\<lambda>q. dup q + 1) (sub m n)" |
|
78 "sub (Num.Bit0 m) (Num.Bit1 n) = (case sub m n of None \<Rightarrow> None |
|
79 | Some q \<Rightarrow> if q = 0 then None else Some (dup q - 1))" |
|
80 apply (auto simp add: nat_of_num_numeral |
|
81 Num.dbl_def Num.dbl_inc_def Num.dbl_dec_def |
|
82 Let_def le_imp_diff_is_add BitM_plus_one sub_def dup_def) |
|
83 apply (simp_all add: sub_non_positive) |
|
84 apply (simp_all add: sub_non_negative [symmetric, where ?'a = int]) |
|
85 done |
|
86 |
|
87 lemma [code, code del]: |
|
88 "(minus :: nat \<Rightarrow> _) = minus" .. |
|
89 |
|
90 lemma minus_nat_code [code]: |
|
91 "nat_of_num k - nat_of_num l = (case sub k l of None \<Rightarrow> 0 | Some j \<Rightarrow> j)" |
|
92 "m - 0 = (m::nat)" |
|
93 "0 - n = (0::nat)" |
|
94 by (simp_all add: nat_of_num_numeral sub_non_positive sub_def) |
|
95 |
|
96 lemma [code, code del]: |
|
97 "(times :: nat \<Rightarrow> _) = times" .. |
|
98 |
|
99 lemma times_nat_code [code]: |
|
100 "nat_of_num k * nat_of_num l = nat_of_num (k * l)" |
|
101 "m * 0 = (0::nat)" |
|
102 "0 * n = (0::nat)" |
|
103 by (simp_all add: nat_of_num_numeral) |
|
104 |
|
105 lemma [code, code del]: |
|
106 "(HOL.equal :: nat \<Rightarrow> _) = HOL.equal" .. |
|
107 |
|
108 lemma equal_nat_code [code]: |
|
109 "HOL.equal 0 (0::nat) \<longleftrightarrow> True" |
|
110 "HOL.equal 0 (nat_of_num l) \<longleftrightarrow> False" |
|
111 "HOL.equal (nat_of_num k) 0 \<longleftrightarrow> False" |
|
112 "HOL.equal (nat_of_num k) (nat_of_num l) \<longleftrightarrow> HOL.equal k l" |
|
113 by (simp_all add: nat_of_num_numeral equal) |
|
114 |
|
115 lemma equal_nat_refl [code nbe]: |
|
116 "HOL.equal (n::nat) n \<longleftrightarrow> True" |
|
117 by (rule equal_refl) |
|
118 |
|
119 lemma [code, code del]: |
|
120 "(less_eq :: nat \<Rightarrow> _) = less_eq" .. |
|
121 |
|
122 lemma less_eq_nat_code [code]: |
|
123 "0 \<le> (n::nat) \<longleftrightarrow> True" |
|
124 "nat_of_num k \<le> 0 \<longleftrightarrow> False" |
|
125 "nat_of_num k \<le> nat_of_num l \<longleftrightarrow> k \<le> l" |
|
126 by (simp_all add: nat_of_num_numeral) |
|
127 |
|
128 lemma [code, code del]: |
|
129 "(less :: nat \<Rightarrow> _) = less" .. |
|
130 |
|
131 lemma less_nat_code [code]: |
|
132 "(m::nat) < 0 \<longleftrightarrow> False" |
|
133 "0 < nat_of_num l \<longleftrightarrow> True" |
|
134 "nat_of_num k < nat_of_num l \<longleftrightarrow> k < l" |
|
135 by (simp_all add: nat_of_num_numeral) |
|
136 |
|
137 |
|
138 subsection {* Conversions *} |
|
139 |
|
140 lemma [code, code del]: |
|
141 "of_nat = of_nat" .. |
|
142 |
|
143 lemma of_nat_code [code]: |
|
144 "of_nat 0 = 0" |
|
145 "of_nat (nat_of_num k) = numeral k" |
|
146 by (simp_all add: nat_of_num_numeral) |
|
147 |
|
148 |
|
149 subsection {* Case analysis *} |
|
150 |
|
151 text {* |
|
152 Case analysis on natural numbers is rephrased using a conditional |
|
153 expression: |
|
154 *} |
|
155 |
|
156 lemma [code, code_unfold]: |
|
157 "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))" |
|
158 by (auto simp add: fun_eq_iff dest!: gr0_implies_Suc) |
|
159 |
|
160 |
|
161 subsection {* Preprocessors *} |
|
162 |
|
163 text {* |
|
164 The term @{term "Suc n"} is no longer a valid pattern. |
|
165 Therefore, all occurrences of this term in a position |
|
166 where a pattern is expected (i.e.~on the left-hand side of a recursion |
|
167 equation) must be eliminated. |
|
168 This can be accomplished by applying the following transformation rules: |
|
169 *} |
|
170 |
|
171 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow> |
|
172 f n \<equiv> if n = 0 then g else h (n - 1)" |
|
173 by (rule eq_reflection) (cases n, simp_all) |
|
174 |
|
175 text {* |
|
176 The rules above are built into a preprocessor that is plugged into |
|
177 the code generator. Since the preprocessor for introduction rules |
|
178 does not know anything about modes, some of the modes that worked |
|
179 for the canonical representation of natural numbers may no longer work. |
|
180 *} |
|
181 |
|
182 (*<*) |
|
183 setup {* |
|
184 let |
|
185 |
|
186 fun remove_suc thy thms = |
|
187 let |
|
188 val vname = singleton (Name.variant_list (map fst |
|
189 (fold (Term.add_var_names o Thm.full_prop_of) thms []))) "n"; |
|
190 val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT)); |
|
191 fun lhs_of th = snd (Thm.dest_comb |
|
192 (fst (Thm.dest_comb (cprop_of th)))); |
|
193 fun rhs_of th = snd (Thm.dest_comb (cprop_of th)); |
|
194 fun find_vars ct = (case term_of ct of |
|
195 (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))] |
|
196 | _ $ _ => |
|
197 let val (ct1, ct2) = Thm.dest_comb ct |
|
198 in |
|
199 map (apfst (fn ct => Thm.apply ct ct2)) (find_vars ct1) @ |
|
200 map (apfst (Thm.apply ct1)) (find_vars ct2) |
|
201 end |
|
202 | _ => []); |
|
203 val eqs = maps |
|
204 (fn th => map (pair th) (find_vars (lhs_of th))) thms; |
|
205 fun mk_thms (th, (ct, cv')) = |
|
206 let |
|
207 val th' = |
|
208 Thm.implies_elim |
|
209 (Conv.fconv_rule (Thm.beta_conversion true) |
|
210 (Drule.instantiate' |
|
211 [SOME (ctyp_of_term ct)] [SOME (Thm.lambda cv ct), |
|
212 SOME (Thm.lambda cv' (rhs_of th)), NONE, SOME cv'] |
|
213 @{thm Suc_if_eq})) (Thm.forall_intr cv' th) |
|
214 in |
|
215 case map_filter (fn th'' => |
|
216 SOME (th'', singleton |
|
217 (Variable.trade (K (fn [th'''] => [th''' RS th'])) |
|
218 (Variable.global_thm_context th'')) th'') |
|
219 handle THM _ => NONE) thms of |
|
220 [] => NONE |
|
221 | thps => |
|
222 let val (ths1, ths2) = split_list thps |
|
223 in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end |
|
224 end |
|
225 in get_first mk_thms eqs end; |
|
226 |
|
227 fun eqn_suc_base_preproc thy thms = |
|
228 let |
|
229 val dest = fst o Logic.dest_equals o prop_of; |
|
230 val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc}); |
|
231 in |
|
232 if forall (can dest) thms andalso exists (contains_suc o dest) thms |
|
233 then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes |
|
234 else NONE |
|
235 end; |
|
236 |
|
237 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc; |
|
238 |
|
239 in |
|
240 |
|
241 Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc) |
|
242 |
|
243 end; |
|
244 *} |
|
245 (*>*) |
|
246 |
|
247 code_modulename SML |
|
248 Code_Binary_Nat Arith |
|
249 |
|
250 code_modulename OCaml |
|
251 Code_Binary_Nat Arith |
|
252 |
|
253 code_modulename Haskell |
|
254 Code_Binary_Nat Arith |
|
255 |
|
256 hide_const (open) dup sub |
|
257 |
|
258 end |
|
259 |