--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Tue Sep 11 11:29:28 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Tue Sep 11 11:53:34 2012 +0200
@@ -24,11 +24,16 @@
val caseN = "case";
val coitersN = "coiters";
val corecsN = "corecs";
+val disc_coitersN = "disc_coiters";
+val disc_corecsN = "disc_corecs";
val itersN = "iters";
val recsN = "recs";
+val sel_coitersN = "sel_coiters";
+val sel_corecsN = "sel_corecs";
-fun split_list8 xs =
- (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs, map #8 xs);
+fun split_list11 xs =
+ (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs, map #8 xs,
+ map #9 xs, map #10 xs, map #11 xs);
fun strip_map_type (Type (@{type_name fun}, [T as Type _, T'])) = strip_map_type T' |>> cons T
| strip_map_type T = ([], T);
@@ -53,9 +58,10 @@
fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v));
fun tack z_name (c, v) f =
- let val z = Free (z_name, mk_sumT (fastype_of v, fastype_of c)) in
- Term.lambda z (mk_sum_case (Term.lambda v v) (Term.lambda c (f $ c)) $ z)
- end;
+ let
+ val T = fastype_of v;
+ val z = Free (z_name, mk_sumT (T, fastype_of c))
+ in Term.lambda z (mk_sum_case (mk_id T, Term.lambda c (f $ c)) $ z) end;
fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
@@ -148,9 +154,9 @@
| freeze_fp T = T;
val ctr_TsssXs = map (map (map freeze_fp)) fake_ctr_Tsss;
- val sum_prod_TsXs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssXs;
+ val ctr_sum_prod_TsXs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssXs;
- val eqs = map dest_TFree Xs ~~ sum_prod_TsXs;
+ val eqs = map dest_TFree Xs ~~ ctr_sum_prod_TsXs;
val (pre_bnfs, ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects,
fp_iter_thms, fp_rec_thms), lthy)) =
@@ -209,13 +215,12 @@
if member (op =) Cs U then Us else [T]
| dest_rec_pair T = [T];
- val ((iter_only as (gss, g_Tss, yssss), rec_only as (hss, h_Tss, zssss)),
- (zs, cs, cpss, p_Tss, coiter_only as ((pgss, cgsss), g_sum_prod_Ts, g_prod_Tss, g_Tsss),
- corec_only as ((phss, chsss), h_sum_prod_Ts, h_prod_Tss, h_Tsss))) =
+ val ((iter_only as (gss, _, _), rec_only as (hss, _, _)),
+ (zs, cs, cpss, coiter_only as ((pgss, cgsss), _), corec_only as ((phss, chsss), _))) =
if lfp then
let
val y_Tsss =
- map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN n o domain_type)
+ map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
ns mss fp_iter_fun_Ts;
val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
@@ -225,8 +230,8 @@
||>> mk_Freesss "x" y_Tsss;
val z_Tssss =
- map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n
- o domain_type) ns mss fp_rec_fun_Ts;
+ map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o
+ dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
val hss = map2 (map2 retype_free) gss h_Tss;
@@ -235,7 +240,7 @@
|> mk_Freessss "x" z_Tssss;
in
(((gss, g_Tss, map (map (map single)) ysss), (hss, h_Tss, zssss)),
- ([], [], [], [], (([], []), [], [], []), (([], []), [], [], [])))
+ ([], [], [], (([], []), ([], [])), (([], []), ([], []))))
end
else
let
@@ -245,20 +250,20 @@
val p_Tss =
map2 (fn C => fn n => replicate (Int.max (0, n - 1)) (C --> HOLogic.boolT)) Cs ns;
- fun popescu_zip [] [fs] = fs
- | popescu_zip (p :: ps) (fs :: fss) = p :: fs @ popescu_zip ps fss;
+ fun zip_preds_getters [] [fs] = fs
+ | zip_preds_getters (p :: ps) (fs :: fss) = p :: fs @ zip_preds_getters ps fss;
fun mk_types fun_Ts =
let
val f_sum_prod_Ts = map range_type fun_Ts;
- val f_prod_Tss = map2 dest_sumTN ns f_sum_prod_Ts;
+ val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
val f_Tsss =
map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss' f_prod_Tss;
- val pf_Tss = map2 popescu_zip p_Tss f_Tsss
- in (f_sum_prod_Ts, f_prod_Tss, f_Tsss, pf_Tss) end;
+ val pf_Tss = map2 zip_preds_getters p_Tss f_Tsss
+ in (f_sum_prod_Ts, f_Tsss, pf_Tss) end;
- val (g_sum_prod_Ts, g_prod_Tss, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
- val (h_sum_prod_Ts, h_prod_Tss, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
+ val (g_sum_prod_Ts, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
+ val (h_sum_prod_Ts, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
val ((((Free (z, _), cs), pss), gsss), _) =
lthy
@@ -273,13 +278,13 @@
fun mk_terms fsss =
let
- val pfss = map2 popescu_zip pss fsss;
+ val pfss = map2 zip_preds_getters pss fsss;
val cfsss = map2 (fn c => map (map (fn f => f $ c))) cs fsss
in (pfss, cfsss) end;
in
((([], [], []), ([], [], [])),
- ([z], cs, cpss, p_Tss, (mk_terms gsss, g_sum_prod_Ts, g_prod_Tss, pg_Tss),
- (mk_terms hsss, h_sum_prod_Ts, h_prod_Tss, ph_Tss)))
+ ([z], cs, cpss, (mk_terms gsss, (g_sum_prod_Ts, pg_Tss)),
+ (mk_terms hsss, (h_sum_prod_Ts, ph_Tss))))
end;
fun pour_some_sugar_on_type (((((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec),
@@ -288,6 +293,7 @@
let
val unfT = domain_type (fastype_of fld);
val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
+ val ctr_sum_prod_T = mk_sumTN_balanced ctr_prod_Ts;
val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
val ((((u, v), fs), xss), _) =
@@ -298,13 +304,14 @@
||>> mk_Freess "x" ctr_Tss;
val ctr_rhss =
- map2 (fn k => fn xs =>
- fold_rev Term.lambda xs (fld $ mk_InN ctr_prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
+ map2 (fn k => fn xs => fold_rev Term.lambda xs (fld $
+ mk_InN_balanced ctr_sum_prod_T n (HOLogic.mk_tuple xs) k)) ks xss;
val case_binder = Binding.suffix_name ("_" ^ caseN) b;
val case_rhs =
- fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (map2 mk_uncurried_fun fs xss) $ (unf $ v));
+ fold_rev Term.lambda (fs @ [v])
+ (mk_sum_caseN_balanced (map2 mk_uncurried_fun fs xss) $ (unf $ v));
val ((raw_case :: raw_ctrs, raw_case_def :: raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
|> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
@@ -340,7 +347,8 @@
val sumEN_thm' =
Local_Defs.unfold lthy @{thms all_unit_eq}
- (Drule.instantiate' (map (SOME o certifyT lthy) ctr_prod_Ts) [] (mk_sumEN n))
+ (Drule.instantiate' (map (SOME o certifyT lthy) ctr_prod_Ts) []
+ (mk_sumEN_balanced n))
|> Morphism.thm phi;
in
mk_exhaust_tac ctxt n ctr_defs fld_iff_unf_thm sumEN_thm'
@@ -360,7 +368,7 @@
val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
- fun some_lfp_sugar no_defs_lthy =
+ fun some_lfp_sugar ((selss0, discIs, sel_thmss), no_defs_lthy) =
let
val fpT_to_C = fpT --> C;
@@ -373,7 +381,7 @@
val spec =
mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binder, res_T)),
Term.list_comb (fp_iter_like,
- map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) fss xssss));
+ map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss));
in (binder, spec) end;
val iter_likes =
@@ -395,28 +403,28 @@
val [iter, recx] = map (mk_iter_like As Cs o Morphism.term phi) csts;
in
- ((ctrs, iter, recx, v, xss, ctr_defs, iter_def, rec_def), lthy)
+ ((ctrs, selss0, iter, recx, v, xss, ctr_defs, discIs, sel_thmss, iter_def, rec_def),
+ lthy)
end;
- fun some_gfp_sugar no_defs_lthy =
+ fun some_gfp_sugar ((selss0, discIs, sel_thmss), no_defs_lthy) =
let
val B_to_fpT = C --> fpT;
- fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfsss), f_sum_prod_Ts, f_prod_Tss,
- pf_Tss)) =
+ fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfsss), (f_sum_prod_Ts, pf_Tss))) =
let
val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
val binder = Binding.suffix_name ("_" ^ suf) b;
- fun mk_popescu_join c n cps sum_prod_T prod_Ts cfss =
+ fun mk_preds_getters_join c n cps sum_prod_T cfss =
Term.lambda c (mk_IfN sum_prod_T cps
- (map2 (mk_InN prod_Ts) (map HOLogic.mk_tuple cfss) (1 upto n)));
+ (map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cfss) (1 upto n)));
val spec =
mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
Term.list_comb (fp_iter_like,
- map6 mk_popescu_join cs ns cpss f_sum_prod_Ts f_prod_Tss cfsss));
+ map5 mk_preds_getters_join cs ns cpss f_sum_prod_Ts cfsss));
in (binder, spec) end;
val coiter_likes =
@@ -438,7 +446,8 @@
val [coiter, corec] = map (mk_iter_like As Cs o Morphism.term phi) csts;
in
- ((ctrs, coiter, corec, v, xss, ctr_defs, coiter_def, corec_def), lthy)
+ ((ctrs, selss0, coiter, corec, v, xss, ctr_defs, discIs, sel_thmss, coiter_def,
+ corec_def), lthy)
end;
in
wrap_datatype tacss ((ctrs0, casex0), (disc_binders, sel_binderss)) lthy'
@@ -461,8 +470,8 @@
val args = map build_arg TUs;
in Term.list_comb (mapx, args) end;
- fun pour_more_sugar_on_lfps ((ctrss, iters, recs, vs, xsss, ctr_defss, iter_defs, rec_defs),
- lthy) =
+ fun pour_more_sugar_on_lfps ((ctrss, _, iters, recs, vs, xsss, ctr_defss, _, _, iter_defs,
+ rec_defs), lthy) =
let
val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
val giters = map (lists_bmoc gss) iters;
@@ -505,9 +514,11 @@
val rec_tacss =
map2 (map o mk_iter_like_tac pre_map_defs map_ids rec_defs) fp_rec_thms ctr_defss;
in
- (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
+ (map2 (map2 (fn goal => fn tac =>
+ Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
goal_iterss iter_tacss,
- map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
+ map2 (map2 (fn goal => fn tac =>
+ Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
goal_recss rec_tacss)
end;
@@ -522,8 +533,8 @@
lthy |> Local_Theory.notes notes |> snd
end;
- fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, vs, xsss, ctr_defss, coiter_defs,
- corec_defs), lthy) =
+ fun pour_more_sugar_on_gfps ((ctrss, selsss, coiters, corecs, vs, _, ctr_defss, discIss,
+ sel_thmsss, coiter_defs, corec_defs), lthy) =
let
val z = the_single zs;
@@ -578,13 +589,40 @@
goal_corecss corec_tacss)
end;
+ fun mk_disc_coiter_like_thms [_] = K []
+ | mk_disc_coiter_like_thms thms = map2 (curry (op RS)) thms;
+
+ val disc_coiter_thmss = map2 mk_disc_coiter_like_thms coiter_thmss discIss;
+ val disc_corec_thmss = map2 mk_disc_coiter_like_thms corec_thmss discIss;
+
+ fun mk_sel_coiter_like_thm coiter_like_thm sel0 sel_thm =
+ let
+ val (domT, ranT) = dest_funT (fastype_of sel0);
+ val arg_cong' =
+ Drule.instantiate' (map (SOME o certifyT lthy) [domT, ranT])
+ [NONE, NONE, SOME (certify lthy sel0)] arg_cong
+ |> Thm.varifyT_global;
+ val sel_thm' = sel_thm RSN (2, trans);
+ in
+ coiter_like_thm RS arg_cong' RS sel_thm'
+ end;
+
+ val sel_coiter_thmsss =
+ map3 (map3 (map2 o mk_sel_coiter_like_thm)) coiter_thmss selsss sel_thmsss;
+ val sel_corec_thmsss =
+ map3 (map3 (map2 o mk_sel_coiter_like_thm)) corec_thmss selsss sel_thmsss;
+
val notes =
[(coitersN, coiter_thmss),
- (corecsN, corec_thmss)]
+ (disc_coitersN, disc_coiter_thmss),
+ (sel_coitersN, map flat sel_coiter_thmsss),
+ (corecsN, corec_thmss),
+ (disc_corecsN, disc_corec_thmss),
+ (sel_corecsN, map flat sel_corec_thmsss)]
|> maps (fn (thmN, thmss) =>
- map2 (fn b => fn thms =>
- ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
- bs thmss);
+ map_filter (fn (_, []) => NONE | (b, thms) =>
+ SOME ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []),
+ [(thms, [])])) (bs ~~ thmss));
in
lthy |> Local_Theory.notes notes |> snd
end;
@@ -593,7 +631,7 @@
|> fold_map pour_some_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_binderss ~~
ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss)
- |>> split_list8
+ |>> split_list11
|> (if lfp then pour_more_sugar_on_lfps else pour_more_sugar_on_gfps);
val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML Tue Sep 11 11:29:28 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML Tue Sep 11 11:53:34 2012 +0200
@@ -11,7 +11,8 @@
val mk_half_pairss: 'a list -> ('a * 'a) list list
val mk_ctr: typ list -> term -> term
val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list ->
- (term list * term) * (binding list * binding list list) -> local_theory -> local_theory
+ (term list * term) * (binding list * binding list list) -> local_theory ->
+ (term list list * thm list * thm list list) * local_theory
end;
structure BNF_Wrap : BNF_WRAP =
@@ -46,6 +47,8 @@
fun pad_list x n xs = xs @ replicate (n - length xs) x;
+fun unflat_lookup eq ys zs = map (map (fn x => nth zs (find_index (curry eq x) ys)));
+
fun mk_half_pairss' _ [] = []
| mk_half_pairss' indent (y :: ys) =
indent @ fold_rev (cons o single o pair y) ys (mk_half_pairss' ([] :: indent) ys);
@@ -84,15 +87,15 @@
val _ = if n > 0 then () else error "No constructors specified";
- val Type (T_name, As0) = body_type (fastype_of (hd ctrs0));
- val b = Binding.qualified_name T_name;
+ val Type (fpT_name, As0) = body_type (fastype_of (hd ctrs0));
+ val b = Binding.qualified_name fpT_name;
val (As, B) =
no_defs_lthy
|> mk_TFrees (length As0)
||> the_single o fst o mk_TFrees 1;
- val T = Type (T_name, As);
+ val fpT = Type (fpT_name, As);
val ctrs = map (mk_ctr As) ctrs0;
val ctr_Tss = map (binder_types o fastype_of) ctrs;
@@ -146,8 +149,8 @@
||>> mk_Freess "y" ctr_Tss
||>> mk_Frees "f" case_Ts
||>> mk_Frees "g" case_Ts
- ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "v") T
- ||>> yield_singleton (mk_Frees "w") T
+ ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "v") fpT
+ ||>> yield_singleton (mk_Frees "w") fpT
||>> yield_singleton (apfst (op ~~) oo mk_Frees' "P") HOLogic.boolT;
val q = Free (fst p', B --> HOLogic.boolT);
@@ -170,10 +173,7 @@
val exist_xs_v_eq_ctrs =
map2 (fn xctr => fn xs => list_exists_free xs (HOLogic.mk_eq (v, xctr))) xctrs xss;
- fun mk_sel_case_args k xs x T =
- map2 (fn Ts => fn i => if i = k then fold_rev Term.lambda xs x else mk_undef T Ts) ctr_Tss ks;
-
- fun disc_free b = Free (Binding.name_of b, T --> HOLogic.boolT);
+ fun disc_free b = Free (Binding.name_of b, fpT --> HOLogic.boolT);
fun disc_spec b exist_xs_v_eq_ctr = mk_Trueprop_eq (disc_free b $ v, exist_xs_v_eq_ctr);
@@ -186,19 +186,40 @@
fun alternate_disc k =
if n = 2 then Term.lambda v (alternate_disc_lhs (3 - k)) else error "Cannot use \"*\" here"
- fun sel_spec b x xs k =
- let val T' = fastype_of x in
- mk_Trueprop_eq (Free (Binding.name_of b, T --> T') $ v,
- Term.list_comb (mk_case As T', mk_sel_case_args k xs x T') $ v)
+ fun mk_sel_case_args proto_sels T =
+ map2 (fn Ts => fn i =>
+ case AList.lookup (op =) proto_sels i of
+ NONE => mk_undef T Ts
+ | SOME (xs, x) => fold_rev Term.lambda xs x) ctr_Tss ks;
+
+ fun sel_spec b proto_sels =
+ let
+ val _ =
+ (case duplicates (op =) (map fst proto_sels) of
+ k :: _ => error ("Duplicate selector name " ^ quote (Binding.name_of b) ^
+ " for constructor " ^ quote (Syntax.string_of_term no_defs_lthy (nth ctrs (k - 1))))
+ | [] => ())
+ val T =
+ (case distinct (op =) (map (fastype_of o snd o snd) proto_sels) of
+ [T] => T
+ | T :: T' :: _ => error ("Inconsistent range type for selector " ^
+ quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ no_defs_lthy T) ^
+ " vs. " ^ quote (Syntax.string_of_typ no_defs_lthy T')));
+ in
+ mk_Trueprop_eq (Free (Binding.name_of b, fpT --> T) $ v,
+ Term.list_comb (mk_case As T, mk_sel_case_args proto_sels T) $ v)
end;
val missing_unique_disc_def = TrueI; (*arbitrary marker*)
val missing_alternate_disc_def = FalseE; (*arbitrary marker*)
- (* TODO: Allow use of same selector for several constructors *)
- (* TODO: Allow use of same name for datatype and for constructor, e.g. "data L = L" *)
+ val proto_selss = map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss;
+ val sel_bundles = AList.group Binding.eq_name (flat sel_binderss ~~ flat proto_selss);
+ val sel_binders = map fst sel_bundles;
- val (((raw_discs, raw_disc_defs), (raw_selss, raw_sel_defss)), (lthy', lthy)) =
+ fun unflat_selss xs = unflat_lookup Binding.eq_name sel_binders xs sel_binderss;
+
+ val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) =
no_defs_lthy
|> apfst split_list o fold_map4 (fn k => fn m => fn exist_xs_v_eq_ctr =>
fn NONE =>
@@ -208,19 +229,19 @@
| SOME b => Specification.definition (SOME (b, NONE, NoSyn),
((Thm.def_binding b, []), disc_spec b exist_xs_v_eq_ctr)) #>> apsnd snd)
ks ms exist_xs_v_eq_ctrs disc_binders
- ||>> apfst split_list o fold_map3 (fn bs => fn xs => fn k => apfst split_list o
- fold_map2 (fn b => fn x => Specification.definition (SOME (b, NONE, NoSyn),
- ((Thm.def_binding b, []), sel_spec b x xs k)) #>> apsnd snd) bs xs) sel_binderss xss ks
+ ||>> apfst split_list o fold_map (fn (b, proto_sels) =>
+ Specification.definition (SOME (b, NONE, NoSyn),
+ ((Thm.def_binding b, []), sel_spec b proto_sels)) #>> apsnd snd) sel_bundles
||> `Local_Theory.restore;
(*transforms defined frees into consts (and more)*)
val phi = Proof_Context.export_morphism lthy lthy';
val disc_defs = map (Morphism.thm phi) raw_disc_defs;
- val sel_defss = map (map (Morphism.thm phi)) raw_sel_defss;
+ val sel_defss = unflat_selss (map (Morphism.thm phi) raw_sel_defs);
val discs0 = map (Morphism.term phi) raw_discs;
- val selss0 = map (map (Morphism.term phi)) raw_selss;
+ val selss0 = unflat_selss (map (Morphism.term phi) raw_sels);
fun mk_disc_or_sel Ts c =
Term.subst_atomic_types (snd (Term.dest_Type (domain_type (fastype_of c))) ~~ Ts) c;
@@ -289,23 +310,8 @@
end;
val sel_thmss =
- let
- fun mk_thm k xs goal_case case_thm x sel_def =
- let
- val T = fastype_of x;
- val cTs =
- map ((fn T' => certifyT lthy (if T' = B then T else T')) o TFree)
- (rev (Term.add_tfrees goal_case []));
- val cxs = map (certify lthy) (mk_sel_case_args k xs x T);
- in
- Local_Defs.fold lthy [sel_def]
- (Drule.instantiate' (map SOME cTs) (map SOME cxs) case_thm)
- end;
- fun mk_thms k xs goal_case case_thm sel_defs =
- map2 (mk_thm k xs (strip_all_body goal_case) case_thm) xs sel_defs;
- in
- map5 mk_thms ks xss goal_cases case_thms sel_defss
- end;
+ map2 (fn case_thm => map (fn sel_def => case_thm RS (sel_def RS trans))) case_thms
+ sel_defss;
fun mk_unique_disc_def () =
let
@@ -496,7 +502,7 @@
|> map (fn (thmN, thms) =>
((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]));
in
- lthy |> Local_Theory.notes notes |> snd
+ ((selss, discI_thms, sel_thmss), lthy |> Local_Theory.notes notes |> snd)
end;
in
(goalss, after_qed, lthy')
@@ -511,7 +517,7 @@
val parse_bindingss = Parse.$$$ "[" |-- Parse.list parse_bindings --| Parse.$$$ "]";
val wrap_datatype_cmd = (fn (goalss, after_qed, lthy) =>
- Proof.theorem NONE after_qed (map (map (rpair [])) goalss) lthy) oo
+ Proof.theorem NONE (snd oo after_qed) (map (map (rpair [])) goalss) lthy) oo
prepare_wrap_datatype Syntax.read_term;
val _ =