finished splitting sum types for corecursors
authorblanchet
Tue, 11 Sep 2012 13:06:14 +0200
changeset 49276 59fa53ed7507
parent 49275 ce87d6a901eb
child 49277 aee77001243f
finished splitting sum types for corecursors
src/HOL/Codatatype/BNF_Library.thy
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML
--- a/src/HOL/Codatatype/BNF_Library.thy	Tue Sep 11 13:06:14 2012 +0200
+++ b/src/HOL/Codatatype/BNF_Library.thy	Tue Sep 11 13:06:14 2012 +0200
@@ -843,10 +843,6 @@
   "\<lbrakk>\<forall>x. s = f (Inr (Inl x)) \<longrightarrow> P; \<forall>x. s = f (Inr (Inr x)) \<longrightarrow> P\<rbrakk> \<Longrightarrow> \<forall>x. s = f (Inr x) \<longrightarrow> P"
 by (metis obj_sumE)
 
-lemma sum_map_if:
-"sum_map f g (if p then Inl x else Inr y) = (if p then Inl (f x) else Inr (g y))"
-by simp
-
 lemma sum_case_if:
 "sum_case f g (if p then Inl x else Inr y) = (if p then f x else g y)"
 by simp
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 13:06:14 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 13:06:14 2012 +0200
@@ -213,8 +213,7 @@
     val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
 
     val ((iter_only as (gss, _, _), rec_only as (hss, _, _)),
-         (zs, cs, cpss, coiter_only as ((pgss, _, cgssss), _),
-          corec_only as ((phss, _, chssss), _))) =
+         (zs, cs, cpss, coiter_only as ((pgss, crgsss), _), corec_only as ((phss, cshsss), _))) =
       if lfp then
         let
           val y_Tsss =
@@ -245,7 +244,7 @@
           val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
         in
           (((gss, g_Tss, yssss), (hss, h_Tss, zssss)),
-           ([], [], [], (([], [], []), ([], [])), (([], [], []), ([], []))))
+           ([], [], [], (([], []), ([], [])), (([], []), ([], []))))
         end
       else
         let
@@ -254,11 +253,11 @@
 
           val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_predT) ns Cs;
 
-          fun zip_getterss qss fss = maps (op @) (qss ~~ fss);
+          fun zip_predss_getterss qss fss = maps (op @) (qss ~~ fss);
 
-          fun zip_preds_gettersss [] [qss] [fss] = zip_getterss qss fss
-            | zip_preds_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
-              p :: zip_getterss qss fss @ zip_preds_gettersss ps qsss fsss;
+          fun zip_preds_predsss_gettersss [] [qss] [fss] = zip_predss_getterss qss fss
+            | zip_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
+              p :: zip_predss_getterss qss fss @ zip_preds_predsss_gettersss ps qsss fsss;
 
           fun mk_types maybe_dest_sumT fun_Ts =
             let
@@ -269,7 +268,7 @@
                   Cs mss' f_prod_Tss;
               val q_Tssss =
                 map (map (map (fn [_] => [] | [_, C] => [mk_predT (domain_type C)]))) f_Tssss;
-              val pf_Tss = map3 zip_preds_gettersss p_Tss q_Tssss f_Tssss;
+              val pf_Tss = map3 zip_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
             in (q_Tssss, f_sum_prod_Ts, f_Tssss, pf_Tss) end;
 
           val (r_Tssss, g_sum_prod_Ts, g_Tssss, pg_Tss) = mk_types single fp_iter_fun_Ts;
@@ -297,12 +296,17 @@
 
           val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
 
+          fun mk_preds_getters_join [] [cf] = cf
+            | mk_preds_getters_join [cq] [cf, cf'] =
+              mk_If cq (mk_Inl (fastype_of cf') cf) (mk_Inr (fastype_of cf) cf');
+
           fun mk_terms qssss fssss =
             let
-              val pfss = map3 zip_preds_gettersss pss qssss fssss;
+              val pfss = map3 zip_preds_predsss_gettersss pss qssss fssss;
               val cqssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs qssss;
               val cfssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs fssss;
-            in (pfss, cqssss, cfssss) end;
+              val cqfsss = map2 (map2 (map2 mk_preds_getters_join)) cqssss cfssss;
+            in (pfss, cqfsss) end;
         in
           ((([], [], []), ([], [], [])),
            ([z], cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, pg_Tss)),
@@ -433,16 +437,11 @@
           let
             val B_to_fpT = C --> fpT;
 
-            fun mk_getters_join [] [cf] = cf
-              | mk_getters_join [cq] [cf, cf'] =
-                mk_If cq (mk_Inl (fastype_of cf') cf) (mk_Inr (fastype_of cf) cf');
+            fun mk_preds_getterss_join c n cps sum_prod_T cqfss =
+              Term.lambda c (mk_IfN sum_prod_T cps
+                (map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cqfss) (1 upto n)));
 
-            fun mk_preds_gettersss_join c n cps sum_prod_T cqsss cfsss =
-              Term.lambda c (mk_IfN sum_prod_T cps
-                (map2 (mk_InN_balanced sum_prod_T n)
-                   (map2 (HOLogic.mk_tuple oo map2 mk_getters_join) cqsss cfsss) (1 upto n)));
-
-            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cqssss, cfssss), (f_sum_prod_Ts,
+            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cqfsss), (f_sum_prod_Ts,
                 pf_Tss))) =
               let
                 val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
@@ -452,7 +451,7 @@
                 val spec =
                   mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
                     Term.list_comb (fp_iter_like,
-                      map6 mk_preds_gettersss_join cs ns cpss f_sum_prod_Ts cqssss cfssss));
+                      map5 mk_preds_getterss_join cs ns cpss f_sum_prod_Ts cqfsss));
               in (binder, spec) end;
 
             val coiter_like_bundles =
@@ -542,11 +541,9 @@
             val rec_tacss =
               map2 (map o mk_iter_like_tac pre_map_defs map_ids rec_defs) fp_rec_thms ctr_defss;
           in
-            (map2 (map2 (fn goal => fn tac =>
-                 Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
+            (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
                goal_iterss iter_tacss,
-             map2 (map2 (fn goal => fn tac =>
-                 Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
+             map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
                goal_recss rec_tacss)
           end;
 
@@ -573,10 +570,10 @@
           let
             fun mk_goal_cond pos = HOLogic.mk_Trueprop o (not pos ? HOLogic.mk_not);
 
-            fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfss' =
+            fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfs' =
               fold_rev (fold_rev Logic.all) ([c] :: pfss)
                 (Logic.list_implies (seq_conds mk_goal_cond n k cps,
-                   mk_Trueprop_eq (fcoiter_like $ c, lists_bmoc (take m cfss') ctr)));
+                   mk_Trueprop_eq (fcoiter_like $ c, Term.list_comb (ctr, take m cfs'))));
 
             fun build_call fiter_likes maybe_tack (T, U) =
               if T = U then
@@ -589,22 +586,21 @@
             fun mk_U maybe_mk_sumT =
               typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
 
-            fun repair_calls fiter_likes maybe_mk_sumT maybe_tack
-                (cf as Free (_, Type (_, [_, T])) $ _) =
-              if exists_subtype (member (op =) Cs) T then
-                build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cf
-              else
-                cf;
+            fun repair_calls fiter_likes maybe_mk_sumT maybe_tack cqf =
+              let val T = fastype_of cqf in
+                if exists_subtype (member (op =) Cs) T then
+                  build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
+                else
+                  cqf
+              end;
 
-            val cgssss' =
-              map (map (map (map (repair_calls gcoiters (K I) (K I))))) cgssss;
-            val chssss' =
-              map (map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z))))) chssss;
+            val crgsss' = map (map (map (repair_calls gcoiters (K I) (K I)))) crgsss;
+            val cshsss' = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) cshsss;
 
             val goal_coiterss =
-              map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss cgssss';
+              map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss crgsss';
             val goal_corecss =
-              map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss chssss';
+              map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss cshsss';
 
             val coiter_tacss =
               map3 (map oo mk_coiter_like_tac coiter_defs map_ids) fp_iter_thms pre_map_defs
@@ -613,9 +609,12 @@
               map3 (map oo mk_coiter_like_tac corec_defs map_ids) fp_rec_thms pre_map_defs
                 ctr_defss;
           in
-            (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
+            (map2 (map2 (fn goal => fn tac =>
+                 Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
                goal_coiterss coiter_tacss,
-             map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
+             map2 (map2 (fn goal => fn tac =>
+                 Skip_Proof.prove lthy [] [] goal (tac o #context)
+                 |> Local_Defs.unfold lthy @{thms sum_case_if} |> Thm.close_derivation))
                goal_corecss corec_tacss)
           end;
 
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML	Tue Sep 11 13:06:14 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML	Tue Sep 11 13:06:14 2012 +0200
@@ -60,8 +60,7 @@
     iter_like_thms) THEN Local_Defs.unfold_tac ctxt @{thms id_def} THEN rtac refl 1;
 
 val coiter_like_ss = ss_only @{thms if_True if_False};
-val coiter_like_thms =
-  @{thms id_apply map_pair_def sum_case_if sum_map.simps sum_map_if prod.cases};
+val coiter_like_thms = @{thms id_apply map_pair_def sum_map.simps prod.cases};
 
 fun mk_coiter_like_tac coiter_like_defs map_ids fld_unf_coiter_like pre_map_def ctr_def ctxt =
   Local_Defs.unfold_tac ctxt (ctr_def :: coiter_like_defs) THEN