--- a/src/HOL/Tools/datatype_codegen.ML Wed Dec 05 14:15:39 2007 +0100
+++ b/src/HOL/Tools/datatype_codegen.ML Wed Dec 05 14:15:45 2007 +0100
@@ -2,32 +2,21 @@
ID: $Id$
Author: Stefan Berghofer & Florian Haftmann, TU Muenchen
-Code generator for inductive datatypes.
+Code generator facilities for inductive datatypes.
*)
signature DATATYPE_CODEGEN =
sig
val get_eq: theory -> string -> thm list
- val get_eq_datatype: theory -> string -> thm list
val get_case_cert: theory -> string -> thm
-
- type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
- -> theory -> theory
- val add_codetypes_hook: hook -> theory -> theory
- val get_codetypes_arities: theory -> (string * bool) list -> sort
- -> (string * (arity * term list)) list
- val prove_codetypes_arities: tactic -> (string * bool) list -> sort
- -> (arity list -> (string * term list) list -> theory
- -> ((bstring * Attrib.src list) * term) list * theory)
- -> (arity list -> (string * term list) list -> thm list -> theory -> theory)
- -> theory -> theory
-
val setup: theory -> theory
end;
structure DatatypeCodegen : DATATYPE_CODEGEN =
struct
+(** SML code generator **)
+
open Codegen;
fun mk_tuple [p] = p
@@ -310,66 +299,21 @@
| datatype_tycodegen _ _ _ _ _ _ _ = NONE;
-(** datatypes for code 2nd generation **)
-
-local
-
-val not_sym = thm "HOL.not_sym";
-val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
-val refl = thm "refl";
-val eqTrueI = thm "eqTrueI";
+(** generic code generator **)
-fun mk_distinct cos =
- let
- fun sym_product [] = []
- | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
- fun mk_co_args (co, tys) ctxt =
- let
- val names = Name.invents ctxt "a" (length tys);
- val ctxt' = fold Name.declare names ctxt;
- val vs = map2 (curry Free) names tys;
- in (vs, ctxt') end;
- fun mk_dist ((co1, tys1), (co2, tys2)) =
- let
- val ((xs1, xs2), _) = Name.context
- |> mk_co_args (co1, tys1)
- ||>> mk_co_args (co2, tys2);
- val prem = HOLogic.mk_eq
- (list_comb (co1, xs1), list_comb (co2, xs2));
- val t = HOLogic.mk_not prem;
- in HOLogic.mk_Trueprop t end;
- in map mk_dist (sym_product cos) end;
+(* specification *)
-in
-
-fun get_eq_datatype thy dtco =
+fun add_datatype_spec vs dtco cos thy =
let
- val SOME (vs, cs) = DatatypePackage.get_datatype_spec thy dtco;
- fun mk_triv_inject co =
- let
- val ct' = Thm.cterm_of thy
- (Const (co, Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vs)))
- val cty' = Thm.ctyp_of_term ct';
- val SOME (ct, cty) = fold_aterms (fn Var (v, ty) =>
- (K o SOME) (Thm.cterm_of thy (Var (v, Thm.typ_of cty')), Thm.ctyp_of thy ty) | _ => I)
- (Thm.prop_of refl) NONE;
- in eqTrueI OF [Thm.instantiate ([(cty, cty')], [(ct, ct')]) refl] end;
- val inject1 = map_filter (fn (co, []) => SOME (mk_triv_inject co) | _ => NONE) cs
- val inject2 = (#inject o DatatypePackage.the_datatype thy) dtco;
- val ctxt = ProofContext.init thy;
- val simpset = Simplifier.context ctxt
- (MetaSimplifier.empty_ss addsimprocs [distinct_simproc]);
- val cos = map (fn (co, tys) =>
- (Const (co, tys ---> Type (dtco, map TFree vs)), tys)) cs;
- val tac = ALLGOALS (simp_tac simpset)
- THEN ALLGOALS (ProofContext.fact_tac [not_false_true, TrueI]);
- val distinct =
- mk_distinct cos
- |> map (fn t => Goal.prove_global thy [] [] t (K tac))
- |> (fn thms => thms @ map (fn thm => not_sym OF [thm]) thms)
- in inject1 @ inject2 @ distinct end;
+ val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
+ in
+ thy
+ |> try (Code.add_datatype cs)
+ |> the_default thy
+ end;
-end;
+
+(* case certificates *)
fun get_case_cert thy tyco =
let
@@ -402,170 +346,116 @@
|> Thm.varifyT
end;
-
-
-(** codetypes for code 2nd generation **)
-
-(* abstraction over datatypes vs. type copies *)
-
-fun get_typecopy_spec thy tyco =
+fun add_datatype_cases dtco thy =
let
- val SOME { vs, constr, typ, ... } = TypecopyPackage.get_info thy tyco
- in (vs, [(constr, [typ])]) end;
-
-
-fun get_spec thy (dtco, true) =
- (the o DatatypePackage.get_datatype_spec thy) dtco
- | get_spec thy (tyco, false) =
- get_typecopy_spec thy tyco;
-
-local
- fun get_eq_thms thy tyco = case DatatypePackage.get_datatype thy tyco
- of SOME _ => get_eq_datatype thy tyco
- | NONE => [TypecopyPackage.get_eq thy tyco];
- fun constrain_op_eq_thms thy thms =
- let
- fun add_eq (Const ("op =", ty)) =
- fold (insert (eq_fst (op =))) (Term.add_tvarsT ty [])
- | add_eq _ =
- I
- val eqs = fold (fold_aterms add_eq o Thm.prop_of) thms [];
- val instT = map (fn (v_i, sort) =>
- (Thm.ctyp_of thy (TVar (v_i, sort)),
- Thm.ctyp_of thy (TVar (v_i, Sorts.inter_sort (Sign.classes_of thy)
- (sort, [HOLogic.class_eq]))))) eqs;
- in
- thms
- |> map (Thm.instantiate (instT, []))
- end;
-in
- fun get_eq thy tyco =
- get_eq_thms thy tyco
- |> maps ((#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy)
- |> constrain_op_eq_thms thy
-end;
-
-type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
- -> theory -> theory;
-
-fun add_codetypes_hook hook thy =
- let
- fun add_spec thy (tyco, is_dt) =
- (tyco, (is_dt, get_spec thy (tyco, is_dt)));
- fun datatype_hook dtcos thy =
- hook (map (add_spec thy) (map (rpair true) dtcos)) thy;
- fun typecopy_hook tyco thy =
- hook ([(tyco, (false, get_typecopy_spec thy tyco))]) thy;
+ val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
+ val certs = get_case_cert thy dtco;
in
thy
- |> DatatypePackage.interpretation datatype_hook
- |> TypecopyPackage.interpretation typecopy_hook
+ |> Code.add_case certs
+ |> fold_rev Code.add_default_func case_rewrites
end;
-fun the_codetypes_mut_specs thy ([(tyco, is_dt)]) =
- let
- val (vs, cs) = get_spec thy (tyco, is_dt)
- in (vs, [(tyco, (is_dt, cs))]) end
- | the_codetypes_mut_specs thy (tycos' as (tyco, true) :: _) =
- let
- val tycos = map fst tycos';
- val tycos'' = (map (#1 o snd) o #descr o DatatypePackage.the_datatype thy) tyco;
- val _ = if gen_subset (op =) (tycos, tycos'') then () else
- error ("type constructors are not mutually recursive: " ^ (commas o map quote) tycos);
- val (vs::_, css) = split_list (map (the o DatatypePackage.get_datatype_spec thy) tycos);
- in (vs, map2 (fn (tyco, is_dt) => fn cs => (tyco, (is_dt, cs))) tycos' css) end;
+
+(* equality *)
+
+local
-
-(* instrumentalizing the sort algebra *)
+val not_sym = thm "HOL.not_sym";
+val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
+val refl = thm "refl";
+val eqTrueI = thm "eqTrueI";
-fun get_codetypes_arities thy tycos sort =
+fun mk_distinct cos =
let
- val pp = Sign.pp thy;
- val algebra = Sign.classes_of thy;
- val (vs_proto, css_proto) = the_codetypes_mut_specs thy tycos;
- val vs = map (fn (v, vsort) => (v, Sorts.inter_sort algebra (vsort, sort))) vs_proto;
- val css = map (fn (tyco, (_, cs)) => (tyco, cs)) css_proto;
- val algebra' = algebra
- |> fold (fn (tyco, _) =>
- Sorts.add_arities pp (tyco, map (fn class => (class, map snd vs)) sort)) css;
- fun typ_sort_inst ty = CodeUnit.typ_sort_inst algebra' (Logic.varifyT ty, sort);
- val venv = Vartab.empty
- |> fold (fn (v, sort) => Vartab.update_new ((v, 0), sort)) vs
- |> fold (fn (_, cs) => fold (fn (_, tys) => fold typ_sort_inst tys) cs) css;
- fun inst (v, _) = (v, (the o Vartab.lookup venv) (v, 0));
- val vs' = map inst vs;
- fun mk_arity tyco = (tyco, map snd vs', sort);
- fun mk_cons tyco (c, tys) =
+ fun sym_product [] = []
+ | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
+ fun mk_co_args (co, tys) ctxt =
+ let
+ val names = Name.invents ctxt "a" (length tys);
+ val ctxt' = fold Name.declare names ctxt;
+ val vs = map2 (curry Free) names tys;
+ in (vs, ctxt') end;
+ fun mk_dist ((co1, tys1), (co2, tys2)) =
let
- val tys' = (map o Term.map_type_tfree) (TFree o inst) tys;
- val ts = Name.names Name.context "a" tys';
- val ty = (tys' ---> Type (tyco, map TFree vs'));
- in list_comb (Const (c, ty), map Free ts) end;
- in
- map (fn (tyco, cs) => (tyco, (mk_arity tyco, map (mk_cons tyco) cs))) css
- end;
+ val ((xs1, xs2), _) = Name.context
+ |> mk_co_args (co1, tys1)
+ ||>> mk_co_args (co2, tys2);
+ val prem = HOLogic.mk_eq
+ (list_comb (co1, xs1), list_comb (co2, xs2));
+ val t = HOLogic.mk_not prem;
+ in HOLogic.mk_Trueprop t end;
+ in map mk_dist (sym_product cos) end;
+
+in
-fun prove_codetypes_arities tac tycos sort f after_qed thy =
- case try (get_codetypes_arities thy tycos) sort
- of NONE => thy
- | SOME insts => let
- fun proven (tyco, asorts, sort) =
- Sorts.of_sort (Sign.classes_of thy)
- (Type (tyco, map TFree (Name.names Name.context "'a" asorts)), sort);
- val (arities, css) = (split_list o map_filter
- (fn (tyco, (arity, cs)) => if proven arity
- then NONE else SOME (arity, (tyco, cs)))) insts;
- in
- thy
- |> not (null arities) ? (
- f arities css
- #-> (fn defs =>
- Instance.prove_instance tac arities defs
- #-> (fn defs =>
- after_qed arities css defs)))
- end;
+fun get_eq thy dtco =
+ let
+ val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco;
+ fun mk_triv_inject co =
+ let
+ val ct' = Thm.cterm_of thy
+ (Const (co, Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vs)))
+ val cty' = Thm.ctyp_of_term ct';
+ val SOME (ct, cty) = fold_aterms (fn Var (v, ty) =>
+ (K o SOME) (Thm.cterm_of thy (Var (v, Thm.typ_of cty')), Thm.ctyp_of thy ty) | _ => I)
+ (Thm.prop_of refl) NONE;
+ in eqTrueI OF [Thm.instantiate ([(cty, cty')], [(ct, ct')]) refl] end;
+ val inject1 = map_filter (fn (co, []) => SOME (mk_triv_inject co) | _ => NONE) cs
+ val inject2 = (#inject o DatatypePackage.the_datatype thy) dtco;
+ val ctxt = ProofContext.init thy;
+ val simpset = Simplifier.context ctxt
+ (MetaSimplifier.empty_ss addsimprocs [distinct_simproc]);
+ val cos = map (fn (co, tys) =>
+ (Const (co, tys ---> Type (dtco, map TFree vs)), tys)) cs;
+ val tac = ALLGOALS (simp_tac simpset)
+ THEN ALLGOALS (ProofContext.fact_tac [not_false_true, TrueI]);
+ val distinct =
+ mk_distinct cos
+ |> map (fn t => Goal.prove_global thy [] [] t (K tac))
+ |> (fn thms => thms @ map (fn thm => not_sym OF [thm]) thms)
+ in inject1 @ inject2 @ distinct end;
-
-(* operational equality *)
+end;
-fun eq_hook specs =
+fun add_datatypes_equality vs dtcos thy =
let
- fun add_eq_thms (dtco, (_, (vs, cs))) thy =
+ fun get_eq' thy dtco = get_eq thy dtco
+ |> map (CodeUnit.constrain_thm [HOLogic.class_eq])
+ |> maps ((#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy);
+ fun add_eq_thms dtco thy =
let
val thy_ref = Theory.check_thy thy;
val const = Class.param_of_inst thy ("op =", dtco);
- val get_thms = (fn () => get_eq (Theory.deref thy_ref) dtco |> rev);
+ val get_thms = (fn () => get_eq' (Theory.deref thy_ref) dtco |> rev);
in
Code.add_funcl (const, Susp.delay get_thms) thy
end;
+ val sorts_eq =
+ map (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq] o snd) vs;
in
- prove_codetypes_arities (Class.intro_classes_tac [])
- (map (fn (tyco, (is_dt, _)) => (tyco, is_dt)) specs)
- [HOLogic.class_eq] ((K o K o pair) []) ((K o K o K) (fold add_eq_thms specs))
+ thy
+ |> Instance.instantiate (dtcos, sorts_eq, [HOLogic.class_eq]) (pair ())
+ ((K o K) (Class.intro_classes_tac []))
+ |> fold add_eq_thms dtcos
end;
-
(** theory setup **)
-fun add_datatype_spec dtco thy =
+fun add_datatype_code dtcos thy =
let
- val SOME (vs, cos) = DatatypePackage.get_datatype_spec thy dtco;
- val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
- val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
- val certs = get_case_cert thy dtco;
+ val (vs :: _, coss) = (split_list o map (DatatypePackage.the_datatype_spec thy)) dtcos;
in
thy
- |> try (Code.add_datatype cs)
- |> the_default thy
- |> Code.add_case certs
- |> fold_rev Code.add_default_func case_rewrites
+ |> fold2 (add_datatype_spec vs) dtcos coss
+ |> fold add_datatype_cases dtcos
+ |> add_datatypes_equality vs dtcos
end;
val setup =
add_codegen "datatype" datatype_codegen
#> add_tycodegen "datatype" datatype_tycodegen
- #> DatatypePackage.interpretation (fold add_datatype_spec)
- #> add_codetypes_hook eq_hook
+ #> DatatypePackage.interpretation add_datatype_code
end;