```     1 (*  Title:      HOL/Library/Code_Nat.thy
```     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
```     3 *)
```     4
```     5 header {* Implementation of natural numbers as binary numerals *}
```     6
```     7 theory Code_Nat
```     8 imports Main
```     9 begin
```    10
```    11 text {*
```    12   When generating code for functions on natural numbers, the
```    13   canonical representation using @{term "0::nat"} and
```    14   @{term Suc} is unsuitable for computations involving large
```    15   numbers.  This theory refines the representation of
```    16   natural numbers for code generation to use binary
```    17   numerals, which do not grow linear in size but logarithmic.
```    18 *}
```    19
```    20 subsection {* Representation *}
```    21
```    22 lemma [code_abbrev]:
```    23   "nat_of_num = numeral"
```    24   by (fact nat_of_num_numeral)
```    25
```    26 code_datatype "0::nat" nat_of_num
```    27
```    28 lemma [code]:
```    29   "num_of_nat 0 = Num.One"
```    30   "num_of_nat (nat_of_num k) = k"
```    31   by (simp_all add: nat_of_num_inverse)
```    32
```    33 lemma [code]:
```    34   "(1\<Colon>nat) = Numeral1"
```    35   by simp
```    36
```    37 lemma [code_abbrev]: "Numeral1 = (1\<Colon>nat)"
```    38   by simp
```    39
```    40 lemma [code]:
```    41   "Suc n = n + 1"
```    42   by simp
```    43
```    44
```    45 subsection {* Basic arithmetic *}
```    46
```    47 lemma [code, code del]:
```    48   "(plus :: nat \<Rightarrow> _) = plus" ..
```    49
```    50 lemma plus_nat_code [code]:
```    51   "nat_of_num k + nat_of_num l = nat_of_num (k + l)"
```    52   "m + 0 = (m::nat)"
```    53   "0 + n = (n::nat)"
```    54   by (simp_all add: nat_of_num_numeral)
```    55
```    56 text {* Bounded subtraction needs some auxiliary *}
```    57
```    58 definition dup :: "nat \<Rightarrow> nat" where
```    59   "dup n = n + n"
```    60
```    61 lemma dup_code [code]:
```    62   "dup 0 = 0"
```    63   "dup (nat_of_num k) = nat_of_num (Num.Bit0 k)"
```    64   unfolding Num_def by (simp_all add: dup_def numeral_Bit0)
```    65
```    66 definition sub :: "num \<Rightarrow> num \<Rightarrow> nat option" where
```    67   "sub k l = (if k \<ge> l then Some (numeral k - numeral l) else None)"
```    68
```    69 lemma sub_code [code]:
```    70   "sub Num.One Num.One = Some 0"
```    71   "sub (Num.Bit0 m) Num.One = Some (nat_of_num (Num.BitM m))"
```    72   "sub (Num.Bit1 m) Num.One = Some (nat_of_num (Num.Bit0 m))"
```    73   "sub Num.One (Num.Bit0 n) = None"
```    74   "sub Num.One (Num.Bit1 n) = None"
```    75   "sub (Num.Bit0 m) (Num.Bit0 n) = Option.map dup (sub m n)"
```    76   "sub (Num.Bit1 m) (Num.Bit1 n) = Option.map dup (sub m n)"
```    77   "sub (Num.Bit1 m) (Num.Bit0 n) = Option.map (\<lambda>q. dup q + 1) (sub m n)"
```    78   "sub (Num.Bit0 m) (Num.Bit1 n) = (case sub m n of None \<Rightarrow> None
```    79      | Some q \<Rightarrow> if q = 0 then None else Some (dup q - 1))"
```    80   apply (auto simp add: nat_of_num_numeral
```    81     Num.dbl_def Num.dbl_inc_def Num.dbl_dec_def
```    82     Let_def le_imp_diff_is_add BitM_plus_one sub_def dup_def)
```    83   apply (simp_all add: sub_non_positive)
```    84   apply (simp_all add: sub_non_negative [symmetric, where ?'a = int])
```    85   done
```    86
```    87 lemma [code, code del]:
```    88   "(minus :: nat \<Rightarrow> _) = minus" ..
```    89
```    90 lemma minus_nat_code [code]:
```    91   "nat_of_num k - nat_of_num l = (case sub k l of None \<Rightarrow> 0 | Some j \<Rightarrow> j)"
```    92   "m - 0 = (m::nat)"
```    93   "0 - n = (0::nat)"
```    94   by (simp_all add: nat_of_num_numeral sub_non_positive sub_def)
```    95
```    96 lemma [code, code del]:
```    97   "(times :: nat \<Rightarrow> _) = times" ..
```    98
```    99 lemma times_nat_code [code]:
```   100   "nat_of_num k * nat_of_num l = nat_of_num (k * l)"
```   101   "m * 0 = (0::nat)"
```   102   "0 * n = (0::nat)"
```   103   by (simp_all add: nat_of_num_numeral)
```   104
```   105 lemma [code, code del]:
```   106   "(HOL.equal :: nat \<Rightarrow> _) = HOL.equal" ..
```   107
```   108 lemma equal_nat_code [code]:
```   109   "HOL.equal 0 (0::nat) \<longleftrightarrow> True"
```   110   "HOL.equal 0 (nat_of_num l) \<longleftrightarrow> False"
```   111   "HOL.equal (nat_of_num k) 0 \<longleftrightarrow> False"
```   112   "HOL.equal (nat_of_num k) (nat_of_num l) \<longleftrightarrow> HOL.equal k l"
```   113   by (simp_all add: nat_of_num_numeral equal)
```   114
```   115 lemma equal_nat_refl [code nbe]:
```   116   "HOL.equal (n::nat) n \<longleftrightarrow> True"
```   117   by (rule equal_refl)
```   118
```   119 lemma [code, code del]:
```   120   "(less_eq :: nat \<Rightarrow> _) = less_eq" ..
```   121
```   122 lemma less_eq_nat_code [code]:
```   123   "0 \<le> (n::nat) \<longleftrightarrow> True"
```   124   "nat_of_num k \<le> 0 \<longleftrightarrow> False"
```   125   "nat_of_num k \<le> nat_of_num l \<longleftrightarrow> k \<le> l"
```   126   by (simp_all add: nat_of_num_numeral)
```   127
```   128 lemma [code, code del]:
```   129   "(less :: nat \<Rightarrow> _) = less" ..
```   130
```   131 lemma less_nat_code [code]:
```   132   "(m::nat) < 0 \<longleftrightarrow> False"
```   133   "0 < nat_of_num l \<longleftrightarrow> True"
```   134   "nat_of_num k < nat_of_num l \<longleftrightarrow> k < l"
```   135   by (simp_all add: nat_of_num_numeral)
```   136
```   137
```   138 subsection {* Conversions *}
```   139
```   140 lemma [code, code del]:
```   141   "of_nat = of_nat" ..
```   142
```   143 lemma of_nat_code [code]:
```   144   "of_nat 0 = 0"
```   145   "of_nat (nat_of_num k) = numeral k"
```   146   by (simp_all add: nat_of_num_numeral)
```   147
```   148
```   149 subsection {* Case analysis *}
```   150
```   151 text {*
```   152   Case analysis on natural numbers is rephrased using a conditional
```   153   expression:
```   154 *}
```   155
```   156 lemma [code, code_unfold]:
```   157   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
```   158   by (auto simp add: fun_eq_iff dest!: gr0_implies_Suc)
```   159
```   160
```   161 subsection {* Preprocessors *}
```   162
```   163 text {*
```   164   The term @{term "Suc n"} is no longer a valid pattern.
```   165   Therefore, all occurrences of this term in a position
```   166   where a pattern is expected (i.e.~on the left-hand side of a recursion
```   167   equation) must be eliminated.
```   168   This can be accomplished by applying the following transformation rules:
```   169 *}
```   170
```   171 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
```   172   f n \<equiv> if n = 0 then g else h (n - 1)"
```   173   by (rule eq_reflection) (cases n, simp_all)
```   174
```   175 text {*
```   176   The rules above are built into a preprocessor that is plugged into
```   177   the code generator. Since the preprocessor for introduction rules
```   178   does not know anything about modes, some of the modes that worked
```   179   for the canonical representation of natural numbers may no longer work.
```   180 *}
```   181
```   182 (*<*)
```   183 setup {*
```   184 let
```   185
```   186 fun remove_suc thy thms =
```   187   let
```   188     val vname = singleton (Name.variant_list (map fst
```   189       (fold (Term.add_var_names o Thm.full_prop_of) thms []))) "n";
```   190     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
```   191     fun lhs_of th = snd (Thm.dest_comb
```   192       (fst (Thm.dest_comb (cprop_of th))));
```   193     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
```   194     fun find_vars ct = (case term_of ct of
```   195         (Const (@{const_name Suc}, _) \$ Var _) => [(cv, snd (Thm.dest_comb ct))]
```   196       | _ \$ _ =>
```   197         let val (ct1, ct2) = Thm.dest_comb ct
```   198         in
```   199           map (apfst (fn ct => Thm.apply ct ct2)) (find_vars ct1) @
```   200           map (apfst (Thm.apply ct1)) (find_vars ct2)
```   201         end
```   202       | _ => []);
```   203     val eqs = maps
```   204       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
```   205     fun mk_thms (th, (ct, cv')) =
```   206       let
```   207         val th' =
```   208           Thm.implies_elim
```   209            (Conv.fconv_rule (Thm.beta_conversion true)
```   210              (Drule.instantiate'
```   211                [SOME (ctyp_of_term ct)] [SOME (Thm.lambda cv ct),
```   212                  SOME (Thm.lambda cv' (rhs_of th)), NONE, SOME cv']
```   213                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
```   214       in
```   215         case map_filter (fn th'' =>
```   216             SOME (th'', singleton
```   217               (Variable.trade (K (fn [th'''] => [th''' RS th']))
```   218                 (Variable.global_thm_context th'')) th'')
```   219           handle THM _ => NONE) thms of
```   220             [] => NONE
```   221           | thps =>
```   222               let val (ths1, ths2) = split_list thps
```   223               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
```   224       end
```   225   in get_first mk_thms eqs end;
```   226
```   227 fun eqn_suc_base_preproc thy thms =
```   228   let
```   229     val dest = fst o Logic.dest_equals o prop_of;
```   230     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
```   231   in
```   232     if forall (can dest) thms andalso exists (contains_suc o dest) thms
```   233       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
```   234        else NONE
```   235   end;
```   236
```   237 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
```   238
```   239 in
```   240
```   241   Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
```   242
```   243 end;
```   244 *}
```   245 (*>*)
```   246
```   247 code_modulename SML
```   248   Code_Nat Arith
```   249
```   250 code_modulename OCaml
```   251   Code_Nat Arith
```   252
```   253 code_modulename Haskell
```   254   Code_Nat Arith
```   255
```   256 hide_const (open) dup sub
```   257
```   258 end
```