# HG changeset patch # User haftmann # Date 1192803627 -7200 # Node ID cae0f68b693b864cfdb45e7b141a8b306da63db2 # Parent fe9632d914c75267b1f9ee512e6babd5c84ae132 now employing dictionaries diff -r fe9632d914c7 -r cae0f68b693b src/Tools/nbe.ML --- a/src/Tools/nbe.ML Fri Oct 19 16:18:00 2007 +0200 +++ b/src/Tools/nbe.ML Fri Oct 19 16:20:27 2007 +0200 @@ -7,22 +7,23 @@ signature NBE = sig + val norm_conv: cterm -> thm + val norm_term: theory -> term -> term + datatype Univ = Const of string * Univ list (*named (uninterpreted) constants*) | Free of string * Univ list + | DFree of string (*free (uninterpreted) dictionary parameters*) | BVar of int * Univ list | Abs of (int * (Univ list -> Univ)) * Univ list; - val free: string -> Univ list -> Univ (*free (uninterpreted) variables*) - val abs: int -> (Univ list -> Univ) -> Univ list -> Univ - (*abstractions as functions*) + val free: string -> Univ (*free (uninterpreted) variables*) val app: Univ -> Univ -> Univ (*explicit application*) + val abs: int -> (Univ list -> Univ) -> Univ + (*abstractions as closures*) val univs_ref: (unit -> Univ list) ref val lookup_fun: string -> Univ - val norm_conv: cterm -> thm - val norm_term: theory -> term -> term - val trace: bool ref val setup: theory -> theory end; @@ -59,41 +60,25 @@ datatype Univ = Const of string * Univ list (*named (uninterpreted) constants*) | Free of string * Univ list (*free variables*) + | DFree of string (*free (uninterpreted) dictionary parameters*) | BVar of int * Univ list (*bound named variables*) | Abs of (int * (Univ list -> Univ)) * Univ list (*abstractions as closures*); (* constructor functions *) -val free = curry Free; -fun abs n f ts = Abs ((n, f), ts); +fun free v = Free (v, []); +fun abs n f = Abs ((n, f), []); fun app (Abs ((1, f), xs)) x = f (x :: xs) | app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs) | app (Const (name, args)) x = Const (name, x :: args) | app (Free (name, args)) x = Free (name, x :: args) | app (BVar (name, args)) x = BVar (name, x :: args); -(* global functions store *) +(* universe graph *) -structure Nbe_Functions = CodeDataFun -( - type T = Univ Graph.T; - val empty = Graph.empty; - fun merge _ = Graph.merge (K true); - fun purge _ NONE _ = Graph.empty - | purge NONE _ _ = Graph.empty - | purge (SOME thy) (SOME cs) gr = Graph.empty - (*let - val cs_exisiting = - map_filter (CodeName.const_rev thy) (Graph.keys gr); - val dels = (Graph.all_preds gr - o map (CodeName.const thy) - o filter (member (op =) cs_exisiting) - ) cs; - in Graph.del_nodes dels gr end*); -); - -fun defined gr = can (Graph.get_node gr); +type univ_gr = Univ option Graph.T; +val compiled : univ_gr -> string -> bool = can o Graph.get_node; (* sandbox communication *) @@ -101,30 +86,28 @@ local -val gr_ref = ref NONE : Nbe_Functions.T option ref; +val gr_ref = ref NONE : univ_gr option ref; -fun compile tab raw_s = NAMED_CRITICAL "nbe" (fn () => +fun compile gr raw_s = NAMED_CRITICAL "nbe" (fn () => let val _ = univs_ref := (fn () => []); val s = "Nbe.univs_ref := " ^ raw_s; val _ = tracing (fn () => "\n--- generated code:\n" ^ s) (); - val _ = gr_ref := SOME tab; + val _ = gr_ref := SOME gr; val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n", Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n") (!trace) s; val _ = gr_ref := NONE; in !univs_ref end); + in -fun lookup_fun s = NAMED_CRITICAL "nbe" (fn () => - case ! gr_ref +fun lookup_fun s = NAMED_CRITICAL "nbe" (fn () => case ! gr_ref of NONE => error "compile_univs" - | SOME gr => Graph.get_node gr s); + | SOME gr => the (Graph.get_node gr s)); -fun compile_univs tab ([], _) = [] - | compile_univs tab (cs, raw_s) = case compile tab raw_s () - of [] => error "compile_univs" - | univs => cs ~~ univs; +fun compile_univs gr ([], _) = [] + | compile_univs gr (cs, raw_s) = cs ~~ compile gr raw_s (); end; (*local*) @@ -135,7 +118,8 @@ infix 9 `$` `$$`; fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")"; -fun e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")"; +fun e `$$` [] = e + | e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")"; fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")"; fun ml_Val v s = "val " ^ v ^ " = " ^ s; @@ -167,21 +151,25 @@ val prefix = "Nbe."; val name_const = prefix ^ "Const"; val name_free = prefix ^ "free"; + val name_dfree = prefix ^ "DFree"; val name_abs = prefix ^ "abs"; val name_app = prefix ^ "app"; val name_lookup_fun = prefix ^ "lookup_fun"; in -fun nbe_const c ts = name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")"); +fun nbe_const c ts = + name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")"); fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c; -fun nbe_free v = name_free `$$` [ML_Syntax.print_string v, ml_list []]; +fun nbe_free v = name_free `$` ML_Syntax.print_string v; +fun nbe_dfree v = name_dfree `$` ML_Syntax.print_string v; +fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n; fun nbe_bound v = "v_" ^ v; fun nbe_apps e es = Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e); fun nbe_abss 0 f = f `$` ml_list [] - | nbe_abss n f = name_abs `$$` [string_of_int n, f, ml_list []]; + | nbe_abss n f = name_abs `$$` [string_of_int n, f]; fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c); @@ -193,7 +181,12 @@ (* greetings to Tarski *) -fun assemble_iterm thy is_fun num_args = +fun assemble_idict (DictConst (inst, dss)) = + nbe_apps (nbe_fun inst) ((maps o map) assemble_idict dss) + | assemble_idict (DictVar (supers, (v, (n, _)))) = + fold (fn super => nbe_apps (nbe_fun super) o single) supers (nbe_dict v n); + +fun assemble_iterm is_fun num_args = let fun of_iterm t = let @@ -206,7 +199,8 @@ end else nbe_const c ts | NONE => if is_fun c then nbe_apps (nbe_fun c) ts else nbe_const c ts - and of_iapp (IConst (c, (dss, _))) ts = of_iconst c ts + and of_iapp (IConst (c, (dss, _))) ts = of_iconst c + (ts @ rev ((maps o map) assemble_idict dss)) | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts | of_iapp ((v, _) `|-> t) ts = nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts @@ -215,112 +209,123 @@ @ [("_", of_iterm t0)])) ts in of_iterm end; -fun assemble_fun thy is_fun num_args (c, eqns) = +fun assemble_fun gr num_args (c, (vs, eqns)) = let - val assemble_arg = assemble_iterm thy (K false) (K NONE); - val assemble_rhs = assemble_iterm thy is_fun num_args; + val assemble_arg = assemble_iterm (K false) (K NONE); + val assemble_rhs = assemble_iterm (is_some o Graph.get_node gr) num_args; + val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs + |> rev; fun assemble_eqn (args, rhs) = - ([ml_list (map assemble_arg (rev args))], assemble_rhs rhs); - val default_params = map nbe_bound - (Name.invent_list [] "a" ((the o num_args) c)); + ([ml_list (map assemble_arg (rev args) @ dict_params)], assemble_rhs rhs); + val default_params = map nbe_bound (Name.invent_list [] "a" ((the o num_args) c)); val default_eqn = ([ml_list default_params], nbe_const c default_params); in map assemble_eqn eqns @ [default_eqn] end; -fun assemble_eqnss thy is_fun ([], deps) = ([], "") - | assemble_eqnss thy is_fun (eqnss, deps) = +fun assemble_eqnss gr ([], deps) = ([], "") + | assemble_eqnss gr (eqnss, deps) = let val cs = map fst eqnss; - val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss; - val funs = fold (fold (CodeThingol.fold_constnames - (insert (op =))) o map snd o snd) eqnss []; - val bind_funs = map nbe_lookup (filter is_fun funs); + val num_args = cs ~~ map (fn (_, (vs, (args, rhs) :: _)) => + length (maps snd vs) + length args) eqnss; + val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps); val bind_locals = ml_fundefs (map nbe_fun cs ~~ map - (assemble_fun thy is_fun (AList.lookup (op =) num_args)) eqnss); + (assemble_fun gr (AList.lookup (op =) num_args)) eqnss); val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args) |> ml_delay; - in (cs, ml_Let (bind_funs @ [bind_locals]) result) end; + in (cs, ml_Let (bind_deps @ [bind_locals]) result) end; -fun assemble_eval thy is_fun (((vs, ty), t), deps) = - let - val funs = CodeThingol.fold_constnames (insert (op =)) t []; - val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t []; - val bind_funs = map nbe_lookup (filter is_fun funs); - val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees)], - assemble_iterm thy is_fun (K NONE) t)])]; - val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)] - |> ml_delay; - in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end; +fun eqns_of_stmt (_, CodeThingol.Fun (_, [])) = + [] + | eqns_of_stmt (const, CodeThingol.Fun ((vs, _), eqns)) = + [(const, (vs, map fst eqns))] + | eqns_of_stmt (_, CodeThingol.Datatypecons _) = + [] + | eqns_of_stmt (_, CodeThingol.Datatype _) = + [] + | eqns_of_stmt (class, CodeThingol.Class (v, (superclasses, classops))) = + let + val names = map snd superclasses @ map fst classops; + val params = Name.invent_list [] "d" (length names); + fun mk (k, name) = + (name, ([(v, [])], + [([IConst (class, ([], [])) `$$ map IVar params], IVar (nth params k))])); + in map_index mk names end + | eqns_of_stmt (_, CodeThingol.Classrel _) = + [] + | eqns_of_stmt (_, CodeThingol.Classparam _) = + [] + | eqns_of_stmt (inst, CodeThingol.Classinst ((class, (_, arities)), (superinsts, instops))) = + [(inst, (arities, [([], IConst (class, ([], [])) `$$ + map (fn (_, (_, (inst, dicts))) => IConst (inst, (dicts, []))) superinsts + @ map (IConst o snd o fst) instops)]))]; -fun eqns_of_stmt ((_, CodeThingol.Fun (_, [])), _) = - NONE - | eqns_of_stmt ((name, CodeThingol.Fun (_, eqns)), deps) = - SOME ((name, map fst eqns), deps) - | eqns_of_stmt ((_, CodeThingol.Datatypecons _), _) = - NONE - | eqns_of_stmt ((_, CodeThingol.Datatype _), _) = - NONE - | eqns_of_stmt ((_, CodeThingol.Class _), _) = - NONE - | eqns_of_stmt ((_, CodeThingol.Classrel _), _) = - NONE - | eqns_of_stmt ((_, CodeThingol.Classparam _), _) = - NONE - | eqns_of_stmt ((_, CodeThingol.Classinst _), _) = - NONE; +fun compile_stmts stmts_deps = + let + val names = map (fst o fst) stmts_deps; + val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps; + val eqnss = maps (eqns_of_stmt o fst) stmts_deps; + val compiled_deps = names_deps |> maps snd |> distinct (op =) |> subtract (op =) names; + fun compile gr = (eqnss, compiled_deps) |> assemble_eqnss gr |> compile_univs gr |> rpair gr; + in + fold (fn name => Graph.new_node (name, NONE)) names + #> fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps + #> compile + #-> fold (fn (name, univ) => Graph.map_node name (K (SOME univ))) + end; -fun compile_stmts thy is_fun = - map_filter eqns_of_stmt - #> split_list - #> assemble_eqnss thy is_fun - #> compile_univs (Nbe_Functions.get thy); +fun ensure_stmts code = + let + fun add_stmts names gr = if exists (compiled gr) names then gr else gr + |> compile_stmts (map (fn name => ((name, Graph.get_node code name), + Graph.imm_succs code name)) names); + in fold_rev add_stmts (Graph.strong_conn code) end; -fun eval_term thy is_fun = - assemble_eval thy is_fun - #> compile_univs (Nbe_Functions.get thy) +fun assemble_eval gr (((vs, ty), t), deps) = + let + val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t []; + val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps); + val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs + |> rev; + val bind_value = ml_fundefs [(nbe_value, + [([ml_list (map nbe_bound frees @ dict_params)], + assemble_iterm (is_some o Graph.get_node gr) (K NONE) t)])]; + val result = ml_list [nbe_value `$` ml_list + (map nbe_free frees @ map nbe_dfree dict_params)] + |> ml_delay; + in ([nbe_value], ml_Let (bind_deps @ [bind_value]) result) end; + +fun eval_term gr = + assemble_eval gr + #> compile_univs gr #> the_single #> snd; -(** compilation and evaluation **) - -(* ensure global functions *) - -fun ensure_funs thy code = - let - fun add_dep (name, dep) gr = - if can (Graph.get_node gr) name andalso can (Graph.get_node gr) dep - then Graph.add_edge (name, dep) gr else gr; - fun compile' stmts gr = - let - val compiled = compile_stmts thy (defined gr) stmts; - val names = map (fst o fst) stmts; - val deps = maps snd stmts; - in - Nbe_Functions.change thy (fold Graph.new_node compiled - #> fold (fn name => fold (curry add_dep name) deps) names) - end; - val nbe_gr = Nbe_Functions.get thy; - val stmtss = rev (Graph.strong_conn code) - |> (map o map_filter) (fn name => if defined nbe_gr name - then NONE - else SOME ((name, Graph.get_node code name), Graph.imm_succs code name)) - |> filter_out null - in fold compile' stmtss nbe_gr end; +(** evaluation **) (* reification *) fun term_of_univ thy t = let + fun take_until f [] = [] + | take_until f (x::xs) = if f x then [] else x :: take_until f xs; + fun is_dict (Const (c, _)) = + (is_some o CodeName.class_rev thy) c + orelse (is_some o CodeName.classrel_rev thy) c + orelse (is_some o CodeName.instance_rev thy) c + | is_dict (DFree _) = true + | is_dict _ = false; fun of_apps bounds (t, ts) = fold_map (of_univ bounds) ts #>> (fn ts' => list_comb (t, rev ts')) and of_univ bounds (Const (name, ts)) typidx = let + val ts' = take_until is_dict ts; val SOME c = CodeName.const_rev thy name; val T = Code.default_typ thy c; val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T; val typidx' = typidx + maxidx_of_typ T' + 1; - in of_apps bounds (Term.Const (c, T'), ts) typidx' end + in of_apps bounds (Term.Const (c, T'), ts') typidx' end | of_univ bounds (Free (name, ts)) typidx = of_apps bounds (Term.Free (name, dummyT), ts) typidx | of_univ bounds (BVar (name, ts)) typidx = @@ -331,6 +336,33 @@ |-> (fn t' => pair (Term.Abs ("u", dummyT, t'))) in of_univ 0 t 0 |> fst end; +(* function store *) + +structure Nbe_Functions = CodeDataFun +( + type T = univ_gr; + val empty = Graph.empty; + fun merge _ = Graph.merge (K true); + fun purge _ NONE _ = Graph.empty + | purge NONE _ _ = Graph.empty + | purge (SOME thy) (SOME cs) gr = + let + val cs_exisiting = + map_filter (CodeName.const_rev thy) (Graph.keys gr); + val dels = (Graph.all_preds gr + o map (CodeName.const thy) + o filter (member (op =) cs_exisiting) + ) cs; + in Graph.del_nodes dels gr end; +); + +(* compilation, evaluation and reification *) + +fun compile_eval thy code vs_ty_t deps = + (vs_ty_t, deps) + |> eval_term (Nbe_Functions.change thy (ensure_stmts code)) + |> term_of_univ thy; + (* evaluation with type reconstruction *) fun eval thy code t vs_ty_t deps = @@ -349,16 +381,14 @@ error ("Illegal schematic type variables in normalized term: " ^ setmp show_types true (Sign.string_of_term thy) t); in - (vs_ty_t, deps) - |> eval_term thy (defined (ensure_funs thy code)) - |> term_of_univ thy + compile_eval thy code vs_ty_t deps |> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t) |> anno_vars |> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t) |> constrain |> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t) + |> tracing (fn t => "---\n") |> check_tvars - |> tracing (fn _ => "---\n") end; (* evaluation oracle *)