src/HOL/Library/Efficient_Nat.thy
author haftmann
Tue Jul 07 17:37:00 2009 +0200 (2009-07-07)
changeset 31954 8db19c99b00a
parent 31377 a48f9ef9de15
child 31998 2c7a24f74db9
permissions -rw-r--r--
Stefan Berghofer's code generator uses Pure equality instead of HOL equality now
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Integer Main
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term "Suc"} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code inline]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 lemmas [code post] = zero_nat_code [symmetric]
    33 
    34 lemma one_nat_code [code, code inline]:
    35   "1 = (Numeral1 :: nat)"
    36   by simp
    37 lemmas [code post] = one_nat_code [symmetric]
    38 
    39 lemma Suc_code [code]:
    40   "Suc n = n + 1"
    41   by simp
    42 
    43 lemma plus_nat_code [code]:
    44   "n + m = nat (of_nat n + of_nat m)"
    45   by simp
    46 
    47 lemma minus_nat_code [code]:
    48   "n - m = nat (of_nat n - of_nat m)"
    49   by simp
    50 
    51 lemma times_nat_code [code]:
    52   "n * m = nat (of_nat n * of_nat m)"
    53   unfolding of_nat_mult [symmetric] by simp
    54 
    55 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    56   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    57 
    58 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    59   [code del]: "divmod_aux = Divides.divmod"
    60 
    61 lemma [code]:
    62   "Divides.divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
    63   unfolding divmod_aux_def divmod_div_mod by simp
    64 
    65 lemma divmod_aux_code [code]:
    66   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    67   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    68 
    69 lemma eq_nat_code [code]:
    70   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    71   by (simp add: eq)
    72 
    73 lemma eq_nat_refl [code nbe]:
    74   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    75   by (rule HOL.eq_refl)
    76 
    77 lemma less_eq_nat_code [code]:
    78   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    79   by simp
    80 
    81 lemma less_nat_code [code]:
    82   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    83   by simp
    84 
    85 subsection {* Case analysis *}
    86 
    87 text {*
    88   Case analysis on natural numbers is rephrased using a conditional
    89   expression:
    90 *}
    91 
    92 lemma [code, code unfold]:
    93   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    94   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    95 
    96 
    97 subsection {* Preprocessors *}
    98 
    99 text {*
   100   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   101   a constructor term. Therefore, all occurrences of this term in a position
   102   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   103   equation or in the arguments of an inductive relation in an introduction
   104   rule) must be eliminated.
   105   This can be accomplished by applying the following transformation rules:
   106 *}
   107 
   108 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
   109   f n \<equiv> if n = 0 then g else h (n - 1)"
   110   by (rule eq_reflection) (cases n, simp_all)
   111 
   112 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   113   by (cases n) simp_all
   114 
   115 text {*
   116   The rules above are built into a preprocessor that is plugged into
   117   the code generator. Since the preprocessor for introduction rules
   118   does not know anything about modes, some of the modes that worked
   119   for the canonical representation of natural numbers may no longer work.
   120 *}
   121 
   122 (*<*)
   123 setup {*
   124 let
   125 
   126 fun remove_suc thy thms =
   127   let
   128     val vname = Name.variant (map fst
   129       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
   130     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   131     fun lhs_of th = snd (Thm.dest_comb
   132       (fst (Thm.dest_comb (cprop_of th))));
   133     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
   134     fun find_vars ct = (case term_of ct of
   135         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   136       | _ $ _ =>
   137         let val (ct1, ct2) = Thm.dest_comb ct
   138         in 
   139           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   140           map (apfst (Thm.capply ct1)) (find_vars ct2)
   141         end
   142       | _ => []);
   143     val eqs = maps
   144       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   145     fun mk_thms (th, (ct, cv')) =
   146       let
   147         val th' =
   148           Thm.implies_elim
   149            (Conv.fconv_rule (Thm.beta_conversion true)
   150              (Drule.instantiate'
   151                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   152                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   153                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   154       in
   155         case map_filter (fn th'' =>
   156             SOME (th'', singleton
   157               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   158           handle THM _ => NONE) thms of
   159             [] => NONE
   160           | thps =>
   161               let val (ths1, ths2) = split_list thps
   162               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   163       end
   164   in get_first mk_thms eqs end;
   165 
   166 fun eqn_suc_preproc thy thms =
   167   let
   168     val dest = fst o Logic.dest_equals o prop_of;
   169     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   170   in
   171     if forall (can dest) thms andalso exists (contains_suc o dest) thms
   172       then perhaps_loop (remove_suc thy) thms
   173        else NONE
   174   end;
   175 
   176 val eqn_suc_preproc1 = Code_Preproc.simple_functrans eqn_suc_preproc;
   177 
   178 fun eqn_suc_preproc2 thy thms = eqn_suc_preproc thy thms
   179   |> the_default thms;
   180 
   181 fun remove_suc_clause thy thms =
   182   let
   183     val vname = Name.variant (map fst
   184       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   185     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   186       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   187       | find_var _ = NONE;
   188     fun find_thm th =
   189       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   190       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   191   in
   192     case get_first find_thm thms of
   193       NONE => thms
   194     | SOME ((th, th'), (Sucv, v)) =>
   195         let
   196           val cert = cterm_of (Thm.theory_of_thm th);
   197           val th'' = ObjectLogic.rulify (Thm.implies_elim
   198             (Conv.fconv_rule (Thm.beta_conversion true)
   199               (Drule.instantiate' []
   200                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   201                    abstract_over (Sucv,
   202                      HOLogic.dest_Trueprop (prop_of th')))))),
   203                  SOME (cert v)] @{thm Suc_clause}))
   204             (Thm.forall_intr (cert v) th'))
   205         in
   206           remove_suc_clause thy (map (fn th''' =>
   207             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   208         end
   209   end;
   210 
   211 fun clause_suc_preproc thy ths =
   212   let
   213     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   214   in
   215     if forall (can (dest o concl_of)) ths andalso
   216       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   217         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   218     then remove_suc_clause thy ths else ths
   219   end;
   220 in
   221 
   222   Codegen.add_preprocessor eqn_suc_preproc2
   223   #> Codegen.add_preprocessor clause_suc_preproc
   224   #> Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc1)
   225 
   226 end;
   227 *}
   228 (*>*)
   229 
   230 
   231 subsection {* Target language setup *}
   232 
   233 text {*
   234   For ML, we map @{typ nat} to target language integers, where we
   235   assert that values are always non-negative.
   236 *}
   237 
   238 code_type nat
   239   (SML "IntInf.int")
   240   (OCaml "Big'_int.big'_int")
   241 
   242 types_code
   243   nat ("int")
   244 attach (term_of) {*
   245 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   246 *}
   247 attach (test) {*
   248 fun gen_nat i =
   249   let val n = random_range 0 i
   250   in (n, fn () => term_of_nat n) end;
   251 *}
   252 
   253 text {*
   254   For Haskell we define our own @{typ nat} type.  The reason
   255   is that we have to distinguish type class instances
   256   for @{typ nat} and @{typ int}.
   257 *}
   258 
   259 code_include Haskell "Nat" {*
   260 newtype Nat = Nat Integer deriving (Show, Eq);
   261 
   262 instance Num Nat where {
   263   fromInteger k = Nat (if k >= 0 then k else 0);
   264   Nat n + Nat m = Nat (n + m);
   265   Nat n - Nat m = fromInteger (n - m);
   266   Nat n * Nat m = Nat (n * m);
   267   abs n = n;
   268   signum _ = 1;
   269   negate n = error "negate Nat";
   270 };
   271 
   272 instance Ord Nat where {
   273   Nat n <= Nat m = n <= m;
   274   Nat n < Nat m = n < m;
   275 };
   276 
   277 instance Real Nat where {
   278   toRational (Nat n) = toRational n;
   279 };
   280 
   281 instance Enum Nat where {
   282   toEnum k = fromInteger (toEnum k);
   283   fromEnum (Nat n) = fromEnum n;
   284 };
   285 
   286 instance Integral Nat where {
   287   toInteger (Nat n) = n;
   288   divMod n m = quotRem n m;
   289   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   290 };
   291 *}
   292 
   293 code_reserved Haskell Nat
   294 
   295 code_type nat
   296   (Haskell "Nat.Nat")
   297 
   298 code_instance nat :: eq
   299   (Haskell -)
   300 
   301 text {*
   302   Natural numerals.
   303 *}
   304 
   305 lemma [code inline, symmetric, code post]:
   306   "nat (number_of i) = number_nat_inst.number_of_nat i"
   307   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   308   by (simp add: number_nat_inst.number_of_nat)
   309 
   310 setup {*
   311   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   312     false true) ["SML", "OCaml", "Haskell"]
   313 *}
   314 
   315 text {*
   316   Since natural numbers are implemented
   317   using integers in ML, the coercion function @{const "of_nat"} of type
   318   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   319   For the @{const "nat"} function for converting an integer to a natural
   320   number, we give a specific implementation using an ML function that
   321   returns its input value, provided that it is non-negative, and otherwise
   322   returns @{text "0"}.
   323 *}
   324 
   325 definition
   326   int :: "nat \<Rightarrow> int"
   327 where
   328   [code del]: "int = of_nat"
   329 
   330 lemma int_code' [code]:
   331   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   332   unfolding int_nat_number_of [folded int_def] ..
   333 
   334 lemma nat_code' [code]:
   335   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   336   unfolding nat_number_of_def number_of_is_id neg_def by simp
   337 
   338 lemma of_nat_int [code unfold]:
   339   "of_nat = int" by (simp add: int_def)
   340 declare of_nat_int [symmetric, code post]
   341 
   342 code_const int
   343   (SML "_")
   344   (OCaml "_")
   345 
   346 consts_code
   347   int ("(_)")
   348   nat ("\<module>nat")
   349 attach {*
   350 fun nat i = if i < 0 then 0 else i;
   351 *}
   352 
   353 code_const nat
   354   (SML "IntInf.max/ (/0,/ _)")
   355   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   356 
   357 text {* For Haskell, things are slightly different again. *}
   358 
   359 code_const int and nat
   360   (Haskell "toInteger" and "fromInteger")
   361 
   362 text {* Conversion from and to indices. *}
   363 
   364 code_const Code_Numeral.of_nat
   365   (SML "IntInf.toInt")
   366   (OCaml "_")
   367   (Haskell "fromEnum")
   368 
   369 code_const Code_Numeral.nat_of
   370   (SML "IntInf.fromInt")
   371   (OCaml "_")
   372   (Haskell "toEnum")
   373 
   374 text {* Using target language arithmetic operations whenever appropriate *}
   375 
   376 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   377   (SML "IntInf.+ ((_), (_))")
   378   (OCaml "Big'_int.add'_big'_int")
   379   (Haskell infixl 6 "+")
   380 
   381 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   382   (SML "IntInf.* ((_), (_))")
   383   (OCaml "Big'_int.mult'_big'_int")
   384   (Haskell infixl 7 "*")
   385 
   386 code_const divmod_aux
   387   (SML "IntInf.divMod/ ((_),/ (_))")
   388   (OCaml "Big'_int.quomod'_big'_int")
   389   (Haskell "divMod")
   390 
   391 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   392   (SML "!((_ : IntInf.int) = _)")
   393   (OCaml "Big'_int.eq'_big'_int")
   394   (Haskell infixl 4 "==")
   395 
   396 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   397   (SML "IntInf.<= ((_), (_))")
   398   (OCaml "Big'_int.le'_big'_int")
   399   (Haskell infix 4 "<=")
   400 
   401 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   402   (SML "IntInf.< ((_), (_))")
   403   (OCaml "Big'_int.lt'_big'_int")
   404   (Haskell infix 4 "<")
   405 
   406 consts_code
   407   "0::nat"                     ("0")
   408   "1::nat"                     ("1")
   409   Suc                          ("(_ +/ 1)")
   410   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   411   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   412   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   413   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   414 
   415 
   416 text {* Evaluation *}
   417 
   418 lemma [code, code del]:
   419   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
   420 
   421 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
   422   (SML "HOLogic.mk'_number/ HOLogic.natT")
   423 
   424 
   425 text {* Module names *}
   426 
   427 code_modulename SML
   428   Nat Integer
   429   Divides Integer
   430   Ring_and_Field Integer
   431   Efficient_Nat Integer
   432 
   433 code_modulename OCaml
   434   Nat Integer
   435   Divides Integer
   436   Ring_and_Field Integer
   437   Efficient_Nat Integer
   438 
   439 code_modulename Haskell
   440   Nat Integer
   441   Divides Integer
   442   Ring_and_Field Integer
   443   Efficient_Nat Integer
   444 
   445 hide const int
   446 
   447 end