--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sun Sep 09 17:14:39 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sun Sep 09 18:55:10 2012 +0200
@@ -48,7 +48,12 @@
fun mk_uncurried2_fun f xss =
mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
-fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v))
+fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v));
+
+fun tack z_name (c, v) f =
+ let val z = Free (z_name, mk_sumT (fastype_of v, fastype_of c)) in
+ Term.lambda z (mk_sum_case (Term.lambda v v) (Term.lambda c (f $ c)) $ z)
+ end;
fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
@@ -204,7 +209,7 @@
| dest_rec_pair T = [T];
val ((iter_only as (gss, g_Tss, yssss), rec_only as (hss, h_Tss, zssss)),
- (cs, cpss, p_Tss, coiter_only as ((pgss, cgsss), g_sum_prod_Ts, g_prod_Tss, g_Tsss),
+ (zs, cs, cpss, p_Tss, coiter_only as ((pgss, cgsss), g_sum_prod_Ts, g_prod_Tss, g_Tsss),
corec_only as ((phss, chsss), h_sum_prod_Ts, h_prod_Tss, h_Tsss))) =
if lfp then
let
@@ -229,7 +234,7 @@
|> mk_Freessss "x" z_Tssss;
in
(((gss, g_Tss, map (map (map single)) ysss), (hss, h_Tss, zssss)),
- ([], [], [], (([], []), [], [], []), (([], []), [], [], [])))
+ ([], [], [], [], (([], []), [], [], []), (([], []), [], [], [])))
end
else
let
@@ -254,15 +259,15 @@
val (g_sum_prod_Ts, g_prod_Tss, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
val (h_sum_prod_Ts, h_prod_Tss, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
- val (((c, pss), gsss), _) =
+ val ((((Free (z, _), cs), pss), gsss), _) =
lthy
- |> yield_singleton (mk_Frees "c") dummyT
+ |> yield_singleton (mk_Frees "z") dummyT
+ ||>> mk_Frees "a" Cs
||>> mk_Freess "p" p_Tss
||>> mk_Freesss "g" g_Tsss;
val hsss = map2 (map2 (map2 retype_free)) gsss h_Tsss;
- val cs = map (retype_free c) Cs;
val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
fun mk_terms fsss =
@@ -272,7 +277,7 @@
in (pfss, cfsss) end;
in
((([], [], []), ([], [], [])),
- (cs, cpss, p_Tss, (mk_terms gsss, g_sum_prod_Ts, g_prod_Tss, pg_Tss),
+ ([z], cs, cpss, p_Tss, (mk_terms gsss, g_sum_prod_Ts, g_prod_Tss, pg_Tss),
(mk_terms hsss, h_sum_prod_Ts, h_prod_Tss, ph_Tss)))
end;
@@ -447,24 +452,6 @@
Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
end;
- fun build_iter_like_call vs basic_Ts fiter_likes maybe_tick =
- let
- fun build (T, U) =
- if T = U then
- Const (@{const_name id}, T --> T)
- else
- (case (find_index (curry (op =) T) basic_Ts, (T, U)) of
- (~1, (Type (s, Ts), Type (_, Us))) =>
- let
- val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
- val mapx = mk_map Ts Us map0;
- val TUs =
- map dest_funT (fst (split_last (fst (strip_map_type (fastype_of mapx)))));
- val args = map build TUs;
- in Term.list_comb (mapx, args) end
- | (j, _) => maybe_tick (nth vs j) (nth fiter_likes j))
- in build end;
-
fun pour_more_sugar_on_lfps ((ctrss, iters, recs, vs, xsss, ctr_defss, iter_defs, rec_defs),
lthy) =
let
@@ -478,14 +465,32 @@
fold_rev (fold_rev Logic.all) (xs :: fss)
(mk_Trueprop_eq (fiter_like $ xctr, Term.list_comb (f, fxs)));
+ fun build_call fiter_likes maybe_tick =
+ let
+ fun build (T, U) =
+ if T = U then
+ Const (@{const_name id}, T --> T)
+ else
+ (case (find_index (curry (op =) T) fpTs, (T, U)) of
+ (~1, (Type (s, Ts), Type (_, Us))) =>
+ let
+ val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
+ val mapx = mk_map Ts Us map0;
+ val TUs =
+ map dest_funT (fst (split_last (fst (strip_map_type (fastype_of mapx)))));
+ val args = map build TUs;
+ in Term.list_comb (mapx, args) end
+ | (j, _) => maybe_tick (nth vs j) (nth fiter_likes j))
+ in build end;
+
fun mk_U maybe_prodT =
typ_subst (map2 (fn fpT => fn C => (fpT, maybe_prodT fpT C)) fpTs Cs);
fun repair_calls fiter_likes maybe_cons maybe_tick maybe_prodT (x as Free (_, T)) =
if member (op =) fpTs T then
- maybe_cons x [build_iter_like_call vs fpTs fiter_likes (K I) (T, mk_U (K I) T) $ x]
+ maybe_cons x [build_call fiter_likes (K I) (T, mk_U (K I) T) $ x]
else if exists_subtype (member (op =) fpTs) T then
- [build_iter_like_call vs fpTs fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
+ [build_call fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
else
[x];
@@ -521,6 +526,8 @@
fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, vs, xsss, ctr_defss, coiter_defs,
corec_defs), lthy) =
let
+ val z = the_single zs;
+
val gcoiters = map (lists_bmoc pgss) coiters;
val hcorecs = map (lists_bmoc phss) corecs;
@@ -533,32 +540,58 @@
(Logic.list_implies (seq_conds mk_goal_cond n k cps,
mk_Trueprop_eq (fcoiter_like $ c, Term.list_comb (ctr, cfs'))));
+ fun build_call fiter_likes maybe_tack =
+ let
+ fun build (T, U) =
+ if T = U then
+ Const (@{const_name id}, T --> T)
+ else
+ (case (find_index (curry (op =) U) fpTs, (T, U)) of
+ (~1, (Type (s, Ts), Type (_, Us))) =>
+ let
+ val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
+ val mapx = mk_map Ts Us map0;
+ val TUs =
+ map dest_funT (fst (split_last (fst (strip_map_type (fastype_of mapx)))));
+ val args = map build TUs;
+ in Term.list_comb (mapx, args) end
+ | (j, _) => maybe_tack (nth cs j, nth vs j) (nth fiter_likes j))
+ in build end;
+
fun mk_U maybe_sumT =
- typ_subst (map2 (fn C => fn fpT => (C, maybe_sumT C fpT)) Cs fpTs);
+ typ_subst (map2 (fn C => fn fpT => (maybe_sumT fpT C, fpT)) Cs fpTs);
fun repair_calls fiter_likes maybe_sumT maybe_tack
(cf as Free (_, Type (_, [_, T])) $ _) =
if exists_subtype (member (op =) Cs) T then
- build_iter_like_call vs Cs fiter_likes maybe_tack (T, mk_U maybe_sumT T) $ cf
+ build_call fiter_likes maybe_tack (T, mk_U maybe_sumT T) $ cf
else
cf;
val cgsss = map (map (map (repair_calls gcoiters (K I) (K I)))) cgsss;
+ val chsss = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) chsss;
val goal_coiterss =
map7 (map3 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss cgsss;
+ val goal_corecss =
+ map7 (map3 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss chsss;
val coiter_tacss =
map3 (map oo mk_coiter_like_tac coiter_defs map_ids) fp_iter_thms pre_map_defs
ctr_defss;
+ val corec_tacss =
+ map3 (map oo mk_coiter_like_tac corec_defs map_ids) fp_rec_thms pre_map_defs
+ ctr_defss;
in
(map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
goal_coiterss coiter_tacss,
- [])
+ map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
+ goal_corecss corec_tacss)
end;
val notes =
- [(coitersN, coiter_thmss)]
+ [(coitersN, coiter_thmss),
+ (corecsN, corec_thmss)]
|> maps (fn (thmN, thmss) =>
map2 (fn b => fn thms =>
((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))