--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Tue Sep 11 13:06:14 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Tue Sep 11 13:06:14 2012 +0200
@@ -213,8 +213,7 @@
val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
val ((iter_only as (gss, _, _), rec_only as (hss, _, _)),
- (zs, cs, cpss, coiter_only as ((pgss, _, cgssss), _),
- corec_only as ((phss, _, chssss), _))) =
+ (zs, cs, cpss, coiter_only as ((pgss, crgsss), _), corec_only as ((phss, cshsss), _))) =
if lfp then
let
val y_Tsss =
@@ -245,7 +244,7 @@
val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
in
(((gss, g_Tss, yssss), (hss, h_Tss, zssss)),
- ([], [], [], (([], [], []), ([], [])), (([], [], []), ([], []))))
+ ([], [], [], (([], []), ([], [])), (([], []), ([], []))))
end
else
let
@@ -254,11 +253,11 @@
val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_predT) ns Cs;
- fun zip_getterss qss fss = maps (op @) (qss ~~ fss);
+ fun zip_predss_getterss qss fss = maps (op @) (qss ~~ fss);
- fun zip_preds_gettersss [] [qss] [fss] = zip_getterss qss fss
- | zip_preds_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
- p :: zip_getterss qss fss @ zip_preds_gettersss ps qsss fsss;
+ fun zip_preds_predsss_gettersss [] [qss] [fss] = zip_predss_getterss qss fss
+ | zip_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
+ p :: zip_predss_getterss qss fss @ zip_preds_predsss_gettersss ps qsss fsss;
fun mk_types maybe_dest_sumT fun_Ts =
let
@@ -269,7 +268,7 @@
Cs mss' f_prod_Tss;
val q_Tssss =
map (map (map (fn [_] => [] | [_, C] => [mk_predT (domain_type C)]))) f_Tssss;
- val pf_Tss = map3 zip_preds_gettersss p_Tss q_Tssss f_Tssss;
+ val pf_Tss = map3 zip_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
in (q_Tssss, f_sum_prod_Ts, f_Tssss, pf_Tss) end;
val (r_Tssss, g_sum_prod_Ts, g_Tssss, pg_Tss) = mk_types single fp_iter_fun_Ts;
@@ -297,12 +296,17 @@
val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
+ fun mk_preds_getters_join [] [cf] = cf
+ | mk_preds_getters_join [cq] [cf, cf'] =
+ mk_If cq (mk_Inl (fastype_of cf') cf) (mk_Inr (fastype_of cf) cf');
+
fun mk_terms qssss fssss =
let
- val pfss = map3 zip_preds_gettersss pss qssss fssss;
+ val pfss = map3 zip_preds_predsss_gettersss pss qssss fssss;
val cqssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs qssss;
val cfssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs fssss;
- in (pfss, cqssss, cfssss) end;
+ val cqfsss = map2 (map2 (map2 mk_preds_getters_join)) cqssss cfssss;
+ in (pfss, cqfsss) end;
in
((([], [], []), ([], [], [])),
([z], cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, pg_Tss)),
@@ -433,16 +437,11 @@
let
val B_to_fpT = C --> fpT;
- fun mk_getters_join [] [cf] = cf
- | mk_getters_join [cq] [cf, cf'] =
- mk_If cq (mk_Inl (fastype_of cf') cf) (mk_Inr (fastype_of cf) cf');
+ fun mk_preds_getterss_join c n cps sum_prod_T cqfss =
+ Term.lambda c (mk_IfN sum_prod_T cps
+ (map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cqfss) (1 upto n)));
- fun mk_preds_gettersss_join c n cps sum_prod_T cqsss cfsss =
- Term.lambda c (mk_IfN sum_prod_T cps
- (map2 (mk_InN_balanced sum_prod_T n)
- (map2 (HOLogic.mk_tuple oo map2 mk_getters_join) cqsss cfsss) (1 upto n)));
-
- fun generate_coiter_like (suf, fp_iter_like, ((pfss, cqssss, cfssss), (f_sum_prod_Ts,
+ fun generate_coiter_like (suf, fp_iter_like, ((pfss, cqfsss), (f_sum_prod_Ts,
pf_Tss))) =
let
val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
@@ -452,7 +451,7 @@
val spec =
mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
Term.list_comb (fp_iter_like,
- map6 mk_preds_gettersss_join cs ns cpss f_sum_prod_Ts cqssss cfssss));
+ map5 mk_preds_getterss_join cs ns cpss f_sum_prod_Ts cqfsss));
in (binder, spec) end;
val coiter_like_bundles =
@@ -542,11 +541,9 @@
val rec_tacss =
map2 (map o mk_iter_like_tac pre_map_defs map_ids rec_defs) fp_rec_thms ctr_defss;
in
- (map2 (map2 (fn goal => fn tac =>
- Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
+ (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
goal_iterss iter_tacss,
- map2 (map2 (fn goal => fn tac =>
- Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
+ map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
goal_recss rec_tacss)
end;
@@ -573,10 +570,10 @@
let
fun mk_goal_cond pos = HOLogic.mk_Trueprop o (not pos ? HOLogic.mk_not);
- fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfss' =
+ fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfs' =
fold_rev (fold_rev Logic.all) ([c] :: pfss)
(Logic.list_implies (seq_conds mk_goal_cond n k cps,
- mk_Trueprop_eq (fcoiter_like $ c, lists_bmoc (take m cfss') ctr)));
+ mk_Trueprop_eq (fcoiter_like $ c, Term.list_comb (ctr, take m cfs'))));
fun build_call fiter_likes maybe_tack (T, U) =
if T = U then
@@ -589,22 +586,21 @@
fun mk_U maybe_mk_sumT =
typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
- fun repair_calls fiter_likes maybe_mk_sumT maybe_tack
- (cf as Free (_, Type (_, [_, T])) $ _) =
- if exists_subtype (member (op =) Cs) T then
- build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cf
- else
- cf;
+ fun repair_calls fiter_likes maybe_mk_sumT maybe_tack cqf =
+ let val T = fastype_of cqf in
+ if exists_subtype (member (op =) Cs) T then
+ build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
+ else
+ cqf
+ end;
- val cgssss' =
- map (map (map (map (repair_calls gcoiters (K I) (K I))))) cgssss;
- val chssss' =
- map (map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z))))) chssss;
+ val crgsss' = map (map (map (repair_calls gcoiters (K I) (K I)))) crgsss;
+ val cshsss' = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) cshsss;
val goal_coiterss =
- map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss cgssss';
+ map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss crgsss';
val goal_corecss =
- map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss chssss';
+ map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss cshsss';
val coiter_tacss =
map3 (map oo mk_coiter_like_tac coiter_defs map_ids) fp_iter_thms pre_map_defs
@@ -613,9 +609,12 @@
map3 (map oo mk_coiter_like_tac corec_defs map_ids) fp_rec_thms pre_map_defs
ctr_defss;
in
- (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
+ (map2 (map2 (fn goal => fn tac =>
+ Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
goal_coiterss coiter_tacss,
- map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
+ map2 (map2 (fn goal => fn tac =>
+ Skip_Proof.prove lthy [] [] goal (tac o #context)
+ |> Local_Defs.unfold lthy @{thms sum_case_if} |> Thm.close_derivation))
goal_corecss corec_tacss)
end;