--- a/src/Tools/nbe.ML Thu Oct 04 19:41:52 2007 +0200
+++ b/src/Tools/nbe.ML Thu Oct 04 19:41:53 2007 +0200
@@ -2,14 +2,7 @@
ID: $Id$
Authors: Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen
-Evaluation mechanisms for normalization by evaluation.
-*)
-
-(*
-FIXME:
-- get rid of BVar (?) - it is only used for terms to be evaluated, not for functions
-- proper purge operation - preliminary for...
-- really incremental code generation
+Normalization by evaluation, based on generic code generator.
*)
signature NBE =
@@ -27,7 +20,8 @@
val univs_ref: (unit -> Univ list) ref
val lookup_fun: string -> Univ
- val normalization_conv: cterm -> thm
+ val norm_conv: cterm -> thm
+ val norm_term: theory -> term -> term
val trace: bool ref
val setup: theory -> theory
@@ -83,25 +77,37 @@
structure Nbe_Functions = CodeDataFun
(
- type T = Univ Symtab.table;
- val empty = Symtab.empty;
- fun merge _ = Symtab.merge (K true);
- fun purge _ _ _ = Symtab.empty;
+ type T = Univ Graph.T;
+ val empty = Graph.empty;
+ fun merge _ = Graph.merge (K true);
+ fun purge _ NONE _ = Graph.empty
+ | purge NONE _ _ = Graph.empty
+ | purge (SOME thy) (SOME cs) gr = Graph.empty
+ (*let
+ val cs_exisiting =
+ map_filter (CodeName.const_rev thy) (Graph.keys gr);
+ val dels = (Graph.all_preds gr
+ o map (CodeName.const thy)
+ o filter (member (op =) cs_exisiting)
+ ) cs;
+ in Graph.del_nodes dels gr end*);
);
+fun defined gr = can (Graph.get_node gr);
+
(* sandbox communication *)
val univs_ref = ref (fn () => [] : Univ list);
local
-val tab_ref = ref NONE : Univ Symtab.table option ref;
+val gr_ref = ref NONE : Nbe_Functions.T option ref;
in
-fun lookup_fun s = case ! tab_ref
+fun lookup_fun s = case ! gr_ref
of NONE => error "compile_univs"
- | SOME tab => (the o Symtab.lookup tab) s;
+ | SOME gr => Graph.get_node gr s;
fun compile_univs tab ([], _) = []
| compile_univs tab (cs, raw_s) =
@@ -109,11 +115,11 @@
val _ = univs_ref := (fn () => []);
val s = "Nbe.univs_ref := " ^ raw_s;
val _ = tracing (fn () => "\n--- generated code:\n" ^ s) ();
- val _ = tab_ref := SOME tab;
+ val _ = gr_ref := SOME tab;
val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
(!trace) s;
- val _ = tab_ref := NONE;
+ val _ = gr_ref := NONE;
val univs = case !univs_ref () of [] => error "compile_univs" | univs => univs;
in cs ~~ univs end;
@@ -254,7 +260,7 @@
NONE
| eqns_of_stmt ((_, CodeThingol.Classrel _), _) =
NONE
- | eqns_of_stmt ((_, CodeThingol.Classop _), _) =
+ | eqns_of_stmt ((_, CodeThingol.Classparam _), _) =
NONE
| eqns_of_stmt ((_, CodeThingol.Classinst _), _) =
NONE;
@@ -278,19 +284,27 @@
fun ensure_funs thy code =
let
- fun compile' stmts tab =
+ fun add_dep (name, dep) gr =
+ if can (Graph.get_node gr) name andalso can (Graph.get_node gr) dep
+ then Graph.add_edge (name, dep) gr else gr;
+ fun compile' stmts gr =
let
- val compiled = compile_stmts thy (Symtab.defined tab) stmts;
- in Nbe_Functions.change thy (fold Symtab.update compiled) end;
- val nbe_tab = Nbe_Functions.get thy;
+ val compiled = compile_stmts thy (defined gr) stmts;
+ val names = map (fst o fst) stmts;
+ val deps = maps snd stmts;
+ in
+ Nbe_Functions.change thy (fold Graph.new_node compiled
+ #> fold (fn name => fold (curry add_dep name) deps) names)
+ end;
+ val nbe_gr = Nbe_Functions.get thy;
val stmtss = rev (Graph.strong_conn code)
- |> (map o map_filter) (fn name => if Symtab.defined nbe_tab name
+ |> (map o map_filter) (fn name => if defined nbe_gr name
then NONE
else SOME ((name, Graph.get_node code name), Graph.imm_succs code name))
|> filter_out null
- in fold compile' stmtss nbe_tab end;
+ in fold compile' stmtss nbe_gr end;
-(* re-conversion *)
+(* reification *)
fun term_of_univ thy t =
let
@@ -333,7 +347,7 @@
^ setmp show_types true (Sign.string_of_term thy) t);
in
(vs_ty_t, deps)
- |> eval_term thy (Symtab.defined (ensure_funs thy code))
+ |> eval_term thy (defined (ensure_funs thy code))
|> term_of_univ thy
|> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t)
|> anno_vars
@@ -346,36 +360,41 @@
(* evaluation oracle *)
-exception Normalization of CodeThingol.code * term
+exception Norm of CodeThingol.code * term
* (CodeThingol.typscheme * CodeThingol.iterm) * string list;
-fun normalization_oracle (thy, Normalization (code, t, vs_ty_t, deps)) =
+fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) =
Logic.mk_equals (t, eval thy code t vs_ty_t deps);
-fun normalization_invoke thy code t vs_ty_t deps =
- Thm.invoke_oracle_i thy "HOL.normalization" (thy, Normalization (code, t, vs_ty_t, deps));
+fun norm_invoke thy code t vs_ty_t deps =
+ Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps));
(*FIXME get rid of hardwired theory name*)
-fun normalization_conv ct =
+fun norm_conv ct =
let
val thy = Thm.theory_of_cterm ct;
fun conv code vs_ty_t deps ct =
let
val t = Thm.term_of ct;
- in normalization_invoke thy code t vs_ty_t deps end;
+ in norm_invoke thy code t vs_ty_t deps end;
in CodePackage.eval_conv thy conv ct end;
+fun norm_term thy =
+ let
+ fun invoke code vs_ty_t deps t =
+ eval thy code t vs_ty_t deps;
+ in CodePackage.eval_term thy invoke #> Code.postprocess_term thy end;
+
(* evaluation command *)
fun norm_print_term ctxt modes t =
let
val thy = ProofContext.theory_of ctxt;
- val ct = Thm.cterm_of thy t;
- val (_, t') = (Logic.dest_equals o Thm.prop_of o normalization_conv) ct;
- val ty = Term.type_of t';
+ val t' = norm_term thy t;
+ val ty' = Term.type_of t';
val p = PrintMode.with_modes modes (fn () =>
Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk,
- Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty)]) ();
+ Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty')]) ();
in Pretty.writeln p end;
@@ -385,7 +404,7 @@
let val ctxt = Toplevel.context_of state
in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
-val setup = Theory.add_oracle ("normalization", normalization_oracle)
+val setup = Theory.add_oracle ("norm", norm_oracle)
local structure P = OuterParse and K = OuterKeyword in