moved a lot to codegen_func.ML
authorhaftmann
Tue, 09 Jan 2007 08:31:48 +0100
changeset 22033 8e19bad4125f
parent 22032 979671292fbe
child 22034 44ab6c04b3dc
moved a lot to codegen_func.ML
src/Pure/Tools/codegen_data.ML
src/Pure/Tools/codegen_func.ML
src/Pure/Tools/nbe.ML
--- a/src/Pure/Tools/codegen_data.ML	Tue Jan 09 08:31:47 2007 +0100
+++ b/src/Pure/Tools/codegen_data.ML	Tue Jan 09 08:31:48 2007 +0100
@@ -29,10 +29,8 @@
 
   val print_thms: theory -> unit
 
-  val typ_func: theory -> thm -> typ
   val typ_funcs: theory -> CodegenConsts.const * thm list -> typ
-  val rewrite_func: thm list -> thm -> thm
-  val preprocess_cterm: theory -> cterm -> thm
+  val preprocess_cterm: cterm -> thm
 
   val trace: bool ref
 end;
@@ -107,144 +105,6 @@
 
 (** code theorems **)
 
-(* making rewrite theorems *)
-
-fun bad_thm msg thm =
-  error (msg ^ ": " ^ string_of_thm thm);
-
-fun check_rew thy thm =
-  let
-    val (lhs, rhs) = (Logic.dest_equals o Thm.prop_of) thm;
-    fun vars_of t = fold_aterms
-     (fn Var (v, _) => insert (op =) v
-       | Free _ => bad_thm "Illegal free variable in rewrite theorem" thm
-       | _ => I) t [];
-    fun tvars_of t = fold_term_types
-     (fn _ => fold_atyps (fn TVar (v, _) => insert (op =) v
-                          | TFree _ => bad_thm "Illegal free type variable in rewrite theorem" thm)) t [];
-    val lhs_vs = vars_of lhs;
-    val rhs_vs = vars_of rhs;
-    val lhs_tvs = tvars_of lhs;
-    val rhs_tvs = tvars_of lhs;
-    val _ = if null (subtract (op =) lhs_vs rhs_vs)
-      then ()
-      else bad_thm "Free variables on right hand side of rewrite theorems" thm
-    val _ = if null (subtract (op =) lhs_tvs rhs_tvs)
-      then ()
-      else bad_thm "Free type variables on right hand side of rewrite theorems" thm
-  in thm end;
-
-fun mk_rew thy thm =
-  let
-    val thms = (#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy thm;
-  in
-    map (check_rew thy) thms
-  end;
-
-
-(* making function theorems *)
-
-fun typ_func thy = snd o dest_Const o fst o strip_comb
-  o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of;
-
-val strict_functyp = ref true;
-
-fun dest_func thy = apfst dest_Const o strip_comb
-  o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of
-  o Drule.fconv_rule Drule.beta_eta_conversion;
-
-fun mk_head thy thm =
-  ((CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm, thm);
-
-fun check_func thy thm = case try (dest_func thy) thm
- of SOME (c_ty as (c, ty), args) =>
-      let
-        val _ =
-          if has_duplicates (op =)
-            ((fold o fold_aterms) (fn Var (v, _) => cons v
-              | _ => I
-            ) args [])
-          then bad_thm "Repeated variables on left hand side of function equation" thm
-          else ()
-        fun no_abs (Abs _) = bad_thm "Abstraction on left hand side of function equation" thm 
-          | no_abs (t1 $ t2) = (no_abs t1; no_abs t2)
-          | no_abs _ = ();
-        val _ = map no_abs args;
-        val is_classop = (is_some o AxClass.class_of_param thy) c;
-        val const = CodegenConsts.norm_of_typ thy c_ty;
-        val ty_decl = CodegenConsts.disc_typ_of_const thy
-          (snd o CodegenConsts.typ_of_inst thy) const;
-        val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
-      in if Sign.typ_equiv thy (ty_decl, ty)
-        then SOME (const, thm)
-        else (if is_classop
-            then if !strict_functyp
-              then error
-              else warning #> K NONE
-          else if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
-            then warning #> (K o SOME) (const, thm)
-          else if !strict_functyp
-            then error
-          else warning #> K NONE)
-          ("Type\n" ^ string_of_typ ty
-           ^ "\nof function theorem\n"
-           ^ string_of_thm thm
-           ^ "\nis strictly less general than declared function type\n"
-           ^ string_of_typ ty_decl)
-      end
-  | NONE => bad_thm "Not a function equation" thm;
-
-fun check_typ_classop thy thm =
-  let
-    val (c_ty as (c, ty), _) = dest_func thy thm;  
-  in case AxClass.class_of_param thy c
-   of SOME class => let
-        val const = CodegenConsts.norm_of_typ thy c_ty;
-        val ty_decl = CodegenConsts.disc_typ_of_const thy
-            (snd o CodegenConsts.typ_of_inst thy) const;
-        val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
-      in if Sign.typ_equiv thy (ty_decl, ty)
-        then thm
-        else error
-          ("Type\n" ^ string_of_typ ty
-           ^ "\nof function theorem\n"
-           ^ string_of_thm thm
-           ^ "\nis strictly less general than declared function type\n"
-           ^ string_of_typ ty_decl)
-      end
-    | NONE => thm
-  end;
-
-fun mk_func thy raw_thm =
-  mk_rew thy raw_thm
-  |> map_filter (check_func thy);
-
-fun get_prim_def_funcs thy c =
-  let
-    fun constrain thm0 thm = case AxClass.class_of_param thy (fst c)
-     of SOME _ =>
-          let
-            val ty_decl = CodegenConsts.disc_typ_of_classop thy c;
-            val max = maxidx_of_typ ty_decl + 1;
-            val thm = Thm.incr_indexes max thm;
-            val ty = typ_func thy thm;
-            val (env, _) = Sign.typ_unify thy (ty_decl, ty) (Vartab.empty, max);
-            val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
-              cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
-          in Thm.instantiate (instT, []) thm end
-      | NONE => thm
-  in case CodegenConsts.find_def thy c
-   of SOME ((_, thm), _) =>
-        thm
-        |> Thm.transfer thy
-        |> try (map snd o mk_func thy)
-        |> these
-        |> map (constrain thm)
-        |> map (CodegenFunc.expand_eta thy ~1)
-    | NONE => []
-  end;
-
-
 (* pairs of (selected, deleted) function theorems *)
 
 type sdthms = thm list Susp.T * thm list;
@@ -529,31 +389,18 @@
 
 (** theorem transformation and certification **)
 
-fun rewrite_func rewrites thm =
-  let
-    val rewrite = MetaSimplifier.rewrite false rewrites;
-    val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o Thm.cprop_of) thm;
-    val Const ("==", _) = Thm.term_of ct_eq;
-    val (ct_f, ct_args) = Drule.strip_comb ct_lhs;
-    val rhs' = rewrite ct_rhs;
-    val args' = map rewrite ct_args;
-    val lhs' = Thm.symmetric (fold (fn th1 => fn th2 => Thm.combination th2 th1)
-      args' (Thm.reflexive ct_f));
-  in
-    Thm.transitive (Thm.transitive lhs' thm) rhs'
-  end handle Bind => raise ERROR "rewrite_func"
-
-fun common_typ_funcs thy [] = []
-  | common_typ_funcs thy [thm] = [thm]
-  | common_typ_funcs thy thms =
+fun common_typ_funcs [] = []
+  | common_typ_funcs [thm] = [thm]
+  | common_typ_funcs thms =
       let
+        val thy = Thm.theory_of_thm (hd thms)
         fun incr_thm thm max =
           let
             val thm' = incr_indexes max thm;
             val max' = Thm.maxidx_of thm' + 1;
           in (thm', max') end;
         val (thms', maxidx) = fold_map incr_thm thms 0;
-        val (ty1::tys) = map (typ_func thy) thms';
+        val (ty1::tys) = map CodegenFunc.typ_func thms';
         fun unify ty env = Sign.typ_unify thy (ty1, ty) env
           handle Type.TUNIFY =>
             error ("Type unificaton failed, while unifying function equations\n"
@@ -568,8 +415,8 @@
 fun certify_const thy c c_thms =
   let
     fun cert (c', thm) = if CodegenConsts.eq_const (c, c')
-      then thm else bad_thm ("Wrong head of function equation,\nexpected constant "
-        ^ CodegenConsts.string_of_const thy c) thm
+      then thm else error ("Wrong head of function equation,\nexpected constant "
+        ^ CodegenConsts.string_of_const thy c ^ "\n" ^ string_of_thm thm)
   in map cert c_thms end;
 
 fun mk_cos tyco vs cos =
@@ -647,9 +494,9 @@
 
 (** interfaces **)
 
-fun add_func thm thy =
+fun gen_add_func mk_func thm thy =
   let
-    val thms = mk_func thy thm;
+    val thms = mk_func thm;
     val cs = map fst thms;
   in
     map_exec_purge (SOME cs) (map_funcs 
@@ -657,11 +504,12 @@
        (c, (Susp.value [], [])) (add_thm thm)) thms)) thy
   end;
 
-fun add_func_legacy thm = setmp strict_functyp false (add_func thm);
+val add_func = gen_add_func CodegenFunc.mk_func;
+val add_func_legacy = gen_add_func CodegenFunc.legacy_mk_func;
 
 fun del_func thm thy =
   let
-    val thms = mk_func thy thm;
+    val thms = CodegenFunc.mk_func thm;
     val cs = map fst thms;
   in
     map_exec_purge (SOME cs) (map_funcs
@@ -672,7 +520,7 @@
 fun add_funcl (c, lthms) thy =
   let
     val c' = CodegenConsts.norm thy c;
-    val lthms' = certificate thy (fn thy => certify_const thy c' o maps (mk_func thy)) lthms;
+    val lthms' = certificate thy (fn thy => certify_const thy c' o maps (CodegenFunc.mk_func)) lthms;
   in
     map_exec_purge (SOME [c]) (map_funcs (Consttab.map_default (c', (Susp.value [], []))
       (add_lthms lthms'))) thy
@@ -699,10 +547,10 @@
   in map_exec_purge (SOME consts) del thy end;
 
 fun add_inline thm thy =
-  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (insert eq_thm) (mk_rew thy thm)) thy;
+  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (insert eq_thm) (CodegenFunc.mk_rew thm)) thy;
 
 fun del_inline thm thy =
-  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (remove eq_thm) (mk_rew thy thm)) thy ;
+  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (remove eq_thm) (CodegenFunc.mk_rew thm)) thy ;
 
 fun add_inline_proc f =
   (map_exec_purge NONE o map_preproc o apfst o apsnd) (cons (serial (), f));
@@ -715,12 +563,12 @@
 fun gen_apply_inline_proc prep post thy f x =
   let
     val cts = prep x;
-    val rews = map (check_rew thy) (f thy cts);
+    val rews = map CodegenFunc.check_rew (f thy cts);
   in post rews x end;
 
 val apply_inline_proc = gen_apply_inline_proc (maps
   ((fn [args, rhs] => rhs :: (snd o Drule.strip_comb) args) o snd o Drule.strip_comb o Thm.cprop_of))
-  (fn rews => map (rewrite_func rews));
+  (fn rews => map (CodegenFunc.rewrite_func rews));
 val apply_inline_proc_cterm = gen_apply_inline_proc single
   (MetaSimplifier.rewrite false);
 
@@ -728,11 +576,11 @@
   | apply_preproc thy f (thms as (thm :: _)) =
       let
         val thms' = f thy thms;
-        val c = (CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm;
-      in (certify_const thy c o map (mk_head thy)) thms' end;
+        val c = (CodegenConsts.norm_of_typ thy o fst o CodegenFunc.dest_func) thm;
+      in (certify_const thy c o map CodegenFunc.mk_head) thms' end;
 
 fun cmp_thms thy =
-  make_ord (fn (thm1, thm2) => not (Sign.typ_instance thy (typ_func thy thm1, typ_func thy thm2)));
+  make_ord (fn (thm1, thm2) => not (Sign.typ_instance thy (CodegenFunc.typ_func thm1, CodegenFunc.typ_func thm2)));
 
 fun rhs_conv conv thm =
   let
@@ -744,19 +592,23 @@
 fun preprocess thy thms =
   thms
   |> fold (fn (_, f) => apply_preproc thy f) ((#preprocs o the_preproc o get_exec) thy)
-  |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy))
+  |> map (CodegenFunc.rewrite_func ((#inlines o the_preproc o get_exec) thy))
   |> fold (fn (_, f) => apply_inline_proc thy f) ((#inline_procs o the_preproc o get_exec) thy)
 (*FIXME - must check: rewrite rule, function equation, proper constant |> map (snd o check_func false thy) *)
   |> sort (cmp_thms thy)
-  |> common_typ_funcs thy;
+  |> common_typ_funcs;
 
-fun preprocess_cterm thy ct =
-  ct
-  |> Thm.reflexive
-  |> fold (rhs_conv o MetaSimplifier.rewrite false o single)
-    ((#inlines o the_preproc o get_exec) thy)
-  |> fold (fn (_, f) => rhs_conv (apply_inline_proc_cterm thy f))
-    ((#inline_procs o the_preproc o get_exec) thy)
+fun preprocess_cterm ct =
+  let
+    val thy = Thm.theory_of_cterm ct
+  in
+    ct
+    |> Thm.reflexive
+    |> fold (rhs_conv o MetaSimplifier.rewrite false o single)
+      ((#inlines o the_preproc o get_exec) thy)
+    |> fold (fn (_, f) => rhs_conv (apply_inline_proc_cterm thy f))
+      ((#inline_procs o the_preproc o get_exec) thy)
+  end;
 
 end; (*local*)
 
@@ -768,7 +620,7 @@
       |> these
       |> map (Thm.transfer thy);
     val funcs_2 = case funcs_1
-     of [] => get_prim_def_funcs thy c
+     of [] => CodegenFunc.get_prim_def_funcs thy c
       | xs => xs;
     fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals
       o ObjectLogic.drop_judgment thy o Drule.plain_prop_of);
@@ -789,9 +641,9 @@
 fun typ_funcs thy (c as (name, _), []) = (case AxClass.class_of_param thy name
      of SOME class => CodegenConsts.disc_typ_of_classop thy c
       | NONE => (case Option.map (Susp.force o fst) (Consttab.lookup ((the_funcs o get_exec) thy) c)
-         of SOME [eq] => typ_func thy eq
+         of SOME [eq] => CodegenFunc.typ_func eq
           | _ => Sign.the_const_type thy name))
-  | typ_funcs thy (_, eq :: _) = typ_func thy eq;
+  | typ_funcs thy (_, eq :: _) = CodegenFunc.typ_func eq;
 
 
 (** code attributes **)
--- a/src/Pure/Tools/codegen_func.ML	Tue Jan 09 08:31:47 2007 +0100
+++ b/src/Pure/Tools/codegen_func.ML	Tue Jan 09 08:31:48 2007 +0100
@@ -2,43 +2,152 @@
     ID:         $Id$
     Author:     Florian Haftmann, TU Muenchen
 
-Handling defining equations ("func"s) for code generator framework
+Handling defining equations ("func"s) for code generator framework.
 *)
 
-(* FIXME move various stuff here *)
-
 signature CODEGEN_FUNC =
 sig
-  val expand_eta: theory -> int -> thm -> thm
+  val check_rew: thm -> thm
+  val mk_rew: thm -> thm list
+  val check_func: thm -> (CodegenConsts.const * thm) option
+  val mk_func: thm -> (CodegenConsts.const * thm) list
+  val dest_func: thm -> (string * typ) * term list
+  val mk_head: thm -> CodegenConsts.const * thm
+  val typ_func: thm -> typ
+  val legacy_mk_func: thm -> (CodegenConsts.const * thm) list
+  val expand_eta: int -> thm -> thm
+  val rewrite_func: thm list -> thm -> thm
+  val get_prim_def_funcs: theory -> string * typ list -> thm list
 end;
 
 structure CodegenFunc : CODEGEN_FUNC =
 struct
 
-(* FIXME get rid of this code duplication *)
-val purify_name =
+fun lift_thm_thy f thm = f (Thm.theory_of_thm thm) thm;
+
+fun bad_thm msg thm =
+  error (msg ^ ": " ^ string_of_thm thm);
+
+
+(* making rewrite theorems *)
+
+fun check_rew thm =
+  let
+    val thy = Thm.theory_of_thm thm;
+    val (lhs, rhs) = (Logic.dest_equals o Thm.prop_of) thm;
+    fun vars_of t = fold_aterms
+     (fn Var (v, _) => insert (op =) v
+       | Free _ => bad_thm "Illegal free variable in rewrite theorem" thm
+       | _ => I) t [];
+    fun tvars_of t = fold_term_types
+     (fn _ => fold_atyps (fn TVar (v, _) => insert (op =) v
+                          | TFree _ => bad_thm "Illegal free type variable in rewrite theorem" thm)) t [];
+    val lhs_vs = vars_of lhs;
+    val rhs_vs = vars_of rhs;
+    val lhs_tvs = tvars_of lhs;
+    val rhs_tvs = tvars_of lhs;
+    val _ = if null (subtract (op =) lhs_vs rhs_vs)
+      then ()
+      else bad_thm "Free variables on right hand side of rewrite theorems" thm
+    val _ = if null (subtract (op =) lhs_tvs rhs_tvs)
+      then ()
+      else bad_thm "Free type variables on right hand side of rewrite theorems" thm
+  in thm end;
+
+fun mk_rew thm =
   let
-    fun is_valid s = Symbol.is_ascii_letter s orelse Symbol.is_ascii_digit s orelse s = "'";
-    val is_junk = not o is_valid andf Symbol.not_eof;
-    val junk = Scan.many is_junk;
-    val scan_valids = Symbol.scanner "Malformed input"
-      ((junk |--
-        (Scan.optional (Scan.one Symbol.is_ascii_letter) "x" ^^ (Scan.many is_valid >> implode)
-        --| junk))
-      -- Scan.repeat ((Scan.many1 is_valid >> implode) --| junk) >> op ::);
-  in explode #> scan_valids #> space_implode "_" end;
+    val thy = Thm.theory_of_thm thm;
+    val thms = (#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy thm;
+  in
+    map check_rew thms
+  end;
+
+
+(* making function theorems *)
+
+val typ_func = lift_thm_thy (fn thy => snd o dest_Const o fst o strip_comb
+  o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of);
+
+val dest_func = lift_thm_thy (fn thy => apfst dest_Const o strip_comb
+  o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of
+  o Drule.fconv_rule Drule.beta_eta_conversion);
+
+val mk_head = lift_thm_thy (fn thy => fn thm =>
+  ((CodegenConsts.norm_of_typ thy o fst o dest_func) thm, thm));
 
-val purify_lower =
-  explode
-  #> (fn cs => (if forall Symbol.is_ascii_upper cs
-        then map else nth_map 0) Symbol.to_ascii_lower cs)
-  #> implode;
+fun gen_check_func strict_functyp thm = case try dest_func thm
+ of SOME (c_ty as (c, ty), args) =>
+      let
+        val thy = Thm.theory_of_thm thm;
+        val _ =
+          if has_duplicates (op =)
+            ((fold o fold_aterms) (fn Var (v, _) => cons v
+              | _ => I
+            ) args [])
+          then bad_thm "Repeated variables on left hand side of function equation" thm
+          else ()
+        fun no_abs (Abs _) = bad_thm "Abstraction on left hand side of function equation" thm 
+          | no_abs (t1 $ t2) = (no_abs t1; no_abs t2)
+          | no_abs _ = ();
+        val _ = map no_abs args;
+        val is_classop = (is_some o AxClass.class_of_param thy) c;
+        val const = CodegenConsts.norm_of_typ thy c_ty;
+        val ty_decl = CodegenConsts.disc_typ_of_const thy
+          (snd o CodegenConsts.typ_of_inst thy) const;
+        val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
+        val error_warning = if strict_functyp
+          then error
+          else warning #> K NONE
+      in if Sign.typ_equiv thy (ty_decl, ty)
+        then SOME (const, thm)
+        else (if is_classop
+            then error_warning
+          else if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
+            then warning #> (K o SOME) (const, thm)
+          else error_warning)
+          ("Type\n" ^ string_of_typ ty
+           ^ "\nof function theorem\n"
+           ^ string_of_thm thm
+           ^ "\nis strictly less general than declared function type\n"
+           ^ string_of_typ ty_decl)
+      end
+  | NONE => bad_thm "Not a function equation" thm;
 
-fun purify_var "" = "x"
-  | purify_var v = (purify_name #> purify_lower) v;
+val check_func = gen_check_func true;
+val legacy_check_func = gen_check_func false;
 
-fun expand_eta thy k thm =
+fun check_typ_classop thm =
   let
+    val thy = Thm.theory_of_thm thm;
+    val (c_ty as (c, ty), _) = dest_func thm;
+  in case AxClass.class_of_param thy c
+   of SOME class => let
+        val const = CodegenConsts.norm_of_typ thy c_ty;
+        val ty_decl = CodegenConsts.disc_typ_of_const thy
+            (snd o CodegenConsts.typ_of_inst thy) const;
+        val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
+      in if Sign.typ_equiv thy (ty_decl, ty)
+        then thm
+        else error
+          ("Type\n" ^ string_of_typ ty
+           ^ "\nof function theorem\n"
+           ^ string_of_thm thm
+           ^ "\nis strictly less general than declared function type\n"
+           ^ string_of_typ ty_decl)
+      end
+    | NONE => thm
+  end;
+
+fun gen_mk_func check_func = map_filter check_func o mk_rew;
+val mk_func = gen_mk_func check_func;
+val legacy_mk_func = gen_mk_func legacy_check_func;
+
+
+(* utilities *)
+
+fun expand_eta k thm =
+  let
+    val thy = Thm.theory_of_thm thm;
     val (lhs, rhs) = (Logic.dest_equals o Drule.plain_prop_of) thm;
     val (head, args) = strip_comb lhs;
     val l = if k = ~1
@@ -48,7 +157,7 @@
     fun get_name _ 0 used = ([], used)
       | get_name (Abs (v, ty, t)) k used =
           used
-          |> Name.variants [purify_var v]
+          |> Name.variants [v]
           ||>> get_name t (k - 1)
           |>> (fn ([v'], vs') => (v', ty) :: vs')
       | get_name t k used = 
@@ -68,4 +177,43 @@
     fold (fn refl => fn thm => Thm.combination thm refl) vs_refl thm
   end;
 
+fun get_prim_def_funcs thy c =
+  let
+    fun constrain thm0 thm = case AxClass.class_of_param thy (fst c)
+     of SOME _ =>
+          let
+            val ty_decl = CodegenConsts.disc_typ_of_classop thy c;
+            val max = maxidx_of_typ ty_decl + 1;
+            val thm = Thm.incr_indexes max thm;
+            val ty = typ_func thm;
+            val (env, _) = Sign.typ_unify thy (ty_decl, ty) (Vartab.empty, max);
+            val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
+              cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
+          in Thm.instantiate (instT, []) thm end
+      | NONE => thm
+  in case CodegenConsts.find_def thy c
+   of SOME ((_, thm), _) =>
+        thm
+        |> Thm.transfer thy
+        |> try (map snd o mk_func)
+        |> these
+        |> map (constrain thm)
+        |> map (expand_eta ~1)
+    | NONE => []
+  end;
+
+fun rewrite_func rewrites thm =
+  let
+    val rewrite = MetaSimplifier.rewrite false rewrites;
+    val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o Thm.cprop_of) thm;
+    val Const ("==", _) = Thm.term_of ct_eq;
+    val (ct_f, ct_args) = Drule.strip_comb ct_lhs;
+    val rhs' = rewrite ct_rhs;
+    val args' = map rewrite ct_args;
+    val lhs' = Thm.symmetric (fold (fn th1 => fn th2 => Thm.combination th2 th1)
+      args' (Thm.reflexive ct_f));
+  in
+    Thm.transitive (Thm.transitive lhs' thm) rhs'
+  end handle Bind => raise ERROR "rewrite_func"
+
 end;
--- a/src/Pure/Tools/nbe.ML	Tue Jan 09 08:31:47 2007 +0100
+++ b/src/Pure/Tools/nbe.ML	Tue Jan 09 08:31:48 2007 +0100
@@ -72,7 +72,7 @@
   let
     val ctxt = ProofContext.init thy;
     val pres = (map (LocalDefs.meta_rewrite_rule ctxt) o fst) (NBE_Rewrite.get thy)
-  in map (CodegenData.rewrite_func pres) end
+  in map (CodegenFunc.rewrite_func pres) end
 
 fun apply_posts thy =
   let