src/HOL/Library/Efficient_Nat.thy
author wenzelm
Tue Dec 30 21:46:48 2008 +0100 (2008-12-30)
changeset 29258 bce03c644efb
parent 28969 4ed63cdda799
child 29270 0eade173f77e
permissions -rw-r--r--
canonical Term.add_var_names;
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     4 *)
     5 
     6 header {* Implementation of natural numbers by target-language integers *}
     7 
     8 theory Efficient_Nat
     9 imports Code_Index Code_Integer
    10 begin
    11 
    12 text {*
    13   When generating code for functions on natural numbers, the
    14   canonical representation using @{term "0::nat"} and
    15   @{term "Suc"} is unsuitable for computations involving large
    16   numbers.  The efficiency of the generated code can be improved
    17   drastically by implementing natural numbers by target-language
    18   integers.  To do this, just include this theory.
    19 *}
    20 
    21 subsection {* Basic arithmetic *}
    22 
    23 text {*
    24   Most standard arithmetic functions on natural numbers are implemented
    25   using their counterparts on the integers:
    26 *}
    27 
    28 code_datatype number_nat_inst.number_of_nat
    29 
    30 lemma zero_nat_code [code, code inline]:
    31   "0 = (Numeral0 :: nat)"
    32   by simp
    33 lemmas [code post] = zero_nat_code [symmetric]
    34 
    35 lemma one_nat_code [code, code inline]:
    36   "1 = (Numeral1 :: nat)"
    37   by simp
    38 lemmas [code post] = one_nat_code [symmetric]
    39 
    40 lemma Suc_code [code]:
    41   "Suc n = n + 1"
    42   by simp
    43 
    44 lemma plus_nat_code [code]:
    45   "n + m = nat (of_nat n + of_nat m)"
    46   by simp
    47 
    48 lemma minus_nat_code [code]:
    49   "n - m = nat (of_nat n - of_nat m)"
    50   by simp
    51 
    52 lemma times_nat_code [code]:
    53   "n * m = nat (of_nat n * of_nat m)"
    54   unfolding of_nat_mult [symmetric] by simp
    55 
    56 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    57   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    58 
    59 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    60   [code del]: "divmod_aux = divmod"
    61 
    62 lemma [code]:
    63   "divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
    64   unfolding divmod_aux_def divmod_div_mod by simp
    65 
    66 lemma divmod_aux_code [code]:
    67   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    68   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    69 
    70 lemma eq_nat_code [code]:
    71   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    72   by (simp add: eq)
    73 
    74 lemma eq_nat_refl [code nbe]:
    75   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    76   by (rule HOL.eq_refl)
    77 
    78 lemma less_eq_nat_code [code]:
    79   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    80   by simp
    81 
    82 lemma less_nat_code [code]:
    83   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    84   by simp
    85 
    86 subsection {* Case analysis *}
    87 
    88 text {*
    89   Case analysis on natural numbers is rephrased using a conditional
    90   expression:
    91 *}
    92 
    93 lemma [code, code unfold]:
    94   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    95   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    96 
    97 
    98 subsection {* Preprocessors *}
    99 
   100 text {*
   101   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   102   a constructor term. Therefore, all occurrences of this term in a position
   103   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   104   equation or in the arguments of an inductive relation in an introduction
   105   rule) must be eliminated.
   106   This can be accomplished by applying the following transformation rules:
   107 *}
   108 
   109 lemma Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
   110   f n = (if n = 0 then g else h (n - 1))"
   111   by (case_tac n) simp_all
   112 
   113 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   114   by (case_tac n) simp_all
   115 
   116 text {*
   117   The rules above are built into a preprocessor that is plugged into
   118   the code generator. Since the preprocessor for introduction rules
   119   does not know anything about modes, some of the modes that worked
   120   for the canonical representation of natural numbers may no longer work.
   121 *}
   122 
   123 (*<*)
   124 setup {*
   125 let
   126 
   127 fun remove_suc thy thms =
   128   let
   129     val vname = Name.variant (map fst
   130       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   131     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   132     fun lhs_of th = snd (Thm.dest_comb
   133       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
   134     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
   135     fun find_vars ct = (case term_of ct of
   136         (Const ("Suc", _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   137       | _ $ _ =>
   138         let val (ct1, ct2) = Thm.dest_comb ct
   139         in 
   140           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   141           map (apfst (Thm.capply ct1)) (find_vars ct2)
   142         end
   143       | _ => []);
   144     val eqs = maps
   145       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   146     fun mk_thms (th, (ct, cv')) =
   147       let
   148         val th' =
   149           Thm.implies_elim
   150            (Conv.fconv_rule (Thm.beta_conversion true)
   151              (Drule.instantiate'
   152                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   153                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   154                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   155       in
   156         case map_filter (fn th'' =>
   157             SOME (th'', singleton
   158               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   159           handle THM _ => NONE) thms of
   160             [] => NONE
   161           | thps =>
   162               let val (ths1, ths2) = split_list thps
   163               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   164       end
   165   in case get_first mk_thms eqs of
   166       NONE => thms
   167     | SOME x => remove_suc thy x
   168   end;
   169 
   170 fun eqn_suc_preproc thy ths =
   171   let
   172     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of;
   173     fun contains_suc t = member (op =) (term_consts t) @{const_name Suc};
   174   in
   175     if forall (can dest) ths andalso
   176       exists (contains_suc o dest) ths
   177     then remove_suc thy ths else ths
   178   end;
   179 
   180 fun remove_suc_clause thy thms =
   181   let
   182     val vname = Name.variant (map fst
   183       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   184     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   185       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   186       | find_var _ = NONE;
   187     fun find_thm th =
   188       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   189       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   190   in
   191     case get_first find_thm thms of
   192       NONE => thms
   193     | SOME ((th, th'), (Sucv, v)) =>
   194         let
   195           val cert = cterm_of (Thm.theory_of_thm th);
   196           val th'' = ObjectLogic.rulify (Thm.implies_elim
   197             (Conv.fconv_rule (Thm.beta_conversion true)
   198               (Drule.instantiate' []
   199                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   200                    abstract_over (Sucv,
   201                      HOLogic.dest_Trueprop (prop_of th')))))),
   202                  SOME (cert v)] @{thm Suc_clause}))
   203             (Thm.forall_intr (cert v) th'))
   204         in
   205           remove_suc_clause thy (map (fn th''' =>
   206             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   207         end
   208   end;
   209 
   210 fun clause_suc_preproc thy ths =
   211   let
   212     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   213   in
   214     if forall (can (dest o concl_of)) ths andalso
   215       exists (fn th => member (op =) (foldr add_term_consts
   216         [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths
   217     then remove_suc_clause thy ths else ths
   218   end;
   219 
   220 fun lift f thy eqns1 =
   221   let
   222     val eqns2 = burrow_fst Drule.zero_var_indexes_list eqns1;
   223     val thms3 = try (map fst
   224       #> map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
   225       #> f thy
   226       #> map (fn thm => thm RS @{thm eq_reflection})
   227       #> map (Conv.fconv_rule Drule.beta_eta_conversion)) eqns2;
   228     val thms4 = Option.map Drule.zero_var_indexes_list thms3;
   229   in case thms4
   230    of NONE => NONE
   231     | SOME thms4 => if Thm.eq_thms (map fst eqns2, thms4)
   232         then NONE else SOME (map (apfst (AxClass.overload thy) o Code_Unit.mk_eqn thy) thms4)
   233   end
   234 
   235 in
   236 
   237   Codegen.add_preprocessor eqn_suc_preproc
   238   #> Codegen.add_preprocessor clause_suc_preproc
   239   #> Code.add_functrans ("eqn_Suc", lift eqn_suc_preproc)
   240   #> Code.add_functrans ("clause_Suc", lift clause_suc_preproc)
   241 
   242 end;
   243 *}
   244 (*>*)
   245 
   246 
   247 subsection {* Target language setup *}
   248 
   249 text {*
   250   For ML, we map @{typ nat} to target language integers, where we
   251   assert that values are always non-negative.
   252 *}
   253 
   254 code_type nat
   255   (SML "IntInf.int")
   256   (OCaml "Big'_int.big'_int")
   257 
   258 types_code
   259   nat ("int")
   260 attach (term_of) {*
   261 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   262 *}
   263 attach (test) {*
   264 fun gen_nat i =
   265   let val n = random_range 0 i
   266   in (n, fn () => term_of_nat n) end;
   267 *}
   268 
   269 text {*
   270   For Haskell we define our own @{typ nat} type.  The reason
   271   is that we have to distinguish type class instances
   272   for @{typ nat} and @{typ int}.
   273 *}
   274 
   275 code_include Haskell "Nat" {*
   276 newtype Nat = Nat Integer deriving (Show, Eq);
   277 
   278 instance Num Nat where {
   279   fromInteger k = Nat (if k >= 0 then k else 0);
   280   Nat n + Nat m = Nat (n + m);
   281   Nat n - Nat m = fromInteger (n - m);
   282   Nat n * Nat m = Nat (n * m);
   283   abs n = n;
   284   signum _ = 1;
   285   negate n = error "negate Nat";
   286 };
   287 
   288 instance Ord Nat where {
   289   Nat n <= Nat m = n <= m;
   290   Nat n < Nat m = n < m;
   291 };
   292 
   293 instance Real Nat where {
   294   toRational (Nat n) = toRational n;
   295 };
   296 
   297 instance Enum Nat where {
   298   toEnum k = fromInteger (toEnum k);
   299   fromEnum (Nat n) = fromEnum n;
   300 };
   301 
   302 instance Integral Nat where {
   303   toInteger (Nat n) = n;
   304   divMod n m = quotRem n m;
   305   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   306 };
   307 *}
   308 
   309 code_reserved Haskell Nat
   310 
   311 code_type nat
   312   (Haskell "Nat")
   313 
   314 code_instance nat :: eq
   315   (Haskell -)
   316 
   317 text {*
   318   Natural numerals.
   319 *}
   320 
   321 lemma [code inline, symmetric, code post]:
   322   "nat (number_of i) = number_nat_inst.number_of_nat i"
   323   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   324   by (simp add: number_nat_inst.number_of_nat)
   325 
   326 setup {*
   327   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   328     true false) ["SML", "OCaml", "Haskell"]
   329 *}
   330 
   331 text {*
   332   Since natural numbers are implemented
   333   using integers in ML, the coercion function @{const "of_nat"} of type
   334   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   335   For the @{const "nat"} function for converting an integer to a natural
   336   number, we give a specific implementation using an ML function that
   337   returns its input value, provided that it is non-negative, and otherwise
   338   returns @{text "0"}.
   339 *}
   340 
   341 definition
   342   int :: "nat \<Rightarrow> int"
   343 where
   344   [code del]: "int = of_nat"
   345 
   346 lemma int_code' [code]:
   347   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   348   unfolding int_nat_number_of [folded int_def] ..
   349 
   350 lemma nat_code' [code]:
   351   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   352   unfolding nat_number_of_def number_of_is_id neg_def by simp
   353 
   354 lemma of_nat_int [code unfold]:
   355   "of_nat = int" by (simp add: int_def)
   356 declare of_nat_int [symmetric, code post]
   357 
   358 code_const int
   359   (SML "_")
   360   (OCaml "_")
   361 
   362 consts_code
   363   int ("(_)")
   364   nat ("\<module>nat")
   365 attach {*
   366 fun nat i = if i < 0 then 0 else i;
   367 *}
   368 
   369 code_const nat
   370   (SML "IntInf.max/ (/0,/ _)")
   371   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   372 
   373 text {* For Haskell, things are slightly different again. *}
   374 
   375 code_const int and nat
   376   (Haskell "toInteger" and "fromInteger")
   377 
   378 text {* Conversion from and to indices. *}
   379 
   380 code_const index_of_nat
   381   (SML "IntInf.toInt")
   382   (OCaml "Big'_int.int'_of'_big'_int")
   383   (Haskell "fromEnum")
   384 
   385 code_const nat_of_index
   386   (SML "IntInf.fromInt")
   387   (OCaml "Big'_int.big'_int'_of'_int")
   388   (Haskell "toEnum")
   389 
   390 text {* Using target language arithmetic operations whenever appropriate *}
   391 
   392 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   393   (SML "IntInf.+ ((_), (_))")
   394   (OCaml "Big'_int.add'_big'_int")
   395   (Haskell infixl 6 "+")
   396 
   397 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   398   (SML "IntInf.* ((_), (_))")
   399   (OCaml "Big'_int.mult'_big'_int")
   400   (Haskell infixl 7 "*")
   401 
   402 code_const divmod_aux
   403   (SML "IntInf.divMod/ ((_),/ (_))")
   404   (OCaml "Big'_int.quomod'_big'_int")
   405   (Haskell "divMod")
   406 
   407 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   408   (SML "!((_ : IntInf.int) = _)")
   409   (OCaml "Big'_int.eq'_big'_int")
   410   (Haskell infixl 4 "==")
   411 
   412 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   413   (SML "IntInf.<= ((_), (_))")
   414   (OCaml "Big'_int.le'_big'_int")
   415   (Haskell infix 4 "<=")
   416 
   417 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   418   (SML "IntInf.< ((_), (_))")
   419   (OCaml "Big'_int.lt'_big'_int")
   420   (Haskell infix 4 "<")
   421 
   422 consts_code
   423   "0::nat"                     ("0")
   424   "1::nat"                     ("1")
   425   Suc                          ("(_ +/ 1)")
   426   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   427   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   428   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   429   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   430 
   431 
   432 text {* Evaluation *}
   433 
   434 lemma [code, code del]:
   435   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
   436 
   437 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
   438   (SML "HOLogic.mk'_number/ HOLogic.natT")
   439 
   440 
   441 text {* Module names *}
   442 
   443 code_modulename SML
   444   Nat Integer
   445   Divides Integer
   446   Ring_and_Field Integer
   447   Efficient_Nat Integer
   448 
   449 code_modulename OCaml
   450   Nat Integer
   451   Divides Integer
   452   Ring_and_Field Integer
   453   Efficient_Nat Integer
   454 
   455 code_modulename Haskell
   456   Nat Integer
   457   Divides Integer
   458   Ring_and_Field Integer
   459   Efficient_Nat Integer
   460 
   461 hide const int
   462 
   463 end