src/Pure/Tools/codegen_theorems.ML
changeset 19341 3414c04fbc39
parent 19280 5091dc43817b
child 19436 3f5835aac3ce
--- a/src/Pure/Tools/codegen_theorems.ML	Thu Apr 06 16:08:22 2006 +0200
+++ b/src/Pure/Tools/codegen_theorems.ML	Thu Apr 06 16:08:25 2006 +0200
@@ -1,4 +1,4 @@
-(*  Title:      Pure/Tools/CODEGEN_THEOREMS.ML
+(*  Title:      Pure/Tools/codegen_theorems.ML
     ID:         $Id$
     Author:     Florian Haftmann, TU Muenchen
 
@@ -9,11 +9,31 @@
 sig
   val add_notify: (string option -> theory -> theory) -> theory -> theory;
   val add_preproc: (theory -> thm list -> thm list) -> theory -> theory;
-  val add_funn: thm -> theory -> theory;
+  val add_fun_extr: (theory -> string * typ -> thm list) -> theory -> theory;
+  val add_datatype_extr: (theory -> string
+     -> (((string * sort) list * (string * typ list) list) * tactic) option)
+    -> theory -> theory;
+  val add_fun: thm -> theory -> theory;
   val add_pred: thm -> theory -> theory;
   val add_unfold: thm -> theory -> theory;
-  val preprocess: theory -> thm list -> thm list;
+  val del_def: thm -> theory -> theory;
+  val del_unfold: thm -> theory -> theory;
+  val purge_defs: string * typ -> theory -> theory;
+
+  val common_typ: theory -> (thm -> typ) -> thm list -> thm list;
+  val preprocess: theory -> (thm -> typ) option -> thm list -> thm list;
+  val preprocess_fun: theory -> thm list -> (typ * thm list) option;
   val preprocess_term: theory -> term -> term;
+  val get_funs: theory -> string * typ -> (typ * thm list) option;
+  val get_datatypes: theory -> string
+    -> (((string * sort) list * (string * typ list) list) * thm list) option;
+
+  val debug: bool ref;
+  val debug_msg: ('a -> string) -> 'a -> 'a;
+
+  val print_thms: theory -> unit;
+  val init_obj: theory -> string -> string * (thm list -> tactic) -> string * (thm list -> tactic)
+    -> string * (thm list -> tactic) -> string * (thm list -> tactic) -> unit;
 end;
 
 structure CodegenTheorems: CODEGEN_THEOREMS =
@@ -21,16 +41,126 @@
 
 (** auxiliary **)
 
-fun dest_funn thm =
-  case try (fst o dest_Const o fst o strip_comb o fst o Logic.dest_equals o prop_of) thm
-   of SOME c => SOME (c, thm)
+val debug = ref false;
+fun debug_msg f x = (if !debug then Output.debug (f x) else (); x);
+
+
+(** object logic **)
+
+val obj_bool_ref : string option ref = ref NONE;
+val obj_true_ref : string option ref = ref NONE;
+val obj_false_ref : string option ref = ref NONE;
+val obj_and_ref : string option ref = ref NONE;
+val obj_eq_ref : string option ref = ref NONE;
+val obj_eq_elim_ref : thm option ref = ref NONE;
+fun idem c = (the o !) c;
+
+fun mk_tf sel =
+  let
+    val bool_typ = Type (idem obj_bool_ref, []);
+    val name = idem
+      (if sel then obj_true_ref else obj_false_ref);
+  in
+    Const (name, bool_typ)
+  end handle Option => error "no object logic setup for code theorems";
+
+fun mk_obj_conj (x, y) =
+  let
+    val bool_typ = Type (idem obj_bool_ref, []);
+  in
+    Const (idem obj_and_ref, bool_typ --> bool_typ --> bool_typ) $ x $ y
+  end handle Option => error "no object logic setup for code theorems";
+
+fun mk_obj_eq (x, y) =
+  let
+    val bool_typ = Type (idem obj_bool_ref, []);
+  in
+    Const (idem obj_eq_ref, type_of x --> type_of y --> bool_typ) $ x $ y
+  end handle Option => error "no object logic setup for code theorems";
+
+fun is_obj_eq c =
+  c = idem obj_eq_ref
+    handle Option => error "no object logic setup for code theorems";
+
+fun mk_bool_eq ((x, y), rhs) =
+  let
+    val bool_typ = Type (idem obj_bool_ref, []);
+  in
+    Logic.mk_equals (
+      (mk_obj_eq (x, y)),
+      rhs
+    )
+  end handle Option => error "no object logic setup for code theorems";
+
+fun elim_obj_eq thm = rewrite_rule [idem obj_eq_elim_ref] thm
+  handle Option => error "no object logic setup for code theorems";
+
+fun init_obj thy bohl (truh, truh_tac) (fals, fals_tac) (ant, ant_tac) (eq, eq_tac) =
+  let
+    val _ = if (is_some o !) obj_bool_ref
+      then error "already set" else ()
+    val bool_typ = Type (bohl, []);
+    val free_typ  = TFree ("'a", Sign.defaultS thy);
+    val var_x = Free ("x", free_typ);
+    val var_y = Free ("y", free_typ);
+    val prop_P = Free ("P", bool_typ);
+    val prop_Q = Free ("Q", bool_typ);
+    val _ = Goal.prove thy [] []
+      (ObjectLogic.ensure_propT thy (Const (truh, bool_typ))) truh_tac;
+    val _ = Goal.prove thy ["P"] [ObjectLogic.ensure_propT thy (Const (fals, bool_typ))]
+      (ObjectLogic.ensure_propT thy prop_P) fals_tac;
+    val _ = Goal.prove thy ["P", "Q"] [ObjectLogic.ensure_propT thy prop_P, ObjectLogic.ensure_propT thy prop_Q]
+      (ObjectLogic.ensure_propT thy (Const (ant, bool_typ --> bool_typ --> bool_typ) $ prop_P $ prop_Q)) ant_tac;
+    val atomize_eq = Goal.prove thy ["x", "y"] []
+      (Logic.mk_equals (
+        Logic.mk_equals (var_x, var_y),
+        ObjectLogic.ensure_propT thy
+          (Const (eq, free_typ --> free_typ --> bool_typ) $ var_x $ var_y))) eq_tac;
+  in
+    obj_bool_ref := SOME bohl;
+    obj_true_ref := SOME truh;
+    obj_false_ref := SOME fals;
+    obj_and_ref := SOME ant;
+    obj_eq_ref := SOME eq;
+    obj_eq_elim_ref := SOME (Thm.symmetric atomize_eq)
+  end;
+
+
+(** auxiliary **)
+
+fun destr_fun thy thm =
+  case try (
+    prop_of
+    #> ObjectLogic.drop_judgment thy
+    #> Logic.dest_equals
+    #> fst
+    #> strip_comb
+    #> fst
+    #> dest_Const
+  ) (elim_obj_eq thm)
+   of SOME c_ty => SOME (c_ty, thm)
     | NONE => NONE;
 
+fun dest_fun thy thm =
+  case destr_fun thy thm
+   of SOME x => x
+    | NONE => error ("not a function equation: " ^ string_of_thm thm);
+
 fun dest_pred thm =
   case try (fst o dest_Const o fst o strip_comb o snd o Logic.dest_implies o prop_of) thm
    of SOME c => SOME (c, thm)
     | NONE => NONE;
 
+fun getf_first [] _ = NONE
+  | getf_first (f::fs) x = case f x
+     of NONE => getf_first fs x
+      | y as SOME x => y;
+
+fun getf_first_list [] x = []
+  | getf_first_list (f::fs) x = case f x
+     of [] => getf_first_list fs x
+      | xs => xs;
+      
 
 (** theory data **)
 
@@ -46,49 +176,113 @@
     mk_procs (AList.merge (op =) (K true) (preprocs1, preprocs2),
       AList.merge (op =) (K true) (notify1, notify2));
 
+datatype extrs = Extrs of {
+  funs: (serial * (theory -> string * typ -> thm list)) list,
+  datatypes: (serial * (theory -> string -> (((string * sort) list * (string * typ list) list) * tactic) option)) list
+};
+
+fun mk_extrs (funs, datatypes) = Extrs { funs = funs, datatypes = datatypes };
+fun map_extrs f (Extrs { funs, datatypes }) = mk_extrs (f (funs, datatypes));
+fun merge_extrs _ (Extrs { funs = funs1, datatypes = datatypes1 },
+  Extrs { funs = funs2, datatypes = datatypes2 }) =
+    mk_extrs (AList.merge (op =) (K true) (funs1, funs2),
+      AList.merge (op =) (K true) (datatypes1, datatypes2));
+
 datatype codethms = Codethms of {
-  funns: thm list Symtab.table,
+  funs: thm list Symtab.table,
   preds: thm list Symtab.table,
   unfolds: thm list
 };
 
-fun mk_codethms ((funns, preds), unfolds) =
-  Codethms { funns = funns, preds = preds, unfolds = unfolds };
-fun map_codethms f (Codethms { funns, preds, unfolds }) =
-  mk_codethms (f ((funns, preds), unfolds));
-fun merge_codethms _ (Codethms { funns = funns1, preds = preds1, unfolds = unfolds1 },
-  Codethms { funns = funns2, preds = preds2, unfolds = unfolds2 }) =
-    mk_codethms ((Symtab.join (K (uncurry (fold (insert eq_thm)))) (funns1, funns2),
+fun mk_codethms ((funs, preds), unfolds) =
+  Codethms { funs = funs, preds = preds, unfolds = unfolds };
+fun map_codethms f (Codethms { funs, preds, unfolds }) =
+  mk_codethms (f ((funs, preds), unfolds));
+fun merge_codethms _ (Codethms { funs = funs1, preds = preds1, unfolds = unfolds1 },
+  Codethms { funs = funs2, preds = preds2, unfolds = unfolds2 }) =
+    mk_codethms ((Symtab.join (K (uncurry (fold (insert eq_thm)))) (funs1, funs2),
         Symtab.join (K (uncurry (fold (insert eq_thm)))) (preds1, preds2)),
           fold (insert eq_thm) unfolds1 unfolds2);
 
+datatype codecache = Codecache of {
+  funs: thm list Symtab.table,
+  datatypes: (string * typ list) list Symtab.table
+};
+
+fun mk_codecache (funs, datatypes) = Codecache { funs = funs, datatypes = datatypes };
+fun map_codecache f (Extrs { funs, datatypes }) = Codecache (f (funs, datatypes));
+fun merge_codecache _ (Codecache { funs = funs1, datatypes = datatypes1 },
+  Extrs { funs = funs2, datatypes = datatypes2 }) =
+    mk_codecache (Symtab.empty, Symtab.empty);
+
 datatype T = T of {
   procs: procs,
+  extrs: extrs,
   codethms: codethms
 };
 
-fun mk_T (procs, codethms) = T { procs = procs, codethms = codethms };
-fun map_T f (T { procs, codethms }) = mk_T (f (procs, codethms));
-fun merge_T pp (T { procs = procs1, codethms = codethms1 },
-  T { procs = procs2, codethms = codethms2 }) =
-    mk_T (merge_procs pp (procs1, procs2), merge_codethms pp (codethms1, codethms2));
+fun mk_T (procs, (extrs, codethms)) = T { procs = procs, extrs = extrs, codethms = codethms };
+fun map_T f (T { procs, extrs, codethms }) = mk_T (f (procs, (extrs, codethms)));
+fun merge_T pp (T { procs = procs1, extrs = extrs1, codethms = codethms1 },
+  T { procs = procs2, extrs = extrs2, codethms = codethms2 }) =
+    mk_T (merge_procs pp (procs1, procs2), (merge_extrs pp (extrs1, extrs2), merge_codethms pp (codethms1, codethms2)));
 
 structure CodegenTheorems = TheoryDataFun
 (struct
   val name = "Pure/CodegenTheorems";
   type T = T;
   val empty = mk_T (mk_procs ([], []),
-    mk_codethms ((Symtab.empty, Symtab.empty), []));
+    (mk_extrs ([], []), mk_codethms ((Symtab.empty, Symtab.empty), [])));
   val copy = I;
   val extend = I;
   val merge = merge_T;
-  fun print _ _ = ();
+  fun print (thy : theory) (data : T) =
+    let
+      val codethms = (fn T { codethms, ... } => codethms) data;
+      val funs = (Symtab.dest o (fn Codethms { funs, ... } => funs)) codethms;
+      val preds = (Symtab.dest o (fn Codethms { preds, ... } => preds)) codethms;
+      val unfolds = (fn Codethms { unfolds, ... } => unfolds) codethms;
+    in
+      (Pretty.writeln o Pretty.block o Pretty.fbreaks) ([
+        Pretty.str "code generation theorems:",
+        Pretty.str "function theorems:" ] @
+        Pretty.fbreaks (
+          map (fn (c, thms) => 
+            (Pretty.block o Pretty.fbreaks) (
+              Pretty.str c :: map Display.pretty_thm thms
+            )
+          ) funs
+        ) @ [
+        Pretty.str "predicate theorems:" ] @
+        Pretty.fbreaks (
+          map (fn (c, thms) => 
+            (Pretty.block o Pretty.fbreaks) (
+              Pretty.str c :: map Display.pretty_thm thms
+            )
+          ) preds
+        ) @ [
+        Pretty.str "unfolding theorems:",
+        (Pretty.block o Pretty.fbreaks o map Display.pretty_thm) unfolds
+      ])
+    end;
 end);
 
 val _ = Context.add_setup CodegenTheorems.init;
-
+val print_thms = CodegenTheorems.print;
 
-(** notifiers and preprocessors **)
+local
+  val the_procs = (fn T { procs = Procs procs, ... } => procs) o CodegenTheorems.get
+  val the_extrs = (fn T { extrs = Extrs extrs, ... } => extrs) o CodegenTheorems.get
+  val the_codethms = (fn T { codethms = Codethms codethms, ... } => codethms) o CodegenTheorems.get
+in
+  val the_preprocs = (fn { preprocs, ... } => map snd preprocs) o the_procs;
+  val the_notify = (fn { notify, ... } => map snd notify) o the_procs;
+  val the_funs_extrs = (fn { funs, ... } => map snd funs) o the_extrs;
+  val the_datatypes_extrs = (fn { datatypes, ... } => map snd datatypes) o the_extrs;
+  val the_funs = (fn { funs, ... } => funs) o the_codethms;
+  val the_preds = (fn { preds, ... } => preds) o the_codethms;
+  val the_unfolds = (fn { unfolds, ... } => unfolds) o the_codethms;
+end (*local*);
 
 fun add_notify f =
   CodegenTheorems.map (map_T (fn (procs, codethms) =>
@@ -96,8 +290,7 @@
       (preprocs, (serial (), f) :: notify)), codethms)));
 
 fun notify_all c thy =
-  fold (fn f => f c) (((fn Procs { notify, ... } => map snd notify)
-    o (fn T { procs, ... } => procs) o CodegenTheorems.get) thy) thy;
+  fold (fn f => f c) (the_notify thy) thy;
 
 fun add_preproc f =
   CodegenTheorems.map (map_T (fn (procs, codethms) =>
@@ -105,44 +298,220 @@
       ((serial (), f) :: preprocs, notify)), codethms)))
   #> notify_all NONE;
 
-fun preprocess thy =
-  fold (fn f => f thy) (((fn Procs { preprocs, ... } => map snd preprocs)
-    o (fn T { procs, ... } => procs) o CodegenTheorems.get) thy);
+fun add_fun_extr f =
+  CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
+    (procs, (extrs |> map_extrs (fn (funs, datatypes) =>
+      ((serial (), f) :: funs, datatypes)), codethms))))
+  #> notify_all NONE;
+
+fun add_datatype_extr f =
+  CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
+    (procs, (extrs |> map_extrs (fn (funs, datatypes) =>
+      (funs, (serial (), f) :: datatypes)), codethms))))
+  #> notify_all NONE;
+
+fun add_fun thm thy =
+  case destr_fun thy thm
+   of SOME ((c, _), _) =>
+        thy
+        |> CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
+           (procs, (extrs, codethms |> map_codethms (fn ((funs, preds), unfolds) =>
+            ((funs |> Symtab.default (c, []) |> Symtab.map_entry c (fn thms => thms @ [thm]), preds), unfolds))))))
+        |> notify_all (SOME c)
+    | NONE => tap (fn _ => warning ("not a function equation: " ^ string_of_thm thm)) thy;
+
+fun add_pred thm thy =
+  case dest_pred thm
+   of SOME (c, _) =>
+        thy
+        |> CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
+          (procs, (extrs, codethms |> map_codethms (fn ((funs, preds), unfolds) =>
+            ((funs, preds |> Symtab.default (c, []) |> Symtab.map_entry c (fn thms => thms @ [thm])), unfolds))))))
+        |> notify_all (SOME c)
+    | NONE => tap (fn _ => warning ("not a predicate clause: " ^ string_of_thm thm)) thy;
+
+fun add_unfold thm =
+  CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
+    (procs, (extrs, codethms |> map_codethms (fn (defs, unfolds) =>
+      (defs, thm :: unfolds))))))
+  #> notify_all NONE;
+
+fun del_def thm thy =
+  case destr_fun thy thm
+   of SOME ((c, _), thm) =>
+        thy
+        |> CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
+           (procs, (extrs, codethms |> map_codethms (fn ((funs, preds), unfolds) =>
+            ((funs |> Symtab.map_entry c (remove eq_thm thm), preds), unfolds))))))
+        |> notify_all (SOME c)
+    | NONE => case dest_pred thm
+   of SOME (c, thm) =>
+        thy
+        |> CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
+           (procs, (extrs, codethms |> map_codethms (fn ((funs, preds), unfolds) =>
+            ((funs, preds |> Symtab.map_entry c (remove eq_thm thm)), unfolds))))))
+        |> notify_all (SOME c)
+    | NONE => error ("no code theorem to delete");
+
+fun del_unfold thm = 
+  CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
+    (procs, (extrs, codethms |> map_codethms (fn (defs, unfolds) =>
+      (defs, remove eq_thm thm unfolds))))))
+  #> notify_all NONE;
+
+fun purge_defs (c, ty) thy =
+  thy
+  |> CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
+      (procs, (extrs, codethms |> map_codethms (fn ((funs, preds), unfolds) =>
+        ((funs |> Symtab.map_entry c
+            (filter (fn thm => Sign.typ_instance thy ((snd o fst o dest_fun thy) thm, ty))),
+          preds |> Symtab.update (c, [])), unfolds))))))
+  |> notify_all (SOME c);
+
+
+(** preprocessing **)
+
+fun common_typ thy _ [] = []
+  | common_typ thy _ [thm] = [thm]
+  | common_typ thy extract_typ thms =
+      let
+        fun incr_thm thm max =
+          let
+            val thm' = incr_indexes max thm;
+            val max' = (maxidx_of_typ o type_of o prop_of) thm' + 1;
+          in (thm', max') end;
+        val (thms', maxidx) = fold_map incr_thm thms 0;
+        val (ty1::tys) = map extract_typ thms;
+        fun unify ty = Type.unify (Sign.tsig_of thy) (ty1, ty);
+        val (env, _) = fold unify tys (Vartab.empty, maxidx)
+        val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
+          cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
+      in map (Thm.instantiate (instT, [])) thms end;
+
+fun preprocess thy extract_typ thms =
+  thms
+  |> map (Thm.transfer thy)
+  |> fold (fn f => f thy) (the_preprocs thy)
+  |> map (rewrite_rule (the_unfolds thy))
+  |> (if is_some extract_typ then common_typ thy (the extract_typ) else I)
+  |> Drule.conj_intr_list
+  |> Drule.zero_var_indexes
+  |> Drule.conj_elim_list
+  |> map Drule.unvarifyT
+  |> map Drule.unvarify;
+
+fun preprocess_fun thy thms =
+  let
+    fun tap_typ [] = NONE
+      | tap_typ (thms as (thm::_)) = SOME ((snd o fst o dest_fun thy) thm, thms)
+  in
+    thms
+    |> map elim_obj_eq
+    |> preprocess thy (SOME (snd o fst o dest_fun thy))
+    |> tap_typ
+  end;
 
 fun preprocess_term thy t =
   let
-    val x = Free (variant (add_term_names (t, [])) "x", fastype_of t);
-    (* fake definition *)
+    val x = Free (variant (add_term_names (t, [])) "a", fastype_of t);
+    (*fake definition*)
     val eq = setmp quick_and_dirty true (SkipProof.make_thm thy)
       (Logic.mk_equals (x, t));
     fun err () = error "preprocess_term: bad preprocessor"
-  in case map prop_of (preprocess thy [eq]) of
+  in case map prop_of (preprocess thy NONE [eq]) of
       [Const ("==", _) $ x' $ t'] => if x = x' then t' else err ()
     | _ => err ()
   end;
 
-fun add_unfold thm =
-  CodegenTheorems.map (map_T (fn (procs, codethms) =>
-    (procs, codethms |> map_codethms (fn (defs, unfolds) =>
-      (defs, thm :: unfolds)))))
+
+(** retrieval **)
+
+fun get_funs thy (c, ty) =
+  let
+    val filter_typ = Library.mapfilter (fn ((_, ty'), thm) =>
+      if Sign.typ_instance thy (ty', ty)
+        orelse Sign.typ_instance thy (ty, ty')
+      then SOME thm else debug_msg (fn _ => "dropping " ^ string_of_thm thm) NONE);
+    val thms_funs = 
+      (these o Symtab.lookup (the_funs thy)) c
+      |> map (dest_fun thy)
+      |> filter_typ;
+    val thms = case thms_funs
+     of [] =>
+          Defs.specifications_of (Theory.defs_of thy) c
+          |> map (PureThy.get_thms thy o Name o fst o snd)
+          |> Library.flat
+          |> append (getf_first_list (map (fn f => f thy) (the_funs_extrs thy)) (c, ty))
+          |> map (dest_fun thy)
+          |> filter_typ
+      | thms => thms
+  in
+    thms
+    |> preprocess_fun thy
+  end;
 
-fun add_funn thm =
-  case dest_funn thm
-   of SOME (c, thm) =>
-    CodegenTheorems.map (map_T (fn (procs, codethms) =>
-      (procs, codethms |> map_codethms (fn ((funns, preds), unfolds) =>
-        ((funns |> Symtab.default (c, []) |> Symtab.map (fn thms => thms @ [thm]), preds), unfolds)))))
-    | NONE => error ("not a function equation: " ^ string_of_thm thm);
+fun get_datatypes thy dtco =
+  let
+    val truh = mk_tf true;
+    val fals = mk_tf false;
+    fun mk_lhs vs ((co1, tys1), (co2, tys2)) =
+      let
+        val dty = Type (dtco, map TFree vs);
+        val (xs1, xs2) = chop (length tys1) (Term.invent_names [] "x" (length tys1 + length tys2));
+        val frees1 = map2 (fn x => fn ty => Free (x, ty)) xs1 tys1;
+        val frees2 = map2 (fn x => fn ty => Free (x, ty)) xs2 tys2;
+        fun zip_co co xs tys = list_comb (Const (co,
+          tys ---> dty), map2 (fn x => fn ty => Free (x, ty)) xs tys);
+      in
+        ((frees1, frees2), (zip_co co1 xs1 tys1, zip_co co2 xs2 tys2))
+      end;
+    fun mk_rhs [] [] = truh
+      | mk_rhs xs ys = foldr1 mk_obj_conj (map2 (curry mk_obj_eq) xs ys);
+    fun mk_eq vs (args as ((co1, _), (co2, _))) (inj, dist) =
+      if co1 = co2
+        then let
+          val ((fs1, fs2), lhs) = mk_lhs vs args;
+          val rhs = mk_rhs fs1 fs2;
+        in (mk_bool_eq (lhs, rhs) :: inj, dist) end
+        else let
+          val (_, lhs) = mk_lhs vs args;
+        in (inj, mk_bool_eq (lhs, fals) :: dist) end;
+    fun mk_eqs (vs, cos) =
+      let val cos' = rev cos 
+      in (op @) (fold (mk_eq vs) (product cos' cos') ([], [])) end;
+    fun mk_eq_thms tac vs_cos =
+      map (fn t => (Goal.prove thy [] []
+        (ObjectLogic.ensure_propT thy t) (K tac))) (mk_eqs vs_cos);
+  in
+    case getf_first (map (fn f => f thy) (the_datatypes_extrs thy)) dtco
+     of NONE => NONE
+      | SOME (vs_cos, tac) => SOME (vs_cos, mk_eq_thms tac vs_cos)
+  end;
 
-fun add_pred thm =
-  case dest_pred thm
-   of SOME (c, thm) =>
-    CodegenTheorems.map (map_T (fn (procs, codethms) =>
-      (procs, codethms |> map_codethms (fn ((funns, preds), unfolds) =>
-        ((funns, preds |> Symtab.default (c, []) |> Symtab.map (fn thms => thms @ [thm])), unfolds)))))
-    | NONE => error ("not a predicate clause: " ^ string_of_thm thm);
+fun get_eq thy (c, ty) =
+  if is_obj_eq c
+  then case get_datatypes thy ((fst o dest_Type o hd o fst o strip_type) ty)
+   of SOME (_, thms) => thms
+    | _ => []
+  else [];
 
 
-(** isar **)
+(** code attributes and setup **)
 
-end; (* struct *)
+local
+  fun add_simple_attribute (name, f) =
+    (Codegen.add_attribute name o (Scan.succeed o Thm.declaration_attribute))
+      (Context.map_theory o f);
+in
+  val _ = map (Context.add_setup o add_simple_attribute) [
+    ("fun", add_fun),
+    ("pred", add_pred),
+    ("unfold", (fn thm => Codegen.add_unfold thm #> add_unfold thm)),
+    ("unfolt", add_unfold),
+    ("nofold", del_unfold)
+  ]
+end; (*local*)
+
+val _ = Context.add_setup (add_fun_extr get_eq);
+
+end; (*struct*)