--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sat Sep 08 21:04:26 2012 +0200
@@ -23,6 +23,11 @@
fun retype_free (Free (s, _)) T = Free (s, T);
+fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
+fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
+fun mk_doubly_uncurried_fun f xss =
+ mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
+
fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
fun merge_type_arg_constrained ctxt (T, c) (T', c') =
@@ -242,6 +247,10 @@
val mss = map (map length) ctr_Tsss;
val Css = map2 replicate ns Cs;
+ fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
+ if member (op =) Cs U then Us else [T]
+ | dest_rec_pair T = [T];
+
fun sugar_datatype no_defs_lthy =
let
val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter))));
@@ -253,7 +262,8 @@
val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec))));
val z_prod_Tss = map2 dest_sumTN ns fp_z_Ts;
val z_Tsss = map2 (map2 dest_tupleT) mss z_prod_Tss;
- val h_Tss = map2 (map2 (curry (op --->))) z_Tsss Css;
+ val z_Tssss = map (map (map dest_rec_pair)) z_Tsss;
+ val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
val rec_T = flat h_Tss ---> fp_T --> C;
val ((gss, ysss), _) =
@@ -262,9 +272,9 @@
||>> mk_Freesss "x" y_Tsss;
val hss = map2 (map2 retype_free) gss h_Tss;
- val (zsss, _) =
+ val (zssss, _) =
no_defs_lthy
- |> mk_Freesss "x" z_Tsss;
+ |> mk_Freessss "x" z_Tssss;
val iter_binder = Binding.suffix_name ("_" ^ iterN) b;
val rec_binder = Binding.suffix_name ("_" ^ recN) b;
@@ -277,7 +287,8 @@
Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
val rec_spec =
mk_Trueprop_eq (fold (fn hs => fn t => Term.list_comb (t, hs)) hss rec_free,
- Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) hss zsss));
+ Term.list_comb (fp_rec,
+ map2 (mk_sum_caseN oo map2 mk_doubly_uncurried_fun) hss zssss));
val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy
|> apfst split_list o fold_map (fn (b, spec) =>