--- a/src/HOL/Tools/Datatype/datatype_codegen.ML Thu Apr 22 09:30:39 2010 +0200
+++ b/src/HOL/Tools/Datatype/datatype_codegen.ML Thu Apr 22 12:07:00 2010 +0200
@@ -12,6 +12,137 @@
structure Datatype_Codegen : DATATYPE_CODEGEN =
struct
+(** generic code generator **)
+
+(* liberal addition of code data for datatypes *)
+
+fun mk_constr_consts thy vs dtco cos =
+ let
+ val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
+ val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
+ in if is_some (try (Code.constrset_of_consts thy) cs')
+ then SOME cs
+ else NONE
+ end;
+
+
+(* case certificates *)
+
+fun mk_case_cert thy tyco =
+ let
+ val raw_thms =
+ (#case_rewrites o Datatype_Data.the_info thy) tyco;
+ val thms as hd_thm :: _ = raw_thms
+ |> Conjunction.intr_balanced
+ |> Thm.unvarify_global
+ |> Conjunction.elim_balanced (length raw_thms)
+ |> map Simpdata.mk_meta_eq
+ |> map Drule.zero_var_indexes
+ val params = fold_aterms (fn (Free (v, _)) => insert (op =) v
+ | _ => I) (Thm.prop_of hd_thm) [];
+ val rhs = hd_thm
+ |> Thm.prop_of
+ |> Logic.dest_equals
+ |> fst
+ |> Term.strip_comb
+ |> apsnd (fst o split_last)
+ |> list_comb;
+ val lhs = Free (Name.variant params "case", Term.fastype_of rhs);
+ val asm = (Thm.cterm_of thy o Logic.mk_equals) (lhs, rhs);
+ in
+ thms
+ |> Conjunction.intr_balanced
+ |> MetaSimplifier.rewrite_rule [(Thm.symmetric o Thm.assume) asm]
+ |> Thm.implies_intr asm
+ |> Thm.generalize ([], params) 0
+ |> AxClass.unoverload thy
+ |> Thm.varifyT_global
+ end;
+
+
+(* equality *)
+
+fun mk_eq_eqns thy dtco =
+ let
+ val (vs, cos) = Datatype_Data.the_spec thy dtco;
+ val { descr, index, inject = inject_thms, distinct = distinct_thms, ... } =
+ Datatype_Data.the_info thy dtco;
+ val ty = Type (dtco, map TFree vs);
+ fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT)
+ $ t1 $ t2;
+ fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const);
+ fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const);
+ val triv_injects = map_filter
+ (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty))))
+ | _ => NONE) cos;
+ fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) =
+ trueprop $ (equiv $ mk_eq (t1, t2) $ rhs);
+ val injects = map prep_inject (nth (Datatype_Prop.make_injs [descr] vs) index);
+ fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) =
+ [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)];
+ val distincts = maps prep_distinct (snd (nth (Datatype_Prop.make_distincts [descr] vs) index));
+ val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty)));
+ val simpset = Simplifier.global_context thy (HOL_basic_ss addsimps
+ (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms @ distinct_thms)));
+ fun prove prop = Skip_Proof.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
+ |> Simpdata.mk_eq;
+ in (map prove (triv_injects @ injects @ distincts), prove refl) end;
+
+fun add_equality vs dtcos thy =
+ let
+ fun add_def dtco lthy =
+ let
+ val ty = Type (dtco, map TFree vs);
+ fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
+ $ Free ("x", ty) $ Free ("y", ty);
+ val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
+ (mk_side @{const_name eq_class.eq}, mk_side @{const_name "op ="}));
+ val def' = Syntax.check_term lthy def;
+ val ((_, (_, thm)), lthy') = Specification.definition
+ (NONE, (Attrib.empty_binding, def')) lthy;
+ val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy);
+ val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
+ in (thm', lthy') end;
+ fun tac thms = Class.intro_classes_tac []
+ THEN ALLGOALS (ProofContext.fact_tac thms);
+ fun add_eq_thms dtco =
+ Theory.checkpoint
+ #> `(fn thy => mk_eq_eqns thy dtco)
+ #-> (fn (thms, thm) =>
+ Code.add_nbe_eqn thm
+ #> fold_rev Code.add_eqn thms);
+ in
+ thy
+ |> Theory_Target.instantiation (dtcos, vs, [HOLogic.class_eq])
+ |> fold_map add_def dtcos
+ |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm)
+ (fn _ => fn def_thms => tac def_thms) def_thms)
+ |-> (fn def_thms => fold Code.del_eqn def_thms)
+ |> fold add_eq_thms dtcos
+ end;
+
+
+(* register a datatype etc. *)
+
+fun add_all_code config dtcos thy =
+ let
+ val (vs :: _, coss) = (split_list o map (Datatype_Data.the_spec thy)) dtcos;
+ val any_css = map2 (mk_constr_consts thy vs) dtcos coss;
+ val css = if exists is_none any_css then []
+ else map_filter I any_css;
+ val case_rewrites = maps (#case_rewrites o Datatype_Data.the_info thy) dtcos;
+ val certs = map (mk_case_cert thy) dtcos;
+ in
+ if null css then thy
+ else thy
+ |> tap (fn _ => Datatype_Aux.message config "Registering datatype for code generator ...")
+ |> fold Code.add_datatype css
+ |> fold_rev Code.add_default_eqn case_rewrites
+ |> fold Code.add_case certs
+ |> add_equality vs dtcos
+ end;
+
+
(** SML code generator **)
open Codegen;
@@ -288,142 +419,11 @@
| datatype_tycodegen _ _ _ _ _ _ _ = NONE;
-(** generic code generator **)
-
-(* liberal addition of code data for datatypes *)
-
-fun mk_constr_consts thy vs dtco cos =
- let
- val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
- val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
- in if is_some (try (Code.constrset_of_consts thy) cs')
- then SOME cs
- else NONE
- end;
-
-
-(* case certificates *)
-
-fun mk_case_cert thy tyco =
- let
- val raw_thms =
- (#case_rewrites o Datatype_Data.the_info thy) tyco;
- val thms as hd_thm :: _ = raw_thms
- |> Conjunction.intr_balanced
- |> Thm.unvarify_global
- |> Conjunction.elim_balanced (length raw_thms)
- |> map Simpdata.mk_meta_eq
- |> map Drule.zero_var_indexes
- val params = fold_aterms (fn (Free (v, _)) => insert (op =) v
- | _ => I) (Thm.prop_of hd_thm) [];
- val rhs = hd_thm
- |> Thm.prop_of
- |> Logic.dest_equals
- |> fst
- |> Term.strip_comb
- |> apsnd (fst o split_last)
- |> list_comb;
- val lhs = Free (Name.variant params "case", Term.fastype_of rhs);
- val asm = (Thm.cterm_of thy o Logic.mk_equals) (lhs, rhs);
- in
- thms
- |> Conjunction.intr_balanced
- |> MetaSimplifier.rewrite_rule [(Thm.symmetric o Thm.assume) asm]
- |> Thm.implies_intr asm
- |> Thm.generalize ([], params) 0
- |> AxClass.unoverload thy
- |> Thm.varifyT_global
- end;
-
-
-(* equality *)
-
-fun mk_eq_eqns thy dtco =
- let
- val (vs, cos) = Datatype_Data.the_spec thy dtco;
- val { descr, index, inject = inject_thms, distinct = distinct_thms, ... } =
- Datatype_Data.the_info thy dtco;
- val ty = Type (dtco, map TFree vs);
- fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT)
- $ t1 $ t2;
- fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const);
- fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const);
- val triv_injects = map_filter
- (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty))))
- | _ => NONE) cos;
- fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) =
- trueprop $ (equiv $ mk_eq (t1, t2) $ rhs);
- val injects = map prep_inject (nth (Datatype_Prop.make_injs [descr] vs) index);
- fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) =
- [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)];
- val distincts = maps prep_distinct (snd (nth (Datatype_Prop.make_distincts [descr] vs) index));
- val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty)));
- val simpset = Simplifier.global_context thy (HOL_basic_ss addsimps
- (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms @ distinct_thms)));
- fun prove prop = Skip_Proof.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
- |> Simpdata.mk_eq;
- in (map prove (triv_injects @ injects @ distincts), prove refl) end;
-
-fun add_equality vs dtcos thy =
- let
- fun add_def dtco lthy =
- let
- val ty = Type (dtco, map TFree vs);
- fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
- $ Free ("x", ty) $ Free ("y", ty);
- val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
- (mk_side @{const_name eq_class.eq}, mk_side @{const_name "op ="}));
- val def' = Syntax.check_term lthy def;
- val ((_, (_, thm)), lthy') = Specification.definition
- (NONE, (Attrib.empty_binding, def')) lthy;
- val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy);
- val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
- in (thm', lthy') end;
- fun tac thms = Class.intro_classes_tac []
- THEN ALLGOALS (ProofContext.fact_tac thms);
- fun add_eq_thms dtco =
- Theory.checkpoint
- #> `(fn thy => mk_eq_eqns thy dtco)
- #-> (fn (thms, thm) =>
- Code.add_nbe_eqn thm
- #> fold_rev Code.add_eqn thms);
- in
- thy
- |> Theory_Target.instantiation (dtcos, vs, [HOLogic.class_eq])
- |> fold_map add_def dtcos
- |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm)
- (fn _ => fn def_thms => tac def_thms) def_thms)
- |-> (fn def_thms => fold Code.del_eqn def_thms)
- |> fold add_eq_thms dtcos
- end;
-
-
-(* register a datatype etc. *)
-
-fun add_all_code config dtcos thy =
- let
- val (vs :: _, coss) = (split_list o map (Datatype_Data.the_spec thy)) dtcos;
- val any_css = map2 (mk_constr_consts thy vs) dtcos coss;
- val css = if exists is_none any_css then []
- else map_filter I any_css;
- val case_rewrites = maps (#case_rewrites o Datatype_Data.the_info thy) dtcos;
- val certs = map (mk_case_cert thy) dtcos;
- in
- if null css then thy
- else thy
- |> tap (fn _ => Datatype_Aux.message config "Registering datatype for code generator ...")
- |> fold Code.add_datatype css
- |> fold_rev Code.add_default_eqn case_rewrites
- |> fold Code.add_case certs
- |> add_equality vs dtcos
- end;
-
-
(** theory setup **)
val setup =
- add_codegen "datatype" datatype_codegen
- #> add_tycodegen "datatype" datatype_tycodegen
- #> Datatype_Data.interpretation add_all_code
+ Datatype_Data.interpretation add_all_code
+ #> add_codegen "datatype" datatype_codegen
+ #> add_tycodegen "datatype" datatype_tycodegen;
end;