| 47108 |      1 | (*  Title:      HOL/Library/Code_Nat.thy
 | 
|  |      2 |     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
 | 
|  |      3 | *)
 | 
|  |      4 | 
 | 
|  |      5 | header {* Implementation of natural numbers as binary numerals *}
 | 
|  |      6 | 
 | 
|  |      7 | theory Code_Nat
 | 
|  |      8 | imports Main
 | 
|  |      9 | begin
 | 
|  |     10 | 
 | 
|  |     11 | text {*
 | 
|  |     12 |   When generating code for functions on natural numbers, the
 | 
|  |     13 |   canonical representation using @{term "0::nat"} and
 | 
|  |     14 |   @{term Suc} is unsuitable for computations involving large
 | 
|  |     15 |   numbers.  This theory refines the representation of
 | 
|  |     16 |   natural numbers for code generation to use binary
 | 
|  |     17 |   numerals, which do not grow linear in size but logarithmic.
 | 
|  |     18 | *}
 | 
|  |     19 | 
 | 
|  |     20 | subsection {* Representation *}
 | 
|  |     21 | 
 | 
|  |     22 | lemma [code_abbrev]:
 | 
|  |     23 |   "nat_of_num = numeral"
 | 
|  |     24 |   by (fact nat_of_num_numeral)
 | 
|  |     25 | 
 | 
|  |     26 | code_datatype "0::nat" nat_of_num
 | 
|  |     27 | 
 | 
|  |     28 | lemma [code]:
 | 
|  |     29 |   "num_of_nat 0 = Num.One"
 | 
|  |     30 |   "num_of_nat (nat_of_num k) = k"
 | 
|  |     31 |   by (simp_all add: nat_of_num_inverse)
 | 
|  |     32 | 
 | 
|  |     33 | lemma [code]:
 | 
|  |     34 |   "(1\<Colon>nat) = Numeral1"
 | 
|  |     35 |   by simp
 | 
|  |     36 | 
 | 
|  |     37 | lemma [code_abbrev]: "Numeral1 = (1\<Colon>nat)"
 | 
|  |     38 |   by simp
 | 
|  |     39 | 
 | 
|  |     40 | lemma [code]:
 | 
|  |     41 |   "Suc n = n + 1"
 | 
|  |     42 |   by simp
 | 
|  |     43 | 
 | 
|  |     44 | 
 | 
|  |     45 | subsection {* Basic arithmetic *}
 | 
|  |     46 | 
 | 
|  |     47 | lemma [code, code del]:
 | 
|  |     48 |   "(plus :: nat \<Rightarrow> _) = plus" ..
 | 
|  |     49 | 
 | 
|  |     50 | lemma plus_nat_code [code]:
 | 
|  |     51 |   "nat_of_num k + nat_of_num l = nat_of_num (k + l)"
 | 
|  |     52 |   "m + 0 = (m::nat)"
 | 
|  |     53 |   "0 + n = (n::nat)"
 | 
|  |     54 |   by (simp_all add: nat_of_num_numeral)
 | 
|  |     55 | 
 | 
|  |     56 | text {* Bounded subtraction needs some auxiliary *}
 | 
|  |     57 | 
 | 
|  |     58 | definition dup :: "nat \<Rightarrow> nat" where
 | 
|  |     59 |   "dup n = n + n"
 | 
|  |     60 | 
 | 
|  |     61 | lemma dup_code [code]:
 | 
|  |     62 |   "dup 0 = 0"
 | 
|  |     63 |   "dup (nat_of_num k) = nat_of_num (Num.Bit0 k)"
 | 
|  |     64 |   unfolding Num_def by (simp_all add: dup_def numeral_Bit0)
 | 
|  |     65 | 
 | 
|  |     66 | definition sub :: "num \<Rightarrow> num \<Rightarrow> nat option" where
 | 
|  |     67 |   "sub k l = (if k \<ge> l then Some (numeral k - numeral l) else None)"
 | 
|  |     68 | 
 | 
|  |     69 | lemma sub_code [code]:
 | 
|  |     70 |   "sub Num.One Num.One = Some 0"
 | 
|  |     71 |   "sub (Num.Bit0 m) Num.One = Some (nat_of_num (Num.BitM m))"
 | 
|  |     72 |   "sub (Num.Bit1 m) Num.One = Some (nat_of_num (Num.Bit0 m))"
 | 
|  |     73 |   "sub Num.One (Num.Bit0 n) = None"
 | 
|  |     74 |   "sub Num.One (Num.Bit1 n) = None"
 | 
|  |     75 |   "sub (Num.Bit0 m) (Num.Bit0 n) = Option.map dup (sub m n)"
 | 
|  |     76 |   "sub (Num.Bit1 m) (Num.Bit1 n) = Option.map dup (sub m n)"
 | 
|  |     77 |   "sub (Num.Bit1 m) (Num.Bit0 n) = Option.map (\<lambda>q. dup q + 1) (sub m n)"
 | 
|  |     78 |   "sub (Num.Bit0 m) (Num.Bit1 n) = (case sub m n of None \<Rightarrow> None
 | 
|  |     79 |      | Some q \<Rightarrow> if q = 0 then None else Some (dup q - 1))"
 | 
|  |     80 |   apply (auto simp add: nat_of_num_numeral
 | 
|  |     81 |     Num.dbl_def Num.dbl_inc_def Num.dbl_dec_def
 | 
|  |     82 |     Let_def le_imp_diff_is_add BitM_plus_one sub_def dup_def)
 | 
|  |     83 |   apply (simp_all add: sub_non_positive)
 | 
|  |     84 |   apply (simp_all add: sub_non_negative [symmetric, where ?'a = int])
 | 
|  |     85 |   done
 | 
|  |     86 | 
 | 
|  |     87 | lemma [code, code del]:
 | 
|  |     88 |   "(minus :: nat \<Rightarrow> _) = minus" ..
 | 
|  |     89 | 
 | 
|  |     90 | lemma minus_nat_code [code]:
 | 
|  |     91 |   "nat_of_num k - nat_of_num l = (case sub k l of None \<Rightarrow> 0 | Some j \<Rightarrow> j)"
 | 
|  |     92 |   "m - 0 = (m::nat)"
 | 
|  |     93 |   "0 - n = (0::nat)"
 | 
|  |     94 |   by (simp_all add: nat_of_num_numeral sub_non_positive sub_def)
 | 
|  |     95 | 
 | 
|  |     96 | lemma [code, code del]:
 | 
|  |     97 |   "(times :: nat \<Rightarrow> _) = times" ..
 | 
|  |     98 | 
 | 
|  |     99 | lemma times_nat_code [code]:
 | 
|  |    100 |   "nat_of_num k * nat_of_num l = nat_of_num (k * l)"
 | 
|  |    101 |   "m * 0 = (0::nat)"
 | 
|  |    102 |   "0 * n = (0::nat)"
 | 
|  |    103 |   by (simp_all add: nat_of_num_numeral)
 | 
|  |    104 | 
 | 
|  |    105 | lemma [code, code del]:
 | 
|  |    106 |   "(HOL.equal :: nat \<Rightarrow> _) = HOL.equal" ..
 | 
|  |    107 | 
 | 
|  |    108 | lemma equal_nat_code [code]:
 | 
|  |    109 |   "HOL.equal 0 (0::nat) \<longleftrightarrow> True"
 | 
|  |    110 |   "HOL.equal 0 (nat_of_num l) \<longleftrightarrow> False"
 | 
|  |    111 |   "HOL.equal (nat_of_num k) 0 \<longleftrightarrow> False"
 | 
|  |    112 |   "HOL.equal (nat_of_num k) (nat_of_num l) \<longleftrightarrow> HOL.equal k l"
 | 
|  |    113 |   by (simp_all add: nat_of_num_numeral equal)
 | 
|  |    114 | 
 | 
|  |    115 | lemma equal_nat_refl [code nbe]:
 | 
|  |    116 |   "HOL.equal (n::nat) n \<longleftrightarrow> True"
 | 
|  |    117 |   by (rule equal_refl)
 | 
|  |    118 | 
 | 
|  |    119 | lemma [code, code del]:
 | 
|  |    120 |   "(less_eq :: nat \<Rightarrow> _) = less_eq" ..
 | 
|  |    121 | 
 | 
|  |    122 | lemma less_eq_nat_code [code]:
 | 
|  |    123 |   "0 \<le> (n::nat) \<longleftrightarrow> True"
 | 
|  |    124 |   "nat_of_num k \<le> 0 \<longleftrightarrow> False"
 | 
|  |    125 |   "nat_of_num k \<le> nat_of_num l \<longleftrightarrow> k \<le> l"
 | 
|  |    126 |   by (simp_all add: nat_of_num_numeral)
 | 
|  |    127 | 
 | 
|  |    128 | lemma [code, code del]:
 | 
|  |    129 |   "(less :: nat \<Rightarrow> _) = less" ..
 | 
|  |    130 | 
 | 
|  |    131 | lemma less_nat_code [code]:
 | 
|  |    132 |   "(m::nat) < 0 \<longleftrightarrow> False"
 | 
|  |    133 |   "0 < nat_of_num l \<longleftrightarrow> True"
 | 
|  |    134 |   "nat_of_num k < nat_of_num l \<longleftrightarrow> k < l"
 | 
|  |    135 |   by (simp_all add: nat_of_num_numeral)
 | 
|  |    136 | 
 | 
|  |    137 | 
 | 
|  |    138 | subsection {* Conversions *}
 | 
|  |    139 | 
 | 
|  |    140 | lemma [code, code del]:
 | 
|  |    141 |   "of_nat = of_nat" ..
 | 
|  |    142 | 
 | 
|  |    143 | lemma of_nat_code [code]:
 | 
|  |    144 |   "of_nat 0 = 0"
 | 
|  |    145 |   "of_nat (nat_of_num k) = numeral k"
 | 
|  |    146 |   by (simp_all add: nat_of_num_numeral)
 | 
|  |    147 | 
 | 
|  |    148 | 
 | 
|  |    149 | subsection {* Case analysis *}
 | 
|  |    150 | 
 | 
|  |    151 | text {*
 | 
|  |    152 |   Case analysis on natural numbers is rephrased using a conditional
 | 
|  |    153 |   expression:
 | 
|  |    154 | *}
 | 
|  |    155 | 
 | 
|  |    156 | lemma [code, code_unfold]:
 | 
|  |    157 |   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
 | 
|  |    158 |   by (auto simp add: fun_eq_iff dest!: gr0_implies_Suc)
 | 
|  |    159 | 
 | 
|  |    160 | 
 | 
|  |    161 | subsection {* Preprocessors *}
 | 
|  |    162 | 
 | 
|  |    163 | text {*
 | 
|  |    164 |   The term @{term "Suc n"} is no longer a valid pattern.
 | 
|  |    165 |   Therefore, all occurrences of this term in a position
 | 
|  |    166 |   where a pattern is expected (i.e.~on the left-hand side of a recursion
 | 
|  |    167 |   equation) must be eliminated.
 | 
|  |    168 |   This can be accomplished by applying the following transformation rules:
 | 
|  |    169 | *}
 | 
|  |    170 | 
 | 
|  |    171 | lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
 | 
|  |    172 |   f n \<equiv> if n = 0 then g else h (n - 1)"
 | 
|  |    173 |   by (rule eq_reflection) (cases n, simp_all)
 | 
|  |    174 | 
 | 
|  |    175 | text {*
 | 
|  |    176 |   The rules above are built into a preprocessor that is plugged into
 | 
|  |    177 |   the code generator. Since the preprocessor for introduction rules
 | 
|  |    178 |   does not know anything about modes, some of the modes that worked
 | 
|  |    179 |   for the canonical representation of natural numbers may no longer work.
 | 
|  |    180 | *}
 | 
|  |    181 | 
 | 
|  |    182 | (*<*)
 | 
|  |    183 | setup {*
 | 
|  |    184 | let
 | 
|  |    185 | 
 | 
|  |    186 | fun remove_suc thy thms =
 | 
|  |    187 |   let
 | 
|  |    188 |     val vname = singleton (Name.variant_list (map fst
 | 
|  |    189 |       (fold (Term.add_var_names o Thm.full_prop_of) thms []))) "n";
 | 
|  |    190 |     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
 | 
|  |    191 |     fun lhs_of th = snd (Thm.dest_comb
 | 
|  |    192 |       (fst (Thm.dest_comb (cprop_of th))));
 | 
|  |    193 |     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
 | 
|  |    194 |     fun find_vars ct = (case term_of ct of
 | 
|  |    195 |         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
 | 
|  |    196 |       | _ $ _ =>
 | 
|  |    197 |         let val (ct1, ct2) = Thm.dest_comb ct
 | 
|  |    198 |         in 
 | 
|  |    199 |           map (apfst (fn ct => Thm.apply ct ct2)) (find_vars ct1) @
 | 
|  |    200 |           map (apfst (Thm.apply ct1)) (find_vars ct2)
 | 
|  |    201 |         end
 | 
|  |    202 |       | _ => []);
 | 
|  |    203 |     val eqs = maps
 | 
|  |    204 |       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
 | 
|  |    205 |     fun mk_thms (th, (ct, cv')) =
 | 
|  |    206 |       let
 | 
|  |    207 |         val th' =
 | 
|  |    208 |           Thm.implies_elim
 | 
|  |    209 |            (Conv.fconv_rule (Thm.beta_conversion true)
 | 
|  |    210 |              (Drule.instantiate'
 | 
|  |    211 |                [SOME (ctyp_of_term ct)] [SOME (Thm.lambda cv ct),
 | 
|  |    212 |                  SOME (Thm.lambda cv' (rhs_of th)), NONE, SOME cv']
 | 
|  |    213 |                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
 | 
|  |    214 |       in
 | 
|  |    215 |         case map_filter (fn th'' =>
 | 
|  |    216 |             SOME (th'', singleton
 | 
|  |    217 |               (Variable.trade (K (fn [th'''] => [th''' RS th']))
 | 
|  |    218 |                 (Variable.global_thm_context th'')) th'')
 | 
|  |    219 |           handle THM _ => NONE) thms of
 | 
|  |    220 |             [] => NONE
 | 
|  |    221 |           | thps =>
 | 
|  |    222 |               let val (ths1, ths2) = split_list thps
 | 
|  |    223 |               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
 | 
|  |    224 |       end
 | 
|  |    225 |   in get_first mk_thms eqs end;
 | 
|  |    226 | 
 | 
|  |    227 | fun eqn_suc_base_preproc thy thms =
 | 
|  |    228 |   let
 | 
|  |    229 |     val dest = fst o Logic.dest_equals o prop_of;
 | 
|  |    230 |     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
 | 
|  |    231 |   in
 | 
|  |    232 |     if forall (can dest) thms andalso exists (contains_suc o dest) thms
 | 
|  |    233 |       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
 | 
|  |    234 |        else NONE
 | 
|  |    235 |   end;
 | 
|  |    236 | 
 | 
|  |    237 | val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
 | 
|  |    238 | 
 | 
|  |    239 | in
 | 
|  |    240 | 
 | 
|  |    241 |   Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
 | 
|  |    242 | 
 | 
|  |    243 | end;
 | 
|  |    244 | *}
 | 
|  |    245 | (*>*)
 | 
|  |    246 | 
 | 
|  |    247 | code_modulename SML
 | 
|  |    248 |   Code_Nat Arith
 | 
|  |    249 | 
 | 
|  |    250 | code_modulename OCaml
 | 
|  |    251 |   Code_Nat Arith
 | 
|  |    252 | 
 | 
|  |    253 | code_modulename Haskell
 | 
|  |    254 |   Code_Nat Arith
 | 
|  |    255 | 
 | 
|  |    256 | hide_const (open) dup sub
 | 
|  |    257 | 
 | 
|  |    258 | end
 |