src/HOL/Library/Efficient_Nat.thy
author haftmann
Fri Jul 23 10:58:13 2010 +0200 (2010-07-23)
changeset 37947 844977c7abeb
parent 37892 3d8857f42a64
child 37958 9728342bcd56
permissions -rw-r--r--
avoid unreliable Haskell Int type
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Integer Main
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term Suc} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code_unfold_post]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 
    33 lemma one_nat_code [code, code_unfold_post]:
    34   "1 = (Numeral1 :: nat)"
    35   by simp
    36 
    37 lemma Suc_code [code]:
    38   "Suc n = n + 1"
    39   by simp
    40 
    41 lemma plus_nat_code [code]:
    42   "n + m = nat (of_nat n + of_nat m)"
    43   by simp
    44 
    45 lemma minus_nat_code [code]:
    46   "n - m = nat (of_nat n - of_nat m)"
    47   by simp
    48 
    49 lemma times_nat_code [code]:
    50   "n * m = nat (of_nat n * of_nat m)"
    51   unfolding of_nat_mult [symmetric] by simp
    52 
    53 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    54   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    55 
    56 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    57   [code del]: "divmod_aux = divmod_nat"
    58 
    59 lemma [code]:
    60   "divmod_nat n m = (if m = 0 then (0, n) else divmod_aux n m)"
    61   unfolding divmod_aux_def divmod_nat_div_mod by simp
    62 
    63 lemma divmod_aux_code [code]:
    64   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    65   unfolding divmod_aux_def divmod_nat_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    66 
    67 lemma eq_nat_code [code]:
    68   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    69   by (simp add: eq)
    70 
    71 lemma eq_nat_refl [code nbe]:
    72   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    73   by (rule HOL.eq_refl)
    74 
    75 lemma less_eq_nat_code [code]:
    76   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    77   by simp
    78 
    79 lemma less_nat_code [code]:
    80   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    81   by simp
    82 
    83 subsection {* Case analysis *}
    84 
    85 text {*
    86   Case analysis on natural numbers is rephrased using a conditional
    87   expression:
    88 *}
    89 
    90 lemma [code, code_unfold]:
    91   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    92   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    93 
    94 
    95 subsection {* Preprocessors *}
    96 
    97 text {*
    98   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
    99   a constructor term. Therefore, all occurrences of this term in a position
   100   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   101   equation or in the arguments of an inductive relation in an introduction
   102   rule) must be eliminated.
   103   This can be accomplished by applying the following transformation rules:
   104 *}
   105 
   106 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
   107   f n \<equiv> if n = 0 then g else h (n - 1)"
   108   by (rule eq_reflection) (cases n, simp_all)
   109 
   110 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   111   by (cases n) simp_all
   112 
   113 text {*
   114   The rules above are built into a preprocessor that is plugged into
   115   the code generator. Since the preprocessor for introduction rules
   116   does not know anything about modes, some of the modes that worked
   117   for the canonical representation of natural numbers may no longer work.
   118 *}
   119 
   120 (*<*)
   121 setup {*
   122 let
   123 
   124 fun remove_suc thy thms =
   125   let
   126     val vname = Name.variant (map fst
   127       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
   128     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   129     fun lhs_of th = snd (Thm.dest_comb
   130       (fst (Thm.dest_comb (cprop_of th))));
   131     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
   132     fun find_vars ct = (case term_of ct of
   133         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   134       | _ $ _ =>
   135         let val (ct1, ct2) = Thm.dest_comb ct
   136         in 
   137           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   138           map (apfst (Thm.capply ct1)) (find_vars ct2)
   139         end
   140       | _ => []);
   141     val eqs = maps
   142       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   143     fun mk_thms (th, (ct, cv')) =
   144       let
   145         val th' =
   146           Thm.implies_elim
   147            (Conv.fconv_rule (Thm.beta_conversion true)
   148              (Drule.instantiate'
   149                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   150                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   151                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   152       in
   153         case map_filter (fn th'' =>
   154             SOME (th'', singleton
   155               (Variable.trade (K (fn [th'''] => [th''' RS th']))
   156                 (Variable.global_thm_context th'')) th'')
   157           handle THM _ => NONE) thms of
   158             [] => NONE
   159           | thps =>
   160               let val (ths1, ths2) = split_list thps
   161               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   162       end
   163   in get_first mk_thms eqs end;
   164 
   165 fun eqn_suc_base_preproc thy thms =
   166   let
   167     val dest = fst o Logic.dest_equals o prop_of;
   168     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   169   in
   170     if forall (can dest) thms andalso exists (contains_suc o dest) thms
   171       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
   172        else NONE
   173   end;
   174 
   175 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
   176 
   177 fun remove_suc_clause thy thms =
   178   let
   179     val vname = Name.variant (map fst
   180       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   181     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   182       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   183       | find_var _ = NONE;
   184     fun find_thm th =
   185       let val th' = Conv.fconv_rule Object_Logic.atomize th
   186       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   187   in
   188     case get_first find_thm thms of
   189       NONE => thms
   190     | SOME ((th, th'), (Sucv, v)) =>
   191         let
   192           val cert = cterm_of (Thm.theory_of_thm th);
   193           val th'' = Object_Logic.rulify (Thm.implies_elim
   194             (Conv.fconv_rule (Thm.beta_conversion true)
   195               (Drule.instantiate' []
   196                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   197                    abstract_over (Sucv,
   198                      HOLogic.dest_Trueprop (prop_of th')))))),
   199                  SOME (cert v)] @{thm Suc_clause}))
   200             (Thm.forall_intr (cert v) th'))
   201         in
   202           remove_suc_clause thy (map (fn th''' =>
   203             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   204         end
   205   end;
   206 
   207 fun clause_suc_preproc thy ths =
   208   let
   209     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   210   in
   211     if forall (can (dest o concl_of)) ths andalso
   212       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   213         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   214     then remove_suc_clause thy ths else ths
   215   end;
   216 in
   217 
   218   Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
   219   #> Codegen.add_preprocessor clause_suc_preproc
   220 
   221 end;
   222 *}
   223 (*>*)
   224 
   225 
   226 subsection {* Target language setup *}
   227 
   228 text {*
   229   For ML, we map @{typ nat} to target language integers, where we
   230   ensure that values are always non-negative.
   231 *}
   232 
   233 code_type nat
   234   (SML "IntInf.int")
   235   (OCaml "Big'_int.big'_int")
   236 
   237 types_code
   238   nat ("int")
   239 attach (term_of) {*
   240 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   241 *}
   242 attach (test) {*
   243 fun gen_nat i =
   244   let val n = random_range 0 i
   245   in (n, fn () => term_of_nat n) end;
   246 *}
   247 
   248 text {*
   249   For Haskell ans Scala we define our own @{typ nat} type.  The reason
   250   is that we have to distinguish type class instances for @{typ nat}
   251   and @{typ int}.
   252 *}
   253 
   254 code_include Haskell "Nat" {*
   255 newtype Nat = Nat Integer deriving (Eq, Show, Read);
   256 
   257 instance Num Nat where {
   258   fromInteger k = Nat (if k >= 0 then k else 0);
   259   Nat n + Nat m = Nat (n + m);
   260   Nat n - Nat m = fromInteger (n - m);
   261   Nat n * Nat m = Nat (n * m);
   262   abs n = n;
   263   signum _ = 1;
   264   negate n = error "negate Nat";
   265 };
   266 
   267 instance Ord Nat where {
   268   Nat n <= Nat m = n <= m;
   269   Nat n < Nat m = n < m;
   270 };
   271 
   272 instance Real Nat where {
   273   toRational (Nat n) = toRational n;
   274 };
   275 
   276 instance Enum Nat where {
   277   toEnum k = fromInteger (toEnum k);
   278   fromEnum (Nat n) = fromEnum n;
   279 };
   280 
   281 instance Integral Nat where {
   282   toInteger (Nat n) = n;
   283   divMod n m = quotRem n m;
   284   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   285 };
   286 *}
   287 
   288 code_reserved Haskell Nat
   289 
   290 code_include Scala "Nat" {*
   291 import scala.Math
   292 
   293 object Nat {
   294 
   295   def apply(numeral: BigInt): Nat = new Nat(numeral max 0)
   296   def apply(numeral: Int): Nat = Nat(BigInt(numeral))
   297   def apply(numeral: String): Nat = Nat(BigInt(numeral))
   298 
   299 }
   300 
   301 class Nat private(private val value: BigInt) {
   302 
   303   override def hashCode(): Int = this.value.hashCode()
   304 
   305   override def equals(that: Any): Boolean = that match {
   306     case that: Nat => this equals that
   307     case _ => false
   308   }
   309 
   310   override def toString(): String = this.value.toString
   311 
   312   def equals(that: Nat): Boolean = this.value == that.value
   313 
   314   def as_BigInt: BigInt = this.value
   315   def as_Int: Int = if (this.value >= Int.MinValue && this.value <= Int.MaxValue)
   316       this.value.intValue
   317     else this.value.intValue
   318 
   319   def +(that: Nat): Nat = new Nat(this.value + that.value)
   320   def -(that: Nat): Nat = Nat(this.value - that.value)
   321   def *(that: Nat): Nat = new Nat(this.value * that.value)
   322 
   323   def /%(that: Nat): (Nat, Nat) = if (that.value == 0) (new Nat(0), this)
   324     else {
   325       val (k, l) = this.value /% that.value
   326       (new Nat(k), new Nat(l))
   327     }
   328 
   329   def <=(that: Nat): Boolean = this.value <= that.value
   330 
   331   def <(that: Nat): Boolean = this.value < that.value
   332 
   333 }
   334 *}
   335 
   336 code_reserved Scala Nat
   337 
   338 code_type nat
   339   (Haskell "Nat.Nat")
   340   (Scala "Nat")
   341 
   342 code_instance nat :: eq
   343   (Haskell -)
   344 
   345 text {*
   346   Natural numerals.
   347 *}
   348 
   349 lemma [code_unfold_post]:
   350   "nat (number_of i) = number_nat_inst.number_of_nat i"
   351   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   352   by (simp add: number_nat_inst.number_of_nat)
   353 
   354 setup {*
   355   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   356     false Code_Printer.literal_positive_numeral) ["SML", "OCaml", "Haskell"]
   357   #> Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   358     false Code_Printer.literal_positive_numeral "Scala"
   359 *}
   360 
   361 text {*
   362   Since natural numbers are implemented
   363   using integers in ML, the coercion function @{const "of_nat"} of type
   364   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   365   For the @{const nat} function for converting an integer to a natural
   366   number, we give a specific implementation using an ML function that
   367   returns its input value, provided that it is non-negative, and otherwise
   368   returns @{text "0"}.
   369 *}
   370 
   371 definition int :: "nat \<Rightarrow> int" where
   372   [code del]: "int = of_nat"
   373 
   374 lemma int_code' [code]:
   375   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   376   unfolding int_nat_number_of [folded int_def] ..
   377 
   378 lemma nat_code' [code]:
   379   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   380   unfolding nat_number_of_def number_of_is_id neg_def by simp
   381 
   382 lemma of_nat_int [code_unfold_post]:
   383   "of_nat = int" by (simp add: int_def)
   384 
   385 lemma of_nat_aux_int [code_unfold]:
   386   "of_nat_aux (\<lambda>i. i + 1) k 0 = int k"
   387   by (simp add: int_def Nat.of_nat_code)
   388 
   389 code_const int
   390   (SML "_")
   391   (OCaml "_")
   392 
   393 consts_code
   394   int ("(_)")
   395   nat ("\<module>nat")
   396 attach {*
   397 fun nat i = if i < 0 then 0 else i;
   398 *}
   399 
   400 code_const nat
   401   (SML "IntInf.max/ (/0,/ _)")
   402   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   403 
   404 text {* For Haskell and Scala, things are slightly different again. *}
   405 
   406 code_const int and nat
   407   (Haskell "toInteger" and "fromInteger")
   408   (Scala "!_.as'_BigInt" and "Nat")
   409 
   410 text {* Conversion from and to indices. *}
   411 
   412 code_const Code_Numeral.of_nat
   413   (SML "IntInf.toInt")
   414   (OCaml "_")
   415   (Haskell "toInteger")
   416   (Scala "!_.as'_Int")
   417 
   418 code_const Code_Numeral.nat_of
   419   (SML "IntInf.fromInt")
   420   (OCaml "_")
   421   (Haskell "fromInteger")
   422   (Scala "Nat")
   423 
   424 text {* Using target language arithmetic operations whenever appropriate *}
   425 
   426 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   427   (SML "IntInf.+ ((_), (_))")
   428   (OCaml "Big'_int.add'_big'_int")
   429   (Haskell infixl 6 "+")
   430   (Scala infixl 7 "+")
   431 
   432 code_const "op - \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   433   (Haskell infixl 6 "-")
   434   (Scala infixl 7 "-")
   435 
   436 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   437   (SML "IntInf.* ((_), (_))")
   438   (OCaml "Big'_int.mult'_big'_int")
   439   (Haskell infixl 7 "*")
   440   (Scala infixl 8 "*")
   441 
   442 code_const divmod_aux
   443   (SML "IntInf.divMod/ ((_),/ (_))")
   444   (OCaml "Big'_int.quomod'_big'_int")
   445   (Haskell "divMod")
   446   (Scala infixl 8 "/%")
   447 
   448 code_const divmod_nat
   449   (Haskell "divMod")
   450   (Scala infixl 8 "/%")
   451 
   452 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   453   (SML "!((_ : IntInf.int) = _)")
   454   (OCaml "Big'_int.eq'_big'_int")
   455   (Haskell infixl 4 "==")
   456   (Scala infixl 5 "==")
   457 
   458 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   459   (SML "IntInf.<= ((_), (_))")
   460   (OCaml "Big'_int.le'_big'_int")
   461   (Haskell infix 4 "<=")
   462   (Scala infixl 4 "<=")
   463 
   464 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   465   (SML "IntInf.< ((_), (_))")
   466   (OCaml "Big'_int.lt'_big'_int")
   467   (Haskell infix 4 "<")
   468   (Scala infixl 4 "<")
   469 
   470 consts_code
   471   "0::nat"                     ("0")
   472   "1::nat"                     ("1")
   473   Suc                          ("(_ +/ 1)")
   474   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   475   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   476   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   477   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   478 
   479 
   480 text {* Evaluation *}
   481 
   482 lemma [code, code del]:
   483   "(Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term) = Code_Evaluation.term_of" ..
   484 
   485 code_const "Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term"
   486   (SML "HOLogic.mk'_number/ HOLogic.natT")
   487 
   488 
   489 text {* Module names *}
   490 
   491 code_modulename SML
   492   Efficient_Nat Arith
   493 
   494 code_modulename OCaml
   495   Efficient_Nat Arith
   496 
   497 code_modulename Haskell
   498   Efficient_Nat Arith
   499 
   500 hide_const int
   501 
   502 end