--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue Apr 30 03:18:07 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue Apr 30 09:53:56 2013 +0200
@@ -15,12 +15,12 @@
val fp_of: Proof.context -> string -> fp option
- val derive_induct_fold_rec_thms_for_types: BNF_Def.BNF list -> thm -> thm list -> thm list ->
- BNF_Def.BNF list -> BNF_Def.BNF list -> typ list -> typ list -> typ list list list ->
- int list list -> int list -> term list list -> term list list -> term list list -> term list
- list list -> thm list list -> term list -> term list -> thm list -> thm list -> Proof.context ->
+ val derive_induct_fold_rec_thms_for_types: BNF_Def.BNF list -> term list -> term list -> thm ->
+ thm list -> thm list -> BNF_Def.BNF list -> BNF_Def.BNF list -> typ list -> typ list ->
+ typ list -> term list list -> thm list list -> term list -> term list -> thm list -> thm list ->
+ Proof.context ->
(thm * thm list * Args.src list) * (thm list list * Args.src list)
- * (thm list list * Args.src list)
+ * (thm list list * Args.src list)
val derive_coinduct_unfold_corec_thms_for_types: Proof.context -> Proof.context ->
BNF_Def.BNF list -> thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.BNF list ->
BNF_Def.BNF list -> typ list -> typ list -> typ list -> int list list -> int list list ->
@@ -188,6 +188,31 @@
Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
end;
+val mk_fp_rec_like_fun_types = fst o split_last o binder_types o fastype_of o hd;
+
+fun mk_fp_rec_like lfp As Cs fp_rec_likes0 =
+ map (mk_rec_like lfp As Cs) fp_rec_likes0
+ |> (fn ts => (ts, mk_fp_rec_like_fun_types ts));
+
+fun mk_rec_like_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
+
+fun project_recT fpTs proj =
+ let
+ fun project (Type (s as @{type_name prod}, Ts as [T, U])) =
+ if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts)
+ | project (Type (s, Ts)) = Type (s, map project Ts)
+ | project T = T;
+ in project end;
+
+fun unzip_recT fpTs T =
+ if exists_subtype_in fpTs T then ([project_recT fpTs fst T], [project_recT fpTs snd T])
+ else ([T], []);
+
+fun massage_rec_fun_arg_typesss fpTs = map (map (flat_rec (unzip_recT fpTs)));
+
+val mk_fold_fun_typess = map2 (map2 (curry (op --->)));
+val mk_rec_fun_typess = mk_fold_fun_typess oo massage_rec_fun_arg_typesss;
+
fun mk_map live Ts Us t =
let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
@@ -243,11 +268,16 @@
val Ts' = map domain_type (fst (strip_typeN live (fastype_of rel)));
in Term.list_comb (rel, map build_arg Ts') end;
-fun derive_induct_fold_rec_thms_for_types pre_bnfs ctor_induct ctor_fold_thms ctor_rec_thms
- nesting_bnfs nested_bnfs fpTs Cs ctr_Tsss mss ns gss hss ctrss xsss ctr_defss folds recs
- fold_defs rec_defs lthy =
+fun derive_induct_fold_rec_thms_for_types pre_bnfs ctor_folds0 ctor_recs0 ctor_induct ctor_fold_thms
+ ctor_rec_thms nesting_bnfs nested_bnfs fpTs Cs As ctrss ctr_defss folds recs fold_defs rec_defs
+ lthy =
let
+ val ctr_Tsss = map (map (binder_types o fastype_of)) ctrss;
+
val nn = length pre_bnfs;
+ val ns = map length ctr_Tsss;
+ val mss = map (map length) ctr_Tsss;
+ val Css = map2 replicate ns Cs;
val pre_map_defs = map map_def_of_bnf pre_bnfs;
val pre_set_defss = map set_defs_of_bnf pre_bnfs;
@@ -258,11 +288,23 @@
val fp_b_names = map base_name_of_typ fpTs;
- val (((ps, ps'), us'), names_lthy) =
+ val (_, ctor_fold_fun_Ts) = mk_fp_rec_like true As Cs ctor_folds0;
+ val (_, ctor_rec_fun_Ts) = mk_fp_rec_like true As Cs ctor_recs0;
+
+ val y_Tsss = map3 mk_rec_like_fun_arg_typess ns mss ctor_fold_fun_Ts;
+ val g_Tss = mk_fold_fun_typess y_Tsss Css;
+
+ val z_Tsss = map3 mk_rec_like_fun_arg_typess ns mss ctor_rec_fun_Ts;
+ val h_Tss = mk_rec_fun_typess fpTs z_Tsss Css;
+
+ val (((((ps, ps'), xsss), gss), us'), names_lthy) =
lthy
|> mk_Frees' "P" (map mk_pred1T fpTs)
+ ||>> mk_Freesss "x" ctr_Tsss
+ ||>> mk_Freess "f" g_Tss
||>> Variable.variant_fixes fp_b_names;
+ val hss = map2 (map2 retype_free) h_Tss gss;
val us = map2 (curry Free) us' fpTs;
fun mk_sets_nested bnf =
@@ -831,40 +873,24 @@
val mss = map (map length) ctr_Tsss;
val Css = map2 replicate ns Cs;
- val fp_folds as any_fp_fold :: _ = map (mk_rec_like lfp As Cs) fp_folds0;
- val fp_recs as any_fp_rec :: _ = map (mk_rec_like lfp As Cs) fp_recs0;
+ val (fp_folds, fp_fold_fun_Ts) = mk_fp_rec_like lfp As Cs fp_folds0;
+ val (fp_recs, fp_rec_fun_Ts) = mk_fp_rec_like lfp As Cs fp_recs0;
- val fp_fold_fun_Ts = fst (split_last (binder_types (fastype_of any_fp_fold)));
- val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of any_fp_rec)));
-
- val (((fold_only as (gss, _, _), rec_only as (hss, _, _)),
+ val (((fold_only, rec_only),
(cs, cpss, unfold_only as ((pgss, crssss, cgssss), (_, g_Tsss, _)),
corec_only as ((phss, csssss, chssss), (_, h_Tsss, _)))), names_lthy0) =
if lfp then
let
- val y_Tsss =
- map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
- ns mss fp_fold_fun_Ts;
- val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
+ val y_Tsss = map3 mk_rec_like_fun_arg_typess ns mss fp_fold_fun_Ts;
+ val g_Tss = mk_fold_fun_typess y_Tsss Css;
val ((gss, ysss), lthy) =
lthy
|> mk_Freess "f" g_Tss
||>> mk_Freesss "x" y_Tsss;
- fun proj_recT proj (Type (s as @{type_name prod}, Ts as [T, U])) =
- if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_recT proj) Ts)
- | proj_recT proj (Type (s, Ts)) = Type (s, map (proj_recT proj) Ts)
- | proj_recT _ T = T;
-
- fun unzip_recT T =
- if exists_subtype_in fpTs T then ([proj_recT fst T], [proj_recT snd T]) else ([T], []);
-
- val z_Tsss =
- map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
- ns mss fp_rec_fun_Ts;
- val z_Tsss' = map (map (flat_rec unzip_recT)) z_Tsss;
- val h_Tss = map2 (map2 (curry (op --->))) z_Tsss' Css;
+ val z_Tsss = map3 mk_rec_like_fun_arg_typess ns mss fp_rec_fun_Ts;
+ val h_Tss = mk_rec_fun_typess fpTs z_Tsss Css;
val hss = map2 (map2 retype_free) h_Tss gss;
val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
@@ -1252,14 +1278,14 @@
injects @ distincts @ case_thms @ rec_likes @ fold_likes);
fun derive_and_note_induct_fold_rec_thms_for_types
- (((ctrss, xsss, ctr_defss, ctr_wrap_ress), (folds, recs, fold_defs, rec_defs)), lthy) =
+ (((ctrss, _, ctr_defss, ctr_wrap_ress), (folds, recs, fold_defs, rec_defs)), lthy) =
let
val ((induct_thm, induct_thms, induct_attrs),
(fold_thmss, fold_attrs),
(rec_thmss, rec_attrs)) =
- derive_induct_fold_rec_thms_for_types pre_bnfs fp_induct fp_fold_thms fp_rec_thms
- nesting_bnfs nested_bnfs fpTs Cs ctr_Tsss mss ns gss hss ctrss xsss ctr_defss folds recs
- fold_defs rec_defs lthy;
+ derive_induct_fold_rec_thms_for_types pre_bnfs fp_folds0 fp_recs0 fp_induct fp_fold_thms
+ fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs As ctrss ctr_defss folds recs fold_defs
+ rec_defs lthy;
fun induct_type_attr T_name = Attrib.internal (K (Induct.induct_type T_name));