--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Fri Jun 07 22:13:04 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Fri Jun 07 22:17:19 2013 -0400
@@ -1,6 +1,6 @@
(* Title: HOL/BNF/Tools/bnf_fp_def_sugar.ML
Author: Jasmin Blanchette, TU Muenchen
- Copyright 2012
+ Copyright 2012, 2013
Sugared datatype and codatatype constructions.
*)
@@ -15,12 +15,9 @@
fp_res: BNF_FP_Util.fp_result,
ctr_defss: thm list list,
ctr_sugars: BNF_Ctr_Sugar.ctr_sugar list,
- un_folds: term list,
- co_recs: term list,
- co_induct: thm,
- strong_co_induct: thm,
- un_fold_thmss: thm list list,
- co_rec_thmss: thm list list};
+ co_iterss: term list list,
+ co_inducts: thm list,
+ co_iter_thmsss: thm list list list};
val of_fp_sugar: (fp_sugar -> 'a list) -> fp_sugar -> 'a
val morph_fp_sugar: morphism -> fp_sugar -> fp_sugar
@@ -31,17 +28,16 @@
val flat_rec: 'a list list -> 'a list
val mk_co_iter: theory -> BNF_FP_Util.fp_kind -> typ -> typ list -> term -> term
val nesty_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
- val mk_un_fold_co_rec_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
+ val mk_map: int -> typ list -> typ list -> term -> term
+ val build_map: local_theory -> (typ * typ -> term) -> typ * typ -> term
+ val mk_co_iters_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
int list list -> term list list -> Proof.context ->
(term list list
* (typ list list * typ list list list list * term list list
* term list list list list) list option
- * (term list * term list list
- * ((term list list * term list list list list * term list list list list)
- * (typ list * typ list list list * typ list list)) list) option)
+ * (string * term list * term list list
+ * ((term list list * term list list list) * (typ list * typ list list)) list) option)
* Proof.context
- val mk_map: int -> typ list -> typ list -> term -> term
- val build_map: local_theory -> (typ * typ -> term) -> typ * typ -> term
val mk_iter_fun_arg_typessss: typ list -> int list -> int list list -> term ->
typ list list list list
@@ -49,23 +45,24 @@
(typ list list * typ list list list list * term list list * term list list list list) list ->
(string -> binding) -> typ list -> typ list -> term list -> Proof.context ->
(term list * thm list) * Proof.context
- val define_coiters: string list -> term list * term list list
- * ((term list list * term list list list list * term list list list list)
- * (typ list * typ list list list * typ list list)) list ->
+ val define_coiters: string list -> string * term list * term list list
+ * ((term list list * term list list list) * (typ list * typ list list)) list ->
(string -> binding) -> typ list -> typ list -> term list -> Proof.context ->
(term list * thm list) * Proof.context
- val derive_induct_iters_thms_for_types: BNF_Def.bnf list -> term list list ->
+ val derive_induct_iters_thms_for_types: BNF_Def.bnf list ->
(typ list list * typ list list list list * term list list * term list list list list) list ->
- thm -> thm list list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list ->
+ thm list -> thm list list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list ->
typ list -> typ list list list -> term list list -> thm list list -> term list list ->
thm list list -> local_theory ->
- (thm * thm list * Args.src list) * (thm list list * Args.src list)
+ (thm list * thm * Args.src list) * (thm list list * Args.src list)
* (thm list list * Args.src list)
- val derive_coinduct_coiters_thms_for_types: BNF_Def.bnf list -> term list list -> thm ->
- thm -> thm list -> thm list list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list ->
+ val derive_coinduct_coiters_thms_for_types: BNF_Def.bnf list ->
+ string * term list * term list list * ((term list list * term list list list)
+ * (typ list * typ list list)) list ->
+ thm list -> thm list -> thm list list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list ->
typ list -> typ list -> int list list -> int list list -> int list -> thm list list ->
BNF_Ctr_Sugar.ctr_sugar list -> term list list -> thm list list -> local_theory ->
- (thm * thm list * thm * thm list * Args.src list)
+ ((thm list * thm) list * Args.src list)
* (thm list list * thm list list * Args.src list)
* (thm list list * thm list list) * (thm list list * thm list list * Args.src list)
* (thm list list * thm list list * Args.src list)
@@ -103,12 +100,9 @@
fp_res: fp_result,
ctr_defss: thm list list,
ctr_sugars: ctr_sugar list,
- un_folds: term list,
- co_recs: term list,
- co_induct: thm,
- strong_co_induct: thm,
- un_fold_thmss: thm list list,
- co_rec_thmss: thm list list};
+ co_iterss: term list list,
+ co_inducts: thm list,
+ co_iter_thmsss: thm list list list};
fun of_fp_sugar f (fp_sugar as {index, ...}) = nth (f fp_sugar) index;
@@ -116,16 +110,15 @@
{T = T2, fp = fp2, index = index2, fp_res = fp_res2, ...} : fp_sugar) =
T1 = T2 andalso fp1 = fp2 andalso index1 = index2 andalso eq_fp_result (fp_res1, fp_res2);
-fun morph_fp_sugar phi {T, fp, index, pre_bnfs, fp_res, ctr_defss, ctr_sugars, un_folds,
- co_recs, co_induct, strong_co_induct, un_fold_thmss, co_rec_thmss} =
+fun morph_fp_sugar phi {T, fp, index, pre_bnfs, fp_res, ctr_defss, ctr_sugars, co_iterss,
+ co_inducts, co_iter_thmsss} =
{T = Morphism.typ phi T, fp = fp, index = index, pre_bnfs = map (morph_bnf phi)
pre_bnfs, fp_res = morph_fp_result phi fp_res,
ctr_defss = map (map (Morphism.thm phi)) ctr_defss,
- ctr_sugars = map (morph_ctr_sugar phi) ctr_sugars, un_folds = map (Morphism.term phi) un_folds,
- co_recs = map (Morphism.term phi) co_recs, co_induct = Morphism.thm phi co_induct,
- strong_co_induct = Morphism.thm phi strong_co_induct,
- un_fold_thmss = map (map (Morphism.thm phi)) un_fold_thmss,
- co_rec_thmss = map (map (Morphism.thm phi)) co_rec_thmss};
+ ctr_sugars = map (morph_ctr_sugar phi) ctr_sugars,
+ co_iterss = map (map (Morphism.term phi)) co_iterss,
+ co_inducts = map (Morphism.thm phi) co_inducts,
+ co_iter_thmsss = map (map (map (Morphism.thm phi))) co_iter_thmsss};
structure Data = Generic_Data
(
@@ -141,14 +134,14 @@
Local_Theory.declaration {syntax = false, pervasive = true}
(fn phi => Data.map (Symtab.update_new (key, morph_fp_sugar phi fp_sugar)));
-fun register_fp_sugars fp pre_bnfs (fp_res as {Ts, ...}) ctr_defss ctr_sugars [un_folds, co_recs]
- co_induct strong_co_induct [un_fold_thmss, co_rec_thmss] lthy =
+fun register_fp_sugars fp pre_bnfs (fp_res as {Ts, ...}) ctr_defss ctr_sugars co_iterss co_inducts
+ co_iter_thmsss lthy =
(0, lthy)
|> fold (fn T as Type (s, _) => fn (kk, lthy) => (kk + 1,
register_fp_sugar s {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, fp_res = fp_res,
- ctr_defss = ctr_defss, ctr_sugars = ctr_sugars, un_folds = un_folds, co_recs = co_recs,
- co_induct = co_induct, strong_co_induct = strong_co_induct, un_fold_thmss = un_fold_thmss,
- co_rec_thmss = co_rec_thmss} lthy)) Ts
+ ctr_defss = ctr_defss, ctr_sugars = ctr_sugars, co_iterss = co_iterss,
+ co_inducts = co_inducts, co_iter_thmsss = co_iter_thmsss}
+ lthy)) Ts
|> snd;
(* This function could produce clashes in contrived examples (e.g., "x.A", "x.x_A", "y.A"). *)
@@ -247,131 +240,10 @@
val mk_fp_iter_fun_types = fst o split_last o binder_types o fastype_of;
-fun project_co_recT special_Tname Cs proj =
- let
- fun project (Type (s, Ts as [T, U])) =
- if s = special_Tname andalso member (op =) Cs U then proj (T, U)
- else Type (s, map project Ts)
- | project (Type (s, Ts)) = Type (s, map project Ts)
- | project T = T;
- in project end;
-
-val project_corecT = project_co_recT @{type_name sum};
-
fun unzip_recT Cs (T as Type (@{type_name prod}, Ts as [_, U])) =
if member (op =) Cs U then Ts else [T]
| unzip_recT _ T = [T];
-fun mk_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
-
-fun mk_iter_fun_arg_typessss Cs ns mss =
- mk_fp_iter_fun_types
- #> map3 mk_fun_arg_typess ns mss
- #> map (map (map (unzip_recT Cs)));
-
-fun mk_iters_args_types Cs ns mss [ctor_fold_fun_Ts, ctor_rec_fun_Ts] lthy =
- let
- val Css = map2 replicate ns Cs;
- val y_Tsss = map3 mk_fun_arg_typess ns mss ctor_fold_fun_Ts;
- val g_Tss = map2 (fn C => map (fn y_Ts => y_Ts ---> C)) Cs y_Tsss;
-
- val ((gss, ysss), lthy) =
- lthy
- |> mk_Freess "f" g_Tss
- ||>> mk_Freesss "x" y_Tsss;
-
- val y_Tssss = map (map (map single)) y_Tsss;
- val yssss = map (map (map single)) ysss;
-
- val z_Tssss =
- map3 (fn n => fn ms => map2 (map (unzip_recT Cs) oo dest_tupleT) ms o
- dest_sumTN_balanced n o domain_type) ns mss ctor_rec_fun_Ts;
-
- val z_Tsss' = map (map flat_rec) z_Tssss;
- val h_Tss = map2 (map2 (curry (op --->))) z_Tsss' Css;
-
- val hss = map2 (map2 retype_free) h_Tss gss;
- val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
- val (zssss_tl, lthy) =
- lthy
- |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
- val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
- in
- ([(g_Tss, y_Tssss, gss, yssss), (h_Tss, z_Tssss, hss, zssss)], lthy)
- end;
-
-fun mk_coiters_args_types Cs ns mss [dtor_unfold_fun_Ts, dtor_corec_fun_Ts] lthy =
- let
- (*avoid "'a itself" arguments in coiterators and corecursors*)
- fun repair_arity [0] = [1]
- | repair_arity ms = ms;
-
- fun unzip_corecT T =
- if exists_subtype_in Cs T then [project_corecT Cs fst T, project_corecT Cs snd T]
- else [T];
-
- val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs;
-
- fun mk_types maybe_unzipT fun_Ts =
- let
- val f_sum_prod_Ts = map range_type fun_Ts;
- val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
- val f_Tsss = map2 (map2 dest_tupleT o repair_arity) mss f_prod_Tss;
- val f_Tssss = map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
- val q_Tssss =
- map (map (map (fn [_] => [] | [_, T] => [mk_pred1T (domain_type T)]))) f_Tssss;
- val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
- in (q_Tssss, f_Tssss, (f_sum_prod_Ts, f_Tsss, pf_Tss)) end;
-
- val (r_Tssss, g_Tssss, unfold_types) = mk_types single dtor_unfold_fun_Ts;
- val (s_Tssss, h_Tssss, corec_types) = mk_types unzip_corecT dtor_corec_fun_Ts;
-
- val (((cs, pss), gssss), lthy) =
- lthy
- |> mk_Frees "a" Cs
- ||>> mk_Freess "p" p_Tss
- ||>> mk_Freessss "g" g_Tssss;
- val rssss = map (map (map (fn [] => []))) r_Tssss;
-
- val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
- val ((sssss, hssss_tl), lthy) =
- lthy
- |> mk_Freessss "q" s_Tssss
- ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
- val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
-
- val cpss = map2 (map o rapp) cs pss;
-
- fun mk_args qssss fssss =
- let
- val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
- val cqssss = map2 (map o map o map o rapp) cs qssss;
- val cfssss = map2 (map o map o map o rapp) cs fssss;
- in (pfss, cqssss, cfssss) end;
-
- val unfold_args = mk_args rssss gssss;
- val corec_args = mk_args sssss hssss;
- in
- ((cs, cpss, [(unfold_args, unfold_types), (corec_args, corec_types)]), lthy)
- end;
-
-fun mk_un_fold_co_rec_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy =
- let
- val thy = Proof_Context.theory_of lthy;
-
- val (xtor_co_iter_fun_Tss', xtor_co_iterss') =
- map (mk_co_iters thy fp fpTs Cs #> `(mk_fp_iter_fun_types o hd)) (transpose xtor_co_iterss0)
- |> split_list;
-
- val ((iters_args_types, coiters_args_types), lthy') =
- if fp = Least_FP then
- mk_iters_args_types Cs ns mss xtor_co_iter_fun_Tss' lthy |>> (rpair NONE o SOME)
- else
- mk_coiters_args_types Cs ns mss xtor_co_iter_fun_Tss' lthy |>> (pair NONE o SOME)
- in
- ((xtor_co_iterss', iters_args_types, coiters_args_types), lthy')
- end;
-
fun mk_map live Ts Us t =
let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
@@ -382,6 +254,26 @@
Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
end;
+fun build_map lthy build_simple =
+ let
+ fun build (TU as (T, U)) =
+ if T = U then
+ id_const T
+ else
+ (case TU of
+ (Type (s, Ts), Type (s', Us)) =>
+ if s = s' then
+ let
+ val bnf = the (bnf_of lthy s);
+ val live = live_of_bnf bnf;
+ val mapx = mk_map live Ts Us (map_of_bnf bnf);
+ val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
+ in Term.list_comb (mapx, map build TUs') end
+ else
+ build_simple TU
+ | _ => build_simple TU);
+ in build end;
+
fun liveness_of_fp_bnf n bnf =
(case T_of_bnf bnf of
Type (_, Ts) => map (not o member (op =) (deads_of_bnf bnf)) Ts
@@ -425,25 +317,125 @@
fun indexify proj xs f p = f (find_index (curry (op =) (proj p)) xs) p;
-fun build_map lthy build_simple =
+fun mk_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
+
+fun mk_iter_fun_arg_typessss Cs ns mss =
+ mk_fp_iter_fun_types
+ #> map3 mk_fun_arg_typess ns mss
+ #> map (map (map (unzip_recT Cs)));
+
+fun mk_iters_args_types Cs ns mss ctor_iter_fun_Tss lthy =
+ let
+ val Css = map2 replicate ns Cs;
+ val y_Tsss = map3 mk_fun_arg_typess ns mss (map un_fold_of ctor_iter_fun_Tss);
+ val g_Tss = map2 (fn C => map (fn y_Ts => y_Ts ---> C)) Cs y_Tsss;
+
+ val ((gss, ysss), lthy) =
+ lthy
+ |> mk_Freess "f" g_Tss
+ ||>> mk_Freesss "x" y_Tsss;
+
+ val y_Tssss = map (map (map single)) y_Tsss;
+ val yssss = map (map (map single)) ysss;
+
+ val z_Tssss =
+ map3 (fn n => fn ms => map2 (map (unzip_recT Cs) oo dest_tupleT) ms o
+ dest_sumTN_balanced n o domain_type o co_rec_of) ns mss ctor_iter_fun_Tss;
+
+ val z_Tsss' = map (map flat_rec) z_Tssss;
+ val h_Tss = map2 (map2 (curry (op --->))) z_Tsss' Css;
+
+ val hss = map2 (map2 retype_free) h_Tss gss;
+ val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
+ val (zssss_tl, lthy) =
+ lthy
+ |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
+ val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
+ in
+ ([(g_Tss, y_Tssss, gss, yssss), (h_Tss, z_Tssss, hss, zssss)], lthy)
+ end;
+
+fun mk_coiters_args_types Cs ns mss dtor_coiter_fun_Tss lthy =
let
- fun build (TU as (T, U)) =
- if T = U then
- id_const T
+ (*avoid "'a itself" arguments in coiterators and corecursors*)
+ fun repair_arity [0] = [1]
+ | repair_arity ms = ms;
+
+ fun unzip_corecT (T as Type (@{type_name sum}, Ts as [_, U])) =
+ if member (op =) Cs U then Ts else [T]
+ | unzip_corecT T = [T];
+
+ val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs;
+
+ fun mk_types maybe_unzipT get_Ts =
+ let
+ val fun_Ts = map get_Ts dtor_coiter_fun_Tss;
+ val f_sum_prod_Ts = map range_type fun_Ts;
+ val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
+ val f_Tsss = map2 (map2 dest_tupleT o repair_arity) mss f_prod_Tss;
+ val f_Tssss = map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
+ val q_Tssss =
+ map (map (map (fn [_] => [] | [_, T] => [mk_pred1T (domain_type T)]))) f_Tssss;
+ val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
+ in (q_Tssss, f_Tsss, f_Tssss, (f_sum_prod_Ts, pf_Tss)) end;
+
+ val (r_Tssss, g_Tsss, g_Tssss, unfold_types) = mk_types single un_fold_of;
+ val (s_Tssss, h_Tsss, h_Tssss, corec_types) = mk_types unzip_corecT co_rec_of;
+
+ val ((((Free (z, _), cs), pss), gssss), lthy) =
+ lthy
+ |> yield_singleton (mk_Frees "z") dummyT
+ ||>> mk_Frees "a" Cs
+ ||>> mk_Freess "p" p_Tss
+ ||>> mk_Freessss "g" g_Tssss;
+ val rssss = map (map (map (fn [] => []))) r_Tssss;
+
+ val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
+ val ((sssss, hssss_tl), lthy) =
+ lthy
+ |> mk_Freessss "q" s_Tssss
+ ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
+ val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
+
+ val cpss = map2 (map o rapp) cs pss;
+
+ fun build_sum_inj mk_inj = build_map lthy (uncurry mk_inj o dest_sumT o snd);
+
+ fun build_dtor_coiter_arg _ [] [cf] = cf
+ | build_dtor_coiter_arg T [cq] [cf, cf'] =
+ mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
+ (build_sum_inj Inr_const (fastype_of cf', T) $ cf');
+
+ fun mk_args qssss fssss f_Tsss =
+ let
+ val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
+ val cqssss = map2 (map o map o map o rapp) cs qssss;
+ val cfssss = map2 (map o map o map o rapp) cs fssss;
+ val cqfsss = map3 (map3 (map3 build_dtor_coiter_arg)) f_Tsss cqssss cfssss;
+ in (pfss, cqfsss) end;
+
+ val unfold_args = mk_args rssss gssss g_Tsss;
+ val corec_args = mk_args sssss hssss h_Tsss;
+ in
+ ((z, cs, cpss, [(unfold_args, unfold_types), (corec_args, corec_types)]), lthy)
+ end;
+
+fun mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy =
+ let
+ val thy = Proof_Context.theory_of lthy;
+
+ val (xtor_co_iter_fun_Tss, xtor_co_iterss) =
+ map (mk_co_iters thy fp fpTs Cs #> `(mk_fp_iter_fun_types o hd)) (transpose xtor_co_iterss0)
+ |> apsnd transpose o apfst transpose o split_list;
+
+ val ((iters_args_types, coiters_args_types), lthy') =
+ if fp = Least_FP then
+ mk_iters_args_types Cs ns mss xtor_co_iter_fun_Tss lthy |>> (rpair NONE o SOME)
else
- (case TU of
- (Type (s, Ts), Type (s', Us)) =>
- if s = s' then
- let
- val bnf = the (bnf_of lthy s);
- val live = live_of_bnf bnf;
- val mapx = mk_map live Ts Us (map_of_bnf bnf);
- val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
- in Term.list_comb (mapx, map build TUs') end
- else
- build_simple TU
- | _ => build_simple TU);
- in build end;
+ mk_coiters_args_types Cs ns mss xtor_co_iter_fun_Tss lthy |>> (pair NONE o SOME)
+ in
+ ((xtor_co_iterss, iters_args_types, coiters_args_types), lthy')
+ end;
fun mk_iter_body ctor_iter fss xssss =
Term.list_comb (ctor_iter, map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss);
@@ -454,19 +446,8 @@
(map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cqfss) (1 upto n)))
end;
-fun mk_coiter_body lthy cs cpss f_sum_prod_Ts f_Tsss cqssss cfssss dtor_coiter =
- let
- fun build_sum_inj mk_inj = build_map lthy (uncurry mk_inj o dest_sumT o snd);
-
- fun build_dtor_coiter_arg _ [] [cf] = cf
- | build_dtor_coiter_arg T [cq] [cf, cf'] =
- mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
- (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
-
- val cqfsss = map3 (map3 (map3 build_dtor_coiter_arg)) f_Tsss cqssss cfssss;
- in
- Term.list_comb (dtor_coiter, map4 mk_preds_getterss_join cs cpss f_sum_prod_Ts cqfsss)
- end;
+fun mk_coiter_body cs cpss f_sum_prod_Ts cqfsss dtor_coiter =
+ Term.list_comb (dtor_coiter, map4 mk_preds_getterss_join cs cpss f_sum_prod_Ts cqfsss);
fun define_co_iters fp fpT Cs binding_specs lthy0 =
let
@@ -504,29 +485,36 @@
define_co_iters Least_FP fpT Cs (map3 generate_iter iterNs iter_args_typess' ctor_iters) lthy
end;
-fun define_coiters coiterNs (cs, cpss, coiter_args_typess') mk_binding fpTs Cs dtor_coiters lthy =
+fun define_coiters coiterNs (_, cs, cpss, coiter_args_typess') mk_binding fpTs Cs dtor_coiters
+ lthy =
let
val nn = length fpTs;
val C_to_fpT as Type (_, [_, fpT]) = snd (strip_typeN nn (fastype_of (hd dtor_coiters)));
- fun generate_coiter suf ((pfss, cqssss, cfssss), (f_sum_prod_Ts, f_Tsss, pf_Tss)) dtor_coiter =
+ fun generate_coiter suf ((pfss, cqfsss), (f_sum_prod_Ts, pf_Tss)) dtor_coiter =
let
val res_T = fold_rev (curry (op --->)) pf_Tss C_to_fpT;
val b = mk_binding suf;
val spec =
mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of b, res_T)),
- mk_coiter_body lthy cs cpss f_sum_prod_Ts f_Tsss cqssss cfssss dtor_coiter);
+ mk_coiter_body cs cpss f_sum_prod_Ts cqfsss dtor_coiter);
in (b, spec) end;
in
define_co_iters Greatest_FP fpT Cs
(map3 generate_coiter coiterNs coiter_args_typess' dtor_coiters) lthy
end;
-fun derive_induct_iters_thms_for_types pre_bnfs (ctor_iters1 :: _) [fold_args_types, rec_args_types]
- ctor_induct [ctor_fold_thms, ctor_rec_thms] nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss
- ctr_defss [folds, recs] [fold_defs, rec_defs] lthy =
+fun derive_induct_iters_thms_for_types pre_bnfs [fold_args_types, rec_args_types] [ctor_induct]
+ ctor_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss iterss iter_defss
+ lthy =
let
+ val iterss' = transpose iterss;
+ val iter_defss' = transpose iter_defss;
+
+ val [folds, recs] = iterss';
+ val [fold_defs, rec_defs] = iter_defss';
+
val ctr_Tsss = map (map (binder_types o fastype_of)) ctrss;
val nn = length pre_bnfs;
@@ -541,9 +529,6 @@
val fp_b_names = map base_name_of_typ fpTs;
- val ctor_fold_fun_Ts = mk_fp_iter_fun_types (un_fold_of ctor_iters1);
- val ctor_rec_fun_Ts = mk_fp_iter_fun_types (co_rec_of ctor_iters1);
-
val ((((ps, ps'), xsss), us'), names_lthy) =
lthy
|> mk_Frees' "P" (map mk_pred1T fpTs)
@@ -670,17 +655,23 @@
map2 (map2 prove) goalss tacss
end;
- val fold_thmss = mk_iter_thmss fold_args_types folds fold_defs ctor_fold_thms;
- val rec_thmss = mk_iter_thmss rec_args_types recs rec_defs ctor_rec_thms;
+ val fold_thmss = mk_iter_thmss fold_args_types folds fold_defs (map un_fold_of ctor_iter_thmss);
+ val rec_thmss = mk_iter_thmss rec_args_types recs rec_defs (map co_rec_of ctor_iter_thmss);
in
- ((induct_thm, induct_thms, [induct_case_names_attr]),
+ ((induct_thms, induct_thm, [induct_case_names_attr]),
(fold_thmss, code_simp_attrs), (rec_thmss, code_simp_attrs))
end;
-fun derive_coinduct_coiters_thms_for_types pre_bnfs (dtor_coiters1 :: _) dtor_coinduct
- dtor_strong_induct dtor_ctors [dtor_unfold_thms, dtor_corec_thms] nesting_bnfs nested_bnfs fpTs
- Cs As kss mss ns ctr_defss ctr_sugars [unfolds, corecs] [unfold_defs, corec_defs] lthy =
+fun derive_coinduct_coiters_thms_for_types pre_bnfs (z, cs, cpss,
+ [(unfold_args as (pgss, crgsss), _), (corec_args as (phss, cshsss), _)])
+ dtor_coinducts dtor_ctors dtor_coiter_thmss nesting_bnfs nested_bnfs fpTs Cs As kss mss ns
+ ctr_defss ctr_sugars coiterss coiter_defss lthy =
let
+ val coiterss' = transpose coiterss;
+ val coiter_defss' = transpose coiter_defss;
+
+ val [unfold_defs, corec_defs] = coiter_defss';
+
val nn = length pre_bnfs;
val pre_map_defs = map map_def_of_bnf pre_bnfs;
@@ -688,14 +679,10 @@
val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nesting_bnfs;
val nesting_rel_eqs = map rel_eq_of_bnf nesting_bnfs;
val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
- val nested_map_comps'' = map ((fn thm => thm RS sym) o map_comp_of_bnf) nested_bnfs;
val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
val fp_b_names = map base_name_of_typ fpTs;
- val dtor_unfold_fun_Ts = mk_fp_iter_fun_types (un_fold_of dtor_coiters1);
- val dtor_corec_fun_Ts = mk_fp_iter_fun_types (co_rec_of dtor_coiters1);
-
val ctrss = map (map (mk_ctr As) o #ctrs) ctr_sugars;
val discss = map (map (mk_disc_or_sel As) o #discs) ctr_sugars;
val selsss = map (map (map (mk_disc_or_sel As)) o #selss) ctr_sugars;
@@ -704,11 +691,8 @@
val discIss = map #discIs ctr_sugars;
val sel_thmsss = map #sel_thmss ctr_sugars;
- val ((cs, cpss, [((pgss, crssss, cgssss), _), ((phss, csssss, chssss), _)]), names_lthy0) =
- mk_coiters_args_types Cs ns mss [dtor_unfold_fun_Ts, dtor_corec_fun_Ts] lthy;
-
val (((rs, us'), vs'), names_lthy) =
- names_lthy0
+ lthy
|> mk_Frees "R" (map (fn T => mk_pred2T T T) fpTs)
||>> Variable.variant_fixes fp_b_names
||>> Variable.variant_fixes (map (suffix "'") fp_b_names);
@@ -721,7 +705,7 @@
val vdiscss = map2 (map o rapp) vs discss;
val vselsss = map2 (map o map o rapp) vs selsss;
- val ((coinduct_thms, coinduct_thm), (strong_coinduct_thms, strong_coinduct_thm)) =
+ val coinduct_thms_pairs =
let
val uvrs = map3 (fn r => fn u => fn v => r $ u $ v) rs us vs;
val uv_eqs = map2 (curry HOLogic.mk_eq) us vs;
@@ -729,6 +713,7 @@
map4 (fn u => fn v => fn uvr => fn uv_eq =>
fold_rev Term.lambda [u, v] (HOLogic.mk_disj (uvr, uv_eq))) us vs uvrs uv_eqs;
+ (* TODO: generalize (cf. "build_map") *)
fun build_rel rs' T =
(case find_index (curry (op =) T) fpTs of
~1 =>
@@ -772,8 +757,7 @@
Logic.list_implies (map8 (mk_prem rs') uvrs us vs ns udiscss uselsss vdiscss vselsss,
concl);
- val goal = mk_goal rs;
- val strong_goal = mk_goal strong_rs;
+ val goals = map mk_goal [rs, strong_rs];
fun prove dtor_coinduct' goal =
Goal.prove_sorry lthy [] [] goal (fn {context = ctxt, ...} =>
@@ -788,7 +772,7 @@
|> Drule.zero_var_indexes
|> `(conj_dests nn);
in
- (postproc nn (prove dtor_coinduct goal), postproc nn (prove dtor_strong_induct strong_goal))
+ map2 (postproc nn oo prove) dtor_coinducts goals
end;
fun mk_coinduct_concls ms discs ctrs =
@@ -808,8 +792,8 @@
fun mk_maybe_not pos = not pos ? HOLogic.mk_not;
- val gunfolds = map (lists_bmoc pgss) unfolds;
- val hcorecs = map (lists_bmoc phss) corecs;
+ val fcoiterss' as [gunfolds, hcorecs] =
+ map2 (fn (pfss, _) => map (lists_bmoc pfss)) [unfold_args, corec_args] coiterss';
val (unfold_thmss, corec_thmss, safe_unfold_thmss, safe_corec_thmss) =
let
@@ -818,24 +802,36 @@
(Logic.list_implies (seq_conds (HOLogic.mk_Trueprop oo mk_maybe_not) n k cps,
mk_Trueprop_eq (fcoiter $ c, Term.list_comb (ctr, take m cfs'))));
- (* TODO: get rid of "mk_U" *)
- val mk_U = typ_subst_nonatomic (map2 pair Cs fpTs);
+ fun build_coiter fcoiters maybe_tack (T, U) =
+ if T = U then
+ id_const T
+ else
+ (case find_index (curry (op =) U) fpTs of
+ ~1 => build_map lthy (build_coiter fcoiters maybe_tack) (T, U)
+ | kk => maybe_tack (nth cs kk, nth us kk) (nth fcoiters kk));
+
+ fun mk_U maybe_mk_sumT =
+ typ_subst_nonatomic (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
- fun intr_coiters fcoiters [] [cf] =
- let val T = fastype_of cf in
- if exists_subtype_in Cs T then
- build_map lthy (indexify fst Cs (K o nth fcoiters)) (T, mk_U T) $ cf
- else
- cf
- end
- | intr_coiters fcoiters [cq] [cf, cf'] =
- mk_If cq (intr_coiters fcoiters [] [cf]) (intr_coiters fcoiters [] [cf']);
+ fun tack z_name (c, u) f =
+ let val z = Free (z_name, mk_sumT (fastype_of u, fastype_of c)) in
+ Term.lambda z (mk_sum_case (Term.lambda u u, Term.lambda c (f $ c)) $ z)
+ end;
- val crgsss = map2 (map2 (map2 (intr_coiters gunfolds))) crssss cgssss;
- val cshsss = map2 (map2 (map2 (intr_coiters hcorecs))) csssss chssss;
+ fun intr_coiters fcoiters maybe_mk_sumT maybe_tack cqf =
+ let val T = fastype_of cqf in
+ if exists_subtype_in Cs T then
+ build_coiter fcoiters maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
+ else
+ cqf
+ end;
- val unfold_goalss = map8 (map4 oooo mk_goal pgss) cs cpss gunfolds ns kss ctrss mss crgsss;
- val corec_goalss = map8 (map4 oooo mk_goal phss) cs cpss hcorecs ns kss ctrss mss cshsss;
+ val crgsss' = map (map (map (intr_coiters (un_fold_of fcoiterss') (K I) (K I)))) crgsss;
+ val cshsss' = map (map (map (intr_coiters (co_rec_of fcoiterss') (curry mk_sumT) (tack z))))
+ cshsss;
+
+ val unfold_goalss = map8 (map4 oooo mk_goal pgss) cs cpss gunfolds ns kss ctrss mss crgsss';
+ val corec_goalss = map8 (map4 oooo mk_goal phss) cs cpss hcorecs ns kss ctrss mss cshsss';
fun mk_map_if_distrib bnf =
let
@@ -852,26 +848,30 @@
val nested_map_if_distribs = map mk_map_if_distrib nested_bnfs;
val unfold_tacss =
- map3 (map oo mk_coiter_tac unfold_defs [] [] nesting_map_ids'' [])
- dtor_unfold_thms pre_map_defs ctr_defss;
+ map3 (map oo mk_coiter_tac unfold_defs nesting_map_ids'')
+ (map un_fold_of dtor_coiter_thmss) pre_map_defs ctr_defss;
val corec_tacss =
- map3 (map oo mk_coiter_tac corec_defs nested_map_comps'' nested_map_comp's
- (nested_map_ids'' @ nesting_map_ids'') nested_map_if_distribs)
- dtor_corec_thms pre_map_defs ctr_defss;
+ map3 (map oo mk_coiter_tac corec_defs nesting_map_ids'')
+ (map co_rec_of dtor_coiter_thmss) pre_map_defs ctr_defss;
fun prove goal tac =
Goal.prove_sorry lthy [] [] goal (tac o #context)
|> Thm.close_derivation;
val unfold_thmss = map2 (map2 prove) unfold_goalss unfold_tacss;
- val corec_thmss = map2 (map2 prove) corec_goalss corec_tacss;
+ val corec_thmss =
+ map2 (map2 prove) corec_goalss corec_tacss
+ |> map (map (unfold_thms lthy @{thms sum_case_if}));
+
+ val unfold_safesss = map2 (map2 (map2 (curry (op =)))) crgsss' crgsss;
+ val corec_safesss = map2 (map2 (map2 (curry (op =)))) cshsss' cshsss;
val filter_safesss =
map2 (map_filter (fn (safes, thm) => if forall I safes then SOME thm else NONE) oo
- curry (op ~~)) (map2 (map2 (map2 (member (op =)))) cgssss crgsss);
+ curry (op ~~));
- val safe_unfold_thmss = filter_safesss unfold_thmss;
- val safe_corec_thmss = filter_safesss corec_thmss;
+ val safe_unfold_thmss = filter_safesss unfold_safesss unfold_thmss;
+ val safe_corec_thmss = filter_safesss corec_safesss corec_thmss;
in
(unfold_thmss, corec_thmss, safe_unfold_thmss, safe_corec_thmss)
end;
@@ -941,7 +941,7 @@
val coinduct_case_attrs =
coinduct_consumes_attr :: coinduct_case_names_attr :: coinduct_case_concl_attrs;
in
- ((coinduct_thm, coinduct_thms, strong_coinduct_thm, strong_coinduct_thms, coinduct_case_attrs),
+ ((coinduct_thms_pairs, coinduct_case_attrs),
(unfold_thmss, corec_thmss, []),
(safe_unfold_thmss, safe_corec_thmss),
(disc_unfold_thmss, disc_corec_thmss, simp_attrs),
@@ -1050,9 +1050,8 @@
map dest_TFree Xs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctrXs_sum_prod_Ts;
val (pre_bnfs, (fp_res as {bnfs = fp_bnfs as any_fp_bnf :: _, ctors = ctors0, dtors = dtors0,
- xtor_co_iterss = xtor_co_iterss0, xtor_co_induct, xtor_strong_co_induct, dtor_ctors,
- ctor_dtors, ctor_injects, xtor_map_thms, xtor_set_thmss, xtor_rel_thms,
- xtor_co_iter_thmss, ...}, lthy)) =
+ xtor_co_iterss = xtor_co_iterss0, xtor_co_inducts, dtor_ctors, ctor_dtors, ctor_injects,
+ xtor_map_thms, xtor_set_thmss, xtor_rel_thms, xtor_co_iter_thmss, ...}, lthy)) =
fp_bnf (construct_fp mixfixes map_bs rel_bs set_bss) fp_bs (map dest_TFree unsorted_As) fp_eqs
no_defs_lthy0;
@@ -1099,14 +1098,13 @@
val kss = map (fn n => 1 upto n) ns;
val mss = map (map length) ctr_Tsss;
- val ((xtor_co_iterss', iters_args_types, coiters_args_types), lthy) =
- mk_un_fold_co_rec_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy;
- val xtor_co_iterss = transpose xtor_co_iterss';
+ val ((xtor_co_iterss, iters_args_types, coiters_args_types), lthy) =
+ mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy;
fun define_ctrs_case_for_type ((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
- xtor_co_iters), ctor_dtor), dtor_ctor), ctor_inject), pre_map_def), pre_set_defs),
- pre_rel_def), fp_map_thm), fp_set_thms), fp_rel_thm), n), ks), ms), ctr_bindings),
- ctr_mixfixes), ctr_Tss), disc_bindings), sel_bindingss), raw_sel_defaultss) no_defs_lthy =
+ xtor_co_iters), ctor_dtor), dtor_ctor), ctor_inject), pre_map_def), pre_set_defs),
+ pre_rel_def), fp_map_thm), fp_set_thms), fp_rel_thm), n), ks), ms), ctr_bindings),
+ ctr_mixfixes), ctr_Tss), disc_bindings), sel_bindingss), raw_sel_defaultss) no_defs_lthy =
let
val fp_b_name = Binding.name_of fp_b;
@@ -1302,8 +1300,8 @@
fun wrap_types_etc (wrap_types_etcs, lthy) =
fold_map I wrap_types_etcs lthy
- |>> apsnd (apsnd transpose o apfst transpose o split_list)
- o apfst (apsnd split_list4 o apfst split_list4 o split_list) o split_list;
+ |>> apsnd split_list o apfst (apsnd split_list4 o apfst split_list4 o split_list)
+ o split_list;
val mk_simp_thmss =
map7 (fn {injects, distincts, case_thms, ...} => fn un_folds => fn co_recs =>
@@ -1313,13 +1311,13 @@
fun derive_and_note_induct_iters_thms_for_types
((((mapsx, rel_injects, rel_distincts, setss), (ctrss, _, ctr_defss, ctr_sugars)),
- (iterss', iter_defss')), lthy) =
+ (iterss, iter_defss)), lthy) =
let
- val ((induct_thm, induct_thms, induct_attrs), (fold_thmss, fold_attrs),
+ val ((induct_thms, induct_thm, induct_attrs), (fold_thmss, fold_attrs),
(rec_thmss, rec_attrs)) =
- derive_induct_iters_thms_for_types pre_bnfs xtor_co_iterss (the iters_args_types)
- xtor_co_induct (transpose xtor_co_iter_thmss) nesting_bnfs nested_bnfs fpTs Cs Xs
- ctrXs_Tsss ctrss ctr_defss iterss' iter_defss' lthy;
+ derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_inducts
+ xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss iterss
+ iter_defss lthy;
val induct_type_attr = Attrib.internal o K o Induct.induct_type;
@@ -1339,24 +1337,24 @@
in
lthy
|> Local_Theory.notes (common_notes @ notes) |> snd
- |> register_fp_sugars Least_FP pre_bnfs fp_res ctr_defss ctr_sugars iterss' induct_thm
- induct_thm [fold_thmss, rec_thmss]
+ |> register_fp_sugars Least_FP pre_bnfs fp_res ctr_defss ctr_sugars iterss [induct_thm]
+ (transpose [fold_thmss, rec_thmss])
end;
fun derive_and_note_coinduct_coiters_thms_for_types
((((mapsx, rel_injects, rel_distincts, setss), (_, _, ctr_defss, ctr_sugars)),
- (coiterss', coiter_defss')), lthy) =
+ (coiterss, coiter_defss)), lthy) =
let
- val ((coinduct_thm, coinduct_thms, strong_coinduct_thm, strong_coinduct_thms,
+ val (([(coinduct_thms, coinduct_thm), (strong_coinduct_thms, strong_coinduct_thm)],
coinduct_attrs),
(unfold_thmss, corec_thmss, coiter_attrs),
(safe_unfold_thmss, safe_corec_thmss),
(disc_unfold_thmss, disc_corec_thmss, disc_coiter_attrs),
(disc_unfold_iff_thmss, disc_corec_iff_thmss, disc_coiter_iff_attrs),
(sel_unfold_thmss, sel_corec_thmss, sel_coiter_attrs)) =
- derive_coinduct_coiters_thms_for_types pre_bnfs xtor_co_iterss xtor_co_induct
- xtor_strong_co_induct dtor_ctors (transpose xtor_co_iter_thmss) nesting_bnfs nested_bnfs
- fpTs Cs As kss mss ns ctr_defss ctr_sugars coiterss' coiter_defss' lthy;
+ derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types) xtor_co_inducts
+ dtor_ctors xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs As kss mss ns ctr_defss
+ ctr_sugars coiterss coiter_defss lthy;
val coinduct_type_attr = Attrib.internal o K o Induct.coinduct_type;
@@ -1398,8 +1396,8 @@
in
lthy
|> Local_Theory.notes (anonymous_notes @ common_notes @ notes) |> snd
- |> register_fp_sugars Greatest_FP pre_bnfs fp_res ctr_defss ctr_sugars coiterss'
- coinduct_thm strong_coinduct_thm [unfold_thmss, corec_thmss]
+ |> register_fp_sugars Greatest_FP pre_bnfs fp_res ctr_defss ctr_sugars coiterss
+ [coinduct_thm, strong_coinduct_thm] (transpose [unfold_thmss, corec_thmss])
end;
val lthy' = lthy