src/HOL/Library/Efficient_Nat.thy
author haftmann
Fri Oct 30 18:32:40 2009 +0100 (2009-10-30)
changeset 33364 2bd12592c5e8
parent 33343 2eb0b672ab40
child 34893 ecdc526af73a
permissions -rw-r--r--
tuned code setup
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Integer Main
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term "Suc"} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code_unfold_post]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 
    33 lemma one_nat_code [code, code_unfold_post]:
    34   "1 = (Numeral1 :: nat)"
    35   by simp
    36 
    37 lemma Suc_code [code]:
    38   "Suc n = n + 1"
    39   by simp
    40 
    41 lemma plus_nat_code [code]:
    42   "n + m = nat (of_nat n + of_nat m)"
    43   by simp
    44 
    45 lemma minus_nat_code [code]:
    46   "n - m = nat (of_nat n - of_nat m)"
    47   by simp
    48 
    49 lemma times_nat_code [code]:
    50   "n * m = nat (of_nat n * of_nat m)"
    51   unfolding of_nat_mult [symmetric] by simp
    52 
    53 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    54   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    55 
    56 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    57   [code del]: "divmod_aux = divmod_nat"
    58 
    59 lemma [code]:
    60   "divmod_nat n m = (if m = 0 then (0, n) else divmod_aux n m)"
    61   unfolding divmod_aux_def divmod_nat_div_mod by simp
    62 
    63 lemma divmod_aux_code [code]:
    64   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    65   unfolding divmod_aux_def divmod_nat_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    66 
    67 lemma eq_nat_code [code]:
    68   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    69   by (simp add: eq)
    70 
    71 lemma eq_nat_refl [code nbe]:
    72   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    73   by (rule HOL.eq_refl)
    74 
    75 lemma less_eq_nat_code [code]:
    76   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    77   by simp
    78 
    79 lemma less_nat_code [code]:
    80   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    81   by simp
    82 
    83 subsection {* Case analysis *}
    84 
    85 text {*
    86   Case analysis on natural numbers is rephrased using a conditional
    87   expression:
    88 *}
    89 
    90 lemma [code, code_unfold]:
    91   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    92   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    93 
    94 
    95 subsection {* Preprocessors *}
    96 
    97 text {*
    98   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
    99   a constructor term. Therefore, all occurrences of this term in a position
   100   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   101   equation or in the arguments of an inductive relation in an introduction
   102   rule) must be eliminated.
   103   This can be accomplished by applying the following transformation rules:
   104 *}
   105 
   106 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
   107   f n \<equiv> if n = 0 then g else h (n - 1)"
   108   by (rule eq_reflection) (cases n, simp_all)
   109 
   110 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   111   by (cases n) simp_all
   112 
   113 text {*
   114   The rules above are built into a preprocessor that is plugged into
   115   the code generator. Since the preprocessor for introduction rules
   116   does not know anything about modes, some of the modes that worked
   117   for the canonical representation of natural numbers may no longer work.
   118 *}
   119 
   120 (*<*)
   121 setup {*
   122 let
   123 
   124 fun remove_suc thy thms =
   125   let
   126     val vname = Name.variant (map fst
   127       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
   128     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   129     fun lhs_of th = snd (Thm.dest_comb
   130       (fst (Thm.dest_comb (cprop_of th))));
   131     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
   132     fun find_vars ct = (case term_of ct of
   133         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   134       | _ $ _ =>
   135         let val (ct1, ct2) = Thm.dest_comb ct
   136         in 
   137           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   138           map (apfst (Thm.capply ct1)) (find_vars ct2)
   139         end
   140       | _ => []);
   141     val eqs = maps
   142       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   143     fun mk_thms (th, (ct, cv')) =
   144       let
   145         val th' =
   146           Thm.implies_elim
   147            (Conv.fconv_rule (Thm.beta_conversion true)
   148              (Drule.instantiate'
   149                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   150                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   151                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   152       in
   153         case map_filter (fn th'' =>
   154             SOME (th'', singleton
   155               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   156           handle THM _ => NONE) thms of
   157             [] => NONE
   158           | thps =>
   159               let val (ths1, ths2) = split_list thps
   160               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   161       end
   162   in get_first mk_thms eqs end;
   163 
   164 fun eqn_suc_preproc thy thms =
   165   let
   166     val dest = fst o Logic.dest_equals o prop_of;
   167     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   168   in
   169     if forall (can dest) thms andalso exists (contains_suc o dest) thms
   170       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
   171        else NONE
   172   end;
   173 
   174 val eqn_suc_preproc1 = Code_Preproc.simple_functrans eqn_suc_preproc;
   175 
   176 fun eqn_suc_preproc2 thy thms = eqn_suc_preproc thy thms
   177   |> the_default thms;
   178 
   179 fun remove_suc_clause thy thms =
   180   let
   181     val vname = Name.variant (map fst
   182       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   183     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   184       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   185       | find_var _ = NONE;
   186     fun find_thm th =
   187       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   188       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   189   in
   190     case get_first find_thm thms of
   191       NONE => thms
   192     | SOME ((th, th'), (Sucv, v)) =>
   193         let
   194           val cert = cterm_of (Thm.theory_of_thm th);
   195           val th'' = ObjectLogic.rulify (Thm.implies_elim
   196             (Conv.fconv_rule (Thm.beta_conversion true)
   197               (Drule.instantiate' []
   198                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   199                    abstract_over (Sucv,
   200                      HOLogic.dest_Trueprop (prop_of th')))))),
   201                  SOME (cert v)] @{thm Suc_clause}))
   202             (Thm.forall_intr (cert v) th'))
   203         in
   204           remove_suc_clause thy (map (fn th''' =>
   205             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   206         end
   207   end;
   208 
   209 fun clause_suc_preproc thy ths =
   210   let
   211     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   212   in
   213     if forall (can (dest o concl_of)) ths andalso
   214       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   215         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   216     then remove_suc_clause thy ths else ths
   217   end;
   218 in
   219 
   220   Codegen.add_preprocessor eqn_suc_preproc2
   221   #> Codegen.add_preprocessor clause_suc_preproc
   222   #> Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc1)
   223 
   224 end;
   225 *}
   226 (*>*)
   227 
   228 
   229 subsection {* Target language setup *}
   230 
   231 text {*
   232   For ML, we map @{typ nat} to target language integers, where we
   233   assert that values are always non-negative.
   234 *}
   235 
   236 code_type nat
   237   (SML "IntInf.int")
   238   (OCaml "Big'_int.big'_int")
   239 
   240 types_code
   241   nat ("int")
   242 attach (term_of) {*
   243 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   244 *}
   245 attach (test) {*
   246 fun gen_nat i =
   247   let val n = random_range 0 i
   248   in (n, fn () => term_of_nat n) end;
   249 *}
   250 
   251 text {*
   252   For Haskell we define our own @{typ nat} type.  The reason
   253   is that we have to distinguish type class instances
   254   for @{typ nat} and @{typ int}.
   255 *}
   256 
   257 code_include Haskell "Nat" {*
   258 newtype Nat = Nat Integer deriving (Show, Eq);
   259 
   260 instance Num Nat where {
   261   fromInteger k = Nat (if k >= 0 then k else 0);
   262   Nat n + Nat m = Nat (n + m);
   263   Nat n - Nat m = fromInteger (n - m);
   264   Nat n * Nat m = Nat (n * m);
   265   abs n = n;
   266   signum _ = 1;
   267   negate n = error "negate Nat";
   268 };
   269 
   270 instance Ord Nat where {
   271   Nat n <= Nat m = n <= m;
   272   Nat n < Nat m = n < m;
   273 };
   274 
   275 instance Real Nat where {
   276   toRational (Nat n) = toRational n;
   277 };
   278 
   279 instance Enum Nat where {
   280   toEnum k = fromInteger (toEnum k);
   281   fromEnum (Nat n) = fromEnum n;
   282 };
   283 
   284 instance Integral Nat where {
   285   toInteger (Nat n) = n;
   286   divMod n m = quotRem n m;
   287   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   288 };
   289 *}
   290 
   291 code_reserved Haskell Nat
   292 
   293 code_type nat
   294   (Haskell "Nat.Nat")
   295 
   296 code_instance nat :: eq
   297   (Haskell -)
   298 
   299 text {*
   300   Natural numerals.
   301 *}
   302 
   303 lemma [code_unfold_post]:
   304   "nat (number_of i) = number_nat_inst.number_of_nat i"
   305   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   306   by (simp add: number_nat_inst.number_of_nat)
   307 
   308 setup {*
   309   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   310     false true) ["SML", "OCaml", "Haskell"]
   311 *}
   312 
   313 text {*
   314   Since natural numbers are implemented
   315   using integers in ML, the coercion function @{const "of_nat"} of type
   316   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   317   For the @{const "nat"} function for converting an integer to a natural
   318   number, we give a specific implementation using an ML function that
   319   returns its input value, provided that it is non-negative, and otherwise
   320   returns @{text "0"}.
   321 *}
   322 
   323 definition int :: "nat \<Rightarrow> int" where
   324   [code del]: "int = of_nat"
   325 
   326 lemma int_code' [code]:
   327   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   328   unfolding int_nat_number_of [folded int_def] ..
   329 
   330 lemma nat_code' [code]:
   331   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   332   unfolding nat_number_of_def number_of_is_id neg_def by simp
   333 
   334 lemma of_nat_int [code_unfold_post]:
   335   "of_nat = int" by (simp add: int_def)
   336 
   337 lemma of_nat_aux_int [code_unfold]:
   338   "of_nat_aux (\<lambda>i. i + 1) k 0 = int k"
   339   by (simp add: int_def Nat.of_nat_code)
   340 
   341 code_const int
   342   (SML "_")
   343   (OCaml "_")
   344 
   345 consts_code
   346   int ("(_)")
   347   nat ("\<module>nat")
   348 attach {*
   349 fun nat i = if i < 0 then 0 else i;
   350 *}
   351 
   352 code_const nat
   353   (SML "IntInf.max/ (/0,/ _)")
   354   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   355 
   356 text {* For Haskell, things are slightly different again. *}
   357 
   358 code_const int and nat
   359   (Haskell "toInteger" and "fromInteger")
   360 
   361 text {* Conversion from and to indices. *}
   362 
   363 code_const Code_Numeral.of_nat
   364   (SML "IntInf.toInt")
   365   (OCaml "_")
   366   (Haskell "fromEnum")
   367 
   368 code_const Code_Numeral.nat_of
   369   (SML "IntInf.fromInt")
   370   (OCaml "_")
   371   (Haskell "toEnum")
   372 
   373 text {* Using target language arithmetic operations whenever appropriate *}
   374 
   375 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   376   (SML "IntInf.+ ((_), (_))")
   377   (OCaml "Big'_int.add'_big'_int")
   378   (Haskell infixl 6 "+")
   379 
   380 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   381   (SML "IntInf.* ((_), (_))")
   382   (OCaml "Big'_int.mult'_big'_int")
   383   (Haskell infixl 7 "*")
   384 
   385 code_const divmod_aux
   386   (SML "IntInf.divMod/ ((_),/ (_))")
   387   (OCaml "Big'_int.quomod'_big'_int")
   388   (Haskell "divMod")
   389 
   390 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   391   (SML "!((_ : IntInf.int) = _)")
   392   (OCaml "Big'_int.eq'_big'_int")
   393   (Haskell infixl 4 "==")
   394 
   395 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   396   (SML "IntInf.<= ((_), (_))")
   397   (OCaml "Big'_int.le'_big'_int")
   398   (Haskell infix 4 "<=")
   399 
   400 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   401   (SML "IntInf.< ((_), (_))")
   402   (OCaml "Big'_int.lt'_big'_int")
   403   (Haskell infix 4 "<")
   404 
   405 consts_code
   406   "0::nat"                     ("0")
   407   "1::nat"                     ("1")
   408   Suc                          ("(_ +/ 1)")
   409   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   410   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   411   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   412   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   413 
   414 
   415 text {* Evaluation *}
   416 
   417 lemma [code, code del]:
   418   "(Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term) = Code_Evaluation.term_of" ..
   419 
   420 code_const "Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term"
   421   (SML "HOLogic.mk'_number/ HOLogic.natT")
   422 
   423 
   424 text {* Module names *}
   425 
   426 code_modulename SML
   427   Efficient_Nat Arith
   428 
   429 code_modulename OCaml
   430   Efficient_Nat Arith
   431 
   432 code_modulename Haskell
   433   Efficient_Nat Arith
   434 
   435 hide const int
   436 
   437 end