src/HOL/Library/Efficient_Nat.thy
 author haftmann Thu Sep 25 10:17:22 2008 +0200 (2008-09-25) changeset 28351 abfc66969d1f parent 28346 b8390cd56b8f child 28423 9fc3befd8191 permissions -rw-r--r--
non left-linear equations for nbe
1 (*  Title:      HOL/Library/Efficient_Nat.thy
2     ID:         \$Id\$
3     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
4 *)
6 header {* Implementation of natural numbers by target-language integers *}
8 theory Efficient_Nat
9 imports Plain Code_Index Code_Integer
10 begin
12 text {*
13   When generating code for functions on natural numbers, the
14   canonical representation using @{term "0::nat"} and
15   @{term "Suc"} is unsuitable for computations involving large
16   numbers.  The efficiency of the generated code can be improved
17   drastically by implementing natural numbers by target-language
18   integers.  To do this, just include this theory.
19 *}
21 subsection {* Basic arithmetic *}
23 text {*
24   Most standard arithmetic functions on natural numbers are implemented
25   using their counterparts on the integers:
26 *}
28 code_datatype number_nat_inst.number_of_nat
30 lemma zero_nat_code [code, code unfold]:
31   "0 = (Numeral0 :: nat)"
32   by simp
33 lemmas [code post] = zero_nat_code [symmetric]
35 lemma one_nat_code [code, code unfold]:
36   "1 = (Numeral1 :: nat)"
37   by simp
38 lemmas [code post] = one_nat_code [symmetric]
40 lemma Suc_code [code]:
41   "Suc n = n + 1"
42   by simp
44 lemma plus_nat_code [code]:
45   "n + m = nat (of_nat n + of_nat m)"
46   by simp
48 lemma minus_nat_code [code]:
49   "n - m = nat (of_nat n - of_nat m)"
50   by simp
52 lemma times_nat_code [code]:
53   "n * m = nat (of_nat n * of_nat m)"
54   unfolding of_nat_mult [symmetric] by simp
56 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"}
57   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
59 definition
60   divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat"
61 where
62   [code func del]: "divmod_aux = divmod"
64 lemma [code func]:
65   "divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
66   unfolding divmod_aux_def divmod_div_mod by simp
68 lemma divmod_aux_code [code]:
69   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
70   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
72 lemma eq_nat_code [code]:
73   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
76 lemma eq_nat_refl [code nbe]:
77   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
78   by (rule HOL.eq_refl)
80 lemma less_eq_nat_code [code]:
81   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
82   by simp
84 lemma less_nat_code [code]:
85   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
86   by simp
88 subsection {* Case analysis *}
90 text {*
91   Case analysis on natural numbers is rephrased using a conditional
92   expression:
93 *}
95 lemma [code func, code unfold]:
96   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
97   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
100 subsection {* Preprocessors *}
102 text {*
103   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
104   a constructor term. Therefore, all occurrences of this term in a position
105   where a pattern is expected (i.e.\ on the left-hand side of a recursion
106   equation or in the arguments of an inductive relation in an introduction
107   rule) must be eliminated.
108   This can be accomplished by applying the following transformation rules:
109 *}
111 lemma Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
112   f n = (if n = 0 then g else h (n - 1))"
113   by (case_tac n) simp_all
115 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
116   by (case_tac n) simp_all
118 text {*
119   The rules above are built into a preprocessor that is plugged into
120   the code generator. Since the preprocessor for introduction rules
121   does not know anything about modes, some of the modes that worked
122   for the canonical representation of natural numbers may no longer work.
123 *}
125 (*<*)
126 setup {*
127 let
129 fun remove_suc thy thms =
130   let
131     val vname = Name.variant (map fst
132       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
133     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
134     fun lhs_of th = snd (Thm.dest_comb
135       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
136     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
137     fun find_vars ct = (case term_of ct of
138         (Const ("Suc", _) \$ Var _) => [(cv, snd (Thm.dest_comb ct))]
139       | _ \$ _ =>
140         let val (ct1, ct2) = Thm.dest_comb ct
141         in
142           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
143           map (apfst (Thm.capply ct1)) (find_vars ct2)
144         end
145       | _ => []);
146     val eqs = maps
147       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
148     fun mk_thms (th, (ct, cv')) =
149       let
150         val th' =
151           Thm.implies_elim
152            (Conv.fconv_rule (Thm.beta_conversion true)
153              (Drule.instantiate'
154                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
155                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
156                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
157       in
158         case map_filter (fn th'' =>
159             SOME (th'', singleton
160               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
161           handle THM _ => NONE) thms of
162             [] => NONE
163           | thps =>
164               let val (ths1, ths2) = split_list thps
165               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
166       end
167   in case get_first mk_thms eqs of
168       NONE => thms
169     | SOME x => remove_suc thy x
170   end;
172 fun eqn_suc_preproc thy ths =
173   let
174     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of;
175     fun contains_suc t = member (op =) (term_consts t) @{const_name Suc};
176   in
177     if forall (can dest) ths andalso
178       exists (contains_suc o dest) ths
179     then remove_suc thy ths else ths
180   end;
182 fun remove_suc_clause thy thms =
183   let
184     val vname = Name.variant (map fst
185       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
186     fun find_var (t as Const (@{const_name Suc}, _) \$ (v as Var _)) = SOME (t, v)
187       | find_var (t \$ u) = (case find_var t of NONE => find_var u | x => x)
188       | find_var _ = NONE;
189     fun find_thm th =
190       let val th' = Conv.fconv_rule ObjectLogic.atomize th
191       in Option.map (pair (th, th')) (find_var (prop_of th')) end
192   in
193     case get_first find_thm thms of
194       NONE => thms
195     | SOME ((th, th'), (Sucv, v)) =>
196         let
197           val cert = cterm_of (Thm.theory_of_thm th);
198           val th'' = ObjectLogic.rulify (Thm.implies_elim
199             (Conv.fconv_rule (Thm.beta_conversion true)
200               (Drule.instantiate' []
201                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
202                    abstract_over (Sucv,
203                      HOLogic.dest_Trueprop (prop_of th')))))),
204                  SOME (cert v)] @{thm Suc_clause}))
205             (Thm.forall_intr (cert v) th'))
206         in
207           remove_suc_clause thy (map (fn th''' =>
208             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
209         end
210   end;
212 fun clause_suc_preproc thy ths =
213   let
214     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
215   in
216     if forall (can (dest o concl_of)) ths andalso
217       exists (fn th => member (op =) (foldr add_term_consts
218         [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths
219     then remove_suc_clause thy ths else ths
220   end;
222 fun lift f thy thms1 =
223   let
224     val thms2 = Drule.zero_var_indexes_list thms1;
225     val thms3 = try (map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
226       #> f thy
227       #> map (fn thm => thm RS @{thm eq_reflection})
228       #> map (Conv.fconv_rule Drule.beta_eta_conversion)) thms2;
229     val thms4 = Option.map Drule.zero_var_indexes_list thms3;
230   in case thms4
231    of NONE => NONE
232     | SOME thms4 => if Thm.eq_thms (thms2, thms4) then NONE else SOME thms4
233   end
235 in
239   #> Code.add_functrans ("eqn_Suc", lift eqn_suc_preproc)
240   #> Code.add_functrans ("clause_Suc", lift clause_suc_preproc)
242 end;
243 *}
244 (*>*)
247 subsection {* Target language setup *}
249 text {*
250   For ML, we map @{typ nat} to target language integers, where we
251   assert that values are always non-negative.
252 *}
254 code_type nat
255   (SML "IntInf.int")
256   (OCaml "Big'_int.big'_int")
258 types_code
259   nat ("int")
260 attach (term_of) {*
261 val term_of_nat = HOLogic.mk_number HOLogic.natT;
262 *}
263 attach (test) {*
264 fun gen_nat i =
265   let val n = random_range 0 i
266   in (n, fn () => term_of_nat n) end;
267 *}
269 text {*
270   For Haskell we define our own @{typ nat} type.  The reason
271   is that we have to distinguish type class instances
272   for @{typ nat} and @{typ int}.
273 *}
276 newtype Nat = Nat Integer deriving (Show, Eq);
278 instance Num Nat where {
279   fromInteger k = Nat (if k >= 0 then k else 0);
280   Nat n + Nat m = Nat (n + m);
281   Nat n - Nat m = fromInteger (n - m);
282   Nat n * Nat m = Nat (n * m);
283   abs n = n;
284   signum _ = 1;
285   negate n = error "negate Nat";
286 };
288 instance Ord Nat where {
289   Nat n <= Nat m = n <= m;
290   Nat n < Nat m = n < m;
291 };
293 instance Real Nat where {
294   toRational (Nat n) = toRational n;
295 };
297 instance Enum Nat where {
298   toEnum k = fromInteger (toEnum k);
300 };
302 instance Integral Nat where {
303   toInteger (Nat n) = n;
304   divMod n m = quotRem n m;
305   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
306 };
307 *}
311 code_type nat
314 code_instance nat :: eq
317 text {*
318   Natural numerals.
319 *}
321 lemma [code inline, symmetric, code post]:
322   "nat (number_of i) = number_nat_inst.number_of_nat i"
323   -- {* this interacts as desired with @{thm nat_number_of_def} *}
326 setup {*
328     true false) ["SML", "OCaml", "Haskell"]
329 *}
331 text {*
332   Since natural numbers are implemented
333   using integers in ML, the coercion function @{const "of_nat"} of type
334   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
335   For the @{const "nat"} function for converting an integer to a natural
336   number, we give a specific implementation using an ML function that
337   returns its input value, provided that it is non-negative, and otherwise
338   returns @{text "0"}.
339 *}
341 definition
342   int :: "nat \<Rightarrow> int"
343 where
344   [code func del]: "int = of_nat"
346 lemma int_code' [code func]:
347   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
348   unfolding int_nat_number_of [folded int_def] ..
350 lemma nat_code' [code func]:
351   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
352   by auto
354 lemma of_nat_int [code unfold]:
355   "of_nat = int" by (simp add: int_def)
356 declare of_nat_int [symmetric, code post]
358 code_const int
359   (SML "_")
360   (OCaml "_")
362 consts_code
363   int ("(_)")
364   nat ("\<module>nat")
365 attach {*
366 fun nat i = if i < 0 then 0 else i;
367 *}
369 code_const nat
370   (SML "IntInf.max/ (/0,/ _)")
371   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
373 text {* For Haskell, things are slightly different again. *}
375 code_const int and nat
378 text {* Conversion from and to indices. *}
380 code_const index_of_nat
381   (SML "IntInf.toInt")
382   (OCaml "Big'_int.int'_of'_big'_int")
385 code_const nat_of_index
386   (SML "IntInf.fromInt")
387   (OCaml "Big'_int.big'_int'_of'_int")
390 text {* Using target language arithmetic operations whenever appropriate *}
392 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
393   (SML "IntInf.+ ((_), (_))")
397 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
398   (SML "IntInf.* ((_), (_))")
399   (OCaml "Big'_int.mult'_big'_int")
402 code_const divmod_aux
403   (SML "IntInf.divMod/ ((_),/ (_))")
404   (OCaml "Big'_int.quomod'_big'_int")
407 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
408   (SML "!((_ : IntInf.int) = _)")
409   (OCaml "Big'_int.eq'_big'_int")
412 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
413   (SML "IntInf.<= ((_), (_))")
414   (OCaml "Big'_int.le'_big'_int")
417 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
418   (SML "IntInf.< ((_), (_))")
419   (OCaml "Big'_int.lt'_big'_int")
422 consts_code
423   0                            ("0")
424   Suc                          ("(_ +/ 1)")
425   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
426   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
427   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
428   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
431 text {* Evaluation *}
433 lemma [code func, code func del]:
434   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
436 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
437   (SML "HOLogic.mk'_number/ HOLogic.natT")
440 text {* Module names *}
442 code_modulename SML
443   Nat Integer
444   Divides Integer
445   Efficient_Nat Integer
447 code_modulename OCaml
448   Nat Integer
449   Divides Integer
450   Efficient_Nat Integer