author  haftmann 
Wed, 13 Feb 2008 09:35:33 +0100  
changeset 26064  65585de05a66 
parent 26011  d55224947082 
child 26739  947b6013e863 
permissions  rwrr 
24590  1 
(* Title: Tools/nbe.ML 
24155  2 
ID: $Id$ 
3 
Authors: Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen 

4 

24839  5 
Normalization by evaluation, based on generic code generator. 
24155  6 
*) 
7 

8 
signature NBE = 

9 
sig 

25101  10 
val norm_conv: cterm > thm 
11 
val norm_term: theory > term > term 

12 

25204  13 
datatype Univ = 
26064  14 
Const of int * Univ list (*named (uninterpreted) constants*) 
25924  15 
 Free of string * Univ list (*free (uninterpreted) variables*) 
16 
 DFree of string * int (*free (uninterpreted) dictionary parameters*) 

24155  17 
 BVar of int * Univ list 
18 
 Abs of (int * (Univ list > Univ)) * Univ list; 

25924  19 
val apps: Univ > Univ list > Univ (*explicit applications*) 
25944  20 
val abss: int > (Univ list > Univ) > Univ 
25924  21 
(*abstractions as closures*) 
24155  22 

25204  23 
val univs_ref: (unit > Univ list > Univ list) option ref 
24155  24 
val trace: bool ref 
25924  25 

24155  26 
val setup: theory > theory 
27 
end; 

28 

29 
structure Nbe: NBE = 

30 
struct 

31 

32 
(* generic nonsense *) 

33 

34 
val trace = ref false; 

35 
fun tracing f x = if !trace then (Output.tracing (f x); x) else x; 

36 

37 

38 
(** the semantical universe **) 

39 

40 
(* 

41 
Functions are given by their semantical function value. To avoid 

42 
trouble with the MLtype system, these functions have the most 

43 
generic type, that is "Univ list > Univ". The calling convention is 

44 
that the arguments come as a list, the last argument first. In 

45 
other words, a function call that usually would look like 

46 

47 
f x_1 x_2 ... x_n or f(x_1,x_2, ..., x_n) 

48 

49 
would be in our convention called as 

50 

51 
f [x_n,..,x_2,x_1] 

52 

53 
Moreover, to handle functions that are still waiting for some 

54 
arguments we have additionally a list of arguments collected to far 

55 
and the number of arguments we're still waiting for. 

56 
*) 

57 

25204  58 
datatype Univ = 
26064  59 
Const of int * Univ list (*named (uninterpreted) constants*) 
24155  60 
 Free of string * Univ list (*free variables*) 
25924  61 
 DFree of string * int (*free (uninterpreted) dictionary parameters*) 
24155  62 
 BVar of int * Univ list (*bound named variables*) 
63 
 Abs of (int * (Univ list > Univ)) * Univ list 

24381  64 
(*abstractions as closures*); 
24155  65 

66 
(* constructor functions *) 

67 

25944  68 
fun abss n f = Abs ((n, f), []); 
25924  69 
fun apps (Abs ((n, f), xs)) ys = let val k = n  length ys in 
70 
if k = 0 then f (ys @ xs) 

71 
else if k < 0 then 

72 
let val (zs, ws) = chop (~ k) ys 

73 
in apps (f (ws @ xs)) zs end 

74 
else Abs ((k, f), ys @ xs) end (*note: reverse convention also for apps!*) 

75 
 apps (Const (name, xs)) ys = Const (name, ys @ xs) 

76 
 apps (Free (name, xs)) ys = Free (name, ys @ xs) 

77 
 apps (BVar (name, xs)) ys = BVar (name, ys @ xs); 

24155  78 

79 

80 
(** assembling and compiling ML code from terms **) 

81 

82 
(* abstract ML syntax *) 

83 

84 
infix 9 `$` `$$`; 

85 
fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")"; 

25101  86 
fun e `$$` [] = e 
87 
 e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")"; 

24590  88 
fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")"; 
24155  89 

90 
fun ml_cases t cs = 

91 
"(case " ^ t ^ " of " ^ space_implode "  " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")"; 

25944  92 
fun ml_Let d e = "let\n" ^ d ^ " in " ^ e ^ " end"; 
24155  93 

94 
fun ml_list es = "[" ^ commas es ^ "]"; 

95 

96 
fun ml_fundefs ([(name, [([], e)])]) = 

97 
"val " ^ name ^ " = " ^ e ^ "\n" 

98 
 ml_fundefs (eqs :: eqss) = 

99 
let 

100 
fun fundef (name, eqs) = 

101 
let 

102 
fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e 

103 
in space_implode "\n  " (map eqn eqs) end; 

104 
in 

105 
(prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss 

106 
> space_implode "\n" 

107 
> suffix "\n" 

108 
end; 

109 

25944  110 
(* nbe specific syntax and sandbox communication *) 
111 

112 
val univs_ref = ref (NONE : (unit > Univ list > Univ list) option); 

24155  113 

114 
local 

115 
val prefix = "Nbe."; 

25944  116 
val name_ref = prefix ^ "univs_ref"; 
24155  117 
val name_const = prefix ^ "Const"; 
25944  118 
val name_abss = prefix ^ "abss"; 
25924  119 
val name_apps = prefix ^ "apps"; 
24155  120 
in 
121 

25944  122 
val univs_cookie = (name_ref, univs_ref); 
123 

25935  124 
fun nbe_fun "" = "nbe_value" 
125 
 nbe_fun c = "c_" ^ translate_string (fn "." => "_"  c => c) c; 

25101  126 
fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n; 
24155  127 
fun nbe_bound v = "v_" ^ v; 
128 

25924  129 
(*note: these three are the "turning spots" where proper argument order is established!*) 
130 
fun nbe_apps t [] = t 

131 
 nbe_apps t ts = name_apps `$$` [t, ml_list (rev ts)]; 

132 
fun nbe_apps_local c ts = nbe_fun c `$` ml_list (rev ts); 

26064  133 
fun nbe_apps_constr idx ts = 
134 
name_const `$` ("(" ^ string_of_int idx ^ ", " ^ ml_list (rev ts) ^ ")"); 

25924  135 

24155  136 
fun nbe_abss 0 f = f `$` ml_list [] 
25944  137 
 nbe_abss n f = name_abss `$$` [string_of_int n, f]; 
24155  138 

139 
end; 

140 

24219  141 
open BasicCodeThingol; 
24155  142 

25865  143 
(* code generation *) 
24155  144 

26064  145 
fun assemble_eqnss idx_of deps eqnss = 
25944  146 
let 
147 
fun prep_eqns (c, (vs, eqns)) = 

25924  148 
let 
25944  149 
val dicts = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs; 
150 
val num_args = length dicts + (length o fst o hd) eqns; 

151 
in (c, (num_args, (dicts, eqns))) end; 

152 
val eqnss' = map prep_eqns eqnss; 

153 

154 
fun assemble_constapp c dss ts = 

155 
let 

156 
val ts' = (maps o map) assemble_idict dss @ ts; 

157 
in case AList.lookup (op =) eqnss' c 

158 
of SOME (n, _) => if n <= length ts' 

25924  159 
then let val (ts1, ts2) = chop n ts' 
160 
in nbe_apps (nbe_apps_local c ts1) ts2 

161 
end else nbe_apps (nbe_abss n (nbe_fun c)) ts' 

25944  162 
 NONE => if member (op =) deps c 
163 
then nbe_apps (nbe_fun c) ts' 

26064  164 
else nbe_apps_constr (idx_of c) ts' 
25924  165 
end 
25944  166 
and assemble_idict (DictConst (inst, dss)) = 
167 
assemble_constapp inst dss [] 

168 
 assemble_idict (DictVar (supers, (v, (n, _)))) = 

169 
fold_rev (fn super => assemble_constapp super [] o single) supers (nbe_dict v n); 

25924  170 

25944  171 
fun assemble_iterm constapp = 
24155  172 
let 
25944  173 
fun of_iterm t = 
174 
let 

175 
val (t', ts) = CodeThingol.unfold_app t 

176 
in of_iapp t' (fold_rev (cons o of_iterm) ts []) end 

177 
and of_iapp (IConst (c, (dss, _))) ts = constapp c dss ts 

178 
 of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts 

179 
 of_iapp ((v, _) `> t) ts = 

180 
nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts 

181 
 of_iapp (ICase (((t, _), cs), t0)) ts = 

182 
nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs 

183 
@ [("_", of_iterm t0)])) ts 

184 
in of_iterm end; 

24155  185 

25944  186 
fun assemble_eqns (c, (num_args, (dicts, eqns))) = 
187 
let 

26064  188 
val assemble_arg = assemble_iterm 
189 
(fn c => fn _ => fn ts => nbe_apps_constr (idx_of c) ts); 

25944  190 
val assemble_rhs = assemble_iterm assemble_constapp; 
191 
fun assemble_eqn (args, rhs) = 

192 
([ml_list (rev (dicts @ map assemble_arg args))], assemble_rhs rhs); 

193 
val default_args = map nbe_bound (Name.invent_list [] "a" num_args); 

26064  194 
val default_eqn = if c = "" then NONE 
195 
else SOME ([ml_list (rev default_args)], 

196 
nbe_apps_constr (idx_of c) default_args); 

25944  197 
in 
26064  198 
((nbe_fun c, map assemble_eqn eqns @ the_list default_eqn), 
25944  199 
nbe_abss num_args (nbe_fun c)) 
200 
end; 

24155  201 

25944  202 
val (fun_vars, fun_vals) = map_split assemble_eqns eqnss'; 
203 
val deps_vars = ml_list (map nbe_fun deps); 

204 
in ml_abs deps_vars (ml_Let (ml_fundefs fun_vars) (ml_list fun_vals)) end; 

205 

206 
(* code compilation *) 

207 

208 
fun compile_eqnss gr raw_deps [] = [] 

209 
 compile_eqnss gr raw_deps eqnss = 

24155  210 
let 
26064  211 
val (deps, deps_vals) = split_list (map_filter 
212 
(fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node gr dep)))) raw_deps); 

213 
val idx_of = raw_deps 

214 
> map (fn dep => (dep, snd (Graph.get_node gr dep))) 

215 
> AList.lookup (op =) 

216 
> (fn f => the o f); 

217 
val s = assemble_eqnss idx_of deps eqnss; 

24155  218 
val cs = map fst eqnss; 
25944  219 
in 
220 
s 

221 
> tracing (fn s => "\n code to be evaluated:\n" ^ s) 

222 
> ML_Context.evaluate 

223 
(Output.tracing o enclose "\ncompiler echo:\n" "\n\n", 

224 
Output.tracing o enclose "\n compiler echo (with error):\n" "\n\n") 

225 
(!trace) univs_cookie 

226 
> (fn f => f deps_vals) 

227 
> (fn univs => cs ~~ univs) 

228 
end; 

25190  229 

25944  230 
(* preparing function equations *) 
24155  231 

25101  232 
fun eqns_of_stmt (_, CodeThingol.Fun (_, [])) = 
233 
[] 

234 
 eqns_of_stmt (const, CodeThingol.Fun ((vs, _), eqns)) = 

235 
[(const, (vs, map fst eqns))] 

236 
 eqns_of_stmt (_, CodeThingol.Datatypecons _) = 

237 
[] 

238 
 eqns_of_stmt (_, CodeThingol.Datatype _) = 

239 
[] 

240 
 eqns_of_stmt (class, CodeThingol.Class (v, (superclasses, classops))) = 

241 
let 

242 
val names = map snd superclasses @ map fst classops; 

243 
val params = Name.invent_list [] "d" (length names); 

244 
fun mk (k, name) = 

245 
(name, ([(v, [])], 

246 
[([IConst (class, ([], [])) `$$ map IVar params], IVar (nth params k))])); 

247 
in map_index mk names end 

248 
 eqns_of_stmt (_, CodeThingol.Classrel _) = 

249 
[] 

250 
 eqns_of_stmt (_, CodeThingol.Classparam _) = 

251 
[] 

252 
 eqns_of_stmt (inst, CodeThingol.Classinst ((class, (_, arities)), (superinsts, instops))) = 

253 
[(inst, (arities, [([], IConst (class, ([], [])) `$$ 

254 
map (fn (_, (_, (inst, dicts))) => IConst (inst, (dicts, []))) superinsts 

255 
@ map (IConst o snd o fst) instops)]))]; 

24155  256 

25101  257 
fun compile_stmts stmts_deps = 
258 
let 

259 
val names = map (fst o fst) stmts_deps; 

260 
val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps; 

261 
val eqnss = maps (eqns_of_stmt o fst) stmts_deps; 

26064  262 
val refl_deps = names_deps 
25190  263 
> maps snd 
264 
> distinct (op =) 

26064  265 
> fold (insert (op =)) names; 
266 
fun new_node name (gr, (maxidx, idx_tab)) = if can (Graph.get_node gr) name 

267 
then (gr, (maxidx, idx_tab)) 

268 
else (Graph.new_node (name, (NONE, maxidx)) gr, 

269 
(maxidx + 1, Inttab.update_new (maxidx, name) idx_tab)); 

25190  270 
fun compile gr = eqnss 
26064  271 
> compile_eqnss gr refl_deps 
25190  272 
> rpair gr; 
25101  273 
in 
26064  274 
fold new_node refl_deps 
275 
#> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps 

276 
#> compile 

277 
#> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ)))) 

25101  278 
end; 
24155  279 

25101  280 
fun ensure_stmts code = 
281 
let 

26064  282 
fun add_stmts names (gr, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) gr) names 
283 
then (gr, (maxidx, idx_tab)) 

284 
else (gr, (maxidx, idx_tab)) 

285 
> compile_stmts (map (fn name => ((name, Graph.get_node code name), 

25101  286 
Graph.imm_succs code name)) names); 
287 
in fold_rev add_stmts (Graph.strong_conn code) end; 

24155  288 

25944  289 

290 
(** evaluation **) 

291 

292 
(* term evaluation *) 

293 

25924  294 
fun eval_term gr deps ((vs, ty), t) = 
295 
let 

296 
val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [] 

297 
val frees' = map (fn v => Free (v, [])) frees; 

298 
val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs; 

299 
in 

25935  300 
("", (vs, [(map IVar frees, t)])) 
25924  301 
> singleton (compile_eqnss gr deps) 
302 
> snd 

303 
> (fn t => apps t (rev (dict_frees @ frees'))) 

304 
end; 

24155  305 

24839  306 
(* reification *) 
24155  307 

26064  308 
fun term_of_univ thy idx_tab t = 
24155  309 
let 
25101  310 
fun take_until f [] = [] 
311 
 take_until f (x::xs) = if f x then [] else x :: take_until f xs; 

26064  312 
fun is_dict (Const (idx, _)) = 
313 
let 

314 
val c = the (Inttab.lookup idx_tab idx); 

315 
in 

316 
(is_some o CodeName.class_rev thy) c 

317 
orelse (is_some o CodeName.classrel_rev thy) c 

318 
orelse (is_some o CodeName.instance_rev thy) c 

319 
end 

25101  320 
 is_dict (DFree _) = true 
321 
 is_dict _ = false; 

24155  322 
fun of_apps bounds (t, ts) = 
323 
fold_map (of_univ bounds) ts 

324 
#>> (fn ts' => list_comb (t, rev ts')) 

26064  325 
and of_univ bounds (Const (idx, ts)) typidx = 
24155  326 
let 
25101  327 
val ts' = take_until is_dict ts; 
26064  328 
val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx; 
24423
ae9cd0e92423
overloaded definitions accompanied by explicit constants
haftmann
parents:
24381
diff
changeset

329 
val T = Code.default_typ thy c; 
24155  330 
val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T; 
331 
val typidx' = typidx + maxidx_of_typ T' + 1; 

25101  332 
in of_apps bounds (Term.Const (c, T'), ts') typidx' end 
24155  333 
 of_univ bounds (Free (name, ts)) typidx = 
334 
of_apps bounds (Term.Free (name, dummyT), ts) typidx 

335 
 of_univ bounds (BVar (name, ts)) typidx = 

336 
of_apps bounds (Bound (bounds  name  1), ts) typidx 

337 
 of_univ bounds (t as Abs _) typidx = 

338 
typidx 

25924  339 
> of_univ (bounds + 1) (apps t [BVar (bounds, [])]) 
24155  340 
> (fn t' => pair (Term.Abs ("u", dummyT, t'))) 
341 
in of_univ 0 t 0 > fst end; 

342 

25101  343 
(* function store *) 
344 

345 
structure Nbe_Functions = CodeDataFun 

346 
( 

26064  347 
type T = (Univ option * int) Graph.T * (int * string Inttab.table); 
348 
val empty = (Graph.empty, (0, Inttab.empty)); 

349 
fun merge _ ((gr1, (maxidx1, idx_tab1)), (gr2, (maxidx2, idx_tab2))) = 

350 
(Graph.merge (K true) (gr1, gr2), (IntInf.max (maxidx1, maxidx2), 

351 
Inttab.merge (K true) (idx_tab1, idx_tab2))); 

352 
fun purge _ NONE _ = empty 

353 
 purge NONE _ _ = empty 

354 
 purge (SOME thy) (SOME cs) (gr, (maxidx, idx_tab)) = 

25101  355 
let 
356 
val cs_exisiting = 

357 
map_filter (CodeName.const_rev thy) (Graph.keys gr); 

358 
val dels = (Graph.all_preds gr 

359 
o map (CodeName.const thy) 

360 
o filter (member (op =) cs_exisiting) 

361 
) cs; 

26064  362 
in (Graph.del_nodes dels gr, (maxidx, idx_tab)) end; 
25101  363 
); 
364 

365 
(* compilation, evaluation and reification *) 

366 

367 
fun compile_eval thy code vs_ty_t deps = 

26064  368 
let 
369 
val (gr, (_, idx_tab)) = Nbe_Functions.change thy (ensure_stmts code); 

370 
in 

371 
vs_ty_t 

372 
> eval_term gr deps 

373 
> term_of_univ thy idx_tab 

374 
end; 

25101  375 

24155  376 
(* evaluation with type reconstruction *) 
377 

24381  378 
fun eval thy code t vs_ty_t deps = 
24155  379 
let 
24347  380 
val ty = type_of t; 
24155  381 
fun subst_Frees [] = I 
382 
 subst_Frees inst = 

383 
Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s) 

384 
 t => t); 

385 
val anno_vars = 

386 
subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t [])) 

387 
#> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t [])) 

24347  388 
fun constrain t = 
24680  389 
singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain ty t); 
24155  390 
fun check_tvars t = if null (Term.term_tvars t) then t else 
391 
error ("Illegal schematic type variables in normalized term: " 

392 
^ setmp show_types true (Sign.string_of_term thy) t); 

25167  393 
val string_of_term = setmp show_types true (Sign.string_of_term thy); 
24155  394 
in 
25101  395 
compile_eval thy code vs_ty_t deps 
25167  396 
> tracing (fn t => "Normalized:\n" ^ string_of_term t) 
24155  397 
> anno_vars 
25167  398 
> tracing (fn t => "Vars typed:\n" ^ string_of_term t) 
24155  399 
> constrain 
25167  400 
> tracing (fn t => "Types inferred:\n" ^ string_of_term t) 
25101  401 
> tracing (fn t => "\n") 
24155  402 
> check_tvars 
403 
end; 

404 

405 
(* evaluation oracle *) 

406 

24839  407 
exception Norm of CodeThingol.code * term 
24381  408 
* (CodeThingol.typscheme * CodeThingol.iterm) * string list; 
24155  409 

24839  410 
fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) = 
24381  411 
Logic.mk_equals (t, eval thy code t vs_ty_t deps); 
24155  412 

24839  413 
fun norm_invoke thy code t vs_ty_t deps = 
414 
Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps)); 

24283  415 
(*FIXME get rid of hardwired theory name*) 
24155  416 

24839  417 
fun norm_conv ct = 
24155  418 
let 
419 
val thy = Thm.theory_of_cterm ct; 

24381  420 
fun conv code vs_ty_t deps ct = 
24155  421 
let 
422 
val t = Thm.term_of ct; 

24839  423 
in norm_invoke thy code t vs_ty_t deps end; 
26011  424 
in CodePackage.evaluate_conv thy conv ct end; 
24155  425 

24839  426 
fun norm_term thy = 
427 
let 

428 
fun invoke code vs_ty_t deps t = 

429 
eval thy code t vs_ty_t deps; 

26011  430 
in CodePackage.evaluate_term thy invoke #> Code.postprocess_term thy end; 
24839  431 

24155  432 
(* evaluation command *) 
433 

434 
fun norm_print_term ctxt modes t = 

435 
let 

436 
val thy = ProofContext.theory_of ctxt; 

24839  437 
val t' = norm_term thy t; 
438 
val ty' = Term.type_of t'; 

24634  439 
val p = PrintMode.with_modes modes (fn () => 
24920  440 
Pretty.block [Pretty.quote (Syntax.pretty_term ctxt t'), Pretty.fbrk, 
441 
Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt ty')]) (); 

24155  442 
in Pretty.writeln p end; 
443 

444 

445 
(** Isar setup **) 

446 

24508
c8b82fec6447
replaced ProofContext.read_term/prop by general Syntax.read_term/prop;
wenzelm
parents:
24493
diff
changeset

447 
fun norm_print_term_cmd (modes, s) state = 
24155  448 
let val ctxt = Toplevel.context_of state 
24508
c8b82fec6447
replaced ProofContext.read_term/prop by general Syntax.read_term/prop;
wenzelm
parents:
24493
diff
changeset

449 
in norm_print_term ctxt modes (Syntax.read_term ctxt s) end; 
24155  450 

24839  451 
val setup = Theory.add_oracle ("norm", norm_oracle) 
24155  452 

453 
local structure P = OuterParse and K = OuterKeyword in 

454 

455 
val opt_modes = Scan.optional (P.$$$ "("  P.!!! (Scan.repeat1 P.xname  P.$$$ ")")) []; 

456 

24867  457 
val _ = 
24155  458 
OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag 
459 
(opt_modes  P.typ >> (Toplevel.keep o norm_print_term_cmd)); 

460 

461 
end; 

462 

463 
end; 