src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 49681 aa66ea552357
parent 49672 902b24e0ffb4
child 49682 f57af1c46f99
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Oct 02 01:00:18 2012 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Oct 02 01:00:18 2012 +0200
@@ -361,19 +361,19 @@
             | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
               p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
 
-          fun mk_types maybe_dest_sumT fun_Ts =
+          fun mk_types maybe_unzipT fun_Ts =
             let
               val f_sum_prod_Ts = map range_type fun_Ts;
               val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
+              val f_Tsss = map2 (map2 dest_tupleT) mss' f_prod_Tss;
               val f_Tssss =
-                map3 (fn C => map2 (map (map (curry (op -->) C) o maybe_dest_sumT) oo dest_tupleT))
-                  Cs mss' f_prod_Tss;
+                map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
               val q_Tssss =
                 map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss;
               val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
-            in (q_Tssss, f_sum_prod_Ts, f_Tssss, pf_Tss) end;
+            in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end;
 
-          val (r_Tssss, g_sum_prod_Ts, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
+          val (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
 
           val ((((Free (z, _), cs), pss), gssss), lthy) =
             lthy
@@ -383,11 +383,16 @@
             ||>> mk_Freessss "g" g_Tssss;
           val rssss = map (map (map (fn [] => []))) r_Tssss;
 
-          fun dest_corec_sumT (T as Type (@{type_name sum}, Us as [_, U])) =
-              if member (op =) Cs U then Us else [T]
-            | dest_corec_sumT T = [T];
+          fun proj_corecT proj (Type (s as @{type_name sum}, Ts as [T, U])) =
+              if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_corecT proj) Ts)
+            | proj_corecT proj (Type (s, Ts)) = Type (s, map (proj_corecT proj) Ts)
+            | proj_corecT _ T = T;
 
-          val (s_Tssss, h_sum_prod_Ts, h_Tssss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
+          fun unzip_corecT T =
+            if exists_fp_subtype T then [proj_corecT fst T, proj_corecT snd T] else [T];
+
+          val (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss) =
+            mk_types unzip_corecT fp_rec_fun_Ts;
 
           val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
           val ((sssss, hssss_tl), lthy) =
@@ -396,23 +401,34 @@
             ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
           val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
 
-          val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
+          val cpss = map2 (map o rapp) cs pss;
 
-          fun mk_preds_getters_join [] [cf] = cf
-            | mk_preds_getters_join [cq] [cf, cf'] =
-              mk_If cq (mk_Inl (fastype_of cf') cf) (mk_Inr (fastype_of cf) cf');
+          fun build_sum_inj mk_inj (T, U) =
+            if T = U then
+              id_const T
+            else
+              (case (T, U) of
+                (Type (s, _), Type (s', _)) =>
+                if s = s' then build_map (build_sum_inj mk_inj) T U
+                else uncurry mk_inj (dest_sumT U)
+              | _ => uncurry mk_inj (dest_sumT U));
 
-          fun mk_terms qssss fssss =
+          fun build_dtor_corec_arg _ [] [cf] = cf
+            | build_dtor_corec_arg T [cq] [cf, cf'] =
+              mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
+                (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
+
+          fun mk_terms f_Tsss qssss fssss =
             let
               val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
-              val cqssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs qssss;
-              val cfssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs fssss;
-              val cqfsss = map2 (map2 (map2 mk_preds_getters_join)) cqssss cfssss;
+              val cqssss = map2 (map o map o map o rapp) cs qssss;
+              val cfssss = map2 (map o map o map o rapp) cs fssss;
+              val cqfsss = map3 (map3 (map3 build_dtor_corec_arg)) f_Tsss cqssss cfssss;
             in (pfss, cqfsss) end;
         in
           (((([], [], []), ([], [], [])),
-            ([z], cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, pg_Tss)),
-             (mk_terms sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
+            ([z], cs, cpss, (mk_terms g_Tsss rssss gssss, (g_sum_prod_Ts, pg_Tss)),
+             (mk_terms h_Tsss sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
         end;
 
     fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
@@ -595,15 +611,16 @@
           let
             val fpT_to_C = fpT --> C;
 
-            fun build_ctor_rec_arg mk_proj (T, U) =
+            fun build_prod_proj mk_proj (T, U) =
               if T = U then
                 id_const T
               else
                 (case (T, U) of
                   (Type (s, _), Type (s', _)) =>
-                  if s = s' then build_map (build_ctor_rec_arg mk_proj) T U else mk_proj T
+                  if s = s' then build_map (build_prod_proj mk_proj) T U else mk_proj T
                 | _ => mk_proj T);
 
+            (* TODO: Avoid these complications; cf. corec case *)
             fun mk_U proj (Type (s as @{type_name prod}, Ts as [T', U])) =
                 if member (op =) fpTs T' then proj (T', U) else Type (s, map (mk_U proj) Ts)
               | mk_U proj (Type (s, Ts)) = Type (s, map (mk_U proj) Ts)
@@ -611,8 +628,8 @@
 
             fun unzip_rec (x as Free (_, T)) =
               if exists_fp_subtype T then
-                [build_ctor_rec_arg fst_const (T, mk_U fst T) $ x,
-                 build_ctor_rec_arg snd_const (T, mk_U snd T) $ x]
+                [build_prod_proj fst_const (T, mk_U fst T) $ x,
+                 build_prod_proj snd_const (T, mk_U snd T) $ x]
               else
                 [x];