--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sat Sep 08 21:04:26 2012 +0200
@@ -128,7 +128,7 @@
val eqs = map dest_TFree Xs ~~ sum_prod_TsXs;
- val ((raw_unfs, raw_flds, raw_fp_iters, raw_fp_recs, unf_flds, fld_unfs, fld_injects), lthy') =
+ val ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects), lthy') =
fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs mixfixes As' eqs lthy;
val timer = time (Timer.startRealTimer ());
@@ -141,23 +141,29 @@
val mk_unf = mk_unf_or_fld domain_type;
val mk_fld = mk_unf_or_fld range_type;
- val unfs = map (mk_unf As) raw_unfs;
- val flds = map (mk_fld As) raw_flds;
+ val unfs = map (mk_unf As) unfs0;
+ val flds = map (mk_fld As) flds0;
val fpTs = map (domain_type o fastype_of) unfs;
val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
- fun mk_fp_iter_or_rec Ts Us c =
+ val ns = map length ctr_Tsss;
+ val mss = map (map length) ctr_Tsss;
+ val Css = map2 replicate ns Cs;
+ val Cs' = flat Css;
+
+ fun mk_iter_or_rec Ts Us c =
let
val (binders, body) = strip_type (fastype_of c);
- val Type (_, Ts0) = if gfp then body else List.last binders;
- val Us0 = map (if gfp then domain_type else body_type) (fst (split_last binders));
+ val (fst_binders, last_binder) = split_last binders;
+ val Type (_, Ts0) = if gfp then body else last_binder;
+ val Us0 = map (if gfp then domain_type else body_type) fst_binders;
in
Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) c
end;
- val fp_iters = map (mk_fp_iter_or_rec As Cs) raw_fp_iters;
- val fp_recs = map (mk_fp_iter_or_rec As Cs) raw_fp_recs;
+ val fp_iters = map (mk_iter_or_rec As Cs) fp_iters0;
+ val fp_recs = map (mk_iter_or_rec As Cs) fp_recs0;
fun pour_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf),
unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss)
@@ -199,8 +205,10 @@
val ctr_defs = map (Morphism.thm phi) raw_ctr_defs;
val case_def = Morphism.thm phi raw_case_def;
- val ctrs = map (Morphism.term phi) raw_ctrs;
- val casex = Morphism.term phi raw_case;
+ val ctrs0 = map (Morphism.term phi) raw_ctrs;
+ val casex0 = Morphism.term phi raw_case;
+
+ val ctrs = map (mk_ctr As) ctrs0;
fun exhaust_tac {context = ctxt, ...} =
let
@@ -245,10 +253,6 @@
val is_fpT = member (op =) fpTs;
- val ns = map length ctr_Tsss;
- val mss = map (map length) ctr_Tsss;
- val Css = map2 replicate ns Cs;
-
fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
if member (op =) Cs U then Us else [T]
| dest_rec_pair T = [T];
@@ -303,15 +307,18 @@
val iter_def = Morphism.thm phi raw_iter_def;
val rec_def = Morphism.thm phi raw_rec_def;
- val iter = Morphism.term phi raw_iter;
- val recx = Morphism.term phi raw_rec;
+ val iter0 = Morphism.term phi raw_iter;
+ val rec0 = Morphism.term phi raw_rec;
+
+ val iter = mk_iter_or_rec As Cs' iter0;
+ val recx = mk_iter_or_rec As Cs' rec0;
in
([[ctrs], [[iter]], [[recx]], xss, gss, hss], lthy)
end;
fun sugar_codatatype no_defs_lthy = ([], no_defs_lthy);
in
- wrap_datatype tacss ((ctrs, casex), (disc_binders, sel_binderss)) lthy'
+ wrap_datatype tacss ((ctrs0, casex0), (disc_binders, sel_binderss)) lthy'
|> (if gfp then sugar_codatatype else sugar_datatype)
end;
@@ -327,9 +334,12 @@
mk_Trueprop_eq (fc $ xctr, fc $ xctr);
val goal_iterss = map2 (fn giter => map (mk_goal_iter_or_rec giter)) giters xctrss;
- val goal_recss = [];
- val iter_tacss = []; (* ### map (map mk_iter_or_rec_tac); (* needs ctr_def, iter_def, fld_iter *) *)
- val rec_tacss = [];
+ val goal_recss = map2 (fn hrec => map (mk_goal_iter_or_rec hrec)) hrecs xctrss;
+ val iter_tacss =
+ map (map (K (fn _ => Skip_Proof.cheat_tac (Proof_Context.theory_of lthy)))) goal_iterss;
+ (* ### map (map mk_iter_or_rec_tac); (* needs ctr_def, iter_def, fld_iter *) *)
+ val rec_tacss =
+ map (map (K (fn _ => Skip_Proof.cheat_tac (Proof_Context.theory_of lthy)))) goal_recss;
in
(map2 (map2 (Skip_Proof.prove lthy [] [])) goal_iterss iter_tacss,
map2 (map2 (Skip_Proof.prove lthy [] [])) goal_recss rec_tacss)
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML Sat Sep 08 21:04:26 2012 +0200
@@ -9,6 +9,7 @@
sig
val no_binder: binding
val mk_half_pairss: 'a list -> ('a * 'a) list list
+ val mk_ctr: typ list -> term -> term
val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list ->
(term list * term) * (binding list * binding list list) -> local_theory -> local_theory
end;
@@ -54,10 +55,15 @@
(* TODO: provide a way to have a different default value, e.g. "tl Nil = Nil" *)
fun mk_undef T Ts = Const (@{const_name undefined}, Ts ---> T);
+fun mk_ctr Ts ctr =
+ let val Type (_, Ts0) = body_type (fastype_of ctr) in
+ Term.subst_atomic_types (Ts0 ~~ Ts) ctr
+ end;
+
fun eta_expand_case_arg xs f_xs = fold_rev Term.lambda xs f_xs;
-fun name_of_ctr t =
- case head_of t of
+fun name_of_ctr c =
+ case head_of c of
Const (s, _) => s
| Free (s, _) => s
| _ => error "Cannot extract name of constructor";
@@ -86,11 +92,6 @@
|> mk_TFrees (length As0)
||> the_single o fst o mk_TFrees 1;
- fun mk_ctr Ts ctr =
- let val Type (_, Ts0) = body_type (fastype_of ctr) in
- Term.subst_atomic_types (Ts0 ~~ Ts) ctr
- end;
-
val T = Type (T_name, As);
val ctrs = map (mk_ctr As) ctrs0;
val ctr_Tss = map (binder_types o fastype_of) ctrs;
@@ -220,8 +221,8 @@
val discs0 = map (Morphism.term phi) raw_discs;
val selss0 = map (map (Morphism.term phi)) raw_selss;
- fun mk_disc_or_sel Ts t =
- Term.subst_atomic_types (snd (Term.dest_Type (domain_type (fastype_of t))) ~~ Ts) t;
+ fun mk_disc_or_sel Ts c =
+ Term.subst_atomic_types (snd (Term.dest_Type (domain_type (fastype_of c))) ~~ Ts) c;
val discs = map (mk_disc_or_sel As) discs0;
val selss = map (map (mk_disc_or_sel As)) selss0;
@@ -245,9 +246,9 @@
val goal_half_distinctss =
let
- fun mk_goal ((xs, t), (xs', t')) =
+ fun mk_goal ((xs, xc), (xs', xc')) =
fold_rev Logic.all (xs @ xs')
- (HOLogic.mk_Trueprop (HOLogic.mk_not (HOLogic.mk_eq (t, t'))));
+ (HOLogic.mk_Trueprop (HOLogic.mk_not (HOLogic.mk_eq (xc, xc'))));
in
map (map mk_goal) (mk_half_pairss (xss ~~ xctrs))
end;