src/HOL/Library/Efficient_Nat.thy
 author bulwahn Fri Apr 08 16:31:14 2011 +0200 (2011-04-08) changeset 42316 12635bb655fd parent 40607 30d512bf47a7 child 43324 2b47822868e4 permissions -rw-r--r--
deactivating other compilations in quickcheck_exhaustive momentarily that only interesting for my benchmarks and experiments
1 (*  Title:      HOL/Library/Efficient_Nat.thy
2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
3 *)
5 header {* Implementation of natural numbers by target-language integers *}
7 theory Efficient_Nat
8 imports Code_Integer Main
9 begin
11 text {*
12   When generating code for functions on natural numbers, the
13   canonical representation using @{term "0::nat"} and
14   @{term Suc} is unsuitable for computations involving large
15   numbers.  The efficiency of the generated code can be improved
16   drastically by implementing natural numbers by target-language
17   integers.  To do this, just include this theory.
18 *}
20 subsection {* Basic arithmetic *}
22 text {*
23   Most standard arithmetic functions on natural numbers are implemented
24   using their counterparts on the integers:
25 *}
27 code_datatype number_nat_inst.number_of_nat
29 lemma zero_nat_code [code, code_unfold_post]:
30   "0 = (Numeral0 :: nat)"
31   by simp
33 lemma one_nat_code [code, code_unfold_post]:
34   "1 = (Numeral1 :: nat)"
35   by simp
37 lemma Suc_code [code]:
38   "Suc n = n + 1"
39   by simp
41 lemma plus_nat_code [code]:
42   "n + m = nat (of_nat n + of_nat m)"
43   by simp
45 lemma minus_nat_code [code]:
46   "n - m = nat (of_nat n - of_nat m)"
47   by simp
49 lemma times_nat_code [code]:
50   "n * m = nat (of_nat n * of_nat m)"
51   unfolding of_nat_mult [symmetric] by simp
53 lemma divmod_nat_code [code]:
54   "divmod_nat n m = map_pair nat nat (pdivmod (of_nat n) (of_nat m))"
55   by (simp add: map_pair_def split_def pdivmod_def nat_div_distrib nat_mod_distrib divmod_nat_div_mod)
57 lemma eq_nat_code [code]:
58   "HOL.equal n m \<longleftrightarrow> HOL.equal (of_nat n \<Colon> int) (of_nat m)"
59   by (simp add: equal)
61 lemma eq_nat_refl [code nbe]:
62   "HOL.equal (n::nat) n \<longleftrightarrow> True"
63   by (rule equal_refl)
65 lemma less_eq_nat_code [code]:
66   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
67   by simp
69 lemma less_nat_code [code]:
70   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
71   by simp
73 subsection {* Case analysis *}
75 text {*
76   Case analysis on natural numbers is rephrased using a conditional
77   expression:
78 *}
80 lemma [code, code_unfold]:
81   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
82   by (auto simp add: fun_eq_iff dest!: gr0_implies_Suc)
85 subsection {* Preprocessors *}
87 text {*
88   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
89   a constructor term. Therefore, all occurrences of this term in a position
90   where a pattern is expected (i.e.\ on the left-hand side of a recursion
91   equation or in the arguments of an inductive relation in an introduction
92   rule) must be eliminated.
93   This can be accomplished by applying the following transformation rules:
94 *}
96 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
97   f n \<equiv> if n = 0 then g else h (n - 1)"
98   by (rule eq_reflection) (cases n, simp_all)
100 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
101   by (cases n) simp_all
103 text {*
104   The rules above are built into a preprocessor that is plugged into
105   the code generator. Since the preprocessor for introduction rules
106   does not know anything about modes, some of the modes that worked
107   for the canonical representation of natural numbers may no longer work.
108 *}
110 (*<*)
111 setup {*
112 let
114 fun remove_suc thy thms =
115   let
116     val vname = Name.variant (map fst
117       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
118     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
119     fun lhs_of th = snd (Thm.dest_comb
120       (fst (Thm.dest_comb (cprop_of th))));
121     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
122     fun find_vars ct = (case term_of ct of
123         (Const (@{const_name Suc}, _) \$ Var _) => [(cv, snd (Thm.dest_comb ct))]
124       | _ \$ _ =>
125         let val (ct1, ct2) = Thm.dest_comb ct
126         in
127           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
128           map (apfst (Thm.capply ct1)) (find_vars ct2)
129         end
130       | _ => []);
131     val eqs = maps
132       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
133     fun mk_thms (th, (ct, cv')) =
134       let
135         val th' =
136           Thm.implies_elim
137            (Conv.fconv_rule (Thm.beta_conversion true)
138              (Drule.instantiate'
139                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
140                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
141                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
142       in
143         case map_filter (fn th'' =>
144             SOME (th'', singleton
145               (Variable.trade (K (fn [th'''] => [th''' RS th']))
146                 (Variable.global_thm_context th'')) th'')
147           handle THM _ => NONE) thms of
148             [] => NONE
149           | thps =>
150               let val (ths1, ths2) = split_list thps
151               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
152       end
153   in get_first mk_thms eqs end;
155 fun eqn_suc_base_preproc thy thms =
156   let
157     val dest = fst o Logic.dest_equals o prop_of;
158     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
159   in
160     if forall (can dest) thms andalso exists (contains_suc o dest) thms
161       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
162        else NONE
163   end;
165 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
167 fun remove_suc_clause thy thms =
168   let
169     val vname = Name.variant (map fst
170       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
171     fun find_var (t as Const (@{const_name Suc}, _) \$ (v as Var _)) = SOME (t, v)
172       | find_var (t \$ u) = (case find_var t of NONE => find_var u | x => x)
173       | find_var _ = NONE;
174     fun find_thm th =
175       let val th' = Conv.fconv_rule Object_Logic.atomize th
176       in Option.map (pair (th, th')) (find_var (prop_of th')) end
177   in
178     case get_first find_thm thms of
179       NONE => thms
180     | SOME ((th, th'), (Sucv, v)) =>
181         let
182           val cert = cterm_of (Thm.theory_of_thm th);
183           val th'' = Object_Logic.rulify (Thm.implies_elim
184             (Conv.fconv_rule (Thm.beta_conversion true)
185               (Drule.instantiate' []
186                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
187                    abstract_over (Sucv,
188                      HOLogic.dest_Trueprop (prop_of th')))))),
189                  SOME (cert v)] @{thm Suc_clause}))
190             (Thm.forall_intr (cert v) th'))
191         in
192           remove_suc_clause thy (map (fn th''' =>
193             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
194         end
195   end;
197 fun clause_suc_preproc thy ths =
198   let
199     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
200   in
201     if forall (can (dest o concl_of)) ths andalso
202       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
203         (map_filter (try dest) (concl_of th :: prems_of th))) ths
204     then remove_suc_clause thy ths else ths
205   end;
206 in
208   Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
209   #> Codegen.add_preprocessor clause_suc_preproc
211 end;
212 *}
213 (*>*)
216 subsection {* Target language setup *}
218 text {*
219   For ML, we map @{typ nat} to target language integers, where we
220   ensure that values are always non-negative.
221 *}
223 code_type nat
224   (SML "IntInf.int")
225   (OCaml "Big'_int.big'_int")
226   (Eval "int")
228 types_code
229   nat ("int")
230 attach (term_of) {*
231 val term_of_nat = HOLogic.mk_number HOLogic.natT;
232 *}
233 attach (test) {*
234 fun gen_nat i =
235   let val n = random_range 0 i
236   in (n, fn () => term_of_nat n) end;
237 *}
239 text {*
240   For Haskell ans Scala we define our own @{typ nat} type.  The reason
241   is that we have to distinguish type class instances for @{typ nat}
242   and @{typ int}.
243 *}
245 code_include Haskell "Nat"
246 {*newtype Nat = Nat Integer deriving (Eq, Show, Read);
248 instance Num Nat where {
249   fromInteger k = Nat (if k >= 0 then k else 0);
250   Nat n + Nat m = Nat (n + m);
251   Nat n - Nat m = fromInteger (n - m);
252   Nat n * Nat m = Nat (n * m);
253   abs n = n;
254   signum _ = 1;
255   negate n = error "negate Nat";
256 };
258 instance Ord Nat where {
259   Nat n <= Nat m = n <= m;
260   Nat n < Nat m = n < m;
261 };
263 instance Real Nat where {
264   toRational (Nat n) = toRational n;
265 };
267 instance Enum Nat where {
268   toEnum k = fromInteger (toEnum k);
269   fromEnum (Nat n) = fromEnum n;
270 };
272 instance Integral Nat where {
273   toInteger (Nat n) = n;
274   divMod n m = quotRem n m;
275   quotRem (Nat n) (Nat m)
276     | (m == 0) = (0, Nat n)
277     | otherwise = (Nat k, Nat l) where (k, l) = quotRem n m;
278 };
279 *}
281 code_reserved Haskell Nat
283 code_include Scala "Nat"
284 {*object Nat {
286   def apply(numeral: BigInt): Nat = new Nat(numeral max 0)
287   def apply(numeral: Int): Nat = Nat(BigInt(numeral))
288   def apply(numeral: String): Nat = Nat(BigInt(numeral))
290 }
292 class Nat private(private val value: BigInt) {
294   override def hashCode(): Int = this.value.hashCode()
296   override def equals(that: Any): Boolean = that match {
297     case that: Nat => this equals that
298     case _ => false
299   }
301   override def toString(): String = this.value.toString
303   def equals(that: Nat): Boolean = this.value == that.value
305   def as_BigInt: BigInt = this.value
306   def as_Int: Int = if (this.value >= scala.Int.MinValue && this.value <= scala.Int.MaxValue)
307       this.value.intValue
308     else error("Int value out of range: " + this.value.toString)
310   def +(that: Nat): Nat = new Nat(this.value + that.value)
311   def -(that: Nat): Nat = Nat(this.value - that.value)
312   def *(that: Nat): Nat = new Nat(this.value * that.value)
314   def /%(that: Nat): (Nat, Nat) = if (that.value == 0) (new Nat(0), this)
315     else {
316       val (k, l) = this.value /% that.value
317       (new Nat(k), new Nat(l))
318     }
320   def <=(that: Nat): Boolean = this.value <= that.value
322   def <(that: Nat): Boolean = this.value < that.value
324 }
325 *}
327 code_reserved Scala Nat
329 code_type nat
331   (Scala "Nat")
333 code_instance nat :: equal
336 text {*
337   Natural numerals.
338 *}
340 lemma [code_unfold_post]:
341   "nat (number_of i) = number_nat_inst.number_of_nat i"
342   -- {* this interacts as desired with @{thm nat_number_of_def} *}
343   by (simp add: number_nat_inst.number_of_nat)
345 setup {*
346   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
347     false Code_Printer.literal_positive_numeral) ["SML", "OCaml", "Haskell", "Scala"]
348 *}
350 text {*
351   Since natural numbers are implemented
352   using integers in ML, the coercion function @{const "of_nat"} of type
353   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
354   For the @{const nat} function for converting an integer to a natural
355   number, we give a specific implementation using an ML function that
356   returns its input value, provided that it is non-negative, and otherwise
357   returns @{text "0"}.
358 *}
360 definition int :: "nat \<Rightarrow> int" where
361   [code del]: "int = of_nat"
363 lemma int_code' [code]:
364   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
365   unfolding int_nat_number_of [folded int_def] ..
367 lemma nat_code' [code]:
368   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
369   unfolding nat_number_of_def number_of_is_id neg_def by simp
371 lemma of_nat_int [code_unfold_post]:
372   "of_nat = int" by (simp add: int_def)
374 lemma of_nat_aux_int [code_unfold]:
375   "of_nat_aux (\<lambda>i. i + 1) k 0 = int k"
376   by (simp add: int_def Nat.of_nat_code)
378 code_const int
379   (SML "_")
380   (OCaml "_")
382 consts_code
383   int ("(_)")
384   nat ("\<module>nat")
385 attach {*
386 fun nat i = if i < 0 then 0 else i;
387 *}
389 code_const nat
390   (SML "IntInf.max/ (/0,/ _)")
391   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
392   (Eval "Integer.max/ _/ 0")
394 text {* For Haskell and Scala, things are slightly different again. *}
396 code_const int and nat
397   (Haskell "toInteger" and "fromInteger")
398   (Scala "!_.as'_BigInt" and "Nat")
400 text {* Conversion from and to code numerals. *}
402 code_const Code_Numeral.of_nat
403   (SML "IntInf.toInt")
404   (OCaml "_")
405   (Haskell "!(fromInteger/ ./ toInteger)")
406   (Scala "!Natural(_.as'_BigInt)")
407   (Eval "_")
409 code_const Code_Numeral.nat_of
410   (SML "IntInf.fromInt")
411   (OCaml "_")
412   (Haskell "!(fromInteger/ ./ toInteger)")
413   (Scala "!Nat(_.as'_BigInt)")
414   (Eval "_")
416 text {* Using target language arithmetic operations whenever appropriate *}
418 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
419   (SML "IntInf.+ ((_), (_))")
421   (Haskell infixl 6 "+")
422   (Scala infixl 7 "+")
423   (Eval infixl 8 "+")
425 code_const "op - \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
426   (Haskell infixl 6 "-")
427   (Scala infixl 7 "-")
429 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
430   (SML "IntInf.* ((_), (_))")
431   (OCaml "Big'_int.mult'_big'_int")
432   (Haskell infixl 7 "*")
433   (Scala infixl 8 "*")
434   (Eval infixl 9 "*")
436 code_const divmod_nat
437   (SML "IntInf.divMod/ ((_),/ (_))")
438   (OCaml "Big'_int.quomod'_big'_int")
440   (Scala infixl 8 "/%")
441   (Eval "Integer.div'_mod")
443 code_const "HOL.equal \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
444   (SML "!((_ : IntInf.int) = _)")
445   (OCaml "Big'_int.eq'_big'_int")
446   (Haskell infix 4 "==")
447   (Scala infixl 5 "==")
448   (Eval infixl 6 "=")
450 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
451   (SML "IntInf.<= ((_), (_))")
452   (OCaml "Big'_int.le'_big'_int")
453   (Haskell infix 4 "<=")
454   (Scala infixl 4 "<=")
455   (Eval infixl 6 "<=")
457 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
458   (SML "IntInf.< ((_), (_))")
459   (OCaml "Big'_int.lt'_big'_int")
460   (Haskell infix 4 "<")
461   (Scala infixl 4 "<")
462   (Eval infixl 6 "<")
464 consts_code
465   "0::nat"                     ("0")
466   "1::nat"                     ("1")
467   Suc                          ("(_ +/ 1)")
468   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
469   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
470   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
471   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
474 text {* Evaluation *}
476 lemma [code, code del]:
477   "(Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term) = Code_Evaluation.term_of" ..
479 code_const "Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term"
480   (SML "HOLogic.mk'_number/ HOLogic.natT")
483 text {* Module names *}
485 code_modulename SML
486   Efficient_Nat Arith
488 code_modulename OCaml
489   Efficient_Nat Arith