src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 53592 5a7bf8c859f6
parent 53591 b6e2993fd0d3
child 53645 44f15d386aae
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Fri Sep 13 02:26:59 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Fri Sep 13 02:55:04 2013 +0200
@@ -53,7 +53,7 @@
         * ((term list list * term list list list) * (typ list * typ list list)) list) option)
     * Proof.context
 
-  val mk_iter_fun_arg_types: typ list -> int list -> int list list -> term ->
+  val mk_iter_fun_arg_types: typ list list list -> int list -> int list list -> term ->
     typ list list list list
   val mk_coiter_fun_arg_types: typ list list list -> typ list -> int list -> term ->
     typ list list
@@ -268,9 +268,8 @@
 
 val mk_fp_iter_fun_types = binder_fun_types o fastype_of;
 
-(* ### FIXME? *)
-fun unzip_recT Cs (T as Type (@{type_name prod}, Ts as [_, U])) =
-    if member (op =) Cs U then Ts else [T]
+fun unzip_recT (Type (@{type_name prod}, _)) T = [T]
+  | unzip_recT _ (T as Type (@{type_name prod}, Ts)) = Ts
   | unzip_recT _ T = [T];
 
 fun unzip_corecT (Type (@{type_name sum}, _)) T = [T]
@@ -399,12 +398,12 @@
 
 fun mk_iter_fun_arg_types0 n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
 
-fun mk_iter_fun_arg_types Cs ns mss =
+fun mk_iter_fun_arg_types ctr_Tsss ns mss =
   mk_fp_iter_fun_types
   #> map3 mk_iter_fun_arg_types0 ns mss
-  #> map (map (map (unzip_recT Cs)));
+  #> map2 (map2 (map2 unzip_recT)) ctr_Tsss;
 
-fun mk_iters_args_types Cs ns mss ctor_iter_fun_Tss lthy =
+fun mk_iters_args_types ctr_Tsss Cs ns mss ctor_iter_fun_Tss lthy =
   let
     val Css = map2 replicate ns Cs;
     val y_Tsss = map3 mk_iter_fun_arg_types0 ns mss (map un_fold_of ctor_iter_fun_Tss);
@@ -419,8 +418,11 @@
     val yssss = map (map (map single)) ysss;
 
     val z_Tssss =
-      map3 (fn n => fn ms => map2 (map (unzip_recT Cs) oo dest_tupleT) ms o
-        dest_sumTN_balanced n o domain_type o co_rec_of) ns mss ctor_iter_fun_Tss;
+      map4 (fn n => fn ms => fn ctr_Tss => fn ctor_iter_fun_Ts =>
+          map3 (fn m => fn ctr_Ts => fn ctor_iter_fun_T =>
+              map2 unzip_recT ctr_Ts (dest_tupleT m ctor_iter_fun_T))
+            ms ctr_Tss (dest_sumTN_balanced n (domain_type (co_rec_of ctor_iter_fun_Ts))))
+        ns mss ctr_Tsss ctor_iter_fun_Tss;
 
     val z_Tsss' = map (map flat_rec_arg_args) z_Tssss;
     val h_Tss = map2 (map2 (curry op --->)) z_Tsss' Css;
@@ -522,7 +524,7 @@
 
     val ((iters_args_types, coiters_args_types), lthy') =
       if fp = Least_FP then
-        mk_iters_args_types Cs ns mss xtor_co_iter_fun_Tss lthy |>> (rpair NONE o SOME)
+        mk_iters_args_types ctr_Tsss Cs ns mss xtor_co_iter_fun_Tss lthy |>> (rpair NONE o SOME)
       else
         mk_coiters_args_types ctr_Tsss Cs ns mss xtor_co_iter_fun_Tss lthy |>> (pair NONE o SOME)
   in