--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Fri Jun 07 12:54:40 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Fri Jun 07 14:45:07 2013 +0200
@@ -28,17 +28,17 @@
val flat_rec: 'a list list -> 'a list
val mk_co_iter: theory -> BNF_FP_Util.fp_kind -> typ -> typ list -> term -> term
val nesty_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
+ val mk_map: int -> typ list -> typ list -> term -> term
+ val build_map: local_theory -> (typ * typ -> term) -> typ * typ -> term
val mk_co_iters_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
int list list -> term list list -> Proof.context ->
(term list list
* (typ list list * typ list list list list * term list list
* term list list list list) list option
- * (term list * term list list
- * ((term list list * term list list list list * term list list list list)
+ * (string * term list * term list list
+ * ((term list list * term list list list)
* (typ list * typ list list list * typ list list)) list) option)
* Proof.context
- val mk_map: int -> typ list -> typ list -> term -> term
- val build_map: local_theory -> (typ * typ -> term) -> typ * typ -> term
val mk_iter_fun_arg_typessss: typ list -> int list -> int list list -> term ->
typ list list list list
@@ -46,8 +46,8 @@
(typ list list * typ list list list list * term list list * term list list list list) list ->
(string -> binding) -> typ list -> typ list -> term list -> Proof.context ->
(term list * thm list) * Proof.context
- val define_coiters: string list -> term list * term list list
- * ((term list list * term list list list list * term list list list list)
+ val define_coiters: string list -> string * term list * term list list
+ * ((term list list * term list list list)
* (typ list * typ list list list * typ list list)) list ->
(string -> binding) -> typ list -> typ list -> term list -> Proof.context ->
(term list * thm list) * Proof.context
@@ -59,8 +59,7 @@
(thm list * thm * Args.src list) * (thm list list * Args.src list)
* (thm list list * Args.src list)
val derive_coinduct_coiters_thms_for_types: BNF_Def.bnf list ->
- term list * term list list
- * ((term list list * term list list list list * term list list list list) * 'a) list ->
+ string * term list * term list list * ((term list list * term list list list) * 'a) list ->
thm list -> thm list -> thm list list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list ->
typ list -> typ list -> int list list -> int list list -> int list -> thm list list ->
BNF_Ctr_Sugar.ctr_sugar list -> term list list -> thm list list -> local_theory ->
@@ -257,6 +256,79 @@
if member (op =) Cs U then Ts else [T]
| unzip_recT _ T = [T];
+fun mk_map live Ts Us t =
+ let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
+ Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
+ end;
+
+fun mk_rel live Ts Us t =
+ let val [Type (_, Ts0), Type (_, Us0)] = binder_types (snd (strip_typeN live (fastype_of t))) in
+ Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
+ end;
+
+fun build_map lthy build_simple =
+ let
+ fun build (TU as (T, U)) =
+ if T = U then
+ id_const T
+ else
+ (case TU of
+ (Type (s, Ts), Type (s', Us)) =>
+ if s = s' then
+ let
+ val bnf = the (bnf_of lthy s);
+ val live = live_of_bnf bnf;
+ val mapx = mk_map live Ts Us (map_of_bnf bnf);
+ val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
+ in Term.list_comb (mapx, map build TUs') end
+ else
+ build_simple TU
+ | _ => build_simple TU);
+ in build end;
+
+fun liveness_of_fp_bnf n bnf =
+ (case T_of_bnf bnf of
+ Type (_, Ts) => map (not o member (op =) (deads_of_bnf bnf)) Ts
+ | _ => replicate n false);
+
+fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
+
+fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
+
+fun merge_type_args (As, As') =
+ if length As = length As' then map2 merge_type_arg As As' else cannot_merge_types ();
+
+fun reassoc_conjs thm =
+ reassoc_conjs (thm RS @{thm conj_assoc[THEN iffD1]})
+ handle THM _ => thm;
+
+fun type_args_named_constrained_of ((((ncAs, _), _), _), _) = ncAs;
+fun type_binding_of ((((_, b), _), _), _) = b;
+fun map_binding_of (((_, (b, _)), _), _) = b;
+fun rel_binding_of (((_, (_, b)), _), _) = b;
+fun mixfix_of ((_, mx), _) = mx;
+fun ctr_specs_of (_, ctr_specs) = ctr_specs;
+
+fun disc_of ((((disc, _), _), _), _) = disc;
+fun ctr_of ((((_, ctr), _), _), _) = ctr;
+fun args_of (((_, args), _), _) = args;
+fun defaults_of ((_, ds), _) = ds;
+fun ctr_mixfix_of (_, mx) = mx;
+
+fun add_nesty_bnf_names Us =
+ let
+ fun add (Type (s, Ts)) ss =
+ let val (needs, ss') = fold_map add Ts ss in
+ if exists I needs then (true, insert (op =) s ss') else (false, ss')
+ end
+ | add T ss = (member (op =) Us T, ss);
+ in snd oo add end;
+
+fun nesty_bnfs ctxt ctr_Tsss Us =
+ map_filter (bnf_of ctxt) (fold (fold (fold (add_nesty_bnf_names Us))) ctr_Tsss []);
+
+fun indexify proj xs f p = f (find_index (curry (op =) (proj p)) xs) p;
+
fun mk_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
fun mk_iter_fun_arg_typessss Cs ns mss =
@@ -301,9 +373,9 @@
fun repair_arity [0] = [1]
| repair_arity ms = ms;
- fun unzip_corecT T =
- if exists_subtype_in Cs T then [project_corecT Cs fst T, project_corecT Cs snd T]
- else [T];
+ fun unzip_corecT (T as Type (@{type_name sum}, Ts as [_, U])) =
+ if member (op =) Cs U then Ts else [T]
+ | unzip_corecT T = [T];
val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs;
@@ -322,9 +394,10 @@
val (r_Tssss, g_Tssss, unfold_types) = mk_types single un_fold_of;
val (s_Tssss, h_Tssss, corec_types) = mk_types unzip_corecT co_rec_of;
- val (((cs, pss), gssss), lthy) =
+ val ((((Free (z, _), cs), pss), gssss), lthy) =
lthy
- |> mk_Frees "a" Cs
+ |> yield_singleton (mk_Frees "z") dummyT
+ ||>> mk_Frees "a" Cs
||>> mk_Freess "p" p_Tss
||>> mk_Freessss "g" g_Tssss;
val rssss = map (map (map (fn [] => []))) r_Tssss;
@@ -338,17 +411,25 @@
val cpss = map2 (map o rapp) cs pss;
- fun mk_args qssss fssss =
+ fun build_sum_inj mk_inj = build_map lthy (uncurry mk_inj o dest_sumT o snd);
+
+ fun build_dtor_coiter_arg _ [] [cf] = cf
+ | build_dtor_coiter_arg T [cq] [cf, cf'] =
+ mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
+ (build_sum_inj Inr_const (fastype_of cf', T) $ cf');
+
+ fun mk_args qssss fssss (_, f_Tsss, _) =
let
val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
val cqssss = map2 (map o map o map o rapp) cs qssss;
val cfssss = map2 (map o map o map o rapp) cs fssss;
- in (pfss, cqssss, cfssss) end;
+ val cqfsss = map3 (map3 (map3 build_dtor_coiter_arg)) f_Tsss cqssss cfssss;
+ in (pfss, cqfsss) end;
- val unfold_args = mk_args rssss gssss;
- val corec_args = mk_args sssss hssss;
+ val unfold_args = mk_args rssss gssss unfold_types;
+ val corec_args = mk_args sssss hssss corec_types;
in
- ((cs, cpss, [(unfold_args, unfold_types), (corec_args, corec_types)]), lthy)
+ ((z, cs, cpss, [(unfold_args, unfold_types), (corec_args, corec_types)]), lthy)
end;
fun mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy =
@@ -368,79 +449,6 @@
((xtor_co_iterss, iters_args_types, coiters_args_types), lthy')
end;
-fun mk_map live Ts Us t =
- let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
- Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
- end;
-
-fun mk_rel live Ts Us t =
- let val [Type (_, Ts0), Type (_, Us0)] = binder_types (snd (strip_typeN live (fastype_of t))) in
- Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
- end;
-
-fun liveness_of_fp_bnf n bnf =
- (case T_of_bnf bnf of
- Type (_, Ts) => map (not o member (op =) (deads_of_bnf bnf)) Ts
- | _ => replicate n false);
-
-fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
-
-fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
-
-fun merge_type_args (As, As') =
- if length As = length As' then map2 merge_type_arg As As' else cannot_merge_types ();
-
-fun reassoc_conjs thm =
- reassoc_conjs (thm RS @{thm conj_assoc[THEN iffD1]})
- handle THM _ => thm;
-
-fun type_args_named_constrained_of ((((ncAs, _), _), _), _) = ncAs;
-fun type_binding_of ((((_, b), _), _), _) = b;
-fun map_binding_of (((_, (b, _)), _), _) = b;
-fun rel_binding_of (((_, (_, b)), _), _) = b;
-fun mixfix_of ((_, mx), _) = mx;
-fun ctr_specs_of (_, ctr_specs) = ctr_specs;
-
-fun disc_of ((((disc, _), _), _), _) = disc;
-fun ctr_of ((((_, ctr), _), _), _) = ctr;
-fun args_of (((_, args), _), _) = args;
-fun defaults_of ((_, ds), _) = ds;
-fun ctr_mixfix_of (_, mx) = mx;
-
-fun add_nesty_bnf_names Us =
- let
- fun add (Type (s, Ts)) ss =
- let val (needs, ss') = fold_map add Ts ss in
- if exists I needs then (true, insert (op =) s ss') else (false, ss')
- end
- | add T ss = (member (op =) Us T, ss);
- in snd oo add end;
-
-fun nesty_bnfs ctxt ctr_Tsss Us =
- map_filter (bnf_of ctxt) (fold (fold (fold (add_nesty_bnf_names Us))) ctr_Tsss []);
-
-fun indexify proj xs f p = f (find_index (curry (op =) (proj p)) xs) p;
-
-fun build_map lthy build_simple =
- let
- fun build (TU as (T, U)) =
- if T = U then
- id_const T
- else
- (case TU of
- (Type (s, Ts), Type (s', Us)) =>
- if s = s' then
- let
- val bnf = the (bnf_of lthy s);
- val live = live_of_bnf bnf;
- val mapx = mk_map live Ts Us (map_of_bnf bnf);
- val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
- in Term.list_comb (mapx, map build TUs') end
- else
- build_simple TU
- | _ => build_simple TU);
- in build end;
-
fun mk_iter_body ctor_iter fss xssss =
Term.list_comb (ctor_iter, map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss);
@@ -450,19 +458,8 @@
(map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cqfss) (1 upto n)))
end;
-fun mk_coiter_body lthy cs cpss f_sum_prod_Ts f_Tsss cqssss cfssss dtor_coiter =
- let
- fun build_sum_inj mk_inj = build_map lthy (uncurry mk_inj o dest_sumT o snd);
-
- fun build_dtor_coiter_arg _ [] [cf] = cf
- | build_dtor_coiter_arg T [cq] [cf, cf'] =
- mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
- (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
-
- val cqfsss = map3 (map3 (map3 build_dtor_coiter_arg)) f_Tsss cqssss cfssss;
- in
- Term.list_comb (dtor_coiter, map4 mk_preds_getterss_join cs cpss f_sum_prod_Ts cqfsss)
- end;
+fun mk_coiter_body lthy cs cpss f_sum_prod_Ts f_Tsss cqfsss dtor_coiter =
+ Term.list_comb (dtor_coiter, map4 mk_preds_getterss_join cs cpss f_sum_prod_Ts cqfsss);
fun define_co_iters fp fpT Cs binding_specs lthy0 =
let
@@ -500,19 +497,20 @@
define_co_iters Least_FP fpT Cs (map3 generate_iter iterNs iter_args_typess' ctor_iters) lthy
end;
-fun define_coiters coiterNs (cs, cpss, coiter_args_typess') mk_binding fpTs Cs dtor_coiters lthy =
+fun define_coiters coiterNs (_, cs, cpss, coiter_args_typess') mk_binding fpTs Cs dtor_coiters
+ lthy =
let
val nn = length fpTs;
val C_to_fpT as Type (_, [_, fpT]) = snd (strip_typeN nn (fastype_of (hd dtor_coiters)));
- fun generate_coiter suf ((pfss, cqssss, cfssss), (f_sum_prod_Ts, f_Tsss, pf_Tss)) dtor_coiter =
+ fun generate_coiter suf ((pfss, cqfsss), (f_sum_prod_Ts, f_Tsss, pf_Tss)) dtor_coiter =
let
val res_T = fold_rev (curry (op --->)) pf_Tss C_to_fpT;
val b = mk_binding suf;
val spec =
mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of b, res_T)),
- mk_coiter_body lthy cs cpss f_sum_prod_Ts f_Tsss cqssss cfssss dtor_coiter);
+ mk_coiter_body lthy cs cpss f_sum_prod_Ts f_Tsss cqfsss dtor_coiter);
in (b, spec) end;
in
define_co_iters Greatest_FP fpT Cs
@@ -676,8 +674,8 @@
(fold_thmss, code_simp_attrs), (rec_thmss, code_simp_attrs))
end;
-fun derive_coinduct_coiters_thms_for_types pre_bnfs (cs, cpss,
- [(unfold_args as (pgss, _, cgssss), _), (corec_args as (phss, _, _), _)])
+fun derive_coinduct_coiters_thms_for_types pre_bnfs (z, cs, cpss,
+ [(unfold_args as (pgss, crgsss), _), (corec_args as (phss, cshsss), _)])
dtor_coinducts dtor_ctors dtor_coiter_thmss nesting_bnfs nested_bnfs fpTs Cs As kss mss ns
ctr_defss ctr_sugars coiterss coiter_defss lthy =
let
@@ -693,7 +691,6 @@
val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nesting_bnfs;
val nesting_rel_eqs = map rel_eq_of_bnf nesting_bnfs;
val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
- val nested_map_comps'' = map ((fn thm => thm RS sym) o map_comp_of_bnf) nested_bnfs;
val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
val fp_b_names = map base_name_of_typ fpTs;
@@ -728,6 +725,7 @@
map4 (fn u => fn v => fn uvr => fn uv_eq =>
fold_rev Term.lambda [u, v] (HOLogic.mk_disj (uvr, uv_eq))) us vs uvrs uv_eqs;
+ (* TODO: generalize (cf. "build_map") *)
fun build_rel rs' T =
(case find_index (curry (op =) T) fpTs of
~1 =>
@@ -807,7 +805,7 @@
fun mk_maybe_not pos = not pos ? HOLogic.mk_not;
val fcoiterss' as [gunfolds, hcorecs] =
- map2 (fn (pfss, _, _) => map (lists_bmoc pfss)) [unfold_args, corec_args] coiterss';
+ map2 (fn (pfss, _) => map (lists_bmoc pfss)) [unfold_args, corec_args] coiterss';
val (unfold_thmss, corec_thmss, safe_unfold_thmss, safe_corec_thmss) =
let
@@ -816,25 +814,48 @@
(Logic.list_implies (seq_conds (HOLogic.mk_Trueprop oo mk_maybe_not) n k cps,
mk_Trueprop_eq (fcoiter $ c, Term.list_comb (ctr, take m cfs'))));
- val substC = typ_subst_nonatomic (map2 pair Cs fpTs);
+ fun build_coiter fcoiters maybe_tack (T, U) =
+ if T = U then
+ id_const T
+ else
+ (case find_index (curry (op =) U) fpTs of
+ ~1 => build_map lthy (build_coiter fcoiters maybe_tack) (T, U)
+ | kk => maybe_tack (nth cs kk, nth us kk) (nth fcoiters kk));
+
+ fun mk_U maybe_mk_sumT =
+ typ_subst_nonatomic (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
+
+ fun tack z_name (c, u) f =
+ let val z = Free (z_name, mk_sumT (fastype_of u, fastype_of c)) in
+ Term.lambda z (mk_sum_case (Term.lambda u u, Term.lambda c (f $ c)) $ z)
+ end;
- fun intr_coiters fcoiters [] [cf] =
- let val T = fastype_of cf in
- if exists_subtype_in Cs T then
- build_map lthy (indexify fst Cs (K o nth fcoiters)) (T, substC T) $ cf
- else
- cf
- end
- | intr_coiters fcoiters [cq] [cf, cf'] =
- mk_If cq (intr_coiters fcoiters [] [cf]) (intr_coiters fcoiters [] [cf']);
+ fun intr_coiters fcoiters maybe_mk_sumT maybe_tack cqf =
+ let val T = fastype_of cqf in
+ if exists_subtype_in Cs T then
+ build_coiter fcoiters maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
+ else
+ cqf
+ end;
- val [crgsss, cshsss] =
+ val crgsss' =
+ (fn fcoiters => fn (_, cqfsss) =>
+ map (map (map (intr_coiters fcoiters (K I) (K I)))) cqfsss)
+ (un_fold_of fcoiterss') unfold_args;
+ val cshsss' =
+ (fn fcoiters => fn (_, cqfsss) =>
+ map (map (map (intr_coiters fcoiters (curry mk_sumT) (tack z)))) cqfsss)
+ (co_rec_of fcoiterss') corec_args;
+
+(*###
+ val [crgsss', cshsss'] =
map2 (fn fcoiters => fn (_, cqssss, cfssss) =>
map2 (map2 (map2 (intr_coiters fcoiters))) cqssss cfssss)
fcoiterss' [unfold_args, corec_args];
+*)
- val unfold_goalss = map8 (map4 oooo mk_goal pgss) cs cpss gunfolds ns kss ctrss mss crgsss;
- val corec_goalss = map8 (map4 oooo mk_goal phss) cs cpss hcorecs ns kss ctrss mss cshsss;
+ val unfold_goalss = map8 (map4 oooo mk_goal pgss) cs cpss gunfolds ns kss ctrss mss crgsss';
+ val corec_goalss = map8 (map4 oooo mk_goal phss) cs cpss hcorecs ns kss ctrss mss cshsss';
fun mk_map_if_distrib bnf =
let
@@ -851,10 +872,10 @@
val nested_map_if_distribs = map mk_map_if_distrib nested_bnfs;
val unfold_tacss =
- map3 (map oo mk_coiter_tac unfold_defs [] [] nesting_map_ids'' [])
+ map3 (map oo mk_coiter_tac unfold_defs [] nesting_map_ids'' [])
(map un_fold_of dtor_coiter_thmss) pre_map_defs ctr_defss;
val corec_tacss =
- map3 (map oo mk_coiter_tac corec_defs nested_map_comps'' nested_map_comp's
+ map3 (map oo mk_coiter_tac corec_defs nested_map_comp's
(nested_map_ids'' @ nesting_map_ids'') nested_map_if_distribs)
(map co_rec_of dtor_coiter_thmss) pre_map_defs ctr_defss;
@@ -865,12 +886,19 @@
val unfold_thmss = map2 (map2 prove) unfold_goalss unfold_tacss;
val corec_thmss = map2 (map2 prove) corec_goalss corec_tacss;
+ val corec_thmss =
+ map2 (map2 prove) corec_goalss corec_tacss
+ |> map (map (unfold_thms lthy @{thms sum_case_if}));
+
+ val unfold_safesss = map2 (map2 (map2 (curry (op =)))) crgsss' crgsss;
+ val corec_safesss = map2 (map2 (map2 (curry (op =)))) cshsss' cshsss;
+
val filter_safesss =
map2 (map_filter (fn (safes, thm) => if forall I safes then SOME thm else NONE) oo
- curry (op ~~)) (map2 (map2 (map2 (member (op =)))) cgssss crgsss);
+ curry (op ~~));
- val safe_unfold_thmss = filter_safesss unfold_thmss;
- val safe_corec_thmss = filter_safesss corec_thmss;
+ val safe_unfold_thmss = filter_safesss unfold_safesss unfold_thmss;
+ val safe_corec_thmss = filter_safesss corec_safesss corec_thmss;
in
(unfold_thmss, corec_thmss, safe_unfold_thmss, safe_corec_thmss)
end;