src/HOL/Library/Code_Binary_Nat.thy
changeset 50023 28f3263d4d1b
parent 47108 2a1953f0d20d
child 51113 222fb6cb2c3e
equal deleted inserted replaced
50022:286dfcab9833 50023:28f3263d4d1b
       
     1 (*  Title:      HOL/Library/Code_Binary_Nat.thy
       
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
       
     3 *)
       
     4 
       
     5 header {* Implementation of natural numbers as binary numerals *}
       
     6 
       
     7 theory Code_Binary_Nat
       
     8 imports Main
       
     9 begin
       
    10 
       
    11 text {*
       
    12   When generating code for functions on natural numbers, the
       
    13   canonical representation using @{term "0::nat"} and
       
    14   @{term Suc} is unsuitable for computations involving large
       
    15   numbers.  This theory refines the representation of
       
    16   natural numbers for code generation to use binary
       
    17   numerals, which do not grow linear in size but logarithmic.
       
    18 *}
       
    19 
       
    20 subsection {* Representation *}
       
    21 
       
    22 code_datatype "0::nat" nat_of_num
       
    23 
       
    24 lemma [code_abbrev]:
       
    25   "nat_of_num = numeral"
       
    26   by (fact nat_of_num_numeral)
       
    27 
       
    28 lemma [code]:
       
    29   "num_of_nat 0 = Num.One"
       
    30   "num_of_nat (nat_of_num k) = k"
       
    31   by (simp_all add: nat_of_num_inverse)
       
    32 
       
    33 lemma [code]:
       
    34   "(1\<Colon>nat) = Numeral1"
       
    35   by simp
       
    36 
       
    37 lemma [code_abbrev]: "Numeral1 = (1\<Colon>nat)"
       
    38   by simp
       
    39 
       
    40 lemma [code]:
       
    41   "Suc n = n + 1"
       
    42   by simp
       
    43 
       
    44 
       
    45 subsection {* Basic arithmetic *}
       
    46 
       
    47 lemma [code, code del]:
       
    48   "(plus :: nat \<Rightarrow> _) = plus" ..
       
    49 
       
    50 lemma plus_nat_code [code]:
       
    51   "nat_of_num k + nat_of_num l = nat_of_num (k + l)"
       
    52   "m + 0 = (m::nat)"
       
    53   "0 + n = (n::nat)"
       
    54   by (simp_all add: nat_of_num_numeral)
       
    55 
       
    56 text {* Bounded subtraction needs some auxiliary *}
       
    57 
       
    58 definition dup :: "nat \<Rightarrow> nat" where
       
    59   "dup n = n + n"
       
    60 
       
    61 lemma dup_code [code]:
       
    62   "dup 0 = 0"
       
    63   "dup (nat_of_num k) = nat_of_num (Num.Bit0 k)"
       
    64   by (simp_all add: dup_def numeral_Bit0)
       
    65 
       
    66 definition sub :: "num \<Rightarrow> num \<Rightarrow> nat option" where
       
    67   "sub k l = (if k \<ge> l then Some (numeral k - numeral l) else None)"
       
    68 
       
    69 lemma sub_code [code]:
       
    70   "sub Num.One Num.One = Some 0"
       
    71   "sub (Num.Bit0 m) Num.One = Some (nat_of_num (Num.BitM m))"
       
    72   "sub (Num.Bit1 m) Num.One = Some (nat_of_num (Num.Bit0 m))"
       
    73   "sub Num.One (Num.Bit0 n) = None"
       
    74   "sub Num.One (Num.Bit1 n) = None"
       
    75   "sub (Num.Bit0 m) (Num.Bit0 n) = Option.map dup (sub m n)"
       
    76   "sub (Num.Bit1 m) (Num.Bit1 n) = Option.map dup (sub m n)"
       
    77   "sub (Num.Bit1 m) (Num.Bit0 n) = Option.map (\<lambda>q. dup q + 1) (sub m n)"
       
    78   "sub (Num.Bit0 m) (Num.Bit1 n) = (case sub m n of None \<Rightarrow> None
       
    79      | Some q \<Rightarrow> if q = 0 then None else Some (dup q - 1))"
       
    80   apply (auto simp add: nat_of_num_numeral
       
    81     Num.dbl_def Num.dbl_inc_def Num.dbl_dec_def
       
    82     Let_def le_imp_diff_is_add BitM_plus_one sub_def dup_def)
       
    83   apply (simp_all add: sub_non_positive)
       
    84   apply (simp_all add: sub_non_negative [symmetric, where ?'a = int])
       
    85   done
       
    86 
       
    87 lemma [code, code del]:
       
    88   "(minus :: nat \<Rightarrow> _) = minus" ..
       
    89 
       
    90 lemma minus_nat_code [code]:
       
    91   "nat_of_num k - nat_of_num l = (case sub k l of None \<Rightarrow> 0 | Some j \<Rightarrow> j)"
       
    92   "m - 0 = (m::nat)"
       
    93   "0 - n = (0::nat)"
       
    94   by (simp_all add: nat_of_num_numeral sub_non_positive sub_def)
       
    95 
       
    96 lemma [code, code del]:
       
    97   "(times :: nat \<Rightarrow> _) = times" ..
       
    98 
       
    99 lemma times_nat_code [code]:
       
   100   "nat_of_num k * nat_of_num l = nat_of_num (k * l)"
       
   101   "m * 0 = (0::nat)"
       
   102   "0 * n = (0::nat)"
       
   103   by (simp_all add: nat_of_num_numeral)
       
   104 
       
   105 lemma [code, code del]:
       
   106   "(HOL.equal :: nat \<Rightarrow> _) = HOL.equal" ..
       
   107 
       
   108 lemma equal_nat_code [code]:
       
   109   "HOL.equal 0 (0::nat) \<longleftrightarrow> True"
       
   110   "HOL.equal 0 (nat_of_num l) \<longleftrightarrow> False"
       
   111   "HOL.equal (nat_of_num k) 0 \<longleftrightarrow> False"
       
   112   "HOL.equal (nat_of_num k) (nat_of_num l) \<longleftrightarrow> HOL.equal k l"
       
   113   by (simp_all add: nat_of_num_numeral equal)
       
   114 
       
   115 lemma equal_nat_refl [code nbe]:
       
   116   "HOL.equal (n::nat) n \<longleftrightarrow> True"
       
   117   by (rule equal_refl)
       
   118 
       
   119 lemma [code, code del]:
       
   120   "(less_eq :: nat \<Rightarrow> _) = less_eq" ..
       
   121 
       
   122 lemma less_eq_nat_code [code]:
       
   123   "0 \<le> (n::nat) \<longleftrightarrow> True"
       
   124   "nat_of_num k \<le> 0 \<longleftrightarrow> False"
       
   125   "nat_of_num k \<le> nat_of_num l \<longleftrightarrow> k \<le> l"
       
   126   by (simp_all add: nat_of_num_numeral)
       
   127 
       
   128 lemma [code, code del]:
       
   129   "(less :: nat \<Rightarrow> _) = less" ..
       
   130 
       
   131 lemma less_nat_code [code]:
       
   132   "(m::nat) < 0 \<longleftrightarrow> False"
       
   133   "0 < nat_of_num l \<longleftrightarrow> True"
       
   134   "nat_of_num k < nat_of_num l \<longleftrightarrow> k < l"
       
   135   by (simp_all add: nat_of_num_numeral)
       
   136 
       
   137 
       
   138 subsection {* Conversions *}
       
   139 
       
   140 lemma [code, code del]:
       
   141   "of_nat = of_nat" ..
       
   142 
       
   143 lemma of_nat_code [code]:
       
   144   "of_nat 0 = 0"
       
   145   "of_nat (nat_of_num k) = numeral k"
       
   146   by (simp_all add: nat_of_num_numeral)
       
   147 
       
   148 
       
   149 subsection {* Case analysis *}
       
   150 
       
   151 text {*
       
   152   Case analysis on natural numbers is rephrased using a conditional
       
   153   expression:
       
   154 *}
       
   155 
       
   156 lemma [code, code_unfold]:
       
   157   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
       
   158   by (auto simp add: fun_eq_iff dest!: gr0_implies_Suc)
       
   159 
       
   160 
       
   161 subsection {* Preprocessors *}
       
   162 
       
   163 text {*
       
   164   The term @{term "Suc n"} is no longer a valid pattern.
       
   165   Therefore, all occurrences of this term in a position
       
   166   where a pattern is expected (i.e.~on the left-hand side of a recursion
       
   167   equation) must be eliminated.
       
   168   This can be accomplished by applying the following transformation rules:
       
   169 *}
       
   170 
       
   171 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
       
   172   f n \<equiv> if n = 0 then g else h (n - 1)"
       
   173   by (rule eq_reflection) (cases n, simp_all)
       
   174 
       
   175 text {*
       
   176   The rules above are built into a preprocessor that is plugged into
       
   177   the code generator. Since the preprocessor for introduction rules
       
   178   does not know anything about modes, some of the modes that worked
       
   179   for the canonical representation of natural numbers may no longer work.
       
   180 *}
       
   181 
       
   182 (*<*)
       
   183 setup {*
       
   184 let
       
   185 
       
   186 fun remove_suc thy thms =
       
   187   let
       
   188     val vname = singleton (Name.variant_list (map fst
       
   189       (fold (Term.add_var_names o Thm.full_prop_of) thms []))) "n";
       
   190     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
       
   191     fun lhs_of th = snd (Thm.dest_comb
       
   192       (fst (Thm.dest_comb (cprop_of th))));
       
   193     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
       
   194     fun find_vars ct = (case term_of ct of
       
   195         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
       
   196       | _ $ _ =>
       
   197         let val (ct1, ct2) = Thm.dest_comb ct
       
   198         in 
       
   199           map (apfst (fn ct => Thm.apply ct ct2)) (find_vars ct1) @
       
   200           map (apfst (Thm.apply ct1)) (find_vars ct2)
       
   201         end
       
   202       | _ => []);
       
   203     val eqs = maps
       
   204       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
       
   205     fun mk_thms (th, (ct, cv')) =
       
   206       let
       
   207         val th' =
       
   208           Thm.implies_elim
       
   209            (Conv.fconv_rule (Thm.beta_conversion true)
       
   210              (Drule.instantiate'
       
   211                [SOME (ctyp_of_term ct)] [SOME (Thm.lambda cv ct),
       
   212                  SOME (Thm.lambda cv' (rhs_of th)), NONE, SOME cv']
       
   213                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
       
   214       in
       
   215         case map_filter (fn th'' =>
       
   216             SOME (th'', singleton
       
   217               (Variable.trade (K (fn [th'''] => [th''' RS th']))
       
   218                 (Variable.global_thm_context th'')) th'')
       
   219           handle THM _ => NONE) thms of
       
   220             [] => NONE
       
   221           | thps =>
       
   222               let val (ths1, ths2) = split_list thps
       
   223               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
       
   224       end
       
   225   in get_first mk_thms eqs end;
       
   226 
       
   227 fun eqn_suc_base_preproc thy thms =
       
   228   let
       
   229     val dest = fst o Logic.dest_equals o prop_of;
       
   230     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
       
   231   in
       
   232     if forall (can dest) thms andalso exists (contains_suc o dest) thms
       
   233       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
       
   234        else NONE
       
   235   end;
       
   236 
       
   237 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
       
   238 
       
   239 in
       
   240 
       
   241   Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
       
   242 
       
   243 end;
       
   244 *}
       
   245 (*>*)
       
   246 
       
   247 code_modulename SML
       
   248   Code_Binary_Nat Arith
       
   249 
       
   250 code_modulename OCaml
       
   251   Code_Binary_Nat Arith
       
   252 
       
   253 code_modulename Haskell
       
   254   Code_Binary_Nat Arith
       
   255 
       
   256 hide_const (open) dup sub
       
   257 
       
   258 end
       
   259