--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue May 07 10:34:55 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue May 07 10:35:40 2013 +0200
@@ -31,6 +31,8 @@
(typ list * typ list) list list list
val mk_fold_recs: Proof.context -> typ list -> typ list -> typ list -> int list ->
int list list -> term list -> term list -> term list * term list
+ val mk_unfold_corecs: Proof.context -> typ list -> typ list -> typ list -> int list ->
+ int list list -> term list -> term list -> term list * term list
val derive_induct_fold_rec_thms_for_types: BNF_Def.bnf list -> term list -> term list -> thm ->
thm list -> thm list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list ->
@@ -229,7 +231,7 @@
val massage_rec_fun_arg_typesss = map o map o flat_rec o unzip_recT;
-val mk_fold_fun_typess = map2 (map2 (curry (op --->)));
+fun mk_fold_fun_typess y_Tsss Cs = map2 (fn C => map (fn y_Ts => y_Ts ---> C)) Cs y_Tsss;
val mk_rec_fun_typess = mk_fold_fun_typess oo massage_rec_fun_arg_typesss;
fun mk_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
@@ -239,10 +241,10 @@
#> map3 mk_fun_arg_typess ns mss
#> map (map (map (unzip_recT fpTs)));
-fun mk_fold_rec_args_types fpTs Css ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy =
+fun mk_fold_rec_args_types fpTs Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy =
let
val y_Tsss = map3 mk_fun_arg_typess ns mss ctor_fold_fun_Ts;
- val g_Tss = mk_fold_fun_typess y_Tsss Css;
+ val g_Tss = mk_fold_fun_typess y_Tsss Cs;
val ((gss, ysss), lthy) =
lthy
@@ -250,7 +252,7 @@
||>> mk_Freesss "x" y_Tsss;
val z_Tsss = map3 mk_fun_arg_typess ns mss ctor_rec_fun_Ts;
- val h_Tss = mk_rec_fun_typess fpTs z_Tsss Css;
+ val h_Tss = mk_rec_fun_typess fpTs z_Tsss Cs;
val hss = map2 (map2 retype_free) h_Tss gss;
val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
@@ -258,7 +260,7 @@
(((gss, g_Tss, ysss), (hss, h_Tss, zsss)), lthy)
end;
-fun mk_unfold_corec_terms_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy =
+fun mk_unfold_corec_args_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy =
let
(*avoid "'a itself" arguments in coiterators and corecursors*)
fun repair_arity [0] = [1]
@@ -282,15 +284,14 @@
val f_sum_prod_Ts = map range_type fun_Ts;
val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
val f_Tsss = map2 (map2 dest_tupleT o repair_arity) mss f_prod_Tss;
- val f_Tssss =
- map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
+ val f_Tssss = map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
val q_Tssss =
map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss;
val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
- in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end;
+ in (q_Tssss, f_Tssss, (f_sum_prod_Ts, f_Tsss, pf_Tss)) end;
- val (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss) = mk_types single dtor_unfold_fun_Ts;
- val (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss) = mk_types unzip_corecT dtor_corec_fun_Ts;
+ val (r_Tssss, g_Tssss, unfold_types) = mk_types single dtor_unfold_fun_Ts;
+ val (s_Tssss, h_Tssss, corec_types) = mk_types unzip_corecT dtor_corec_fun_Ts;
val (((cs, pss), gssss), lthy) =
lthy
@@ -308,18 +309,17 @@
val cpss = map2 (map o rapp) cs pss;
- fun mk_terms qssss fssss =
+ fun mk_args qssss fssss =
let
val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
val cqssss = map2 (map o map o map o rapp) cs qssss;
val cfssss = map2 (map o map o map o rapp) cs fssss;
in (pfss, cqssss, cfssss) end;
- val unfold_terms = mk_terms rssss gssss;
- val corec_terms = mk_terms sssss hssss;
+ val unfold_args = mk_args rssss gssss;
+ val corec_args = mk_args sssss hssss;
in
- ((cs, cpss, (unfold_terms, (g_sum_prod_Ts, g_Tsss, pg_Tss)),
- (corec_terms, (h_sum_prod_Ts, h_Tsss, ph_Tss))), lthy)
+ ((cs, cpss, (unfold_args, unfold_types), (corec_args, corec_types)), lthy)
end;
fun mk_map live Ts Us t =
@@ -407,19 +407,16 @@
fun mk_fold_recs lthy fpTs As Cs ns mss ctor_folds ctor_recs =
let
- val Css = map2 replicate ns Cs;
-
val (_, ctor_fold_fun_Ts) = mk_fp_iter true As Cs ctor_folds;
val (_, ctor_rec_fun_Ts) = mk_fp_iter true As Cs ctor_recs;
val (((gss, _, ysss), (hss, _, zsss)), _) =
- mk_fold_rec_args_types fpTs Css ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy;
+ mk_fold_rec_args_types fpTs Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy;
fun mk_term ctor_iter fss xsss =
fold_rev (fold_rev Term.lambda) fss (mk_iter_body lthy fpTs ctor_iter fss xsss);
- fun mk_terms ctor_fold ctor_rec =
- (mk_term ctor_fold gss ysss, mk_term ctor_rec hss zsss)
+ fun mk_terms ctor_fold ctor_rec = (mk_term ctor_fold gss ysss, mk_term ctor_rec hss zsss);
in
map2 mk_terms ctor_folds ctor_recs |> split_list
end;
@@ -428,15 +425,37 @@
Term.lambda c (mk_IfN sum_prod_T cps
(map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cqfss) (1 upto n)));
-fun mk_coiter_body cs ns cpss f_sum_prod_Ts cqfsss dtor_coiter =
- Term.list_comb (dtor_coiter, map5 mk_preds_getterss_join cs ns cpss f_sum_prod_Ts cqfsss);
+fun mk_coiter_body lthy cs ns cpss f_sum_prod_Ts f_Tsss cqssss cfssss dtor_coiter =
+ let
+ fun build_sum_inj mk_inj = build_map lthy (uncurry mk_inj o dest_sumT o snd);
+
+ fun build_dtor_coiter_arg _ [] [cf] = cf
+ | build_dtor_coiter_arg T [cq] [cf, cf'] =
+ mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
+ (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
+
+ val cqfsss = map3 (map3 (map3 build_dtor_coiter_arg)) f_Tsss cqssss cfssss;
+ in
+ Term.list_comb (dtor_coiter, map5 mk_preds_getterss_join cs ns cpss f_sum_prod_Ts cqfsss)
+ end;
-(*###
- fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
- fp_fold), fp_rec), ctor_dtor), dtor_ctor), ctor_inject), pre_map_def), pre_set_defs),
- pre_rel_def), fp_map_thm), fp_set_thms), fp_rel_thm), n), ks), ms), ctr_bindings),
- ctr_mixfixes), ctr_Tss), disc_bindings), sel_bindingss), raw_sel_defaultss) no_defs_lthy =
-*)
+fun mk_unfold_corecs lthy fpTs As Cs ns mss dtor_unfolds dtor_corecs =
+ let
+ val (_, dtor_unfold_fun_Ts) = mk_fp_iter true As Cs dtor_unfolds;
+ val (_, dtor_corec_fun_Ts) = mk_fp_iter true As Cs dtor_corecs;
+
+ val ((cs, cpss, unfold_only, corec_only), _) =
+ mk_unfold_corec_args_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy;
+
+ fun mk_term dtor_coiter ((pfss, cqssss, cfssss), (f_sum_prod_Ts, f_Tsss, _)) =
+ fold_rev (fold_rev Term.lambda) pfss
+ (mk_coiter_body lthy cs ns cpss f_sum_prod_Ts f_Tsss cqssss cfssss dtor_coiter);
+
+ fun mk_terms dtor_unfold dtor_corec =
+ (mk_term dtor_unfold unfold_only, mk_term dtor_corec corec_only);
+ in
+ map2 mk_terms dtor_unfolds dtor_corecs |> split_list
+ end;
fun derive_induct_fold_rec_thms_for_types pre_bnfs ctor_folds0 ctor_recs0 ctor_induct ctor_fold_thms
ctor_rec_thms nesting_bnfs nested_bnfs fpTs Cs As ctrss ctr_defss folds recs fold_defs rec_defs
@@ -447,7 +466,6 @@
val nn = length pre_bnfs;
val ns = map length ctr_Tsss;
val mss = map (map length) ctr_Tsss;
- val Css = map2 replicate ns Cs;
val pre_map_defs = map map_def_of_bnf pre_bnfs;
val pre_set_defss = map set_defs_of_bnf pre_bnfs;
@@ -462,7 +480,7 @@
val (_, ctor_rec_fun_Ts) = mk_fp_iter true As Cs ctor_recs0;
val (((gss, _, _), (hss, _, _)), names_lthy0) =
- mk_fold_rec_args_types fpTs Css ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy;
+ mk_fold_rec_args_types fpTs Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy;
val ((((ps, ps'), xsss), us'), names_lthy) =
names_lthy0
@@ -625,7 +643,7 @@
val sel_thmsss = map #sel_thmss ctr_sugars;
val ((cs, cpss, ((pgss, crssss, cgssss), _), ((phss, csssss, chssss), _)), names_lthy0) =
- mk_unfold_corec_terms_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy;
+ mk_unfold_corec_args_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy;
val (((rs, us'), vs'), names_lthy) =
names_lthy0
@@ -752,6 +770,7 @@
val crgsss = map2 (map2 (map2 (intr_coiters gunfolds))) crssss cgssss;
val cshsss = map2 (map2 (map2 (intr_coiters hcorecs))) csssss chssss;
+val _ = tracing ("*** cshsss1: " ^ PolyML.makestring cshsss) (*###*)
val unfold_goalss = map8 (map4 oooo mk_goal pgss) cs cpss gunfolds ns kss ctrss mss crgsss;
val corec_goalss = map8 (map4 oooo mk_goal phss) cs cpss hcorecs ns kss ctrss mss cshsss;
@@ -1029,19 +1048,16 @@
val ns = map length ctr_Tsss;
val kss = map (fn n => 1 upto n) ns;
val mss = map (map length) ctr_Tsss;
- val Css = map2 replicate ns Cs;
val (fp_folds, fp_fold_fun_Ts) = mk_fp_iter lfp As Cs fp_folds0;
val (fp_recs, fp_rec_fun_Ts) = mk_fp_iter lfp As Cs fp_recs0;
- val (((fold_only, rec_only),
- (cs, cpss, unfold_only as ((_, crssss, cgssss), (_, g_Tsss, _)),
- corec_only as ((_, csssss, chssss), (_, h_Tsss, _)))), _) =
+ val (((fold_only, rec_only), (cs, cpss, unfold_only, corec_only)), _) =
if lfp then
- mk_fold_rec_args_types fpTs Css ns mss fp_fold_fun_Ts fp_rec_fun_Ts lthy
+ mk_fold_rec_args_types fpTs Cs ns mss fp_fold_fun_Ts fp_rec_fun_Ts lthy
|>> rpair ([], [], (([], [], []), ([], [], [])), (([], [], []), ([], [], [])))
else
- mk_unfold_corec_terms_types fpTs Cs ns mss fp_fold_fun_Ts fp_rec_fun_Ts lthy
+ mk_unfold_corec_args_types fpTs Cs ns mss fp_fold_fun_Ts fp_rec_fun_Ts lthy
|>> pair (([], [], []), ([], [], []));
fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
@@ -1229,21 +1245,16 @@
fun define_fold_rec no_defs_lthy =
let
- val fpT_to_C = fpT --> C;
-
fun generate_iter (suf, ctor_iter, (fss, f_Tss, xsss)) =
let
- val res_T = fold_rev (curry (op --->)) f_Tss fpT_to_C;
+ val res_T = fold_rev (curry (op --->)) f_Tss (fpT --> C);
val binding = qualify false fp_b_name (Binding.suffix_name ("_" ^ suf) fp_b);
val spec =
mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binding, res_T)),
mk_iter_body no_defs_lthy fpTs ctor_iter fss xsss);
in (binding, spec) end;
- val iter_infos =
- [(foldN, fp_fold, fold_only),
- (recN, fp_rec, rec_only)];
-
+ val iter_infos = [(foldN, fp_fold, fold_only), (recN, fp_rec, rec_only)];
val (bindings, specs) = map generate_iter iter_infos |> split_list;
val ((csts, defs), (lthy', lthy)) = no_defs_lthy
@@ -1263,32 +1274,18 @@
fun define_unfold_corec no_defs_lthy =
let
- val B_to_fpT = C --> fpT;
-
- fun build_sum_inj mk_inj = build_map lthy (uncurry mk_inj o dest_sumT o snd);
-
- fun build_dtor_coiter_arg _ [] [cf] = cf
- | build_dtor_coiter_arg T [cq] [cf, cf'] =
- mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
- (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
-
- val crgsss = map3 (map3 (map3 build_dtor_coiter_arg)) g_Tsss crssss cgssss;
- val cshsss = map3 (map3 (map3 build_dtor_coiter_arg)) h_Tsss csssss chssss;
-
- fun generate_coiter (suf, dtor_coiter, (cqfsss, ((pfss, _, _),
- (f_sum_prod_Ts, _, pf_Tss)))) =
+ fun generate_coiter (suf, dtor_coiter, ((pfss, cqssss, cfssss),
+ (f_sum_prod_Ts, f_Tsss, pf_Tss))) =
let
- val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
+ val res_T = fold_rev (curry (op --->)) pf_Tss (C --> fpT);
val binding = qualify false fp_b_name (Binding.suffix_name ("_" ^ suf) fp_b);
val spec =
mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binding, res_T)),
- mk_coiter_body cs ns cpss f_sum_prod_Ts cqfsss dtor_coiter);
+ mk_coiter_body no_defs_lthy cs ns cpss f_sum_prod_Ts f_Tsss cqssss cfssss
+ dtor_coiter);
in (binding, spec) end;
- val coiter_infos =
- [(unfoldN, fp_fold, (crgsss, unfold_only)),
- (corecN, fp_rec, (cshsss, corec_only))];
-
+ val coiter_infos = [(unfoldN, fp_fold, unfold_only), (corecN, fp_rec, corec_only)];
val (bindings, specs) = map generate_coiter coiter_infos |> split_list;
val ((csts, defs), (lthy', lthy)) = no_defs_lthy