src/HOL/Library/Efficient_Nat.thy
author haftmann
Tue Jan 29 10:19:56 2008 +0100 (2008-01-29)
changeset 26009 b6a64fe38634
parent 25967 dd602eb20f3f
child 26100 fbc60cd02ae2
permissions -rw-r--r--
treating division by zero properly
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     4 *)
     5 
     6 header {* Implementation of natural numbers by target-language integers *}
     7 
     8 theory Efficient_Nat
     9 imports Main Code_Integer Code_Index
    10 begin
    11 
    12 text {*
    13   When generating code for functions on natural numbers, the
    14   canonical representation using @{term "0::nat"} and
    15   @{term "Suc"} is unsuitable for computations involving large
    16   numbers.  The efficiency of the generated code can be improved
    17   drastically by implementing natural numbers by target-language
    18   integers.  To do this, just include this theory.
    19 *}
    20 
    21 subsection {* Basic arithmetic *}
    22 
    23 text {*
    24   Most standard arithmetic functions on natural numbers are implemented
    25   using their counterparts on the integers:
    26 *}
    27 
    28 code_datatype number_nat_inst.number_of_nat
    29 
    30 lemma zero_nat_code [code, code unfold]:
    31   "0 = (Numeral0 :: nat)"
    32   by simp
    33 lemmas [code post] = zero_nat_code [symmetric]
    34 
    35 lemma one_nat_code [code, code unfold]:
    36   "1 = (Numeral1 :: nat)"
    37   by simp
    38 lemmas [code post] = one_nat_code [symmetric]
    39 
    40 lemma Suc_code [code]:
    41   "Suc n = n + 1"
    42   by simp
    43 
    44 lemma plus_nat_code [code]:
    45   "n + m = nat (of_nat n + of_nat m)"
    46   by simp
    47 
    48 lemma minus_nat_code [code]:
    49   "n - m = nat (of_nat n - of_nat m)"
    50   by simp
    51 
    52 lemma times_nat_code [code]:
    53   "n * m = nat (of_nat n * of_nat m)"
    54   unfolding of_nat_mult [symmetric] by simp
    55 
    56 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    57   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    58 
    59 definition
    60   div_mod_nat_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat"
    61 where
    62   [code func del]: "div_mod_nat_aux = Divides.divmod"
    63 
    64 lemma [code func]:
    65   "Divides.divmod n m = (if m = 0 then (0, n) else div_mod_nat_aux n m)"
    66   unfolding div_mod_nat_aux_def divmod_def by simp
    67 
    68 lemma div_mod_aux_code [code]:
    69   "div_mod_nat_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    70   unfolding div_mod_nat_aux_def divmod_def zdiv_int [symmetric] zmod_int [symmetric] by simp
    71 
    72 lemma eq_nat_code [code]:
    73   "n = m \<longleftrightarrow> (of_nat n \<Colon> int) = of_nat m"
    74   by simp
    75 
    76 lemma less_eq_nat_code [code]:
    77   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    78   by simp
    79 
    80 lemma less_nat_code [code]:
    81   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    82   by simp
    83 
    84 subsection {* Case analysis *}
    85 
    86 text {*
    87   Case analysis on natural numbers is rephrased using a conditional
    88   expression:
    89 *}
    90 
    91 lemma [code func, code unfold]:
    92   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    93   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    94 
    95 
    96 subsection {* Preprocessors *}
    97 
    98 text {*
    99   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   100   a constructor term. Therefore, all occurrences of this term in a position
   101   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   102   equation or in the arguments of an inductive relation in an introduction
   103   rule) must be eliminated.
   104   This can be accomplished by applying the following transformation rules:
   105 *}
   106 
   107 lemma Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
   108   f n = (if n = 0 then g else h (n - 1))"
   109   by (case_tac n) simp_all
   110 
   111 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   112   by (case_tac n) simp_all
   113 
   114 text {*
   115   The rules above are built into a preprocessor that is plugged into
   116   the code generator. Since the preprocessor for introduction rules
   117   does not know anything about modes, some of the modes that worked
   118   for the canonical representation of natural numbers may no longer work.
   119 *}
   120 
   121 (*<*)
   122 
   123 ML {*
   124 fun remove_suc thy thms =
   125   let
   126     val vname = Name.variant (map fst
   127       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   128     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   129     fun lhs_of th = snd (Thm.dest_comb
   130       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
   131     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
   132     fun find_vars ct = (case term_of ct of
   133         (Const ("Suc", _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   134       | _ $ _ =>
   135         let val (ct1, ct2) = Thm.dest_comb ct
   136         in 
   137           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   138           map (apfst (Thm.capply ct1)) (find_vars ct2)
   139         end
   140       | _ => []);
   141     val eqs = maps
   142       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   143     fun mk_thms (th, (ct, cv')) =
   144       let
   145         val th' =
   146           Thm.implies_elim
   147            (Conv.fconv_rule (Thm.beta_conversion true)
   148              (Drule.instantiate'
   149                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   150                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   151                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   152       in
   153         case map_filter (fn th'' =>
   154             SOME (th'', singleton
   155               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   156           handle THM _ => NONE) thms of
   157             [] => NONE
   158           | thps =>
   159               let val (ths1, ths2) = split_list thps
   160               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   161       end
   162   in
   163     case get_first mk_thms eqs of
   164       NONE => thms
   165     | SOME x => remove_suc thy x
   166   end;
   167 
   168 fun eqn_suc_preproc thy ths =
   169   let
   170     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of;
   171     fun contains_suc t = member (op =) (term_consts t) @{const_name Suc};
   172   in
   173     if forall (can dest) ths andalso
   174       exists (contains_suc o dest) ths
   175     then remove_suc thy ths else ths
   176   end;
   177 
   178 fun remove_suc_clause thy thms =
   179   let
   180     val vname = Name.variant (map fst
   181       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   182     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   183       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   184       | find_var _ = NONE;
   185     fun find_thm th =
   186       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   187       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   188   in
   189     case get_first find_thm thms of
   190       NONE => thms
   191     | SOME ((th, th'), (Sucv, v)) =>
   192         let
   193           val cert = cterm_of (Thm.theory_of_thm th);
   194           val th'' = ObjectLogic.rulify (Thm.implies_elim
   195             (Conv.fconv_rule (Thm.beta_conversion true)
   196               (Drule.instantiate' []
   197                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   198                    abstract_over (Sucv,
   199                      HOLogic.dest_Trueprop (prop_of th')))))),
   200                  SOME (cert v)] @{thm Suc_clause}))
   201             (Thm.forall_intr (cert v) th'))
   202         in
   203           remove_suc_clause thy (map (fn th''' =>
   204             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   205         end
   206   end;
   207 
   208 fun clause_suc_preproc thy ths =
   209   let
   210     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   211   in
   212     if forall (can (dest o concl_of)) ths andalso
   213       exists (fn th => member (op =) (foldr add_term_consts
   214         [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths
   215     then remove_suc_clause thy ths else ths
   216   end;
   217 
   218 fun lift_obj_eq f thy thms =
   219   thms
   220   |> try (
   221     map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
   222     #> f thy
   223     #> map (fn thm => thm RS @{thm eq_reflection})
   224     #> map (Conv.fconv_rule Drule.beta_eta_conversion))
   225   |> the_default thms
   226 *}
   227 
   228 setup {*
   229   Codegen.add_preprocessor eqn_suc_preproc
   230   #> Codegen.add_preprocessor clause_suc_preproc
   231   #> Code.add_preproc ("eqn_Suc", lift_obj_eq eqn_suc_preproc)
   232   #> Code.add_preproc ("clause_Suc", lift_obj_eq clause_suc_preproc)
   233 *}
   234 (*>*)
   235 
   236 subsection {* Target language setup *}
   237 
   238 text {*
   239   For ML, we map @{typ nat} to target language integers, where we
   240   assert that values are always non-negative.
   241 *}
   242 
   243 code_type nat
   244   (SML "int")
   245   (OCaml "Big'_int.big'_int")
   246 
   247 types_code
   248   nat ("int")
   249 attach (term_of) {*
   250 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   251 *}
   252 attach (test) {*
   253 fun gen_nat i =
   254   let val n = random_range 0 i
   255   in (n, fn () => term_of_nat n) end;
   256 *}
   257 
   258 text {*
   259   For Haskell we define our own @{typ nat} type.  The reason
   260   is that we have to distinguish type class instances
   261   for @{typ nat} and @{typ int}.
   262 *}
   263 
   264 code_include Haskell "Nat" {*
   265 newtype Nat = Nat Integer deriving (Show, Eq);
   266 
   267 instance Num Nat where {
   268   fromInteger k = Nat (if k >= 0 then k else 0);
   269   Nat n + Nat m = Nat (n + m);
   270   Nat n - Nat m = fromInteger (n - m);
   271   Nat n * Nat m = Nat (n * m);
   272   abs n = n;
   273   signum _ = 1;
   274   negate n = error "negate Nat";
   275 };
   276 
   277 instance Ord Nat where {
   278   Nat n <= Nat m = n <= m;
   279   Nat n < Nat m = n < m;
   280 };
   281 
   282 instance Real Nat where {
   283   toRational (Nat n) = toRational n;
   284 };
   285 
   286 instance Enum Nat where {
   287   toEnum k = fromInteger (toEnum k);
   288   fromEnum (Nat n) = fromEnum n;
   289 };
   290 
   291 instance Integral Nat where {
   292   toInteger (Nat n) = n;
   293   divMod n m = quotRem n m;
   294   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   295 };
   296 *}
   297 
   298 code_reserved Haskell Nat
   299 
   300 code_type nat
   301   (Haskell "Nat")
   302 
   303 code_instance nat :: eq
   304   (Haskell -)
   305 
   306 text {*
   307   Natural numerals.
   308 *}
   309 
   310 lemma [code inline, symmetric, code post]:
   311   "nat (number_of i) = number_nat_inst.number_of_nat i"
   312   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   313   by (simp add: number_nat_inst.number_of_nat)
   314 
   315 setup {*
   316   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   317     true false) ["SML", "OCaml", "Haskell"]
   318 *}
   319 
   320 text {*
   321   Since natural numbers are implemented
   322   using integers in ML, the coercion function @{const "of_nat"} of type
   323   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   324   For the @{const "nat"} function for converting an integer to a natural
   325   number, we give a specific implementation using an ML function that
   326   returns its input value, provided that it is non-negative, and otherwise
   327   returns @{text "0"}.
   328 *}
   329 
   330 definition
   331   int :: "nat \<Rightarrow> int"
   332 where
   333   [code func del]: "int = of_nat"
   334 
   335 lemma int_code' [code func]:
   336   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   337   unfolding int_nat_number_of [folded int_def] ..
   338 
   339 lemma nat_code' [code func]:
   340   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   341   by auto
   342 
   343 lemma of_nat_int [code unfold]:
   344   "of_nat = int" by (simp add: int_def)
   345 declare of_nat_int [symmetric, code post]
   346 
   347 code_const int
   348   (SML "_")
   349   (OCaml "_")
   350 
   351 consts_code
   352   int ("(_)")
   353   nat ("\<module>nat")
   354 attach {*
   355 fun nat i = if i < 0 then 0 else i;
   356 *}
   357 
   358 code_const nat
   359   (SML "IntInf.max/ (/0,/ _)")
   360   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   361 
   362 text {* For Haskell, things are slightly different again. *}
   363 
   364 code_const int and nat
   365   (Haskell "toInteger" and "fromInteger")
   366 
   367 text {* Conversion from and to indices. *}
   368 
   369 code_const index_of_nat
   370   (SML "IntInf.toInt")
   371   (OCaml "Big'_int.int'_of'_big'_int")
   372   (Haskell "toEnum")
   373 
   374 code_const nat_of_index
   375   (SML "IntInf.fromInt")
   376   (OCaml "Big'_int.big'_int'_of'_int")
   377   (Haskell "fromEnum")
   378 
   379 text {* Using target language arithmetic operations whenever appropriate *}
   380 
   381 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   382   (SML "IntInf.+ ((_), (_))")
   383   (OCaml "Big'_int.add'_big'_int")
   384   (Haskell infixl 6 "+")
   385 
   386 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   387   (SML "IntInf.* ((_), (_))")
   388   (OCaml "Big'_int.mult'_big'_int")
   389   (Haskell infixl 7 "*")
   390 
   391 code_const div_mod_nat_aux
   392   (SML "IntInf.divMod/ ((_),/ (_))")
   393   (OCaml "Big'_int.quomod'_big'_int")
   394   (Haskell "divMod")
   395 
   396 code_const "op = \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   397   (SML "!((_ : IntInf.int) = _)")
   398   (OCaml "Big'_int.eq'_big'_int")
   399   (Haskell infixl 4 "==")
   400 
   401 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   402   (SML "IntInf.<= ((_), (_))")
   403   (OCaml "Big'_int.le'_big'_int")
   404   (Haskell infix 4 "<=")
   405 
   406 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   407   (SML "IntInf.< ((_), (_))")
   408   (OCaml "Big'_int.lt'_big'_int")
   409   (Haskell infix 4 "<")
   410 
   411 consts_code
   412   0                            ("0")
   413   Suc                          ("(_ +/ 1)")
   414   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   415   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   416   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   417   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   418 
   419 
   420 text {* Module names *}
   421 
   422 code_modulename SML
   423   Nat Integer
   424   Divides Integer
   425   Efficient_Nat Integer
   426 
   427 code_modulename OCaml
   428   Nat Integer
   429   Divides Integer
   430   Efficient_Nat Integer
   431 
   432 code_modulename Haskell
   433   Nat Integer
   434   Divides Integer
   435   Efficient_Nat Integer
   436 
   437 hide const int
   438 
   439 end