src/Tools/nbe.ML
changeset 24839 199c48ec5a09
parent 24713 8b3b6d09ef40
child 24867 e5b55d7be9bb
--- a/src/Tools/nbe.ML	Thu Oct 04 19:41:52 2007 +0200
+++ b/src/Tools/nbe.ML	Thu Oct 04 19:41:53 2007 +0200
@@ -2,14 +2,7 @@
     ID:         $Id$
     Authors:    Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen
 
-Evaluation mechanisms for normalization by evaluation.
-*)
-
-(*
-FIXME:
-- get rid of BVar (?) - it is only used for terms to be evaluated, not for functions
-- proper purge operation - preliminary for...
-- really incremental code generation
+Normalization by evaluation, based on generic code generator.
 *)
 
 signature NBE =
@@ -27,7 +20,8 @@
   val univs_ref: (unit -> Univ list) ref 
   val lookup_fun: string -> Univ
 
-  val normalization_conv: cterm -> thm
+  val norm_conv: cterm -> thm
+  val norm_term: theory -> term -> term
 
   val trace: bool ref
   val setup: theory -> theory
@@ -83,25 +77,37 @@
 
 structure Nbe_Functions = CodeDataFun
 (
-  type T = Univ Symtab.table;
-  val empty = Symtab.empty;
-  fun merge _ = Symtab.merge (K true);
-  fun purge _ _ _ = Symtab.empty;
+  type T = Univ Graph.T;
+  val empty = Graph.empty;
+  fun merge _ = Graph.merge (K true);
+  fun purge _ NONE _ = Graph.empty
+    | purge NONE _ _ = Graph.empty
+    | purge (SOME thy) (SOME cs) gr = Graph.empty
+        (*let
+          val cs_exisiting =
+            map_filter (CodeName.const_rev thy) (Graph.keys gr);
+          val dels = (Graph.all_preds gr
+              o map (CodeName.const thy)
+              o filter (member (op =) cs_exisiting)
+            ) cs;
+        in Graph.del_nodes dels gr end*);
 );
 
+fun defined gr = can (Graph.get_node gr);
+
 (* sandbox communication *)
 
 val univs_ref = ref (fn () => [] : Univ list);
 
 local
 
-val tab_ref = ref NONE : Univ Symtab.table option ref;
+val gr_ref = ref NONE : Nbe_Functions.T option ref;
 
 in
 
-fun lookup_fun s = case ! tab_ref
+fun lookup_fun s = case ! gr_ref
  of NONE => error "compile_univs"
-  | SOME tab => (the o Symtab.lookup tab) s;
+  | SOME gr => Graph.get_node gr s;
 
 fun compile_univs tab ([], _) = []
   | compile_univs tab (cs, raw_s) =
@@ -109,11 +115,11 @@
         val _ = univs_ref := (fn () => []);
         val s = "Nbe.univs_ref := " ^ raw_s;
         val _ = tracing (fn () => "\n--- generated code:\n" ^ s) ();
-        val _ = tab_ref := SOME tab;
+        val _ = gr_ref := SOME tab;
         val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
           Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
           (!trace) s;
-        val _ = tab_ref := NONE;
+        val _ = gr_ref := NONE;
         val univs = case !univs_ref () of [] => error "compile_univs" | univs => univs;
       in cs ~~ univs end;
 
@@ -254,7 +260,7 @@
       NONE
   | eqns_of_stmt ((_, CodeThingol.Classrel _), _) =
       NONE
-  | eqns_of_stmt ((_, CodeThingol.Classop _), _) =
+  | eqns_of_stmt ((_, CodeThingol.Classparam _), _) =
       NONE
   | eqns_of_stmt ((_, CodeThingol.Classinst _), _) =
       NONE;
@@ -278,19 +284,27 @@
 
 fun ensure_funs thy code =
   let
-    fun compile' stmts tab =
+    fun add_dep (name, dep) gr =
+      if can (Graph.get_node gr) name andalso can (Graph.get_node gr) dep
+      then Graph.add_edge (name, dep) gr else gr;
+    fun compile' stmts gr =
       let
-        val compiled = compile_stmts thy (Symtab.defined tab) stmts;
-      in Nbe_Functions.change thy (fold Symtab.update compiled) end;
-    val nbe_tab = Nbe_Functions.get thy;
+        val compiled = compile_stmts thy (defined gr) stmts;
+        val names = map (fst o fst) stmts;
+        val deps = maps snd stmts;
+      in
+        Nbe_Functions.change thy (fold Graph.new_node compiled
+          #> fold (fn name => fold (curry add_dep name) deps) names)
+      end;
+    val nbe_gr = Nbe_Functions.get thy;
     val stmtss = rev (Graph.strong_conn code)
-      |> (map o map_filter) (fn name => if Symtab.defined nbe_tab name
+      |> (map o map_filter) (fn name => if defined nbe_gr name
            then NONE
            else SOME ((name, Graph.get_node code name), Graph.imm_succs code name))
       |> filter_out null
-  in fold compile' stmtss nbe_tab end;
+  in fold compile' stmtss nbe_gr end;
 
-(* re-conversion *)
+(* reification *)
 
 fun term_of_univ thy t =
   let
@@ -333,7 +347,7 @@
         ^ setmp show_types true (Sign.string_of_term thy) t);
   in
     (vs_ty_t, deps)
-    |> eval_term thy (Symtab.defined (ensure_funs thy code))
+    |> eval_term thy (defined (ensure_funs thy code))
     |> term_of_univ thy
     |> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t)
     |> anno_vars
@@ -346,36 +360,41 @@
 
 (* evaluation oracle *)
 
-exception Normalization of CodeThingol.code * term
+exception Norm of CodeThingol.code * term
   * (CodeThingol.typscheme * CodeThingol.iterm) * string list;
 
-fun normalization_oracle (thy, Normalization (code, t, vs_ty_t, deps)) =
+fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) =
   Logic.mk_equals (t, eval thy code t vs_ty_t deps);
 
-fun normalization_invoke thy code t vs_ty_t deps =
-  Thm.invoke_oracle_i thy "HOL.normalization" (thy, Normalization (code, t, vs_ty_t, deps));
+fun norm_invoke thy code t vs_ty_t deps =
+  Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps));
   (*FIXME get rid of hardwired theory name*)
 
-fun normalization_conv ct =
+fun norm_conv ct =
   let
     val thy = Thm.theory_of_cterm ct;
     fun conv code vs_ty_t deps ct =
       let
         val t = Thm.term_of ct;
-      in normalization_invoke thy code t vs_ty_t deps end;
+      in norm_invoke thy code t vs_ty_t deps end;
   in CodePackage.eval_conv thy conv ct end;
 
+fun norm_term thy =
+  let
+    fun invoke code vs_ty_t deps t =
+      eval thy code t vs_ty_t deps;
+  in CodePackage.eval_term thy invoke #> Code.postprocess_term thy end;
+
 (* evaluation command *)
 
 fun norm_print_term ctxt modes t =
   let
     val thy = ProofContext.theory_of ctxt;
-    val ct = Thm.cterm_of thy t;
-    val (_, t') = (Logic.dest_equals o Thm.prop_of o normalization_conv) ct;
-    val ty = Term.type_of t';
+    val t' = norm_term thy t;
+    val ty' = Term.type_of t';
     val p = PrintMode.with_modes modes (fn () =>
       Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk,
-        Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty)]) ();
+        Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty')]) ();
   in Pretty.writeln p end;
 
 
@@ -385,7 +404,7 @@
   let val ctxt = Toplevel.context_of state
   in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
 
-val setup = Theory.add_oracle ("normalization", normalization_oracle)
+val setup = Theory.add_oracle ("norm", norm_oracle)
 
 local structure P = OuterParse and K = OuterKeyword in