--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Wed May 29 02:35:49 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Wed May 29 02:35:49 2013 +0200
@@ -32,8 +32,8 @@
val indexify_fst: ''a list -> (int -> ''a * 'b -> 'c) -> ''a * 'b -> 'c
val mk_un_fold_co_rec_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
int list list -> term list -> term list -> Proof.context ->
- (term list * term list * ((term list list * typ list list * term list list list)
- * (term list list * typ list list * term list list list)) option
+ (term list * term list * ((term list list * typ list list * term list list list list)
+ * (term list list * typ list list * term list list list list)) option
* (term list * term list list
* ((term list list * term list list list list * term list list list list)
* (typ list * typ list list list * typ list list))
@@ -44,9 +44,10 @@
val mk_iter_fun_arg_types_pairsss: typ list -> int list -> int list list -> term ->
(typ list * typ list) list list list
- val define_fold_rec: (term list list * typ list list * term list list list)
- * (term list list * typ list list * term list list list) -> (string -> binding) -> typ list ->
- typ list -> term -> term -> Proof.context -> (term * term * thm * thm) * Proof.context
+ val define_fold_rec: (term list list * typ list list * term list list list list)
+ * (term list list * typ list list * term list list list list) -> (string -> binding) ->
+ typ list -> typ list -> term -> term -> Proof.context ->
+ (term * term * thm * thm) * Proof.context
val define_unfold_corec: term list * term list list
* ((term list list * term list list list list * term list list list list)
* (typ list * typ list list list * typ list list))
@@ -182,17 +183,14 @@
val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
-fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
-fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
-
fun flat_rec unzipf xs =
let val ps = map unzipf xs in
(* The first line below gives the preferred order. The second line is for compatibility with the
old datatype package: *)
-(*
maps (op @) ps
+(* ###
+ maps fst ps @ maps snd ps
*)
- maps fst ps @ maps snd ps
end;
fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss);
@@ -201,6 +199,11 @@
| flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
+fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
+fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
+fun mk_uncurried2_fun f xss =
+ mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
+
fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) =
Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1));
@@ -245,8 +248,12 @@
val mk_fp_iter_fun_types = fst o split_last o binder_types o fastype_of;
-fun meta_unzip_rec getT proj1 proj2 fpTs y =
- if exists_subtype_in fpTs (getT y) then ([proj1 y], [proj2 y]) else ([y], []);
+fun meta_unzip_rec getT left right nested fpTs y =
+ let val T = getT y in
+ if member (op =) fpTs T then ([left y], [right y])
+ else if exists_subtype_in fpTs T then ([nested y], [])
+ else ([y], [])
+ end;
fun project_co_recT special_Tname fpTs proj =
let
@@ -259,10 +266,7 @@
val project_recT = project_co_recT @{type_name prod};
val project_corecT = project_co_recT @{type_name sum};
-fun unzip_recT fpTs = meta_unzip_rec I (project_recT fpTs fst) (project_recT fpTs snd) fpTs;
-
-fun mk_fold_fun_typess y_Tsss Cs = map2 (fn C => map (fn y_Ts => y_Ts ---> C)) Cs y_Tsss;
-val mk_rec_fun_typess = mk_fold_fun_typess oo map o map o flat_rec o unzip_recT;
+fun unzip_recT fpTs = meta_unzip_rec I (project_recT fpTs fst) (project_recT fpTs snd) I fpTs;
fun mk_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
@@ -273,21 +277,40 @@
fun mk_fold_rec_args_types fpTs Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy =
let
+ val Css = map2 replicate ns Cs;
val y_Tsss = map3 mk_fun_arg_typess ns mss ctor_fold_fun_Ts;
- val g_Tss = mk_fold_fun_typess y_Tsss Cs;
+ val g_Tss = map2 (fn C => map (fn y_Ts => y_Ts ---> C)) Cs y_Tsss;
val ((gss, ysss), lthy) =
lthy
|> mk_Freess "f" g_Tss
||>> mk_Freesss "x" y_Tsss;
+ val yssss = map (map (map single)) ysss;
+
+ (* ### *)
+ fun dest_rec_prodT (T as Type (@{type_name prod}, Us as [_, U])) =
+ if member (op =) Cs U then Us else [T]
+ | dest_rec_prodT T = [T];
+
+ val z_Tssss =
+ map3 (fn n => fn ms => map2 (map dest_rec_prodT oo dest_tupleT) ms o
+ dest_sumTN_balanced n o domain_type) ns mss ctor_rec_fun_Ts;
val z_Tsss = map3 mk_fun_arg_typess ns mss ctor_rec_fun_Ts;
- val h_Tss = mk_rec_fun_typess fpTs z_Tsss Cs;
+ val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
val hss = map2 (map2 retype_free) h_Tss gss;
val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
+ val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
+ val (zssss_tl, lthy) =
+ lthy
+ |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
+ val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
+
+val _ = tracing (" *** OLD: " ^ PolyML.makestring (ysss, zsss)) (*###*)
+val _ = tracing (" *** NEW: " ^ PolyML.makestring (yssss, zssss)) (*###*)
in
- (((gss, g_Tss, ysss), (hss, h_Tss, zsss)), lthy)
+ (((gss, g_Tss, yssss), (hss, h_Tss, zssss)), lthy)
end;
fun mk_unfold_corec_args_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy =
@@ -438,18 +461,12 @@
| _ => build_simple TU);
in build end;
-fun mk_iter_body lthy fpTs ctor_iter fss xsss =
+fun mk_iter_body lthy fpTs ctor_iter fss xssss =
let
fun build_proj sel sel_const (x as Free (_, T)) =
build_map lthy (sel_const o fst) (T, project_recT fpTs sel T) $ x;
-
- (* TODO: Avoid these complications; cf. corec case *)
- val unzip_rec = meta_unzip_rec (snd o dest_Free) (build_proj fst fst_const)
- (build_proj snd snd_const) fpTs;
-
- fun mk_iter_arg f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f (flat_rec unzip_rec xs);
in
- Term.list_comb (ctor_iter, map2 (mk_sum_caseN_balanced oo map2 mk_iter_arg) fss xsss)
+ Term.list_comb (ctor_iter, map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss)
end;
fun mk_preds_getterss_join c cps sum_prod_T cqfss =
@@ -480,13 +497,13 @@
val fpT_to_C as Type (_, [fpT, _]) = snd (strip_typeN nn (fastype_of ctor_fold));
- fun generate_iter (suf, ctor_iter, (fss, f_Tss, xsss)) =
+ fun generate_iter (suf, ctor_iter, (fss, f_Tss, xssss)) =
let
val res_T = fold_rev (curry (op --->)) f_Tss fpT_to_C;
val binding = mk_binding suf;
val spec =
mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binding, res_T)),
- mk_iter_body lthy0 fpTs ctor_iter fss xsss);
+ mk_iter_body lthy0 fpTs ctor_iter fss xssss);
in (binding, spec) end;
val binding_specs =
@@ -558,7 +575,6 @@
val pre_map_defs = map map_def_of_bnf pre_bnfs;
val pre_set_defss = map set_defs_of_bnf pre_bnfs;
val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nesting_bnfs;
- val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
val nested_set_map's = maps set_map'_of_bnf nested_bnfs;
@@ -671,24 +687,47 @@
val mk_U = typ_subst_nonatomic (map2 pair fpTs Cs);
- fun unzip_iters fiters =
+ (* ### *)
+ fun typ_subst inst (T as Type (s, Ts)) =
+ (case AList.lookup (op =) inst T of
+ NONE => Type (s, map (typ_subst inst) Ts)
+ | SOME T' => T')
+ | typ_subst inst T = the_default T (AList.lookup (op =) inst T);
+
+ fun mk_U' maybe_mk_prodT =
+ typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
+
+ (* ### *)
+ fun build_rec_like fiters maybe_tick (T, U) =
+ if T = U then
+ id_const T
+ else
+ (case find_index (curry (op =) T) fpTs of
+ ~1 => build_map lthy (build_rec_like fiters maybe_tick) (T, U)
+ | kk => maybe_tick (nth us kk) (nth fiters kk));
+
+ fun unzip_iters fiters maybe_tick maybe_mk_prodT =
meta_unzip_rec (snd o dest_Free) I
(fn x as Free (_, T) => build_map lthy (indexify_fst fpTs (K o nth fiters))
- (T, mk_U T) $ x) fpTs;
+ (T, mk_U T) $ x)
+ (fn x as Free (_, T) => build_rec_like fiters maybe_tick (T, mk_U' maybe_mk_prodT T) $ x)
+ fpTs;
+
+ fun tick u f = Term.lambda u (HOLogic.mk_prod (u, f $ u));
val gxsss = map (map (flat_rec ((fn (ts, ts') => ([hd (ts' @ ts)], [])) o
- unzip_iters gfolds))) xsss;
- val hxsss = map (map (flat_rec (unzip_iters hrecs))) xsss;
+ unzip_iters gfolds (K I) (K I)))) xsss;
+ val hxsss = map (map (flat_rec (unzip_iters hrecs tick (curry HOLogic.mk_prodT)))) xsss;
val fold_goalss = map5 (map4 o mk_goal gss) gfolds xctrss gss xsss gxsss;
val rec_goalss = map5 (map4 o mk_goal hss) hrecs xctrss hss xsss hxsss;
val fold_tacss =
- map2 (map o mk_iter_tac pre_map_defs [] nesting_map_ids'' fold_defs) ctor_fold_thms
- ctr_defss;
+ map2 (map o mk_iter_tac pre_map_defs nesting_map_ids'' fold_defs)
+ ctor_fold_thms ctr_defss;
val rec_tacss =
- map2 (map o mk_iter_tac pre_map_defs nested_map_comp's
- (nested_map_ids'' @ nesting_map_ids'') rec_defs) ctor_rec_thms ctr_defss;
+ map2 (map o mk_iter_tac pre_map_defs (nested_map_ids'' @ nesting_map_ids'') rec_defs)
+ ctor_rec_thms ctr_defss;
fun prove goal tac =
Goal.prove_sorry lthy [] [] goal (tac o #context)