simplified infrastructure for code generator operational equality
authorhaftmann
Wed Dec 05 14:15:45 2007 +0100 (2007-12-05)
changeset 25534d0b74fdd6067
parent 25533 0140cc7b26ad
child 25535 4975b7529a14
simplified infrastructure for code generator operational equality
src/HOL/Code_Setup.thy
src/HOL/Datatype.thy
src/HOL/HOL.thy
src/HOL/Inductive.thy
src/HOL/Nat.thy
src/HOL/Product_Type.thy
src/HOL/Sum_Type.thy
src/HOL/Tools/datatype_codegen.ML
src/HOL/Tools/typecopy_package.ML
     1.1 --- a/src/HOL/Code_Setup.thy	Wed Dec 05 14:15:39 2007 +0100
     1.2 +++ b/src/HOL/Code_Setup.thy	Wed Dec 05 14:15:45 2007 +0100
     1.3 @@ -91,12 +91,6 @@
     1.4    (OCaml "bool")
     1.5    (Haskell "Bool")
     1.6  
     1.7 -code_instance bool :: eq
     1.8 -  (Haskell -)
     1.9 -
    1.10 -code_const "op = \<Colon> bool \<Rightarrow> bool \<Rightarrow> bool"
    1.11 -  (Haskell infixl 4 "==")
    1.12 -
    1.13  code_const True and False and Not and "op &" and "op |" and If
    1.14    (SML "true" and "false" and "not"
    1.15      and infixl 1 "andalso" and infixl 0 "orelse"
     2.1 --- a/src/HOL/Datatype.thy	Wed Dec 05 14:15:39 2007 +0100
     2.2 +++ b/src/HOL/Datatype.thy	Wed Dec 05 14:15:45 2007 +0100
     2.3 @@ -10,7 +10,6 @@
     2.4  
     2.5  theory Datatype
     2.6  imports Finite_Set
     2.7 -uses "Tools/datatype_codegen.ML"
     2.8  begin
     2.9  
    2.10  typedef (Node)
    2.11 @@ -682,8 +681,6 @@
    2.12  
    2.13  subsubsection {* Code generator setup *}
    2.14  
    2.15 -setup DatatypeCodegen.setup
    2.16 -
    2.17  definition
    2.18    is_none :: "'a option \<Rightarrow> bool" where
    2.19    is_none_none [code post, symmetric, code inline]: "is_none x \<longleftrightarrow> x = None"
     3.1 --- a/src/HOL/HOL.thy	Wed Dec 05 14:15:39 2007 +0100
     3.2 +++ b/src/HOL/HOL.thy	Wed Dec 05 14:15:45 2007 +0100
     3.3 @@ -1636,8 +1636,6 @@
     3.4  
     3.5  class eq (attach "op =") = type
     3.6  
     3.7 -code_datatype True False
     3.8 -
     3.9  lemma [code func]:
    3.10    shows "False \<and> x \<longleftrightarrow> False"
    3.11      and "True \<and> x \<longleftrightarrow> x"
    3.12 @@ -1654,17 +1652,9 @@
    3.13    shows "\<not> True \<longleftrightarrow> False"
    3.14      and "\<not> False \<longleftrightarrow> True" by (rule HOL.simp_thms)+
    3.15  
    3.16 -instance bool :: eq ..
    3.17 -
    3.18 -lemma [code func]:
    3.19 -  shows "False = P \<longleftrightarrow> \<not> P"
    3.20 -    and "True = P \<longleftrightarrow> P" 
    3.21 -    and "P = False \<longleftrightarrow> \<not> P" 
    3.22 -    and "P = True \<longleftrightarrow> P" by simp_all
    3.23 -
    3.24  code_datatype Trueprop "prop"
    3.25  
    3.26 -code_datatype "TYPE('a)"
    3.27 +code_datatype "TYPE('a\<Colon>{})"
    3.28  
    3.29  lemma Let_case_cert:
    3.30    assumes "CASE \<equiv> (\<lambda>x. Let x f)"
     4.1 --- a/src/HOL/Inductive.thy	Wed Dec 05 14:15:39 2007 +0100
     4.2 +++ b/src/HOL/Inductive.thy	Wed Dec 05 14:15:45 2007 +0100
     4.3 @@ -18,6 +18,7 @@
     4.4    ("Tools/datatype_case.ML")
     4.5    ("Tools/datatype_package.ML")
     4.6    ("Tools/primrec_package.ML")
     4.7 +  ("Tools/datatype_codegen.ML")
     4.8  begin
     4.9  
    4.10  subsection {* Least and greatest fixed points *}
    4.11 @@ -329,6 +330,9 @@
    4.12  setup DatatypePackage.setup
    4.13  use "Tools/primrec_package.ML"
    4.14  
    4.15 +use "Tools/datatype_codegen.ML"
    4.16 +setup DatatypeCodegen.setup
    4.17 +
    4.18  use "Tools/inductive_codegen.ML"
    4.19  setup InductiveCodegen.setup
    4.20  
     5.1 --- a/src/HOL/Nat.thy	Wed Dec 05 14:15:39 2007 +0100
     5.2 +++ b/src/HOL/Nat.thy	Wed Dec 05 14:15:45 2007 +0100
     5.3 @@ -1143,15 +1143,6 @@
     5.4  
     5.5  subsection {* Code generator setup *}
     5.6  
     5.7 -instance nat :: eq ..
     5.8 -
     5.9 -lemma [code func]:
    5.10 -  "(0\<Colon>nat) = 0 \<longleftrightarrow> True"
    5.11 -  "Suc n = Suc m \<longleftrightarrow> n = m"
    5.12 -  "Suc n = 0 \<longleftrightarrow> False"
    5.13 -  "0 = Suc m \<longleftrightarrow> False"
    5.14 -by auto
    5.15 -
    5.16  lemma [code func]:
    5.17    "(0\<Colon>nat) \<le> m \<longleftrightarrow> True"
    5.18    "Suc (n\<Colon>nat) \<le> m \<longleftrightarrow> n < m"
     6.1 --- a/src/HOL/Product_Type.thy	Wed Dec 05 14:15:39 2007 +0100
     6.2 +++ b/src/HOL/Product_Type.thy	Wed Dec 05 14:15:45 2007 +0100
     6.3 @@ -24,6 +24,17 @@
     6.4  declare case_split [cases type: bool]
     6.5    -- "prefer plain propositional version"
     6.6  
     6.7 +lemma [code func]:
     6.8 +  shows "False = P \<longleftrightarrow> \<not> P"
     6.9 +    and "True = P \<longleftrightarrow> P" 
    6.10 +    and "P = False \<longleftrightarrow> \<not> P" 
    6.11 +    and "P = True \<longleftrightarrow> P" by simp_all
    6.12 +
    6.13 +code_const "op = \<Colon> bool \<Rightarrow> bool \<Rightarrow> bool"
    6.14 +  (Haskell infixl 4 "==")
    6.15 +
    6.16 +code_instance bool :: eq
    6.17 +  (Haskell -)
    6.18  
    6.19  subsection {* Unit *}
    6.20  
     7.1 --- a/src/HOL/Sum_Type.thy	Wed Dec 05 14:15:39 2007 +0100
     7.2 +++ b/src/HOL/Sum_Type.thy	Wed Dec 05 14:15:45 2007 +0100
     7.3 @@ -215,26 +215,6 @@
     7.4  by blast
     7.5  
     7.6  
     7.7 -subsection {* Code generator setup *}
     7.8 -
     7.9 -instance "+" :: (eq, eq) eq ..
    7.10 -
    7.11 -lemma [code func]:
    7.12 -  "(Inl x \<Colon> 'a\<Colon>eq + 'b\<Colon>eq) = Inl y \<longleftrightarrow> x = y"
    7.13 -  unfolding Inl_eq ..
    7.14 -
    7.15 -lemma [code func]:
    7.16 -  "(Inr x \<Colon> 'a\<Colon>eq + 'b\<Colon>eq) = Inr y \<longleftrightarrow> x = y"
    7.17 -  unfolding Inr_eq ..
    7.18 -
    7.19 -lemma [code func]:
    7.20 -  "Inl (x\<Colon>'a\<Colon>eq) = Inr (y\<Colon>'b\<Colon>eq) \<longleftrightarrow> False"
    7.21 -  using Inl_not_Inr by auto
    7.22 -
    7.23 -lemma [code func]:
    7.24 -  "Inr (x\<Colon>'b\<Colon>eq) = Inl (y\<Colon>'a\<Colon>eq) \<longleftrightarrow> False"
    7.25 -  using Inr_not_Inl by auto
    7.26 -
    7.27  ML
    7.28  {*
    7.29  val Inl_RepI = thm "Inl_RepI";
     8.1 --- a/src/HOL/Tools/datatype_codegen.ML	Wed Dec 05 14:15:39 2007 +0100
     8.2 +++ b/src/HOL/Tools/datatype_codegen.ML	Wed Dec 05 14:15:45 2007 +0100
     8.3 @@ -2,32 +2,21 @@
     8.4      ID:         $Id$
     8.5      Author:     Stefan Berghofer & Florian Haftmann, TU Muenchen
     8.6  
     8.7 -Code generator for inductive datatypes.
     8.8 +Code generator facilities for inductive datatypes.
     8.9  *)
    8.10  
    8.11  signature DATATYPE_CODEGEN =
    8.12  sig
    8.13    val get_eq: theory -> string -> thm list
    8.14 -  val get_eq_datatype: theory -> string -> thm list
    8.15    val get_case_cert: theory -> string -> thm
    8.16 -
    8.17 -  type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
    8.18 -    -> theory -> theory
    8.19 -  val add_codetypes_hook: hook -> theory -> theory
    8.20 -  val get_codetypes_arities: theory -> (string * bool) list -> sort
    8.21 -    -> (string * (arity * term list)) list
    8.22 -  val prove_codetypes_arities: tactic -> (string * bool) list -> sort
    8.23 -    -> (arity list -> (string * term list) list -> theory
    8.24 -      -> ((bstring * Attrib.src list) * term) list * theory)
    8.25 -    -> (arity list -> (string * term list) list -> thm list -> theory -> theory)
    8.26 -    -> theory -> theory
    8.27 -
    8.28    val setup: theory -> theory
    8.29  end;
    8.30  
    8.31  structure DatatypeCodegen : DATATYPE_CODEGEN =
    8.32  struct
    8.33  
    8.34 +(** SML code generator **)
    8.35 +
    8.36  open Codegen;
    8.37  
    8.38  fun mk_tuple [p] = p
    8.39 @@ -310,66 +299,21 @@
    8.40    | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
    8.41  
    8.42  
    8.43 -(** datatypes for code 2nd generation **)
    8.44 -
    8.45 -local
    8.46 -
    8.47 -val not_sym = thm "HOL.not_sym";
    8.48 -val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
    8.49 -val refl = thm "refl";
    8.50 -val eqTrueI = thm "eqTrueI";
    8.51 +(** generic code generator **)
    8.52  
    8.53 -fun mk_distinct cos =
    8.54 -  let
    8.55 -    fun sym_product [] = []
    8.56 -      | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
    8.57 -    fun mk_co_args (co, tys) ctxt =
    8.58 -      let
    8.59 -        val names = Name.invents ctxt "a" (length tys);
    8.60 -        val ctxt' = fold Name.declare names ctxt;
    8.61 -        val vs = map2 (curry Free) names tys;
    8.62 -      in (vs, ctxt') end;
    8.63 -    fun mk_dist ((co1, tys1), (co2, tys2)) =
    8.64 -      let
    8.65 -        val ((xs1, xs2), _) = Name.context
    8.66 -          |> mk_co_args (co1, tys1)
    8.67 -          ||>> mk_co_args (co2, tys2);
    8.68 -        val prem = HOLogic.mk_eq
    8.69 -          (list_comb (co1, xs1), list_comb (co2, xs2));
    8.70 -        val t = HOLogic.mk_not prem;
    8.71 -      in HOLogic.mk_Trueprop t end;
    8.72 -  in map mk_dist (sym_product cos) end;
    8.73 +(* specification *)
    8.74  
    8.75 -in
    8.76 -
    8.77 -fun get_eq_datatype thy dtco =
    8.78 +fun add_datatype_spec vs dtco cos thy =
    8.79    let
    8.80 -    val SOME (vs, cs) = DatatypePackage.get_datatype_spec thy dtco;
    8.81 -    fun mk_triv_inject co =
    8.82 -      let
    8.83 -        val ct' = Thm.cterm_of thy
    8.84 -          (Const (co, Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vs)))
    8.85 -        val cty' = Thm.ctyp_of_term ct';
    8.86 -        val SOME (ct, cty) = fold_aterms (fn Var (v, ty) =>
    8.87 -          (K o SOME) (Thm.cterm_of thy (Var (v, Thm.typ_of cty')), Thm.ctyp_of thy ty) | _ => I)
    8.88 -          (Thm.prop_of refl) NONE;
    8.89 -      in eqTrueI OF [Thm.instantiate ([(cty, cty')], [(ct, ct')]) refl] end;
    8.90 -    val inject1 = map_filter (fn (co, []) => SOME (mk_triv_inject co) | _ => NONE) cs
    8.91 -    val inject2 = (#inject o DatatypePackage.the_datatype thy) dtco;
    8.92 -    val ctxt = ProofContext.init thy;
    8.93 -    val simpset = Simplifier.context ctxt
    8.94 -      (MetaSimplifier.empty_ss addsimprocs [distinct_simproc]);
    8.95 -    val cos = map (fn (co, tys) =>
    8.96 -        (Const (co, tys ---> Type (dtco, map TFree vs)), tys)) cs;
    8.97 -    val tac = ALLGOALS (simp_tac simpset)
    8.98 -      THEN ALLGOALS (ProofContext.fact_tac [not_false_true, TrueI]);
    8.99 -    val distinct =
   8.100 -      mk_distinct cos
   8.101 -      |> map (fn t => Goal.prove_global thy [] [] t (K tac))
   8.102 -      |> (fn thms => thms @ map (fn thm => not_sym OF [thm]) thms)
   8.103 -  in inject1 @ inject2 @ distinct end;
   8.104 +    val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
   8.105 +  in
   8.106 +    thy
   8.107 +    |> try (Code.add_datatype cs)
   8.108 +    |> the_default thy
   8.109 +  end;
   8.110  
   8.111 -end;
   8.112 +
   8.113 +(* case certificates *)
   8.114  
   8.115  fun get_case_cert thy tyco =
   8.116    let
   8.117 @@ -402,170 +346,116 @@
   8.118      |> Thm.varifyT
   8.119    end;
   8.120  
   8.121 -
   8.122 -
   8.123 -(** codetypes for code 2nd generation **)
   8.124 -
   8.125 -(* abstraction over datatypes vs. type copies *)
   8.126 -
   8.127 -fun get_typecopy_spec thy tyco =
   8.128 +fun add_datatype_cases dtco thy =
   8.129    let
   8.130 -    val SOME { vs, constr, typ, ... } = TypecopyPackage.get_info thy tyco
   8.131 -  in (vs, [(constr, [typ])]) end;
   8.132 -
   8.133 -
   8.134 -fun get_spec thy (dtco, true) =
   8.135 -      (the o DatatypePackage.get_datatype_spec thy) dtco
   8.136 -  | get_spec thy (tyco, false) =
   8.137 -      get_typecopy_spec thy tyco;
   8.138 -
   8.139 -local
   8.140 -  fun get_eq_thms thy tyco = case DatatypePackage.get_datatype thy tyco
   8.141 -   of SOME _ => get_eq_datatype thy tyco
   8.142 -    | NONE => [TypecopyPackage.get_eq thy tyco];
   8.143 -  fun constrain_op_eq_thms thy thms =
   8.144 -    let
   8.145 -      fun add_eq (Const ("op =", ty)) =
   8.146 -            fold (insert (eq_fst (op =))) (Term.add_tvarsT ty [])
   8.147 -        | add_eq _ =
   8.148 -            I
   8.149 -      val eqs = fold (fold_aterms add_eq o Thm.prop_of) thms [];
   8.150 -      val instT = map (fn (v_i, sort) =>
   8.151 -        (Thm.ctyp_of thy (TVar (v_i, sort)),
   8.152 -           Thm.ctyp_of thy (TVar (v_i, Sorts.inter_sort (Sign.classes_of thy)
   8.153 -             (sort, [HOLogic.class_eq]))))) eqs;
   8.154 -    in
   8.155 -      thms
   8.156 -      |> map (Thm.instantiate (instT, []))
   8.157 -    end;
   8.158 -in
   8.159 -  fun get_eq thy tyco =
   8.160 -    get_eq_thms thy tyco
   8.161 -    |> maps ((#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy)
   8.162 -    |> constrain_op_eq_thms thy
   8.163 -end;
   8.164 -
   8.165 -type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
   8.166 -  -> theory -> theory;
   8.167 -
   8.168 -fun add_codetypes_hook hook thy =
   8.169 -  let
   8.170 -    fun add_spec thy (tyco, is_dt) =
   8.171 -      (tyco, (is_dt, get_spec thy (tyco, is_dt)));
   8.172 -    fun datatype_hook dtcos thy =
   8.173 -      hook (map (add_spec thy) (map (rpair true) dtcos)) thy;
   8.174 -    fun typecopy_hook tyco thy =
   8.175 -      hook ([(tyco, (false, get_typecopy_spec thy tyco))]) thy;
   8.176 +    val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
   8.177 +    val certs = get_case_cert thy dtco;
   8.178    in
   8.179      thy
   8.180 -    |> DatatypePackage.interpretation datatype_hook
   8.181 -    |> TypecopyPackage.interpretation typecopy_hook
   8.182 +    |> Code.add_case certs
   8.183 +    |> fold_rev Code.add_default_func case_rewrites
   8.184    end;
   8.185  
   8.186 -fun the_codetypes_mut_specs thy ([(tyco, is_dt)]) =
   8.187 -      let
   8.188 -        val (vs, cs) = get_spec thy (tyco, is_dt)
   8.189 -      in (vs, [(tyco, (is_dt, cs))]) end
   8.190 -  | the_codetypes_mut_specs thy (tycos' as (tyco, true) :: _) =
   8.191 -      let
   8.192 -        val tycos = map fst tycos';
   8.193 -        val tycos'' = (map (#1 o snd) o #descr o DatatypePackage.the_datatype thy) tyco;
   8.194 -        val _ = if gen_subset (op =) (tycos, tycos'') then () else
   8.195 -          error ("type constructors are not mutually recursive: " ^ (commas o map quote) tycos);
   8.196 -        val (vs::_, css) = split_list (map (the o DatatypePackage.get_datatype_spec thy) tycos);
   8.197 -      in (vs, map2 (fn (tyco, is_dt) => fn cs => (tyco, (is_dt, cs))) tycos' css) end;
   8.198 +
   8.199 +(* equality *)
   8.200 +
   8.201 +local
   8.202  
   8.203 -
   8.204 -(* instrumentalizing the sort algebra *)
   8.205 +val not_sym = thm "HOL.not_sym";
   8.206 +val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
   8.207 +val refl = thm "refl";
   8.208 +val eqTrueI = thm "eqTrueI";
   8.209  
   8.210 -fun get_codetypes_arities thy tycos sort =
   8.211 +fun mk_distinct cos =
   8.212    let
   8.213 -    val pp = Sign.pp thy;
   8.214 -    val algebra = Sign.classes_of thy;
   8.215 -    val (vs_proto, css_proto) = the_codetypes_mut_specs thy tycos;
   8.216 -    val vs = map (fn (v, vsort) => (v, Sorts.inter_sort algebra (vsort, sort))) vs_proto;
   8.217 -    val css = map (fn (tyco, (_, cs)) => (tyco, cs)) css_proto;
   8.218 -    val algebra' = algebra
   8.219 -      |> fold (fn (tyco, _) =>
   8.220 -           Sorts.add_arities pp (tyco, map (fn class => (class, map snd vs)) sort)) css;
   8.221 -    fun typ_sort_inst ty = CodeUnit.typ_sort_inst algebra' (Logic.varifyT ty, sort);
   8.222 -    val venv = Vartab.empty
   8.223 -      |> fold (fn (v, sort) => Vartab.update_new ((v, 0), sort)) vs
   8.224 -      |> fold (fn (_, cs) => fold (fn (_, tys) => fold typ_sort_inst tys) cs) css;
   8.225 -    fun inst (v, _) = (v, (the o Vartab.lookup venv) (v, 0));
   8.226 -    val vs' = map inst vs;
   8.227 -    fun mk_arity tyco = (tyco, map snd vs', sort);
   8.228 -    fun mk_cons tyco (c, tys) =
   8.229 +    fun sym_product [] = []
   8.230 +      | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
   8.231 +    fun mk_co_args (co, tys) ctxt =
   8.232 +      let
   8.233 +        val names = Name.invents ctxt "a" (length tys);
   8.234 +        val ctxt' = fold Name.declare names ctxt;
   8.235 +        val vs = map2 (curry Free) names tys;
   8.236 +      in (vs, ctxt') end;
   8.237 +    fun mk_dist ((co1, tys1), (co2, tys2)) =
   8.238        let
   8.239 -        val tys' = (map o Term.map_type_tfree) (TFree o inst) tys;
   8.240 -        val ts = Name.names Name.context "a" tys';
   8.241 -        val ty = (tys' ---> Type (tyco, map TFree vs'));
   8.242 -      in list_comb (Const (c, ty), map Free ts) end;
   8.243 -  in
   8.244 -    map (fn (tyco, cs) => (tyco, (mk_arity tyco, map (mk_cons tyco) cs))) css
   8.245 -  end;
   8.246 +        val ((xs1, xs2), _) = Name.context
   8.247 +          |> mk_co_args (co1, tys1)
   8.248 +          ||>> mk_co_args (co2, tys2);
   8.249 +        val prem = HOLogic.mk_eq
   8.250 +          (list_comb (co1, xs1), list_comb (co2, xs2));
   8.251 +        val t = HOLogic.mk_not prem;
   8.252 +      in HOLogic.mk_Trueprop t end;
   8.253 +  in map mk_dist (sym_product cos) end;
   8.254 +
   8.255 +in
   8.256  
   8.257 -fun prove_codetypes_arities tac tycos sort f after_qed thy =
   8.258 -  case try (get_codetypes_arities thy tycos) sort
   8.259 -   of NONE => thy
   8.260 -    | SOME insts => let
   8.261 -        fun proven (tyco, asorts, sort) =
   8.262 -          Sorts.of_sort (Sign.classes_of thy)
   8.263 -            (Type (tyco, map TFree (Name.names Name.context "'a" asorts)), sort);
   8.264 -        val (arities, css) = (split_list o map_filter
   8.265 -          (fn (tyco, (arity, cs)) => if proven arity
   8.266 -            then NONE else SOME (arity, (tyco, cs)))) insts;
   8.267 -      in
   8.268 -        thy
   8.269 -        |> not (null arities) ? (
   8.270 -            f arities css
   8.271 -            #-> (fn defs =>
   8.272 -              Instance.prove_instance tac arities defs
   8.273 -            #-> (fn defs =>
   8.274 -              after_qed arities css defs)))
   8.275 -      end;
   8.276 +fun get_eq thy dtco =
   8.277 +  let
   8.278 +    val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco;
   8.279 +    fun mk_triv_inject co =
   8.280 +      let
   8.281 +        val ct' = Thm.cterm_of thy
   8.282 +          (Const (co, Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vs)))
   8.283 +        val cty' = Thm.ctyp_of_term ct';
   8.284 +        val SOME (ct, cty) = fold_aterms (fn Var (v, ty) =>
   8.285 +          (K o SOME) (Thm.cterm_of thy (Var (v, Thm.typ_of cty')), Thm.ctyp_of thy ty) | _ => I)
   8.286 +          (Thm.prop_of refl) NONE;
   8.287 +      in eqTrueI OF [Thm.instantiate ([(cty, cty')], [(ct, ct')]) refl] end;
   8.288 +    val inject1 = map_filter (fn (co, []) => SOME (mk_triv_inject co) | _ => NONE) cs
   8.289 +    val inject2 = (#inject o DatatypePackage.the_datatype thy) dtco;
   8.290 +    val ctxt = ProofContext.init thy;
   8.291 +    val simpset = Simplifier.context ctxt
   8.292 +      (MetaSimplifier.empty_ss addsimprocs [distinct_simproc]);
   8.293 +    val cos = map (fn (co, tys) =>
   8.294 +        (Const (co, tys ---> Type (dtco, map TFree vs)), tys)) cs;
   8.295 +    val tac = ALLGOALS (simp_tac simpset)
   8.296 +      THEN ALLGOALS (ProofContext.fact_tac [not_false_true, TrueI]);
   8.297 +    val distinct =
   8.298 +      mk_distinct cos
   8.299 +      |> map (fn t => Goal.prove_global thy [] [] t (K tac))
   8.300 +      |> (fn thms => thms @ map (fn thm => not_sym OF [thm]) thms)
   8.301 +  in inject1 @ inject2 @ distinct end;
   8.302  
   8.303 -
   8.304 -(* operational equality *)
   8.305 +end;
   8.306  
   8.307 -fun eq_hook specs =
   8.308 +fun add_datatypes_equality vs dtcos thy =
   8.309    let
   8.310 -    fun add_eq_thms (dtco, (_, (vs, cs))) thy =
   8.311 +    fun get_eq' thy dtco = get_eq thy dtco
   8.312 +      |> map (CodeUnit.constrain_thm [HOLogic.class_eq])
   8.313 +      |> maps ((#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy);
   8.314 +    fun add_eq_thms dtco thy =
   8.315        let
   8.316          val thy_ref = Theory.check_thy thy;
   8.317          val const = Class.param_of_inst thy ("op =", dtco);
   8.318 -        val get_thms = (fn () => get_eq (Theory.deref thy_ref) dtco |> rev);
   8.319 +        val get_thms = (fn () => get_eq' (Theory.deref thy_ref) dtco |> rev);
   8.320        in
   8.321          Code.add_funcl (const, Susp.delay get_thms) thy
   8.322        end;
   8.323 +    val sorts_eq =
   8.324 +      map (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq] o snd) vs;
   8.325    in
   8.326 -    prove_codetypes_arities (Class.intro_classes_tac [])
   8.327 -      (map (fn (tyco, (is_dt, _)) => (tyco, is_dt)) specs)
   8.328 -      [HOLogic.class_eq] ((K o K o pair) []) ((K o K o K) (fold add_eq_thms specs))
   8.329 +    thy
   8.330 +    |> Instance.instantiate (dtcos, sorts_eq, [HOLogic.class_eq]) (pair ())
   8.331 +         ((K o K) (Class.intro_classes_tac []))
   8.332 +    |> fold add_eq_thms dtcos
   8.333    end;
   8.334  
   8.335  
   8.336 -
   8.337  (** theory setup **)
   8.338  
   8.339 -fun add_datatype_spec dtco thy =
   8.340 +fun add_datatype_code dtcos thy =
   8.341    let
   8.342 -    val SOME (vs, cos) = DatatypePackage.get_datatype_spec thy dtco;
   8.343 -    val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
   8.344 -    val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
   8.345 -    val certs = get_case_cert thy dtco;
   8.346 +    val (vs :: _, coss) = (split_list o map (DatatypePackage.the_datatype_spec thy)) dtcos;
   8.347    in
   8.348      thy
   8.349 -    |> try (Code.add_datatype cs)
   8.350 -    |> the_default thy
   8.351 -    |> Code.add_case certs
   8.352 -    |> fold_rev Code.add_default_func case_rewrites
   8.353 +    |> fold2 (add_datatype_spec vs) dtcos coss
   8.354 +    |> fold add_datatype_cases dtcos
   8.355 +    |> add_datatypes_equality vs dtcos
   8.356    end;
   8.357  
   8.358  val setup = 
   8.359    add_codegen "datatype" datatype_codegen
   8.360    #> add_tycodegen "datatype" datatype_tycodegen
   8.361 -  #> DatatypePackage.interpretation (fold add_datatype_spec)
   8.362 -  #> add_codetypes_hook eq_hook
   8.363 +  #> DatatypePackage.interpretation add_datatype_code
   8.364  
   8.365  end;
     9.1 --- a/src/HOL/Tools/typecopy_package.ML	Wed Dec 05 14:15:39 2007 +0100
     9.2 +++ b/src/HOL/Tools/typecopy_package.ML	Wed Dec 05 14:15:45 2007 +0100
     9.3 @@ -20,7 +20,6 @@
     9.4    val get_typecopies: theory -> string list
     9.5    val get_info: theory -> string -> info option
     9.6    val interpretation: (string -> theory -> theory) -> theory -> theory
     9.7 -  val get_eq: theory -> string -> thm
     9.8    val print_typecopies: theory -> unit
     9.9    val setup: theory -> theory
    9.10  end;
    9.11 @@ -122,16 +121,19 @@
    9.12  
    9.13  (* code generator setup *)
    9.14  
    9.15 -fun get_eq thy = #inject o the o get_info thy;
    9.16 -
    9.17  fun add_typecopy_spec tyco thy =
    9.18    let
    9.19 -    val SOME { constr, proj_def, inject, ... } = get_info thy tyco;
    9.20 +    val SOME { constr, proj_def, inject, vs, ... } = get_info thy tyco;
    9.21 +    val sorts_eq =
    9.22 +      map (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq] o snd) vs;
    9.23      val ty = Logic.unvarifyT (Sign.the_const_type thy constr);
    9.24    in
    9.25      thy
    9.26      |> Code.add_datatype [(constr, ty)]
    9.27      |> Code.add_func proj_def
    9.28 +    |> Instance.instantiate ([tyco], sorts_eq, [HOLogic.class_eq]) (pair ())
    9.29 +         ((K o K) (Class.intro_classes_tac []))
    9.30 +    |> Code.add_func (CodeUnit.constrain_thm [HOLogic.class_eq] inject)
    9.31    end;
    9.32  
    9.33  val setup =