--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Tue Sep 11 13:06:13 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Tue Sep 11 13:06:13 2012 +0200
@@ -210,12 +210,8 @@
val fp_iter_fun_Ts = fst (split_last (binder_types (fastype_of fp_iter1)));
val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
- fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
- if member (op =) Cs U then Us else [T]
- | dest_rec_pair T = [T];
-
val ((iter_only as (gss, _, _), rec_only as (hss, _, _)),
- (zs, cs, cpss, coiter_only as ((pgss, cgsss), _), corec_only as ((phss, chsss), _))) =
+ (zs, cs, cpss, coiter_only as ((pgss, cgssss), _), corec_only as ((phss, chssss), _))) =
if lfp then
let
val y_Tsss =
@@ -227,18 +223,25 @@
lthy
|> mk_Freess "f" g_Tss
||>> mk_Freesss "x" y_Tsss;
+ val yssss = map (map (map single)) ysss;
+
+ fun dest_rec_prodT (T as Type (@{type_name prod}, Us as [_, U])) =
+ if member (op =) Cs U then Us else [T]
+ | dest_rec_prodT T = [T];
val z_Tssss =
- map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o
+ map3 (fn n => fn ms => map2 (map dest_rec_prodT oo dest_tupleT) ms o
dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
val hss = map2 (map2 retype_free) gss h_Tss;
- val (zssss, _) =
+ val zssss_hd = map2 (map2 (map2 (fn y => fn T :: _ => retype_free y T))) ysss z_Tssss;
+ val (zssss_tl, _) =
lthy
- |> mk_Freessss "x" z_Tssss;
+ |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
+ val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
in
- (((gss, g_Tss, map (map (map single)) ysss), (hss, h_Tss, zssss)),
+ (((gss, g_Tss, yssss), (hss, h_Tss, zssss)),
([], [], [], (([], []), ([], [])), (([], []), ([], []))))
end
else
@@ -249,20 +252,23 @@
val p_Tss =
map2 (fn C => fn n => replicate (Int.max (0, n - 1)) (C --> HOLogic.boolT)) Cs ns;
- fun zip_preds_getters [] [fs] = fs
- | zip_preds_getters (p :: ps) (fs :: fss) = p :: fs @ zip_preds_getters ps fss;
+ fun zip_getters fss = flat fss;
- fun mk_types fun_Ts =
+ fun zip_preds_getters [] [fss] = zip_getters fss
+ | zip_preds_getters (p :: ps) (fss :: fsss) =
+ p :: zip_getters fss @ zip_preds_getters ps fsss;
+
+ fun mk_types maybe_dest_sumT fun_Ts =
let
val f_sum_prod_Ts = map range_type fun_Ts;
val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
val f_Tsss =
map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss' f_prod_Tss;
- val pf_Tss = map2 zip_preds_getters p_Tss f_Tsss
+ val f_Tssss = map (map (map maybe_dest_sumT)) f_Tsss;
+ val pf_Tss = map2 zip_preds_getters p_Tss f_Tssss;
in (f_sum_prod_Ts, f_Tsss, pf_Tss) end;
- val (g_sum_prod_Ts, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
- val (h_sum_prod_Ts, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
+ val (g_sum_prod_Ts, g_Tsss, pg_Tss) = mk_types single fp_iter_fun_Ts;
val ((((Free (z, _), cs), pss), gsss), _) =
lthy
@@ -270,20 +276,28 @@
||>> mk_Frees "a" Cs
||>> mk_Freess "p" p_Tss
||>> mk_Freesss "g" g_Tsss;
+ val gssss = map (map (map single)) gsss;
+
+ fun dest_corec_sumT (T as Type (@{type_name sum}, Us as [_, U])) =
+ if member (op =) Cs U then Us else [T]
+ | dest_corec_sumT T = [T];
+
+ val (h_sum_prod_Ts, h_Tsss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
val hsss = map2 (map2 (map2 retype_free)) gsss h_Tsss;
+ val hssss = map (map (map single)) hsss; (*###*)
val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
- fun mk_terms fsss =
+ fun mk_terms fssss =
let
- val pfss = map2 zip_preds_getters pss fsss;
- val cfsss = map2 (fn c => map (map (fn f => f $ c))) cs fsss
- in (pfss, cfsss) end;
+ val pfss = map2 zip_preds_getters pss fssss;
+ val cfssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs fssss;
+ in (pfss, cfssss) end;
in
((([], [], []), ([], [], [])),
- ([z], cs, cpss, (mk_terms gsss, (g_sum_prod_Ts, pg_Tss)),
- (mk_terms hsss, (h_sum_prod_Ts, ph_Tss))))
+ ([z], cs, cpss, (mk_terms gssss, (g_sum_prod_Ts, pg_Tss)),
+ (mk_terms hssss, (h_sum_prod_Ts, ph_Tss))))
end;
fun pour_some_sugar_on_type (((((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec),
@@ -383,11 +397,11 @@
map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss));
in (binder, spec) end;
- val iter_likes =
+ val iter_like_bundles =
[(iterN, fp_iter, iter_only),
(recN, fp_rec, rec_only)];
- val (binders, specs) = map generate_iter_like iter_likes |> split_list;
+ val (binders, specs) = map generate_iter_like iter_like_bundles |> split_list;
val ((csts, defs), (lthy', lthy)) = no_defs_lthy
|> apfst split_list o fold_map2 (fn b => fn spec =>
@@ -410,27 +424,29 @@
let
val B_to_fpT = C --> fpT;
- fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfsss), (f_sum_prod_Ts, pf_Tss))) =
+ fun mk_preds_getters_join c n cps sum_prod_T cfsss =
+ Term.lambda c (mk_IfN sum_prod_T cps
+ (map2 (mk_InN_balanced sum_prod_T n) (map (HOLogic.mk_tuple o flat) cfsss) (*###*)
+ (1 upto n)));
+
+ fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfssss), (f_sum_prod_Ts,
+ pf_Tss))) =
let
val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
val binder = Binding.suffix_name ("_" ^ suf) b;
- fun mk_preds_getters_join c n cps sum_prod_T cfss =
- Term.lambda c (mk_IfN sum_prod_T cps
- (map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cfss) (1 upto n)));
-
val spec =
mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
Term.list_comb (fp_iter_like,
- map5 mk_preds_getters_join cs ns cpss f_sum_prod_Ts cfsss));
+ map5 mk_preds_getters_join cs ns cpss f_sum_prod_Ts cfssss));
in (binder, spec) end;
- val coiter_likes =
+ val coiter_like_bundles =
[(coiterN, fp_iter, coiter_only),
(corecN, fp_rec, corec_only)];
- val (binders, specs) = map generate_coiter_like coiter_likes |> split_list;
+ val (binders, specs) = map generate_coiter_like coiter_like_bundles |> split_list;
val ((csts, defs), (lthy', lthy)) = no_defs_lthy
|> apfst split_list o fold_map2 (fn b => fn spec =>
@@ -490,14 +506,14 @@
~1 => build_map (build_call fiter_likes maybe_tick) T U
| j => maybe_tick (nth vs j) (nth fiter_likes j));
- fun mk_U maybe_prodT =
- typ_subst (map2 (fn fpT => fn C => (fpT, maybe_prodT fpT C)) fpTs Cs);
+ fun mk_U maybe_mk_prodT =
+ typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
- fun repair_calls fiter_likes maybe_cons maybe_tick maybe_prodT (x as Free (_, T)) =
+ fun repair_calls fiter_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
if member (op =) fpTs T then
maybe_cons x [build_call fiter_likes (K I) (T, mk_U (K I) T) $ x]
else if exists_subtype (member (op =) fpTs) T then
- [build_call fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
+ [build_call fiter_likes maybe_tick (T, mk_U maybe_mk_prodT T) $ x]
else
[x];
@@ -544,10 +560,10 @@
let
fun mk_goal_cond pos = HOLogic.mk_Trueprop o (not pos ? HOLogic.mk_not);
- fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfs' =
+ fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfss' =
fold_rev (fold_rev Logic.all) ([c] :: pfss)
(Logic.list_implies (seq_conds mk_goal_cond n k cps,
- mk_Trueprop_eq (fcoiter_like $ c, Term.list_comb (ctr, take m cfs'))));
+ mk_Trueprop_eq (fcoiter_like $ c, lists_bmoc (take m cfss') ctr)));
fun build_call fiter_likes maybe_tack (T, U) =
if T = U then
@@ -557,23 +573,25 @@
~1 => build_map (build_call fiter_likes maybe_tack) T U
| j => maybe_tack (nth cs j, nth vs j) (nth fiter_likes j));
- fun mk_U maybe_sumT =
- typ_subst (map2 (fn C => fn fpT => (maybe_sumT fpT C, fpT)) Cs fpTs);
+ fun mk_U maybe_mk_sumT =
+ typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
- fun repair_calls fiter_likes maybe_sumT maybe_tack
+ fun repair_calls fiter_likes maybe_mk_sumT maybe_tack
(cf as Free (_, Type (_, [_, T])) $ _) =
if exists_subtype (member (op =) Cs) T then
- build_call fiter_likes maybe_tack (T, mk_U maybe_sumT T) $ cf
+ build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cf
else
cf;
- val cgsss' = map (map (map (repair_calls gcoiters (K I) (K I)))) cgsss;
- val chsss' = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) chsss;
+ val cgssss' =
+ map (map (map (map (repair_calls gcoiters (K I) (K I))))) cgssss;
+ val chssss' =
+ map (map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z))))) chssss;
val goal_coiterss =
- map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss cgsss';
+ map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss cgssss';
val goal_corecss =
- map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss chsss';
+ map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss chssss';
val coiter_tacss =
map3 (map oo mk_coiter_like_tac coiter_defs map_ids) fp_iter_thms pre_map_defs