src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 52302 867d5d16158c
parent 52301 7935e82a4ae4
child 52303 16d7708aba40
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Wed Jun 05 12:51:16 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Wed Jun 05 13:13:35 2013 +0200
@@ -31,7 +31,6 @@
   val flat_rec: 'a list list -> 'a list
   val mk_co_iter: theory -> BNF_FP_Util.fp_kind -> typ -> typ list -> term -> term
   val nesty_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
-  val indexify_fst: ''a list -> (int -> ''a * 'b -> 'c) -> ''a * 'b -> 'c
   val mk_un_fold_co_rec_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
     int list list -> term list -> term list -> Proof.context ->
     (term list * term list
@@ -433,7 +432,7 @@
 fun nesty_bnfs ctxt ctr_Tsss Us =
   map_filter (bnf_of ctxt) (fold (fold (fold (add_nesty_bnf_names Us))) ctr_Tsss []);
 
-fun indexify_fst xs f (x, y) = f (find_index (curry (op =) x) xs) (x, y);
+fun indexify proj xs f p = f (find_index (curry (op =) (proj p)) xs) p;
 
 fun build_map lthy build_simple =
   let
@@ -674,15 +673,19 @@
           fold_rev (fold_rev Logic.all) (xs :: fss)
             (mk_Trueprop_eq (fiter $ xctr, Term.list_comb (f, fxs)));
 
-        fun tick u f = Term.lambda u (HOLogic.mk_prod (u, f $ u));
+        fun maybe_tick (T, U) u f =
+          if try (fst o HOLogic.dest_prodT) U = SOME T then
+            Term.lambda u (HOLogic.mk_prod (u, f $ u))
+          else
+            f;
 
-        fun unzip_iters fiters maybe_tick (x as Free (_, T)) Us =
+        fun unzip_iters tick fiters (x as Free (_, T)) =
           map (fn U => if U = T then x else
-            build_map lthy (indexify_fst fpTs (fn kk => fn _ =>
-              nth fiters kk |> length Us = 1 ? maybe_tick (nth us kk))) (T, U) $ x) Us;
+            build_map lthy (indexify fst fpTs (fn kk => fn TU =>
+              nth fiters kk |> maybe_tick TU (nth us kk))) (T, U) $ x);
 
-        val gxsss = map2 (map2 (flat_rec oo map2 (unzip_iters gfolds (K I)))) xsss y_Tssss;
-        val hxsss = map2 (map2 (flat_rec oo map2 (unzip_iters hrecs tick))) xsss z_Tssss;
+        val gxsss = map2 (map2 (flat_rec oo map2 (unzip_iters false gfolds))) xsss y_Tssss;
+        val hxsss = map2 (map2 (flat_rec oo map2 (unzip_iters true hrecs))) xsss z_Tssss;
 
         val fold_goalss = map5 (map4 o mk_goal gss) gfolds xctrss gss xsss gxsss;
         val rec_goalss = map5 (map4 o mk_goal hss) hrecs xctrss hss xsss hxsss;
@@ -853,7 +856,7 @@
         fun intr_coiters fcoiters [] [cf] =
             let val T = fastype_of cf in
               if exists_subtype_in Cs T then
-                build_map lthy (indexify_fst Cs (K o nth fcoiters)) (T, mk_U T) $ cf
+                build_map lthy (indexify fst Cs (K o nth fcoiters)) (T, mk_U T) $ cf
               else
                 cf
             end