src/HOL/Library/Efficient_Nat.thy
 author haftmann Tue Jan 29 10:19:56 2008 +0100 (2008-01-29) changeset 26009 b6a64fe38634 parent 25967 dd602eb20f3f child 26100 fbc60cd02ae2 permissions -rw-r--r--
treating division by zero properly
1 (*  Title:      HOL/Library/Efficient_Nat.thy
2     ID:         \$Id\$
3     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
4 *)
6 header {* Implementation of natural numbers by target-language integers *}
8 theory Efficient_Nat
9 imports Main Code_Integer Code_Index
10 begin
12 text {*
13   When generating code for functions on natural numbers, the
14   canonical representation using @{term "0::nat"} and
15   @{term "Suc"} is unsuitable for computations involving large
16   numbers.  The efficiency of the generated code can be improved
17   drastically by implementing natural numbers by target-language
18   integers.  To do this, just include this theory.
19 *}
21 subsection {* Basic arithmetic *}
23 text {*
24   Most standard arithmetic functions on natural numbers are implemented
25   using their counterparts on the integers:
26 *}
28 code_datatype number_nat_inst.number_of_nat
30 lemma zero_nat_code [code, code unfold]:
31   "0 = (Numeral0 :: nat)"
32   by simp
33 lemmas [code post] = zero_nat_code [symmetric]
35 lemma one_nat_code [code, code unfold]:
36   "1 = (Numeral1 :: nat)"
37   by simp
38 lemmas [code post] = one_nat_code [symmetric]
40 lemma Suc_code [code]:
41   "Suc n = n + 1"
42   by simp
44 lemma plus_nat_code [code]:
45   "n + m = nat (of_nat n + of_nat m)"
46   by simp
48 lemma minus_nat_code [code]:
49   "n - m = nat (of_nat n - of_nat m)"
50   by simp
52 lemma times_nat_code [code]:
53   "n * m = nat (of_nat n * of_nat m)"
54   unfolding of_nat_mult [symmetric] by simp
56 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"}
57   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
59 definition
60   div_mod_nat_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat"
61 where
62   [code func del]: "div_mod_nat_aux = Divides.divmod"
64 lemma [code func]:
65   "Divides.divmod n m = (if m = 0 then (0, n) else div_mod_nat_aux n m)"
66   unfolding div_mod_nat_aux_def divmod_def by simp
68 lemma div_mod_aux_code [code]:
69   "div_mod_nat_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
70   unfolding div_mod_nat_aux_def divmod_def zdiv_int [symmetric] zmod_int [symmetric] by simp
72 lemma eq_nat_code [code]:
73   "n = m \<longleftrightarrow> (of_nat n \<Colon> int) = of_nat m"
74   by simp
76 lemma less_eq_nat_code [code]:
77   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
78   by simp
80 lemma less_nat_code [code]:
81   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
82   by simp
84 subsection {* Case analysis *}
86 text {*
87   Case analysis on natural numbers is rephrased using a conditional
88   expression:
89 *}
91 lemma [code func, code unfold]:
92   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
93   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
96 subsection {* Preprocessors *}
98 text {*
99   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
100   a constructor term. Therefore, all occurrences of this term in a position
101   where a pattern is expected (i.e.\ on the left-hand side of a recursion
102   equation or in the arguments of an inductive relation in an introduction
103   rule) must be eliminated.
104   This can be accomplished by applying the following transformation rules:
105 *}
107 lemma Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
108   f n = (if n = 0 then g else h (n - 1))"
109   by (case_tac n) simp_all
111 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
112   by (case_tac n) simp_all
114 text {*
115   The rules above are built into a preprocessor that is plugged into
116   the code generator. Since the preprocessor for introduction rules
117   does not know anything about modes, some of the modes that worked
118   for the canonical representation of natural numbers may no longer work.
119 *}
121 (*<*)
123 ML {*
124 fun remove_suc thy thms =
125   let
126     val vname = Name.variant (map fst
127       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
128     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
129     fun lhs_of th = snd (Thm.dest_comb
130       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
131     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
132     fun find_vars ct = (case term_of ct of
133         (Const ("Suc", _) \$ Var _) => [(cv, snd (Thm.dest_comb ct))]
134       | _ \$ _ =>
135         let val (ct1, ct2) = Thm.dest_comb ct
136         in
137           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
138           map (apfst (Thm.capply ct1)) (find_vars ct2)
139         end
140       | _ => []);
141     val eqs = maps
142       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
143     fun mk_thms (th, (ct, cv')) =
144       let
145         val th' =
146           Thm.implies_elim
147            (Conv.fconv_rule (Thm.beta_conversion true)
148              (Drule.instantiate'
149                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
150                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
151                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
152       in
153         case map_filter (fn th'' =>
154             SOME (th'', singleton
155               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
156           handle THM _ => NONE) thms of
157             [] => NONE
158           | thps =>
159               let val (ths1, ths2) = split_list thps
160               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
161       end
162   in
163     case get_first mk_thms eqs of
164       NONE => thms
165     | SOME x => remove_suc thy x
166   end;
168 fun eqn_suc_preproc thy ths =
169   let
170     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of;
171     fun contains_suc t = member (op =) (term_consts t) @{const_name Suc};
172   in
173     if forall (can dest) ths andalso
174       exists (contains_suc o dest) ths
175     then remove_suc thy ths else ths
176   end;
178 fun remove_suc_clause thy thms =
179   let
180     val vname = Name.variant (map fst
181       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
182     fun find_var (t as Const (@{const_name Suc}, _) \$ (v as Var _)) = SOME (t, v)
183       | find_var (t \$ u) = (case find_var t of NONE => find_var u | x => x)
184       | find_var _ = NONE;
185     fun find_thm th =
186       let val th' = Conv.fconv_rule ObjectLogic.atomize th
187       in Option.map (pair (th, th')) (find_var (prop_of th')) end
188   in
189     case get_first find_thm thms of
190       NONE => thms
191     | SOME ((th, th'), (Sucv, v)) =>
192         let
193           val cert = cterm_of (Thm.theory_of_thm th);
194           val th'' = ObjectLogic.rulify (Thm.implies_elim
195             (Conv.fconv_rule (Thm.beta_conversion true)
196               (Drule.instantiate' []
197                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
198                    abstract_over (Sucv,
199                      HOLogic.dest_Trueprop (prop_of th')))))),
200                  SOME (cert v)] @{thm Suc_clause}))
201             (Thm.forall_intr (cert v) th'))
202         in
203           remove_suc_clause thy (map (fn th''' =>
204             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
205         end
206   end;
208 fun clause_suc_preproc thy ths =
209   let
210     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
211   in
212     if forall (can (dest o concl_of)) ths andalso
213       exists (fn th => member (op =) (foldr add_term_consts
214         [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths
215     then remove_suc_clause thy ths else ths
216   end;
218 fun lift_obj_eq f thy thms =
219   thms
220   |> try (
221     map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
222     #> f thy
223     #> map (fn thm => thm RS @{thm eq_reflection})
224     #> map (Conv.fconv_rule Drule.beta_eta_conversion))
225   |> the_default thms
226 *}
228 setup {*
230   #> Codegen.add_preprocessor clause_suc_preproc
231   #> Code.add_preproc ("eqn_Suc", lift_obj_eq eqn_suc_preproc)
232   #> Code.add_preproc ("clause_Suc", lift_obj_eq clause_suc_preproc)
233 *}
234 (*>*)
236 subsection {* Target language setup *}
238 text {*
239   For ML, we map @{typ nat} to target language integers, where we
240   assert that values are always non-negative.
241 *}
243 code_type nat
244   (SML "int")
245   (OCaml "Big'_int.big'_int")
247 types_code
248   nat ("int")
249 attach (term_of) {*
250 val term_of_nat = HOLogic.mk_number HOLogic.natT;
251 *}
252 attach (test) {*
253 fun gen_nat i =
254   let val n = random_range 0 i
255   in (n, fn () => term_of_nat n) end;
256 *}
258 text {*
259   For Haskell we define our own @{typ nat} type.  The reason
260   is that we have to distinguish type class instances
261   for @{typ nat} and @{typ int}.
262 *}
264 code_include Haskell "Nat" {*
265 newtype Nat = Nat Integer deriving (Show, Eq);
267 instance Num Nat where {
268   fromInteger k = Nat (if k >= 0 then k else 0);
269   Nat n + Nat m = Nat (n + m);
270   Nat n - Nat m = fromInteger (n - m);
271   Nat n * Nat m = Nat (n * m);
272   abs n = n;
273   signum _ = 1;
274   negate n = error "negate Nat";
275 };
277 instance Ord Nat where {
278   Nat n <= Nat m = n <= m;
279   Nat n < Nat m = n < m;
280 };
282 instance Real Nat where {
283   toRational (Nat n) = toRational n;
284 };
286 instance Enum Nat where {
287   toEnum k = fromInteger (toEnum k);
288   fromEnum (Nat n) = fromEnum n;
289 };
291 instance Integral Nat where {
292   toInteger (Nat n) = n;
293   divMod n m = quotRem n m;
294   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
295 };
296 *}
298 code_reserved Haskell Nat
300 code_type nat
303 code_instance nat :: eq
306 text {*
307   Natural numerals.
308 *}
310 lemma [code inline, symmetric, code post]:
311   "nat (number_of i) = number_nat_inst.number_of_nat i"
312   -- {* this interacts as desired with @{thm nat_number_of_def} *}
313   by (simp add: number_nat_inst.number_of_nat)
315 setup {*
316   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
317     true false) ["SML", "OCaml", "Haskell"]
318 *}
320 text {*
321   Since natural numbers are implemented
322   using integers in ML, the coercion function @{const "of_nat"} of type
323   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
324   For the @{const "nat"} function for converting an integer to a natural
325   number, we give a specific implementation using an ML function that
326   returns its input value, provided that it is non-negative, and otherwise
327   returns @{text "0"}.
328 *}
330 definition
331   int :: "nat \<Rightarrow> int"
332 where
333   [code func del]: "int = of_nat"
335 lemma int_code' [code func]:
336   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
337   unfolding int_nat_number_of [folded int_def] ..
339 lemma nat_code' [code func]:
340   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
341   by auto
343 lemma of_nat_int [code unfold]:
344   "of_nat = int" by (simp add: int_def)
345 declare of_nat_int [symmetric, code post]
347 code_const int
348   (SML "_")
349   (OCaml "_")
351 consts_code
352   int ("(_)")
353   nat ("\<module>nat")
354 attach {*
355 fun nat i = if i < 0 then 0 else i;
356 *}
358 code_const nat
359   (SML "IntInf.max/ (/0,/ _)")
360   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
362 text {* For Haskell, things are slightly different again. *}
364 code_const int and nat
365   (Haskell "toInteger" and "fromInteger")
367 text {* Conversion from and to indices. *}
369 code_const index_of_nat
370   (SML "IntInf.toInt")
371   (OCaml "Big'_int.int'_of'_big'_int")
374 code_const nat_of_index
375   (SML "IntInf.fromInt")
376   (OCaml "Big'_int.big'_int'_of'_int")
379 text {* Using target language arithmetic operations whenever appropriate *}
381 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
382   (SML "IntInf.+ ((_), (_))")
384   (Haskell infixl 6 "+")
386 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
387   (SML "IntInf.* ((_), (_))")
388   (OCaml "Big'_int.mult'_big'_int")
389   (Haskell infixl 7 "*")
391 code_const div_mod_nat_aux
392   (SML "IntInf.divMod/ ((_),/ (_))")
393   (OCaml "Big'_int.quomod'_big'_int")
396 code_const "op = \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
397   (SML "!((_ : IntInf.int) = _)")
398   (OCaml "Big'_int.eq'_big'_int")
399   (Haskell infixl 4 "==")
401 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
402   (SML "IntInf.<= ((_), (_))")
403   (OCaml "Big'_int.le'_big'_int")
404   (Haskell infix 4 "<=")
406 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
407   (SML "IntInf.< ((_), (_))")
408   (OCaml "Big'_int.lt'_big'_int")
409   (Haskell infix 4 "<")
411 consts_code
412   0                            ("0")
413   Suc                          ("(_ +/ 1)")
414   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
415   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
416   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
417   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
420 text {* Module names *}
422 code_modulename SML
423   Nat Integer
424   Divides Integer
425   Efficient_Nat Integer
427 code_modulename OCaml
428   Nat Integer
429   Divides Integer
430   Efficient_Nat Integer