47108
|
1 |
(* Title: HOL/Library/Code_Nat.thy
|
|
2 |
Author: Stefan Berghofer, Florian Haftmann, TU Muenchen
|
|
3 |
*)
|
|
4 |
|
|
5 |
header {* Implementation of natural numbers as binary numerals *}
|
|
6 |
|
|
7 |
theory Code_Nat
|
|
8 |
imports Main
|
|
9 |
begin
|
|
10 |
|
|
11 |
text {*
|
|
12 |
When generating code for functions on natural numbers, the
|
|
13 |
canonical representation using @{term "0::nat"} and
|
|
14 |
@{term Suc} is unsuitable for computations involving large
|
|
15 |
numbers. This theory refines the representation of
|
|
16 |
natural numbers for code generation to use binary
|
|
17 |
numerals, which do not grow linear in size but logarithmic.
|
|
18 |
*}
|
|
19 |
|
|
20 |
subsection {* Representation *}
|
|
21 |
|
|
22 |
lemma [code_abbrev]:
|
|
23 |
"nat_of_num = numeral"
|
|
24 |
by (fact nat_of_num_numeral)
|
|
25 |
|
|
26 |
code_datatype "0::nat" nat_of_num
|
|
27 |
|
|
28 |
lemma [code]:
|
|
29 |
"num_of_nat 0 = Num.One"
|
|
30 |
"num_of_nat (nat_of_num k) = k"
|
|
31 |
by (simp_all add: nat_of_num_inverse)
|
|
32 |
|
|
33 |
lemma [code]:
|
|
34 |
"(1\<Colon>nat) = Numeral1"
|
|
35 |
by simp
|
|
36 |
|
|
37 |
lemma [code_abbrev]: "Numeral1 = (1\<Colon>nat)"
|
|
38 |
by simp
|
|
39 |
|
|
40 |
lemma [code]:
|
|
41 |
"Suc n = n + 1"
|
|
42 |
by simp
|
|
43 |
|
|
44 |
|
|
45 |
subsection {* Basic arithmetic *}
|
|
46 |
|
|
47 |
lemma [code, code del]:
|
|
48 |
"(plus :: nat \<Rightarrow> _) = plus" ..
|
|
49 |
|
|
50 |
lemma plus_nat_code [code]:
|
|
51 |
"nat_of_num k + nat_of_num l = nat_of_num (k + l)"
|
|
52 |
"m + 0 = (m::nat)"
|
|
53 |
"0 + n = (n::nat)"
|
|
54 |
by (simp_all add: nat_of_num_numeral)
|
|
55 |
|
|
56 |
text {* Bounded subtraction needs some auxiliary *}
|
|
57 |
|
|
58 |
definition dup :: "nat \<Rightarrow> nat" where
|
|
59 |
"dup n = n + n"
|
|
60 |
|
|
61 |
lemma dup_code [code]:
|
|
62 |
"dup 0 = 0"
|
|
63 |
"dup (nat_of_num k) = nat_of_num (Num.Bit0 k)"
|
|
64 |
unfolding Num_def by (simp_all add: dup_def numeral_Bit0)
|
|
65 |
|
|
66 |
definition sub :: "num \<Rightarrow> num \<Rightarrow> nat option" where
|
|
67 |
"sub k l = (if k \<ge> l then Some (numeral k - numeral l) else None)"
|
|
68 |
|
|
69 |
lemma sub_code [code]:
|
|
70 |
"sub Num.One Num.One = Some 0"
|
|
71 |
"sub (Num.Bit0 m) Num.One = Some (nat_of_num (Num.BitM m))"
|
|
72 |
"sub (Num.Bit1 m) Num.One = Some (nat_of_num (Num.Bit0 m))"
|
|
73 |
"sub Num.One (Num.Bit0 n) = None"
|
|
74 |
"sub Num.One (Num.Bit1 n) = None"
|
|
75 |
"sub (Num.Bit0 m) (Num.Bit0 n) = Option.map dup (sub m n)"
|
|
76 |
"sub (Num.Bit1 m) (Num.Bit1 n) = Option.map dup (sub m n)"
|
|
77 |
"sub (Num.Bit1 m) (Num.Bit0 n) = Option.map (\<lambda>q. dup q + 1) (sub m n)"
|
|
78 |
"sub (Num.Bit0 m) (Num.Bit1 n) = (case sub m n of None \<Rightarrow> None
|
|
79 |
| Some q \<Rightarrow> if q = 0 then None else Some (dup q - 1))"
|
|
80 |
apply (auto simp add: nat_of_num_numeral
|
|
81 |
Num.dbl_def Num.dbl_inc_def Num.dbl_dec_def
|
|
82 |
Let_def le_imp_diff_is_add BitM_plus_one sub_def dup_def)
|
|
83 |
apply (simp_all add: sub_non_positive)
|
|
84 |
apply (simp_all add: sub_non_negative [symmetric, where ?'a = int])
|
|
85 |
done
|
|
86 |
|
|
87 |
lemma [code, code del]:
|
|
88 |
"(minus :: nat \<Rightarrow> _) = minus" ..
|
|
89 |
|
|
90 |
lemma minus_nat_code [code]:
|
|
91 |
"nat_of_num k - nat_of_num l = (case sub k l of None \<Rightarrow> 0 | Some j \<Rightarrow> j)"
|
|
92 |
"m - 0 = (m::nat)"
|
|
93 |
"0 - n = (0::nat)"
|
|
94 |
by (simp_all add: nat_of_num_numeral sub_non_positive sub_def)
|
|
95 |
|
|
96 |
lemma [code, code del]:
|
|
97 |
"(times :: nat \<Rightarrow> _) = times" ..
|
|
98 |
|
|
99 |
lemma times_nat_code [code]:
|
|
100 |
"nat_of_num k * nat_of_num l = nat_of_num (k * l)"
|
|
101 |
"m * 0 = (0::nat)"
|
|
102 |
"0 * n = (0::nat)"
|
|
103 |
by (simp_all add: nat_of_num_numeral)
|
|
104 |
|
|
105 |
lemma [code, code del]:
|
|
106 |
"(HOL.equal :: nat \<Rightarrow> _) = HOL.equal" ..
|
|
107 |
|
|
108 |
lemma equal_nat_code [code]:
|
|
109 |
"HOL.equal 0 (0::nat) \<longleftrightarrow> True"
|
|
110 |
"HOL.equal 0 (nat_of_num l) \<longleftrightarrow> False"
|
|
111 |
"HOL.equal (nat_of_num k) 0 \<longleftrightarrow> False"
|
|
112 |
"HOL.equal (nat_of_num k) (nat_of_num l) \<longleftrightarrow> HOL.equal k l"
|
|
113 |
by (simp_all add: nat_of_num_numeral equal)
|
|
114 |
|
|
115 |
lemma equal_nat_refl [code nbe]:
|
|
116 |
"HOL.equal (n::nat) n \<longleftrightarrow> True"
|
|
117 |
by (rule equal_refl)
|
|
118 |
|
|
119 |
lemma [code, code del]:
|
|
120 |
"(less_eq :: nat \<Rightarrow> _) = less_eq" ..
|
|
121 |
|
|
122 |
lemma less_eq_nat_code [code]:
|
|
123 |
"0 \<le> (n::nat) \<longleftrightarrow> True"
|
|
124 |
"nat_of_num k \<le> 0 \<longleftrightarrow> False"
|
|
125 |
"nat_of_num k \<le> nat_of_num l \<longleftrightarrow> k \<le> l"
|
|
126 |
by (simp_all add: nat_of_num_numeral)
|
|
127 |
|
|
128 |
lemma [code, code del]:
|
|
129 |
"(less :: nat \<Rightarrow> _) = less" ..
|
|
130 |
|
|
131 |
lemma less_nat_code [code]:
|
|
132 |
"(m::nat) < 0 \<longleftrightarrow> False"
|
|
133 |
"0 < nat_of_num l \<longleftrightarrow> True"
|
|
134 |
"nat_of_num k < nat_of_num l \<longleftrightarrow> k < l"
|
|
135 |
by (simp_all add: nat_of_num_numeral)
|
|
136 |
|
|
137 |
|
|
138 |
subsection {* Conversions *}
|
|
139 |
|
|
140 |
lemma [code, code del]:
|
|
141 |
"of_nat = of_nat" ..
|
|
142 |
|
|
143 |
lemma of_nat_code [code]:
|
|
144 |
"of_nat 0 = 0"
|
|
145 |
"of_nat (nat_of_num k) = numeral k"
|
|
146 |
by (simp_all add: nat_of_num_numeral)
|
|
147 |
|
|
148 |
|
|
149 |
subsection {* Case analysis *}
|
|
150 |
|
|
151 |
text {*
|
|
152 |
Case analysis on natural numbers is rephrased using a conditional
|
|
153 |
expression:
|
|
154 |
*}
|
|
155 |
|
|
156 |
lemma [code, code_unfold]:
|
|
157 |
"nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
|
|
158 |
by (auto simp add: fun_eq_iff dest!: gr0_implies_Suc)
|
|
159 |
|
|
160 |
|
|
161 |
subsection {* Preprocessors *}
|
|
162 |
|
|
163 |
text {*
|
|
164 |
The term @{term "Suc n"} is no longer a valid pattern.
|
|
165 |
Therefore, all occurrences of this term in a position
|
|
166 |
where a pattern is expected (i.e.~on the left-hand side of a recursion
|
|
167 |
equation) must be eliminated.
|
|
168 |
This can be accomplished by applying the following transformation rules:
|
|
169 |
*}
|
|
170 |
|
|
171 |
lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
|
|
172 |
f n \<equiv> if n = 0 then g else h (n - 1)"
|
|
173 |
by (rule eq_reflection) (cases n, simp_all)
|
|
174 |
|
|
175 |
text {*
|
|
176 |
The rules above are built into a preprocessor that is plugged into
|
|
177 |
the code generator. Since the preprocessor for introduction rules
|
|
178 |
does not know anything about modes, some of the modes that worked
|
|
179 |
for the canonical representation of natural numbers may no longer work.
|
|
180 |
*}
|
|
181 |
|
|
182 |
(*<*)
|
|
183 |
setup {*
|
|
184 |
let
|
|
185 |
|
|
186 |
fun remove_suc thy thms =
|
|
187 |
let
|
|
188 |
val vname = singleton (Name.variant_list (map fst
|
|
189 |
(fold (Term.add_var_names o Thm.full_prop_of) thms []))) "n";
|
|
190 |
val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
|
|
191 |
fun lhs_of th = snd (Thm.dest_comb
|
|
192 |
(fst (Thm.dest_comb (cprop_of th))));
|
|
193 |
fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
|
|
194 |
fun find_vars ct = (case term_of ct of
|
|
195 |
(Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
|
|
196 |
| _ $ _ =>
|
|
197 |
let val (ct1, ct2) = Thm.dest_comb ct
|
|
198 |
in
|
|
199 |
map (apfst (fn ct => Thm.apply ct ct2)) (find_vars ct1) @
|
|
200 |
map (apfst (Thm.apply ct1)) (find_vars ct2)
|
|
201 |
end
|
|
202 |
| _ => []);
|
|
203 |
val eqs = maps
|
|
204 |
(fn th => map (pair th) (find_vars (lhs_of th))) thms;
|
|
205 |
fun mk_thms (th, (ct, cv')) =
|
|
206 |
let
|
|
207 |
val th' =
|
|
208 |
Thm.implies_elim
|
|
209 |
(Conv.fconv_rule (Thm.beta_conversion true)
|
|
210 |
(Drule.instantiate'
|
|
211 |
[SOME (ctyp_of_term ct)] [SOME (Thm.lambda cv ct),
|
|
212 |
SOME (Thm.lambda cv' (rhs_of th)), NONE, SOME cv']
|
|
213 |
@{thm Suc_if_eq})) (Thm.forall_intr cv' th)
|
|
214 |
in
|
|
215 |
case map_filter (fn th'' =>
|
|
216 |
SOME (th'', singleton
|
|
217 |
(Variable.trade (K (fn [th'''] => [th''' RS th']))
|
|
218 |
(Variable.global_thm_context th'')) th'')
|
|
219 |
handle THM _ => NONE) thms of
|
|
220 |
[] => NONE
|
|
221 |
| thps =>
|
|
222 |
let val (ths1, ths2) = split_list thps
|
|
223 |
in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
|
|
224 |
end
|
|
225 |
in get_first mk_thms eqs end;
|
|
226 |
|
|
227 |
fun eqn_suc_base_preproc thy thms =
|
|
228 |
let
|
|
229 |
val dest = fst o Logic.dest_equals o prop_of;
|
|
230 |
val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
|
|
231 |
in
|
|
232 |
if forall (can dest) thms andalso exists (contains_suc o dest) thms
|
|
233 |
then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
|
|
234 |
else NONE
|
|
235 |
end;
|
|
236 |
|
|
237 |
val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
|
|
238 |
|
|
239 |
in
|
|
240 |
|
|
241 |
Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
|
|
242 |
|
|
243 |
end;
|
|
244 |
*}
|
|
245 |
(*>*)
|
|
246 |
|
|
247 |
code_modulename SML
|
|
248 |
Code_Nat Arith
|
|
249 |
|
|
250 |
code_modulename OCaml
|
|
251 |
Code_Nat Arith
|
|
252 |
|
|
253 |
code_modulename Haskell
|
|
254 |
Code_Nat Arith
|
|
255 |
|
|
256 |
hide_const (open) dup sub
|
|
257 |
|
|
258 |
end
|