--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue Oct 02 01:00:18 2012 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue Oct 02 01:00:18 2012 +0200
@@ -113,11 +113,6 @@
Type (_, Ts) => map (not o member (op =) (deads_of_bnf bnf)) Ts
| _ => replicate n false);
-fun tack z_name (c, u) f =
- let val z = Free (z_name, mk_sumT (fastype_of u, fastype_of c)) in
- Term.lambda z (mk_sum_case (Term.lambda u u, Term.lambda c (f $ c)) $ z)
- end;
-
fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
@@ -277,6 +272,7 @@
val pre_map_defs = map map_def_of_bnf pre_bnfs;
val pre_set_defss = map set_defs_of_bnf pre_bnfs;
val pre_rel_defs = map rel_def_of_bnf pre_bnfs;
+ val nested_map_comps'' = map ((fn thm => thm RS sym) o map_comp_of_bnf) nested_bnfs;
val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nesting_bnfs;
@@ -312,8 +308,8 @@
val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of any_fp_rec)));
val (((fold_only as (gss, _, _), rec_only as (hss, _, _)),
- (zs, cs, cpss, unfold_only as ((pgss, crgsss), _), corec_only as ((phss, cshsss), _))),
- names_lthy0) =
+ (cs, cpss, unfold_only as ((pgss, crssss, cgssss), (_, g_Tsss, _)),
+ corec_only as ((phss, csssss, chssss), (_, h_Tsss, _)))), names_lthy0) =
if lfp then
let
val y_Tsss =
@@ -344,7 +340,7 @@
val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
in
((((gss, g_Tss, ysss), (hss, h_Tss, zsss)),
- ([], [], [], (([], []), ([], [])), (([], []), ([], [])))), lthy)
+ ([], [], (([], [], []), ([], [], [])), (([], [], []), ([], [], [])))), lthy)
end
else
let
@@ -373,10 +369,9 @@
val (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
- val ((((Free (z, _), cs), pss), gssss), lthy) =
+ val (((cs, pss), gssss), lthy) =
lthy
- |> yield_singleton (mk_Frees "z") dummyT
- ||>> mk_Frees "a" Cs
+ |> mk_Frees "a" Cs
||>> mk_Freess "p" p_Tss
||>> mk_Freessss "g" g_Tssss;
val rssss = map (map (map (fn [] => []))) r_Tssss;
@@ -401,32 +396,16 @@
val cpss = map2 (map o rapp) cs pss;
- fun build_sum_inj mk_inj (T, U) =
- if T = U then
- id_const T
- else
- (case (T, U) of
- (Type (s, _), Type (s', _)) =>
- if s = s' then build_map (build_sum_inj mk_inj) T U
- else uncurry mk_inj (dest_sumT U)
- | _ => uncurry mk_inj (dest_sumT U));
-
- fun build_dtor_corec_arg _ [] [cf] = cf
- | build_dtor_corec_arg T [cq] [cf, cf'] =
- mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
- (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
-
- fun mk_terms f_Tsss qssss fssss =
+ fun mk_terms qssss fssss =
let
val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
val cqssss = map2 (map o map o map o rapp) cs qssss;
val cfssss = map2 (map o map o map o rapp) cs fssss;
- val cqfsss = map3 (map3 (map3 build_dtor_corec_arg)) f_Tsss cqssss cfssss;
- in (pfss, cqfsss) end;
+ in (pfss, cqssss, cfssss) end;
in
(((([], [], []), ([], [], [])),
- ([z], cs, cpss, (mk_terms g_Tsss rssss gssss, (g_sum_prod_Ts, pg_Tss)),
- (mk_terms h_Tsss sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
+ (cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, g_Tsss, pg_Tss)),
+ (mk_terms sssss hssss, (h_sum_prod_Ts, h_Tsss, ph_Tss)))), lthy)
end;
fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
@@ -668,12 +647,30 @@
let
val B_to_fpT = C --> fpT;
+ fun build_sum_inj mk_inj (T, U) =
+ if T = U then
+ id_const T
+ else
+ (case (T, U) of
+ (Type (s, _), Type (s', _)) =>
+ if s = s' then build_map (build_sum_inj mk_inj) T U
+ else uncurry mk_inj (dest_sumT U)
+ | _ => uncurry mk_inj (dest_sumT U));
+
+ fun build_dtor_corec_like_arg _ [] [cf] = cf
+ | build_dtor_corec_like_arg T [cq] [cf, cf'] =
+ mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
+ (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
+
+ val crgsss = map3 (map3 (map3 build_dtor_corec_like_arg)) g_Tsss crssss cgssss;
+ val cshsss = map3 (map3 (map3 build_dtor_corec_like_arg)) h_Tsss csssss chssss;
+
fun mk_preds_getterss_join c n cps sum_prod_T cqfss =
Term.lambda c (mk_IfN sum_prod_T cps
(map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cqfss) (1 upto n)));
- fun generate_corec_like (suf, fp_rec_like, ((pfss, cqfsss), (f_sum_prod_Ts,
- pf_Tss))) =
+ fun generate_corec_like (suf, fp_rec_like, (cqfsss, ((pfss, _, _), (f_sum_prod_Ts, _,
+ pf_Tss)))) =
let
val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
val binding = qualify false fp_b_name (Binding.suffix_name ("_" ^ suf) fp_b);
@@ -684,8 +681,8 @@
in (binding, spec) end;
val corec_like_infos =
- [(unfoldN, fp_fold, unfold_only),
- (corecN, fp_rec, corec_only)];
+ [(unfoldN, fp_fold, (crgsss, unfold_only)),
+ (corecN, fp_rec, (cshsss, corec_only))];
val (bindings, specs) = map generate_corec_like corec_like_infos |> split_list;
@@ -919,8 +916,7 @@
fun build_rel rs' T =
(case find_index (curry (op =) T) fpTs of
~1 =>
- if exists_fp_subtype T then build_rel_step (build_rel rs') T
- else HOLogic.eq_const T
+ if exists_fp_subtype T then build_rel_step (build_rel rs') T else HOLogic.eq_const T
| kk => nth rs' kk);
fun build_rel_app rs' usel vsel =
@@ -974,7 +970,6 @@
fun mk_maybe_not pos = not pos ? HOLogic.mk_not;
- val z = the_single zs;
val gunfolds = map (lists_bmoc pgss) unfolds;
val hcorecs = map (lists_bmoc phss) corecs;
@@ -985,58 +980,66 @@
(Logic.list_implies (seq_conds (HOLogic.mk_Trueprop oo mk_maybe_not) n k cps,
mk_Trueprop_eq (fcorec_like $ c, Term.list_comb (ctr, take m cfs'))));
- fun build_corec_like fcorec_likes maybe_tack (T, U) =
+ fun build_corec_like fcorec_likes (T, U) =
if T = U then
id_const T
else
(case find_index (curry (op =) U) fpTs of
- ~1 => build_map (build_corec_like fcorec_likes maybe_tack) T U
- | kk => maybe_tack (nth cs kk, nth us kk) (nth fcorec_likes kk));
+ ~1 => build_map (build_corec_like fcorec_likes) T U
+ | kk => nth fcorec_likes kk);
+
+ val mk_U = typ_subst (map2 pair Cs fpTs);
- fun mk_U maybe_mk_sumT =
- typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
+ fun intr_corec_likes fcorec_likes [] [cf] =
+ let val T = fastype_of cf in
+ if exists_Cs_subtype T then build_corec_like fcorec_likes (T, mk_U T) $ cf else cf
+ end
+ | intr_corec_likes fcorec_likes [cq] [cf, cf'] =
+ mk_If cq (intr_corec_likes fcorec_likes [] [cf])
+ (intr_corec_likes fcorec_likes [] [cf']);
+
+ val crgsss = map2 (map2 (map2 (intr_corec_likes gunfolds))) crssss cgssss;
+ val cshsss = map2 (map2 (map2 (intr_corec_likes hcorecs))) csssss chssss;
- fun intr_corec_likes fcorec_likes maybe_mk_sumT maybe_tack cqf =
- let val T = fastype_of cqf in
- if exists_Cs_subtype T then
- build_corec_like fcorec_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
- else
- cqf
+ val unfold_goalss =
+ map8 (map4 oooo mk_goal pgss) cs cpss gunfolds ns kss ctrss mss crgsss;
+ val corec_goalss =
+ map8 (map4 oooo mk_goal phss) cs cpss hcorecs ns kss ctrss mss cshsss;
+
+ fun mk_map_if_distrib bnf =
+ let
+ val mapx = map_of_bnf bnf;
+ val live = live_of_bnf bnf;
+ val ((Ts, T), U) = strip_typeN (live + 1) (fastype_of mapx) |>> split_last;
+ val fs = Variable.variant_frees lthy [mapx] (map (pair "f") Ts);
+ val t = Term.list_comb (mapx, map (Var o apfst (rpair 0)) fs);
+ in
+ Drule.instantiate' (map (SOME o certifyT lthy) [U, T]) [SOME (certify lthy t)]
+ @{thm if_distrib}
end;
- val crgsss' = map (map (map (intr_corec_likes gunfolds (K I) (K I)))) crgsss;
- val cshsss' =
- map (map (map (intr_corec_likes hcorecs (curry mk_sumT) (tack z)))) cshsss;
-
- val unfold_goalss =
- map8 (map4 oooo mk_goal pgss) cs cpss gunfolds ns kss ctrss mss crgsss';
- val corec_goalss =
- map8 (map4 oooo mk_goal phss) cs cpss hcorecs ns kss ctrss mss cshsss';
+ val nested_map_if_distribs = map mk_map_if_distrib nested_bnfs;
val unfold_tacss =
- map3 (map oo mk_corec_like_tac unfold_defs nesting_map_ids'') fp_fold_thms
- pre_map_defs ctr_defss;
+ map3 (map oo mk_corec_like_tac unfold_defs [] [] nesting_map_ids'' [])
+ fp_fold_thms pre_map_defs ctr_defss;
val corec_tacss =
- map3 (map oo mk_corec_like_tac corec_defs nesting_map_ids'') fp_rec_thms pre_map_defs
- ctr_defss;
+ map3 (map oo mk_corec_like_tac corec_defs nested_map_comps'' nested_map_comp's
+ (nested_map_ids'' @ nesting_map_ids'') nested_map_if_distribs)
+ fp_rec_thms pre_map_defs ctr_defss;
fun prove goal tac =
Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation;
val unfold_thmss = map2 (map2 prove) unfold_goalss unfold_tacss;
- val corec_thmss =
- map2 (map2 prove) corec_goalss corec_tacss
- |> map (map (unfold_thms lthy @{thms sum_case_if}));
-
- val unfold_safesss = map2 (map2 (map2 (curry (op =)))) crgsss' crgsss;
- val corec_safesss = map2 (map2 (map2 (curry (op =)))) cshsss' cshsss;
+ val corec_thmss = map2 (map2 prove) corec_goalss corec_tacss;
val filter_safesss =
map2 (map_filter (fn (safes, thm) => if forall I safes then SOME thm else NONE) oo
- curry (op ~~));
+ curry (op ~~)) (map2 (map2 (map2 (member (op =)))) cgssss crgsss);
- val safe_unfold_thmss = filter_safesss unfold_safesss unfold_thmss;
- val safe_corec_thmss = filter_safesss corec_safesss corec_thmss;
+ val safe_unfold_thmss = filter_safesss unfold_thmss;
+ val safe_corec_thmss = filter_safesss corec_thmss;
in
(unfold_thmss, corec_thmss, safe_unfold_thmss, safe_corec_thmss)
end;