src/HOL/Library/Efficient_Nat.thy
 author haftmann Tue Sep 16 09:21:24 2008 +0200 (2008-09-16) changeset 28228 7ebe8dc06cbb parent 27673 52056ddac194 child 28346 b8390cd56b8f permissions -rw-r--r--
evaluation using code generator
1 (*  Title:      HOL/Library/Efficient_Nat.thy
2     ID:         \$Id\$
3     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
4 *)
6 header {* Implementation of natural numbers by target-language integers *}
8 theory Efficient_Nat
9 imports Plain Code_Index Code_Integer
10 begin
12 text {*
13   When generating code for functions on natural numbers, the
14   canonical representation using @{term "0::nat"} and
15   @{term "Suc"} is unsuitable for computations involving large
16   numbers.  The efficiency of the generated code can be improved
17   drastically by implementing natural numbers by target-language
18   integers.  To do this, just include this theory.
19 *}
21 subsection {* Basic arithmetic *}
23 text {*
24   Most standard arithmetic functions on natural numbers are implemented
25   using their counterparts on the integers:
26 *}
28 code_datatype number_nat_inst.number_of_nat
30 lemma zero_nat_code [code, code unfold]:
31   "0 = (Numeral0 :: nat)"
32   by simp
33 lemmas [code post] = zero_nat_code [symmetric]
35 lemma one_nat_code [code, code unfold]:
36   "1 = (Numeral1 :: nat)"
37   by simp
38 lemmas [code post] = one_nat_code [symmetric]
40 lemma Suc_code [code]:
41   "Suc n = n + 1"
42   by simp
44 lemma plus_nat_code [code]:
45   "n + m = nat (of_nat n + of_nat m)"
46   by simp
48 lemma minus_nat_code [code]:
49   "n - m = nat (of_nat n - of_nat m)"
50   by simp
52 lemma times_nat_code [code]:
53   "n * m = nat (of_nat n * of_nat m)"
54   unfolding of_nat_mult [symmetric] by simp
56 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"}
57   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
59 definition
60   divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat"
61 where
62   [code func del]: "divmod_aux = divmod"
64 lemma [code func]:
65   "divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
66   unfolding divmod_aux_def divmod_div_mod by simp
68 lemma divmod_aux_code [code]:
69   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
70   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
72 lemma eq_nat_code [code]:
73   "n = m \<longleftrightarrow> (of_nat n \<Colon> int) = of_nat m"
74   by simp
76 lemma less_eq_nat_code [code]:
77   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
78   by simp
80 lemma less_nat_code [code]:
81   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
82   by simp
84 subsection {* Case analysis *}
86 text {*
87   Case analysis on natural numbers is rephrased using a conditional
88   expression:
89 *}
91 lemma [code func, code unfold]:
92   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
93   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
96 subsection {* Preprocessors *}
98 text {*
99   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
100   a constructor term. Therefore, all occurrences of this term in a position
101   where a pattern is expected (i.e.\ on the left-hand side of a recursion
102   equation or in the arguments of an inductive relation in an introduction
103   rule) must be eliminated.
104   This can be accomplished by applying the following transformation rules:
105 *}
107 lemma Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
108   f n = (if n = 0 then g else h (n - 1))"
109   by (case_tac n) simp_all
111 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
112   by (case_tac n) simp_all
114 text {*
115   The rules above are built into a preprocessor that is plugged into
116   the code generator. Since the preprocessor for introduction rules
117   does not know anything about modes, some of the modes that worked
118   for the canonical representation of natural numbers may no longer work.
119 *}
121 (*<*)
122 setup {*
123 let
125 fun remove_suc thy thms =
126   let
127     val vname = Name.variant (map fst
128       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
129     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
130     fun lhs_of th = snd (Thm.dest_comb
131       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
132     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
133     fun find_vars ct = (case term_of ct of
134         (Const ("Suc", _) \$ Var _) => [(cv, snd (Thm.dest_comb ct))]
135       | _ \$ _ =>
136         let val (ct1, ct2) = Thm.dest_comb ct
137         in
138           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
139           map (apfst (Thm.capply ct1)) (find_vars ct2)
140         end
141       | _ => []);
142     val eqs = maps
143       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
144     fun mk_thms (th, (ct, cv')) =
145       let
146         val th' =
147           Thm.implies_elim
148            (Conv.fconv_rule (Thm.beta_conversion true)
149              (Drule.instantiate'
150                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
151                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
152                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
153       in
154         case map_filter (fn th'' =>
155             SOME (th'', singleton
156               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
157           handle THM _ => NONE) thms of
158             [] => NONE
159           | thps =>
160               let val (ths1, ths2) = split_list thps
161               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
162       end
163   in case get_first mk_thms eqs of
164       NONE => thms
165     | SOME x => remove_suc thy x
166   end;
168 fun eqn_suc_preproc thy ths =
169   let
170     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of;
171     fun contains_suc t = member (op =) (term_consts t) @{const_name Suc};
172   in
173     if forall (can dest) ths andalso
174       exists (contains_suc o dest) ths
175     then remove_suc thy ths else ths
176   end;
178 fun remove_suc_clause thy thms =
179   let
180     val vname = Name.variant (map fst
181       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
182     fun find_var (t as Const (@{const_name Suc}, _) \$ (v as Var _)) = SOME (t, v)
183       | find_var (t \$ u) = (case find_var t of NONE => find_var u | x => x)
184       | find_var _ = NONE;
185     fun find_thm th =
186       let val th' = Conv.fconv_rule ObjectLogic.atomize th
187       in Option.map (pair (th, th')) (find_var (prop_of th')) end
188   in
189     case get_first find_thm thms of
190       NONE => thms
191     | SOME ((th, th'), (Sucv, v)) =>
192         let
193           val cert = cterm_of (Thm.theory_of_thm th);
194           val th'' = ObjectLogic.rulify (Thm.implies_elim
195             (Conv.fconv_rule (Thm.beta_conversion true)
196               (Drule.instantiate' []
197                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
198                    abstract_over (Sucv,
199                      HOLogic.dest_Trueprop (prop_of th')))))),
200                  SOME (cert v)] @{thm Suc_clause}))
201             (Thm.forall_intr (cert v) th'))
202         in
203           remove_suc_clause thy (map (fn th''' =>
204             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
205         end
206   end;
208 fun clause_suc_preproc thy ths =
209   let
210     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
211   in
212     if forall (can (dest o concl_of)) ths andalso
213       exists (fn th => member (op =) (foldr add_term_consts
214         [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths
215     then remove_suc_clause thy ths else ths
216   end;
218 fun lift f thy thms1 =
219   let
220     val thms2 = Drule.zero_var_indexes_list thms1;
221     val thms3 = try (map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
222       #> f thy
223       #> map (fn thm => thm RS @{thm eq_reflection})
224       #> map (Conv.fconv_rule Drule.beta_eta_conversion)) thms2;
225     val thms4 = Option.map Drule.zero_var_indexes_list thms3;
226   in case thms4
227    of NONE => NONE
228     | SOME thms4 => if Thm.eq_thms (thms2, thms4) then NONE else SOME thms4
229   end
231 in
234   #> Codegen.add_preprocessor clause_suc_preproc
235   #> Code.add_functrans ("eqn_Suc", lift eqn_suc_preproc)
236   #> Code.add_functrans ("clause_Suc", lift clause_suc_preproc)
238 end;
239 *}
240 (*>*)
243 subsection {* Target language setup *}
245 text {*
246   For ML, we map @{typ nat} to target language integers, where we
247   assert that values are always non-negative.
248 *}
250 code_type nat
251   (SML "IntInf.int")
252   (OCaml "Big'_int.big'_int")
254 types_code
255   nat ("int")
256 attach (term_of) {*
257 val term_of_nat = HOLogic.mk_number HOLogic.natT;
258 *}
259 attach (test) {*
260 fun gen_nat i =
261   let val n = random_range 0 i
262   in (n, fn () => term_of_nat n) end;
263 *}
265 text {*
266   For Haskell we define our own @{typ nat} type.  The reason
267   is that we have to distinguish type class instances
268   for @{typ nat} and @{typ int}.
269 *}
271 code_include Haskell "Nat" {*
272 newtype Nat = Nat Integer deriving (Show, Eq);
274 instance Num Nat where {
275   fromInteger k = Nat (if k >= 0 then k else 0);
276   Nat n + Nat m = Nat (n + m);
277   Nat n - Nat m = fromInteger (n - m);
278   Nat n * Nat m = Nat (n * m);
279   abs n = n;
280   signum _ = 1;
281   negate n = error "negate Nat";
282 };
284 instance Ord Nat where {
285   Nat n <= Nat m = n <= m;
286   Nat n < Nat m = n < m;
287 };
289 instance Real Nat where {
290   toRational (Nat n) = toRational n;
291 };
293 instance Enum Nat where {
294   toEnum k = fromInteger (toEnum k);
295   fromEnum (Nat n) = fromEnum n;
296 };
298 instance Integral Nat where {
299   toInteger (Nat n) = n;
300   divMod n m = quotRem n m;
301   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
302 };
303 *}
305 code_reserved Haskell Nat
307 code_type nat
310 code_instance nat :: eq
313 text {*
314   Natural numerals.
315 *}
317 lemma [code inline, symmetric, code post]:
318   "nat (number_of i) = number_nat_inst.number_of_nat i"
319   -- {* this interacts as desired with @{thm nat_number_of_def} *}
320   by (simp add: number_nat_inst.number_of_nat)
322 setup {*
323   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
324     true false) ["SML", "OCaml", "Haskell"]
325 *}
327 text {*
328   Since natural numbers are implemented
329   using integers in ML, the coercion function @{const "of_nat"} of type
330   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
331   For the @{const "nat"} function for converting an integer to a natural
332   number, we give a specific implementation using an ML function that
333   returns its input value, provided that it is non-negative, and otherwise
334   returns @{text "0"}.
335 *}
337 definition
338   int :: "nat \<Rightarrow> int"
339 where
340   [code func del]: "int = of_nat"
342 lemma int_code' [code func]:
343   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
344   unfolding int_nat_number_of [folded int_def] ..
346 lemma nat_code' [code func]:
347   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
348   by auto
350 lemma of_nat_int [code unfold]:
351   "of_nat = int" by (simp add: int_def)
352 declare of_nat_int [symmetric, code post]
354 code_const int
355   (SML "_")
356   (OCaml "_")
358 consts_code
359   int ("(_)")
360   nat ("\<module>nat")
361 attach {*
362 fun nat i = if i < 0 then 0 else i;
363 *}
365 code_const nat
366   (SML "IntInf.max/ (/0,/ _)")
367   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
369 text {* For Haskell, things are slightly different again. *}
371 code_const int and nat
372   (Haskell "toInteger" and "fromInteger")
374 text {* Conversion from and to indices. *}
376 code_const index_of_nat
377   (SML "IntInf.toInt")
378   (OCaml "Big'_int.int'_of'_big'_int")
381 code_const nat_of_index
382   (SML "IntInf.fromInt")
383   (OCaml "Big'_int.big'_int'_of'_int")
386 text {* Using target language arithmetic operations whenever appropriate *}
388 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
389   (SML "IntInf.+ ((_), (_))")
391   (Haskell infixl 6 "+")
393 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
394   (SML "IntInf.* ((_), (_))")
395   (OCaml "Big'_int.mult'_big'_int")
396   (Haskell infixl 7 "*")
398 code_const divmod_aux
399   (SML "IntInf.divMod/ ((_),/ (_))")
400   (OCaml "Big'_int.quomod'_big'_int")
403 code_const "op = \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
404   (SML "!((_ : IntInf.int) = _)")
405   (OCaml "Big'_int.eq'_big'_int")
406   (Haskell infixl 4 "==")
408 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
409   (SML "IntInf.<= ((_), (_))")
410   (OCaml "Big'_int.le'_big'_int")
411   (Haskell infix 4 "<=")
413 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
414   (SML "IntInf.< ((_), (_))")
415   (OCaml "Big'_int.lt'_big'_int")
416   (Haskell infix 4 "<")
418 consts_code
419   0                            ("0")
420   Suc                          ("(_ +/ 1)")
421   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
422   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
423   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
424   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
427 text {* Evaluation *}
429 lemma [code func, code func del]:
430   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
432 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
433   (SML "HOLogic.mk'_number/ HOLogic.natT")
436 text {* Module names *}
438 code_modulename SML
439   Nat Integer
440   Divides Integer
441   Efficient_Nat Integer
443 code_modulename OCaml
444   Nat Integer
445   Divides Integer
446   Efficient_Nat Integer