src/HOL/Library/Efficient_Nat.thy
changeset 23854 688a8a7bcd4e
child 24195 7d1a16c77f7c
equal deleted inserted replaced
23853:2c69bb1374b8 23854:688a8a7bcd4e
       
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
       
     2     ID:         $Id$
       
     3     Author:     Stefan Berghofer, TU Muenchen
       
     4 *)
       
     5 
       
     6 header {* Implementation of natural numbers by integers *}
       
     7 
       
     8 theory Efficient_Nat
       
     9 imports Main Pretty_Int
       
    10 begin
       
    11 
       
    12 text {*
       
    13 When generating code for functions on natural numbers, the canonical
       
    14 representation using @{term "0::nat"} and @{term "Suc"} is unsuitable for
       
    15 computations involving large numbers. The efficiency of the generated
       
    16 code can be improved drastically by implementing natural numbers by
       
    17 integers. To do this, just include this theory.
       
    18 *}
       
    19 
       
    20 subsection {* Logical rewrites *}
       
    21 
       
    22 text {*
       
    23   An int-to-nat conversion
       
    24   restricted to non-negative ints (in contrast to @{const nat}).
       
    25   Note that this restriction has no logical relevance and
       
    26   is just a kind of proof hint -- nothing prevents you from 
       
    27   writing nonsense like @{term "nat_of_int (-4)"}
       
    28 *}
       
    29 
       
    30 definition
       
    31   nat_of_int :: "int \<Rightarrow> nat" where
       
    32   "k \<ge> 0 \<Longrightarrow> nat_of_int k = nat k"
       
    33 
       
    34 definition
       
    35   int' :: "nat \<Rightarrow> int" where
       
    36   "int' n = of_nat n"
       
    37 
       
    38 lemma int'_Suc [simp]: "int' (Suc n) = 1 + int' n"
       
    39 unfolding int'_def by simp
       
    40 
       
    41 lemma int'_add: "int' (m + n) = int' m + int' n"
       
    42 unfolding int'_def by (rule of_nat_add)
       
    43 
       
    44 lemma int'_mult: "int' (m * n) = int' m * int' n"
       
    45 unfolding int'_def by (rule of_nat_mult)
       
    46 
       
    47 lemma nat_of_int_of_number_of:
       
    48   fixes k
       
    49   assumes "k \<ge> 0"
       
    50   shows "number_of k = nat_of_int (number_of k)"
       
    51   unfolding nat_of_int_def [OF assms] nat_number_of_def number_of_is_id ..
       
    52 
       
    53 lemma nat_of_int_of_number_of_aux:
       
    54   fixes k
       
    55   assumes "Numeral.Pls \<le> k \<equiv> True"
       
    56   shows "k \<ge> 0"
       
    57   using assms unfolding Pls_def by simp
       
    58 
       
    59 lemma nat_of_int_int:
       
    60   "nat_of_int (int' n) = n"
       
    61   using nat_of_int_def int'_def by simp
       
    62 
       
    63 lemma eq_nat_of_int: "int' n = x \<Longrightarrow> n = nat_of_int x"
       
    64 by (erule subst, simp only: nat_of_int_int)
       
    65 
       
    66 text {*
       
    67   Case analysis on natural numbers is rephrased using a conditional
       
    68   expression:
       
    69 *}
       
    70 
       
    71 lemma [code unfold, code inline del]:
       
    72   "nat_case \<equiv> (\<lambda>f g n. if n = 0 then f else g (n - 1))"
       
    73 proof -
       
    74   have rewrite: "\<And>f g n. nat_case f g n = (if n = 0 then f else g (n - 1))"
       
    75   proof -
       
    76     fix f g n
       
    77     show "nat_case f g n = (if n = 0 then f else g (n - 1))"
       
    78       by (cases n) simp_all
       
    79   qed
       
    80   show "nat_case \<equiv> (\<lambda>f g n. if n = 0 then f else g (n - 1))"
       
    81     by (rule eq_reflection ext rewrite)+ 
       
    82 qed
       
    83 
       
    84 lemma [code inline]:
       
    85   "nat_case = (\<lambda>f g n. if n = 0 then f else g (nat_of_int (int' n - 1)))"
       
    86 proof (rule ext)+
       
    87   fix f g n
       
    88   show "nat_case f g n = (if n = 0 then f else g (nat_of_int (int' n - 1)))"
       
    89   by (cases n) (simp_all add: nat_of_int_int)
       
    90 qed
       
    91 
       
    92 text {*
       
    93   Most standard arithmetic functions on natural numbers are implemented
       
    94   using their counterparts on the integers:
       
    95 *}
       
    96 
       
    97 lemma [code func]: "0 = nat_of_int 0"
       
    98   by (simp add: nat_of_int_def)
       
    99 lemma [code func, code inline]:  "1 = nat_of_int 1"
       
   100   by (simp add: nat_of_int_def)
       
   101 lemma [code func]: "Suc n = nat_of_int (int' n + 1)"
       
   102   by (simp add: eq_nat_of_int)
       
   103 lemma [code]: "m + n = nat (int' m + int' n)"
       
   104   by (simp add: int'_def nat_eq_iff2)
       
   105 lemma [code func, code inline]: "m + n = nat_of_int (int' m + int' n)"
       
   106   by (simp add: eq_nat_of_int int'_add)
       
   107 lemma [code, code inline]: "m - n = nat (int' m - int' n)"
       
   108   by (simp add: int'_def nat_eq_iff2 of_nat_diff)
       
   109 lemma [code]: "m * n = nat (int' m * int' n)"
       
   110   unfolding int'_def
       
   111   by (simp add: of_nat_mult [symmetric] del: of_nat_mult)
       
   112 lemma [code func, code inline]: "m * n = nat_of_int (int' m * int' n)"
       
   113   by (simp add: eq_nat_of_int int'_mult)
       
   114 lemma [code]: "m div n = nat (int' m div int' n)"
       
   115   unfolding int'_def zdiv_int [symmetric] by simp
       
   116 lemma [code func]: "m div n = fst (Divides.divmod m n)"
       
   117   unfolding divmod_def by simp
       
   118 lemma [code]: "m mod n = nat (int' m mod int' n)"
       
   119   unfolding int'_def zmod_int [symmetric] by simp
       
   120 lemma [code func]: "m mod n = snd (Divides.divmod m n)"
       
   121   unfolding divmod_def by simp
       
   122 lemma [code, code inline]: "(m < n) \<longleftrightarrow> (int' m < int' n)"
       
   123   unfolding int'_def by simp
       
   124 lemma [code func, code inline]: "(m \<le> n) \<longleftrightarrow> (int' m \<le> int' n)"
       
   125   unfolding int'_def by simp
       
   126 lemma [code func, code inline]: "m = n \<longleftrightarrow> int' m = int' n"
       
   127   unfolding int'_def by simp
       
   128 lemma [code func]: "nat k = (if k < 0 then 0 else nat_of_int k)"
       
   129 proof (cases "k < 0")
       
   130   case True then show ?thesis by simp
       
   131 next
       
   132   case False then show ?thesis by (simp add: nat_of_int_def)
       
   133 qed
       
   134 lemma [code func]:
       
   135   "int_aux n i = (if int' n = 0 then i else int_aux (nat_of_int (int' n - 1)) (i + 1))"
       
   136 proof -
       
   137   have "0 < n \<Longrightarrow> int' n = 1 + int' (nat_of_int (int' n - 1))"
       
   138   proof -
       
   139     assume prem: "n > 0"
       
   140     then have "int' n - 1 \<ge> 0" unfolding int'_def by auto
       
   141     then have "nat_of_int (int' n - 1) = nat (int' n - 1)" by (simp add: nat_of_int_def)
       
   142     with prem show "int' n = 1 + int' (nat_of_int (int' n - 1))" unfolding int'_def by simp
       
   143   qed
       
   144   then show ?thesis unfolding int_aux_def int'_def by auto
       
   145 qed
       
   146 
       
   147 lemma div_nat_code [code func]:
       
   148   "m div k = nat_of_int (fst (divAlg (int' m, int' k)))"
       
   149   unfolding div_def [symmetric] int'_def zdiv_int [symmetric]
       
   150   unfolding int'_def [symmetric] nat_of_int_int ..
       
   151 
       
   152 lemma mod_nat_code [code func]:
       
   153   "m mod k = nat_of_int (snd (divAlg (int' m, int' k)))"
       
   154   unfolding mod_def [symmetric] int'_def zmod_int [symmetric]
       
   155   unfolding int'_def [symmetric] nat_of_int_int ..
       
   156 
       
   157 
       
   158 subsection {* Code generator setup for basic functions *}
       
   159 
       
   160 text {*
       
   161   @{typ nat} is no longer a datatype but embedded into the integers.
       
   162 *}
       
   163 
       
   164 code_datatype nat_of_int
       
   165 
       
   166 code_type nat
       
   167   (SML "IntInf.int")
       
   168   (OCaml "Big'_int.big'_int")
       
   169   (Haskell "Integer")
       
   170 
       
   171 types_code
       
   172   nat ("int")
       
   173 attach (term_of) {*
       
   174 val term_of_nat = HOLogic.mk_number HOLogic.natT o IntInf.fromInt;
       
   175 *}
       
   176 attach (test) {*
       
   177 fun gen_nat i = random_range 0 i;
       
   178 *}
       
   179 
       
   180 consts_code
       
   181   "0 \<Colon> nat" ("0")
       
   182   Suc ("(_ + 1)")
       
   183 
       
   184 text {*
       
   185   Since natural numbers are implemented
       
   186   using integers, the coercion function @{const "int"} of type
       
   187   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function,
       
   188   likewise @{const nat_of_int} of type @{typ "int \<Rightarrow> nat"}.
       
   189   For the @{const "nat"} function for converting an integer to a natural
       
   190   number, we give a specific implementation using an ML function that
       
   191   returns its input value, provided that it is non-negative, and otherwise
       
   192   returns @{text "0"}.
       
   193 *}
       
   194 
       
   195 consts_code
       
   196   int' ("(_)")
       
   197   nat ("\<module>nat")
       
   198 attach {*
       
   199 fun nat i = if i < 0 then 0 else i;
       
   200 *}
       
   201 
       
   202 code_const int'
       
   203   (SML "_")
       
   204   (OCaml "_")
       
   205   (Haskell "_")
       
   206 
       
   207 code_const nat_of_int
       
   208   (SML "_")
       
   209   (OCaml "_")
       
   210   (Haskell "_")
       
   211 
       
   212 
       
   213 subsection {* Preprocessors *}
       
   214 
       
   215 text {*
       
   216   Natural numerals should be expressed using @{const nat_of_int}.
       
   217 *}
       
   218 
       
   219 lemmas [code inline del] = nat_number_of_def
       
   220 
       
   221 ML {*
       
   222 fun nat_of_int_of_number_of thy cts =
       
   223   let
       
   224     val simplify_less = Simplifier.rewrite 
       
   225       (HOL_basic_ss addsimps (@{thms less_numeral_code} @ @{thms less_eq_numeral_code}));
       
   226     fun mk_rew (t, ty) =
       
   227       if ty = HOLogic.natT andalso IntInf.<= (0, HOLogic.dest_numeral t) then
       
   228         Thm.capply @{cterm "(op \<le>) Numeral.Pls"} (Thm.cterm_of thy t)
       
   229         |> simplify_less
       
   230         |> (fn thm => @{thm nat_of_int_of_number_of_aux} OF [thm])
       
   231         |> (fn thm => @{thm nat_of_int_of_number_of} OF [thm])
       
   232         |> (fn thm => @{thm eq_reflection} OF [thm])
       
   233         |> SOME
       
   234       else NONE
       
   235   in
       
   236     fold (HOLogic.add_numerals o Thm.term_of) cts []
       
   237     |> map_filter mk_rew
       
   238   end;
       
   239 *}
       
   240 
       
   241 setup {*
       
   242   CodegenData.add_inline_proc ("nat_of_int_of_number_of", nat_of_int_of_number_of)
       
   243 *}
       
   244 
       
   245 text {*
       
   246   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
       
   247   a constructor term. Therefore, all occurrences of this term in a position
       
   248   where a pattern is expected (i.e.\ on the left-hand side of a recursion
       
   249   equation or in the arguments of an inductive relation in an introduction
       
   250   rule) must be eliminated.
       
   251   This can be accomplished by applying the following transformation rules:
       
   252 *}
       
   253 
       
   254 theorem Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
       
   255   f n = (if n = 0 then g else h (n - 1))"
       
   256   by (case_tac n) simp_all
       
   257 
       
   258 theorem Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
       
   259   by (case_tac n) simp_all
       
   260 
       
   261 text {*
       
   262   The rules above are built into a preprocessor that is plugged into
       
   263   the code generator. Since the preprocessor for introduction rules
       
   264   does not know anything about modes, some of the modes that worked
       
   265   for the canonical representation of natural numbers may no longer work.
       
   266 *}
       
   267 
       
   268 (*<*)
       
   269 
       
   270 ML {*
       
   271 local
       
   272   val Suc_if_eq = thm "Suc_if_eq";
       
   273   val Suc_clause = thm "Suc_clause";
       
   274   fun contains_suc t = member (op =) (term_consts t) "Suc";
       
   275 in
       
   276 
       
   277 fun remove_suc thy thms =
       
   278   let
       
   279     val Suc_if_eq' = Thm.transfer thy Suc_if_eq;
       
   280     val vname = Name.variant (map fst
       
   281       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
       
   282     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
       
   283     fun lhs_of th = snd (Thm.dest_comb
       
   284       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
       
   285     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
       
   286     fun find_vars ct = (case term_of ct of
       
   287         (Const ("Suc", _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
       
   288       | _ $ _ =>
       
   289         let val (ct1, ct2) = Thm.dest_comb ct
       
   290         in 
       
   291           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
       
   292           map (apfst (Thm.capply ct1)) (find_vars ct2)
       
   293         end
       
   294       | _ => []);
       
   295     val eqs = maps
       
   296       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
       
   297     fun mk_thms (th, (ct, cv')) =
       
   298       let
       
   299         val th' =
       
   300           Thm.implies_elim
       
   301            (Conv.fconv_rule (Thm.beta_conversion true)
       
   302              (Drule.instantiate'
       
   303                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
       
   304                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
       
   305                Suc_if_eq')) (Thm.forall_intr cv' th)
       
   306       in
       
   307         case map_filter (fn th'' =>
       
   308             SOME (th'', singleton
       
   309               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
       
   310           handle THM _ => NONE) thms of
       
   311             [] => NONE
       
   312           | thps =>
       
   313               let val (ths1, ths2) = split_list thps
       
   314               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
       
   315       end
       
   316   in
       
   317     case get_first mk_thms eqs of
       
   318       NONE => thms
       
   319     | SOME x => remove_suc thy x
       
   320   end;
       
   321 
       
   322 fun eqn_suc_preproc thy ths =
       
   323   let
       
   324     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of
       
   325   in
       
   326     if forall (can dest) ths andalso
       
   327       exists (contains_suc o dest) ths
       
   328     then remove_suc thy ths else ths
       
   329   end;
       
   330 
       
   331 fun remove_suc_clause thy thms =
       
   332   let
       
   333     val Suc_clause' = Thm.transfer thy Suc_clause;
       
   334     val vname = Name.variant (map fst
       
   335       (fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x";
       
   336     fun find_var (t as Const ("Suc", _) $ (v as Var _)) = SOME (t, v)
       
   337       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
       
   338       | find_var _ = NONE;
       
   339     fun find_thm th =
       
   340       let val th' = Conv.fconv_rule ObjectLogic.atomize th
       
   341       in Option.map (pair (th, th')) (find_var (prop_of th')) end
       
   342   in
       
   343     case get_first find_thm thms of
       
   344       NONE => thms
       
   345     | SOME ((th, th'), (Sucv, v)) =>
       
   346         let
       
   347           val cert = cterm_of (Thm.theory_of_thm th);
       
   348           val th'' = ObjectLogic.rulify (Thm.implies_elim
       
   349             (Conv.fconv_rule (Thm.beta_conversion true)
       
   350               (Drule.instantiate' []
       
   351                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
       
   352                    abstract_over (Sucv,
       
   353                      HOLogic.dest_Trueprop (prop_of th')))))),
       
   354                  SOME (cert v)] Suc_clause'))
       
   355             (Thm.forall_intr (cert v) th'))
       
   356         in
       
   357           remove_suc_clause thy (map (fn th''' =>
       
   358             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
       
   359         end
       
   360   end;
       
   361 
       
   362 fun clause_suc_preproc thy ths =
       
   363   let
       
   364     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
       
   365   in
       
   366     if forall (can (dest o concl_of)) ths andalso
       
   367       exists (fn th => member (op =) (foldr add_term_consts
       
   368         [] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths
       
   369     then remove_suc_clause thy ths else ths
       
   370   end;
       
   371 
       
   372 end; (*local*)
       
   373 
       
   374 fun lift_obj_eq f thy =
       
   375   map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
       
   376   #> f thy
       
   377   #> map (fn thm => thm RS @{thm eq_reflection})
       
   378   #> map (Conv.fconv_rule Drule.beta_eta_conversion)
       
   379 *}
       
   380 
       
   381 setup {*
       
   382   Codegen.add_preprocessor eqn_suc_preproc
       
   383   #> Codegen.add_preprocessor clause_suc_preproc
       
   384   #> CodegenData.add_preproc ("eqn_Suc", lift_obj_eq eqn_suc_preproc)
       
   385   #> CodegenData.add_preproc ("clause_Suc", lift_obj_eq clause_suc_preproc)
       
   386 *}
       
   387 (*>*)
       
   388 
       
   389 
       
   390 subsection {* Module names *}
       
   391 
       
   392 code_modulename SML
       
   393   Nat Integer
       
   394   Divides Integer
       
   395   Efficient_Nat Integer
       
   396 
       
   397 code_modulename OCaml
       
   398   Nat Integer
       
   399   Divides Integer
       
   400   Efficient_Nat Integer
       
   401 
       
   402 code_modulename Haskell
       
   403   Nat Integer
       
   404   Efficient_Nat Integer
       
   405 
       
   406 hide const nat_of_int int'
       
   407 
       
   408 end