src/HOL/Library/Efficient_Nat.thy
author haftmann
Thu Sep 25 10:17:22 2008 +0200 (2008-09-25)
changeset 28351 abfc66969d1f
parent 28346 b8390cd56b8f
child 28423 9fc3befd8191
permissions -rw-r--r--
non left-linear equations for nbe
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     4 *)
     5 
     6 header {* Implementation of natural numbers by target-language integers *}
     7 
     8 theory Efficient_Nat
     9 imports Plain Code_Index Code_Integer
    10 begin
    11 
    12 text {*
    13   When generating code for functions on natural numbers, the
    14   canonical representation using @{term "0::nat"} and
    15   @{term "Suc"} is unsuitable for computations involving large
    16   numbers.  The efficiency of the generated code can be improved
    17   drastically by implementing natural numbers by target-language
    18   integers.  To do this, just include this theory.
    19 *}
    20 
    21 subsection {* Basic arithmetic *}
    22 
    23 text {*
    24   Most standard arithmetic functions on natural numbers are implemented
    25   using their counterparts on the integers:
    26 *}
    27 
    28 code_datatype number_nat_inst.number_of_nat
    29 
    30 lemma zero_nat_code [code, code unfold]:
    31   "0 = (Numeral0 :: nat)"
    32   by simp
    33 lemmas [code post] = zero_nat_code [symmetric]
    34 
    35 lemma one_nat_code [code, code unfold]:
    36   "1 = (Numeral1 :: nat)"
    37   by simp
    38 lemmas [code post] = one_nat_code [symmetric]
    39 
    40 lemma Suc_code [code]:
    41   "Suc n = n + 1"
    42   by simp
    43 
    44 lemma plus_nat_code [code]:
    45   "n + m = nat (of_nat n + of_nat m)"
    46   by simp
    47 
    48 lemma minus_nat_code [code]:
    49   "n - m = nat (of_nat n - of_nat m)"
    50   by simp
    51 
    52 lemma times_nat_code [code]:
    53   "n * m = nat (of_nat n * of_nat m)"
    54   unfolding of_nat_mult [symmetric] by simp
    55 
    56 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    57   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    58 
    59 definition
    60   divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat"
    61 where
    62   [code func del]: "divmod_aux = divmod"
    63 
    64 lemma [code func]:
    65   "divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
    66   unfolding divmod_aux_def divmod_div_mod by simp
    67 
    68 lemma divmod_aux_code [code]:
    69   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    70   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    71 
    72 lemma eq_nat_code [code]:
    73   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    74   by (simp add: eq)
    75 
    76 lemma eq_nat_refl [code nbe]:
    77   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    78   by (rule HOL.eq_refl)
    79 
    80 lemma less_eq_nat_code [code]:
    81   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    82   by simp
    83 
    84 lemma less_nat_code [code]:
    85   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    86   by simp
    87 
    88 subsection {* Case analysis *}
    89 
    90 text {*
    91   Case analysis on natural numbers is rephrased using a conditional
    92   expression:
    93 *}
    94 
    95 lemma [code func, code unfold]:
    96   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    97   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    98 
    99 
   100 subsection {* Preprocessors *}
   101 
   102 text {*
   103   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   104   a constructor term. Therefore, all occurrences of this term in a position
   105   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   106   equation or in the arguments of an inductive relation in an introduction
   107   rule) must be eliminated.
   108   This can be accomplished by applying the following transformation rules:
   109 *}
   110 
   111 lemma Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
   112   f n = (if n = 0 then g else h (n - 1))"
   113   by (case_tac n) simp_all
   114 
   115 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   116   by (case_tac n) simp_all
   117 
   118 text {*
   119   The rules above are built into a preprocessor that is plugged into
   120   the code generator. Since the preprocessor for introduction rules
   121   does not know anything about modes, some of the modes that worked
   122   for the canonical representation of natural numbers may no longer work.
   123 *}
   124 
   125 (*<*)
   126 setup {*
   127 let
   128 
   129 fun remove_suc thy thms =
   130   let
   131     val vname = Name.variant (map fst
   132       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   133     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   134     fun lhs_of th = snd (Thm.dest_comb
   135       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
   136     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
   137     fun find_vars ct = (case term_of ct of
   138         (Const ("Suc", _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   139       | _ $ _ =>
   140         let val (ct1, ct2) = Thm.dest_comb ct
   141         in 
   142           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   143           map (apfst (Thm.capply ct1)) (find_vars ct2)
   144         end
   145       | _ => []);
   146     val eqs = maps
   147       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   148     fun mk_thms (th, (ct, cv')) =
   149       let
   150         val th' =
   151           Thm.implies_elim
   152            (Conv.fconv_rule (Thm.beta_conversion true)
   153              (Drule.instantiate'
   154                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   155                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   156                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   157       in
   158         case map_filter (fn th'' =>
   159             SOME (th'', singleton
   160               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   161           handle THM _ => NONE) thms of
   162             [] => NONE
   163           | thps =>
   164               let val (ths1, ths2) = split_list thps
   165               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   166       end
   167   in case get_first mk_thms eqs of
   168       NONE => thms
   169     | SOME x => remove_suc thy x
   170   end;
   171 
   172 fun eqn_suc_preproc thy ths =
   173   let
   174     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of;
   175     fun contains_suc t = member (op =) (term_consts t) @{const_name Suc};
   176   in
   177     if forall (can dest) ths andalso
   178       exists (contains_suc o dest) ths
   179     then remove_suc thy ths else ths
   180   end;
   181 
   182 fun remove_suc_clause thy thms =
   183   let
   184     val vname = Name.variant (map fst
   185       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   186     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   187       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   188       | find_var _ = NONE;
   189     fun find_thm th =
   190       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   191       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   192   in
   193     case get_first find_thm thms of
   194       NONE => thms
   195     | SOME ((th, th'), (Sucv, v)) =>
   196         let
   197           val cert = cterm_of (Thm.theory_of_thm th);
   198           val th'' = ObjectLogic.rulify (Thm.implies_elim
   199             (Conv.fconv_rule (Thm.beta_conversion true)
   200               (Drule.instantiate' []
   201                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   202                    abstract_over (Sucv,
   203                      HOLogic.dest_Trueprop (prop_of th')))))),
   204                  SOME (cert v)] @{thm Suc_clause}))
   205             (Thm.forall_intr (cert v) th'))
   206         in
   207           remove_suc_clause thy (map (fn th''' =>
   208             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   209         end
   210   end;
   211 
   212 fun clause_suc_preproc thy ths =
   213   let
   214     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   215   in
   216     if forall (can (dest o concl_of)) ths andalso
   217       exists (fn th => member (op =) (foldr add_term_consts
   218         [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths
   219     then remove_suc_clause thy ths else ths
   220   end;
   221 
   222 fun lift f thy thms1 =
   223   let
   224     val thms2 = Drule.zero_var_indexes_list thms1;
   225     val thms3 = try (map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
   226       #> f thy
   227       #> map (fn thm => thm RS @{thm eq_reflection})
   228       #> map (Conv.fconv_rule Drule.beta_eta_conversion)) thms2;
   229     val thms4 = Option.map Drule.zero_var_indexes_list thms3;
   230   in case thms4
   231    of NONE => NONE
   232     | SOME thms4 => if Thm.eq_thms (thms2, thms4) then NONE else SOME thms4
   233   end
   234 
   235 in
   236 
   237   Codegen.add_preprocessor eqn_suc_preproc
   238   #> Codegen.add_preprocessor clause_suc_preproc
   239   #> Code.add_functrans ("eqn_Suc", lift eqn_suc_preproc)
   240   #> Code.add_functrans ("clause_Suc", lift clause_suc_preproc)
   241 
   242 end;
   243 *}
   244 (*>*)
   245 
   246 
   247 subsection {* Target language setup *}
   248 
   249 text {*
   250   For ML, we map @{typ nat} to target language integers, where we
   251   assert that values are always non-negative.
   252 *}
   253 
   254 code_type nat
   255   (SML "IntInf.int")
   256   (OCaml "Big'_int.big'_int")
   257 
   258 types_code
   259   nat ("int")
   260 attach (term_of) {*
   261 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   262 *}
   263 attach (test) {*
   264 fun gen_nat i =
   265   let val n = random_range 0 i
   266   in (n, fn () => term_of_nat n) end;
   267 *}
   268 
   269 text {*
   270   For Haskell we define our own @{typ nat} type.  The reason
   271   is that we have to distinguish type class instances
   272   for @{typ nat} and @{typ int}.
   273 *}
   274 
   275 code_include Haskell "Nat" {*
   276 newtype Nat = Nat Integer deriving (Show, Eq);
   277 
   278 instance Num Nat where {
   279   fromInteger k = Nat (if k >= 0 then k else 0);
   280   Nat n + Nat m = Nat (n + m);
   281   Nat n - Nat m = fromInteger (n - m);
   282   Nat n * Nat m = Nat (n * m);
   283   abs n = n;
   284   signum _ = 1;
   285   negate n = error "negate Nat";
   286 };
   287 
   288 instance Ord Nat where {
   289   Nat n <= Nat m = n <= m;
   290   Nat n < Nat m = n < m;
   291 };
   292 
   293 instance Real Nat where {
   294   toRational (Nat n) = toRational n;
   295 };
   296 
   297 instance Enum Nat where {
   298   toEnum k = fromInteger (toEnum k);
   299   fromEnum (Nat n) = fromEnum n;
   300 };
   301 
   302 instance Integral Nat where {
   303   toInteger (Nat n) = n;
   304   divMod n m = quotRem n m;
   305   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   306 };
   307 *}
   308 
   309 code_reserved Haskell Nat
   310 
   311 code_type nat
   312   (Haskell "Nat")
   313 
   314 code_instance nat :: eq
   315   (Haskell -)
   316 
   317 text {*
   318   Natural numerals.
   319 *}
   320 
   321 lemma [code inline, symmetric, code post]:
   322   "nat (number_of i) = number_nat_inst.number_of_nat i"
   323   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   324   by (simp add: number_nat_inst.number_of_nat)
   325 
   326 setup {*
   327   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   328     true false) ["SML", "OCaml", "Haskell"]
   329 *}
   330 
   331 text {*
   332   Since natural numbers are implemented
   333   using integers in ML, the coercion function @{const "of_nat"} of type
   334   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   335   For the @{const "nat"} function for converting an integer to a natural
   336   number, we give a specific implementation using an ML function that
   337   returns its input value, provided that it is non-negative, and otherwise
   338   returns @{text "0"}.
   339 *}
   340 
   341 definition
   342   int :: "nat \<Rightarrow> int"
   343 where
   344   [code func del]: "int = of_nat"
   345 
   346 lemma int_code' [code func]:
   347   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   348   unfolding int_nat_number_of [folded int_def] ..
   349 
   350 lemma nat_code' [code func]:
   351   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   352   by auto
   353 
   354 lemma of_nat_int [code unfold]:
   355   "of_nat = int" by (simp add: int_def)
   356 declare of_nat_int [symmetric, code post]
   357 
   358 code_const int
   359   (SML "_")
   360   (OCaml "_")
   361 
   362 consts_code
   363   int ("(_)")
   364   nat ("\<module>nat")
   365 attach {*
   366 fun nat i = if i < 0 then 0 else i;
   367 *}
   368 
   369 code_const nat
   370   (SML "IntInf.max/ (/0,/ _)")
   371   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   372 
   373 text {* For Haskell, things are slightly different again. *}
   374 
   375 code_const int and nat
   376   (Haskell "toInteger" and "fromInteger")
   377 
   378 text {* Conversion from and to indices. *}
   379 
   380 code_const index_of_nat
   381   (SML "IntInf.toInt")
   382   (OCaml "Big'_int.int'_of'_big'_int")
   383   (Haskell "fromEnum")
   384 
   385 code_const nat_of_index
   386   (SML "IntInf.fromInt")
   387   (OCaml "Big'_int.big'_int'_of'_int")
   388   (Haskell "toEnum")
   389 
   390 text {* Using target language arithmetic operations whenever appropriate *}
   391 
   392 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   393   (SML "IntInf.+ ((_), (_))")
   394   (OCaml "Big'_int.add'_big'_int")
   395   (Haskell infixl 6 "+")
   396 
   397 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   398   (SML "IntInf.* ((_), (_))")
   399   (OCaml "Big'_int.mult'_big'_int")
   400   (Haskell infixl 7 "*")
   401 
   402 code_const divmod_aux
   403   (SML "IntInf.divMod/ ((_),/ (_))")
   404   (OCaml "Big'_int.quomod'_big'_int")
   405   (Haskell "divMod")
   406 
   407 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   408   (SML "!((_ : IntInf.int) = _)")
   409   (OCaml "Big'_int.eq'_big'_int")
   410   (Haskell infixl 4 "==")
   411 
   412 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   413   (SML "IntInf.<= ((_), (_))")
   414   (OCaml "Big'_int.le'_big'_int")
   415   (Haskell infix 4 "<=")
   416 
   417 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   418   (SML "IntInf.< ((_), (_))")
   419   (OCaml "Big'_int.lt'_big'_int")
   420   (Haskell infix 4 "<")
   421 
   422 consts_code
   423   0                            ("0")
   424   Suc                          ("(_ +/ 1)")
   425   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   426   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   427   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   428   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   429 
   430 
   431 text {* Evaluation *}
   432 
   433 lemma [code func, code func del]:
   434   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
   435 
   436 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
   437   (SML "HOLogic.mk'_number/ HOLogic.natT")
   438 
   439 
   440 text {* Module names *}
   441 
   442 code_modulename SML
   443   Nat Integer
   444   Divides Integer
   445   Efficient_Nat Integer
   446 
   447 code_modulename OCaml
   448   Nat Integer
   449   Divides Integer
   450   Efficient_Nat Integer
   451 
   452 code_modulename Haskell
   453   Nat Integer
   454   Divides Integer
   455   Efficient_Nat Integer
   456 
   457 hide const int
   458 
   459 end