src/HOL/Library/Efficient_Nat.thy
author haftmann
Fri Aug 24 14:14:20 2007 +0200 (2007-08-24)
changeset 24423 ae9cd0e92423
parent 24222 a8a28c15c5cc
child 24630 351a308ab58d
permissions -rw-r--r--
overloaded definitions accompanied by explicit constants
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen
     4 *)
     5 
     6 header {* Implementation of natural numbers by integers *}
     7 
     8 theory Efficient_Nat
     9 imports Main Pretty_Int
    10 begin
    11 
    12 text {*
    13 When generating code for functions on natural numbers, the canonical
    14 representation using @{term "0::nat"} and @{term "Suc"} is unsuitable for
    15 computations involving large numbers. The efficiency of the generated
    16 code can be improved drastically by implementing natural numbers by
    17 integers. To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Logical rewrites *}
    21 
    22 text {*
    23   An int-to-nat conversion
    24   restricted to non-negative ints (in contrast to @{const nat}).
    25   Note that this restriction has no logical relevance and
    26   is just a kind of proof hint -- nothing prevents you from 
    27   writing nonsense like @{term "nat_of_int (-4)"}
    28 *}
    29 
    30 definition
    31   nat_of_int :: "int \<Rightarrow> nat" where
    32   "k \<ge> 0 \<Longrightarrow> nat_of_int k = nat k"
    33 
    34 definition
    35   int' :: "nat \<Rightarrow> int" where
    36   "int' n = of_nat n"
    37 
    38 lemma int'_Suc [simp]: "int' (Suc n) = 1 + int' n"
    39 unfolding int'_def by simp
    40 
    41 lemma int'_add: "int' (m + n) = int' m + int' n"
    42 unfolding int'_def by (rule of_nat_add)
    43 
    44 lemma int'_mult: "int' (m * n) = int' m * int' n"
    45 unfolding int'_def by (rule of_nat_mult)
    46 
    47 lemma nat_of_int_of_number_of:
    48   fixes k
    49   assumes "k \<ge> 0"
    50   shows "number_of k = nat_of_int (number_of k)"
    51   unfolding nat_of_int_def [OF assms] nat_number_of_def number_of_is_id ..
    52 
    53 lemma nat_of_int_of_number_of_aux:
    54   fixes k
    55   assumes "Numeral.Pls \<le> k \<equiv> True"
    56   shows "k \<ge> 0"
    57   using assms unfolding Pls_def by simp
    58 
    59 lemma nat_of_int_int:
    60   "nat_of_int (int' n) = n"
    61   using nat_of_int_def int'_def by simp
    62 
    63 lemma eq_nat_of_int: "int' n = x \<Longrightarrow> n = nat_of_int x"
    64 by (erule subst, simp only: nat_of_int_int)
    65 
    66 code_datatype nat_of_int
    67 
    68 text {*
    69   Case analysis on natural numbers is rephrased using a conditional
    70   expression:
    71 *}
    72 
    73 lemma [code unfold, code inline del]:
    74   "nat_case \<equiv> (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    75 proof -
    76   have rewrite: "\<And>f g n. nat_case f g n = (if n = 0 then f else g (n - 1))"
    77   proof -
    78     fix f g n
    79     show "nat_case f g n = (if n = 0 then f else g (n - 1))"
    80       by (cases n) simp_all
    81   qed
    82   show "nat_case \<equiv> (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    83     by (rule eq_reflection ext rewrite)+ 
    84 qed
    85 
    86 lemma [code inline]:
    87   "nat_case = (\<lambda>f g n. if n = 0 then f else g (nat_of_int (int' n - 1)))"
    88 proof (rule ext)+
    89   fix f g n
    90   show "nat_case f g n = (if n = 0 then f else g (nat_of_int (int' n - 1)))"
    91   by (cases n) (simp_all add: nat_of_int_int)
    92 qed
    93 
    94 text {*
    95   Most standard arithmetic functions on natural numbers are implemented
    96   using their counterparts on the integers:
    97 *}
    98 
    99 lemma [code func]: "0 = nat_of_int 0"
   100   by (simp add: nat_of_int_def)
   101 lemma [code func, code inline]:  "1 = nat_of_int 1"
   102   by (simp add: nat_of_int_def)
   103 lemma [code func]: "Suc n = nat_of_int (int' n + 1)"
   104   by (simp add: eq_nat_of_int)
   105 lemma [code]: "m + n = nat (int' m + int' n)"
   106   by (simp add: int'_def nat_eq_iff2)
   107 lemma [code func, code inline]: "m + n = nat_of_int (int' m + int' n)"
   108   by (simp add: eq_nat_of_int int'_add)
   109 lemma [code, code inline]: "m - n = nat (int' m - int' n)"
   110   by (simp add: int'_def nat_eq_iff2 of_nat_diff)
   111 lemma [code]: "m * n = nat (int' m * int' n)"
   112   unfolding int'_def
   113   by (simp add: of_nat_mult [symmetric] del: of_nat_mult)
   114 lemma [code func, code inline]: "m * n = nat_of_int (int' m * int' n)"
   115   by (simp add: eq_nat_of_int int'_mult)
   116 lemma [code]: "m div n = nat (int' m div int' n)"
   117   unfolding int'_def zdiv_int [symmetric] by simp
   118 lemma [code func]: "m div n = fst (Divides.divmod m n)"
   119   unfolding divmod_def by simp
   120 lemma [code]: "m mod n = nat (int' m mod int' n)"
   121   unfolding int'_def zmod_int [symmetric] by simp
   122 lemma [code func]: "m mod n = snd (Divides.divmod m n)"
   123   unfolding divmod_def by simp
   124 lemma [code, code inline]: "(m < n) \<longleftrightarrow> (int' m < int' n)"
   125   unfolding int'_def by simp
   126 lemma [code func, code inline]: "(m \<le> n) \<longleftrightarrow> (int' m \<le> int' n)"
   127   unfolding int'_def by simp
   128 lemma [code func, code inline]: "m = n \<longleftrightarrow> int' m = int' n"
   129   unfolding int'_def by simp
   130 lemma [code func]: "nat k = (if k < 0 then 0 else nat_of_int k)"
   131 proof (cases "k < 0")
   132   case True then show ?thesis by simp
   133 next
   134   case False then show ?thesis by (simp add: nat_of_int_def)
   135 qed
   136 lemma [code func]:
   137   "int_aux n i = (if int' n = 0 then i else int_aux (nat_of_int (int' n - 1)) (i + 1))"
   138 proof -
   139   have "0 < n \<Longrightarrow> int' n = 1 + int' (nat_of_int (int' n - 1))"
   140   proof -
   141     assume prem: "n > 0"
   142     then have "int' n - 1 \<ge> 0" unfolding int'_def by auto
   143     then have "nat_of_int (int' n - 1) = nat (int' n - 1)" by (simp add: nat_of_int_def)
   144     with prem show "int' n = 1 + int' (nat_of_int (int' n - 1))" unfolding int'_def by simp
   145   qed
   146   then show ?thesis unfolding int_aux_def int'_def by auto
   147 qed
   148 
   149 lemma div_nat_code [code func]:
   150   "m div k = nat_of_int (fst (divAlg (int' m, int' k)))"
   151   unfolding div_def [symmetric] int'_def zdiv_int [symmetric]
   152   unfolding int'_def [symmetric] nat_of_int_int ..
   153 
   154 lemma mod_nat_code [code func]:
   155   "m mod k = nat_of_int (snd (divAlg (int' m, int' k)))"
   156   unfolding mod_def [symmetric] int'_def zmod_int [symmetric]
   157   unfolding int'_def [symmetric] nat_of_int_int ..
   158 
   159 
   160 subsection {* Code generator setup for basic functions *}
   161 
   162 text {*
   163   @{typ nat} is no longer a datatype but embedded into the integers.
   164 *}
   165 
   166 code_type nat
   167   (SML "IntInf.int")
   168   (OCaml "Big'_int.big'_int")
   169   (Haskell "Integer")
   170 
   171 types_code
   172   nat ("int")
   173 attach (term_of) {*
   174 val term_of_nat = HOLogic.mk_number HOLogic.natT o IntInf.fromInt;
   175 *}
   176 attach (test) {*
   177 fun gen_nat i = random_range 0 i;
   178 *}
   179 
   180 consts_code
   181   "0 \<Colon> nat" ("0")
   182   Suc ("(_ + 1)")
   183 
   184 text {*
   185   Since natural numbers are implemented
   186   using integers, the coercion function @{const "int"} of type
   187   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function,
   188   likewise @{const nat_of_int} of type @{typ "int \<Rightarrow> nat"}.
   189   For the @{const "nat"} function for converting an integer to a natural
   190   number, we give a specific implementation using an ML function that
   191   returns its input value, provided that it is non-negative, and otherwise
   192   returns @{text "0"}.
   193 *}
   194 
   195 consts_code
   196   int' ("(_)")
   197   nat ("\<module>nat")
   198 attach {*
   199 fun nat i = if i < 0 then 0 else i;
   200 *}
   201 
   202 code_const int'
   203   (SML "_")
   204   (OCaml "_")
   205   (Haskell "_")
   206 
   207 code_const nat_of_int
   208   (SML "_")
   209   (OCaml "_")
   210   (Haskell "_")
   211 
   212 
   213 subsection {* Preprocessors *}
   214 
   215 text {*
   216   Natural numerals should be expressed using @{const nat_of_int}.
   217 *}
   218 
   219 lemmas [code inline del] = nat_number_of_def
   220 
   221 ML {*
   222 fun nat_of_int_of_number_of thy cts =
   223   let
   224     val simplify_less = Simplifier.rewrite 
   225       (HOL_basic_ss addsimps (@{thms less_numeral_code} @ @{thms less_eq_numeral_code}));
   226     fun mk_rew (t, ty) =
   227       if ty = HOLogic.natT andalso IntInf.<= (0, HOLogic.dest_numeral t) then
   228         Thm.capply @{cterm "(op \<le>) Numeral.Pls"} (Thm.cterm_of thy t)
   229         |> simplify_less
   230         |> (fn thm => @{thm nat_of_int_of_number_of_aux} OF [thm])
   231         |> (fn thm => @{thm nat_of_int_of_number_of} OF [thm])
   232         |> (fn thm => @{thm eq_reflection} OF [thm])
   233         |> SOME
   234       else NONE
   235   in
   236     fold (HOLogic.add_numerals o Thm.term_of) cts []
   237     |> map_filter mk_rew
   238   end;
   239 *}
   240 
   241 setup {*
   242   Code.add_inline_proc ("nat_of_int_of_number_of", nat_of_int_of_number_of)
   243 *}
   244 
   245 text {*
   246   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   247   a constructor term. Therefore, all occurrences of this term in a position
   248   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   249   equation or in the arguments of an inductive relation in an introduction
   250   rule) must be eliminated.
   251   This can be accomplished by applying the following transformation rules:
   252 *}
   253 
   254 theorem Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
   255   f n = (if n = 0 then g else h (n - 1))"
   256   by (case_tac n) simp_all
   257 
   258 theorem Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   259   by (case_tac n) simp_all
   260 
   261 text {*
   262   The rules above are built into a preprocessor that is plugged into
   263   the code generator. Since the preprocessor for introduction rules
   264   does not know anything about modes, some of the modes that worked
   265   for the canonical representation of natural numbers may no longer work.
   266 *}
   267 
   268 (*<*)
   269 
   270 ML {*
   271 fun remove_suc thy thms =
   272   let
   273     val vname = Name.variant (map fst
   274       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   275     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   276     fun lhs_of th = snd (Thm.dest_comb
   277       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
   278     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
   279     fun find_vars ct = (case term_of ct of
   280         (Const ("Suc", _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   281       | _ $ _ =>
   282         let val (ct1, ct2) = Thm.dest_comb ct
   283         in 
   284           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   285           map (apfst (Thm.capply ct1)) (find_vars ct2)
   286         end
   287       | _ => []);
   288     val eqs = maps
   289       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   290     fun mk_thms (th, (ct, cv')) =
   291       let
   292         val th' =
   293           Thm.implies_elim
   294            (Conv.fconv_rule (Thm.beta_conversion true)
   295              (Drule.instantiate'
   296                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   297                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   298                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   299       in
   300         case map_filter (fn th'' =>
   301             SOME (th'', singleton
   302               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   303           handle THM _ => NONE) thms of
   304             [] => NONE
   305           | thps =>
   306               let val (ths1, ths2) = split_list thps
   307               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   308       end
   309   in
   310     case get_first mk_thms eqs of
   311       NONE => thms
   312     | SOME x => remove_suc thy x
   313   end;
   314 
   315 fun eqn_suc_preproc thy ths =
   316   let
   317     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of;
   318     fun contains_suc t = member (op =) (term_consts t) @{const_name Suc};
   319   in
   320     if forall (can dest) ths andalso
   321       exists (contains_suc o dest) ths
   322     then remove_suc thy ths else ths
   323   end;
   324 
   325 fun remove_suc_clause thy thms =
   326   let
   327     val vname = Name.variant (map fst
   328       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
   329     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   330       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   331       | find_var _ = NONE;
   332     fun find_thm th =
   333       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   334       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   335   in
   336     case get_first find_thm thms of
   337       NONE => thms
   338     | SOME ((th, th'), (Sucv, v)) =>
   339         let
   340           val cert = cterm_of (Thm.theory_of_thm th);
   341           val th'' = ObjectLogic.rulify (Thm.implies_elim
   342             (Conv.fconv_rule (Thm.beta_conversion true)
   343               (Drule.instantiate' []
   344                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   345                    abstract_over (Sucv,
   346                      HOLogic.dest_Trueprop (prop_of th')))))),
   347                  SOME (cert v)] @{thm Suc_clause}))
   348             (Thm.forall_intr (cert v) th'))
   349         in
   350           remove_suc_clause thy (map (fn th''' =>
   351             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   352         end
   353   end;
   354 
   355 fun clause_suc_preproc thy ths =
   356   let
   357     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   358   in
   359     if forall (can (dest o concl_of)) ths andalso
   360       exists (fn th => member (op =) (foldr add_term_consts
   361         [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths
   362     then remove_suc_clause thy ths else ths
   363   end;
   364 
   365 fun lift_obj_eq f thy =
   366   map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
   367   #> f thy
   368   #> map (fn thm => thm RS @{thm eq_reflection})
   369   #> map (Conv.fconv_rule Drule.beta_eta_conversion)
   370 *}
   371 
   372 setup {*
   373   Codegen.add_preprocessor eqn_suc_preproc
   374   #> Codegen.add_preprocessor clause_suc_preproc
   375   #> Code.add_preproc ("eqn_Suc", lift_obj_eq eqn_suc_preproc)
   376   #> Code.add_preproc ("clause_Suc", lift_obj_eq clause_suc_preproc)
   377 *}
   378 (*>*)
   379 
   380 
   381 subsection {* Module names *}
   382 
   383 code_modulename SML
   384   Nat Integer
   385   Divides Integer
   386   Efficient_Nat Integer
   387 
   388 code_modulename OCaml
   389   Nat Integer
   390   Divides Integer
   391   Efficient_Nat Integer
   392 
   393 code_modulename Haskell
   394   Nat Integer
   395   Divides Integer
   396   Efficient_Nat Integer
   397 
   398 hide const nat_of_int int'
   399 
   400 end