src/HOL/Library/Efficient_Nat.thy
author haftmann
Wed Jan 13 10:18:45 2010 +0100 (2010-01-13)
changeset 34893 ecdc526af73a
parent 33364 2bd12592c5e8
child 34899 8674bb6f727b
permissions -rw-r--r--
function transformer preprocessor applies to both code generators
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Integer Main
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term "Suc"} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code_unfold_post]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 
    33 lemma one_nat_code [code, code_unfold_post]:
    34   "1 = (Numeral1 :: nat)"
    35   by simp
    36 
    37 lemma Suc_code [code]:
    38   "Suc n = n + 1"
    39   by simp
    40 
    41 lemma plus_nat_code [code]:
    42   "n + m = nat (of_nat n + of_nat m)"
    43   by simp
    44 
    45 lemma minus_nat_code [code]:
    46   "n - m = nat (of_nat n - of_nat m)"
    47   by simp
    48 
    49 lemma times_nat_code [code]:
    50   "n * m = nat (of_nat n * of_nat m)"
    51   unfolding of_nat_mult [symmetric] by simp
    52 
    53 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    54   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    55 
    56 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    57   [code del]: "divmod_aux = divmod_nat"
    58 
    59 lemma [code]:
    60   "divmod_nat n m = (if m = 0 then (0, n) else divmod_aux n m)"
    61   unfolding divmod_aux_def divmod_nat_div_mod by simp
    62 
    63 lemma divmod_aux_code [code]:
    64   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    65   unfolding divmod_aux_def divmod_nat_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    66 
    67 lemma eq_nat_code [code]:
    68   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    69   by (simp add: eq)
    70 
    71 lemma eq_nat_refl [code nbe]:
    72   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    73   by (rule HOL.eq_refl)
    74 
    75 lemma less_eq_nat_code [code]:
    76   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    77   by simp
    78 
    79 lemma less_nat_code [code]:
    80   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    81   by simp
    82 
    83 subsection {* Case analysis *}
    84 
    85 text {*
    86   Case analysis on natural numbers is rephrased using a conditional
    87   expression:
    88 *}
    89 
    90 lemma [code, code_unfold]:
    91   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    92   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    93 
    94 
    95 subsection {* Preprocessors *}
    96 
    97 text {*
    98   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
    99   a constructor term. Therefore, all occurrences of this term in a position
   100   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   101   equation or in the arguments of an inductive relation in an introduction
   102   rule) must be eliminated.
   103   This can be accomplished by applying the following transformation rules:
   104 *}
   105 
   106 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
   107   f n \<equiv> if n = 0 then g else h (n - 1)"
   108   by (rule eq_reflection) (cases n, simp_all)
   109 
   110 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   111   by (cases n) simp_all
   112 
   113 text {*
   114   The rules above are built into a preprocessor that is plugged into
   115   the code generator. Since the preprocessor for introduction rules
   116   does not know anything about modes, some of the modes that worked
   117   for the canonical representation of natural numbers may no longer work.
   118 *}
   119 
   120 (*<*)
   121 setup {*
   122 let
   123 
   124 fun remove_suc thy thms =
   125   let
   126     val vname = Name.variant (map fst
   127       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
   128     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   129     fun lhs_of th = snd (Thm.dest_comb
   130       (fst (Thm.dest_comb (cprop_of th))));
   131     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
   132     fun find_vars ct = (case term_of ct of
   133         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   134       | _ $ _ =>
   135         let val (ct1, ct2) = Thm.dest_comb ct
   136         in 
   137           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   138           map (apfst (Thm.capply ct1)) (find_vars ct2)
   139         end
   140       | _ => []);
   141     val eqs = maps
   142       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   143     fun mk_thms (th, (ct, cv')) =
   144       let
   145         val th' =
   146           Thm.implies_elim
   147            (Conv.fconv_rule (Thm.beta_conversion true)
   148              (Drule.instantiate'
   149                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   150                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   151                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   152       in
   153         case map_filter (fn th'' =>
   154             SOME (th'', singleton
   155               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   156           handle THM _ => NONE) thms of
   157             [] => NONE
   158           | thps =>
   159               let val (ths1, ths2) = split_list thps
   160               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   161       end
   162   in get_first mk_thms eqs end;
   163 
   164 fun eqn_suc_base_preproc thy thms =
   165   let
   166     val dest = fst o Logic.dest_equals o prop_of;
   167     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   168   in
   169     if forall (can dest) thms andalso exists (contains_suc o dest) thms
   170       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
   171        else NONE
   172   end;
   173 
   174 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
   175 
   176 fun remove_suc_clause thy thms =
   177   let
   178     val vname = Name.variant (map fst
   179       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   180     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   181       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   182       | find_var _ = NONE;
   183     fun find_thm th =
   184       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   185       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   186   in
   187     case get_first find_thm thms of
   188       NONE => thms
   189     | SOME ((th, th'), (Sucv, v)) =>
   190         let
   191           val cert = cterm_of (Thm.theory_of_thm th);
   192           val th'' = ObjectLogic.rulify (Thm.implies_elim
   193             (Conv.fconv_rule (Thm.beta_conversion true)
   194               (Drule.instantiate' []
   195                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   196                    abstract_over (Sucv,
   197                      HOLogic.dest_Trueprop (prop_of th')))))),
   198                  SOME (cert v)] @{thm Suc_clause}))
   199             (Thm.forall_intr (cert v) th'))
   200         in
   201           remove_suc_clause thy (map (fn th''' =>
   202             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   203         end
   204   end;
   205 
   206 fun clause_suc_preproc thy ths =
   207   let
   208     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   209   in
   210     if forall (can (dest o concl_of)) ths andalso
   211       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   212         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   213     then remove_suc_clause thy ths else ths
   214   end;
   215 in
   216 
   217   Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
   218   #> Codegen.add_preprocessor clause_suc_preproc
   219 
   220 end;
   221 *}
   222 (*>*)
   223 
   224 
   225 subsection {* Target language setup *}
   226 
   227 text {*
   228   For ML, we map @{typ nat} to target language integers, where we
   229   assert that values are always non-negative.
   230 *}
   231 
   232 code_type nat
   233   (SML "IntInf.int")
   234   (OCaml "Big'_int.big'_int")
   235 
   236 types_code
   237   nat ("int")
   238 attach (term_of) {*
   239 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   240 *}
   241 attach (test) {*
   242 fun gen_nat i =
   243   let val n = random_range 0 i
   244   in (n, fn () => term_of_nat n) end;
   245 *}
   246 
   247 text {*
   248   For Haskell we define our own @{typ nat} type.  The reason
   249   is that we have to distinguish type class instances
   250   for @{typ nat} and @{typ int}.
   251 *}
   252 
   253 code_include Haskell "Nat" {*
   254 newtype Nat = Nat Integer deriving (Show, Eq);
   255 
   256 instance Num Nat where {
   257   fromInteger k = Nat (if k >= 0 then k else 0);
   258   Nat n + Nat m = Nat (n + m);
   259   Nat n - Nat m = fromInteger (n - m);
   260   Nat n * Nat m = Nat (n * m);
   261   abs n = n;
   262   signum _ = 1;
   263   negate n = error "negate Nat";
   264 };
   265 
   266 instance Ord Nat where {
   267   Nat n <= Nat m = n <= m;
   268   Nat n < Nat m = n < m;
   269 };
   270 
   271 instance Real Nat where {
   272   toRational (Nat n) = toRational n;
   273 };
   274 
   275 instance Enum Nat where {
   276   toEnum k = fromInteger (toEnum k);
   277   fromEnum (Nat n) = fromEnum n;
   278 };
   279 
   280 instance Integral Nat where {
   281   toInteger (Nat n) = n;
   282   divMod n m = quotRem n m;
   283   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   284 };
   285 *}
   286 
   287 code_reserved Haskell Nat
   288 
   289 code_type nat
   290   (Haskell "Nat.Nat")
   291 
   292 code_instance nat :: eq
   293   (Haskell -)
   294 
   295 text {*
   296   Natural numerals.
   297 *}
   298 
   299 lemma [code_unfold_post]:
   300   "nat (number_of i) = number_nat_inst.number_of_nat i"
   301   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   302   by (simp add: number_nat_inst.number_of_nat)
   303 
   304 setup {*
   305   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   306     false true) ["SML", "OCaml", "Haskell"]
   307 *}
   308 
   309 text {*
   310   Since natural numbers are implemented
   311   using integers in ML, the coercion function @{const "of_nat"} of type
   312   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   313   For the @{const "nat"} function for converting an integer to a natural
   314   number, we give a specific implementation using an ML function that
   315   returns its input value, provided that it is non-negative, and otherwise
   316   returns @{text "0"}.
   317 *}
   318 
   319 definition int :: "nat \<Rightarrow> int" where
   320   [code del]: "int = of_nat"
   321 
   322 lemma int_code' [code]:
   323   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   324   unfolding int_nat_number_of [folded int_def] ..
   325 
   326 lemma nat_code' [code]:
   327   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   328   unfolding nat_number_of_def number_of_is_id neg_def by simp
   329 
   330 lemma of_nat_int [code_unfold_post]:
   331   "of_nat = int" by (simp add: int_def)
   332 
   333 lemma of_nat_aux_int [code_unfold]:
   334   "of_nat_aux (\<lambda>i. i + 1) k 0 = int k"
   335   by (simp add: int_def Nat.of_nat_code)
   336 
   337 code_const int
   338   (SML "_")
   339   (OCaml "_")
   340 
   341 consts_code
   342   int ("(_)")
   343   nat ("\<module>nat")
   344 attach {*
   345 fun nat i = if i < 0 then 0 else i;
   346 *}
   347 
   348 code_const nat
   349   (SML "IntInf.max/ (/0,/ _)")
   350   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   351 
   352 text {* For Haskell, things are slightly different again. *}
   353 
   354 code_const int and nat
   355   (Haskell "toInteger" and "fromInteger")
   356 
   357 text {* Conversion from and to indices. *}
   358 
   359 code_const Code_Numeral.of_nat
   360   (SML "IntInf.toInt")
   361   (OCaml "_")
   362   (Haskell "fromEnum")
   363 
   364 code_const Code_Numeral.nat_of
   365   (SML "IntInf.fromInt")
   366   (OCaml "_")
   367   (Haskell "toEnum")
   368 
   369 text {* Using target language arithmetic operations whenever appropriate *}
   370 
   371 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   372   (SML "IntInf.+ ((_), (_))")
   373   (OCaml "Big'_int.add'_big'_int")
   374   (Haskell infixl 6 "+")
   375 
   376 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   377   (SML "IntInf.* ((_), (_))")
   378   (OCaml "Big'_int.mult'_big'_int")
   379   (Haskell infixl 7 "*")
   380 
   381 code_const divmod_aux
   382   (SML "IntInf.divMod/ ((_),/ (_))")
   383   (OCaml "Big'_int.quomod'_big'_int")
   384   (Haskell "divMod")
   385 
   386 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   387   (SML "!((_ : IntInf.int) = _)")
   388   (OCaml "Big'_int.eq'_big'_int")
   389   (Haskell infixl 4 "==")
   390 
   391 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   392   (SML "IntInf.<= ((_), (_))")
   393   (OCaml "Big'_int.le'_big'_int")
   394   (Haskell infix 4 "<=")
   395 
   396 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   397   (SML "IntInf.< ((_), (_))")
   398   (OCaml "Big'_int.lt'_big'_int")
   399   (Haskell infix 4 "<")
   400 
   401 consts_code
   402   "0::nat"                     ("0")
   403   "1::nat"                     ("1")
   404   Suc                          ("(_ +/ 1)")
   405   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   406   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   407   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   408   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   409 
   410 
   411 text {* Evaluation *}
   412 
   413 lemma [code, code del]:
   414   "(Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term) = Code_Evaluation.term_of" ..
   415 
   416 code_const "Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term"
   417   (SML "HOLogic.mk'_number/ HOLogic.natT")
   418 
   419 
   420 text {* Module names *}
   421 
   422 code_modulename SML
   423   Efficient_Nat Arith
   424 
   425 code_modulename OCaml
   426   Efficient_Nat Arith
   427 
   428 code_modulename Haskell
   429   Efficient_Nat Arith
   430 
   431 hide const int
   432 
   433 end