src/HOL/Library/Efficient_Nat.thy
 author wenzelm Thu Feb 16 22:53:24 2012 +0100 (2012-02-16) changeset 46507 1b24c24017dd parent 46497 89ccf66aa73d child 47108 2a1953f0d20d permissions -rw-r--r--
tuned proofs;
```     1 (*  Title:      HOL/Library/Efficient_Nat.thy
```
```     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
```
```     3 *)
```
```     4
```
```     5 header {* Implementation of natural numbers by target-language integers *}
```
```     6
```
```     7 theory Efficient_Nat
```
```     8 imports Code_Integer Main
```
```     9 begin
```
```    10
```
```    11 text {*
```
```    12   When generating code for functions on natural numbers, the
```
```    13   canonical representation using @{term "0::nat"} and
```
```    14   @{term Suc} is unsuitable for computations involving large
```
```    15   numbers.  The efficiency of the generated code can be improved
```
```    16   drastically by implementing natural numbers by target-language
```
```    17   integers.  To do this, just include this theory.
```
```    18 *}
```
```    19
```
```    20 subsection {* Basic arithmetic *}
```
```    21
```
```    22 text {*
```
```    23   Most standard arithmetic functions on natural numbers are implemented
```
```    24   using their counterparts on the integers:
```
```    25 *}
```
```    26
```
```    27 code_datatype number_nat_inst.number_of_nat
```
```    28
```
```    29 lemma zero_nat_code [code, code_unfold]:
```
```    30   "0 = (Numeral0 :: nat)"
```
```    31   by simp
```
```    32
```
```    33 lemma one_nat_code [code, code_unfold]:
```
```    34   "1 = (Numeral1 :: nat)"
```
```    35   by simp
```
```    36
```
```    37 lemma Suc_code [code]:
```
```    38   "Suc n = n + 1"
```
```    39   by simp
```
```    40
```
```    41 lemma plus_nat_code [code]:
```
```    42   "n + m = nat (of_nat n + of_nat m)"
```
```    43   by simp
```
```    44
```
```    45 lemma minus_nat_code [code]:
```
```    46   "n - m = nat (of_nat n - of_nat m)"
```
```    47   by simp
```
```    48
```
```    49 lemma times_nat_code [code]:
```
```    50   "n * m = nat (of_nat n * of_nat m)"
```
```    51   unfolding of_nat_mult [symmetric] by simp
```
```    52
```
```    53 lemma divmod_nat_code [code]:
```
```    54   "divmod_nat n m = map_pair nat nat (pdivmod (of_nat n) (of_nat m))"
```
```    55   by (simp add: map_pair_def split_def pdivmod_def nat_div_distrib nat_mod_distrib divmod_nat_div_mod)
```
```    56
```
```    57 lemma eq_nat_code [code]:
```
```    58   "HOL.equal n m \<longleftrightarrow> HOL.equal (of_nat n \<Colon> int) (of_nat m)"
```
```    59   by (simp add: equal)
```
```    60
```
```    61 lemma eq_nat_refl [code nbe]:
```
```    62   "HOL.equal (n::nat) n \<longleftrightarrow> True"
```
```    63   by (rule equal_refl)
```
```    64
```
```    65 lemma less_eq_nat_code [code]:
```
```    66   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
```
```    67   by simp
```
```    68
```
```    69 lemma less_nat_code [code]:
```
```    70   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
```
```    71   by simp
```
```    72
```
```    73 subsection {* Case analysis *}
```
```    74
```
```    75 text {*
```
```    76   Case analysis on natural numbers is rephrased using a conditional
```
```    77   expression:
```
```    78 *}
```
```    79
```
```    80 lemma [code, code_unfold]:
```
```    81   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
```
```    82   by (auto simp add: fun_eq_iff dest!: gr0_implies_Suc)
```
```    83
```
```    84
```
```    85 subsection {* Preprocessors *}
```
```    86
```
```    87 text {*
```
```    88   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
```
```    89   a constructor term. Therefore, all occurrences of this term in a position
```
```    90   where a pattern is expected (i.e.\ on the left-hand side of a recursion
```
```    91   equation or in the arguments of an inductive relation in an introduction
```
```    92   rule) must be eliminated.
```
```    93   This can be accomplished by applying the following transformation rules:
```
```    94 *}
```
```    95
```
```    96 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
```
```    97   f n \<equiv> if n = 0 then g else h (n - 1)"
```
```    98   by (rule eq_reflection) (cases n, simp_all)
```
```    99
```
```   100 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
```
```   101   by (cases n) simp_all
```
```   102
```
```   103 text {*
```
```   104   The rules above are built into a preprocessor that is plugged into
```
```   105   the code generator. Since the preprocessor for introduction rules
```
```   106   does not know anything about modes, some of the modes that worked
```
```   107   for the canonical representation of natural numbers may no longer work.
```
```   108 *}
```
```   109
```
```   110 (*<*)
```
```   111 setup {*
```
```   112 let
```
```   113
```
```   114 fun remove_suc thy thms =
```
```   115   let
```
```   116     val vname = singleton (Name.variant_list (map fst
```
```   117       (fold (Term.add_var_names o Thm.full_prop_of) thms []))) "n";
```
```   118     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
```
```   119     fun lhs_of th = snd (Thm.dest_comb
```
```   120       (fst (Thm.dest_comb (cprop_of th))));
```
```   121     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
```
```   122     fun find_vars ct = (case term_of ct of
```
```   123         (Const (@{const_name Suc}, _) \$ Var _) => [(cv, snd (Thm.dest_comb ct))]
```
```   124       | _ \$ _ =>
```
```   125         let val (ct1, ct2) = Thm.dest_comb ct
```
```   126         in
```
```   127           map (apfst (fn ct => Thm.apply ct ct2)) (find_vars ct1) @
```
```   128           map (apfst (Thm.apply ct1)) (find_vars ct2)
```
```   129         end
```
```   130       | _ => []);
```
```   131     val eqs = maps
```
```   132       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
```
```   133     fun mk_thms (th, (ct, cv')) =
```
```   134       let
```
```   135         val th' =
```
```   136           Thm.implies_elim
```
```   137            (Conv.fconv_rule (Thm.beta_conversion true)
```
```   138              (Drule.instantiate'
```
```   139                [SOME (ctyp_of_term ct)] [SOME (Thm.lambda cv ct),
```
```   140                  SOME (Thm.lambda cv' (rhs_of th)), NONE, SOME cv']
```
```   141                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
```
```   142       in
```
```   143         case map_filter (fn th'' =>
```
```   144             SOME (th'', singleton
```
```   145               (Variable.trade (K (fn [th'''] => [th''' RS th']))
```
```   146                 (Variable.global_thm_context th'')) th'')
```
```   147           handle THM _ => NONE) thms of
```
```   148             [] => NONE
```
```   149           | thps =>
```
```   150               let val (ths1, ths2) = split_list thps
```
```   151               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
```
```   152       end
```
```   153   in get_first mk_thms eqs end;
```
```   154
```
```   155 fun eqn_suc_base_preproc thy thms =
```
```   156   let
```
```   157     val dest = fst o Logic.dest_equals o prop_of;
```
```   158     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
```
```   159   in
```
```   160     if forall (can dest) thms andalso exists (contains_suc o dest) thms
```
```   161       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
```
```   162        else NONE
```
```   163   end;
```
```   164
```
```   165 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
```
```   166
```
```   167 in
```
```   168
```
```   169   Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
```
```   170
```
```   171 end;
```
```   172 *}
```
```   173 (*>*)
```
```   174
```
```   175
```
```   176 subsection {* Target language setup *}
```
```   177
```
```   178 text {*
```
```   179   For ML, we map @{typ nat} to target language integers, where we
```
```   180   ensure that values are always non-negative.
```
```   181 *}
```
```   182
```
```   183 code_type nat
```
```   184   (SML "IntInf.int")
```
```   185   (OCaml "Big'_int.big'_int")
```
```   186   (Eval "int")
```
```   187
```
```   188 text {*
```
```   189   For Haskell and Scala we define our own @{typ nat} type.  The reason
```
```   190   is that we have to distinguish type class instances for @{typ nat}
```
```   191   and @{typ int}.
```
```   192 *}
```
```   193
```
```   194 code_include Haskell "Nat"
```
```   195 {*newtype Nat = Nat Integer deriving (Eq, Show, Read);
```
```   196
```
```   197 instance Num Nat where {
```
```   198   fromInteger k = Nat (if k >= 0 then k else 0);
```
```   199   Nat n + Nat m = Nat (n + m);
```
```   200   Nat n - Nat m = fromInteger (n - m);
```
```   201   Nat n * Nat m = Nat (n * m);
```
```   202   abs n = n;
```
```   203   signum _ = 1;
```
```   204   negate n = error "negate Nat";
```
```   205 };
```
```   206
```
```   207 instance Ord Nat where {
```
```   208   Nat n <= Nat m = n <= m;
```
```   209   Nat n < Nat m = n < m;
```
```   210 };
```
```   211
```
```   212 instance Real Nat where {
```
```   213   toRational (Nat n) = toRational n;
```
```   214 };
```
```   215
```
```   216 instance Enum Nat where {
```
```   217   toEnum k = fromInteger (toEnum k);
```
```   218   fromEnum (Nat n) = fromEnum n;
```
```   219 };
```
```   220
```
```   221 instance Integral Nat where {
```
```   222   toInteger (Nat n) = n;
```
```   223   divMod n m = quotRem n m;
```
```   224   quotRem (Nat n) (Nat m)
```
```   225     | (m == 0) = (0, Nat n)
```
```   226     | otherwise = (Nat k, Nat l) where (k, l) = quotRem n m;
```
```   227 };
```
```   228 *}
```
```   229
```
```   230 code_reserved Haskell Nat
```
```   231
```
```   232 code_include Scala "Nat"
```
```   233 {*object Nat {
```
```   234
```
```   235   def apply(numeral: BigInt): Nat = new Nat(numeral max 0)
```
```   236   def apply(numeral: Int): Nat = Nat(BigInt(numeral))
```
```   237   def apply(numeral: String): Nat = Nat(BigInt(numeral))
```
```   238
```
```   239 }
```
```   240
```
```   241 class Nat private(private val value: BigInt) {
```
```   242
```
```   243   override def hashCode(): Int = this.value.hashCode()
```
```   244
```
```   245   override def equals(that: Any): Boolean = that match {
```
```   246     case that: Nat => this equals that
```
```   247     case _ => false
```
```   248   }
```
```   249
```
```   250   override def toString(): String = this.value.toString
```
```   251
```
```   252   def equals(that: Nat): Boolean = this.value == that.value
```
```   253
```
```   254   def as_BigInt: BigInt = this.value
```
```   255   def as_Int: Int = if (this.value >= scala.Int.MinValue && this.value <= scala.Int.MaxValue)
```
```   256       this.value.intValue
```
```   257     else error("Int value out of range: " + this.value.toString)
```
```   258
```
```   259   def +(that: Nat): Nat = new Nat(this.value + that.value)
```
```   260   def -(that: Nat): Nat = Nat(this.value - that.value)
```
```   261   def *(that: Nat): Nat = new Nat(this.value * that.value)
```
```   262
```
```   263   def /%(that: Nat): (Nat, Nat) = if (that.value == 0) (new Nat(0), this)
```
```   264     else {
```
```   265       val (k, l) = this.value /% that.value
```
```   266       (new Nat(k), new Nat(l))
```
```   267     }
```
```   268
```
```   269   def <=(that: Nat): Boolean = this.value <= that.value
```
```   270
```
```   271   def <(that: Nat): Boolean = this.value < that.value
```
```   272
```
```   273 }
```
```   274 *}
```
```   275
```
```   276 code_reserved Scala Nat
```
```   277
```
```   278 code_type nat
```
```   279   (Haskell "Nat.Nat")
```
```   280   (Scala "Nat")
```
```   281
```
```   282 code_instance nat :: equal
```
```   283   (Haskell -)
```
```   284
```
```   285 text {*
```
```   286   Natural numerals.
```
```   287 *}
```
```   288
```
```   289 lemma [code_abbrev]:
```
```   290   "number_nat_inst.number_of_nat i = nat (number_of i)"
```
```   291   -- {* this interacts as desired with @{thm nat_number_of_def} *}
```
```   292   by (simp add: number_nat_inst.number_of_nat)
```
```   293
```
```   294 setup {*
```
```   295   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
```
```   296     false Code_Printer.literal_positive_numeral) ["SML", "OCaml", "Haskell", "Scala"]
```
```   297 *}
```
```   298
```
```   299 text {*
```
```   300   Since natural numbers are implemented
```
```   301   using integers in ML, the coercion function @{const "of_nat"} of type
```
```   302   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
```
```   303   For the @{const nat} function for converting an integer to a natural
```
```   304   number, we give a specific implementation using an ML function that
```
```   305   returns its input value, provided that it is non-negative, and otherwise
```
```   306   returns @{text "0"}.
```
```   307 *}
```
```   308
```
```   309 definition int :: "nat \<Rightarrow> int" where
```
```   310   [code del, code_abbrev]: "int = of_nat"
```
```   311
```
```   312 lemma int_code' [code]:
```
```   313   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
```
```   314   unfolding int_nat_number_of [folded int_def] ..
```
```   315
```
```   316 lemma nat_code' [code]:
```
```   317   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
```
```   318   unfolding nat_number_of_def number_of_is_id neg_def by simp
```
```   319
```
```   320 lemma of_nat_int: (* FIXME delete candidate *)
```
```   321   "of_nat = int" by (simp add: int_def)
```
```   322
```
```   323 lemma of_nat_aux_int [code_unfold]:
```
```   324   "of_nat_aux (\<lambda>i. i + 1) k 0 = int k"
```
```   325   by (simp add: int_def Nat.of_nat_code)
```
```   326
```
```   327 code_const int
```
```   328   (SML "_")
```
```   329   (OCaml "_")
```
```   330
```
```   331 code_const nat
```
```   332   (SML "IntInf.max/ (0,/ _)")
```
```   333   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
```
```   334   (Eval "Integer.max/ _/ 0")
```
```   335
```
```   336 text {* For Haskell and Scala, things are slightly different again. *}
```
```   337
```
```   338 code_const int and nat
```
```   339   (Haskell "toInteger" and "fromInteger")
```
```   340   (Scala "!_.as'_BigInt" and "Nat")
```
```   341
```
```   342 text {* Conversion from and to code numerals. *}
```
```   343
```
```   344 code_const Code_Numeral.of_nat
```
```   345   (SML "IntInf.toInt")
```
```   346   (OCaml "_")
```
```   347   (Haskell "!(fromInteger/ ./ toInteger)")
```
```   348   (Scala "!Natural(_.as'_BigInt)")
```
```   349   (Eval "_")
```
```   350
```
```   351 code_const Code_Numeral.nat_of
```
```   352   (SML "IntInf.fromInt")
```
```   353   (OCaml "_")
```
```   354   (Haskell "!(fromInteger/ ./ toInteger)")
```
```   355   (Scala "!Nat(_.as'_BigInt)")
```
```   356   (Eval "_")
```
```   357
```
```   358 text {* Using target language arithmetic operations whenever appropriate *}
```
```   359
```
```   360 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
```
```   361   (SML "IntInf.+ ((_), (_))")
```
```   362   (OCaml "Big'_int.add'_big'_int")
```
```   363   (Haskell infixl 6 "+")
```
```   364   (Scala infixl 7 "+")
```
```   365   (Eval infixl 8 "+")
```
```   366
```
```   367 code_const "op - \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
```
```   368   (Haskell infixl 6 "-")
```
```   369   (Scala infixl 7 "-")
```
```   370
```
```   371 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
```
```   372   (SML "IntInf.* ((_), (_))")
```
```   373   (OCaml "Big'_int.mult'_big'_int")
```
```   374   (Haskell infixl 7 "*")
```
```   375   (Scala infixl 8 "*")
```
```   376   (Eval infixl 9 "*")
```
```   377
```
```   378 code_const divmod_nat
```
```   379   (SML "IntInf.divMod/ ((_),/ (_))")
```
```   380   (OCaml "Big'_int.quomod'_big'_int")
```
```   381   (Haskell "divMod")
```
```   382   (Scala infixl 8 "/%")
```
```   383   (Eval "Integer.div'_mod")
```
```   384
```
```   385 code_const "HOL.equal \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
```
```   386   (SML "!((_ : IntInf.int) = _)")
```
```   387   (OCaml "Big'_int.eq'_big'_int")
```
```   388   (Haskell infix 4 "==")
```
```   389   (Scala infixl 5 "==")
```
```   390   (Eval infixl 6 "=")
```
```   391
```
```   392 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
```
```   393   (SML "IntInf.<= ((_), (_))")
```
```   394   (OCaml "Big'_int.le'_big'_int")
```
```   395   (Haskell infix 4 "<=")
```
```   396   (Scala infixl 4 "<=")
```
```   397   (Eval infixl 6 "<=")
```
```   398
```
```   399 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
```
```   400   (SML "IntInf.< ((_), (_))")
```
```   401   (OCaml "Big'_int.lt'_big'_int")
```
```   402   (Haskell infix 4 "<")
```
```   403   (Scala infixl 4 "<")
```
```   404   (Eval infixl 6 "<")
```
```   405
```
```   406
```
```   407 text {* Evaluation *}
```
```   408
```
```   409 lemma [code, code del]:
```
```   410   "(Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term) = Code_Evaluation.term_of" ..
```
```   411
```
```   412 code_const "Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term"
```
```   413   (SML "HOLogic.mk'_number/ HOLogic.natT")
```
```   414
```
```   415 text {* Evaluation with @{text "Quickcheck_Narrowing"} does not work, as
```
```   416   @{text "code_module"} is very aggressive leading to bad Haskell code.
```
```   417   Therefore, we simply deactivate the narrowing-based quickcheck from here on.
```
```   418 *}
```
```   419
```
```   420 declare [[quickcheck_narrowing_active = false]]
```
```   421
```
```   422 text {* Module names *}
```
```   423
```
```   424 code_modulename SML
```
```   425   Efficient_Nat Arith
```
```   426
```
```   427 code_modulename OCaml
```
```   428   Efficient_Nat Arith
```
```   429
```
```   430 code_modulename Haskell
```
```   431   Efficient_Nat Arith
```
```   432
```
```   433 hide_const int
```
```   434
```
```   435 end
```