src/HOL/Library/Efficient_Nat.thy
author haftmann
Thu Jan 14 17:47:39 2010 +0100 (2010-01-14)
changeset 34899 8674bb6f727b
parent 34893 ecdc526af73a
child 34902 780172c006e1
permissions -rw-r--r--
added Scala setup
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Integer Main
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term "Suc"} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code_unfold_post]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 
    33 lemma one_nat_code [code, code_unfold_post]:
    34   "1 = (Numeral1 :: nat)"
    35   by simp
    36 
    37 lemma Suc_code [code]:
    38   "Suc n = n + 1"
    39   by simp
    40 
    41 lemma plus_nat_code [code]:
    42   "n + m = nat (of_nat n + of_nat m)"
    43   by simp
    44 
    45 lemma minus_nat_code [code]:
    46   "n - m = nat (of_nat n - of_nat m)"
    47   by simp
    48 
    49 lemma times_nat_code [code]:
    50   "n * m = nat (of_nat n * of_nat m)"
    51   unfolding of_nat_mult [symmetric] by simp
    52 
    53 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    54   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    55 
    56 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    57   [code del]: "divmod_aux = divmod_nat"
    58 
    59 lemma [code]:
    60   "divmod_nat n m = (if m = 0 then (0, n) else divmod_aux n m)"
    61   unfolding divmod_aux_def divmod_nat_div_mod by simp
    62 
    63 lemma divmod_aux_code [code]:
    64   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    65   unfolding divmod_aux_def divmod_nat_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    66 
    67 lemma eq_nat_code [code]:
    68   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    69   by (simp add: eq)
    70 
    71 lemma eq_nat_refl [code nbe]:
    72   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    73   by (rule HOL.eq_refl)
    74 
    75 lemma less_eq_nat_code [code]:
    76   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    77   by simp
    78 
    79 lemma less_nat_code [code]:
    80   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    81   by simp
    82 
    83 subsection {* Case analysis *}
    84 
    85 text {*
    86   Case analysis on natural numbers is rephrased using a conditional
    87   expression:
    88 *}
    89 
    90 lemma [code, code_unfold]:
    91   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    92   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    93 
    94 
    95 subsection {* Preprocessors *}
    96 
    97 text {*
    98   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
    99   a constructor term. Therefore, all occurrences of this term in a position
   100   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   101   equation or in the arguments of an inductive relation in an introduction
   102   rule) must be eliminated.
   103   This can be accomplished by applying the following transformation rules:
   104 *}
   105 
   106 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
   107   f n \<equiv> if n = 0 then g else h (n - 1)"
   108   by (rule eq_reflection) (cases n, simp_all)
   109 
   110 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   111   by (cases n) simp_all
   112 
   113 text {*
   114   The rules above are built into a preprocessor that is plugged into
   115   the code generator. Since the preprocessor for introduction rules
   116   does not know anything about modes, some of the modes that worked
   117   for the canonical representation of natural numbers may no longer work.
   118 *}
   119 
   120 (*<*)
   121 setup {*
   122 let
   123 
   124 fun remove_suc thy thms =
   125   let
   126     val vname = Name.variant (map fst
   127       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
   128     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   129     fun lhs_of th = snd (Thm.dest_comb
   130       (fst (Thm.dest_comb (cprop_of th))));
   131     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
   132     fun find_vars ct = (case term_of ct of
   133         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   134       | _ $ _ =>
   135         let val (ct1, ct2) = Thm.dest_comb ct
   136         in 
   137           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   138           map (apfst (Thm.capply ct1)) (find_vars ct2)
   139         end
   140       | _ => []);
   141     val eqs = maps
   142       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   143     fun mk_thms (th, (ct, cv')) =
   144       let
   145         val th' =
   146           Thm.implies_elim
   147            (Conv.fconv_rule (Thm.beta_conversion true)
   148              (Drule.instantiate'
   149                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   150                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   151                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   152       in
   153         case map_filter (fn th'' =>
   154             SOME (th'', singleton
   155               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   156           handle THM _ => NONE) thms of
   157             [] => NONE
   158           | thps =>
   159               let val (ths1, ths2) = split_list thps
   160               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   161       end
   162   in get_first mk_thms eqs end;
   163 
   164 fun eqn_suc_base_preproc thy thms =
   165   let
   166     val dest = fst o Logic.dest_equals o prop_of;
   167     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   168   in
   169     if forall (can dest) thms andalso exists (contains_suc o dest) thms
   170       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
   171        else NONE
   172   end;
   173 
   174 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
   175 
   176 fun remove_suc_clause thy thms =
   177   let
   178     val vname = Name.variant (map fst
   179       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   180     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   181       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   182       | find_var _ = NONE;
   183     fun find_thm th =
   184       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   185       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   186   in
   187     case get_first find_thm thms of
   188       NONE => thms
   189     | SOME ((th, th'), (Sucv, v)) =>
   190         let
   191           val cert = cterm_of (Thm.theory_of_thm th);
   192           val th'' = ObjectLogic.rulify (Thm.implies_elim
   193             (Conv.fconv_rule (Thm.beta_conversion true)
   194               (Drule.instantiate' []
   195                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   196                    abstract_over (Sucv,
   197                      HOLogic.dest_Trueprop (prop_of th')))))),
   198                  SOME (cert v)] @{thm Suc_clause}))
   199             (Thm.forall_intr (cert v) th'))
   200         in
   201           remove_suc_clause thy (map (fn th''' =>
   202             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   203         end
   204   end;
   205 
   206 fun clause_suc_preproc thy ths =
   207   let
   208     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   209   in
   210     if forall (can (dest o concl_of)) ths andalso
   211       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   212         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   213     then remove_suc_clause thy ths else ths
   214   end;
   215 in
   216 
   217   Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
   218   #> Codegen.add_preprocessor clause_suc_preproc
   219 
   220 end;
   221 *}
   222 (*>*)
   223 
   224 
   225 subsection {* Target language setup *}
   226 
   227 text {*
   228   For ML, we map @{typ nat} to target language integers, where we
   229   ensure that values are always non-negative.
   230 *}
   231 
   232 code_type nat
   233   (SML "IntInf.int")
   234   (OCaml "Big'_int.big'_int")
   235 
   236 types_code
   237   nat ("int")
   238 attach (term_of) {*
   239 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   240 *}
   241 attach (test) {*
   242 fun gen_nat i =
   243   let val n = random_range 0 i
   244   in (n, fn () => term_of_nat n) end;
   245 *}
   246 
   247 text {*
   248   For Haskell ans Scala we define our own @{typ nat} type.  The reason
   249   is that we have to distinguish type class instances for @{typ nat}
   250   and @{typ int}.
   251 *}
   252 
   253 code_include Haskell "Nat" {*
   254 newtype Nat = Nat Integer deriving (Show, Eq);
   255 
   256 instance Num Nat where {
   257   fromInteger k = Nat (if k >= 0 then k else 0);
   258   Nat n + Nat m = Nat (n + m);
   259   Nat n - Nat m = fromInteger (n - m);
   260   Nat n * Nat m = Nat (n * m);
   261   abs n = n;
   262   signum _ = 1;
   263   negate n = error "negate Nat";
   264 };
   265 
   266 instance Ord Nat where {
   267   Nat n <= Nat m = n <= m;
   268   Nat n < Nat m = n < m;
   269 };
   270 
   271 instance Real Nat where {
   272   toRational (Nat n) = toRational n;
   273 };
   274 
   275 instance Enum Nat where {
   276   toEnum k = fromInteger (toEnum k);
   277   fromEnum (Nat n) = fromEnum n;
   278 };
   279 
   280 instance Integral Nat where {
   281   toInteger (Nat n) = n;
   282   divMod n m = quotRem n m;
   283   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   284 };
   285 *}
   286 
   287 code_reserved Haskell Nat
   288 
   289 code_include Scala "Nat" {*
   290 object Nat {
   291 
   292   def apply(numeral: BigInt): Nat = new Nat(numeral max 0)
   293   def apply(numeral: Int): Nat = Nat(BigInt(numeral))
   294   def apply(numeral: String): Nat = Nat(BigInt(numeral))
   295 
   296 }
   297 
   298 class Nat private(private val value: BigInt) {
   299 
   300   override def hashCode(): Int = this.value.hashCode()
   301 
   302   override def equals(that: Any): Boolean = that match {
   303     case that: Nat => this equals that
   304     case _ => false
   305   }
   306 
   307   override def toString(): String = this.value.toString
   308 
   309   def equals(that: Nat): Boolean = this.value == that.value
   310 
   311   def as_BigInt: BigInt = this.value
   312 
   313   def +(that: Nat): Nat = new Nat(this.value + that.value)
   314   def -(that: Nat): Nat = Nat(this.value + that.value)
   315   def *(that: Nat): Nat = new Nat(this.value * that.value)
   316 
   317   def /%(that: Nat): (Nat, Nat) = if (that.value == 0) (new Nat(0), this)
   318     else {
   319       val (k, l) = this.value /% that.value
   320       (new Nat(k), new Nat(l))
   321     }
   322 
   323   def <=(that: Nat): Boolean = this.value <= that.value
   324 
   325   def <(that: Nat): Boolean = this.value < that.value
   326 
   327 }
   328 *}
   329 
   330 code_reserved Scala Nat
   331 
   332 code_type nat
   333   (Haskell "Nat.Nat")
   334   (Scala "Nat.Nat")
   335 
   336 code_instance nat :: eq
   337   (Haskell -)
   338 
   339 text {*
   340   Natural numerals.
   341 *}
   342 
   343 lemma [code_unfold_post]:
   344   "nat (number_of i) = number_nat_inst.number_of_nat i"
   345   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   346   by (simp add: number_nat_inst.number_of_nat)
   347 
   348 setup {*
   349   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   350     false true Code_Printer.str) ["SML", "OCaml", "Haskell"]
   351   #> Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   352     false true (fn s => (Pretty.block o map Code_Printer.str) ["Nat.Nat", s]) "Scala"
   353 *}
   354 
   355 text {*
   356   Since natural numbers are implemented
   357   using integers in ML, the coercion function @{const "of_nat"} of type
   358   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   359   For the @{const "nat"} function for converting an integer to a natural
   360   number, we give a specific implementation using an ML function that
   361   returns its input value, provided that it is non-negative, and otherwise
   362   returns @{text "0"}.
   363 *}
   364 
   365 definition int :: "nat \<Rightarrow> int" where
   366   [code del]: "int = of_nat"
   367 
   368 lemma int_code' [code]:
   369   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   370   unfolding int_nat_number_of [folded int_def] ..
   371 
   372 lemma nat_code' [code]:
   373   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   374   unfolding nat_number_of_def number_of_is_id neg_def by simp
   375 
   376 lemma of_nat_int [code_unfold_post]:
   377   "of_nat = int" by (simp add: int_def)
   378 
   379 lemma of_nat_aux_int [code_unfold]:
   380   "of_nat_aux (\<lambda>i. i + 1) k 0 = int k"
   381   by (simp add: int_def Nat.of_nat_code)
   382 
   383 code_const int
   384   (SML "_")
   385   (OCaml "_")
   386 
   387 consts_code
   388   int ("(_)")
   389   nat ("\<module>nat")
   390 attach {*
   391 fun nat i = if i < 0 then 0 else i;
   392 *}
   393 
   394 code_const nat
   395   (SML "IntInf.max/ (/0,/ _)")
   396   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   397 
   398 text {* For Haskell ans Scala, things are slightly different again. *}
   399 
   400 code_const int and nat
   401   (Haskell "toInteger" and "fromInteger")
   402   (Scala "!_.as'_BigInt" and "!Nat.Nat((_))")
   403 
   404 text {* Conversion from and to indices. *}
   405 
   406 code_const Code_Numeral.of_nat
   407   (SML "IntInf.toInt")
   408   (OCaml "_")
   409   (Haskell "fromEnum")
   410   (Scala "!_.as'_BigInt")
   411 
   412 code_const Code_Numeral.nat_of
   413   (SML "IntInf.fromInt")
   414   (OCaml "_")
   415   (Haskell "toEnum")
   416   (Scala "!Nat.Nat((_))")
   417 
   418 text {* Using target language arithmetic operations whenever appropriate *}
   419 
   420 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   421   (SML "IntInf.+ ((_), (_))")
   422   (OCaml "Big'_int.add'_big'_int")
   423   (Haskell infixl 6 "+")
   424   (Scala infixl 7 "+")
   425 
   426 code_const "op - \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   427   (Haskell infixl 6 "-")
   428   (Scala infixl 7 "-")
   429 
   430 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   431   (SML "IntInf.* ((_), (_))")
   432   (OCaml "Big'_int.mult'_big'_int")
   433   (Haskell infixl 7 "*")
   434   (Scala infixl 8 "*")
   435 
   436 code_const divmod_aux
   437   (SML "IntInf.divMod/ ((_),/ (_))")
   438   (OCaml "Big'_int.quomod'_big'_int")
   439   (Haskell "divMod")
   440   (Scala infixl 8 "/%")
   441 
   442 code_const divmod_nat
   443   (Haskell "divMod")
   444   (Scala infixl 8 "/%")
   445 
   446 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   447   (SML "!((_ : IntInf.int) = _)")
   448   (OCaml "Big'_int.eq'_big'_int")
   449   (Haskell infixl 4 "==")
   450   (Scala infixl 5 "==")
   451 
   452 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   453   (SML "IntInf.<= ((_), (_))")
   454   (OCaml "Big'_int.le'_big'_int")
   455   (Haskell infix 4 "<=")
   456   (Scala infixl 4 "<=")
   457 
   458 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   459   (SML "IntInf.< ((_), (_))")
   460   (OCaml "Big'_int.lt'_big'_int")
   461   (Haskell infix 4 "<")
   462   (Scala infixl 4 "<")
   463 
   464 consts_code
   465   "0::nat"                     ("0")
   466   "1::nat"                     ("1")
   467   Suc                          ("(_ +/ 1)")
   468   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   469   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   470   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   471   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   472 
   473 
   474 text {* Evaluation *}
   475 
   476 lemma [code, code del]:
   477   "(Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term) = Code_Evaluation.term_of" ..
   478 
   479 code_const "Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term"
   480   (SML "HOLogic.mk'_number/ HOLogic.natT")
   481 
   482 
   483 text {* Module names *}
   484 
   485 code_modulename SML
   486   Efficient_Nat Arith
   487 
   488 code_modulename OCaml
   489   Efficient_Nat Arith
   490 
   491 code_modulename Haskell
   492   Efficient_Nat Arith
   493 
   494 hide const int
   495 
   496 end