added pre/post-processor equations
authornipkow
Mon, 09 Oct 2006 12:16:29 +0200
changeset 20920 07f279940664
parent 20919 dab803075c62
child 20921 24b8536dcf93
added pre/post-processor equations
src/Pure/Tools/nbe.ML
--- a/src/Pure/Tools/nbe.ML	Mon Oct 09 12:08:33 2006 +0200
+++ b/src/Pure/Tools/nbe.ML	Mon Oct 09 12:16:29 2006 +0200
@@ -24,6 +24,55 @@
 
 (* theory data setup *)
 
+structure NBE_Rewrite = TheoryDataFun
+(struct
+  val name = "Pure/nbe";
+  type T = thm list * thm list
+
+  val empty = ([],[])
+  val copy = I;
+  val extend = I;
+
+  fun merge _ ((pres1,posts1), (pres2,posts2)) =
+    (Library.merge eq_thm (pres1,pres2), Library.merge eq_thm (posts1,posts2))
+
+  fun print _ _ = ()
+end);
+
+val _ = Context.add_setup NBE_Rewrite.init;
+
+fun consts_of_pres thy = 
+  let val pres = fst(NBE_Rewrite.get thy);
+      val rhss = map (snd o Logic.dest_equals o prop_of) pres;
+  in (fold o fold_aterms)
+        (fn Const c => insert (op =) (CodegenConsts.norm_of_typ thy c) | _ => I)
+        rhss []
+  end
+
+
+local
+
+fun attr_pre (thy,thm) =
+ ((Context.map_theory o NBE_Rewrite.map o apfst) (insert eq_thm thm) thy, thm)
+fun attr_post (thy,thm) = 
+ ((Context.map_theory o NBE_Rewrite.map o apsnd) (insert eq_thm thm) thy, thm)
+
+in
+val _ = Context.add_setup
+  (Attrib.add_attributes
+     [("normal_pre", Attrib.no_args attr_pre, "declare pre-theorems for normalization"),
+      ("normal_post", Attrib.no_args attr_post, "declare posy-theorems for normalization")]);
+end;
+
+fun apply_pres thy =
+  let val pres = fst(NBE_Rewrite.get thy)
+  in map (CodegenData.rewrite_func pres) end
+
+fun apply_posts thy =
+  let val posts = snd(NBE_Rewrite.get thy)
+  in Tactic.rewrite false posts end
+
+
 structure NBE_Data = CodeDataFun
 (struct
   val name = "Pure/NBE"
@@ -80,14 +129,16 @@
     val (_, ct) = CodegenData.preprocess_cterm thy const_typ (Thm.cterm_of thy t)
     val t' = Thm.term_of ct;
     val (consts, cs) = CodegenConsts.consts_of thy t';
-    val funcgr = CodegenFuncgr.mk_funcgr thy consts cs;
+    val pre_consts = consts_of_pres thy;
+    val consts' = pre_consts @ consts;
+    val funcgr = CodegenFuncgr.mk_funcgr thy consts' cs;
     val nbe_tab = NBE_Data.get thy;
     val all_consts =
-      CodegenFuncgr.all_deps_of funcgr consts
+      (pre_consts :: CodegenFuncgr.all_deps_of funcgr consts')
       |> (map o filter_out) (Symtab.defined nbe_tab o CodegenNames.const thy)
       |> filter_out null;
     val funs = (map o map)
-      (fn c => (CodegenNames.const thy c, CodegenFuncgr.get_funcs funcgr c)) all_consts;
+      (fn c => (CodegenNames.const thy c, apply_pres thy (CodegenFuncgr.get_funcs funcgr c))) all_consts;
     val _ = tracing (fn () => "new definitions: " ^ (commas o maps (map fst)) funs);
     val _ = generate thy funs;
     val nt = NBE_Eval.eval thy (!tab) t';
@@ -115,15 +166,17 @@
     fun constrain ty t = Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy)
         (K NONE) (K NONE) Name.context false ([t], ty) |> fst;
     val _ = tracing (fn () => "Normalized:\n" ^ NBE_Eval.string_of_nterm nt);
-    val t' = NBE_Codegen.nterm_to_term thy nt;
-    val _ = tracing (fn () =>"Converted back:\n" ^ Display.raw_string_of_term t');
-    val t'' = anno_vars vtab t';
-    val _ = tracing (fn () =>"Vars typed:\n" ^ Display.raw_string_of_term t'');
-    val t''' = constrain ty t'';
-    val _ = if null (Term.term_tvars t''') then () else
+    val t1 = NBE_Codegen.nterm_to_term thy nt;
+    val _ = tracing (fn () =>"Converted back:\n" ^ Display.raw_string_of_term t1);
+    val t2 = anno_vars vtab t1;
+    val _ = tracing (fn () =>"Vars typed:\n" ^ Display.raw_string_of_term t2);
+    val t3 = constrain ty t2;
+    val _ = if null (Term.term_tvars t3) then () else
       error ("Illegal schematic type variables in normalized term: "
-        ^ setmp show_types true (Sign.string_of_term thy) t''');
-  in t''' end;
+        ^ setmp show_types true (Sign.string_of_term thy) t3);
+    val eq = apply_posts thy (Thm.cterm_of thy t3);
+    val t4 = snd(Logic.dest_equals(prop_of eq))
+  in t4 end;
 
 fun norm_print_term ctxt modes t =
   let