|
1 (* Title: HOL/Library/Efficient_Nat.thy |
|
2 ID: $Id$ |
|
3 Author: Stefan Berghofer, TU Muenchen |
|
4 *) |
|
5 |
|
6 header {* Implementation of natural numbers by integers *} |
|
7 |
|
8 theory Efficient_Nat |
|
9 imports Main Pretty_Int |
|
10 begin |
|
11 |
|
12 text {* |
|
13 When generating code for functions on natural numbers, the canonical |
|
14 representation using @{term "0::nat"} and @{term "Suc"} is unsuitable for |
|
15 computations involving large numbers. The efficiency of the generated |
|
16 code can be improved drastically by implementing natural numbers by |
|
17 integers. To do this, just include this theory. |
|
18 *} |
|
19 |
|
20 subsection {* Logical rewrites *} |
|
21 |
|
22 text {* |
|
23 An int-to-nat conversion |
|
24 restricted to non-negative ints (in contrast to @{const nat}). |
|
25 Note that this restriction has no logical relevance and |
|
26 is just a kind of proof hint -- nothing prevents you from |
|
27 writing nonsense like @{term "nat_of_int (-4)"} |
|
28 *} |
|
29 |
|
30 definition |
|
31 nat_of_int :: "int \<Rightarrow> nat" where |
|
32 "k \<ge> 0 \<Longrightarrow> nat_of_int k = nat k" |
|
33 |
|
34 definition |
|
35 int' :: "nat \<Rightarrow> int" where |
|
36 "int' n = of_nat n" |
|
37 |
|
38 lemma int'_Suc [simp]: "int' (Suc n) = 1 + int' n" |
|
39 unfolding int'_def by simp |
|
40 |
|
41 lemma int'_add: "int' (m + n) = int' m + int' n" |
|
42 unfolding int'_def by (rule of_nat_add) |
|
43 |
|
44 lemma int'_mult: "int' (m * n) = int' m * int' n" |
|
45 unfolding int'_def by (rule of_nat_mult) |
|
46 |
|
47 lemma nat_of_int_of_number_of: |
|
48 fixes k |
|
49 assumes "k \<ge> 0" |
|
50 shows "number_of k = nat_of_int (number_of k)" |
|
51 unfolding nat_of_int_def [OF assms] nat_number_of_def number_of_is_id .. |
|
52 |
|
53 lemma nat_of_int_of_number_of_aux: |
|
54 fixes k |
|
55 assumes "Numeral.Pls \<le> k \<equiv> True" |
|
56 shows "k \<ge> 0" |
|
57 using assms unfolding Pls_def by simp |
|
58 |
|
59 lemma nat_of_int_int: |
|
60 "nat_of_int (int' n) = n" |
|
61 using nat_of_int_def int'_def by simp |
|
62 |
|
63 lemma eq_nat_of_int: "int' n = x \<Longrightarrow> n = nat_of_int x" |
|
64 by (erule subst, simp only: nat_of_int_int) |
|
65 |
|
66 text {* |
|
67 Case analysis on natural numbers is rephrased using a conditional |
|
68 expression: |
|
69 *} |
|
70 |
|
71 lemma [code unfold, code inline del]: |
|
72 "nat_case \<equiv> (\<lambda>f g n. if n = 0 then f else g (n - 1))" |
|
73 proof - |
|
74 have rewrite: "\<And>f g n. nat_case f g n = (if n = 0 then f else g (n - 1))" |
|
75 proof - |
|
76 fix f g n |
|
77 show "nat_case f g n = (if n = 0 then f else g (n - 1))" |
|
78 by (cases n) simp_all |
|
79 qed |
|
80 show "nat_case \<equiv> (\<lambda>f g n. if n = 0 then f else g (n - 1))" |
|
81 by (rule eq_reflection ext rewrite)+ |
|
82 qed |
|
83 |
|
84 lemma [code inline]: |
|
85 "nat_case = (\<lambda>f g n. if n = 0 then f else g (nat_of_int (int' n - 1)))" |
|
86 proof (rule ext)+ |
|
87 fix f g n |
|
88 show "nat_case f g n = (if n = 0 then f else g (nat_of_int (int' n - 1)))" |
|
89 by (cases n) (simp_all add: nat_of_int_int) |
|
90 qed |
|
91 |
|
92 text {* |
|
93 Most standard arithmetic functions on natural numbers are implemented |
|
94 using their counterparts on the integers: |
|
95 *} |
|
96 |
|
97 lemma [code func]: "0 = nat_of_int 0" |
|
98 by (simp add: nat_of_int_def) |
|
99 lemma [code func, code inline]: "1 = nat_of_int 1" |
|
100 by (simp add: nat_of_int_def) |
|
101 lemma [code func]: "Suc n = nat_of_int (int' n + 1)" |
|
102 by (simp add: eq_nat_of_int) |
|
103 lemma [code]: "m + n = nat (int' m + int' n)" |
|
104 by (simp add: int'_def nat_eq_iff2) |
|
105 lemma [code func, code inline]: "m + n = nat_of_int (int' m + int' n)" |
|
106 by (simp add: eq_nat_of_int int'_add) |
|
107 lemma [code, code inline]: "m - n = nat (int' m - int' n)" |
|
108 by (simp add: int'_def nat_eq_iff2 of_nat_diff) |
|
109 lemma [code]: "m * n = nat (int' m * int' n)" |
|
110 unfolding int'_def |
|
111 by (simp add: of_nat_mult [symmetric] del: of_nat_mult) |
|
112 lemma [code func, code inline]: "m * n = nat_of_int (int' m * int' n)" |
|
113 by (simp add: eq_nat_of_int int'_mult) |
|
114 lemma [code]: "m div n = nat (int' m div int' n)" |
|
115 unfolding int'_def zdiv_int [symmetric] by simp |
|
116 lemma [code func]: "m div n = fst (Divides.divmod m n)" |
|
117 unfolding divmod_def by simp |
|
118 lemma [code]: "m mod n = nat (int' m mod int' n)" |
|
119 unfolding int'_def zmod_int [symmetric] by simp |
|
120 lemma [code func]: "m mod n = snd (Divides.divmod m n)" |
|
121 unfolding divmod_def by simp |
|
122 lemma [code, code inline]: "(m < n) \<longleftrightarrow> (int' m < int' n)" |
|
123 unfolding int'_def by simp |
|
124 lemma [code func, code inline]: "(m \<le> n) \<longleftrightarrow> (int' m \<le> int' n)" |
|
125 unfolding int'_def by simp |
|
126 lemma [code func, code inline]: "m = n \<longleftrightarrow> int' m = int' n" |
|
127 unfolding int'_def by simp |
|
128 lemma [code func]: "nat k = (if k < 0 then 0 else nat_of_int k)" |
|
129 proof (cases "k < 0") |
|
130 case True then show ?thesis by simp |
|
131 next |
|
132 case False then show ?thesis by (simp add: nat_of_int_def) |
|
133 qed |
|
134 lemma [code func]: |
|
135 "int_aux n i = (if int' n = 0 then i else int_aux (nat_of_int (int' n - 1)) (i + 1))" |
|
136 proof - |
|
137 have "0 < n \<Longrightarrow> int' n = 1 + int' (nat_of_int (int' n - 1))" |
|
138 proof - |
|
139 assume prem: "n > 0" |
|
140 then have "int' n - 1 \<ge> 0" unfolding int'_def by auto |
|
141 then have "nat_of_int (int' n - 1) = nat (int' n - 1)" by (simp add: nat_of_int_def) |
|
142 with prem show "int' n = 1 + int' (nat_of_int (int' n - 1))" unfolding int'_def by simp |
|
143 qed |
|
144 then show ?thesis unfolding int_aux_def int'_def by auto |
|
145 qed |
|
146 |
|
147 lemma div_nat_code [code func]: |
|
148 "m div k = nat_of_int (fst (divAlg (int' m, int' k)))" |
|
149 unfolding div_def [symmetric] int'_def zdiv_int [symmetric] |
|
150 unfolding int'_def [symmetric] nat_of_int_int .. |
|
151 |
|
152 lemma mod_nat_code [code func]: |
|
153 "m mod k = nat_of_int (snd (divAlg (int' m, int' k)))" |
|
154 unfolding mod_def [symmetric] int'_def zmod_int [symmetric] |
|
155 unfolding int'_def [symmetric] nat_of_int_int .. |
|
156 |
|
157 |
|
158 subsection {* Code generator setup for basic functions *} |
|
159 |
|
160 text {* |
|
161 @{typ nat} is no longer a datatype but embedded into the integers. |
|
162 *} |
|
163 |
|
164 code_datatype nat_of_int |
|
165 |
|
166 code_type nat |
|
167 (SML "IntInf.int") |
|
168 (OCaml "Big'_int.big'_int") |
|
169 (Haskell "Integer") |
|
170 |
|
171 types_code |
|
172 nat ("int") |
|
173 attach (term_of) {* |
|
174 val term_of_nat = HOLogic.mk_number HOLogic.natT o IntInf.fromInt; |
|
175 *} |
|
176 attach (test) {* |
|
177 fun gen_nat i = random_range 0 i; |
|
178 *} |
|
179 |
|
180 consts_code |
|
181 "0 \<Colon> nat" ("0") |
|
182 Suc ("(_ + 1)") |
|
183 |
|
184 text {* |
|
185 Since natural numbers are implemented |
|
186 using integers, the coercion function @{const "int"} of type |
|
187 @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function, |
|
188 likewise @{const nat_of_int} of type @{typ "int \<Rightarrow> nat"}. |
|
189 For the @{const "nat"} function for converting an integer to a natural |
|
190 number, we give a specific implementation using an ML function that |
|
191 returns its input value, provided that it is non-negative, and otherwise |
|
192 returns @{text "0"}. |
|
193 *} |
|
194 |
|
195 consts_code |
|
196 int' ("(_)") |
|
197 nat ("\<module>nat") |
|
198 attach {* |
|
199 fun nat i = if i < 0 then 0 else i; |
|
200 *} |
|
201 |
|
202 code_const int' |
|
203 (SML "_") |
|
204 (OCaml "_") |
|
205 (Haskell "_") |
|
206 |
|
207 code_const nat_of_int |
|
208 (SML "_") |
|
209 (OCaml "_") |
|
210 (Haskell "_") |
|
211 |
|
212 |
|
213 subsection {* Preprocessors *} |
|
214 |
|
215 text {* |
|
216 Natural numerals should be expressed using @{const nat_of_int}. |
|
217 *} |
|
218 |
|
219 lemmas [code inline del] = nat_number_of_def |
|
220 |
|
221 ML {* |
|
222 fun nat_of_int_of_number_of thy cts = |
|
223 let |
|
224 val simplify_less = Simplifier.rewrite |
|
225 (HOL_basic_ss addsimps (@{thms less_numeral_code} @ @{thms less_eq_numeral_code})); |
|
226 fun mk_rew (t, ty) = |
|
227 if ty = HOLogic.natT andalso IntInf.<= (0, HOLogic.dest_numeral t) then |
|
228 Thm.capply @{cterm "(op \<le>) Numeral.Pls"} (Thm.cterm_of thy t) |
|
229 |> simplify_less |
|
230 |> (fn thm => @{thm nat_of_int_of_number_of_aux} OF [thm]) |
|
231 |> (fn thm => @{thm nat_of_int_of_number_of} OF [thm]) |
|
232 |> (fn thm => @{thm eq_reflection} OF [thm]) |
|
233 |> SOME |
|
234 else NONE |
|
235 in |
|
236 fold (HOLogic.add_numerals o Thm.term_of) cts [] |
|
237 |> map_filter mk_rew |
|
238 end; |
|
239 *} |
|
240 |
|
241 setup {* |
|
242 CodegenData.add_inline_proc ("nat_of_int_of_number_of", nat_of_int_of_number_of) |
|
243 *} |
|
244 |
|
245 text {* |
|
246 In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer |
|
247 a constructor term. Therefore, all occurrences of this term in a position |
|
248 where a pattern is expected (i.e.\ on the left-hand side of a recursion |
|
249 equation or in the arguments of an inductive relation in an introduction |
|
250 rule) must be eliminated. |
|
251 This can be accomplished by applying the following transformation rules: |
|
252 *} |
|
253 |
|
254 theorem Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow> |
|
255 f n = (if n = 0 then g else h (n - 1))" |
|
256 by (case_tac n) simp_all |
|
257 |
|
258 theorem Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n" |
|
259 by (case_tac n) simp_all |
|
260 |
|
261 text {* |
|
262 The rules above are built into a preprocessor that is plugged into |
|
263 the code generator. Since the preprocessor for introduction rules |
|
264 does not know anything about modes, some of the modes that worked |
|
265 for the canonical representation of natural numbers may no longer work. |
|
266 *} |
|
267 |
|
268 (*<*) |
|
269 |
|
270 ML {* |
|
271 local |
|
272 val Suc_if_eq = thm "Suc_if_eq"; |
|
273 val Suc_clause = thm "Suc_clause"; |
|
274 fun contains_suc t = member (op =) (term_consts t) "Suc"; |
|
275 in |
|
276 |
|
277 fun remove_suc thy thms = |
|
278 let |
|
279 val Suc_if_eq' = Thm.transfer thy Suc_if_eq; |
|
280 val vname = Name.variant (map fst |
|
281 (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x"; |
|
282 val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT)); |
|
283 fun lhs_of th = snd (Thm.dest_comb |
|
284 (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th)))))); |
|
285 fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th)))); |
|
286 fun find_vars ct = (case term_of ct of |
|
287 (Const ("Suc", _) $ Var _) => [(cv, snd (Thm.dest_comb ct))] |
|
288 | _ $ _ => |
|
289 let val (ct1, ct2) = Thm.dest_comb ct |
|
290 in |
|
291 map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @ |
|
292 map (apfst (Thm.capply ct1)) (find_vars ct2) |
|
293 end |
|
294 | _ => []); |
|
295 val eqs = maps |
|
296 (fn th => map (pair th) (find_vars (lhs_of th))) thms; |
|
297 fun mk_thms (th, (ct, cv')) = |
|
298 let |
|
299 val th' = |
|
300 Thm.implies_elim |
|
301 (Conv.fconv_rule (Thm.beta_conversion true) |
|
302 (Drule.instantiate' |
|
303 [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct), |
|
304 SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv'] |
|
305 Suc_if_eq')) (Thm.forall_intr cv' th) |
|
306 in |
|
307 case map_filter (fn th'' => |
|
308 SOME (th'', singleton |
|
309 (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'') |
|
310 handle THM _ => NONE) thms of |
|
311 [] => NONE |
|
312 | thps => |
|
313 let val (ths1, ths2) = split_list thps |
|
314 in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end |
|
315 end |
|
316 in |
|
317 case get_first mk_thms eqs of |
|
318 NONE => thms |
|
319 | SOME x => remove_suc thy x |
|
320 end; |
|
321 |
|
322 fun eqn_suc_preproc thy ths = |
|
323 let |
|
324 val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of |
|
325 in |
|
326 if forall (can dest) ths andalso |
|
327 exists (contains_suc o dest) ths |
|
328 then remove_suc thy ths else ths |
|
329 end; |
|
330 |
|
331 fun remove_suc_clause thy thms = |
|
332 let |
|
333 val Suc_clause' = Thm.transfer thy Suc_clause; |
|
334 val vname = Name.variant (map fst |
|
335 (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x"; |
|
336 fun find_var (t as Const ("Suc", _) $ (v as Var _)) = SOME (t, v) |
|
337 | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x) |
|
338 | find_var _ = NONE; |
|
339 fun find_thm th = |
|
340 let val th' = Conv.fconv_rule ObjectLogic.atomize th |
|
341 in Option.map (pair (th, th')) (find_var (prop_of th')) end |
|
342 in |
|
343 case get_first find_thm thms of |
|
344 NONE => thms |
|
345 | SOME ((th, th'), (Sucv, v)) => |
|
346 let |
|
347 val cert = cterm_of (Thm.theory_of_thm th); |
|
348 val th'' = ObjectLogic.rulify (Thm.implies_elim |
|
349 (Conv.fconv_rule (Thm.beta_conversion true) |
|
350 (Drule.instantiate' [] |
|
351 [SOME (cert (lambda v (Abs ("x", HOLogic.natT, |
|
352 abstract_over (Sucv, |
|
353 HOLogic.dest_Trueprop (prop_of th')))))), |
|
354 SOME (cert v)] Suc_clause')) |
|
355 (Thm.forall_intr (cert v) th')) |
|
356 in |
|
357 remove_suc_clause thy (map (fn th''' => |
|
358 if (op = o pairself prop_of) (th''', th) then th'' else th''') thms) |
|
359 end |
|
360 end; |
|
361 |
|
362 fun clause_suc_preproc thy ths = |
|
363 let |
|
364 val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop |
|
365 in |
|
366 if forall (can (dest o concl_of)) ths andalso |
|
367 exists (fn th => member (op =) (foldr add_term_consts |
|
368 [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths |
|
369 then remove_suc_clause thy ths else ths |
|
370 end; |
|
371 |
|
372 end; (*local*) |
|
373 |
|
374 fun lift_obj_eq f thy = |
|
375 map (fn thm => thm RS @{thm meta_eq_to_obj_eq}) |
|
376 #> f thy |
|
377 #> map (fn thm => thm RS @{thm eq_reflection}) |
|
378 #> map (Conv.fconv_rule Drule.beta_eta_conversion) |
|
379 *} |
|
380 |
|
381 setup {* |
|
382 Codegen.add_preprocessor eqn_suc_preproc |
|
383 #> Codegen.add_preprocessor clause_suc_preproc |
|
384 #> CodegenData.add_preproc ("eqn_Suc", lift_obj_eq eqn_suc_preproc) |
|
385 #> CodegenData.add_preproc ("clause_Suc", lift_obj_eq clause_suc_preproc) |
|
386 *} |
|
387 (*>*) |
|
388 |
|
389 |
|
390 subsection {* Module names *} |
|
391 |
|
392 code_modulename SML |
|
393 Nat Integer |
|
394 Divides Integer |
|
395 Efficient_Nat Integer |
|
396 |
|
397 code_modulename OCaml |
|
398 Nat Integer |
|
399 Divides Integer |
|
400 Efficient_Nat Integer |
|
401 |
|
402 code_modulename Haskell |
|
403 Nat Integer |
|
404 Efficient_Nat Integer |
|
405 |
|
406 hide const nat_of_int int' |
|
407 |
|
408 end |