split sum types in corecursor definition
authorblanchet
Tue, 11 Sep 2012 13:06:14 +0200
changeset 49275 ce87d6a901eb
parent 49274 ddd606ec45b9
child 49276 59fa53ed7507
split sum types in corecursor definition
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_util.ML
src/HOL/Codatatype/Tools/bnf_gfp_util.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 13:06:13 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 13:06:14 2012 +0200
@@ -48,6 +48,8 @@
 
 val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs))
 
+fun mk_predT T = T --> HOLogic.boolT;
+
 fun mk_id T = Const (@{const_name id}, T --> T);
 
 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
@@ -211,7 +213,8 @@
     val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
 
     val ((iter_only as (gss, _, _), rec_only as (hss, _, _)),
-         (zs, cs, cpss, coiter_only as ((pgss, cgssss), _), corec_only as ((phss, chssss), _))) =
+         (zs, cs, cpss, coiter_only as ((pgss, _, cgssss), _),
+          corec_only as ((phss, _, chssss), _))) =
       if lfp then
         let
           val y_Tsss =
@@ -242,62 +245,68 @@
           val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
         in
           (((gss, g_Tss, yssss), (hss, h_Tss, zssss)),
-           ([], [], [], (([], []), ([], [])), (([], []), ([], []))))
+           ([], [], [], (([], [], []), ([], [])), (([], [], []), ([], []))))
         end
       else
         let
           (*avoid "'a itself" arguments in coiterators and corecursors*)
           val mss' =  map (fn [0] => [1] | ms => ms) mss;
 
-          val p_Tss =
-            map2 (fn C => fn n => replicate (Int.max (0, n - 1)) (C --> HOLogic.boolT)) Cs ns;
+          val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_predT) ns Cs;
+
+          fun zip_getterss qss fss = maps (op @) (qss ~~ fss);
 
-          fun zip_getters fss = flat fss;
-
-          fun zip_preds_getters [] [fss] = zip_getters fss
-            | zip_preds_getters (p :: ps) (fss :: fsss) =
-              p :: zip_getters fss @ zip_preds_getters ps fsss;
+          fun zip_preds_gettersss [] [qss] [fss] = zip_getterss qss fss
+            | zip_preds_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
+              p :: zip_getterss qss fss @ zip_preds_gettersss ps qsss fsss;
 
           fun mk_types maybe_dest_sumT fun_Ts =
             let
               val f_sum_prod_Ts = map range_type fun_Ts;
               val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
-              val f_Tsss =
-                map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss' f_prod_Tss;
-              val f_Tssss = map (map (map maybe_dest_sumT)) f_Tsss;
-              val pf_Tss = map2 zip_preds_getters p_Tss f_Tssss;
-            in (f_sum_prod_Ts, f_Tsss, pf_Tss) end;
+              val f_Tssss =
+                map3 (fn C => map2 (map (map (curry (op -->) C) o maybe_dest_sumT) oo dest_tupleT))
+                  Cs mss' f_prod_Tss;
+              val q_Tssss =
+                map (map (map (fn [_] => [] | [_, C] => [mk_predT (domain_type C)]))) f_Tssss;
+              val pf_Tss = map3 zip_preds_gettersss p_Tss q_Tssss f_Tssss;
+            in (q_Tssss, f_sum_prod_Ts, f_Tssss, pf_Tss) end;
 
-          val (g_sum_prod_Ts, g_Tsss, pg_Tss) = mk_types single fp_iter_fun_Ts;
+          val (r_Tssss, g_sum_prod_Ts, g_Tssss, pg_Tss) = mk_types single fp_iter_fun_Ts;
 
-          val ((((Free (z, _), cs), pss), gsss), _) =
+          val ((((Free (z, _), cs), pss), gssss), _) =
             lthy
             |> yield_singleton (mk_Frees "z") dummyT
             ||>> mk_Frees "a" Cs
             ||>> mk_Freess "p" p_Tss
-            ||>> mk_Freesss "g" g_Tsss;
-          val gssss = map (map (map single)) gsss;
+            ||>> mk_Freessss "g" g_Tssss;
+          val rssss = map (map (map (fn [] => []))) r_Tssss;
 
           fun dest_corec_sumT (T as Type (@{type_name sum}, Us as [_, U])) =
               if member (op =) Cs U then Us else [T]
             | dest_corec_sumT T = [T];
 
-          val (h_sum_prod_Ts, h_Tsss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
+          val (s_Tssss, h_sum_prod_Ts, h_Tssss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
 
-          val hsss = map2 (map2 (map2 retype_free)) gsss h_Tsss;
-          val hssss = map (map (map single)) hsss; (*###*)
+          val hssss_hd = map2 (map2 (map2 (fn [g] => fn T :: _ => retype_free g T))) gssss h_Tssss;
+          val ((sssss, hssss_tl), _) =
+            lthy
+            |> mk_Freessss "q" s_Tssss
+            ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
+          val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
 
           val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
 
-          fun mk_terms fssss =
+          fun mk_terms qssss fssss =
             let
-              val pfss = map2 zip_preds_getters pss fssss;
+              val pfss = map3 zip_preds_gettersss pss qssss fssss;
+              val cqssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs qssss;
               val cfssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs fssss;
-            in (pfss, cfssss) end;
+            in (pfss, cqssss, cfssss) end;
         in
           ((([], [], []), ([], [], [])),
-           ([z], cs, cpss, (mk_terms gssss, (g_sum_prod_Ts, pg_Tss)),
-            (mk_terms hssss, (h_sum_prod_Ts, ph_Tss))))
+           ([z], cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, pg_Tss)),
+            (mk_terms sssss hssss, (h_sum_prod_Ts, ph_Tss))))
         end;
 
     fun pour_some_sugar_on_type (((((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec),
@@ -424,12 +433,16 @@
           let
             val B_to_fpT = C --> fpT;
 
-            fun mk_preds_getters_join c n cps sum_prod_T cfsss =
+            fun mk_getters_join [] [cf] = cf
+              | mk_getters_join [cq] [cf, cf'] =
+                mk_If cq (mk_Inl (fastype_of cf') cf) (mk_Inr (fastype_of cf) cf');
+
+            fun mk_preds_gettersss_join c n cps sum_prod_T cqsss cfsss =
               Term.lambda c (mk_IfN sum_prod_T cps
-                (map2 (mk_InN_balanced sum_prod_T n) (map (HOLogic.mk_tuple o flat) cfsss) (*###*)
-                   (1 upto n)));
+                (map2 (mk_InN_balanced sum_prod_T n)
+                   (map2 (HOLogic.mk_tuple oo map2 mk_getters_join) cqsss cfsss) (1 upto n)));
 
-            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfssss), (f_sum_prod_Ts,
+            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cqssss, cfssss), (f_sum_prod_Ts,
                 pf_Tss))) =
               let
                 val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
@@ -439,7 +452,7 @@
                 val spec =
                   mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
                     Term.list_comb (fp_iter_like,
-                      map5 mk_preds_getters_join cs ns cpss f_sum_prod_Ts cfssss));
+                      map6 mk_preds_gettersss_join cs ns cpss f_sum_prod_Ts cqssss cfssss));
               in (binder, spec) end;
 
             val coiter_like_bundles =
--- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Tue Sep 11 13:06:13 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Tue Sep 11 13:06:14 2012 +0200
@@ -101,6 +101,7 @@
   val dest_tupleT: int -> typ -> typ list
 
   val mk_Field: term -> term
+  val mk_If: term -> term -> term -> term
   val mk_union: term * term -> term
 
   val mk_sumEN: int -> thm
@@ -258,6 +259,10 @@
 val mk_sum_caseN = Library.foldr1 mk_sum_case;
 val mk_sum_caseN_balanced = Balanced_Tree.make mk_sum_case;
 
+fun mk_If p t f =
+  let val T = fastype_of t;
+  in Const (@{const_name If}, HOLogic.boolT --> T --> T --> T) $ p $ t $ f end;
+
 fun mk_Field r =
   let val T = fst (dest_relT (fastype_of r));
   in Const (@{const_name Field}, mk_relT (T, T) --> HOLogic.mk_setT T) $ r end;
--- a/src/HOL/Codatatype/Tools/bnf_gfp_util.ML	Tue Sep 11 13:06:13 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_gfp_util.ML	Tue Sep 11 13:06:14 2012 +0200
@@ -12,7 +12,6 @@
   val dest_listT: typ -> typ
 
   val mk_Cons: term -> term -> term
-  val mk_If: term -> term -> term -> term
   val mk_Shift: term -> term -> term
   val mk_Succ: term -> term -> term
   val mk_Times: term * term -> term
@@ -122,10 +121,6 @@
 
 fun mk_size t = HOLogic.size_const (fastype_of t) $ t;
 
-fun mk_If p t f =
-  let val T = fastype_of t;
-  in Const (@{const_name If}, HOLogic.boolT --> T --> T --> T) $ p $ t $ f end;
-
 fun mk_quotient A R =
   let val T = fastype_of A;
   in Const (@{const_name quotient}, T --> fastype_of R --> HOLogic.mk_setT T) $ A $ R end;