diff -r 1d1bddf87353 -r 199c48ec5a09 src/Tools/nbe.ML --- a/src/Tools/nbe.ML Thu Oct 04 19:41:52 2007 +0200 +++ b/src/Tools/nbe.ML Thu Oct 04 19:41:53 2007 +0200 @@ -2,14 +2,7 @@ ID: $Id$ Authors: Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen -Evaluation mechanisms for normalization by evaluation. -*) - -(* -FIXME: -- get rid of BVar (?) - it is only used for terms to be evaluated, not for functions -- proper purge operation - preliminary for... -- really incremental code generation +Normalization by evaluation, based on generic code generator. *) signature NBE = @@ -27,7 +20,8 @@ val univs_ref: (unit -> Univ list) ref val lookup_fun: string -> Univ - val normalization_conv: cterm -> thm + val norm_conv: cterm -> thm + val norm_term: theory -> term -> term val trace: bool ref val setup: theory -> theory @@ -83,25 +77,37 @@ structure Nbe_Functions = CodeDataFun ( - type T = Univ Symtab.table; - val empty = Symtab.empty; - fun merge _ = Symtab.merge (K true); - fun purge _ _ _ = Symtab.empty; + type T = Univ Graph.T; + val empty = Graph.empty; + fun merge _ = Graph.merge (K true); + fun purge _ NONE _ = Graph.empty + | purge NONE _ _ = Graph.empty + | purge (SOME thy) (SOME cs) gr = Graph.empty + (*let + val cs_exisiting = + map_filter (CodeName.const_rev thy) (Graph.keys gr); + val dels = (Graph.all_preds gr + o map (CodeName.const thy) + o filter (member (op =) cs_exisiting) + ) cs; + in Graph.del_nodes dels gr end*); ); +fun defined gr = can (Graph.get_node gr); + (* sandbox communication *) val univs_ref = ref (fn () => [] : Univ list); local -val tab_ref = ref NONE : Univ Symtab.table option ref; +val gr_ref = ref NONE : Nbe_Functions.T option ref; in -fun lookup_fun s = case ! tab_ref +fun lookup_fun s = case ! gr_ref of NONE => error "compile_univs" - | SOME tab => (the o Symtab.lookup tab) s; + | SOME gr => Graph.get_node gr s; fun compile_univs tab ([], _) = [] | compile_univs tab (cs, raw_s) = @@ -109,11 +115,11 @@ val _ = univs_ref := (fn () => []); val s = "Nbe.univs_ref := " ^ raw_s; val _ = tracing (fn () => "\n--- generated code:\n" ^ s) (); - val _ = tab_ref := SOME tab; + val _ = gr_ref := SOME tab; val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n", Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n") (!trace) s; - val _ = tab_ref := NONE; + val _ = gr_ref := NONE; val univs = case !univs_ref () of [] => error "compile_univs" | univs => univs; in cs ~~ univs end; @@ -254,7 +260,7 @@ NONE | eqns_of_stmt ((_, CodeThingol.Classrel _), _) = NONE - | eqns_of_stmt ((_, CodeThingol.Classop _), _) = + | eqns_of_stmt ((_, CodeThingol.Classparam _), _) = NONE | eqns_of_stmt ((_, CodeThingol.Classinst _), _) = NONE; @@ -278,19 +284,27 @@ fun ensure_funs thy code = let - fun compile' stmts tab = + fun add_dep (name, dep) gr = + if can (Graph.get_node gr) name andalso can (Graph.get_node gr) dep + then Graph.add_edge (name, dep) gr else gr; + fun compile' stmts gr = let - val compiled = compile_stmts thy (Symtab.defined tab) stmts; - in Nbe_Functions.change thy (fold Symtab.update compiled) end; - val nbe_tab = Nbe_Functions.get thy; + val compiled = compile_stmts thy (defined gr) stmts; + val names = map (fst o fst) stmts; + val deps = maps snd stmts; + in + Nbe_Functions.change thy (fold Graph.new_node compiled + #> fold (fn name => fold (curry add_dep name) deps) names) + end; + val nbe_gr = Nbe_Functions.get thy; val stmtss = rev (Graph.strong_conn code) - |> (map o map_filter) (fn name => if Symtab.defined nbe_tab name + |> (map o map_filter) (fn name => if defined nbe_gr name then NONE else SOME ((name, Graph.get_node code name), Graph.imm_succs code name)) |> filter_out null - in fold compile' stmtss nbe_tab end; + in fold compile' stmtss nbe_gr end; -(* re-conversion *) +(* reification *) fun term_of_univ thy t = let @@ -333,7 +347,7 @@ ^ setmp show_types true (Sign.string_of_term thy) t); in (vs_ty_t, deps) - |> eval_term thy (Symtab.defined (ensure_funs thy code)) + |> eval_term thy (defined (ensure_funs thy code)) |> term_of_univ thy |> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t) |> anno_vars @@ -346,36 +360,41 @@ (* evaluation oracle *) -exception Normalization of CodeThingol.code * term +exception Norm of CodeThingol.code * term * (CodeThingol.typscheme * CodeThingol.iterm) * string list; -fun normalization_oracle (thy, Normalization (code, t, vs_ty_t, deps)) = +fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) = Logic.mk_equals (t, eval thy code t vs_ty_t deps); -fun normalization_invoke thy code t vs_ty_t deps = - Thm.invoke_oracle_i thy "HOL.normalization" (thy, Normalization (code, t, vs_ty_t, deps)); +fun norm_invoke thy code t vs_ty_t deps = + Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps)); (*FIXME get rid of hardwired theory name*) -fun normalization_conv ct = +fun norm_conv ct = let val thy = Thm.theory_of_cterm ct; fun conv code vs_ty_t deps ct = let val t = Thm.term_of ct; - in normalization_invoke thy code t vs_ty_t deps end; + in norm_invoke thy code t vs_ty_t deps end; in CodePackage.eval_conv thy conv ct end; +fun norm_term thy = + let + fun invoke code vs_ty_t deps t = + eval thy code t vs_ty_t deps; + in CodePackage.eval_term thy invoke #> Code.postprocess_term thy end; + (* evaluation command *) fun norm_print_term ctxt modes t = let val thy = ProofContext.theory_of ctxt; - val ct = Thm.cterm_of thy t; - val (_, t') = (Logic.dest_equals o Thm.prop_of o normalization_conv) ct; - val ty = Term.type_of t'; + val t' = norm_term thy t; + val ty' = Term.type_of t'; val p = PrintMode.with_modes modes (fn () => Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk, - Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty)]) (); + Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty')]) (); in Pretty.writeln p end; @@ -385,7 +404,7 @@ let val ctxt = Toplevel.context_of state in norm_print_term ctxt modes (Syntax.read_term ctxt s) end; -val setup = Theory.add_oracle ("normalization", normalization_oracle) +val setup = Theory.add_oracle ("norm", norm_oracle) local structure P = OuterParse and K = OuterKeyword in