--- a/src/Tools/nbe.ML Thu Oct 25 13:52:03 2007 +0200
+++ b/src/Tools/nbe.ML Thu Oct 25 13:52:04 2007 +0200
@@ -21,9 +21,7 @@
val abs: int -> (Univ list -> Univ) -> Univ
(*abstractions as closures*)
- val univs_ref: (unit -> Univ list) ref
- val lookup_fun: string -> Univ
-
+ val univs_ref: (unit -> Univ list -> Univ list) option ref
val trace: bool ref
val setup: theory -> theory
end;
@@ -80,37 +78,6 @@
type univ_gr = Univ option Graph.T;
val compiled : univ_gr -> string -> bool = can o Graph.get_node;
-(* sandbox communication *)
-
-val univs_ref = ref (fn () => [] : Univ list);
-
-local
-
-val gr_ref = ref NONE : univ_gr option ref;
-
-fun compile gr raw_s = NAMED_CRITICAL "nbe" (fn () =>
- let
- val _ = univs_ref := (fn () => []);
- val s = "Nbe.univs_ref := " ^ raw_s;
- val _ = tracing (fn () => "\n--- generated code:\n" ^ s) ();
- val _ = gr_ref := SOME gr;
- val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
- Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
- (!trace) s;
- val _ = gr_ref := NONE;
- in !univs_ref end);
-
-in
-
-fun lookup_fun s = NAMED_CRITICAL "nbe" (fn () => case ! gr_ref
- of NONE => error "compile_univs"
- | SOME gr => the (Graph.get_node gr s));
-
-fun compile_univs gr ([], _) = []
- | compile_univs gr (cs, raw_s) = cs ~~ compile gr raw_s ();
-
-end; (*local*)
-
(** assembling and compiling ML code from terms **)
@@ -122,15 +89,12 @@
| e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")";
-fun ml_Val v s = "val " ^ v ^ " = " ^ s;
fun ml_cases t cs =
"(case " ^ t ^ " of " ^ space_implode " | " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")";
fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end";
fun ml_list es = "[" ^ commas es ^ "]";
-val ml_delay = ml_abs "()"
-
fun ml_fundefs ([(name, [([], e)])]) =
"val " ^ name ^ " = " ^ e ^ "\n"
| ml_fundefs (eqs :: eqss) =
@@ -171,15 +135,24 @@
fun nbe_abss 0 f = f `$` ml_list []
| nbe_abss n f = name_abs `$$` [string_of_int n, f];
-fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c);
-
val nbe_value = "value";
end;
open BasicCodeThingol;
-(* greetings to Tarski *)
+(* sandbox communication *)
+
+val univs_ref = ref (NONE : (unit -> Univ list -> Univ list) option);
+
+val compile =
+ tracing (fn s => "\n--- code to be evaluated:\n" ^ s)
+ #> evaluate ("Nbe.univs_ref", univs_ref) "normalization by evaluation"
+ (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
+ Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
+ (!trace);
+
+(* code generation with greetings to Tarski *)
fun assemble_idict (DictConst (inst, dss)) =
nbe_apps (nbe_fun inst) ((maps o map) assemble_idict dss)
@@ -221,18 +194,23 @@
val default_eqn = ([ml_list default_params], nbe_const c default_params);
in map assemble_eqn eqns @ [default_eqn] end;
-fun assemble_eqnss gr ([], deps) = ([], "")
- | assemble_eqnss gr (eqnss, deps) =
+fun assemble_eqnss gr deps [] = ([], ("", []))
+ | assemble_eqnss gr deps eqnss =
let
val cs = map fst eqnss;
val num_args = cs ~~ map (fn (_, (vs, (args, rhs) :: _)) =>
length (maps snd vs) + length args) eqnss;
- val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps);
+ val deps' = filter (is_some o Option.join o try (Graph.get_node gr)) deps;
+ val bind_deps = ml_list (map nbe_fun deps');
val bind_locals = ml_fundefs (map nbe_fun cs ~~ map
(assemble_fun gr (AList.lookup (op =) num_args)) eqnss);
- val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args)
- |> ml_delay;
- in (cs, ml_Let (bind_deps @ [bind_locals]) result) end;
+ val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args);
+ val arg_deps = map (the o Graph.get_node gr) deps';
+ in (cs, (ml_abs bind_deps (ml_Let [bind_locals] result), arg_deps)) end;
+
+fun compile_eqnss gr deps eqnss = case assemble_eqnss gr deps eqnss
+ of ([], _) => []
+ | (cs, (s, deps)) => cs ~~ compile s deps;
fun eqns_of_stmt (_, CodeThingol.Fun (_, [])) =
[]
@@ -264,8 +242,13 @@
val names = map (fst o fst) stmts_deps;
val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
val eqnss = maps (eqns_of_stmt o fst) stmts_deps;
- val compiled_deps = names_deps |> maps snd |> distinct (op =) |> subtract (op =) names;
- fun compile gr = (eqnss, compiled_deps) |> assemble_eqnss gr |> compile_univs gr |> rpair gr;
+ val compiled_deps = names_deps
+ |> maps snd
+ |> distinct (op =)
+ |> subtract (op =) names;
+ fun compile gr = eqnss
+ |> compile_eqnss gr compiled_deps
+ |> rpair gr;
in
fold (fn name => Graph.new_node (name, NONE)) names
#> fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
@@ -280,25 +263,25 @@
Graph.imm_succs code name)) names);
in fold_rev add_stmts (Graph.strong_conn code) end;
-fun assemble_eval gr (((vs, ty), t), deps) =
+fun assemble_eval gr deps ((vs, ty), t) =
let
val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
- val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps);
+ val deps' = filter (is_some o Option.join o try (Graph.get_node gr)) deps;
+ val bind_deps = ml_list (map nbe_fun deps');
val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs
|> rev;
val bind_value = ml_fundefs [(nbe_value,
[([ml_list (map nbe_bound frees @ dict_params)],
assemble_iterm (is_some o Graph.get_node gr) (K NONE) t)])];
val result = ml_list [nbe_value `$` ml_list
- (map nbe_free frees @ map nbe_dfree dict_params)]
- |> ml_delay;
- in ([nbe_value], ml_Let (bind_deps @ [bind_value]) result) end;
+ (map nbe_free frees @ map nbe_dfree dict_params)];
+ val arg_deps = map (the o Graph.get_node gr) deps';
+ in (ml_abs bind_deps (ml_Let [bind_value] result), arg_deps) end;
-fun eval_term gr =
- assemble_eval gr
- #> compile_univs gr
- #> the_single
- #> snd;
+fun eval_term gr deps t' =
+ let
+ val (s, args) = assemble_eval gr deps t';
+ in the_single (compile s args) end;
(** evaluation **)
@@ -359,8 +342,8 @@
(* compilation, evaluation and reification *)
fun compile_eval thy code vs_ty_t deps =
- (vs_ty_t, deps)
- |> eval_term (Nbe_Functions.change thy (ensure_stmts code))
+ vs_ty_t
+ |> eval_term (Nbe_Functions.change thy (ensure_stmts code)) deps
|> term_of_univ thy;
(* evaluation with type reconstruction *)