src/HOL/Library/Efficient_Nat.thy
author haftmann
Tue Sep 16 09:21:24 2008 +0200 (2008-09-16)
changeset 28228 7ebe8dc06cbb
parent 27673 52056ddac194
child 28346 b8390cd56b8f
permissions -rw-r--r--
evaluation using code generator
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     4 *)
     5 
     6 header {* Implementation of natural numbers by target-language integers *}
     7 
     8 theory Efficient_Nat
     9 imports Plain Code_Index Code_Integer
    10 begin
    11 
    12 text {*
    13   When generating code for functions on natural numbers, the
    14   canonical representation using @{term "0::nat"} and
    15   @{term "Suc"} is unsuitable for computations involving large
    16   numbers.  The efficiency of the generated code can be improved
    17   drastically by implementing natural numbers by target-language
    18   integers.  To do this, just include this theory.
    19 *}
    20 
    21 subsection {* Basic arithmetic *}
    22 
    23 text {*
    24   Most standard arithmetic functions on natural numbers are implemented
    25   using their counterparts on the integers:
    26 *}
    27 
    28 code_datatype number_nat_inst.number_of_nat
    29 
    30 lemma zero_nat_code [code, code unfold]:
    31   "0 = (Numeral0 :: nat)"
    32   by simp
    33 lemmas [code post] = zero_nat_code [symmetric]
    34 
    35 lemma one_nat_code [code, code unfold]:
    36   "1 = (Numeral1 :: nat)"
    37   by simp
    38 lemmas [code post] = one_nat_code [symmetric]
    39 
    40 lemma Suc_code [code]:
    41   "Suc n = n + 1"
    42   by simp
    43 
    44 lemma plus_nat_code [code]:
    45   "n + m = nat (of_nat n + of_nat m)"
    46   by simp
    47 
    48 lemma minus_nat_code [code]:
    49   "n - m = nat (of_nat n - of_nat m)"
    50   by simp
    51 
    52 lemma times_nat_code [code]:
    53   "n * m = nat (of_nat n * of_nat m)"
    54   unfolding of_nat_mult [symmetric] by simp
    55 
    56 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    57   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    58 
    59 definition
    60   divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat"
    61 where
    62   [code func del]: "divmod_aux = divmod"
    63 
    64 lemma [code func]:
    65   "divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
    66   unfolding divmod_aux_def divmod_div_mod by simp
    67 
    68 lemma divmod_aux_code [code]:
    69   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    70   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    71 
    72 lemma eq_nat_code [code]:
    73   "n = m \<longleftrightarrow> (of_nat n \<Colon> int) = of_nat m"
    74   by simp
    75 
    76 lemma less_eq_nat_code [code]:
    77   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    78   by simp
    79 
    80 lemma less_nat_code [code]:
    81   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    82   by simp
    83 
    84 subsection {* Case analysis *}
    85 
    86 text {*
    87   Case analysis on natural numbers is rephrased using a conditional
    88   expression:
    89 *}
    90 
    91 lemma [code func, code unfold]:
    92   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    93   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    94 
    95 
    96 subsection {* Preprocessors *}
    97 
    98 text {*
    99   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   100   a constructor term. Therefore, all occurrences of this term in a position
   101   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   102   equation or in the arguments of an inductive relation in an introduction
   103   rule) must be eliminated.
   104   This can be accomplished by applying the following transformation rules:
   105 *}
   106 
   107 lemma Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
   108   f n = (if n = 0 then g else h (n - 1))"
   109   by (case_tac n) simp_all
   110 
   111 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   112   by (case_tac n) simp_all
   113 
   114 text {*
   115   The rules above are built into a preprocessor that is plugged into
   116   the code generator. Since the preprocessor for introduction rules
   117   does not know anything about modes, some of the modes that worked
   118   for the canonical representation of natural numbers may no longer work.
   119 *}
   120 
   121 (*<*)
   122 setup {*
   123 let
   124 
   125 fun remove_suc thy thms =
   126   let
   127     val vname = Name.variant (map fst
   128       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   129     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   130     fun lhs_of th = snd (Thm.dest_comb
   131       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
   132     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
   133     fun find_vars ct = (case term_of ct of
   134         (Const ("Suc", _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   135       | _ $ _ =>
   136         let val (ct1, ct2) = Thm.dest_comb ct
   137         in 
   138           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   139           map (apfst (Thm.capply ct1)) (find_vars ct2)
   140         end
   141       | _ => []);
   142     val eqs = maps
   143       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   144     fun mk_thms (th, (ct, cv')) =
   145       let
   146         val th' =
   147           Thm.implies_elim
   148            (Conv.fconv_rule (Thm.beta_conversion true)
   149              (Drule.instantiate'
   150                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   151                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   152                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   153       in
   154         case map_filter (fn th'' =>
   155             SOME (th'', singleton
   156               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   157           handle THM _ => NONE) thms of
   158             [] => NONE
   159           | thps =>
   160               let val (ths1, ths2) = split_list thps
   161               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   162       end
   163   in case get_first mk_thms eqs of
   164       NONE => thms
   165     | SOME x => remove_suc thy x
   166   end;
   167 
   168 fun eqn_suc_preproc thy ths =
   169   let
   170     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of;
   171     fun contains_suc t = member (op =) (term_consts t) @{const_name Suc};
   172   in
   173     if forall (can dest) ths andalso
   174       exists (contains_suc o dest) ths
   175     then remove_suc thy ths else ths
   176   end;
   177 
   178 fun remove_suc_clause thy thms =
   179   let
   180     val vname = Name.variant (map fst
   181       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   182     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   183       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   184       | find_var _ = NONE;
   185     fun find_thm th =
   186       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   187       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   188   in
   189     case get_first find_thm thms of
   190       NONE => thms
   191     | SOME ((th, th'), (Sucv, v)) =>
   192         let
   193           val cert = cterm_of (Thm.theory_of_thm th);
   194           val th'' = ObjectLogic.rulify (Thm.implies_elim
   195             (Conv.fconv_rule (Thm.beta_conversion true)
   196               (Drule.instantiate' []
   197                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   198                    abstract_over (Sucv,
   199                      HOLogic.dest_Trueprop (prop_of th')))))),
   200                  SOME (cert v)] @{thm Suc_clause}))
   201             (Thm.forall_intr (cert v) th'))
   202         in
   203           remove_suc_clause thy (map (fn th''' =>
   204             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   205         end
   206   end;
   207 
   208 fun clause_suc_preproc thy ths =
   209   let
   210     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   211   in
   212     if forall (can (dest o concl_of)) ths andalso
   213       exists (fn th => member (op =) (foldr add_term_consts
   214         [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths
   215     then remove_suc_clause thy ths else ths
   216   end;
   217 
   218 fun lift f thy thms1 =
   219   let
   220     val thms2 = Drule.zero_var_indexes_list thms1;
   221     val thms3 = try (map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
   222       #> f thy
   223       #> map (fn thm => thm RS @{thm eq_reflection})
   224       #> map (Conv.fconv_rule Drule.beta_eta_conversion)) thms2;
   225     val thms4 = Option.map Drule.zero_var_indexes_list thms3;
   226   in case thms4
   227    of NONE => NONE
   228     | SOME thms4 => if Thm.eq_thms (thms2, thms4) then NONE else SOME thms4
   229   end
   230 
   231 in
   232 
   233   Codegen.add_preprocessor eqn_suc_preproc
   234   #> Codegen.add_preprocessor clause_suc_preproc
   235   #> Code.add_functrans ("eqn_Suc", lift eqn_suc_preproc)
   236   #> Code.add_functrans ("clause_Suc", lift clause_suc_preproc)
   237 
   238 end;
   239 *}
   240 (*>*)
   241 
   242 
   243 subsection {* Target language setup *}
   244 
   245 text {*
   246   For ML, we map @{typ nat} to target language integers, where we
   247   assert that values are always non-negative.
   248 *}
   249 
   250 code_type nat
   251   (SML "IntInf.int")
   252   (OCaml "Big'_int.big'_int")
   253 
   254 types_code
   255   nat ("int")
   256 attach (term_of) {*
   257 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   258 *}
   259 attach (test) {*
   260 fun gen_nat i =
   261   let val n = random_range 0 i
   262   in (n, fn () => term_of_nat n) end;
   263 *}
   264 
   265 text {*
   266   For Haskell we define our own @{typ nat} type.  The reason
   267   is that we have to distinguish type class instances
   268   for @{typ nat} and @{typ int}.
   269 *}
   270 
   271 code_include Haskell "Nat" {*
   272 newtype Nat = Nat Integer deriving (Show, Eq);
   273 
   274 instance Num Nat where {
   275   fromInteger k = Nat (if k >= 0 then k else 0);
   276   Nat n + Nat m = Nat (n + m);
   277   Nat n - Nat m = fromInteger (n - m);
   278   Nat n * Nat m = Nat (n * m);
   279   abs n = n;
   280   signum _ = 1;
   281   negate n = error "negate Nat";
   282 };
   283 
   284 instance Ord Nat where {
   285   Nat n <= Nat m = n <= m;
   286   Nat n < Nat m = n < m;
   287 };
   288 
   289 instance Real Nat where {
   290   toRational (Nat n) = toRational n;
   291 };
   292 
   293 instance Enum Nat where {
   294   toEnum k = fromInteger (toEnum k);
   295   fromEnum (Nat n) = fromEnum n;
   296 };
   297 
   298 instance Integral Nat where {
   299   toInteger (Nat n) = n;
   300   divMod n m = quotRem n m;
   301   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   302 };
   303 *}
   304 
   305 code_reserved Haskell Nat
   306 
   307 code_type nat
   308   (Haskell "Nat")
   309 
   310 code_instance nat :: eq
   311   (Haskell -)
   312 
   313 text {*
   314   Natural numerals.
   315 *}
   316 
   317 lemma [code inline, symmetric, code post]:
   318   "nat (number_of i) = number_nat_inst.number_of_nat i"
   319   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   320   by (simp add: number_nat_inst.number_of_nat)
   321 
   322 setup {*
   323   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   324     true false) ["SML", "OCaml", "Haskell"]
   325 *}
   326 
   327 text {*
   328   Since natural numbers are implemented
   329   using integers in ML, the coercion function @{const "of_nat"} of type
   330   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   331   For the @{const "nat"} function for converting an integer to a natural
   332   number, we give a specific implementation using an ML function that
   333   returns its input value, provided that it is non-negative, and otherwise
   334   returns @{text "0"}.
   335 *}
   336 
   337 definition
   338   int :: "nat \<Rightarrow> int"
   339 where
   340   [code func del]: "int = of_nat"
   341 
   342 lemma int_code' [code func]:
   343   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   344   unfolding int_nat_number_of [folded int_def] ..
   345 
   346 lemma nat_code' [code func]:
   347   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   348   by auto
   349 
   350 lemma of_nat_int [code unfold]:
   351   "of_nat = int" by (simp add: int_def)
   352 declare of_nat_int [symmetric, code post]
   353 
   354 code_const int
   355   (SML "_")
   356   (OCaml "_")
   357 
   358 consts_code
   359   int ("(_)")
   360   nat ("\<module>nat")
   361 attach {*
   362 fun nat i = if i < 0 then 0 else i;
   363 *}
   364 
   365 code_const nat
   366   (SML "IntInf.max/ (/0,/ _)")
   367   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   368 
   369 text {* For Haskell, things are slightly different again. *}
   370 
   371 code_const int and nat
   372   (Haskell "toInteger" and "fromInteger")
   373 
   374 text {* Conversion from and to indices. *}
   375 
   376 code_const index_of_nat
   377   (SML "IntInf.toInt")
   378   (OCaml "Big'_int.int'_of'_big'_int")
   379   (Haskell "fromEnum")
   380 
   381 code_const nat_of_index
   382   (SML "IntInf.fromInt")
   383   (OCaml "Big'_int.big'_int'_of'_int")
   384   (Haskell "toEnum")
   385 
   386 text {* Using target language arithmetic operations whenever appropriate *}
   387 
   388 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   389   (SML "IntInf.+ ((_), (_))")
   390   (OCaml "Big'_int.add'_big'_int")
   391   (Haskell infixl 6 "+")
   392 
   393 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   394   (SML "IntInf.* ((_), (_))")
   395   (OCaml "Big'_int.mult'_big'_int")
   396   (Haskell infixl 7 "*")
   397 
   398 code_const divmod_aux
   399   (SML "IntInf.divMod/ ((_),/ (_))")
   400   (OCaml "Big'_int.quomod'_big'_int")
   401   (Haskell "divMod")
   402 
   403 code_const "op = \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   404   (SML "!((_ : IntInf.int) = _)")
   405   (OCaml "Big'_int.eq'_big'_int")
   406   (Haskell infixl 4 "==")
   407 
   408 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   409   (SML "IntInf.<= ((_), (_))")
   410   (OCaml "Big'_int.le'_big'_int")
   411   (Haskell infix 4 "<=")
   412 
   413 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   414   (SML "IntInf.< ((_), (_))")
   415   (OCaml "Big'_int.lt'_big'_int")
   416   (Haskell infix 4 "<")
   417 
   418 consts_code
   419   0                            ("0")
   420   Suc                          ("(_ +/ 1)")
   421   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   422   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   423   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   424   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   425 
   426 
   427 text {* Evaluation *}
   428 
   429 lemma [code func, code func del]:
   430   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
   431 
   432 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
   433   (SML "HOLogic.mk'_number/ HOLogic.natT")
   434 
   435 
   436 text {* Module names *}
   437 
   438 code_modulename SML
   439   Nat Integer
   440   Divides Integer
   441   Efficient_Nat Integer
   442 
   443 code_modulename OCaml
   444   Nat Integer
   445   Divides Integer
   446   Efficient_Nat Integer
   447 
   448 code_modulename Haskell
   449   Nat Integer
   450   Divides Integer
   451   Efficient_Nat Integer
   452 
   453 hide const int
   454 
   455 end