(* Title: HOL/Library/Efficient_Nat.thy 
2 
ID: $Id$ 

25931  3 
Author: Stefan Berghofer, Florian Haftmann, TU Muenchen 
23854  4 
*) 
5 

25931  6 
header {* Implementation of natural numbers by targetlanguage integers *} 
23854  7 

8 
theory Efficient_Nat 

9 
imports Main Code_Integer Code_Index 
23854  10 
begin 
11 

12 
text {* 

25931  13 
When generating code for functions on natural numbers, the 
14 
canonical representation using @{term "0::nat"} and 

15 
@{term "Suc"} is unsuitable for computations involving large 

16 
numbers. The efficiency of the generated code can be improved 

17 
drastically by implementing natural numbers by targetlanguage 

18 
integers. To do this, just include this theory. 

23854  19 
*} 
20 

25931  21 
subsection {* Basic arithmetic *} 
23854  22 

23 
text {* 

24 
Most standard arithmetic functions on natural numbers are implemented 

25 
using their counterparts on the integers: 

26 
*} 

27 

25931  28 
code_datatype number_nat_inst.number_of_nat 
29 

25931  30 
lemma zero_nat_code [code, code unfold]: 
31 
"0 = (Numeral0 :: nat)" 

32 
by simp 

33 
lemmas [code post] = zero_nat_code [symmetric] 

34 

25931  35 
lemma one_nat_code [code, code unfold]: 
36 
"1 = (Numeral1 :: nat)" 

37 
by simp 

38 
lemmas [code post] = one_nat_code [symmetric] 

39 

25931  40 
lemma Suc_code [code]: 
41 
"Suc n = n + 1" 

42 
by simp 

43 

25931  44 
lemma plus_nat_code [code]: 
45 
"n + m = nat (of_nat n + of_nat m)" 

46 
by simp 

47 

25931  48 
lemma minus_nat_code [code]: 
49 
"n  m = nat (of_nat n  of_nat m)" 

50 
by simp 

51 

25931  52 
lemma times_nat_code [code]: 
53 
"n * m = nat (of_nat n * of_nat m)" 

54 
unfolding of_nat_mult [symmetric] by simp 

55 

25931  56 
lemma div_nat_code [code]: 
57 
"n div m = nat (of_nat n div of_nat m)" 

58 
unfolding zdiv_int [symmetric] by simp 

59 

25931  60 
lemma mod_nat_code [code]: 
61 
"n mod m = nat (of_nat n mod of_nat m)" 

62 
unfolding zmod_int [symmetric] by simp 

63 

25931  64 
lemma eq_nat_code [code]: 
65 
"n = m \<longleftrightarrow> (of_nat n \<Colon> int) = of_nat m" 

66 
by simp 

67 

25931  68 
lemma less_eq_nat_code [code]: 
69 
"n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m" 

70 
by simp 

23854  71 

25931  72 
lemma less_nat_code [code]: 
73 
"n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m" 

74 
by simp 

23854  75 

25931  76 
subsection {* Case analysis *} 
23854  77 

78 
text {* 

25931  79 
Case analysis on natural numbers is rephrased using a conditional 
80 
expression: 

23854  81 
*} 
82 

25931  83 
lemma [code func, code unfold]: 
84 
"nat_case = (\<lambda>f g n. if n = 0 then f else g (n  1))" 

85 
by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc) 

25615  86 

23854  87 

88 
subsection {* Preprocessors *} 

89 

90 
text {* 

91 
In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer 

92 
a constructor term. Therefore, all occurrences of this term in a position 

93 
where a pattern is expected (i.e.\ on the lefthand side of a recursion 

94 
equation or in the arguments of an inductive relation in an introduction 

95 
rule) must be eliminated. 

96 
This can be accomplished by applying the following transformation rules: 

97 
*} 

98 

25931  99 
lemma Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow> 
23854  100 
f n = (if n = 0 then g else h (n  1))" 
101 
by (case_tac n) simp_all 

102 

25931  103 
lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n  1) n" 
23854  104 
by (case_tac n) simp_all 
105 

106 
text {* 

107 
The rules above are built into a preprocessor that is plugged into 

108 
the code generator. Since the preprocessor for introduction rules 

109 
does not know anything about modes, some of the modes that worked 

110 
for the canonical representation of natural numbers may no longer work. 

111 
*} 

112 

113 
(*<*) 

114 

115 
ML {* 

116 
fun remove_suc thy thms = 

117 
let 

118 
val vname = Name.variant (map fst 

119 
(fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x"; 

120 
val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT)); 

121 
fun lhs_of th = snd (Thm.dest_comb 

122 
(fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th)))))); 

123 
fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th)))); 

124 
fun find_vars ct = (case term_of ct of 

125 
(Const ("Suc", _) $ Var _) => [(cv, snd (Thm.dest_comb ct))] 

126 
 _ $ _ => 

127 
let val (ct1, ct2) = Thm.dest_comb ct 

128 
in 

129 
map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @ 

130 
map (apfst (Thm.capply ct1)) (find_vars ct2) 

131 
end 

132 
 _ => []); 

133 
val eqs = maps 

134 
(fn th => map (pair th) (find_vars (lhs_of th))) thms; 

135 
fun mk_thms (th, (ct, cv')) = 

136 
let 

137 
val th' = 

138 
Thm.implies_elim 

139 
(Conv.fconv_rule (Thm.beta_conversion true) 

140 
(Drule.instantiate' 

141 
[SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct), 

142 
SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv'] 

24222  143 
@{thm Suc_if_eq})) (Thm.forall_intr cv' th) 
23854  144 
in 
145 
case map_filter (fn th'' => 

146 
SOME (th'', singleton 

147 
(Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'') 

148 
handle THM _ => NONE) thms of 

149 
[] => NONE 

150 
 thps => 

151 
let val (ths1, ths2) = split_list thps 

152 
in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end 

153 
end 

154 
in 

155 
case get_first mk_thms eqs of 

156 
NONE => thms 

157 
 SOME x => remove_suc thy x 

158 
end; 

159 

160 
fun eqn_suc_preproc thy ths = 

161 
let 

24222  162 
val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of; 
163 
fun contains_suc t = member (op =) (term_consts t) @{const_name Suc}; 

23854  164 
in 
165 
if forall (can dest) ths andalso 

166 
exists (contains_suc o dest) ths 

167 
then remove_suc thy ths else ths 

168 
end; 

169 

170 
fun remove_suc_clause thy thms = 

171 
let 

172 
val vname = Name.variant (map fst 

173 
(fold (Term.add_varnames o Thm.full_prop_of) thms [])) "x"; 

24222  174 
fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v) 
23854  175 
 find_var (t $ u) = (case find_var t of NONE => find_var u  x => x) 
176 
 find_var _ = NONE; 

177 
fun find_thm th = 

178 
let val th' = Conv.fconv_rule ObjectLogic.atomize th 

179 
in Option.map (pair (th, th')) (find_var (prop_of th')) end 

180 
in 

181 
case get_first find_thm thms of 

182 
NONE => thms 

183 
 SOME ((th, th'), (Sucv, v)) => 

184 
let 

185 
val cert = cterm_of (Thm.theory_of_thm th); 

186 
val th'' = ObjectLogic.rulify (Thm.implies_elim 

187 
(Conv.fconv_rule (Thm.beta_conversion true) 

188 
(Drule.instantiate' [] 

189 
[SOME (cert (lambda v (Abs ("x", HOLogic.natT, 

190 
abstract_over (Sucv, 

191 
HOLogic.dest_Trueprop (prop_of th')))))), 

24222  192 
SOME (cert v)] @{thm Suc_clause})) 
23854  193 
(Thm.forall_intr (cert v) th')) 
194 
in 

195 
remove_suc_clause thy (map (fn th''' => 

196 
if (op = o pairself prop_of) (th''', th) then th'' else th''') thms) 

197 
end 

198 
end; 

199 

200 
fun clause_suc_preproc thy ths = 

201 
let 

202 
val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop 

203 
in 

204 
if forall (can (dest o concl_of)) ths andalso 

205 
exists (fn th => member (op =) (foldr add_term_consts 

206 
[] (map_filter (try dest) (concl_of th :: prems_of th))) "Suc") ths 

207 
then remove_suc_clause thy ths else ths 

208 
end; 

209 

210 
fun lift_obj_eq f thy = 

211 
map (fn thm => thm RS @{thm meta_eq_to_obj_eq}) 

212 
#> f thy 

213 
#> map (fn thm => thm RS @{thm eq_reflection}) 

214 
#> map (Conv.fconv_rule Drule.beta_eta_conversion) 

215 
*} 

216 

217 
setup {* 

218 
Codegen.add_preprocessor eqn_suc_preproc 

219 
#> Codegen.add_preprocessor clause_suc_preproc 

24222  220 
#> Code.add_preproc ("eqn_Suc", lift_obj_eq eqn_suc_preproc) 
221 
#> Code.add_preproc ("clause_Suc", lift_obj_eq clause_suc_preproc) 

23854  222 
*} 
223 
(*>*) 

224 

225 

25931  226 
subsection {* Target language setup *} 
227 

228 
text {* 

229 
We map @{typ nat} to target language integers, where we 

230 
assert that values are always nonnegative. 

231 
*} 

232 

233 
code_type nat 

234 
(SML "int") 

235 
(OCaml "Big'_int.big'_int") 

236 
(Haskell "Integer") 

237 

238 
types_code 

239 
nat ("int") 

240 
attach (term_of) {* 

241 
val term_of_nat = HOLogic.mk_number HOLogic.natT; 

242 
*} 

243 
attach (test) {* 

244 
fun gen_nat i = 

245 
let val n = random_range 0 i 

246 
in (n, fn () => term_of_nat n) end; 

247 
*} 

248 

249 
text {* 

250 
Natural numerals. 

251 
*} 

252 

253 
lemma [code inline]: 

254 
"nat (number_of i) = number_nat_inst.number_of_nat i" 

255 
 {* this interacts as desired with @{thm nat_number_of_def} *} 

256 
by (simp add: number_nat_inst.number_of_nat) 

257 

258 
setup {* 

259 
fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat} 

260 
false false) ["SML", "OCaml", "Haskell"] 

261 
*} 

262 

263 
text {* 

264 
Since natural numbers are implemented 

265 
using integers, the coercion function @{const "of_nat"} of type 

266 
@{typ "nat \<Rightarrow> int"} is simply implemented by the identity function. 

267 
For the @{const "nat"} function for converting an integer to a natural 

268 
number, we give a specific implementation using an ML function that 

269 
returns its input value, provided that it is nonnegative, and otherwise 

270 
returns @{text "0"}. 

271 
*} 

272 

273 
definition 

274 
int :: "nat \<Rightarrow> int" 

275 
where 

276 
[code func del]: "int = of_nat" 

277 

278 
lemma int_code' [code func]: 

279 
"int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)" 

280 
unfolding int_nat_number_of [folded int_def] .. 

281 

282 
lemma nat_code' [code func]: 

283 
"nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)" 

284 
by auto 

285 

286 
lemma of_nat_int [code unfold]: 

287 
"of_nat = int" by (simp add: int_def) 

288 

289 
code_const int 

290 
(SML "_") 

291 
(OCaml "_") 

292 
(Haskell "_") 

293 

294 
code_const nat 

295 
(SML "IntInf.max/ (/0,/ _)") 

296 
(OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int") 

297 
(Haskell "max 0") 

298 

299 
consts_code 

300 
int ("(_)") 

301 
nat ("\<module>nat") 

302 
attach {* 

303 
fun nat i = if i < 0 then 0 else i; 

304 
*} 

305 

306 

307 
text {* Conversion from and to indices. *} 

308 

309 
code_const nat_of_index 

310 
(SML "IntInf.fromInt") 

311 
(OCaml "Big'_int.big'_int'_of'_int") 

312 
(Haskell "toInteger") 

313 

314 
code_const index_of_nat 

315 
(SML "IntInf.toInt") 

316 
(OCaml "Big'_int.int'_of'_big'_int") 

317 
(Haskell "fromInteger") 

318 

319 

320 
text {* Using target language arithmetic operations whenever appropriate *} 

321 

322 
code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" 

323 
(SML "IntInf.+ ((_), (_))") 

324 
(OCaml "Big'_int.add'_big'_int") 

325 
(Haskell infixl 6 "+") 

326 

327 
code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" 

328 
(SML "IntInf.* ((_), (_))") 

329 
(OCaml "Big'_int.mult'_big'_int") 

330 
(Haskell infixl 7 "*") 

331 

332 
code_const "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" 

333 
(SML "IntInf.div/ ((_),/ (_))") 

334 
(OCaml "Big'_int.div'_big'_int") 

335 
(Haskell "div") 

336 

337 
code_const "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" 

338 
(SML "IntInf.mod/ ((_),/ (_))") 

339 
(OCaml "Big'_int.mod'_big'_int") 

340 
(Haskell "mod") 

341 

342 
code_const "op = \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" 

343 
(SML "!((_ : IntInf.int) = _)") 

344 
(OCaml "Big'_int.eq'_big'_int") 

345 
(Haskell infixl 4 "==") 

346 

347 
code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" 

348 
(SML "IntInf.<= ((_), (_))") 

349 
(OCaml "Big'_int.le'_big'_int") 

350 
(Haskell infix 4 "<=") 

351 

352 
code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" 

353 
(SML "IntInf.< ((_), (_))") 

354 
(OCaml "Big'_int.lt'_big'_int") 

355 
(Haskell infix 4 "<") 

356 

357 
consts_code 

358 
0 ("0") 

359 
Suc ("(_ +/ 1)") 

360 
"op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" ("(_ +/ _)") 

361 
"op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" ("(_ */ _)") 

362 
"op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" ("(_ div/ _)") 

363 
"op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" ("(_ mod/ _)") 

364 
"op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" ("(_ <=/ _)") 

365 
"op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" ("(_ </ _)") 

366 

367 

368 
text {* Module names *} 

23854  369 

370 
code_modulename SML 

371 
Nat Integer 

372 
Divides Integer 

373 
Efficient_Nat Integer 

374 

375 
code_modulename OCaml 

376 
Nat Integer 

377 
Divides Integer 

378 
Efficient_Nat Integer 

379 

380 
code_modulename Haskell 

381 
Nat Integer 

24195  382 
Divides Integer 
23854  383 
Efficient_Nat Integer 
384 

25931  385 
hide const int 
23854  386 

387 
end 