clarified implementation
authorhaftmann
Thu, 25 Oct 2007 13:52:04 +0200
changeset 25190 5cd8486c5a4f
parent 25189 a1997f7a394a
child 25191 e1146aa1e3e3
clarified implementation
src/Tools/nbe.ML
--- a/src/Tools/nbe.ML	Thu Oct 25 13:52:03 2007 +0200
+++ b/src/Tools/nbe.ML	Thu Oct 25 13:52:04 2007 +0200
@@ -21,9 +21,7 @@
   val abs: int -> (Univ list -> Univ) -> Univ
                                              (*abstractions as closures*)
 
-  val univs_ref: (unit -> Univ list) ref 
-  val lookup_fun: string -> Univ
-
+  val univs_ref: (unit -> Univ list -> Univ list) option ref 
   val trace: bool ref
   val setup: theory -> theory
 end;
@@ -80,37 +78,6 @@
 type univ_gr = Univ option Graph.T;
 val compiled : univ_gr -> string -> bool = can o Graph.get_node;
 
-(* sandbox communication *)
-
-val univs_ref = ref (fn () => [] : Univ list);
-
-local
-
-val gr_ref = ref NONE : univ_gr option ref;
-
-fun compile gr raw_s = NAMED_CRITICAL "nbe" (fn () =>
-  let
-    val _ = univs_ref := (fn () => []);
-    val s = "Nbe.univs_ref := " ^ raw_s;
-    val _ = tracing (fn () => "\n--- generated code:\n" ^ s) ();
-    val _ = gr_ref := SOME gr;
-    val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
-      Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
-      (!trace) s;
-    val _ = gr_ref := NONE;
-  in !univs_ref end);
-
-in
-
-fun lookup_fun s = NAMED_CRITICAL "nbe" (fn () => case ! gr_ref
- of NONE => error "compile_univs"
-  | SOME gr => the (Graph.get_node gr s));
-
-fun compile_univs gr ([], _) = []
-  | compile_univs gr (cs, raw_s) = cs ~~ compile gr raw_s ();
-
-end; (*local*)
-
 
 (** assembling and compiling ML code from terms **)
 
@@ -122,15 +89,12 @@
   | e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
 fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")";
 
-fun ml_Val v s = "val " ^ v ^ " = " ^ s;
 fun ml_cases t cs =
   "(case " ^ t ^ " of " ^ space_implode " | " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")";
 fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end";
 
 fun ml_list es = "[" ^ commas es ^ "]";
 
-val ml_delay = ml_abs "()"
-
 fun ml_fundefs ([(name, [([], e)])]) =
       "val " ^ name ^ " = " ^ e ^ "\n"
   | ml_fundefs (eqs :: eqss) =
@@ -171,15 +135,24 @@
 fun nbe_abss 0 f = f `$` ml_list []
   | nbe_abss n f = name_abs `$$` [string_of_int n, f];
 
-fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c);
-
 val nbe_value = "value";
 
 end;
 
 open BasicCodeThingol;
 
-(* greetings to Tarski *)
+(* sandbox communication *)
+
+val univs_ref = ref (NONE : (unit -> Univ list -> Univ list) option);
+
+val compile = 
+  tracing (fn s => "\n--- code to be evaluated:\n" ^ s)
+  #> evaluate ("Nbe.univs_ref", univs_ref) "normalization by evaluation"
+      (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
+      Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
+      (!trace);
+
+(* code generation with greetings to Tarski *)
 
 fun assemble_idict (DictConst (inst, dss)) =
       nbe_apps (nbe_fun inst) ((maps o map) assemble_idict dss)
@@ -221,18 +194,23 @@
     val default_eqn = ([ml_list default_params], nbe_const c default_params);
   in map assemble_eqn eqns @ [default_eqn] end;
 
-fun assemble_eqnss gr ([], deps) = ([], "")
-  | assemble_eqnss gr (eqnss, deps) =
+fun assemble_eqnss gr deps [] = ([], ("", []))
+  | assemble_eqnss gr deps eqnss =
       let
         val cs = map fst eqnss;
         val num_args = cs ~~ map (fn (_, (vs, (args, rhs) :: _)) =>
           length (maps snd vs) + length args) eqnss;
-        val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps);
+        val deps' = filter (is_some o Option.join o try (Graph.get_node gr)) deps;
+        val bind_deps = ml_list (map nbe_fun deps');
         val bind_locals = ml_fundefs (map nbe_fun cs ~~ map
           (assemble_fun gr (AList.lookup (op =) num_args)) eqnss);
-        val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args)
-          |> ml_delay;
-      in (cs, ml_Let (bind_deps @ [bind_locals]) result) end;
+        val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args);
+        val arg_deps = map (the o Graph.get_node gr) deps';
+      in (cs, (ml_abs bind_deps (ml_Let [bind_locals] result), arg_deps)) end;
+
+fun compile_eqnss gr deps eqnss = case assemble_eqnss gr deps eqnss
+ of ([], _) => []
+  | (cs, (s, deps)) => cs ~~ compile s deps;
 
 fun eqns_of_stmt (_, CodeThingol.Fun (_, [])) =
       []
@@ -264,8 +242,13 @@
     val names = map (fst o fst) stmts_deps;
     val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
     val eqnss = maps (eqns_of_stmt o fst) stmts_deps;
-    val compiled_deps = names_deps |> maps snd |> distinct (op =) |> subtract (op =) names;
-    fun compile gr = (eqnss, compiled_deps) |> assemble_eqnss gr |> compile_univs gr |> rpair gr;
+    val compiled_deps = names_deps
+      |> maps snd
+      |> distinct (op =)
+      |> subtract (op =) names;
+    fun compile gr = eqnss
+      |> compile_eqnss gr compiled_deps
+      |> rpair gr;
   in
     fold (fn name => Graph.new_node (name, NONE)) names
     #> fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
@@ -280,25 +263,25 @@
           Graph.imm_succs code name)) names);
   in fold_rev add_stmts (Graph.strong_conn code) end;
 
-fun assemble_eval gr (((vs, ty), t), deps) =
+fun assemble_eval gr deps ((vs, ty), t) =
   let
     val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
-    val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps);
+    val deps' = filter (is_some o Option.join o try (Graph.get_node gr)) deps;
+    val bind_deps = ml_list (map nbe_fun deps');
     val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs
       |> rev;
     val bind_value = ml_fundefs [(nbe_value,
       [([ml_list (map nbe_bound frees @ dict_params)],
         assemble_iterm (is_some o Graph.get_node gr) (K NONE) t)])];
     val result = ml_list [nbe_value `$` ml_list
-      (map nbe_free frees @ map nbe_dfree dict_params)]
-      |> ml_delay;
-  in ([nbe_value], ml_Let (bind_deps @ [bind_value]) result) end;
+      (map nbe_free frees @ map nbe_dfree dict_params)];
+    val arg_deps = map (the o Graph.get_node gr) deps';
+  in (ml_abs bind_deps (ml_Let [bind_value] result), arg_deps) end;
 
-fun eval_term gr =
-  assemble_eval gr
-  #> compile_univs gr
-  #> the_single
-  #> snd;
+fun eval_term gr deps t' =
+  let
+    val (s, args) = assemble_eval gr deps t';
+  in the_single (compile s args) end;
 
 
 (** evaluation **)
@@ -359,8 +342,8 @@
 (* compilation, evaluation and reification *)
 
 fun compile_eval thy code vs_ty_t deps =
-  (vs_ty_t, deps)
-  |> eval_term (Nbe_Functions.change thy (ensure_stmts code))
+  vs_ty_t
+  |> eval_term (Nbe_Functions.change thy (ensure_stmts code)) deps
   |> term_of_univ thy;
 
 (* evaluation with type reconstruction *)