src/HOL/Library/Efficient_Nat.thy
author haftmann
Tue Oct 07 16:07:33 2008 +0200 (2008-10-07)
changeset 28522 eacb54d9e78d
parent 28423 9fc3befd8191
child 28562 4e74209f113e
permissions -rw-r--r--
only one theorem table for both code generators
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     4 *)
     5 
     6 header {* Implementation of natural numbers by target-language integers *}
     7 
     8 theory Efficient_Nat
     9 imports Plain Code_Index Code_Integer
    10 begin
    11 
    12 text {*
    13   When generating code for functions on natural numbers, the
    14   canonical representation using @{term "0::nat"} and
    15   @{term "Suc"} is unsuitable for computations involving large
    16   numbers.  The efficiency of the generated code can be improved
    17   drastically by implementing natural numbers by target-language
    18   integers.  To do this, just include this theory.
    19 *}
    20 
    21 subsection {* Basic arithmetic *}
    22 
    23 text {*
    24   Most standard arithmetic functions on natural numbers are implemented
    25   using their counterparts on the integers:
    26 *}
    27 
    28 code_datatype number_nat_inst.number_of_nat
    29 
    30 lemma zero_nat_code [code, code inline]:
    31   "0 = (Numeral0 :: nat)"
    32   by simp
    33 lemmas [code post] = zero_nat_code [symmetric]
    34 
    35 lemma one_nat_code [code, code inline]:
    36   "1 = (Numeral1 :: nat)"
    37   by simp
    38 lemmas [code post] = one_nat_code [symmetric]
    39 
    40 lemma Suc_code [code]:
    41   "Suc n = n + 1"
    42   by simp
    43 
    44 lemma plus_nat_code [code]:
    45   "n + m = nat (of_nat n + of_nat m)"
    46   by simp
    47 
    48 lemma minus_nat_code [code]:
    49   "n - m = nat (of_nat n - of_nat m)"
    50   by simp
    51 
    52 lemma times_nat_code [code]:
    53   "n * m = nat (of_nat n * of_nat m)"
    54   unfolding of_nat_mult [symmetric] by simp
    55 
    56 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    57   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    58 
    59 definition
    60   divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat"
    61 where
    62   [code del]: "divmod_aux = divmod"
    63 
    64 lemma [code]:
    65   "divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
    66   unfolding divmod_aux_def divmod_div_mod by simp
    67 
    68 lemma divmod_aux_code [code]:
    69   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    70   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    71 
    72 lemma eq_nat_code [code]:
    73   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    74   by (simp add: eq)
    75 
    76 lemma eq_nat_refl [code nbe]:
    77   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    78   by (rule HOL.eq_refl)
    79 
    80 lemma less_eq_nat_code [code]:
    81   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    82   by simp
    83 
    84 lemma less_nat_code [code]:
    85   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    86   by simp
    87 
    88 subsection {* Case analysis *}
    89 
    90 text {*
    91   Case analysis on natural numbers is rephrased using a conditional
    92   expression:
    93 *}
    94 
    95 lemma [code, code unfold]:
    96   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    97   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    98 
    99 
   100 subsection {* Preprocessors *}
   101 
   102 text {*
   103   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   104   a constructor term. Therefore, all occurrences of this term in a position
   105   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   106   equation or in the arguments of an inductive relation in an introduction
   107   rule) must be eliminated.
   108   This can be accomplished by applying the following transformation rules:
   109 *}
   110 
   111 lemma Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
   112   f n = (if n = 0 then g else h (n - 1))"
   113   by (case_tac n) simp_all
   114 
   115 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   116   by (case_tac n) simp_all
   117 
   118 text {*
   119   The rules above are built into a preprocessor that is plugged into
   120   the code generator. Since the preprocessor for introduction rules
   121   does not know anything about modes, some of the modes that worked
   122   for the canonical representation of natural numbers may no longer work.
   123 *}
   124 
   125 (*<*)
   126 setup {*
   127 let
   128 
   129 fun remove_suc thy thms =
   130   let
   131     val vname = Name.variant (map fst
   132       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   133     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   134     fun lhs_of th = snd (Thm.dest_comb
   135       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
   136     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
   137     fun find_vars ct = (case term_of ct of
   138         (Const ("Suc", _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   139       | _ $ _ =>
   140         let val (ct1, ct2) = Thm.dest_comb ct
   141         in 
   142           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   143           map (apfst (Thm.capply ct1)) (find_vars ct2)
   144         end
   145       | _ => []);
   146     val eqs = maps
   147       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   148     fun mk_thms (th, (ct, cv')) =
   149       let
   150         val th' =
   151           Thm.implies_elim
   152            (Conv.fconv_rule (Thm.beta_conversion true)
   153              (Drule.instantiate'
   154                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   155                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   156                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   157       in
   158         case map_filter (fn th'' =>
   159             SOME (th'', singleton
   160               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   161           handle THM _ => NONE) thms of
   162             [] => NONE
   163           | thps =>
   164               let val (ths1, ths2) = split_list thps
   165               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   166       end
   167   in case get_first mk_thms eqs of
   168       NONE => thms
   169     | SOME x => remove_suc thy x
   170   end;
   171 
   172 fun eqn_suc_preproc thy ths =
   173   let
   174     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of;
   175     fun contains_suc t = member (op =) (term_consts t) @{const_name Suc};
   176   in
   177     if forall (can dest) ths andalso
   178       exists (contains_suc o dest) ths
   179     then remove_suc thy ths else ths
   180   end;
   181 
   182 fun remove_suc_clause thy thms =
   183   let
   184     val vname = Name.variant (map fst
   185       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   186     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   187       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   188       | find_var _ = NONE;
   189     fun find_thm th =
   190       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   191       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   192   in
   193     case get_first find_thm thms of
   194       NONE => thms
   195     | SOME ((th, th'), (Sucv, v)) =>
   196         let
   197           val cert = cterm_of (Thm.theory_of_thm th);
   198           val th'' = ObjectLogic.rulify (Thm.implies_elim
   199             (Conv.fconv_rule (Thm.beta_conversion true)
   200               (Drule.instantiate' []
   201                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   202                    abstract_over (Sucv,
   203                      HOLogic.dest_Trueprop (prop_of th')))))),
   204                  SOME (cert v)] @{thm Suc_clause}))
   205             (Thm.forall_intr (cert v) th'))
   206         in
   207           remove_suc_clause thy (map (fn th''' =>
   208             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   209         end
   210   end;
   211 
   212 fun clause_suc_preproc thy ths =
   213   let
   214     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   215   in
   216     if forall (can (dest o concl_of)) ths andalso
   217       exists (fn th => member (op =) (foldr add_term_consts
   218         [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths
   219     then remove_suc_clause thy ths else ths
   220   end;
   221 
   222 fun lift f thy eqns1 =
   223   let
   224     val eqns2 = burrow_fst Drule.zero_var_indexes_list eqns1;
   225     val thms3 = try (map fst
   226       #> map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
   227       #> f thy
   228       #> map (fn thm => thm RS @{thm eq_reflection})
   229       #> map (Conv.fconv_rule Drule.beta_eta_conversion)) eqns2;
   230     val thms4 = Option.map Drule.zero_var_indexes_list thms3;
   231   in case thms4
   232    of NONE => NONE
   233     | SOME thms4 => if Thm.eq_thms (map fst eqns2, thms4)
   234         then NONE else SOME (map (apfst (AxClass.overload thy) o Code_Unit.mk_eqn thy) thms4)
   235           (*FIXME*)
   236   end
   237 
   238 in
   239 
   240   Codegen.add_preprocessor eqn_suc_preproc
   241   #> Codegen.add_preprocessor clause_suc_preproc
   242   #> Code.add_functrans ("eqn_Suc", lift eqn_suc_preproc)
   243   #> Code.add_functrans ("clause_Suc", lift clause_suc_preproc)
   244 
   245 end;
   246 *}
   247 (*>*)
   248 
   249 
   250 subsection {* Target language setup *}
   251 
   252 text {*
   253   For ML, we map @{typ nat} to target language integers, where we
   254   assert that values are always non-negative.
   255 *}
   256 
   257 code_type nat
   258   (SML "IntInf.int")
   259   (OCaml "Big'_int.big'_int")
   260 
   261 types_code
   262   nat ("int")
   263 attach (term_of) {*
   264 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   265 *}
   266 attach (test) {*
   267 fun gen_nat i =
   268   let val n = random_range 0 i
   269   in (n, fn () => term_of_nat n) end;
   270 *}
   271 
   272 text {*
   273   For Haskell we define our own @{typ nat} type.  The reason
   274   is that we have to distinguish type class instances
   275   for @{typ nat} and @{typ int}.
   276 *}
   277 
   278 code_include Haskell "Nat" {*
   279 newtype Nat = Nat Integer deriving (Show, Eq);
   280 
   281 instance Num Nat where {
   282   fromInteger k = Nat (if k >= 0 then k else 0);
   283   Nat n + Nat m = Nat (n + m);
   284   Nat n - Nat m = fromInteger (n - m);
   285   Nat n * Nat m = Nat (n * m);
   286   abs n = n;
   287   signum _ = 1;
   288   negate n = error "negate Nat";
   289 };
   290 
   291 instance Ord Nat where {
   292   Nat n <= Nat m = n <= m;
   293   Nat n < Nat m = n < m;
   294 };
   295 
   296 instance Real Nat where {
   297   toRational (Nat n) = toRational n;
   298 };
   299 
   300 instance Enum Nat where {
   301   toEnum k = fromInteger (toEnum k);
   302   fromEnum (Nat n) = fromEnum n;
   303 };
   304 
   305 instance Integral Nat where {
   306   toInteger (Nat n) = n;
   307   divMod n m = quotRem n m;
   308   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   309 };
   310 *}
   311 
   312 code_reserved Haskell Nat
   313 
   314 code_type nat
   315   (Haskell "Nat")
   316 
   317 code_instance nat :: eq
   318   (Haskell -)
   319 
   320 text {*
   321   Natural numerals.
   322 *}
   323 
   324 lemma [code inline, symmetric, code post]:
   325   "nat (number_of i) = number_nat_inst.number_of_nat i"
   326   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   327   by (simp add: number_nat_inst.number_of_nat)
   328 
   329 setup {*
   330   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   331     true false) ["SML", "OCaml", "Haskell"]
   332 *}
   333 
   334 text {*
   335   Since natural numbers are implemented
   336   using integers in ML, the coercion function @{const "of_nat"} of type
   337   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   338   For the @{const "nat"} function for converting an integer to a natural
   339   number, we give a specific implementation using an ML function that
   340   returns its input value, provided that it is non-negative, and otherwise
   341   returns @{text "0"}.
   342 *}
   343 
   344 definition
   345   int :: "nat \<Rightarrow> int"
   346 where
   347   [code func del]: "int = of_nat"
   348 
   349 lemma int_code' [code func]:
   350   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   351   unfolding int_nat_number_of [folded int_def] ..
   352 
   353 lemma nat_code' [code func]:
   354   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   355   by auto
   356 
   357 lemma of_nat_int [code unfold]:
   358   "of_nat = int" by (simp add: int_def)
   359 declare of_nat_int [symmetric, code post]
   360 
   361 code_const int
   362   (SML "_")
   363   (OCaml "_")
   364 
   365 consts_code
   366   int ("(_)")
   367   nat ("\<module>nat")
   368 attach {*
   369 fun nat i = if i < 0 then 0 else i;
   370 *}
   371 
   372 code_const nat
   373   (SML "IntInf.max/ (/0,/ _)")
   374   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   375 
   376 text {* For Haskell, things are slightly different again. *}
   377 
   378 code_const int and nat
   379   (Haskell "toInteger" and "fromInteger")
   380 
   381 text {* Conversion from and to indices. *}
   382 
   383 code_const index_of_nat
   384   (SML "IntInf.toInt")
   385   (OCaml "Big'_int.int'_of'_big'_int")
   386   (Haskell "fromEnum")
   387 
   388 code_const nat_of_index
   389   (SML "IntInf.fromInt")
   390   (OCaml "Big'_int.big'_int'_of'_int")
   391   (Haskell "toEnum")
   392 
   393 text {* Using target language arithmetic operations whenever appropriate *}
   394 
   395 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   396   (SML "IntInf.+ ((_), (_))")
   397   (OCaml "Big'_int.add'_big'_int")
   398   (Haskell infixl 6 "+")
   399 
   400 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   401   (SML "IntInf.* ((_), (_))")
   402   (OCaml "Big'_int.mult'_big'_int")
   403   (Haskell infixl 7 "*")
   404 
   405 code_const divmod_aux
   406   (SML "IntInf.divMod/ ((_),/ (_))")
   407   (OCaml "Big'_int.quomod'_big'_int")
   408   (Haskell "divMod")
   409 
   410 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   411   (SML "!((_ : IntInf.int) = _)")
   412   (OCaml "Big'_int.eq'_big'_int")
   413   (Haskell infixl 4 "==")
   414 
   415 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   416   (SML "IntInf.<= ((_), (_))")
   417   (OCaml "Big'_int.le'_big'_int")
   418   (Haskell infix 4 "<=")
   419 
   420 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   421   (SML "IntInf.< ((_), (_))")
   422   (OCaml "Big'_int.lt'_big'_int")
   423   (Haskell infix 4 "<")
   424 
   425 consts_code
   426   "0::nat"                     ("0")
   427   "1::nat"                     ("1")
   428   Suc                          ("(_ +/ 1)")
   429   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   430   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   431   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   432   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   433 
   434 
   435 text {* Evaluation *}
   436 
   437 lemma [code func, code func del]:
   438   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
   439 
   440 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
   441   (SML "HOLogic.mk'_number/ HOLogic.natT")
   442 
   443 
   444 text {* Module names *}
   445 
   446 code_modulename SML
   447   Nat Integer
   448   Divides Integer
   449   Efficient_Nat Integer
   450 
   451 code_modulename OCaml
   452   Nat Integer
   453   Divides Integer
   454   Efficient_Nat Integer
   455 
   456 code_modulename Haskell
   457   Nat Integer
   458   Divides Integer
   459   Efficient_Nat Integer
   460 
   461 hide const int
   462 
   463 end