src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 49233 7f412734fbb3
parent 49232 9ea11f0c53e4
child 49234 4626ff7cbd2c
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sun Sep 09 17:14:39 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sun Sep 09 18:55:10 2012 +0200
@@ -48,7 +48,12 @@
 fun mk_uncurried2_fun f xss =
   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
 
-fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v))
+fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v));
+
+fun tack z_name (c, v) f =
+  let val z = Free (z_name, mk_sumT (fastype_of v, fastype_of c)) in
+    Term.lambda z (mk_sum_case (Term.lambda v v) (Term.lambda c (f $ c)) $ z)
+  end;
 
 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
 
@@ -204,7 +209,7 @@
       | dest_rec_pair T = [T];
 
     val ((iter_only as (gss, g_Tss, yssss), rec_only as (hss, h_Tss, zssss)),
-         (cs, cpss, p_Tss, coiter_only as ((pgss, cgsss), g_sum_prod_Ts, g_prod_Tss, g_Tsss),
+         (zs, cs, cpss, p_Tss, coiter_only as ((pgss, cgsss), g_sum_prod_Ts, g_prod_Tss, g_Tsss),
           corec_only as ((phss, chsss), h_sum_prod_Ts, h_prod_Tss, h_Tsss))) =
       if lfp then
         let
@@ -229,7 +234,7 @@
             |> mk_Freessss "x" z_Tssss;
         in
           (((gss, g_Tss, map (map (map single)) ysss), (hss, h_Tss, zssss)),
-           ([], [], [], (([], []), [], [], []), (([], []), [], [], [])))
+           ([], [], [], [], (([], []), [], [], []), (([], []), [], [], [])))
         end
       else
         let
@@ -254,15 +259,15 @@
           val (g_sum_prod_Ts, g_prod_Tss, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
           val (h_sum_prod_Ts, h_prod_Tss, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
 
-          val (((c, pss), gsss), _) =
+          val ((((Free (z, _), cs), pss), gsss), _) =
             lthy
-            |> yield_singleton (mk_Frees "c") dummyT
+            |> yield_singleton (mk_Frees "z") dummyT
+            ||>> mk_Frees "a" Cs
             ||>> mk_Freess "p" p_Tss
             ||>> mk_Freesss "g" g_Tsss;
 
           val hsss = map2 (map2 (map2 retype_free)) gsss h_Tsss;
 
-          val cs = map (retype_free c) Cs;
           val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
 
           fun mk_terms fsss =
@@ -272,7 +277,7 @@
             in (pfss, cfsss) end;
         in
           ((([], [], []), ([], [], [])),
-           (cs, cpss, p_Tss, (mk_terms gsss, g_sum_prod_Ts, g_prod_Tss, pg_Tss),
+           ([z], cs, cpss, p_Tss, (mk_terms gsss, g_sum_prod_Ts, g_prod_Tss, pg_Tss),
             (mk_terms hsss, h_sum_prod_Ts, h_prod_Tss, ph_Tss)))
         end;
 
@@ -447,24 +452,6 @@
         Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
       end;
 
-    fun build_iter_like_call vs basic_Ts fiter_likes maybe_tick =
-      let
-        fun build (T, U) =
-          if T = U then
-            Const (@{const_name id}, T --> T)
-          else
-            (case (find_index (curry (op =) T) basic_Ts, (T, U)) of
-              (~1, (Type (s, Ts), Type (_, Us))) =>
-              let
-                val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
-                val mapx = mk_map Ts Us map0;
-                val TUs =
-                  map dest_funT (fst (split_last (fst (strip_map_type (fastype_of mapx)))));
-                val args = map build TUs;
-              in Term.list_comb (mapx, args) end
-            | (j, _) => maybe_tick (nth vs j) (nth fiter_likes j))
-      in build end;
-
     fun pour_more_sugar_on_lfps ((ctrss, iters, recs, vs, xsss, ctr_defss, iter_defs, rec_defs),
         lthy) =
       let
@@ -478,14 +465,32 @@
               fold_rev (fold_rev Logic.all) (xs :: fss)
                 (mk_Trueprop_eq (fiter_like $ xctr, Term.list_comb (f, fxs)));
 
+            fun build_call fiter_likes maybe_tick =
+              let
+                fun build (T, U) =
+                  if T = U then
+                    Const (@{const_name id}, T --> T)
+                  else
+                    (case (find_index (curry (op =) T) fpTs, (T, U)) of
+                      (~1, (Type (s, Ts), Type (_, Us))) =>
+                      let
+                        val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
+                        val mapx = mk_map Ts Us map0;
+                        val TUs =
+                          map dest_funT (fst (split_last (fst (strip_map_type (fastype_of mapx)))));
+                        val args = map build TUs;
+                      in Term.list_comb (mapx, args) end
+                    | (j, _) => maybe_tick (nth vs j) (nth fiter_likes j))
+              in build end;
+
             fun mk_U maybe_prodT =
               typ_subst (map2 (fn fpT => fn C => (fpT, maybe_prodT fpT C)) fpTs Cs);
 
             fun repair_calls fiter_likes maybe_cons maybe_tick maybe_prodT (x as Free (_, T)) =
               if member (op =) fpTs T then
-                maybe_cons x [build_iter_like_call vs fpTs fiter_likes (K I) (T, mk_U (K I) T) $ x]
+                maybe_cons x [build_call fiter_likes (K I) (T, mk_U (K I) T) $ x]
               else if exists_subtype (member (op =) fpTs) T then
-                [build_iter_like_call vs fpTs fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
+                [build_call fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
               else
                 [x];
 
@@ -521,6 +526,8 @@
     fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, vs, xsss, ctr_defss, coiter_defs,
         corec_defs), lthy) =
       let
+        val z = the_single zs;
+
         val gcoiters = map (lists_bmoc pgss) coiters;
         val hcorecs = map (lists_bmoc phss) corecs;
 
@@ -533,32 +540,58 @@
                 (Logic.list_implies (seq_conds mk_goal_cond n k cps,
                    mk_Trueprop_eq (fcoiter_like $ c, Term.list_comb (ctr, cfs'))));
 
+            fun build_call fiter_likes maybe_tack =
+              let
+                fun build (T, U) =
+                  if T = U then
+                    Const (@{const_name id}, T --> T)
+                  else
+                    (case (find_index (curry (op =) U) fpTs, (T, U)) of
+                      (~1, (Type (s, Ts), Type (_, Us))) =>
+                      let
+                        val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
+                        val mapx = mk_map Ts Us map0;
+                        val TUs =
+                          map dest_funT (fst (split_last (fst (strip_map_type (fastype_of mapx)))));
+                        val args = map build TUs;
+                      in Term.list_comb (mapx, args) end
+                    | (j, _) => maybe_tack (nth cs j, nth vs j) (nth fiter_likes j))
+              in build end;
+
             fun mk_U maybe_sumT =
-              typ_subst (map2 (fn C => fn fpT => (C, maybe_sumT C fpT)) Cs fpTs);
+              typ_subst (map2 (fn C => fn fpT => (maybe_sumT fpT C, fpT)) Cs fpTs);
 
             fun repair_calls fiter_likes maybe_sumT maybe_tack
                 (cf as Free (_, Type (_, [_, T])) $ _) =
               if exists_subtype (member (op =) Cs) T then
-                build_iter_like_call vs Cs fiter_likes maybe_tack (T, mk_U maybe_sumT T) $ cf
+                build_call fiter_likes maybe_tack (T, mk_U maybe_sumT T) $ cf
               else
                 cf;
 
             val cgsss = map (map (map (repair_calls gcoiters (K I) (K I)))) cgsss;
+            val chsss = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) chsss;
 
             val goal_coiterss =
               map7 (map3 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss cgsss;
+            val goal_corecss =
+              map7 (map3 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss chsss;
 
             val coiter_tacss =
               map3 (map oo mk_coiter_like_tac coiter_defs map_ids) fp_iter_thms pre_map_defs
                 ctr_defss;
+            val corec_tacss =
+              map3 (map oo mk_coiter_like_tac corec_defs map_ids) fp_rec_thms pre_map_defs
+                ctr_defss;
           in
             (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
                goal_coiterss coiter_tacss,
-             [])
+             map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
+               goal_corecss corec_tacss)
           end;
 
         val notes =
-          [(coitersN, coiter_thmss)]
+          [(coitersN, coiter_thmss),
+           (corecsN, corec_thmss)]
           |> maps (fn (thmN, thmss) =>
             map2 (fn b => fn thms =>
                 ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))