first step towards splitting corecursor function arguments into (p, g, h) triples
authorblanchet
Tue, 11 Sep 2012 13:06:13 +0200
changeset 49274 ddd606ec45b9
parent 49273 f839ce127a2e
child 49275 ce87d6a901eb
first step towards splitting corecursor function arguments into (p, g, h) triples
src/HOL/Codatatype/BNF_Library.thy
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML
--- a/src/HOL/Codatatype/BNF_Library.thy	Tue Sep 11 13:06:13 2012 +0200
+++ b/src/HOL/Codatatype/BNF_Library.thy	Tue Sep 11 13:06:13 2012 +0200
@@ -843,6 +843,14 @@
   "\<lbrakk>\<forall>x. s = f (Inr (Inl x)) \<longrightarrow> P; \<forall>x. s = f (Inr (Inr x)) \<longrightarrow> P\<rbrakk> \<Longrightarrow> \<forall>x. s = f (Inr x) \<longrightarrow> P"
 by (metis obj_sumE)
 
+lemma sum_map_if:
+"sum_map f g (if p then Inl x else Inr y) = (if p then Inl (f x) else Inr (g y))"
+by simp
+
+lemma sum_case_if:
+"sum_case f g (if p then Inl x else Inr y) = (if p then f x else g y)"
+by simp
+
 lemma not_arg_cong_Inr: "x \<noteq> y \<Longrightarrow> Inr x \<noteq> Inr y"
 by simp
 
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 13:06:13 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 13:06:13 2012 +0200
@@ -210,12 +210,8 @@
     val fp_iter_fun_Ts = fst (split_last (binder_types (fastype_of fp_iter1)));
     val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
 
-    fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
-        if member (op =) Cs U then Us else [T]
-      | dest_rec_pair T = [T];
-
     val ((iter_only as (gss, _, _), rec_only as (hss, _, _)),
-         (zs, cs, cpss, coiter_only as ((pgss, cgsss), _), corec_only as ((phss, chsss), _))) =
+         (zs, cs, cpss, coiter_only as ((pgss, cgssss), _), corec_only as ((phss, chssss), _))) =
       if lfp then
         let
           val y_Tsss =
@@ -227,18 +223,25 @@
             lthy
             |> mk_Freess "f" g_Tss
             ||>> mk_Freesss "x" y_Tsss;
+          val yssss = map (map (map single)) ysss;
+
+          fun dest_rec_prodT (T as Type (@{type_name prod}, Us as [_, U])) =
+              if member (op =) Cs U then Us else [T]
+            | dest_rec_prodT T = [T];
 
           val z_Tssss =
-            map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o
+            map3 (fn n => fn ms => map2 (map dest_rec_prodT oo dest_tupleT) ms o
               dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
           val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
 
           val hss = map2 (map2 retype_free) gss h_Tss;
-          val (zssss, _) =
+          val zssss_hd = map2 (map2 (map2 (fn y => fn T :: _ => retype_free y T))) ysss z_Tssss;
+          val (zssss_tl, _) =
             lthy
-            |> mk_Freessss "x" z_Tssss;
+            |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
+          val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
         in
-          (((gss, g_Tss, map (map (map single)) ysss), (hss, h_Tss, zssss)),
+          (((gss, g_Tss, yssss), (hss, h_Tss, zssss)),
            ([], [], [], (([], []), ([], [])), (([], []), ([], []))))
         end
       else
@@ -249,20 +252,23 @@
           val p_Tss =
             map2 (fn C => fn n => replicate (Int.max (0, n - 1)) (C --> HOLogic.boolT)) Cs ns;
 
-          fun zip_preds_getters [] [fs] = fs
-            | zip_preds_getters (p :: ps) (fs :: fss) = p :: fs @ zip_preds_getters ps fss;
+          fun zip_getters fss = flat fss;
 
-          fun mk_types fun_Ts =
+          fun zip_preds_getters [] [fss] = zip_getters fss
+            | zip_preds_getters (p :: ps) (fss :: fsss) =
+              p :: zip_getters fss @ zip_preds_getters ps fsss;
+
+          fun mk_types maybe_dest_sumT fun_Ts =
             let
               val f_sum_prod_Ts = map range_type fun_Ts;
               val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
               val f_Tsss =
                 map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss' f_prod_Tss;
-              val pf_Tss = map2 zip_preds_getters p_Tss f_Tsss
+              val f_Tssss = map (map (map maybe_dest_sumT)) f_Tsss;
+              val pf_Tss = map2 zip_preds_getters p_Tss f_Tssss;
             in (f_sum_prod_Ts, f_Tsss, pf_Tss) end;
 
-          val (g_sum_prod_Ts, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
-          val (h_sum_prod_Ts, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
+          val (g_sum_prod_Ts, g_Tsss, pg_Tss) = mk_types single fp_iter_fun_Ts;
 
           val ((((Free (z, _), cs), pss), gsss), _) =
             lthy
@@ -270,20 +276,28 @@
             ||>> mk_Frees "a" Cs
             ||>> mk_Freess "p" p_Tss
             ||>> mk_Freesss "g" g_Tsss;
+          val gssss = map (map (map single)) gsss;
+
+          fun dest_corec_sumT (T as Type (@{type_name sum}, Us as [_, U])) =
+              if member (op =) Cs U then Us else [T]
+            | dest_corec_sumT T = [T];
+
+          val (h_sum_prod_Ts, h_Tsss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
 
           val hsss = map2 (map2 (map2 retype_free)) gsss h_Tsss;
+          val hssss = map (map (map single)) hsss; (*###*)
 
           val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
 
-          fun mk_terms fsss =
+          fun mk_terms fssss =
             let
-              val pfss = map2 zip_preds_getters pss fsss;
-              val cfsss = map2 (fn c => map (map (fn f => f $ c))) cs fsss
-            in (pfss, cfsss) end;
+              val pfss = map2 zip_preds_getters pss fssss;
+              val cfssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs fssss;
+            in (pfss, cfssss) end;
         in
           ((([], [], []), ([], [], [])),
-           ([z], cs, cpss, (mk_terms gsss, (g_sum_prod_Ts, pg_Tss)),
-            (mk_terms hsss, (h_sum_prod_Ts, ph_Tss))))
+           ([z], cs, cpss, (mk_terms gssss, (g_sum_prod_Ts, pg_Tss)),
+            (mk_terms hssss, (h_sum_prod_Ts, ph_Tss))))
         end;
 
     fun pour_some_sugar_on_type (((((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec),
@@ -383,11 +397,11 @@
                       map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss));
               in (binder, spec) end;
 
-            val iter_likes =
+            val iter_like_bundles =
               [(iterN, fp_iter, iter_only),
                (recN, fp_rec, rec_only)];
 
-            val (binders, specs) = map generate_iter_like iter_likes |> split_list;
+            val (binders, specs) = map generate_iter_like iter_like_bundles |> split_list;
 
             val ((csts, defs), (lthy', lthy)) = no_defs_lthy
               |> apfst split_list o fold_map2 (fn b => fn spec =>
@@ -410,27 +424,29 @@
           let
             val B_to_fpT = C --> fpT;
 
-            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfsss), (f_sum_prod_Ts, pf_Tss))) =
+            fun mk_preds_getters_join c n cps sum_prod_T cfsss =
+              Term.lambda c (mk_IfN sum_prod_T cps
+                (map2 (mk_InN_balanced sum_prod_T n) (map (HOLogic.mk_tuple o flat) cfsss) (*###*)
+                   (1 upto n)));
+
+            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfssss), (f_sum_prod_Ts,
+                pf_Tss))) =
               let
                 val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
 
                 val binder = Binding.suffix_name ("_" ^ suf) b;
 
-                fun mk_preds_getters_join c n cps sum_prod_T cfss =
-                  Term.lambda c (mk_IfN sum_prod_T cps
-                    (map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cfss) (1 upto n)));
-
                 val spec =
                   mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
                     Term.list_comb (fp_iter_like,
-                      map5 mk_preds_getters_join cs ns cpss f_sum_prod_Ts cfsss));
+                      map5 mk_preds_getters_join cs ns cpss f_sum_prod_Ts cfssss));
               in (binder, spec) end;
 
-            val coiter_likes =
+            val coiter_like_bundles =
               [(coiterN, fp_iter, coiter_only),
                (corecN, fp_rec, corec_only)];
 
-            val (binders, specs) = map generate_coiter_like coiter_likes |> split_list;
+            val (binders, specs) = map generate_coiter_like coiter_like_bundles |> split_list;
 
             val ((csts, defs), (lthy', lthy)) = no_defs_lthy
               |> apfst split_list o fold_map2 (fn b => fn spec =>
@@ -490,14 +506,14 @@
                   ~1 => build_map (build_call fiter_likes maybe_tick) T U
                 | j => maybe_tick (nth vs j) (nth fiter_likes j));
 
-            fun mk_U maybe_prodT =
-              typ_subst (map2 (fn fpT => fn C => (fpT, maybe_prodT fpT C)) fpTs Cs);
+            fun mk_U maybe_mk_prodT =
+              typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
 
-            fun repair_calls fiter_likes maybe_cons maybe_tick maybe_prodT (x as Free (_, T)) =
+            fun repair_calls fiter_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
               if member (op =) fpTs T then
                 maybe_cons x [build_call fiter_likes (K I) (T, mk_U (K I) T) $ x]
               else if exists_subtype (member (op =) fpTs) T then
-                [build_call fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
+                [build_call fiter_likes maybe_tick (T, mk_U maybe_mk_prodT T) $ x]
               else
                 [x];
 
@@ -544,10 +560,10 @@
           let
             fun mk_goal_cond pos = HOLogic.mk_Trueprop o (not pos ? HOLogic.mk_not);
 
-            fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfs' =
+            fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfss' =
               fold_rev (fold_rev Logic.all) ([c] :: pfss)
                 (Logic.list_implies (seq_conds mk_goal_cond n k cps,
-                   mk_Trueprop_eq (fcoiter_like $ c, Term.list_comb (ctr, take m cfs'))));
+                   mk_Trueprop_eq (fcoiter_like $ c, lists_bmoc (take m cfss') ctr)));
 
             fun build_call fiter_likes maybe_tack (T, U) =
               if T = U then
@@ -557,23 +573,25 @@
                   ~1 => build_map (build_call fiter_likes maybe_tack) T U
                 | j => maybe_tack (nth cs j, nth vs j) (nth fiter_likes j));
 
-            fun mk_U maybe_sumT =
-              typ_subst (map2 (fn C => fn fpT => (maybe_sumT fpT C, fpT)) Cs fpTs);
+            fun mk_U maybe_mk_sumT =
+              typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
 
-            fun repair_calls fiter_likes maybe_sumT maybe_tack
+            fun repair_calls fiter_likes maybe_mk_sumT maybe_tack
                 (cf as Free (_, Type (_, [_, T])) $ _) =
               if exists_subtype (member (op =) Cs) T then
-                build_call fiter_likes maybe_tack (T, mk_U maybe_sumT T) $ cf
+                build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cf
               else
                 cf;
 
-            val cgsss' = map (map (map (repair_calls gcoiters (K I) (K I)))) cgsss;
-            val chsss' = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) chsss;
+            val cgssss' =
+              map (map (map (map (repair_calls gcoiters (K I) (K I))))) cgssss;
+            val chssss' =
+              map (map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z))))) chssss;
 
             val goal_coiterss =
-              map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss cgsss';
+              map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss cgssss';
             val goal_corecss =
-              map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss chsss';
+              map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss chssss';
 
             val coiter_tacss =
               map3 (map oo mk_coiter_like_tac coiter_defs map_ids) fp_iter_thms pre_map_defs
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML	Tue Sep 11 13:06:13 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML	Tue Sep 11 13:06:13 2012 +0200
@@ -60,7 +60,8 @@
     iter_like_thms) THEN Local_Defs.unfold_tac ctxt @{thms id_def} THEN rtac refl 1;
 
 val coiter_like_ss = ss_only @{thms if_True if_False};
-val coiter_like_thms = @{thms id_apply map_pair_def sum_map.simps prod.cases};
+val coiter_like_thms =
+  @{thms id_apply map_pair_def sum_case_if sum_map.simps sum_map_if prod.cases};
 
 fun mk_coiter_like_tac coiter_like_defs map_ids fld_unf_coiter_like pre_map_def ctr_def ctxt =
   Local_Defs.unfold_tac ctxt (ctr_def :: coiter_like_defs) THEN