src/HOL/Library/Efficient_Nat.thy
author wenzelm
Sun Mar 07 12:19:47 2010 +0100 (2010-03-07)
changeset 35625 9c818cab0dd0
parent 34944 970e1466028d
child 35689 c3bef0c972d7
permissions -rw-r--r--
modernized structure Object_Logic;
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Integer Main
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term "Suc"} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code_unfold_post]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 
    33 lemma one_nat_code [code, code_unfold_post]:
    34   "1 = (Numeral1 :: nat)"
    35   by simp
    36 
    37 lemma Suc_code [code]:
    38   "Suc n = n + 1"
    39   by simp
    40 
    41 lemma plus_nat_code [code]:
    42   "n + m = nat (of_nat n + of_nat m)"
    43   by simp
    44 
    45 lemma minus_nat_code [code]:
    46   "n - m = nat (of_nat n - of_nat m)"
    47   by simp
    48 
    49 lemma times_nat_code [code]:
    50   "n * m = nat (of_nat n * of_nat m)"
    51   unfolding of_nat_mult [symmetric] by simp
    52 
    53 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    54   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    55 
    56 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    57   [code del]: "divmod_aux = divmod_nat"
    58 
    59 lemma [code]:
    60   "divmod_nat n m = (if m = 0 then (0, n) else divmod_aux n m)"
    61   unfolding divmod_aux_def divmod_nat_div_mod by simp
    62 
    63 lemma divmod_aux_code [code]:
    64   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    65   unfolding divmod_aux_def divmod_nat_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    66 
    67 lemma eq_nat_code [code]:
    68   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    69   by (simp add: eq)
    70 
    71 lemma eq_nat_refl [code nbe]:
    72   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    73   by (rule HOL.eq_refl)
    74 
    75 lemma less_eq_nat_code [code]:
    76   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    77   by simp
    78 
    79 lemma less_nat_code [code]:
    80   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    81   by simp
    82 
    83 subsection {* Case analysis *}
    84 
    85 text {*
    86   Case analysis on natural numbers is rephrased using a conditional
    87   expression:
    88 *}
    89 
    90 lemma [code, code_unfold]:
    91   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    92   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    93 
    94 
    95 subsection {* Preprocessors *}
    96 
    97 text {*
    98   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
    99   a constructor term. Therefore, all occurrences of this term in a position
   100   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   101   equation or in the arguments of an inductive relation in an introduction
   102   rule) must be eliminated.
   103   This can be accomplished by applying the following transformation rules:
   104 *}
   105 
   106 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
   107   f n \<equiv> if n = 0 then g else h (n - 1)"
   108   by (rule eq_reflection) (cases n, simp_all)
   109 
   110 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   111   by (cases n) simp_all
   112 
   113 text {*
   114   The rules above are built into a preprocessor that is plugged into
   115   the code generator. Since the preprocessor for introduction rules
   116   does not know anything about modes, some of the modes that worked
   117   for the canonical representation of natural numbers may no longer work.
   118 *}
   119 
   120 (*<*)
   121 setup {*
   122 let
   123 
   124 fun remove_suc thy thms =
   125   let
   126     val vname = Name.variant (map fst
   127       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
   128     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   129     fun lhs_of th = snd (Thm.dest_comb
   130       (fst (Thm.dest_comb (cprop_of th))));
   131     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
   132     fun find_vars ct = (case term_of ct of
   133         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   134       | _ $ _ =>
   135         let val (ct1, ct2) = Thm.dest_comb ct
   136         in 
   137           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   138           map (apfst (Thm.capply ct1)) (find_vars ct2)
   139         end
   140       | _ => []);
   141     val eqs = maps
   142       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   143     fun mk_thms (th, (ct, cv')) =
   144       let
   145         val th' =
   146           Thm.implies_elim
   147            (Conv.fconv_rule (Thm.beta_conversion true)
   148              (Drule.instantiate'
   149                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   150                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   151                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   152       in
   153         case map_filter (fn th'' =>
   154             SOME (th'', singleton
   155               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   156           handle THM _ => NONE) thms of
   157             [] => NONE
   158           | thps =>
   159               let val (ths1, ths2) = split_list thps
   160               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   161       end
   162   in get_first mk_thms eqs end;
   163 
   164 fun eqn_suc_base_preproc thy thms =
   165   let
   166     val dest = fst o Logic.dest_equals o prop_of;
   167     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   168   in
   169     if forall (can dest) thms andalso exists (contains_suc o dest) thms
   170       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
   171        else NONE
   172   end;
   173 
   174 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
   175 
   176 fun remove_suc_clause thy thms =
   177   let
   178     val vname = Name.variant (map fst
   179       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   180     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   181       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   182       | find_var _ = NONE;
   183     fun find_thm th =
   184       let val th' = Conv.fconv_rule Object_Logic.atomize th
   185       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   186   in
   187     case get_first find_thm thms of
   188       NONE => thms
   189     | SOME ((th, th'), (Sucv, v)) =>
   190         let
   191           val cert = cterm_of (Thm.theory_of_thm th);
   192           val th'' = Object_Logic.rulify (Thm.implies_elim
   193             (Conv.fconv_rule (Thm.beta_conversion true)
   194               (Drule.instantiate' []
   195                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   196                    abstract_over (Sucv,
   197                      HOLogic.dest_Trueprop (prop_of th')))))),
   198                  SOME (cert v)] @{thm Suc_clause}))
   199             (Thm.forall_intr (cert v) th'))
   200         in
   201           remove_suc_clause thy (map (fn th''' =>
   202             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   203         end
   204   end;
   205 
   206 fun clause_suc_preproc thy ths =
   207   let
   208     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   209   in
   210     if forall (can (dest o concl_of)) ths andalso
   211       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   212         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   213     then remove_suc_clause thy ths else ths
   214   end;
   215 in
   216 
   217   Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
   218   #> Codegen.add_preprocessor clause_suc_preproc
   219 
   220 end;
   221 *}
   222 (*>*)
   223 
   224 
   225 subsection {* Target language setup *}
   226 
   227 text {*
   228   For ML, we map @{typ nat} to target language integers, where we
   229   ensure that values are always non-negative.
   230 *}
   231 
   232 code_type nat
   233   (SML "IntInf.int")
   234   (OCaml "Big'_int.big'_int")
   235 
   236 types_code
   237   nat ("int")
   238 attach (term_of) {*
   239 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   240 *}
   241 attach (test) {*
   242 fun gen_nat i =
   243   let val n = random_range 0 i
   244   in (n, fn () => term_of_nat n) end;
   245 *}
   246 
   247 text {*
   248   For Haskell ans Scala we define our own @{typ nat} type.  The reason
   249   is that we have to distinguish type class instances for @{typ nat}
   250   and @{typ int}.
   251 *}
   252 
   253 code_include Haskell "Nat" {*
   254 newtype Nat = Nat Integer deriving (Show, Eq);
   255 
   256 instance Num Nat where {
   257   fromInteger k = Nat (if k >= 0 then k else 0);
   258   Nat n + Nat m = Nat (n + m);
   259   Nat n - Nat m = fromInteger (n - m);
   260   Nat n * Nat m = Nat (n * m);
   261   abs n = n;
   262   signum _ = 1;
   263   negate n = error "negate Nat";
   264 };
   265 
   266 instance Ord Nat where {
   267   Nat n <= Nat m = n <= m;
   268   Nat n < Nat m = n < m;
   269 };
   270 
   271 instance Real Nat where {
   272   toRational (Nat n) = toRational n;
   273 };
   274 
   275 instance Enum Nat where {
   276   toEnum k = fromInteger (toEnum k);
   277   fromEnum (Nat n) = fromEnum n;
   278 };
   279 
   280 instance Integral Nat where {
   281   toInteger (Nat n) = n;
   282   divMod n m = quotRem n m;
   283   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   284 };
   285 *}
   286 
   287 code_reserved Haskell Nat
   288 
   289 code_include Scala "Nat" {*
   290 import scala.Math
   291 
   292 object Nat {
   293 
   294   def apply(numeral: BigInt): Nat = new Nat(numeral max 0)
   295   def apply(numeral: Int): Nat = Nat(BigInt(numeral))
   296   def apply(numeral: String): Nat = Nat(BigInt(numeral))
   297 
   298 }
   299 
   300 class Nat private(private val value: BigInt) {
   301 
   302   override def hashCode(): Int = this.value.hashCode()
   303 
   304   override def equals(that: Any): Boolean = that match {
   305     case that: Nat => this equals that
   306     case _ => false
   307   }
   308 
   309   override def toString(): String = this.value.toString
   310 
   311   def equals(that: Nat): Boolean = this.value == that.value
   312 
   313   def as_BigInt: BigInt = this.value
   314   def as_Int: Int = if (this.value >= Math.MAX_INT && this.value <= Math.MAX_INT)
   315       this.value.intValue
   316     else error("Int value too big:" + this.value.toString)
   317 
   318   def +(that: Nat): Nat = new Nat(this.value + that.value)
   319   def -(that: Nat): Nat = Nat(this.value + that.value)
   320   def *(that: Nat): Nat = new Nat(this.value * that.value)
   321 
   322   def /%(that: Nat): (Nat, Nat) = if (that.value == 0) (new Nat(0), this)
   323     else {
   324       val (k, l) = this.value /% that.value
   325       (new Nat(k), new Nat(l))
   326     }
   327 
   328   def <=(that: Nat): Boolean = this.value <= that.value
   329 
   330   def <(that: Nat): Boolean = this.value < that.value
   331 
   332 }
   333 *}
   334 
   335 code_reserved Scala Nat
   336 
   337 code_type nat
   338   (Haskell "Nat.Nat")
   339   (Scala "Nat.Nat")
   340 
   341 code_instance nat :: eq
   342   (Haskell -)
   343 
   344 text {*
   345   Natural numerals.
   346 *}
   347 
   348 lemma [code_unfold_post]:
   349   "nat (number_of i) = number_nat_inst.number_of_nat i"
   350   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   351   by (simp add: number_nat_inst.number_of_nat)
   352 
   353 setup {*
   354   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   355     false Code_Printer.literal_positive_numeral) ["SML", "OCaml", "Haskell"]
   356   #> Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   357     false Code_Printer.literal_positive_numeral "Scala"
   358 *}
   359 
   360 text {*
   361   Since natural numbers are implemented
   362   using integers in ML, the coercion function @{const "of_nat"} of type
   363   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   364   For the @{const "nat"} function for converting an integer to a natural
   365   number, we give a specific implementation using an ML function that
   366   returns its input value, provided that it is non-negative, and otherwise
   367   returns @{text "0"}.
   368 *}
   369 
   370 definition int :: "nat \<Rightarrow> int" where
   371   [code del]: "int = of_nat"
   372 
   373 lemma int_code' [code]:
   374   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   375   unfolding int_nat_number_of [folded int_def] ..
   376 
   377 lemma nat_code' [code]:
   378   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   379   unfolding nat_number_of_def number_of_is_id neg_def by simp
   380 
   381 lemma of_nat_int [code_unfold_post]:
   382   "of_nat = int" by (simp add: int_def)
   383 
   384 lemma of_nat_aux_int [code_unfold]:
   385   "of_nat_aux (\<lambda>i. i + 1) k 0 = int k"
   386   by (simp add: int_def Nat.of_nat_code)
   387 
   388 code_const int
   389   (SML "_")
   390   (OCaml "_")
   391 
   392 consts_code
   393   int ("(_)")
   394   nat ("\<module>nat")
   395 attach {*
   396 fun nat i = if i < 0 then 0 else i;
   397 *}
   398 
   399 code_const nat
   400   (SML "IntInf.max/ (/0,/ _)")
   401   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   402 
   403 text {* For Haskell ans Scala, things are slightly different again. *}
   404 
   405 code_const int and nat
   406   (Haskell "toInteger" and "fromInteger")
   407   (Scala "!_.as'_BigInt" and "!Nat.Nat((_))")
   408 
   409 text {* Conversion from and to indices. *}
   410 
   411 code_const Code_Numeral.of_nat
   412   (SML "IntInf.toInt")
   413   (OCaml "_")
   414   (Haskell "fromEnum")
   415   (Scala "!_.as'_Int")
   416 
   417 code_const Code_Numeral.nat_of
   418   (SML "IntInf.fromInt")
   419   (OCaml "_")
   420   (Haskell "toEnum")
   421   (Scala "!Nat.Nat((_))")
   422 
   423 text {* Using target language arithmetic operations whenever appropriate *}
   424 
   425 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   426   (SML "IntInf.+ ((_), (_))")
   427   (OCaml "Big'_int.add'_big'_int")
   428   (Haskell infixl 6 "+")
   429   (Scala infixl 7 "+")
   430 
   431 code_const "op - \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   432   (Haskell infixl 6 "-")
   433   (Scala infixl 7 "-")
   434 
   435 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   436   (SML "IntInf.* ((_), (_))")
   437   (OCaml "Big'_int.mult'_big'_int")
   438   (Haskell infixl 7 "*")
   439   (Scala infixl 8 "*")
   440 
   441 code_const divmod_aux
   442   (SML "IntInf.divMod/ ((_),/ (_))")
   443   (OCaml "Big'_int.quomod'_big'_int")
   444   (Haskell "divMod")
   445   (Scala infixl 8 "/%")
   446 
   447 code_const divmod_nat
   448   (Haskell "divMod")
   449   (Scala infixl 8 "/%")
   450 
   451 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   452   (SML "!((_ : IntInf.int) = _)")
   453   (OCaml "Big'_int.eq'_big'_int")
   454   (Haskell infixl 4 "==")
   455   (Scala infixl 5 "==")
   456 
   457 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   458   (SML "IntInf.<= ((_), (_))")
   459   (OCaml "Big'_int.le'_big'_int")
   460   (Haskell infix 4 "<=")
   461   (Scala infixl 4 "<=")
   462 
   463 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   464   (SML "IntInf.< ((_), (_))")
   465   (OCaml "Big'_int.lt'_big'_int")
   466   (Haskell infix 4 "<")
   467   (Scala infixl 4 "<")
   468 
   469 consts_code
   470   "0::nat"                     ("0")
   471   "1::nat"                     ("1")
   472   Suc                          ("(_ +/ 1)")
   473   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   474   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   475   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   476   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   477 
   478 
   479 text {* Evaluation *}
   480 
   481 lemma [code, code del]:
   482   "(Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term) = Code_Evaluation.term_of" ..
   483 
   484 code_const "Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term"
   485   (SML "HOLogic.mk'_number/ HOLogic.natT")
   486 
   487 
   488 text {* Module names *}
   489 
   490 code_modulename SML
   491   Efficient_Nat Arith
   492 
   493 code_modulename OCaml
   494   Efficient_Nat Arith
   495 
   496 code_modulename Haskell
   497   Efficient_Nat Arith
   498 
   499 hide const int
   500 
   501 end