(* Title: Tools/nbe.ML 
Normalization by evaluation, based on generic code generator. 
*) 
signature NBE = 

sig 

val norm_conv: cterm > thm 
val norm_term: theory > term > term 

datatype Univ = 
Const of int * Univ list (*named (uninterpreted) constants*) 
 Free of string * Univ list (*free (uninterpreted) variables*) 
 DFree of string * int (*free (uninterpreted) dictionary parameters*) 

 BVar of int * Univ list 
 Abs of (int * (Univ list > Univ)) * Univ list; 

val apps: Univ > Univ list > Univ (*explicit applications*) 
val abss: int > (Univ list > Univ) > Univ 
(*abstractions as closures*) 
25204  23 
val univs_ref: (unit > Univ list > Univ list) option ref 
val trace: bool ref 
25924  25 

val setup: theory > theory 
end; 

structure Nbe: NBE = 

struct 

(* generic nonsense *) 

val trace = ref false; 

fun tracing f x = if !trace then (Output.tracing (f x); x) else x; 

(** the semantical universe **) 

(* 

Functions are given by their semantical function value. To avoid 

trouble with the MLtype system, these functions have the most 

generic type, that is "Univ list > Univ". The calling convention is 

that the arguments come as a list, the last argument first. In 

other words, a function call that usually would look like 

47 
f x_1 x_2 ... x_n or f(x_1,x_2, ..., x_n) 

49 
would be in our convention called as 

51 
f [x_n,..,x_2,x_1] 

Moreover, to handle functions that are still waiting for some 

arguments we have additionally a list of arguments collected to far 

and the number of arguments we're still waiting for. 

*) 

datatype Univ = 
Const of int * Univ list (*named (uninterpreted) constants*) 
 Free of string * Univ list (*free variables*) 
 DFree of string * int (*free (uninterpreted) dictionary parameters*) 
 BVar of int * Univ list (*bound named variables*) 
 Abs of (int * (Univ list > Univ)) * Univ list 

(*abstractions as closures*); 
(* constructor functions *) 

fun abss n f = Abs ((n, f), []); 
fun apps (Abs ((n, f), xs)) ys = let val k = n  length ys in 
if k = 0 then f (ys @ xs) 

else if k < 0 then 

let val (zs, ws) = chop (~ k) ys 

in apps (f (ws @ xs)) zs end 

else Abs ((k, f), ys @ xs) end (*note: reverse convention also for apps!*) 

 apps (Const (name, xs)) ys = Const (name, ys @ xs) 

 apps (Free (name, xs)) ys = Free (name, ys @ xs) 

 apps (BVar (name, xs)) ys = BVar (name, ys @ xs); 

(** assembling and compiling ML code from terms **) 

(* abstract ML syntax *) 

infix 9 `$` `$$`; 

fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")"; 

fun e `$$` [] = e 
fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")"; 
fun ml_cases t cs = 

"(case " ^ t ^ " of " ^ space_implode "  " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")"; 

fun ml_Let d e = "let\n" ^ d ^ " in " ^ e ^ " end"; 
94 
fun ml_fundefs ([(name, [([], e)])]) = 

"val " ^ name ^ " = " ^ e ^ "\n" 

 ml_fundefs (eqs :: eqss) = 

let 

fun fundef (name, eqs) = 

let 

fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e 

in space_implode "\n  " (map eqn eqs) end; 

105 
106 
107 
108 
(* nbe specific syntax and sandbox communication *) 
val univs_ref = ref (NONE : (unit > Univ list > Univ list) option); 

local 

val prefix = "Nbe."; 

val name_ref = prefix ^ "univs_ref"; 
val name_const = prefix ^ "Const"; 
val name_abss = prefix ^ "abss"; 
val name_apps = prefix ^ "apps"; 
in 
25944  122 
val univs_cookie = (name_ref, univs_ref); 
25935  124 
125 
25101  126 
24155  127 
fun nbe_bound v = "v_" ^ v; 
25924  129 
130 
131 
132 
26064  133 
134 
name_const `$` ("(" ^ string_of_int idx ^ ", " ^ ml_list (rev ts) ^ ")"); 

fun nbe_abss 0 f = f `$` ml_list [] 
 nbe_abss n f = name_abss `$$` [string_of_int n, f]; 
139 
24219  141 
24155  142 

(* code generation *) 
26064  145 
25944  146 
147 
25924  148 
25944  149 
150 
151 
152 
153 

fun assemble_constapp c dss ts = 

let 

val ts' = (maps o map) assemble_idict dss @ ts; 

in case AList.lookup (op =) eqnss' c 

of SOME (n, _) => if n <= length ts' 

then let val (ts1, ts2) = chop n ts' 
in nbe_apps (nbe_apps_local c ts1) ts2 

end else nbe_apps (nbe_abss n (nbe_fun c)) ts' 

 NONE => if member (op =) deps c 
then nbe_apps (nbe_fun c) ts' 

else nbe_apps_constr (idx_of c) ts' 
end 
and assemble_idict (DictConst (inst, dss)) = 
assemble_constapp inst dss [] 

 assemble_idict (DictVar (supers, (v, (n, _)))) = 

fold_rev (fn super => assemble_constapp super [] o single) supers (nbe_dict v n); 

25944  171 
24155  172 
25944  173 
174 
175 
176 
177 
178 
179 
180 
181 
 of_iapp (ICase (((t, _), cs), t0)) ts = 

182 
nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs 

183 
@ [("_", of_iterm t0)])) ts 

184 
in of_iterm end; 

24155  185 

25944  186 
fun assemble_eqns (c, (num_args, (dicts, eqns))) = 
187 
let 

26064  188 
val assemble_arg = assemble_iterm 
189 
(fn c => fn _ => fn ts => nbe_apps_constr (idx_of c) ts); 

25944  190 
val assemble_rhs = assemble_iterm assemble_constapp; 
191 
fun assemble_eqn (args, rhs) = 

192 
([ml_list (rev (dicts @ map assemble_arg args))], assemble_rhs rhs); 

193 
val default_args = map nbe_bound (Name.invent_list [] "a" num_args); 

26064  194 
val default_eqn = if c = "" then NONE 
195 
else SOME ([ml_list (rev default_args)], 

196 
nbe_apps_constr (idx_of c) default_args); 

25944  197 
in 
26064  198 
((nbe_fun c, map assemble_eqn eqns @ the_list default_eqn), 
25944  199 
nbe_abss num_args (nbe_fun c)) 
200 
end; 

24155  201 

25944  202 
val (fun_vars, fun_vals) = map_split assemble_eqns eqnss'; 
203 
val deps_vars = ml_list (map nbe_fun deps); 

204 
in ml_abs deps_vars (ml_Let (ml_fundefs fun_vars) (ml_list fun_vals)) end; 

205 

206 
(* code compilation *) 

207 

208 
fun compile_eqnss gr raw_deps [] = [] 

209 
 compile_eqnss gr raw_deps eqnss = 

24155  210 
let 
26064  211 
val (deps, deps_vals) = split_list (map_filter 
212 
(fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node gr dep)))) raw_deps); 

213 
val idx_of = raw_deps 

214 
> map (fn dep => (dep, snd (Graph.get_node gr dep))) 

215 
> AList.lookup (op =) 

216 
> (fn f => the o f); 

217 
val s = assemble_eqnss idx_of deps eqnss; 

24155  218 
val cs = map fst eqnss; 
25944  219 
in 
220 
s 

221 
> tracing (fn s => "\n code to be evaluated:\n" ^ s) 

222 
> ML_Context.evaluate 

223 
(Output.tracing o enclose "\ncompiler echo:\n" "\n\n", 

224 
Output.tracing o enclose "\n compiler echo (with error):\n" "\n\n") 

225 
(!trace) univs_cookie 

226 
> (fn f => f deps_vals) 

227 
> (fn univs => cs ~~ univs) 

228 
end; 

25190  229 

25944  230 
(* preparing function equations *) 
24155  231 

25101  232 
fun eqns_of_stmt (_, CodeThingol.Fun (_, [])) = 
233 
[] 

234 
 eqns_of_stmt (const, CodeThingol.Fun ((vs, _), eqns)) = 

235 
[(const, (vs, map fst eqns))] 

236 
 eqns_of_stmt (_, CodeThingol.Datatypecons _) = 

237 
[] 

238 
 eqns_of_stmt (_, CodeThingol.Datatype _) = 

239 
[] 

240 
 eqns_of_stmt (class, CodeThingol.Class (v, (superclasses, classops))) = 

241 
let 

242 
val names = map snd superclasses @ map fst classops; 

243 
val params = Name.invent_list [] "d" (length names); 

244 
fun mk (k, name) = 

245 
(name, ([(v, [])], 

246 
[([IConst (class, ([], [])) `$$ map IVar params], IVar (nth params k))])); 

247 
in map_index mk names end 

248 
 eqns_of_stmt (_, CodeThingol.Classrel _) = 

249 
[] 

250 
 eqns_of_stmt (_, CodeThingol.Classparam _) = 

251 
[] 

252 
 eqns_of_stmt (inst, CodeThingol.Classinst ((class, (_, arities)), (superinsts, instops))) = 

253 
[(inst, (arities, [([], IConst (class, ([], [])) `$$ 

254 
map (fn (_, (_, (inst, dicts))) => IConst (inst, (dicts, []))) superinsts 

255 
@ map (IConst o snd o fst) instops)]))]; 

24155  256 

25101  257 
fun compile_stmts stmts_deps = 
258 
let 

259 
val names = map (fst o fst) stmts_deps; 

260 
val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps; 

261 
val eqnss = maps (eqns_of_stmt o fst) stmts_deps; 

26064  262 
val refl_deps = names_deps 
25190  263 
> maps snd 
264 
> distinct (op =) 

26064  265 
> fold (insert (op =)) names; 
266 
fun new_node name (gr, (maxidx, idx_tab)) = if can (Graph.get_node gr) name 

267 
then (gr, (maxidx, idx_tab)) 

268 
else (Graph.new_node (name, (NONE, maxidx)) gr, 

269 
(maxidx + 1, Inttab.update_new (maxidx, name) idx_tab)); 

25190  270 
fun compile gr = eqnss 
26064  271 
> compile_eqnss gr refl_deps 
25190  272 
> rpair gr; 
25101  273 
in 
26064  274 
fold new_node refl_deps 
275 
#> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps 

276 
#> compile 

277 
#> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ)))) 

25101  278 
end; 
24155  279 

25101  280 
fun ensure_stmts code = 
281 
let 

26064  282 
fun add_stmts names (gr, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) gr) names 
283 
then (gr, (maxidx, idx_tab)) 

284 
else (gr, (maxidx, idx_tab)) 

285 
> compile_stmts (map (fn name => ((name, Graph.get_node code name), 

25101  286 
Graph.imm_succs code name)) names); 
287 
in fold_rev add_stmts (Graph.strong_conn code) end; 

24155  288 

25944  289 

290 
(** evaluation **) 

291 

292 
(* term evaluation *) 

293 

25924  294 
fun eval_term gr deps ((vs, ty), t) = 
295 
let 

296 
val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [] 

297 
val frees' = map (fn v => Free (v, [])) frees; 

298 
val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs; 

299 
in 

25935  300 
("", (vs, [(map IVar frees, t)])) 
25924  301 
> singleton (compile_eqnss gr deps) 
302 
> snd 

303 
> (fn t => apps t (rev (dict_frees @ frees'))) 

304 
end; 

24155  305 

24839  306 
(* reification *) 
24155  307 

26064  308 
fun term_of_univ thy idx_tab t = 
24155  309 
let 
25101  310 
fun take_until f [] = [] 
311 
 take_until f (x::xs) = if f x then [] else x :: take_until f xs; 

26064  312 
fun is_dict (Const (idx, _)) = 
313 
let 

314 
val c = the (Inttab.lookup idx_tab idx); 

315 
in 

316 
(is_some o CodeName.class_rev thy) c 

317 
orelse (is_some o CodeName.classrel_rev thy) c 

318 
orelse (is_some o CodeName.instance_rev thy) c 

319 
end 

25101  320 
 is_dict (DFree _) = true 
321 
 is_dict _ = false; 

24155  322 
fun of_apps bounds (t, ts) = 
323 
fold_map (of_univ bounds) ts 

324 
#>> (fn ts' => list_comb (t, rev ts')) 

26064  325 
and of_univ bounds (Const (idx, ts)) typidx = 
24155  326 
let 
25101  327 
val ts' = take_until is_dict ts; 
26064  328 
val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx; 
val T = Code.default_typ thy c; 
val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T; 
25101  332 
24155  333 
334 
335 
336 
337 
338 
> of_univ (bounds + 1) (apps t [BVar (bounds, [])]) 
> (fn t' => pair (Term.Abs ("u", dummyT, t'))) 
(* function store *) 
structure Nbe_Functions = CodeDataFun 

( 

type T = (Univ option * int) Graph.T * (int * string Inttab.table); 
val empty = (Graph.empty, (0, Inttab.empty)); 

fun merge _ ((gr1, (maxidx1, idx_tab1)), (gr2, (maxidx2, idx_tab2))) = 

(Graph.merge (K true) (gr1, gr2), (IntInf.max (maxidx1, maxidx2), 

Inttab.merge (K true) (idx_tab1, idx_tab2))); 

fun purge _ NONE _ = empty 

 purge NONE _ _ = empty 

 purge (SOME thy) (SOME cs) (gr, (maxidx, idx_tab)) = 

356 
357 
358 
359 
360 
361 
26064  362 
25101  363 
365 
367 
26064  368 
369 
370 
vs_ty_t 

> eval_term gr deps 

> term_of_univ thy idx_tab 

24155  376 
377 

fun eval thy code t vs_ty_t deps = 
let 
val ty = type_of t; 
fun subst_Frees [] = I 
 subst_Frees inst = 

Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s) 

 t => t); 

val anno_vars = 

subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t [])) 

#> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t [])) 

fun constrain t = 
singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain ty t); 
fun check_tvars t = if null (Term.term_tvars t) then t else 
error ("Illegal schematic type variables in normalized term: " 

^ setmp show_types true (Sign.string_of_term thy) t); 

val string_of_term = setmp show_types true (Sign.string_of_term thy); 
in 
compile_eval thy code vs_ty_t deps 
> tracing (fn t => "Normalized:\n" ^ string_of_term t) 
> anno_vars 
> tracing (fn t => "Vars typed:\n" ^ string_of_term t) 
> constrain 
> tracing (fn t => "Types inferred:\n" ^ string_of_term t) 
> tracing (fn t => "\n") 
> check_tvars 
end; 

(* evaluation oracle *) 

24839  407 
24381  408 
24155  409 

fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) = 
Logic.mk_equals (t, eval thy code t vs_ty_t deps); 
24839  413 
414 
24283  415 
24155  416 

fun norm_conv ct = 
let 
val thy = Thm.theory_of_cterm ct; 

fun conv code vs_ty_t deps ct = 
let 
val t = Thm.term_of ct; 

in norm_invoke thy code t vs_ty_t deps end; 
in CodePackage.evaluate_conv thy conv ct end; 
24839  426 
427 
fun invoke code vs_ty_t deps t = 

eval thy code t vs_ty_t deps; 

in CodePackage.evaluate_term thy invoke #> Code.postprocess_term thy end; 
(* evaluation command *) 
434 
435 
val thy = ProofContext.theory_of ctxt; 

val t' = norm_term thy t; 
val ty' = Term.type_of t'; 

val p = PrintMode.with_modes modes (fn () => 
Pretty.block [Pretty.quote (Syntax.pretty_term ctxt t'), Pretty.fbrk, 
Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt ty')]) (); 

in Pretty.writeln p end; 
(** Isar setup **) 

fun norm_print_term_cmd (modes, s) state = 
let val ctxt = Toplevel.context_of state 
in norm_print_term ctxt modes (Syntax.read_term ctxt s) end; 
24839  451 
24155  452 

local structure P = OuterParse and K = OuterKeyword in 

455 
456 

val _ = 
OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag 
(opt_modes  P.typ >> (Toplevel.keep o norm_print_term_cmd)); 

461 
462 

end; 