24155

1 
(* Title: Tools/Nbe/Nbe_Eval.ML


2 
ID: $Id$


3 
Authors: Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen


4 


5 
Evaluation mechanisms for normalization by evaluation.


6 
*)


7 


8 
(*


9 
FIXME:


10 
 get rid of BVar (?)  it is only used for terms to be evaluated, not for functions


11 
 proper purge operation  preliminary for...


12 
 really incremental code generation


13 
*)


14 


15 
signature NBE =


16 
sig


17 
datatype Univ =


18 
Const of string * Univ list (*named constructors*)


19 
 Free of string * Univ list


20 
 BVar of int * Univ list


21 
 Abs of (int * (Univ list > Univ)) * Univ list;


22 
val free: string > Univ list > Univ (*free (uninterpreted) variables*)


23 
val abs: int > (Univ list > Univ) > Univ list > Univ


24 
(*abstractions as functions*)


25 
val app: Univ > Univ > Univ (*explicit application*)


26 


27 
val univs_ref: Univ list ref


28 
val lookup_fun: CodegenNames.const > Univ


29 


30 
val normalization_conv: cterm > thm


31 


32 
val trace: bool ref


33 
val setup: theory > theory


34 
end;


35 


36 
structure Nbe: NBE =


37 
struct


38 


39 
(* generic nonsense *)


40 


41 
val trace = ref false;


42 
fun tracing f x = if !trace then (Output.tracing (f x); x) else x;


43 


44 


45 
(** the semantical universe **)


46 


47 
(*


48 
Functions are given by their semantical function value. To avoid


49 
trouble with the MLtype system, these functions have the most


50 
generic type, that is "Univ list > Univ". The calling convention is


51 
that the arguments come as a list, the last argument first. In


52 
other words, a function call that usually would look like


53 


54 
f x_1 x_2 ... x_n or f(x_1,x_2, ..., x_n)


55 


56 
would be in our convention called as


57 


58 
f [x_n,..,x_2,x_1]


59 


60 
Moreover, to handle functions that are still waiting for some


61 
arguments we have additionally a list of arguments collected to far


62 
and the number of arguments we're still waiting for.


63 


64 
(?) Finally, it might happen, that a function does not get all the


65 
arguments it needs. In this case the function must provide means to


66 
present itself as a string. As this might be a heavywight


67 
operation, we delay it. (?)


68 
*)


69 


70 
datatype Univ =


71 
Const of string * Univ list (*named constructors*)


72 
 Free of string * Univ list (*free variables*)


73 
 BVar of int * Univ list (*bound named variables*)


74 
 Abs of (int * (Univ list > Univ)) * Univ list


75 
(*functions*);


76 


77 
(* constructor functions *)


78 


79 
val free = curry Free;


80 
fun abs n f ts = Abs ((n, f), ts);


81 
fun app (Abs ((1, f), xs)) x = f (x :: xs)


82 
 app (Abs ((n, f), xs)) x = Abs ((n  1, f), x :: xs)


83 
 app (Const (name, args)) x = Const (name, x :: args)


84 
 app (Free (name, args)) x = Free (name, x :: args)


85 
 app (BVar (name, args)) x = BVar (name, x :: args);


86 


87 
(* global functions store *)


88 


89 
structure Nbe_Functions = CodeDataFun


90 
(struct


91 
type T = Univ Symtab.table;


92 
val empty = Symtab.empty;


93 
fun merge _ = Symtab.merge (K true);


94 
fun purge _ _ _ = Symtab.empty;


95 
end);


96 


97 
(* sandbox communication *)


98 


99 
val univs_ref = ref [] : Univ list ref;


100 


101 
local


102 


103 
val tab_ref = ref NONE : Univ Symtab.table option ref;


104 


105 
in


106 


107 
fun lookup_fun s = case ! tab_ref


108 
of NONE => error "compile_univs"


109 
 SOME tab => (the o Symtab.lookup tab) s;


110 


111 
fun compile_univs tab ([], _) = []


112 
 compile_univs tab (cs, raw_s) =


113 
let


114 
val _ = univs_ref := [];


115 
val s = "Nbe.univs_ref := " ^ raw_s;


116 
val _ = tracing (fn () => "\ngenerated code:\n" ^ s) ();


117 
val _ = tab_ref := SOME tab;


118 
val _ = use_text "" (Output.tracing o enclose "\ncompiler echo:\n" "\n\n",


119 
Output.tracing o enclose "\n compiler echo (with error):\n" "\n\n")


120 
(!trace) s;


121 
val _ = tab_ref := NONE;


122 
val univs = case !univs_ref of [] => error "compile_univs"  univs => univs;


123 
in cs ~~ univs end;


124 


125 
end; (*local*)


126 


127 


128 
(** assembling and compiling ML code from terms **)


129 


130 
(* abstract ML syntax *)


131 


132 
infix 9 `$` `$$`;


133 
fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")";


134 
fun e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";


135 
fun ml_abs v e = "(fn" ^ v ^ " => " ^ e ^ ")";


136 


137 
fun ml_Val v s = "val " ^ v ^ " = " ^ s;


138 
fun ml_cases t cs =


139 
"(case " ^ t ^ " of " ^ space_implode "  " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")";


140 
fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end";


141 


142 
fun ml_list es = "[" ^ commas es ^ "]";


143 


144 
fun ml_fundefs ([(name, [([], e)])]) =


145 
"val " ^ name ^ " = " ^ e ^ "\n"


146 
 ml_fundefs (eqs :: eqss) =


147 
let


148 
fun fundef (name, eqs) =


149 
let


150 
fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e


151 
in space_implode "\n  " (map eqn eqs) end;


152 
in


153 
(prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss


154 
> space_implode "\n"


155 
> suffix "\n"


156 
end;


157 


158 
(* nbe specific syntax *)


159 


160 
local


161 
val prefix = "Nbe.";


162 
val name_const = prefix ^ "Const";


163 
val name_free = prefix ^ "free";


164 
val name_abs = prefix ^ "abs";


165 
val name_app = prefix ^ "app";


166 
val name_lookup_fun = prefix ^ "lookup_fun";


167 
in


168 


169 
fun nbe_const c ts = name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");


170 
fun nbe_fun c = "c_" ^ translate_string (fn "." => "_"  c => c) c;


171 
fun nbe_free v = name_free `$$` [ML_Syntax.print_string v, ml_list []];


172 
fun nbe_bound v = "v_" ^ v;


173 


174 
fun nbe_apps e es =


175 
Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e);


176 


177 
fun nbe_abss 0 f = f `$` ml_list []


178 
 nbe_abss n f = name_abs `$$` [string_of_int n, f, ml_list []];


179 


180 
fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c);


181 


182 
val nbe_value = "value";


183 


184 
end;


185 


186 
open BasicCodegenThingol;


187 


188 
(* greetings to Tarski *)


189 


190 
fun assemble_iterm thy is_fun num_args =


191 
let


192 
fun of_iterm t =


193 
let


194 
val (t', ts) = CodegenThingol.unfold_app t


195 
in of_itermapp t' (fold (cons o of_iterm) ts []) end


196 
and of_itermapp (IConst (c, (dss, _))) ts =


197 
(case num_args c


198 
of SOME n => if n <= length ts


199 
then let val (args2, args1) = chop (length ts  n) ts


200 
in nbe_apps (nbe_fun c `$` ml_list args1) args2


201 
end else nbe_const c ts


202 
 NONE => if is_fun c then nbe_apps (nbe_fun c) ts


203 
else nbe_const c ts)


204 
 of_itermapp (IVar v) ts = nbe_apps (nbe_bound v) ts


205 
 of_itermapp ((v, _) `> t) ts =


206 
nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts


207 
 of_itermapp (ICase (((t, _), cs), t0)) ts =


208 
nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs


209 
@ [("_", of_iterm t0)])) ts


210 
in of_iterm end;


211 


212 
fun assemble_fun thy is_fun num_args (c, eqns) =


213 
let


214 
val assemble_arg = assemble_iterm thy (K false) (K NONE);


215 
val assemble_rhs = assemble_iterm thy is_fun num_args;


216 
fun assemble_eqn (args, rhs) =


217 
([ml_list (map assemble_arg (rev args))], assemble_rhs rhs);


218 
val default_params = map nbe_bound


219 
(Name.invent_list [] "a" ((the o num_args) c));


220 
val default_eqn = ([ml_list default_params], nbe_const c default_params);


221 
in map assemble_eqn eqns @ [default_eqn] end;


222 


223 
fun assemble_eqnss thy is_fun [] = ([], "")


224 
 assemble_eqnss thy is_fun eqnss =


225 
let


226 
val cs = map fst eqnss;


227 
val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss;


228 
val funs = fold (fold (CodegenThingol.fold_constnames


229 
(insert (op =))) o map snd o snd) eqnss [];


230 
val bind_funs = map nbe_lookup (filter is_fun funs);


231 
val bind_locals = ml_fundefs (map nbe_fun cs ~~ map


232 
(assemble_fun thy is_fun (AList.lookup (op =) num_args)) eqnss);


233 
val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args);


234 
in (cs, ml_Let (bind_funs @ [bind_locals]) result) end;


235 


236 
fun assemble_eval thy is_fun t =


237 
let


238 
val funs = CodegenThingol.fold_constnames (insert (op =)) t [];


239 
val frees = CodegenThingol.fold_unbound_varnames (insert (op =)) t [];


240 
val bind_funs = map nbe_lookup (filter is_fun funs);


241 
val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees)],


242 
assemble_iterm thy is_fun (K NONE) t)])];


243 
val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)];


244 
in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end;


245 


246 
fun eqns_of_stmt (name, CodegenThingol.Fun ([], _)) =


247 
NONE


248 
 eqns_of_stmt (name, CodegenThingol.Fun (eqns, _)) =


249 
SOME (name, eqns)


250 
 eqns_of_stmt (_, CodegenThingol.Datatypecons _) =


251 
NONE


252 
 eqns_of_stmt (_, CodegenThingol.Datatype _) =


253 
NONE


254 
 eqns_of_stmt (_, CodegenThingol.Class _) =


255 
NONE


256 
 eqns_of_stmt (_, CodegenThingol.Classrel _) =


257 
NONE


258 
 eqns_of_stmt (_, CodegenThingol.Classop _) =


259 
NONE


260 
 eqns_of_stmt (_, CodegenThingol.Classinst _) =


261 
NONE;


262 


263 
fun compile_stmts thy is_fun =


264 
map_filter eqns_of_stmt


265 
#> assemble_eqnss thy is_fun


266 
#> compile_univs (Nbe_Functions.get thy);


267 


268 
fun eval_term thy is_fun =


269 
assemble_eval thy is_fun


270 
#> compile_univs (Nbe_Functions.get thy)


271 
#> the_single


272 
#> snd;


273 


274 


275 
(** compilation and evaluation **)


276 


277 
(* ensure global functions *)


278 


279 
fun ensure_funs thy code =


280 
let


281 
fun compile' stmts tab =


282 
let


283 
val compiled = compile_stmts thy (Symtab.defined tab) stmts;


284 
in Nbe_Functions.change thy (fold Symtab.update compiled) end;


285 
val nbe_tab = Nbe_Functions.get thy;


286 
val stmtss =


287 
map (AList.make (Graph.get_node code)) (rev (Graph.strong_conn code))


288 
> (map o filter_out) (Symtab.defined nbe_tab o fst)


289 
in fold compile' stmtss nbe_tab end;


290 


291 
(* reconversion *)


292 


293 
fun term_of_univ thy t =


294 
let


295 
fun of_apps bounds (t, ts) =


296 
fold_map (of_univ bounds) ts


297 
#>> (fn ts' => list_comb (t, rev ts'))


298 
and of_univ bounds (Const (name, ts)) typidx =


299 
let


300 
val SOME (const as (c, _)) = CodegenNames.const_rev thy name;


301 
val T = CodegenData.default_typ thy const;


302 
val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;


303 
val typidx' = typidx + maxidx_of_typ T' + 1;


304 
in of_apps bounds (Term.Const (c, T'), ts) typidx' end


305 
 of_univ bounds (Free (name, ts)) typidx =


306 
of_apps bounds (Term.Free (name, dummyT), ts) typidx


307 
 of_univ bounds (BVar (name, ts)) typidx =


308 
of_apps bounds (Bound (bounds  name  1), ts) typidx


309 
 of_univ bounds (t as Abs _) typidx =


310 
typidx


311 
> of_univ (bounds + 1) (app t (BVar (bounds, [])))


312 
> (fn t' => pair (Term.Abs ("u", dummyT, t')))


313 
in of_univ 0 t 0 > fst end;


314 


315 
(* evaluation with type reconstruction *)


316 


317 
fun eval thy code t t' =


318 
let


319 
fun subst_Frees [] = I


320 
 subst_Frees inst =


321 
Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s)


322 
 t => t);


323 
val anno_vars =


324 
subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []))


325 
#> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t []))


326 
fun check_tvars t = if null (Term.term_tvars t) then t else


327 
error ("Illegal schematic type variables in normalized term: "


328 
^ setmp show_types true (Sign.string_of_term thy) t);


329 
val ty = type_of t;


330 
fun constrain t =


331 
singleton (ProofContext.infer_types_pats (ProofContext.init thy)) (TypeInfer.constrain t ty);


332 
in


333 
t'


334 
> eval_term thy (Symtab.defined (ensure_funs thy code))


335 
> term_of_univ thy


336 
> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t)


337 
> tracing (fn _ => "Term type:\n" ^ Display.raw_string_of_typ ty)


338 
> anno_vars


339 
> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t)


340 
> tracing (fn t => setmp show_types true (Sign.string_of_term thy) t)


341 
> constrain


342 
> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t)


343 
> check_tvars


344 
end;


345 


346 
(* evaluation oracle *)


347 


348 
exception Normalization of CodegenThingol.code * term * CodegenThingol.iterm;


349 


350 
fun normalization_oracle (thy, Normalization (code, t, t')) =


351 
Logic.mk_equals (t, eval thy code t t');


352 


353 
fun normalization_invoke thy code t t' =


354 
Thm.invoke_oracle_i thy "Nbe.normalization" (thy, Normalization (code, t, t'));


355 


356 
fun normalization_conv ct =


357 
let


358 
val thy = Thm.theory_of_cterm ct;


359 
fun conv code t' ct =


360 
let


361 
val t = Thm.term_of ct;


362 
in normalization_invoke thy code t t' end;


363 
in CodegenPackage.eval_conv thy conv ct end;


364 


365 
(* evaluation command *)


366 


367 
fun norm_print_term ctxt modes t =


368 
let


369 
val thy = ProofContext.theory_of ctxt;


370 
val ct = Thm.cterm_of thy t;


371 
val (_, t') = (Logic.dest_equals o Thm.prop_of o normalization_conv) ct;


372 
val ty = Term.type_of t';


373 
val p = Library.setmp print_mode (modes @ ! print_mode) (fn () =>


374 
Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk,


375 
Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty)]) ();


376 
in Pretty.writeln p end;


377 


378 


379 
(** Isar setup **)


380 


381 
fun norm_print_term_cmd (modes, raw_t) state =


382 
let val ctxt = Toplevel.context_of state


383 
in norm_print_term ctxt modes (ProofContext.read_term ctxt raw_t) end;


384 


385 
val setup = Theory.add_oracle ("normalization", normalization_oracle)


386 


387 
local structure P = OuterParse and K = OuterKeyword in


388 


389 
val opt_modes = Scan.optional (P.$$$ "("  P.!!! (Scan.repeat1 P.xname  P.$$$ ")")) [];


390 


391 
val nbeP =


392 
OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag


393 
(opt_modes  P.typ >> (Toplevel.keep o norm_print_term_cmd));


394 


395 
val _ = OuterSyntax.add_parsers [nbeP];


396 


397 
end;


398 


399 
end;
