--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue Oct 02 01:00:18 2012 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue Oct 02 01:00:18 2012 +0200
@@ -361,19 +361,19 @@
| flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
- fun mk_types maybe_dest_sumT fun_Ts =
+ fun mk_types maybe_unzipT fun_Ts =
let
val f_sum_prod_Ts = map range_type fun_Ts;
val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
+ val f_Tsss = map2 (map2 dest_tupleT) mss' f_prod_Tss;
val f_Tssss =
- map3 (fn C => map2 (map (map (curry (op -->) C) o maybe_dest_sumT) oo dest_tupleT))
- Cs mss' f_prod_Tss;
+ map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
val q_Tssss =
map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss;
val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
- in (q_Tssss, f_sum_prod_Ts, f_Tssss, pf_Tss) end;
+ in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end;
- val (r_Tssss, g_sum_prod_Ts, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
+ val (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
val ((((Free (z, _), cs), pss), gssss), lthy) =
lthy
@@ -383,11 +383,16 @@
||>> mk_Freessss "g" g_Tssss;
val rssss = map (map (map (fn [] => []))) r_Tssss;
- fun dest_corec_sumT (T as Type (@{type_name sum}, Us as [_, U])) =
- if member (op =) Cs U then Us else [T]
- | dest_corec_sumT T = [T];
+ fun proj_corecT proj (Type (s as @{type_name sum}, Ts as [T, U])) =
+ if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_corecT proj) Ts)
+ | proj_corecT proj (Type (s, Ts)) = Type (s, map (proj_corecT proj) Ts)
+ | proj_corecT _ T = T;
- val (s_Tssss, h_sum_prod_Ts, h_Tssss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
+ fun unzip_corecT T =
+ if exists_fp_subtype T then [proj_corecT fst T, proj_corecT snd T] else [T];
+
+ val (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss) =
+ mk_types unzip_corecT fp_rec_fun_Ts;
val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
val ((sssss, hssss_tl), lthy) =
@@ -396,23 +401,34 @@
||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
- val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
+ val cpss = map2 (map o rapp) cs pss;
- fun mk_preds_getters_join [] [cf] = cf
- | mk_preds_getters_join [cq] [cf, cf'] =
- mk_If cq (mk_Inl (fastype_of cf') cf) (mk_Inr (fastype_of cf) cf');
+ fun build_sum_inj mk_inj (T, U) =
+ if T = U then
+ id_const T
+ else
+ (case (T, U) of
+ (Type (s, _), Type (s', _)) =>
+ if s = s' then build_map (build_sum_inj mk_inj) T U
+ else uncurry mk_inj (dest_sumT U)
+ | _ => uncurry mk_inj (dest_sumT U));
- fun mk_terms qssss fssss =
+ fun build_dtor_corec_arg _ [] [cf] = cf
+ | build_dtor_corec_arg T [cq] [cf, cf'] =
+ mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
+ (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
+
+ fun mk_terms f_Tsss qssss fssss =
let
val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
- val cqssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs qssss;
- val cfssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs fssss;
- val cqfsss = map2 (map2 (map2 mk_preds_getters_join)) cqssss cfssss;
+ val cqssss = map2 (map o map o map o rapp) cs qssss;
+ val cfssss = map2 (map o map o map o rapp) cs fssss;
+ val cqfsss = map3 (map3 (map3 build_dtor_corec_arg)) f_Tsss cqssss cfssss;
in (pfss, cqfsss) end;
in
(((([], [], []), ([], [], [])),
- ([z], cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, pg_Tss)),
- (mk_terms sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
+ ([z], cs, cpss, (mk_terms g_Tsss rssss gssss, (g_sum_prod_Ts, pg_Tss)),
+ (mk_terms h_Tsss sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
end;
fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
@@ -595,15 +611,16 @@
let
val fpT_to_C = fpT --> C;
- fun build_ctor_rec_arg mk_proj (T, U) =
+ fun build_prod_proj mk_proj (T, U) =
if T = U then
id_const T
else
(case (T, U) of
(Type (s, _), Type (s', _)) =>
- if s = s' then build_map (build_ctor_rec_arg mk_proj) T U else mk_proj T
+ if s = s' then build_map (build_prod_proj mk_proj) T U else mk_proj T
| _ => mk_proj T);
+ (* TODO: Avoid these complications; cf. corec case *)
fun mk_U proj (Type (s as @{type_name prod}, Ts as [T', U])) =
if member (op =) fpTs T' then proj (T', U) else Type (s, map (mk_U proj) Ts)
| mk_U proj (Type (s, Ts)) = Type (s, map (mk_U proj) Ts)
@@ -611,8 +628,8 @@
fun unzip_rec (x as Free (_, T)) =
if exists_fp_subtype T then
- [build_ctor_rec_arg fst_const (T, mk_U fst T) $ x,
- build_ctor_rec_arg snd_const (T, mk_U snd T) $ x]
+ [build_prod_proj fst_const (T, mk_U fst T) $ x,
+ build_prod_proj snd_const (T, mk_U snd T) $ x]
else
[x];