--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue Apr 30 10:07:41 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue Apr 30 10:58:25 2013 +0200
@@ -21,14 +21,11 @@
Proof.context ->
(thm * thm list * Args.src list) * (thm list list * Args.src list)
* (thm list list * Args.src list)
- val derive_coinduct_unfold_corec_thms_for_types: Proof.context -> Proof.context ->
- BNF_Def.BNF list -> thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.BNF list ->
- BNF_Def.BNF list -> typ list -> typ list -> typ list -> int list list -> int list list ->
- int list -> term list -> term list list -> term list list -> term list list list list ->
- term list list list list -> term list list -> term list list list list ->
- term list list list list -> term list list -> thm list list ->
- BNF_Ctr_Sugar.ctr_wrap_result list -> term list -> term list -> thm list -> thm list ->
- Proof.context ->
+ val derive_coinduct_unfold_corec_thms_for_types: BNF_Def.BNF list -> term list -> term list ->
+ thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.BNF list -> BNF_Def.BNF list ->
+ typ list -> typ list -> typ list -> int list list -> int list list -> int list ->
+ term list list -> thm list list -> BNF_Ctr_Sugar.ctr_wrap_result list -> term list ->
+ term list -> thm list -> thm list -> Proof.context ->
(thm * thm list * thm * thm list * Args.src list) * (thm list list * thm list list * 'e list)
* (thm list list * thm list list) * (thm list list * thm list list * Args.src list)
* (thm list list * thm list list * Args.src list)
@@ -158,6 +155,12 @@
maps fst ps @ maps snd ps
end;
+fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss);
+
+fun flat_preds_predsss_gettersss [] [qss] [fss] = flat_predss_getterss qss fss
+ | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
+ p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
+
fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) =
Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1));
@@ -196,23 +199,86 @@
fun mk_rec_like_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
-fun project_recT fpTs proj =
+fun massage_rec_fun_arg_typesss fpTs =
let
- fun project (Type (s as @{type_name prod}, Ts as [T, U])) =
- if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts)
- | project (Type (s, Ts)) = Type (s, map project Ts)
- | project T = T;
- in project end;
-
-fun unzip_recT fpTs T =
- if exists_subtype_in fpTs T then ([project_recT fpTs fst T], [project_recT fpTs snd T])
- else ([T], []);
-
-fun massage_rec_fun_arg_typesss fpTs = map (map (flat_rec (unzip_recT fpTs)));
+ fun project_recT proj =
+ let
+ fun project (Type (s as @{type_name prod}, Ts as [T, U])) =
+ if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts)
+ | project (Type (s, Ts)) = Type (s, map project Ts)
+ | project T = T;
+ in project end;
+ fun unzip_recT T =
+ if exists_subtype_in fpTs T then ([project_recT fst T], [project_recT snd T]) else ([T], []);
+ in
+ map (map (flat_rec unzip_recT))
+ end;
val mk_fold_fun_typess = map2 (map2 (curry (op --->)));
val mk_rec_fun_typess = mk_fold_fun_typess oo massage_rec_fun_arg_typesss;
+fun mk_corec_like_pred_types n = replicate (Int.max (0, n - 1)) o mk_pred1T;
+
+fun mk_unfold_corec_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts =
+ let
+ (*avoid "'a itself" arguments in coiterators and corecursors*)
+ fun repair_arity [0] = [1]
+ | repair_arity ms = ms;
+
+ fun project_corecT proj =
+ let
+ fun project (Type (s as @{type_name sum}, Ts as [T, U])) =
+ if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts)
+ | project (Type (s, Ts)) = Type (s, map project Ts)
+ | project T = T;
+ in project end;
+
+ fun unzip_corecT T =
+ if exists_subtype_in fpTs T then [project_corecT fst T, project_corecT snd T] else [T];
+
+ val p_Tss = map2 mk_corec_like_pred_types ns Cs;
+
+ fun mk_types maybe_unzipT fun_Ts =
+ let
+ val f_sum_prod_Ts = map range_type fun_Ts;
+ val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
+ val f_Tsss = map2 (map2 dest_tupleT o repair_arity) mss f_prod_Tss;
+ val f_Tssss =
+ map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
+ val q_Tssss =
+ map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss;
+ val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
+ in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end;
+ in
+ (p_Tss, mk_types single dtor_unfold_fun_Ts, mk_types unzip_corecT dtor_corec_fun_Ts)
+ end
+
+fun mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss lthy =
+ let
+ val (((cs, pss), gssss), lthy) =
+ lthy
+ |> mk_Frees "a" Cs
+ ||>> mk_Freess "p" p_Tss
+ ||>> mk_Freessss "g" g_Tssss;
+ val rssss = map (map (map (fn [] => []))) r_Tssss;
+
+ val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
+ val ((sssss, hssss_tl), lthy) =
+ lthy
+ |> mk_Freessss "q" s_Tssss
+ ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
+ val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
+ in
+ ((cs, pss, (gssss, rssss), (hssss, sssss)), lthy)
+ end;
+
+fun mk_corec_like_terms cs pss qssss fssss =
+ let
+ val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
+ val cqssss = map2 (map o map o map o rapp) cs qssss;
+ val cfssss = map2 (map o map o map o rapp) cs fssss;
+ in (pfss, cqssss, cfssss) end;
+
fun mk_map live Ts Us t =
let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
@@ -440,10 +506,9 @@
(fold_thmss, code_simp_attrs), (rec_thmss, code_simp_attrs))
end;
-fun derive_coinduct_unfold_corec_thms_for_types names_lthy0 no_defs_lthy pre_bnfs dtor_coinduct
+fun derive_coinduct_unfold_corec_thms_for_types pre_bnfs dtor_unfolds0 dtor_corecs0 dtor_coinduct
dtor_strong_induct dtor_ctors dtor_unfold_thms dtor_corec_thms nesting_bnfs nested_bnfs fpTs Cs
- As kss mss ns cs cpss pgss crssss cgssss phss csssss chssss ctrss ctr_defss ctr_wrap_ress
- unfolds corecs unfold_defs corec_defs lthy =
+ As kss mss ns ctrss ctr_defss ctr_wrap_ress unfolds corecs unfold_defs corec_defs lthy =
let
val nn = length pre_bnfs;
@@ -457,6 +522,9 @@
val fp_b_names = map base_name_of_typ fpTs;
+ val (_, dtor_unfold_fun_Ts) = mk_fp_rec_like false As Cs dtor_unfolds0;
+ val (_, dtor_corec_fun_Ts) = mk_fp_rec_like false As Cs dtor_corecs0;
+
val discss = map (map (mk_disc_or_sel As) o #discs) ctr_wrap_ress;
val selsss = map (map (map (mk_disc_or_sel As)) o #selss) ctr_wrap_ress;
val exhausts = map #exhaust ctr_wrap_ress;
@@ -470,6 +538,15 @@
||>> Variable.variant_fixes fp_b_names
||>> Variable.variant_fixes (map (suffix "'") fp_b_names);
+ val (p_Tss, (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss),
+ (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss)) =
+ mk_unfold_corec_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts;
+
+ val ((cs, pss, (gssss, rssss), (hssss, sssss)), names_lthy) =
+ mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss names_lthy;
+
+ val cpss = map2 (map o rapp) cs pss;
+
val us = map2 (curry Free) us' fpTs;
val udiscss = map2 (map o rapp) us discss;
val uselsss = map2 (map o map o rapp) us selsss;
@@ -478,6 +555,9 @@
val vdiscss = map2 (map o rapp) vs discss;
val vselsss = map2 (map o map o rapp) vs selsss;
+ val (pgss, crssss, cgssss) = mk_corec_like_terms cs pss rssss gssss;
+ val (phss, csssss, chssss) = mk_corec_like_terms cs pss sssss hssss;
+
val ((coinduct_thms, coinduct_thm), (strong_coinduct_thms, strong_coinduct_thm)) =
let
val uvrs = map3 (fn r => fn u => fn v => r $ u $ v) rs us vs;
@@ -652,7 +732,7 @@
fun prove goal tac =
Goal.prove_sorry lthy [] [] goal (tac o #context)
- |> singleton (Proof_Context.export names_lthy0 no_defs_lthy)
+ |> singleton (Proof_Context.export names_lthy lthy)
|> Thm.close_derivation;
fun proves [_] [_] = []
@@ -894,68 +974,18 @@
end
else
let
- (*avoid "'a itself" arguments in coiterators and corecursors*)
- val mss' = map (fn [0] => [1] | ms => ms) mss;
-
- val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs;
-
- fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss);
-
- fun flat_preds_predsss_gettersss [] [qss] [fss] = flat_predss_getterss qss fss
- | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
- p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
-
- fun mk_types maybe_unzipT fun_Ts =
- let
- val f_sum_prod_Ts = map range_type fun_Ts;
- val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
- val f_Tsss = map2 (map2 dest_tupleT) mss' f_prod_Tss;
- val f_Tssss =
- map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
- val q_Tssss =
- map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss;
- val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
- in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end;
-
- val (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
+ val (p_Tss, (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss),
+ (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss)) =
+ mk_unfold_corec_types fpTs Cs ns mss fp_fold_fun_Ts fp_rec_fun_Ts;
- val (((cs, pss), gssss), lthy) =
- lthy
- |> mk_Frees "a" Cs
- ||>> mk_Freess "p" p_Tss
- ||>> mk_Freessss "g" g_Tssss;
- val rssss = map (map (map (fn [] => []))) r_Tssss;
-
- fun proj_corecT proj (Type (s as @{type_name sum}, Ts as [T, U])) =
- if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_corecT proj) Ts)
- | proj_corecT proj (Type (s, Ts)) = Type (s, map (proj_corecT proj) Ts)
- | proj_corecT _ T = T;
-
- fun unzip_corecT T =
- if exists_subtype_in fpTs T then [proj_corecT fst T, proj_corecT snd T] else [T];
-
- val (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss) =
- mk_types unzip_corecT fp_rec_fun_Ts;
-
- val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
- val ((sssss, hssss_tl), lthy) =
- lthy
- |> mk_Freessss "q" s_Tssss
- ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
- val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
+ val ((cs, pss, (gssss, rssss), (hssss, sssss)), lthy) =
+ mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss lthy;
val cpss = map2 (map o rapp) cs pss;
-
- fun mk_terms qssss fssss =
- let
- val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
- val cqssss = map2 (map o map o map o rapp) cs qssss;
- val cfssss = map2 (map o map o map o rapp) cs fssss;
- in (pfss, cqssss, cfssss) end;
in
(((([], [], []), ([], [], [])),
- (cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, g_Tsss, pg_Tss)),
- (mk_terms sssss hssss, (h_sum_prod_Ts, h_Tsss, ph_Tss)))), lthy)
+ (cs, cpss, (mk_corec_like_terms cs pss rssss gssss, (g_sum_prod_Ts, g_Tsss, pg_Tss)),
+ (mk_corec_like_terms cs pss sssss hssss, (h_sum_prod_Ts, h_Tsss, ph_Tss)))), lthy)
end;
fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
@@ -1311,10 +1341,9 @@
(disc_unfold_thmss, disc_corec_thmss, disc_corec_like_attrs),
(disc_unfold_iff_thmss, disc_corec_iff_thmss, disc_corec_like_iff_attrs),
(sel_unfold_thmss, sel_corec_thmss, sel_corec_like_attrs)) =
- derive_coinduct_unfold_corec_thms_for_types names_lthy0 no_defs_lthy pre_bnfs fp_induct
+ derive_coinduct_unfold_corec_thms_for_types pre_bnfs fp_folds0 fp_recs0 fp_induct
fp_strong_induct dtor_ctors fp_fold_thms fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs As
- kss mss ns cs cpss pgss crssss cgssss phss csssss chssss ctrss ctr_defss ctr_wrap_ress
- unfolds corecs unfold_defs corec_defs lthy;
+ kss mss ns ctrss ctr_defss ctr_wrap_ress unfolds corecs unfold_defs corec_defs lthy;
fun coinduct_type_attr T_name = Attrib.internal (K (Induct.coinduct_type T_name));