author  haftmann 
Fri, 19 Oct 2007 16:20:27 +0200  
changeset 25101  cae0f68b693b 
parent 25098  1ec53c9ae71a 
child 25167  0fd59d8e2bad 
permissions  rwrr 
24590  1 
(* Title: Tools/nbe.ML 
24155  2 
ID: $Id$ 
3 
Authors: Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen 

4 

24839  5 
Normalization by evaluation, based on generic code generator. 
24155  6 
*) 
7 

8 
signature NBE = 

9 
sig 

25101  10 
val norm_conv: cterm > thm 
11 
val norm_term: theory > term > term 

12 

24155  13 
datatype Univ = 
24381  14 
Const of string * Univ list (*named (uninterpreted) constants*) 
24155  15 
 Free of string * Univ list 
25101  16 
 DFree of string (*free (uninterpreted) dictionary parameters*) 
24155  17 
 BVar of int * Univ list 
18 
 Abs of (int * (Univ list > Univ)) * Univ list; 

25101  19 
val free: string > Univ (*free (uninterpreted) variables*) 
24155  20 
val app: Univ > Univ > Univ (*explicit application*) 
25101  21 
val abs: int > (Univ list > Univ) > Univ 
22 
(*abstractions as closures*) 

24155  23 

24590  24 
val univs_ref: (unit > Univ list) ref 
24423
ae9cd0e92423
overloaded definitions accompanied by explicit constants
haftmann
parents:
24381
diff
changeset

25 
val lookup_fun: string > Univ 
24155  26 

27 
val trace: bool ref 

28 
val setup: theory > theory 

29 
end; 

30 

31 
structure Nbe: NBE = 

32 
struct 

33 

34 
(* generic nonsense *) 

35 

36 
val trace = ref false; 

37 
fun tracing f x = if !trace then (Output.tracing (f x); x) else x; 

38 

39 

40 
(** the semantical universe **) 

41 

42 
(* 

43 
Functions are given by their semantical function value. To avoid 

44 
trouble with the MLtype system, these functions have the most 

45 
generic type, that is "Univ list > Univ". The calling convention is 

46 
that the arguments come as a list, the last argument first. In 

47 
other words, a function call that usually would look like 

48 

49 
f x_1 x_2 ... x_n or f(x_1,x_2, ..., x_n) 

50 

51 
would be in our convention called as 

52 

53 
f [x_n,..,x_2,x_1] 

54 

55 
Moreover, to handle functions that are still waiting for some 

56 
arguments we have additionally a list of arguments collected to far 

57 
and the number of arguments we're still waiting for. 

58 
*) 

59 

60 
datatype Univ = 

24381  61 
Const of string * Univ list (*named (uninterpreted) constants*) 
24155  62 
 Free of string * Univ list (*free variables*) 
25101  63 
 DFree of string (*free (uninterpreted) dictionary parameters*) 
24155  64 
 BVar of int * Univ list (*bound named variables*) 
65 
 Abs of (int * (Univ list > Univ)) * Univ list 

24381  66 
(*abstractions as closures*); 
24155  67 

68 
(* constructor functions *) 

69 

25101  70 
fun free v = Free (v, []); 
71 
fun abs n f = Abs ((n, f), []); 

24155  72 
fun app (Abs ((1, f), xs)) x = f (x :: xs) 
73 
 app (Abs ((n, f), xs)) x = Abs ((n  1, f), x :: xs) 

74 
 app (Const (name, args)) x = Const (name, x :: args) 

75 
 app (Free (name, args)) x = Free (name, x :: args) 

76 
 app (BVar (name, args)) x = BVar (name, x :: args); 

77 

25101  78 
(* universe graph *) 
24155  79 

25101  80 
type univ_gr = Univ option Graph.T; 
81 
val compiled : univ_gr > string > bool = can o Graph.get_node; 

24839  82 

24155  83 
(* sandbox communication *) 
84 

24590  85 
val univs_ref = ref (fn () => [] : Univ list); 
24155  86 

87 
local 

88 

25101  89 
val gr_ref = ref NONE : univ_gr option ref; 
24155  90 

25101  91 
fun compile gr raw_s = NAMED_CRITICAL "nbe" (fn () => 
25080  92 
let 
93 
val _ = univs_ref := (fn () => []); 

94 
val s = "Nbe.univs_ref := " ^ raw_s; 

95 
val _ = tracing (fn () => "\n generated code:\n" ^ s) (); 

25101  96 
val _ = gr_ref := SOME gr; 
25080  97 
val _ = use_text "" (Output.tracing o enclose "\ncompiler echo:\n" "\n\n", 
98 
Output.tracing o enclose "\n compiler echo (with error):\n" "\n\n") 

99 
(!trace) s; 

100 
val _ = gr_ref := NONE; 

25098  101 
in !univs_ref end); 
25101  102 

24155  103 
in 
104 

25101  105 
fun lookup_fun s = NAMED_CRITICAL "nbe" (fn () => case ! gr_ref 
24155  106 
of NONE => error "compile_univs" 
25101  107 
 SOME gr => the (Graph.get_node gr s)); 
24155  108 

25101  109 
fun compile_univs gr ([], _) = [] 
110 
 compile_univs gr (cs, raw_s) = cs ~~ compile gr raw_s (); 

24155  111 

112 
end; (*local*) 

113 

114 

115 
(** assembling and compiling ML code from terms **) 

116 

117 
(* abstract ML syntax *) 

118 

119 
infix 9 `$` `$$`; 

120 
fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")"; 

25101  121 
fun e `$$` [] = e 
122 
 e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")"; 

24590  123 
fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")"; 
24155  124 

125 
fun ml_Val v s = "val " ^ v ^ " = " ^ s; 

126 
fun ml_cases t cs = 

127 
"(case " ^ t ^ " of " ^ space_implode "  " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")"; 

128 
fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end"; 

129 

130 
fun ml_list es = "[" ^ commas es ^ "]"; 

131 

24590  132 
val ml_delay = ml_abs "()" 
133 

24155  134 
fun ml_fundefs ([(name, [([], e)])]) = 
135 
"val " ^ name ^ " = " ^ e ^ "\n" 

136 
 ml_fundefs (eqs :: eqss) = 

137 
let 

138 
fun fundef (name, eqs) = 

139 
let 

140 
fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e 

141 
in space_implode "\n  " (map eqn eqs) end; 

142 
in 

143 
(prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss 

144 
> space_implode "\n" 

145 
> suffix "\n" 

146 
end; 

147 

148 
(* nbe specific syntax *) 

149 

150 
local 

151 
val prefix = "Nbe."; 

152 
val name_const = prefix ^ "Const"; 

153 
val name_free = prefix ^ "free"; 

25101  154 
val name_dfree = prefix ^ "DFree"; 
24155  155 
val name_abs = prefix ^ "abs"; 
156 
val name_app = prefix ^ "app"; 

157 
val name_lookup_fun = prefix ^ "lookup_fun"; 

158 
in 

159 

25101  160 
fun nbe_const c ts = 
161 
name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")"); 

24155  162 
fun nbe_fun c = "c_" ^ translate_string (fn "." => "_"  c => c) c; 
25101  163 
fun nbe_free v = name_free `$` ML_Syntax.print_string v; 
164 
fun nbe_dfree v = name_dfree `$` ML_Syntax.print_string v; 

165 
fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n; 

24155  166 
fun nbe_bound v = "v_" ^ v; 
167 

168 
fun nbe_apps e es = 

169 
Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e); 

170 

171 
fun nbe_abss 0 f = f `$` ml_list [] 

25101  172 
 nbe_abss n f = name_abs `$$` [string_of_int n, f]; 
24155  173 

174 
fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c); 

175 

176 
val nbe_value = "value"; 

177 

178 
end; 

179 

24219  180 
open BasicCodeThingol; 
24155  181 

182 
(* greetings to Tarski *) 

183 

25101  184 
fun assemble_idict (DictConst (inst, dss)) = 
185 
nbe_apps (nbe_fun inst) ((maps o map) assemble_idict dss) 

186 
 assemble_idict (DictVar (supers, (v, (n, _)))) = 

187 
fold (fn super => nbe_apps (nbe_fun super) o single) supers (nbe_dict v n); 

188 

189 
fun assemble_iterm is_fun num_args = 

24155  190 
let 
191 
fun of_iterm t = 

192 
let 

24219  193 
val (t', ts) = CodeThingol.unfold_app t 
24347  194 
in of_iapp t' (fold (cons o of_iterm) ts []) end 
195 
and of_iconst c ts = case num_args c 

196 
of SOME n => if n <= length ts 

197 
then let val (args2, args1) = chop (length ts  n) ts 

198 
in nbe_apps (nbe_fun c `$` ml_list args1) args2 

199 
end else nbe_const c ts 

200 
 NONE => if is_fun c then nbe_apps (nbe_fun c) ts 

201 
else nbe_const c ts 

25101  202 
and of_iapp (IConst (c, (dss, _))) ts = of_iconst c 
203 
(ts @ rev ((maps o map) assemble_idict dss)) 

24347  204 
 of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts 
205 
 of_iapp ((v, _) `> t) ts = 

24155  206 
nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts 
24347  207 
 of_iapp (ICase (((t, _), cs), t0)) ts = 
24155  208 
nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs 
209 
@ [("_", of_iterm t0)])) ts 

210 
in of_iterm end; 

211 

25101  212 
fun assemble_fun gr num_args (c, (vs, eqns)) = 
24155  213 
let 
25101  214 
val assemble_arg = assemble_iterm (K false) (K NONE); 
215 
val assemble_rhs = assemble_iterm (is_some o Graph.get_node gr) num_args; 

216 
val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs 

217 
> rev; 

24155  218 
fun assemble_eqn (args, rhs) = 
25101  219 
([ml_list (map assemble_arg (rev args) @ dict_params)], assemble_rhs rhs); 
220 
val default_params = map nbe_bound (Name.invent_list [] "a" ((the o num_args) c)); 

24155  221 
val default_eqn = ([ml_list default_params], nbe_const c default_params); 
222 
in map assemble_eqn eqns @ [default_eqn] end; 

223 

25101  224 
fun assemble_eqnss gr ([], deps) = ([], "") 
225 
 assemble_eqnss gr (eqnss, deps) = 

24155  226 
let 
227 
val cs = map fst eqnss; 

25101  228 
val num_args = cs ~~ map (fn (_, (vs, (args, rhs) :: _)) => 
229 
length (maps snd vs) + length args) eqnss; 

230 
val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps); 

24155  231 
val bind_locals = ml_fundefs (map nbe_fun cs ~~ map 
25101  232 
(assemble_fun gr (AList.lookup (op =) num_args)) eqnss); 
24590  233 
val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args) 
234 
> ml_delay; 

25101  235 
in (cs, ml_Let (bind_deps @ [bind_locals]) result) end; 
24155  236 

25101  237 
fun eqns_of_stmt (_, CodeThingol.Fun (_, [])) = 
238 
[] 

239 
 eqns_of_stmt (const, CodeThingol.Fun ((vs, _), eqns)) = 

240 
[(const, (vs, map fst eqns))] 

241 
 eqns_of_stmt (_, CodeThingol.Datatypecons _) = 

242 
[] 

243 
 eqns_of_stmt (_, CodeThingol.Datatype _) = 

244 
[] 

245 
 eqns_of_stmt (class, CodeThingol.Class (v, (superclasses, classops))) = 

246 
let 

247 
val names = map snd superclasses @ map fst classops; 

248 
val params = Name.invent_list [] "d" (length names); 

249 
fun mk (k, name) = 

250 
(name, ([(v, [])], 

251 
[([IConst (class, ([], [])) `$$ map IVar params], IVar (nth params k))])); 

252 
in map_index mk names end 

253 
 eqns_of_stmt (_, CodeThingol.Classrel _) = 

254 
[] 

255 
 eqns_of_stmt (_, CodeThingol.Classparam _) = 

256 
[] 

257 
 eqns_of_stmt (inst, CodeThingol.Classinst ((class, (_, arities)), (superinsts, instops))) = 

258 
[(inst, (arities, [([], IConst (class, ([], [])) `$$ 

259 
map (fn (_, (_, (inst, dicts))) => IConst (inst, (dicts, []))) superinsts 

260 
@ map (IConst o snd o fst) instops)]))]; 

24155  261 

25101  262 
fun compile_stmts stmts_deps = 
263 
let 

264 
val names = map (fst o fst) stmts_deps; 

265 
val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps; 

266 
val eqnss = maps (eqns_of_stmt o fst) stmts_deps; 

267 
val compiled_deps = names_deps > maps snd > distinct (op =) > subtract (op =) names; 

268 
fun compile gr = (eqnss, compiled_deps) > assemble_eqnss gr > compile_univs gr > rpair gr; 

269 
in 

270 
fold (fn name => Graph.new_node (name, NONE)) names 

271 
#> fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps 

272 
#> compile 

273 
#> fold (fn (name, univ) => Graph.map_node name (K (SOME univ))) 

274 
end; 

24155  275 

25101  276 
fun ensure_stmts code = 
277 
let 

278 
fun add_stmts names gr = if exists (compiled gr) names then gr else gr 

279 
> compile_stmts (map (fn name => ((name, Graph.get_node code name), 

280 
Graph.imm_succs code name)) names); 

281 
in fold_rev add_stmts (Graph.strong_conn code) end; 

24155  282 

25101  283 
fun assemble_eval gr (((vs, ty), t), deps) = 
284 
let 

285 
val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t []; 

286 
val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps); 

287 
val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs 

288 
> rev; 

289 
val bind_value = ml_fundefs [(nbe_value, 

290 
[([ml_list (map nbe_bound frees @ dict_params)], 

291 
assemble_iterm (is_some o Graph.get_node gr) (K NONE) t)])]; 

292 
val result = ml_list [nbe_value `$` ml_list 

293 
(map nbe_free frees @ map nbe_dfree dict_params)] 

294 
> ml_delay; 

295 
in ([nbe_value], ml_Let (bind_deps @ [bind_value]) result) end; 

296 

297 
fun eval_term gr = 

298 
assemble_eval gr 

299 
#> compile_univs gr 

24155  300 
#> the_single 
301 
#> snd; 

302 

303 

25101  304 
(** evaluation **) 
24155  305 

24839  306 
(* reification *) 
24155  307 

308 
fun term_of_univ thy t = 

309 
let 

25101  310 
fun take_until f [] = [] 
311 
 take_until f (x::xs) = if f x then [] else x :: take_until f xs; 

312 
fun is_dict (Const (c, _)) = 

313 
(is_some o CodeName.class_rev thy) c 

314 
orelse (is_some o CodeName.classrel_rev thy) c 

315 
orelse (is_some o CodeName.instance_rev thy) c 

316 
 is_dict (DFree _) = true 

317 
 is_dict _ = false; 

24155  318 
fun of_apps bounds (t, ts) = 
319 
fold_map (of_univ bounds) ts 

320 
#>> (fn ts' => list_comb (t, rev ts')) 

321 
and of_univ bounds (Const (name, ts)) typidx = 

322 
let 

25101  323 
val ts' = take_until is_dict ts; 
24423
ae9cd0e92423
overloaded definitions accompanied by explicit constants
haftmann
parents:
24381
diff
changeset

324 
val SOME c = CodeName.const_rev thy name; 
ae9cd0e92423
overloaded definitions accompanied by explicit constants
haftmann
parents:
24381
diff
changeset

325 
val T = Code.default_typ thy c; 
24155  326 
val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T; 
327 
val typidx' = typidx + maxidx_of_typ T' + 1; 

25101  328 
in of_apps bounds (Term.Const (c, T'), ts') typidx' end 
24155  329 
 of_univ bounds (Free (name, ts)) typidx = 
330 
of_apps bounds (Term.Free (name, dummyT), ts) typidx 

331 
 of_univ bounds (BVar (name, ts)) typidx = 

332 
of_apps bounds (Bound (bounds  name  1), ts) typidx 

333 
 of_univ bounds (t as Abs _) typidx = 

334 
typidx 

335 
> of_univ (bounds + 1) (app t (BVar (bounds, []))) 

336 
> (fn t' => pair (Term.Abs ("u", dummyT, t'))) 

337 
in of_univ 0 t 0 > fst end; 

338 

25101  339 
(* function store *) 
340 

341 
structure Nbe_Functions = CodeDataFun 

342 
( 

343 
type T = univ_gr; 

344 
val empty = Graph.empty; 

345 
fun merge _ = Graph.merge (K true); 

346 
fun purge _ NONE _ = Graph.empty 

347 
 purge NONE _ _ = Graph.empty 

348 
 purge (SOME thy) (SOME cs) gr = 

349 
let 

350 
val cs_exisiting = 

351 
map_filter (CodeName.const_rev thy) (Graph.keys gr); 

352 
val dels = (Graph.all_preds gr 

353 
o map (CodeName.const thy) 

354 
o filter (member (op =) cs_exisiting) 

355 
) cs; 

356 
in Graph.del_nodes dels gr end; 

357 
); 

358 

359 
(* compilation, evaluation and reification *) 

360 

361 
fun compile_eval thy code vs_ty_t deps = 

362 
(vs_ty_t, deps) 

363 
> eval_term (Nbe_Functions.change thy (ensure_stmts code)) 

364 
> term_of_univ thy; 

365 

24155  366 
(* evaluation with type reconstruction *) 
367 

24381  368 
fun eval thy code t vs_ty_t deps = 
24155  369 
let 
24347  370 
val ty = type_of t; 
24155  371 
fun subst_Frees [] = I 
372 
 subst_Frees inst = 

373 
Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s) 

374 
 t => t); 

375 
val anno_vars = 

376 
subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t [])) 

377 
#> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t [])) 

24347  378 
fun constrain t = 
24680  379 
singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain ty t); 
24155  380 
fun check_tvars t = if null (Term.term_tvars t) then t else 
381 
error ("Illegal schematic type variables in normalized term: " 

382 
^ setmp show_types true (Sign.string_of_term thy) t); 

383 
in 

25101  384 
compile_eval thy code vs_ty_t deps 
24155  385 
> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t) 
386 
> anno_vars 

387 
> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t) 

388 
> constrain 

389 
> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t) 

25101  390 
> tracing (fn t => "\n") 
24155  391 
> check_tvars 
392 
end; 

393 

394 
(* evaluation oracle *) 

395 

24839  396 
exception Norm of CodeThingol.code * term 
24381  397 
* (CodeThingol.typscheme * CodeThingol.iterm) * string list; 
24155  398 

24839  399 
fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) = 
24381  400 
Logic.mk_equals (t, eval thy code t vs_ty_t deps); 
24155  401 

24839  402 
fun norm_invoke thy code t vs_ty_t deps = 
403 
Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps)); 

24283  404 
(*FIXME get rid of hardwired theory name*) 
24155  405 

24839  406 
fun norm_conv ct = 
24155  407 
let 
408 
val thy = Thm.theory_of_cterm ct; 

24381  409 
fun conv code vs_ty_t deps ct = 
24155  410 
let 
411 
val t = Thm.term_of ct; 

24839  412 
in norm_invoke thy code t vs_ty_t deps end; 
24219  413 
in CodePackage.eval_conv thy conv ct end; 
24155  414 

24839  415 
fun norm_term thy = 
416 
let 

417 
fun invoke code vs_ty_t deps t = 

418 
eval thy code t vs_ty_t deps; 

419 
in CodePackage.eval_term thy invoke #> Code.postprocess_term thy end; 

420 

24155  421 
(* evaluation command *) 
422 

423 
fun norm_print_term ctxt modes t = 

424 
let 

425 
val thy = ProofContext.theory_of ctxt; 

24839  426 
val t' = norm_term thy t; 
427 
val ty' = Term.type_of t'; 

24634  428 
val p = PrintMode.with_modes modes (fn () => 
24920  429 
Pretty.block [Pretty.quote (Syntax.pretty_term ctxt t'), Pretty.fbrk, 
430 
Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt ty')]) (); 

24155  431 
in Pretty.writeln p end; 
432 

433 

434 
(** Isar setup **) 

435 

24508
c8b82fec6447
replaced ProofContext.read_term/prop by general Syntax.read_term/prop;
wenzelm
parents:
24493
diff
changeset

436 
fun norm_print_term_cmd (modes, s) state = 
24155  437 
let val ctxt = Toplevel.context_of state 
24508
c8b82fec6447
replaced ProofContext.read_term/prop by general Syntax.read_term/prop;
wenzelm
parents:
24493
diff
changeset

438 
in norm_print_term ctxt modes (Syntax.read_term ctxt s) end; 
24155  439 

24839  440 
val setup = Theory.add_oracle ("norm", norm_oracle) 
24155  441 

442 
local structure P = OuterParse and K = OuterKeyword in 

443 

444 
val opt_modes = Scan.optional (P.$$$ "("  P.!!! (Scan.repeat1 P.xname  P.$$$ ")")) []; 

445 

24867  446 
val _ = 
24155  447 
OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag 
448 
(opt_modes  P.typ >> (Toplevel.keep o norm_print_term_cmd)); 

449 

450 
end; 

451 

452 
end; 