# HG changeset patch # User haftmann # Date 1242232896 -7200 # Node ID 1a5591ecb764b1d2a75f5e27c5925f514d9ecf88 # Parent a9f728dc5c8ec714a4c9aa0144aa57063daf9b88 tuned and generalized construction of code equations for eq diff -r a9f728dc5c8e -r 1a5591ecb764 src/HOL/Tools/datatype_codegen.ML --- a/src/HOL/Tools/datatype_codegen.ML Wed May 13 18:41:36 2009 +0200 +++ b/src/HOL/Tools/datatype_codegen.ML Wed May 13 18:41:36 2009 +0200 @@ -6,7 +6,7 @@ signature DATATYPE_CODEGEN = sig - val mk_eq: theory -> string -> thm list + val mk_eq_eqns: theory -> string -> (thm * bool) list val mk_case_cert: theory -> string -> thm val setup: theory -> theory end; @@ -309,18 +309,6 @@ (** generic code generator **) -(* specification *) - -fun add_datatype_spec vs dtco cos thy = - let - val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos; - in - thy - |> try (Code.add_datatype cs) - |> the_default thy - end; - - (* case certificates *) fun mk_case_cert thy tyco = @@ -354,88 +342,41 @@ |> Thm.varifyT end; -fun add_datatype_cases dtco thy = - let - val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco; - val cert = mk_case_cert thy dtco; - fun add_case_liberal thy = thy - |> try (Code.add_case cert) - |> the_default thy; - in - thy - |> add_case_liberal - |> fold_rev Code.add_default_eqn case_rewrites - end; - (* equality *) -local - -val not_sym = @{thm HOL.not_sym}; -val not_false_true = iffD2 OF [nth @{thms HOL.simp_thms} 7, TrueI]; -val refl = @{thm refl}; -val eqTrueI = @{thm eqTrueI}; - -fun mk_distinct cos = - let - fun sym_product [] = [] - | sym_product (x::xs) = map (pair x) xs @ sym_product xs; - fun mk_co_args (co, tys) ctxt = - let - val names = Name.invents ctxt "a" (length tys); - val ctxt' = fold Name.declare names ctxt; - val vs = map2 (curry Free) names tys; - in (vs, ctxt') end; - fun mk_dist ((co1, tys1), (co2, tys2)) = - let - val ((xs1, xs2), _) = Name.context - |> mk_co_args (co1, tys1) - ||>> mk_co_args (co2, tys2); - val prem = HOLogic.mk_eq - (list_comb (co1, xs1), list_comb (co2, xs2)); - val t = HOLogic.mk_not prem; - in HOLogic.mk_Trueprop t end; - in map mk_dist (sym_product cos) end; - -in - -fun mk_eq thy dtco = +fun mk_eq_eqns thy dtco = let - val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco; - fun mk_triv_inject co = - let - val ct' = Thm.cterm_of thy - (Const (co, Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vs))) - val cty' = Thm.ctyp_of_term ct'; - val SOME (ct, cty) = fold_aterms (fn Var (v, ty) => - (K o SOME) (Thm.cterm_of thy (Var (v, Thm.typ_of cty')), Thm.ctyp_of thy ty) | _ => I) - (Thm.prop_of refl) NONE; - in eqTrueI OF [Thm.instantiate ([(cty, cty')], [(ct, ct')]) refl] end; - val inject1 = map_filter (fn (co, []) => SOME (mk_triv_inject co) | _ => NONE) cs - val inject2 = (#inject o DatatypePackage.the_datatype thy) dtco; - val ctxt = ProofContext.init thy; - val simpset = Simplifier.context ctxt - (Simplifier.empty_ss addsimprocs [DatatypePackage.distinct_simproc]); - val cos = map (fn (co, tys) => - (Const (co, tys ---> Type (dtco, map TFree vs)), tys)) cs; - val tac = ALLGOALS (simp_tac simpset) - THEN ALLGOALS (ProofContext.fact_tac [not_false_true, TrueI]); - val distinct = - mk_distinct cos - |> map (fn t => Goal.prove_global thy [] [] t (K tac)) - |> (fn thms => thms @ map (fn thm => not_sym OF [thm]) thms) - in inject1 @ inject2 @ distinct end; + val (vs, cos) = DatatypePackage.the_datatype_spec thy dtco; + val { descr, index, inject = inject_thms, ... } = DatatypePackage.the_datatype thy dtco; + val ty = Type (dtco, map TFree vs); + fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT) + $ t1 $ t2; + fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const); + fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const); + val triv_injects = map_filter + (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty)))) + | _ => NONE) cos; + fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) = + trueprop $ (equiv $ mk_eq (t1, t2) $ rhs); + val injects = map prep_inject (nth (DatatypeProp.make_injs [descr] vs) index); + fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) = + [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)]; + val distincts = maps prep_distinct (snd (nth (DatatypeProp.make_distincts [descr] vs) index)); + val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty))); + val simpset = Simplifier.context (ProofContext.init thy) (HOL_basic_ss + addsimps (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms)) + addsimprocs [DatatypePackage.distinct_simproc]); + fun prove prop = Goal.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset))) + |> Simpdata.mk_eq + |> Simplifier.rewrite_rule [@{thm equals_eq}]; + in map (rpair true o prove) (triv_injects @ injects @ distincts) @ [(prove refl, false)] end; -end; - -fun add_datatypes_equality vs dtcos thy = +fun add_equality vs dtcos thy = let - val vs' = (map o apsnd) - (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq]) vs; fun add_def dtco lthy = let - val ty = Type (dtco, map TFree vs'); + val ty = Type (dtco, map TFree vs); fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT) $ Free ("x", ty) $ Free ("y", ty); val def = HOLogic.mk_Trueprop (HOLogic.mk_eq @@ -448,52 +389,60 @@ in (thm', lthy') end; fun tac thms = Class.intro_classes_tac [] THEN ALLGOALS (ProofContext.fact_tac thms); - fun mk_eq' thy dtco = mk_eq thy dtco - |> map (Code_Unit.constrain_thm thy [HOLogic.class_eq]) - |> map Simpdata.mk_eq - |> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}]) - |> map (AxClass.unoverload thy); fun add_eq_thms dtco thy = let - val ty = Type (dtco, map TFree vs'); + val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco); val thy_ref = Theory.check_thy thy; - val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco); - val eq_refl = @{thm HOL.eq_refl} - |> Thm.instantiate - ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], []) - |> Simpdata.mk_eq - |> AxClass.unoverload thy; - fun mk_thms () = (eq_refl, false) - :: rev (map (rpair true) (mk_eq' (Theory.deref thy_ref) dtco)); + fun mk_thms () = rev ((mk_eq_eqns (Theory.deref thy_ref) dtco)); in Code.add_eqnl (const, Lazy.lazy mk_thms) thy end; in thy - |> TheoryTarget.instantiation (dtcos, vs', [HOLogic.class_eq]) + |> TheoryTarget.instantiation (dtcos, vs, [HOLogic.class_eq]) |> fold_map add_def dtcos - |-> (fn thms => Class.prove_instantiation_instance (K (tac thms)) - #> LocalTheory.exit_global - #> fold Code.del_eqn thms - #> fold add_eq_thms dtcos) + |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm) + (fn _ => fn def_thms => tac def_thms) def_thms) + |-> (fn def_thms => fold Code.del_eqn def_thms) + |> fold add_eq_thms dtcos + end; + + +(* liberal addition of code data for datatypes *) + +fun mk_constr_consts thy vs dtco cos = + let + val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos; + val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs; + in if is_some (try (Code_Unit.constrset_of_consts thy) cs') + then SOME cs + else NONE end; +fun add_all_code dtcos thy = + let + val (vs :: _, coss) = (split_list o map (DatatypePackage.the_datatype_spec thy)) dtcos; + val any_css = map2 (mk_constr_consts thy vs) dtcos coss; + val css = if exists is_none any_css then [] + else map_filter I any_css; + val case_rewrites = maps (#case_rewrites o DatatypePackage.the_datatype thy) dtcos; + val certs = map (mk_case_cert thy) dtcos; + in + if null css then thy + else thy + |> fold Code.add_datatype css + |> fold_rev Code.add_default_eqn case_rewrites + |> fold Code.add_case certs + |> add_equality vs dtcos + end; + + (** theory setup **) -fun add_datatype_code dtcos thy = - let - val (vs :: _, coss) = (split_list o map (DatatypePackage.the_datatype_spec thy)) dtcos; - in - thy - |> fold2 (add_datatype_spec vs) dtcos coss - |> fold add_datatype_cases dtcos - |> add_datatypes_equality vs dtcos - end; - val setup = add_codegen "datatype" datatype_codegen #> add_tycodegen "datatype" datatype_tycodegen - #> DatatypePackage.interpretation add_datatype_code + #> DatatypePackage.interpretation add_all_code end;