src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 51831 a5137cd2c2c2
parent 51830 403f7ecd061f
child 51832 35911d5acfa9
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Apr 30 11:28:43 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Apr 30 11:59:20 2013 +0200
@@ -217,9 +217,7 @@
 val mk_fold_fun_typess = map2 (map2 (curry (op --->)));
 val mk_rec_fun_typess = mk_fold_fun_typess oo massage_rec_fun_arg_typesss;
 
-fun mk_corec_like_pred_types n = replicate (Int.max (0, n - 1)) o mk_pred1T;
-
-fun mk_unfold_corec_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts =
+fun mk_unfold_corec_terms_and_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy =
   let
     (*avoid "'a itself" arguments in coiterators and corecursors*)
     fun repair_arity [0] = [1]
@@ -236,7 +234,7 @@
     fun unzip_corecT T =
       if exists_subtype_in fpTs T then [project_corecT fst T, project_corecT snd T] else [T];
 
-    val p_Tss = map2 mk_corec_like_pred_types ns Cs;
+    val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs;
 
     fun mk_types maybe_unzipT fun_Ts =
       let
@@ -249,12 +247,10 @@
           map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss;
         val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
       in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end;
-  in
-    (p_Tss, mk_types single dtor_unfold_fun_Ts, mk_types unzip_corecT dtor_corec_fun_Ts)
-  end
 
-fun mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss lthy =
-  let
+    val (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss) = mk_types single dtor_unfold_fun_Ts;
+    val (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss) = mk_types unzip_corecT dtor_corec_fun_Ts;
+
     val (((cs, pss), gssss), lthy) =
       lthy
       |> mk_Frees "a" Cs
@@ -268,16 +264,22 @@
       |> mk_Freessss "q" s_Tssss
       ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
     val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
-  in
-    ((cs, pss, (gssss, rssss), (hssss, sssss)), lthy)
-  end;
+
+    val cpss = map2 (map o rapp) cs pss;
 
-fun mk_corec_like_terms cs pss qssss fssss =
-  let
-    val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
-    val cqssss = map2 (map o map o map o rapp) cs qssss;
-    val cfssss = map2 (map o map o map o rapp) cs fssss;
-  in (pfss, cqssss, cfssss) end;
+    fun mk_terms qssss fssss =
+      let
+        val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
+        val cqssss = map2 (map o map o map o rapp) cs qssss;
+        val cfssss = map2 (map o map o map o rapp) cs fssss;
+      in (pfss, cqssss, cfssss) end;
+
+    val unfold_terms = mk_terms rssss gssss;
+    val corec_terms = mk_terms sssss hssss;
+  in
+    ((cs, cpss, (unfold_terms, (g_sum_prod_Ts, g_Tsss, pg_Tss)),
+      (corec_terms, (h_sum_prod_Ts, h_Tsss, ph_Tss))), lthy)
+  end;
 
 fun mk_map live Ts Us t =
   let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
@@ -532,20 +534,15 @@
     val discIss = map #discIs ctr_wrap_ress;
     val sel_thmsss = map #sel_thmss ctr_wrap_ress;
 
+    val ((cs, cpss, ((pgss, crssss, cgssss), _), ((phss, csssss, chssss), _)), names_lthy0) =
+      mk_unfold_corec_terms_and_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy;
+
     val (((rs, us'), vs'), names_lthy) =
-      lthy
+      names_lthy0
       |> mk_Frees "R" (map (fn T => mk_pred2T T T) fpTs)
       ||>> Variable.variant_fixes fp_b_names
       ||>> Variable.variant_fixes (map (suffix "'") fp_b_names);
 
-    val (p_Tss, (r_Tssss, _, _, g_Tssss, _), (s_Tssss, _, _, h_Tssss, _)) =
-      mk_unfold_corec_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts;
-
-    val ((cs, pss, (gssss, rssss), (hssss, sssss)), names_lthy) =
-      mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss names_lthy;
-
-    val cpss = map2 (map o rapp) cs pss;
-
     val us = map2 (curry Free) us' fpTs;
     val udiscss = map2 (map o rapp) us discss;
     val uselsss = map2 (map o map o rapp) us selsss;
@@ -554,9 +551,6 @@
     val vdiscss = map2 (map o rapp) vs discss;
     val vselsss = map2 (map o map o rapp) vs selsss;
 
-    val (pgss, crssss, cgssss) = mk_corec_like_terms cs pss rssss gssss;
-    val (phss, csssss, chssss) = mk_corec_like_terms cs pss sssss hssss;
-
     val ((coinduct_thms, coinduct_thm), (strong_coinduct_thms, strong_coinduct_thm)) =
       let
         val uvrs = map3 (fn r => fn u => fn v => r $ u $ v) rs us vs;
@@ -972,20 +966,8 @@
             ([], [], (([], [], []), ([], [], [])), (([], [], []), ([], [], [])))), lthy)
         end
       else
-        let
-          val (p_Tss, (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss),
-               (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss)) =
-            mk_unfold_corec_types fpTs Cs ns mss fp_fold_fun_Ts fp_rec_fun_Ts;
-
-          val ((cs, pss, (gssss, rssss), (hssss, sssss)), lthy) =
-            mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss lthy;
-
-          val cpss = map2 (map o rapp) cs pss;
-        in
-          (((([], [], []), ([], [], [])),
-            (cs, cpss, (mk_corec_like_terms cs pss rssss gssss, (g_sum_prod_Ts, g_Tsss, pg_Tss)),
-             (mk_corec_like_terms cs pss sssss hssss, (h_sum_prod_Ts, h_Tsss, ph_Tss)))), lthy)
-        end;
+        mk_unfold_corec_terms_and_types fpTs Cs ns mss fp_fold_fun_Ts fp_rec_fun_Ts lthy
+        |>> pair (([], [], []), ([], [], []));
 
     fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
             fp_fold), fp_rec), ctor_dtor), dtor_ctor), ctor_inject), pre_map_def), pre_set_defs),