# HG changeset patch # User haftmann # Date 1193313124 -7200 # Node ID 5cd8486c5a4f93c440187c33685905267623338c # Parent a1997f7a394a17b743afe31f24bc8504b1de4ee5 clarified implementation diff -r a1997f7a394a -r 5cd8486c5a4f src/Tools/nbe.ML --- a/src/Tools/nbe.ML Thu Oct 25 13:52:03 2007 +0200 +++ b/src/Tools/nbe.ML Thu Oct 25 13:52:04 2007 +0200 @@ -21,9 +21,7 @@ val abs: int -> (Univ list -> Univ) -> Univ (*abstractions as closures*) - val univs_ref: (unit -> Univ list) ref - val lookup_fun: string -> Univ - + val univs_ref: (unit -> Univ list -> Univ list) option ref val trace: bool ref val setup: theory -> theory end; @@ -80,37 +78,6 @@ type univ_gr = Univ option Graph.T; val compiled : univ_gr -> string -> bool = can o Graph.get_node; -(* sandbox communication *) - -val univs_ref = ref (fn () => [] : Univ list); - -local - -val gr_ref = ref NONE : univ_gr option ref; - -fun compile gr raw_s = NAMED_CRITICAL "nbe" (fn () => - let - val _ = univs_ref := (fn () => []); - val s = "Nbe.univs_ref := " ^ raw_s; - val _ = tracing (fn () => "\n--- generated code:\n" ^ s) (); - val _ = gr_ref := SOME gr; - val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n", - Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n") - (!trace) s; - val _ = gr_ref := NONE; - in !univs_ref end); - -in - -fun lookup_fun s = NAMED_CRITICAL "nbe" (fn () => case ! gr_ref - of NONE => error "compile_univs" - | SOME gr => the (Graph.get_node gr s)); - -fun compile_univs gr ([], _) = [] - | compile_univs gr (cs, raw_s) = cs ~~ compile gr raw_s (); - -end; (*local*) - (** assembling and compiling ML code from terms **) @@ -122,15 +89,12 @@ | e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")"; fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")"; -fun ml_Val v s = "val " ^ v ^ " = " ^ s; fun ml_cases t cs = "(case " ^ t ^ " of " ^ space_implode " | " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")"; fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end"; fun ml_list es = "[" ^ commas es ^ "]"; -val ml_delay = ml_abs "()" - fun ml_fundefs ([(name, [([], e)])]) = "val " ^ name ^ " = " ^ e ^ "\n" | ml_fundefs (eqs :: eqss) = @@ -171,15 +135,24 @@ fun nbe_abss 0 f = f `$` ml_list [] | nbe_abss n f = name_abs `$$` [string_of_int n, f]; -fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c); - val nbe_value = "value"; end; open BasicCodeThingol; -(* greetings to Tarski *) +(* sandbox communication *) + +val univs_ref = ref (NONE : (unit -> Univ list -> Univ list) option); + +val compile = + tracing (fn s => "\n--- code to be evaluated:\n" ^ s) + #> evaluate ("Nbe.univs_ref", univs_ref) "normalization by evaluation" + (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n", + Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n") + (!trace); + +(* code generation with greetings to Tarski *) fun assemble_idict (DictConst (inst, dss)) = nbe_apps (nbe_fun inst) ((maps o map) assemble_idict dss) @@ -221,18 +194,23 @@ val default_eqn = ([ml_list default_params], nbe_const c default_params); in map assemble_eqn eqns @ [default_eqn] end; -fun assemble_eqnss gr ([], deps) = ([], "") - | assemble_eqnss gr (eqnss, deps) = +fun assemble_eqnss gr deps [] = ([], ("", [])) + | assemble_eqnss gr deps eqnss = let val cs = map fst eqnss; val num_args = cs ~~ map (fn (_, (vs, (args, rhs) :: _)) => length (maps snd vs) + length args) eqnss; - val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps); + val deps' = filter (is_some o Option.join o try (Graph.get_node gr)) deps; + val bind_deps = ml_list (map nbe_fun deps'); val bind_locals = ml_fundefs (map nbe_fun cs ~~ map (assemble_fun gr (AList.lookup (op =) num_args)) eqnss); - val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args) - |> ml_delay; - in (cs, ml_Let (bind_deps @ [bind_locals]) result) end; + val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args); + val arg_deps = map (the o Graph.get_node gr) deps'; + in (cs, (ml_abs bind_deps (ml_Let [bind_locals] result), arg_deps)) end; + +fun compile_eqnss gr deps eqnss = case assemble_eqnss gr deps eqnss + of ([], _) => [] + | (cs, (s, deps)) => cs ~~ compile s deps; fun eqns_of_stmt (_, CodeThingol.Fun (_, [])) = [] @@ -264,8 +242,13 @@ val names = map (fst o fst) stmts_deps; val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps; val eqnss = maps (eqns_of_stmt o fst) stmts_deps; - val compiled_deps = names_deps |> maps snd |> distinct (op =) |> subtract (op =) names; - fun compile gr = (eqnss, compiled_deps) |> assemble_eqnss gr |> compile_univs gr |> rpair gr; + val compiled_deps = names_deps + |> maps snd + |> distinct (op =) + |> subtract (op =) names; + fun compile gr = eqnss + |> compile_eqnss gr compiled_deps + |> rpair gr; in fold (fn name => Graph.new_node (name, NONE)) names #> fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps @@ -280,25 +263,25 @@ Graph.imm_succs code name)) names); in fold_rev add_stmts (Graph.strong_conn code) end; -fun assemble_eval gr (((vs, ty), t), deps) = +fun assemble_eval gr deps ((vs, ty), t) = let val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t []; - val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps); + val deps' = filter (is_some o Option.join o try (Graph.get_node gr)) deps; + val bind_deps = ml_list (map nbe_fun deps'); val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs |> rev; val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees @ dict_params)], assemble_iterm (is_some o Graph.get_node gr) (K NONE) t)])]; val result = ml_list [nbe_value `$` ml_list - (map nbe_free frees @ map nbe_dfree dict_params)] - |> ml_delay; - in ([nbe_value], ml_Let (bind_deps @ [bind_value]) result) end; + (map nbe_free frees @ map nbe_dfree dict_params)]; + val arg_deps = map (the o Graph.get_node gr) deps'; + in (ml_abs bind_deps (ml_Let [bind_value] result), arg_deps) end; -fun eval_term gr = - assemble_eval gr - #> compile_univs gr - #> the_single - #> snd; +fun eval_term gr deps t' = + let + val (s, args) = assemble_eval gr deps t'; + in the_single (compile s args) end; (** evaluation **) @@ -359,8 +342,8 @@ (* compilation, evaluation and reification *) fun compile_eval thy code vs_ty_t deps = - (vs_ty_t, deps) - |> eval_term (Nbe_Functions.change thy (ensure_stmts code)) + vs_ty_t + |> eval_term (Nbe_Functions.change thy (ensure_stmts code)) deps |> term_of_univ thy; (* evaluation with type reconstruction *)