src/HOL/Library/Efficient_Nat.thy
author haftmann
Wed Sep 01 11:09:50 2010 +0200 (2010-09-01)
changeset 38968 e55deaa22fff
parent 38857 97775f3e8722
child 39198 f967a16dfcdd
permissions -rw-r--r--
do not print object frame around Scala includes -- this is in the responsibility of the user
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Integer Main
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term Suc} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code_unfold_post]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 
    33 lemma one_nat_code [code, code_unfold_post]:
    34   "1 = (Numeral1 :: nat)"
    35   by simp
    36 
    37 lemma Suc_code [code]:
    38   "Suc n = n + 1"
    39   by simp
    40 
    41 lemma plus_nat_code [code]:
    42   "n + m = nat (of_nat n + of_nat m)"
    43   by simp
    44 
    45 lemma minus_nat_code [code]:
    46   "n - m = nat (of_nat n - of_nat m)"
    47   by simp
    48 
    49 lemma times_nat_code [code]:
    50   "n * m = nat (of_nat n * of_nat m)"
    51   unfolding of_nat_mult [symmetric] by simp
    52 
    53 lemma divmod_nat_code [code]:
    54   "divmod_nat n m = prod_fun nat nat (pdivmod (of_nat n) (of_nat m))"
    55   by (simp add: prod_fun_def split_def pdivmod_def nat_div_distrib nat_mod_distrib divmod_nat_div_mod)
    56 
    57 lemma eq_nat_code [code]:
    58   "HOL.equal n m \<longleftrightarrow> HOL.equal (of_nat n \<Colon> int) (of_nat m)"
    59   by (simp add: equal)
    60 
    61 lemma eq_nat_refl [code nbe]:
    62   "HOL.equal (n::nat) n \<longleftrightarrow> True"
    63   by (rule equal_refl)
    64 
    65 lemma less_eq_nat_code [code]:
    66   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    67   by simp
    68 
    69 lemma less_nat_code [code]:
    70   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    71   by simp
    72 
    73 subsection {* Case analysis *}
    74 
    75 text {*
    76   Case analysis on natural numbers is rephrased using a conditional
    77   expression:
    78 *}
    79 
    80 lemma [code, code_unfold]:
    81   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    82   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    83 
    84 
    85 subsection {* Preprocessors *}
    86 
    87 text {*
    88   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
    89   a constructor term. Therefore, all occurrences of this term in a position
    90   where a pattern is expected (i.e.\ on the left-hand side of a recursion
    91   equation or in the arguments of an inductive relation in an introduction
    92   rule) must be eliminated.
    93   This can be accomplished by applying the following transformation rules:
    94 *}
    95 
    96 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
    97   f n \<equiv> if n = 0 then g else h (n - 1)"
    98   by (rule eq_reflection) (cases n, simp_all)
    99 
   100 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   101   by (cases n) simp_all
   102 
   103 text {*
   104   The rules above are built into a preprocessor that is plugged into
   105   the code generator. Since the preprocessor for introduction rules
   106   does not know anything about modes, some of the modes that worked
   107   for the canonical representation of natural numbers may no longer work.
   108 *}
   109 
   110 (*<*)
   111 setup {*
   112 let
   113 
   114 fun remove_suc thy thms =
   115   let
   116     val vname = Name.variant (map fst
   117       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
   118     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   119     fun lhs_of th = snd (Thm.dest_comb
   120       (fst (Thm.dest_comb (cprop_of th))));
   121     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
   122     fun find_vars ct = (case term_of ct of
   123         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   124       | _ $ _ =>
   125         let val (ct1, ct2) = Thm.dest_comb ct
   126         in 
   127           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   128           map (apfst (Thm.capply ct1)) (find_vars ct2)
   129         end
   130       | _ => []);
   131     val eqs = maps
   132       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   133     fun mk_thms (th, (ct, cv')) =
   134       let
   135         val th' =
   136           Thm.implies_elim
   137            (Conv.fconv_rule (Thm.beta_conversion true)
   138              (Drule.instantiate'
   139                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   140                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   141                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   142       in
   143         case map_filter (fn th'' =>
   144             SOME (th'', singleton
   145               (Variable.trade (K (fn [th'''] => [th''' RS th']))
   146                 (Variable.global_thm_context th'')) th'')
   147           handle THM _ => NONE) thms of
   148             [] => NONE
   149           | thps =>
   150               let val (ths1, ths2) = split_list thps
   151               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   152       end
   153   in get_first mk_thms eqs end;
   154 
   155 fun eqn_suc_base_preproc thy thms =
   156   let
   157     val dest = fst o Logic.dest_equals o prop_of;
   158     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   159   in
   160     if forall (can dest) thms andalso exists (contains_suc o dest) thms
   161       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
   162        else NONE
   163   end;
   164 
   165 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
   166 
   167 fun remove_suc_clause thy thms =
   168   let
   169     val vname = Name.variant (map fst
   170       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   171     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   172       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   173       | find_var _ = NONE;
   174     fun find_thm th =
   175       let val th' = Conv.fconv_rule Object_Logic.atomize th
   176       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   177   in
   178     case get_first find_thm thms of
   179       NONE => thms
   180     | SOME ((th, th'), (Sucv, v)) =>
   181         let
   182           val cert = cterm_of (Thm.theory_of_thm th);
   183           val th'' = Object_Logic.rulify (Thm.implies_elim
   184             (Conv.fconv_rule (Thm.beta_conversion true)
   185               (Drule.instantiate' []
   186                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   187                    abstract_over (Sucv,
   188                      HOLogic.dest_Trueprop (prop_of th')))))),
   189                  SOME (cert v)] @{thm Suc_clause}))
   190             (Thm.forall_intr (cert v) th'))
   191         in
   192           remove_suc_clause thy (map (fn th''' =>
   193             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   194         end
   195   end;
   196 
   197 fun clause_suc_preproc thy ths =
   198   let
   199     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   200   in
   201     if forall (can (dest o concl_of)) ths andalso
   202       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   203         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   204     then remove_suc_clause thy ths else ths
   205   end;
   206 in
   207 
   208   Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
   209   #> Codegen.add_preprocessor clause_suc_preproc
   210 
   211 end;
   212 *}
   213 (*>*)
   214 
   215 
   216 subsection {* Target language setup *}
   217 
   218 text {*
   219   For ML, we map @{typ nat} to target language integers, where we
   220   ensure that values are always non-negative.
   221 *}
   222 
   223 code_type nat
   224   (SML "IntInf.int")
   225   (OCaml "Big'_int.big'_int")
   226   (Eval "int")
   227 
   228 types_code
   229   nat ("int")
   230 attach (term_of) {*
   231 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   232 *}
   233 attach (test) {*
   234 fun gen_nat i =
   235   let val n = random_range 0 i
   236   in (n, fn () => term_of_nat n) end;
   237 *}
   238 
   239 text {*
   240   For Haskell ans Scala we define our own @{typ nat} type.  The reason
   241   is that we have to distinguish type class instances for @{typ nat}
   242   and @{typ int}.
   243 *}
   244 
   245 code_include Haskell "Nat"
   246 {*newtype Nat = Nat Integer deriving (Eq, Show, Read);
   247 
   248 instance Num Nat where {
   249   fromInteger k = Nat (if k >= 0 then k else 0);
   250   Nat n + Nat m = Nat (n + m);
   251   Nat n - Nat m = fromInteger (n - m);
   252   Nat n * Nat m = Nat (n * m);
   253   abs n = n;
   254   signum _ = 1;
   255   negate n = error "negate Nat";
   256 };
   257 
   258 instance Ord Nat where {
   259   Nat n <= Nat m = n <= m;
   260   Nat n < Nat m = n < m;
   261 };
   262 
   263 instance Real Nat where {
   264   toRational (Nat n) = toRational n;
   265 };
   266 
   267 instance Enum Nat where {
   268   toEnum k = fromInteger (toEnum k);
   269   fromEnum (Nat n) = fromEnum n;
   270 };
   271 
   272 instance Integral Nat where {
   273   toInteger (Nat n) = n;
   274   divMod n m = quotRem n m;
   275   quotRem (Nat n) (Nat m)
   276     | (m == 0) = (0, Nat n)
   277     | otherwise = (Nat k, Nat l) where (k, l) = quotRem n m;
   278 };
   279 *}
   280 
   281 code_reserved Haskell Nat
   282 
   283 code_include Scala "Nat"
   284 {*object Nat {
   285 
   286   def apply(numeral: BigInt): Nat = new Nat(numeral max 0)
   287   def apply(numeral: Int): Nat = Nat(BigInt(numeral))
   288   def apply(numeral: String): Nat = Nat(BigInt(numeral))
   289 
   290 }
   291 
   292 class Nat private(private val value: BigInt) {
   293 
   294   override def hashCode(): Int = this.value.hashCode()
   295 
   296   override def equals(that: Any): Boolean = that match {
   297     case that: Nat => this equals that
   298     case _ => false
   299   }
   300 
   301   override def toString(): String = this.value.toString
   302 
   303   def equals(that: Nat): Boolean = this.value == that.value
   304 
   305   def as_BigInt: BigInt = this.value
   306   def as_Int: Int = if (this.value >= Int.MinValue && this.value <= Int.MaxValue)
   307       this.value.intValue
   308     else error("Int value out of range: " + this.value.toString)
   309 
   310   def +(that: Nat): Nat = new Nat(this.value + that.value)
   311   def -(that: Nat): Nat = Nat(this.value - that.value)
   312   def *(that: Nat): Nat = new Nat(this.value * that.value)
   313 
   314   def /%(that: Nat): (Nat, Nat) = if (that.value == 0) (new Nat(0), this)
   315     else {
   316       val (k, l) = this.value /% that.value
   317       (new Nat(k), new Nat(l))
   318     }
   319 
   320   def <=(that: Nat): Boolean = this.value <= that.value
   321 
   322   def <(that: Nat): Boolean = this.value < that.value
   323 
   324 }
   325 *}
   326 
   327 code_reserved Scala Nat
   328 
   329 code_type nat
   330   (Haskell "Nat.Nat")
   331   (Scala "Nat")
   332 
   333 code_instance nat :: equal
   334   (Haskell -)
   335 
   336 text {*
   337   Natural numerals.
   338 *}
   339 
   340 lemma [code_unfold_post]:
   341   "nat (number_of i) = number_nat_inst.number_of_nat i"
   342   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   343   by (simp add: number_nat_inst.number_of_nat)
   344 
   345 setup {*
   346   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   347     false Code_Printer.literal_positive_numeral) ["SML", "OCaml", "Haskell", "Scala"]
   348 *}
   349 
   350 text {*
   351   Since natural numbers are implemented
   352   using integers in ML, the coercion function @{const "of_nat"} of type
   353   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   354   For the @{const nat} function for converting an integer to a natural
   355   number, we give a specific implementation using an ML function that
   356   returns its input value, provided that it is non-negative, and otherwise
   357   returns @{text "0"}.
   358 *}
   359 
   360 definition int :: "nat \<Rightarrow> int" where
   361   [code del]: "int = of_nat"
   362 
   363 lemma int_code' [code]:
   364   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   365   unfolding int_nat_number_of [folded int_def] ..
   366 
   367 lemma nat_code' [code]:
   368   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   369   unfolding nat_number_of_def number_of_is_id neg_def by simp
   370 
   371 lemma of_nat_int [code_unfold_post]:
   372   "of_nat = int" by (simp add: int_def)
   373 
   374 lemma of_nat_aux_int [code_unfold]:
   375   "of_nat_aux (\<lambda>i. i + 1) k 0 = int k"
   376   by (simp add: int_def Nat.of_nat_code)
   377 
   378 code_const int
   379   (SML "_")
   380   (OCaml "_")
   381 
   382 consts_code
   383   int ("(_)")
   384   nat ("\<module>nat")
   385 attach {*
   386 fun nat i = if i < 0 then 0 else i;
   387 *}
   388 
   389 code_const nat
   390   (SML "IntInf.max/ (/0,/ _)")
   391   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   392   (Eval "Integer.max/ _/ 0")
   393 
   394 text {* For Haskell and Scala, things are slightly different again. *}
   395 
   396 code_const int and nat
   397   (Haskell "toInteger" and "fromInteger")
   398   (Scala "!_.as'_BigInt" and "Nat")
   399 
   400 text {* Conversion from and to code numerals. *}
   401 
   402 code_const Code_Numeral.of_nat
   403   (SML "IntInf.toInt")
   404   (OCaml "_")
   405   (Haskell "!(fromInteger/ ./ toInteger)")
   406   (Scala "!Natural(_.as'_BigInt)")
   407   (Eval "_")
   408 
   409 code_const Code_Numeral.nat_of
   410   (SML "IntInf.fromInt")
   411   (OCaml "_")
   412   (Haskell "!(fromInteger/ ./ toInteger)")
   413   (Scala "!Nat(_.as'_BigInt)")
   414   (Eval "_")
   415 
   416 text {* Using target language arithmetic operations whenever appropriate *}
   417 
   418 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   419   (SML "IntInf.+ ((_), (_))")
   420   (OCaml "Big'_int.add'_big'_int")
   421   (Haskell infixl 6 "+")
   422   (Scala infixl 7 "+")
   423   (Eval infixl 8 "+")
   424 
   425 code_const "op - \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   426   (Haskell infixl 6 "-")
   427   (Scala infixl 7 "-")
   428 
   429 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   430   (SML "IntInf.* ((_), (_))")
   431   (OCaml "Big'_int.mult'_big'_int")
   432   (Haskell infixl 7 "*")
   433   (Scala infixl 8 "*")
   434   (Eval infixl 9 "*")
   435 
   436 code_const divmod_nat
   437   (SML "IntInf.divMod/ ((_),/ (_))")
   438   (OCaml "Big'_int.quomod'_big'_int")
   439   (Haskell "divMod")
   440   (Scala infixl 8 "/%")
   441   (Eval "Integer.div'_mod")
   442 
   443 code_const "HOL.equal \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   444   (SML "!((_ : IntInf.int) = _)")
   445   (OCaml "Big'_int.eq'_big'_int")
   446   (Haskell infixl 4 "==")
   447   (Scala infixl 5 "==")
   448   (Eval infixl 6 "=")
   449 
   450 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   451   (SML "IntInf.<= ((_), (_))")
   452   (OCaml "Big'_int.le'_big'_int")
   453   (Haskell infix 4 "<=")
   454   (Scala infixl 4 "<=")
   455   (Eval infixl 6 "<=")
   456 
   457 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   458   (SML "IntInf.< ((_), (_))")
   459   (OCaml "Big'_int.lt'_big'_int")
   460   (Haskell infix 4 "<")
   461   (Scala infixl 4 "<")
   462   (Eval infixl 6 "<")
   463 
   464 consts_code
   465   "0::nat"                     ("0")
   466   "1::nat"                     ("1")
   467   Suc                          ("(_ +/ 1)")
   468   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   469   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   470   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   471   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   472 
   473 
   474 text {* Evaluation *}
   475 
   476 lemma [code, code del]:
   477   "(Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term) = Code_Evaluation.term_of" ..
   478 
   479 code_const "Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term"
   480   (SML "HOLogic.mk'_number/ HOLogic.natT")
   481 
   482 
   483 text {* Module names *}
   484 
   485 code_modulename SML
   486   Efficient_Nat Arith
   487 
   488 code_modulename OCaml
   489   Efficient_Nat Arith
   490 
   491 code_modulename Haskell
   492   Efficient_Nat Arith
   493 
   494 hide_const int
   495 
   496 end