src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 52301 7935e82a4ae4
parent 52300 4a4da43e855a
child 52302 867d5d16158c
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Wed Jun 05 11:30:24 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Wed Jun 05 12:51:16 2013 +0200
@@ -28,14 +28,16 @@
 
   val tvar_subst: theory -> typ list -> typ list -> ((string * int) * typ) list
   val exists_subtype_in: typ list -> typ -> bool
-  val flat_rec: ('a -> 'b list) -> 'a list -> 'b list
+  val flat_rec: 'a list list -> 'a list
   val mk_co_iter: theory -> BNF_FP_Util.fp_kind -> typ -> typ list -> term -> term
   val nesty_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
   val indexify_fst: ''a list -> (int -> ''a * 'b -> 'c) -> ''a * 'b -> 'c
   val mk_un_fold_co_rec_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
     int list list -> term list -> term list -> Proof.context ->
-    (term list * term list * ((term list list * typ list list * term list list list list)
-       * (term list list * typ list list * term list list list list)) option
+    (term list * term list
+       * ((typ list list * typ list list list list * term list list * term list list list list)
+          * (typ list list * typ list list list list * term list list
+             * term list list list list)) option
      * (term list * term list list
         * ((term list list * term list list list list * term list list list list)
            * (typ list * typ list list list * typ list list))
@@ -47,9 +49,10 @@
 
   val mk_iter_fun_arg_typessss: typ list -> int list -> int list list -> term ->
     typ list list list list
-  val define_fold_rec: (term list list * typ list list * term list list list list)
-      * (term list list * typ list list * term list list list list) -> (string -> binding) ->
-    typ list -> typ list -> term -> term -> Proof.context ->
+  val define_fold_rec:
+    (typ list list * typ list list list list * term list list * term list list list list)
+     * (typ list list * typ list list list list * term list list * term list list list list) ->
+    (string -> binding) -> typ list -> typ list -> term -> term -> Proof.context ->
     (term * term * thm * thm) * Proof.context
   val define_unfold_corec: term list * term list list
       * ((term list list * term list list list list * term list list list list)
@@ -188,15 +191,13 @@
 
 val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
 
-fun flat_rec unzipf xs =
-  let val ps = map unzipf xs in
-    (* The first line below gives the preferred order. The second line is for compatibility with the
-       old datatype package: *)
+fun flat_rec xss =
+  (* The first line below gives the preferred order. The second line is for compatibility with the
+     old datatype package: *)
 (*
-    flat ps
+  flat xss
 *)
-    map hd ps @ maps tl ps
-  end;
+  map hd xss @ maps tl xss;
 
 fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss);
 
@@ -207,7 +208,7 @@
 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
 fun mk_uncurried2_fun f xss =
-  mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat_rec I xss);
+  mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat_rec xss);
 
 fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) =
   Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1));
@@ -285,13 +286,15 @@
       lthy
       |> mk_Freess "f" g_Tss
       ||>> mk_Freesss "x" y_Tsss;
+
+    val y_Tssss = map (map (map single)) y_Tsss;
     val yssss = map (map (map single)) ysss;
 
     val z_Tssss =
       map3 (fn n => fn ms => map2 (map (unzip_recT Cs) oo dest_tupleT) ms o
         dest_sumTN_balanced n o domain_type) ns mss ctor_rec_fun_Ts;
 
-    val z_Tsss' = map (map (flat_rec I)) z_Tssss;
+    val z_Tsss' = map (map flat_rec) z_Tssss;
     val h_Tss = map2 (map2 (curry (op --->))) z_Tsss' Css;
 
     val hss = map2 (map2 retype_free) h_Tss gss;
@@ -301,7 +304,7 @@
       |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
     val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
   in
-    (((gss, g_Tss, yssss), (hss, h_Tss, zssss)), lthy)
+    (((g_Tss, y_Tssss, gss, yssss), (h_Tss, z_Tssss, hss, zssss)), lthy)
   end;
 
 fun mk_unfold_corec_args_types Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy =
@@ -483,7 +486,7 @@
 
     val fpT_to_C as Type (_, [fpT, _]) = snd (strip_typeN nn (fastype_of ctor_fold));
 
-    fun generate_iter (suf, ctor_iter, (fss, f_Tss, xssss)) =
+    fun generate_iter (suf, ctor_iter, (f_Tss, _, fss, xssss)) =
       let
         val res_T = fold_rev (curry (op --->)) f_Tss fpT_to_C;
         val binding = mk_binding suf;
@@ -569,7 +572,7 @@
     val ctor_fold_fun_Ts = mk_fp_iter_fun_types (hd ctor_folds);
     val ctor_rec_fun_Ts = mk_fp_iter_fun_types (hd ctor_recs);
 
-    val (((gss, _, _), (hss, _, _)), names_lthy0) =
+    val (((_, y_Tssss, gss, _), (_, z_Tssss, hss, _)), names_lthy0) =
       mk_fold_rec_args_types Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy;
 
     val ((((ps, ps'), xsss), us'), names_lthy) =
@@ -671,26 +674,15 @@
           fold_rev (fold_rev Logic.all) (xs :: fss)
             (mk_Trueprop_eq (fiter $ xctr, Term.list_comb (f, fxs)));
 
-        val mk_U = typ_subst_nonatomic (map2 pair fpTs Cs);
-
-        fun mk_nested_U maybe_mk_prodT =
-          typ_subst_nonatomic (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
-
-        fun unzip_iters fiters maybe_tick maybe_mk_prodT x =
-          let val Free (_, T) = x in
-            if member (op =) fpTs T then
-              [x, build_map lthy (indexify_fst fpTs (K o nth fiters)) (T, mk_U T) $ x]
-            else if exists_subtype_in fpTs T then
-              [build_map lthy (indexify_fst fpTs (fn kk => fn _ =>
-                 maybe_tick (nth us kk) (nth fiters kk))) (T, mk_nested_U maybe_mk_prodT T) $ x]
-            else
-              [x]
-          end;
-
         fun tick u f = Term.lambda u (HOLogic.mk_prod (u, f $ u));
 
-        val gxsss = map (map (flat_rec (single o List.last o unzip_iters gfolds (K I) (K I)))) xsss;
-        val hxsss = map (map (flat_rec (unzip_iters hrecs tick (curry HOLogic.mk_prodT)))) xsss;
+        fun unzip_iters fiters maybe_tick (x as Free (_, T)) Us =
+          map (fn U => if U = T then x else
+            build_map lthy (indexify_fst fpTs (fn kk => fn _ =>
+              nth fiters kk |> length Us = 1 ? maybe_tick (nth us kk))) (T, U) $ x) Us;
+
+        val gxsss = map2 (map2 (flat_rec oo map2 (unzip_iters gfolds (K I)))) xsss y_Tssss;
+        val hxsss = map2 (map2 (flat_rec oo map2 (unzip_iters hrecs tick))) xsss z_Tssss;
 
         val fold_goalss = map5 (map4 o mk_goal gss) gfolds xctrss gss xsss gxsss;
         val rec_goalss = map5 (map4 o mk_goal hss) hrecs xctrss hss xsss hxsss;