adjusted to new codegen_funcgr interface
authorhaftmann
Tue, 30 Jan 2007 08:21:23 +0100
changeset 22213 2dd23002c465
parent 22212 079de24eee65
child 22214 6e9ab159512f
adjusted to new codegen_funcgr interface
src/Pure/Tools/codegen_package.ML
src/Pure/Tools/nbe.ML
--- a/src/Pure/Tools/codegen_package.ML	Tue Jan 30 08:21:22 2007 +0100
+++ b/src/Pure/Tools/codegen_package.ML	Tue Jan 30 08:21:23 2007 +0100
@@ -36,25 +36,6 @@
 
 (* theory data *)
 
-structure Code = CodeDataFun
-(struct
-  val name = "Pure/code";
-  type T = CodegenThingol.code;
-  val empty = CodegenThingol.empty_code;
-  fun merge _ = CodegenThingol.merge_code;
-  fun purge _ NONE _ = CodegenThingol.empty_code
-    | purge NONE _ _ = CodegenThingol.empty_code
-    | purge (SOME thy) (SOME cs) code =
-        let
-          val cs_exisiting =
-            map_filter (CodegenNames.const_rev thy) (Graph.keys code);
-          val dels = (Graph.all_preds code
-              o map (CodegenNames.const thy)
-              o filter (member CodegenConsts.eq_const cs_exisiting)
-            ) cs;
-        in Graph.del_nodes dels code end;
-end);
-
 type appgen = theory -> ((sort -> sort) * Sorts.algebra) * Consts.T
   -> CodegenFuncgr.T
   -> bool * string list option
@@ -73,7 +54,7 @@
 
 structure CodegenPackageData = TheoryDataFun
 (struct
-  val name = "Pure/codegen_package";
+  val name = "Pure/codegen_package_setup";
   type T = appgens * abstypes;
   val empty = (Symtab.empty, (Symtab.empty, Consttab.empty));
   val copy = I;
@@ -83,7 +64,34 @@
   fun print _ _ = ();
 end);
 
-val _ = Context.add_setup (Code.init #> CodegenPackageData.init);
+structure Funcgr = CodegenFuncgrRetrieval (
+  val name = "Pure/codegen_package_thms";
+  fun rewrites thy = [];
+);
+
+fun print_codethms thy =
+  Pretty.writeln o CodegenFuncgr.pretty thy o Funcgr.make thy;
+
+structure Code = CodeDataFun
+(struct
+  val name = "Pure/codegen_package_code";
+  type T = CodegenThingol.code;
+  val empty = CodegenThingol.empty_code;
+  fun merge _ = CodegenThingol.merge_code;
+  fun purge _ NONE _ = CodegenThingol.empty_code
+    | purge NONE _ _ = CodegenThingol.empty_code
+    | purge (SOME thy) (SOME cs) code =
+        let
+          val cs_exisiting =
+            map_filter (CodegenNames.const_rev thy) (Graph.keys code);
+          val dels = (Graph.all_preds code
+              o map (CodegenNames.const thy)
+              o filter (member CodegenConsts.eq_const cs_exisiting)
+            ) cs;
+        in Graph.del_nodes dels code end;
+end);
+
+val _ = Context.add_setup (CodegenPackageData.init #> Funcgr.init #> Code.init);
 
 
 (* preparing defining equations *)
@@ -473,7 +481,7 @@
     val _ = if is_some (CodegenData.get_datatype_of_constr thy c2)
       then error ("Not a function: " ^ CodegenConsts.string_of_const thy c2)
       else ();
-    val funcgr = CodegenFuncgr.make thy [c1, c2];
+    val funcgr = Funcgr.make thy [c1, c2];
     val ty1 = (f o CodegenFuncgr.typ funcgr) c1;
     val ty2 = CodegenFuncgr.typ funcgr c2;
     val _ = if Sign.typ_equiv thy (ty1, ty2) then () else
@@ -545,7 +553,7 @@
   let
     val cs = map_filter (Consttab.lookup ((snd o snd o CodegenPackageData.get) thy))
       (CodegenFuncgr.all funcgr);
-    val funcgr' = CodegenFuncgr.make thy cs;
+    val funcgr' = Funcgr.make thy cs;
     val qnaming = NameSpace.qualified_names NameSpace.default_naming;
     val consttab = Consts.empty
       |> fold (fn c => Consts.declare qnaming
@@ -561,7 +569,7 @@
 fun codegen_term thy t =
   let
     val ct = Thm.cterm_of thy t;
-    val (ct', funcgr) = CodegenFuncgr.make_term thy (K K) ct;
+    val (ct', funcgr) = Funcgr.make_term thy (K (K K)) ct;
     val t' = Thm.term_of ct';
   in generate thy funcgr (SOME []) exprgen_term' t' end;
 
@@ -593,7 +601,7 @@
 
 fun filter_generatable thy targets consts =
   let
-    val (consts', funcgr) = CodegenFuncgr.make_consts thy consts;
+    val (consts', funcgr) = Funcgr.make_consts thy consts;
     val consts'' = generate thy funcgr targets (fold_map oooo perhaps_def_const) consts';
     val consts''' = map_filter (fn (const, SOME _) => SOME const | (_, NONE) => NONE)
       (consts' ~~ consts'');
@@ -622,7 +630,7 @@
     fun generate' thy = case cs
      of [] => []
       | _ =>
-          generate thy (CodegenFuncgr.make thy cs) targets
+          generate thy (Funcgr.make thy cs) targets
             (fold_map oooo ensure_def_const') cs;
     fun serialize' [] code seri =
           seri NONE code 
@@ -634,12 +642,11 @@
     (map (serialize' cs code) (map_filter snd seris'); ())
   end;
 
-val (codeK, code_abstypeK, code_axiomsK) =
-  ("code_gen", "code_abstype", "code_axioms");
+fun print_codethms_e thy =
+  print_codethms thy o map (CodegenConsts.read_const thy);
 
-in
 
-val code_bareP = (
+val code_exprP = (
     (Scan.repeat P.term
     -- Scan.repeat (P.$$$ "(" |--
         P.name -- P.arguments
@@ -647,12 +654,17 @@
     >> (fn (raw_cs, seris) => code raw_cs seris)
   );
 
+val (print_codethmsK, codeK, code_abstypeK, code_axiomsK) =
+  ("print_codethms", "code_gen", "code_abstype", "code_axioms");
+
+in
+
 val codeP =
   OuterSyntax.improper_command codeK "generate and serialize executable code for constants"
-    K.diag (P.!!! code_bareP >> (fn f => Toplevel.keep (f o Toplevel.theory_of)));
+    K.diag (P.!!! code_exprP >> (fn f => Toplevel.keep (f o Toplevel.theory_of)));
 
 fun codegen_command thy cmd =
-  case Scan.read OuterLex.stopper (P.!!! code_bareP) ((filter OuterLex.is_proper o OuterSyntax.scan) cmd)
+  case Scan.read OuterLex.stopper (P.!!! code_exprP) ((filter OuterLex.is_proper o OuterSyntax.scan) cmd)
    of SOME f => (writeln "Now generating code..."; f thy)
     | NONE => error ("Bad directive " ^ quote cmd);
 
@@ -669,7 +681,15 @@
     >> (Toplevel.theory o constsubst_e)
   );
 
-val _ = OuterSyntax.add_parsers [codeP, code_abstypeP, code_axiomsP];
+val print_codethmsP =
+  OuterSyntax.improper_command print_codethmsK "print code theorems of this theory" OuterKeyword.diag
+    (Scan.option (P.$$$ "(" |-- Scan.repeat P.term --| P.$$$ ")")
+      >> (fn NONE => CodegenData.print_thms
+           | SOME cs => fn thy => print_codethms_e thy cs)
+      >> (fn f => Toplevel.no_timing o Toplevel.unknown_theory
+      o Toplevel.keep (f o Toplevel.theory_of)));
+
+val _ = OuterSyntax.add_parsers [codeP, code_abstypeP, code_axiomsP, print_codethmsP];
 
 end; (* local *)
 
--- a/src/Pure/Tools/nbe.ML	Tue Jan 30 08:21:22 2007 +0100
+++ b/src/Pure/Tools/nbe.ML	Tue Jan 30 08:21:23 2007 +0100
@@ -57,22 +57,11 @@
     )
   end;
 
-fun consts_of_pres thy = 
-  let
-    val ctxt = ProofContext.init thy;
-    val pres = fst (NBE_Rewrite.get thy);
-    val rhss = map (snd o Logic.dest_equals o prop_of o LocalDefs.meta_rewrite_rule ctxt) pres;
-  in
-    (fold o fold_aterms)
-      (fn Const c => insert (op =) (CodegenConsts.norm_of_typ thy c) | _ => I)
-      rhss []
-  end;
-
-fun apply_pres thy =
+fun the_pres thy =
   let
     val ctxt = ProofContext.init thy;
     val pres = (map (LocalDefs.meta_rewrite_rule ctxt) o fst) (NBE_Rewrite.get thy)
-  in map (CodegenFunc.rewrite_func pres) end
+  in pres end
 
 fun apply_posts thy =
   let
@@ -80,19 +69,25 @@
     val posts = (map (LocalDefs.meta_rewrite_rule ctxt) o snd) (NBE_Rewrite.get thy)
   in MetaSimplifier.rewrite false posts end
 
+(* theorem store *)
+
+structure Funcgr = CodegenFuncgrRetrieval (
+  val name = "Pure/nbe_thms";
+  val rewrites = the_pres;
+);
 
 (* code store *)
 
 structure NBE_Data = CodeDataFun
 (struct
-  val name = "Pure/NBE"
-  type T = NBE_Eval.Univ Symtab.table
-  val empty = Symtab.empty
-  fun merge _ = Symtab.merge (K true)
-  fun purge _ _ _ = Symtab.empty
+  val name = "Pure/mbe";
+  type T = NBE_Eval.Univ Symtab.table;
+  val empty = Symtab.empty;
+  fun merge _ = Symtab.merge (K true);
+  fun purge _ _ _ = Symtab.empty;
 end);
 
-val _ = Context.add_setup NBE_Data.init;
+val _ = Context.add_setup (Funcgr.init #> NBE_Data.init);
 
 
 (** norm by eval **)
@@ -118,31 +113,27 @@
            use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
                 Output.tracing o enclose "\n--- compiler echo (with error!):\n" "\n---\n")
             (!trace) s);
-
     val _ = tracing (fn () => "new definitions: " ^ (commas o maps (map fst)) funs) ();
     val _ = tab := NBE_Data.get thy;;
     val _ = Library.seq (use_code o NBE_Codegen.generate thy
       (fn s => Symtab.defined (!tab) s)) funs;
   in NBE_Data.change thy (K (!tab)) end;
 
-fun ensure_funs thy t =
+fun ensure_funs thy funcgr t =
   let
     val consts = CodegenConsts.consts_of thy t;
-    val pre_consts = consts_of_pres thy;
-    val consts' = pre_consts @ consts;
-    val funcgr = CodegenFuncgr.make thy consts';
     val nbe_tab = NBE_Data.get thy;
-    val all_consts =
-      (pre_consts :: CodegenFuncgr.deps funcgr consts')
-      |> (map o filter_out) (Symtab.defined nbe_tab o CodegenNames.const thy)
-      |> filter_out null;
-    val funs = (map o map)
-      (fn c => (CodegenNames.const thy c, apply_pres thy (CodegenFuncgr.funcs funcgr c))) all_consts;
-  in generate thy funs end;
+  in
+    CodegenFuncgr.deps funcgr consts
+    |> (map o filter_out) (Symtab.defined nbe_tab o CodegenNames.const thy)
+    |> filter_out null
+    |> (map o map) (fn c => (CodegenNames.const thy c, CodegenFuncgr.funcs funcgr c))
+    |> generate thy
+  end;
 
 (* term evaluation *)
 
-fun eval_term thy t =
+fun eval_term thy funcgr t =
   let
     fun subst_Frees [] = I
       | subst_Frees inst =
@@ -157,7 +148,7 @@
     val ty = type_of t;
     fun constrain t = Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy)
       (K NONE) (K NONE) Name.context false ([t], ty) |> fst;
-    val _ = ensure_funs thy t;
+    val _ = ensure_funs thy funcgr t;
   in
     t
     |> tracing (fn t => "Input:\n" ^ Display.raw_string_of_term t)
@@ -173,13 +164,13 @@
 
 (* evaluation oracle *)
 
-exception Normalization of term;
+exception Normalization of CodegenFuncgr.T * term;
 
-fun normalization_oracle (thy, Normalization t) =
-  Logic.mk_equals (t, eval_term thy t);
+fun normalization_oracle (thy, Normalization (funcgr, t)) =
+  Logic.mk_equals (t, eval_term thy funcgr t);
 
-fun normalization_invoke thy t =
-  Thm.invoke_oracle_i thy "Pure.normalization" (thy, Normalization t);
+fun normalization_invoke thy funcgr t =
+  Thm.invoke_oracle_i thy "Pure.normalization" (thy, Normalization (funcgr, t));
 
 in
 
@@ -188,10 +179,10 @@
 fun normalization_conv ct =
   let
     val thy = Thm.theory_of_cterm ct;
-    fun mk drop_classes ct thm1 =
+    fun mk funcgr drop_classes ct thm1 =
       let
         val t = Thm.term_of ct;
-        val thm2 = normalization_invoke thy t;
+        val thm2 = normalization_invoke thy funcgr t;
         val thm3 = apply_posts thy ((snd o Drule.dest_equals o Thm.cprop_of) thm2);
         val thm23 = drop_classes (Thm.transitive thm2 thm3);
       in
@@ -199,7 +190,7 @@
           error ("normalization_conv - could not construct proof:\n"
           ^ (cat_lines o map string_of_thm) [thm1, thm2, thm3])
       end;
-  in fst (CodegenFuncgr.make_term thy mk ct) end;
+  in fst (Funcgr.make_term thy mk ct) end;
 
 fun norm_print_term ctxt modes t =
   let