src/HOL/Library/EfficientNat.thy
author huffman
Wed Jun 13 03:31:11 2007 +0200 (2007-06-13)
changeset 23365 f31794033ae1
parent 23269 851b8ea067ac
child 23394 474ff28210c0
permissions -rw-r--r--
removed constant int :: nat => int;
int is now an abbreviation for of_nat :: nat => int
     1 (*  Title:      HOL/Library/EfficientNat.thy
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen
     4 *)
     5 
     6 header {* Implementation of natural numbers by integers *}
     7 
     8 theory EfficientNat
     9 imports Main Pretty_Int
    10 begin
    11 
    12 text {*
    13 When generating code for functions on natural numbers, the canonical
    14 representation using @{term "0::nat"} and @{term "Suc"} is unsuitable for
    15 computations involving large numbers. The efficiency of the generated
    16 code can be improved drastically by implementing natural numbers by
    17 integers. To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Logical rewrites *}
    21 
    22 text {*
    23   An int-to-nat conversion
    24   restricted to non-negative ints (in contrast to @{const nat}).
    25   Note that this restriction has no logical relevance and
    26   is just a kind of proof hint -- nothing prevents you from 
    27   writing nonsense like @{term "nat_of_int (-4)"}
    28 *}
    29 
    30 definition
    31   nat_of_int :: "int \<Rightarrow> nat" where
    32   "k \<ge> 0 \<Longrightarrow> nat_of_int k = nat k"
    33 
    34 definition
    35   int' :: "nat \<Rightarrow> int" where
    36   "int' n = of_nat n"
    37 
    38 lemma int'_Suc [simp]: "int' (Suc n) = 1 + int' n"
    39 unfolding int'_def by simp
    40 
    41 lemma int'_add: "int' (m + n) = int' m + int' n"
    42 unfolding int'_def by (rule of_nat_add)
    43 
    44 lemma int'_mult: "int' (m * n) = int' m * int' n"
    45 unfolding int'_def by (rule of_nat_mult)
    46 
    47 lemma nat_of_int_of_number_of:
    48   fixes k
    49   assumes "k \<ge> 0"
    50   shows "number_of k = nat_of_int (number_of k)"
    51   unfolding nat_of_int_def [OF prems] nat_number_of_def number_of_is_id ..
    52 
    53 lemma nat_of_int_of_number_of_aux:
    54   fixes k
    55   assumes "Numeral.Pls \<le> k \<equiv> True"
    56   shows "k \<ge> 0"
    57   using prems unfolding Pls_def by simp
    58 
    59 lemma nat_of_int_int:
    60   "nat_of_int (int' n) = n"
    61   using nat_of_int_def int'_def by simp
    62 
    63 lemma eq_nat_of_int: "int' n = x \<Longrightarrow> n = nat_of_int x"
    64 by (erule subst, simp only: nat_of_int_int)
    65 
    66 text {*
    67   Case analysis on natural numbers is rephrased using a conditional
    68   expression:
    69 *}
    70 
    71 lemma [code unfold, code inline del]:
    72   "nat_case \<equiv> (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    73 proof -
    74   have rewrite: "\<And>f g n. nat_case f g n = (if n = 0 then f else g (n - 1))"
    75   proof -
    76     fix f g n
    77     show "nat_case f g n = (if n = 0 then f else g (n - 1))"
    78       by (cases n) simp_all
    79   qed
    80   show "nat_case \<equiv> (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    81     by (rule eq_reflection ext rewrite)+ 
    82 qed
    83 
    84 lemma [code inline]:
    85   "nat_case = (\<lambda>f g n. if n = 0 then f else g (nat_of_int (int' n - 1)))"
    86 proof (rule ext)+
    87   fix f g n
    88   show "nat_case f g n = (if n = 0 then f else g (nat_of_int (int' n - 1)))"
    89   by (cases n) (simp_all add: nat_of_int_int)
    90 qed
    91 
    92 text {*
    93   Most standard arithmetic functions on natural numbers are implemented
    94   using their counterparts on the integers:
    95 *}
    96 
    97 lemma [code func]: "0 = nat_of_int 0"
    98   by (simp add: nat_of_int_def)
    99 lemma [code func, code inline]:  "1 = nat_of_int 1"
   100   by (simp add: nat_of_int_def)
   101 lemma [code func]: "Suc n = nat_of_int (int' n + 1)"
   102   by (simp add: eq_nat_of_int)
   103 lemma [code]: "m + n = nat (int' m + int' n)"
   104   by (simp add: int'_def nat_eq_iff2)
   105 lemma [code func, code inline]: "m + n = nat_of_int (int' m + int' n)"
   106   by (simp add: eq_nat_of_int int'_add)
   107 lemma [code, code inline]: "m - n = nat (int' m - int' n)"
   108   by (simp add: int'_def nat_eq_iff2)
   109 lemma [code]: "m * n = nat (int' m * int' n)"
   110   unfolding int'_def
   111   by (simp add: of_nat_mult [symmetric] del: of_nat_mult)
   112 lemma [code func, code inline]: "m * n = nat_of_int (int' m * int' n)"
   113   by (simp add: eq_nat_of_int int'_mult)
   114 lemma [code]: "m div n = nat (int' m div int' n)"
   115   unfolding int'_def zdiv_int [symmetric] by simp
   116 lemma [code func]: "m div n = fst (Divides.divmod m n)"
   117   unfolding divmod_def by simp
   118 lemma [code]: "m mod n = nat (int' m mod int' n)"
   119   unfolding int'_def zmod_int [symmetric] by simp
   120 lemma [code func]: "m mod n = snd (Divides.divmod m n)"
   121   unfolding divmod_def by simp
   122 lemma [code, code inline]: "(m < n) \<longleftrightarrow> (int' m < int' n)"
   123   unfolding int'_def by simp
   124 lemma [code func, code inline]: "(m \<le> n) \<longleftrightarrow> (int' m \<le> int' n)"
   125   unfolding int'_def by simp
   126 lemma [code func, code inline]: "m = n \<longleftrightarrow> int' m = int' n"
   127   unfolding int'_def by simp
   128 lemma [code func]: "nat k = (if k < 0 then 0 else nat_of_int k)"
   129 proof (cases "k < 0")
   130   case True then show ?thesis by simp
   131 next
   132   case False then show ?thesis by (simp add: nat_of_int_def)
   133 qed
   134 lemma [code func]:
   135   "int_aux i n = (if int' n = 0 then i else int_aux (i + 1) (nat_of_int (int' n - 1)))"
   136 proof -
   137   have "0 < n \<Longrightarrow> int' n = 1 + int' (nat_of_int (int' n - 1))"
   138   proof -
   139     assume prem: "n > 0"
   140     then have "int' n - 1 \<ge> 0" unfolding int'_def by auto
   141     then have "nat_of_int (int' n - 1) = nat (int' n - 1)" by (simp add: nat_of_int_def)
   142     with prem show "int' n = 1 + int' (nat_of_int (int' n - 1))" unfolding int'_def by simp
   143   qed
   144   then show ?thesis unfolding int_aux_def int'_def by simp
   145 qed
   146 
   147 lemma div_nat_code [code func]:
   148   "m div k = nat_of_int (fst (divAlg (int' m, int' k)))"
   149   unfolding div_def [symmetric] int'_def zdiv_int [symmetric]
   150   unfolding int'_def [symmetric] nat_of_int_int ..
   151 
   152 lemma mod_nat_code [code func]:
   153   "m mod k = nat_of_int (snd (divAlg (int' m, int' k)))"
   154   unfolding mod_def [symmetric] int'_def zmod_int [symmetric]
   155   unfolding int'_def [symmetric] nat_of_int_int ..
   156 
   157 subsection {* Code generator setup for basic functions *}
   158 
   159 text {*
   160   @{typ nat} is no longer a datatype but embedded into the integers.
   161 *}
   162 
   163 code_datatype nat_of_int
   164 
   165 code_type nat
   166   (SML "IntInf.int")
   167   (OCaml "Big'_int.big'_int")
   168   (Haskell "Integer")
   169 
   170 types_code
   171   nat ("int")
   172 attach (term_of) {*
   173 val term_of_nat = HOLogic.mk_number HOLogic.natT o IntInf.fromInt;
   174 *}
   175 attach (test) {*
   176 fun gen_nat i = random_range 0 i;
   177 *}
   178 
   179 consts_code
   180   "0 \<Colon> nat" ("0")
   181   Suc ("(_ + 1)")
   182 
   183 text {*
   184   Since natural numbers are implemented
   185   using integers, the coercion function @{const "int"} of type
   186   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function,
   187   likewise @{const nat_of_int} of type @{typ "int \<Rightarrow> nat"}.
   188   For the @{const "nat"} function for converting an integer to a natural
   189   number, we give a specific implementation using an ML function that
   190   returns its input value, provided that it is non-negative, and otherwise
   191   returns @{text "0"}.
   192 *}
   193 
   194 consts_code
   195   int' ("(_)")
   196   nat ("\<module>nat")
   197 attach {*
   198 fun nat i = if i < 0 then 0 else i;
   199 *}
   200 
   201 code_const int'
   202   (SML "_")
   203   (OCaml "_")
   204   (Haskell "_")
   205 
   206 code_const nat_of_int
   207   (SML "_")
   208   (OCaml "_")
   209   (Haskell "_")
   210 
   211 
   212 subsection {* Preprocessors *}
   213 
   214 text {*
   215   Natural numerals should be expressed using @{const nat_of_int}.
   216 *}
   217 
   218 lemmas [code inline del] = nat_number_of_def
   219 
   220 ML {*
   221 fun nat_of_int_of_number_of thy cts =
   222   let
   223     val simplify_less = Simplifier.rewrite 
   224       (HOL_basic_ss addsimps (@{thms less_numeral_code} @ @{thms less_eq_numeral_code}));
   225     fun mk_rew (t, ty) =
   226       if ty = HOLogic.natT andalso IntInf.<= (0, HOLogic.dest_numeral t) then
   227         Thm.capply @{cterm "(op \<le>) Numeral.Pls"} (Thm.cterm_of thy t)
   228         |> simplify_less
   229         |> (fn thm => @{thm nat_of_int_of_number_of_aux} OF [thm])
   230         |> (fn thm => @{thm nat_of_int_of_number_of} OF [thm])
   231         |> (fn thm => @{thm eq_reflection} OF [thm])
   232         |> SOME
   233       else NONE
   234   in
   235     fold (HOLogic.add_numerals o Thm.term_of) cts []
   236     |> map_filter mk_rew
   237   end;
   238 *}
   239 
   240 setup {*
   241   CodegenData.add_inline_proc ("nat_of_int_of_number_of", nat_of_int_of_number_of)
   242 *}
   243 
   244 text {*
   245   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   246   a constructor term. Therefore, all occurrences of this term in a position
   247   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   248   equation or in the arguments of an inductive relation in an introduction
   249   rule) must be eliminated.
   250   This can be accomplished by applying the following transformation rules:
   251 *}
   252 
   253 theorem Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
   254   f n = (if n = 0 then g else h (n - 1))"
   255   by (case_tac n) simp_all
   256 
   257 theorem Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   258   by (case_tac n) simp_all
   259 
   260 text {*
   261   The rules above are built into a preprocessor that is plugged into
   262   the code generator. Since the preprocessor for introduction rules
   263   does not know anything about modes, some of the modes that worked
   264   for the canonical representation of natural numbers may no longer work.
   265 *}
   266 
   267 (*<*)
   268 
   269 ML {*
   270 local
   271   val Suc_if_eq = thm "Suc_if_eq";
   272   val Suc_clause = thm "Suc_clause";
   273   fun contains_suc t = member (op =) (term_consts t) "Suc";
   274 in
   275 
   276 fun remove_suc thy thms =
   277   let
   278     val Suc_if_eq' = Thm.transfer thy Suc_if_eq;
   279     val vname = Name.variant (map fst
   280       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   281     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   282     fun lhs_of th = snd (Thm.dest_comb
   283       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
   284     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
   285     fun find_vars ct = (case term_of ct of
   286         (Const ("Suc", _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   287       | _ $ _ =>
   288         let val (ct1, ct2) = Thm.dest_comb ct
   289         in 
   290           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   291           map (apfst (Thm.capply ct1)) (find_vars ct2)
   292         end
   293       | _ => []);
   294     val eqs = List.concat (map
   295       (fn th => map (pair th) (find_vars (lhs_of th))) thms);
   296     fun mk_thms (th, (ct, cv')) =
   297       let
   298         val th' =
   299           Thm.implies_elim
   300            (Conv.fconv_rule (Thm.beta_conversion true)
   301              (Drule.instantiate'
   302                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   303                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   304                Suc_if_eq')) (Thm.forall_intr cv' th)
   305       in
   306         case map_filter (fn th'' =>
   307             SOME (th'', singleton
   308               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   309           handle THM _ => NONE) thms of
   310             [] => NONE
   311           | thps =>
   312               let val (ths1, ths2) = split_list thps
   313               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   314       end
   315   in
   316     case get_first mk_thms eqs of
   317       NONE => thms
   318     | SOME x => remove_suc thy x
   319   end;
   320 
   321 fun eqn_suc_preproc thy ths =
   322   let
   323     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of
   324   in
   325     if forall (can dest) ths andalso
   326       exists (contains_suc o dest) ths
   327     then remove_suc thy ths else ths
   328   end;
   329 
   330 fun remove_suc_clause thy thms =
   331   let
   332     val Suc_clause' = Thm.transfer thy Suc_clause;
   333     val vname = Name.variant (map fst
   334       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   335     fun find_var (t as Const ("Suc", _) $ (v as Var _)) = SOME (t, v)
   336       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   337       | find_var _ = NONE;
   338     fun find_thm th =
   339       let val th' = ObjectLogic.atomize_thm th
   340       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   341   in
   342     case get_first find_thm thms of
   343       NONE => thms
   344     | SOME ((th, th'), (Sucv, v)) =>
   345         let
   346           val cert = cterm_of (Thm.theory_of_thm th);
   347           val th'' = ObjectLogic.rulify (Thm.implies_elim
   348             (Conv.fconv_rule (Thm.beta_conversion true)
   349               (Drule.instantiate' []
   350                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   351                    abstract_over (Sucv,
   352                      HOLogic.dest_Trueprop (prop_of th')))))),
   353                  SOME (cert v)] Suc_clause'))
   354             (Thm.forall_intr (cert v) th'))
   355         in
   356           remove_suc_clause thy (map (fn th''' =>
   357             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   358         end
   359   end;
   360 
   361 fun clause_suc_preproc thy ths =
   362   let
   363     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   364   in
   365     if forall (can (dest o concl_of)) ths andalso
   366       exists (fn th => member (op =) (foldr add_term_consts
   367         [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths
   368     then remove_suc_clause thy ths else ths
   369   end;
   370 
   371 end; (*local*)
   372 
   373 fun lift_obj_eq f thy =
   374   map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
   375   #> f thy
   376   #> map (fn thm => thm RS @{thm eq_reflection})
   377   #> map (Conv.fconv_rule Drule.beta_eta_conversion)
   378 *}
   379 
   380 setup {*
   381   Codegen.add_preprocessor eqn_suc_preproc
   382   #> Codegen.add_preprocessor clause_suc_preproc
   383   #> CodegenData.add_preproc ("eqn_Suc", lift_obj_eq eqn_suc_preproc)
   384   #> CodegenData.add_preproc ("clause_Suc", lift_obj_eq clause_suc_preproc)
   385 *}
   386 (*>*)
   387 
   388 subsection {* Module names *}
   389 
   390 code_modulename SML
   391   Nat Integer
   392   EfficientNat Integer
   393 
   394 code_modulename OCaml
   395   Nat Integer
   396   EfficientNat Integer
   397 
   398 code_modulename Haskell
   399   Nat Integer
   400   EfficientNat Integer
   401 
   402 hide const nat_of_int int'
   403 
   404 end