# HG changeset patch # User blanchet # Date 1370609107 -7200 # Node ID ead18e3b2c1b15f4c0289ae25517cfc4a8939ce5 # Parent 856b3bd1d87ea1e5769732783438291486620db4 changed back type of corecursor for nested case, effectively reverting aa66ea552357 and 78a3d5006cf1 diff -r 856b3bd1d87e -r ead18e3b2c1b src/HOL/BNF/Tools/bnf_fp_def_sugar.ML --- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Fri Jun 07 12:54:40 2013 +0200 +++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Fri Jun 07 14:45:07 2013 +0200 @@ -28,17 +28,17 @@ val flat_rec: 'a list list -> 'a list val mk_co_iter: theory -> BNF_FP_Util.fp_kind -> typ -> typ list -> term -> term val nesty_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list + val mk_map: int -> typ list -> typ list -> term -> term + val build_map: local_theory -> (typ * typ -> term) -> typ * typ -> term val mk_co_iters_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list -> int list list -> term list list -> Proof.context -> (term list list * (typ list list * typ list list list list * term list list * term list list list list) list option - * (term list * term list list - * ((term list list * term list list list list * term list list list list) + * (string * term list * term list list + * ((term list list * term list list list) * (typ list * typ list list list * typ list list)) list) option) * Proof.context - val mk_map: int -> typ list -> typ list -> term -> term - val build_map: local_theory -> (typ * typ -> term) -> typ * typ -> term val mk_iter_fun_arg_typessss: typ list -> int list -> int list list -> term -> typ list list list list @@ -46,8 +46,8 @@ (typ list list * typ list list list list * term list list * term list list list list) list -> (string -> binding) -> typ list -> typ list -> term list -> Proof.context -> (term list * thm list) * Proof.context - val define_coiters: string list -> term list * term list list - * ((term list list * term list list list list * term list list list list) + val define_coiters: string list -> string * term list * term list list + * ((term list list * term list list list) * (typ list * typ list list list * typ list list)) list -> (string -> binding) -> typ list -> typ list -> term list -> Proof.context -> (term list * thm list) * Proof.context @@ -59,8 +59,7 @@ (thm list * thm * Args.src list) * (thm list list * Args.src list) * (thm list list * Args.src list) val derive_coinduct_coiters_thms_for_types: BNF_Def.bnf list -> - term list * term list list - * ((term list list * term list list list list * term list list list list) * 'a) list -> + string * term list * term list list * ((term list list * term list list list) * 'a) list -> thm list -> thm list -> thm list list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list -> typ list -> int list list -> int list list -> int list -> thm list list -> BNF_Ctr_Sugar.ctr_sugar list -> term list list -> thm list list -> local_theory -> @@ -257,6 +256,79 @@ if member (op =) Cs U then Ts else [T] | unzip_recT _ T = [T]; +fun mk_map live Ts Us t = + let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in + Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t + end; + +fun mk_rel live Ts Us t = + let val [Type (_, Ts0), Type (_, Us0)] = binder_types (snd (strip_typeN live (fastype_of t))) in + Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t + end; + +fun build_map lthy build_simple = + let + fun build (TU as (T, U)) = + if T = U then + id_const T + else + (case TU of + (Type (s, Ts), Type (s', Us)) => + if s = s' then + let + val bnf = the (bnf_of lthy s); + val live = live_of_bnf bnf; + val mapx = mk_map live Ts Us (map_of_bnf bnf); + val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx))); + in Term.list_comb (mapx, map build TUs') end + else + build_simple TU + | _ => build_simple TU); + in build end; + +fun liveness_of_fp_bnf n bnf = + (case T_of_bnf bnf of + Type (_, Ts) => map (not o member (op =) (deads_of_bnf bnf)) Ts + | _ => replicate n false); + +fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters"; + +fun merge_type_arg T T' = if T = T' then T else cannot_merge_types (); + +fun merge_type_args (As, As') = + if length As = length As' then map2 merge_type_arg As As' else cannot_merge_types (); + +fun reassoc_conjs thm = + reassoc_conjs (thm RS @{thm conj_assoc[THEN iffD1]}) + handle THM _ => thm; + +fun type_args_named_constrained_of ((((ncAs, _), _), _), _) = ncAs; +fun type_binding_of ((((_, b), _), _), _) = b; +fun map_binding_of (((_, (b, _)), _), _) = b; +fun rel_binding_of (((_, (_, b)), _), _) = b; +fun mixfix_of ((_, mx), _) = mx; +fun ctr_specs_of (_, ctr_specs) = ctr_specs; + +fun disc_of ((((disc, _), _), _), _) = disc; +fun ctr_of ((((_, ctr), _), _), _) = ctr; +fun args_of (((_, args), _), _) = args; +fun defaults_of ((_, ds), _) = ds; +fun ctr_mixfix_of (_, mx) = mx; + +fun add_nesty_bnf_names Us = + let + fun add (Type (s, Ts)) ss = + let val (needs, ss') = fold_map add Ts ss in + if exists I needs then (true, insert (op =) s ss') else (false, ss') + end + | add T ss = (member (op =) Us T, ss); + in snd oo add end; + +fun nesty_bnfs ctxt ctr_Tsss Us = + map_filter (bnf_of ctxt) (fold (fold (fold (add_nesty_bnf_names Us))) ctr_Tsss []); + +fun indexify proj xs f p = f (find_index (curry (op =) (proj p)) xs) p; + fun mk_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type; fun mk_iter_fun_arg_typessss Cs ns mss = @@ -301,9 +373,9 @@ fun repair_arity [0] = [1] | repair_arity ms = ms; - fun unzip_corecT T = - if exists_subtype_in Cs T then [project_corecT Cs fst T, project_corecT Cs snd T] - else [T]; + fun unzip_corecT (T as Type (@{type_name sum}, Ts as [_, U])) = + if member (op =) Cs U then Ts else [T] + | unzip_corecT T = [T]; val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs; @@ -322,9 +394,10 @@ val (r_Tssss, g_Tssss, unfold_types) = mk_types single un_fold_of; val (s_Tssss, h_Tssss, corec_types) = mk_types unzip_corecT co_rec_of; - val (((cs, pss), gssss), lthy) = + val ((((Free (z, _), cs), pss), gssss), lthy) = lthy - |> mk_Frees "a" Cs + |> yield_singleton (mk_Frees "z") dummyT + ||>> mk_Frees "a" Cs ||>> mk_Freess "p" p_Tss ||>> mk_Freessss "g" g_Tssss; val rssss = map (map (map (fn [] => []))) r_Tssss; @@ -338,17 +411,25 @@ val cpss = map2 (map o rapp) cs pss; - fun mk_args qssss fssss = + fun build_sum_inj mk_inj = build_map lthy (uncurry mk_inj o dest_sumT o snd); + + fun build_dtor_coiter_arg _ [] [cf] = cf + | build_dtor_coiter_arg T [cq] [cf, cf'] = + mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf) + (build_sum_inj Inr_const (fastype_of cf', T) $ cf'); + + fun mk_args qssss fssss (_, f_Tsss, _) = let val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss; val cqssss = map2 (map o map o map o rapp) cs qssss; val cfssss = map2 (map o map o map o rapp) cs fssss; - in (pfss, cqssss, cfssss) end; + val cqfsss = map3 (map3 (map3 build_dtor_coiter_arg)) f_Tsss cqssss cfssss; + in (pfss, cqfsss) end; - val unfold_args = mk_args rssss gssss; - val corec_args = mk_args sssss hssss; + val unfold_args = mk_args rssss gssss unfold_types; + val corec_args = mk_args sssss hssss corec_types; in - ((cs, cpss, [(unfold_args, unfold_types), (corec_args, corec_types)]), lthy) + ((z, cs, cpss, [(unfold_args, unfold_types), (corec_args, corec_types)]), lthy) end; fun mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy = @@ -368,79 +449,6 @@ ((xtor_co_iterss, iters_args_types, coiters_args_types), lthy') end; -fun mk_map live Ts Us t = - let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in - Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t - end; - -fun mk_rel live Ts Us t = - let val [Type (_, Ts0), Type (_, Us0)] = binder_types (snd (strip_typeN live (fastype_of t))) in - Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t - end; - -fun liveness_of_fp_bnf n bnf = - (case T_of_bnf bnf of - Type (_, Ts) => map (not o member (op =) (deads_of_bnf bnf)) Ts - | _ => replicate n false); - -fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters"; - -fun merge_type_arg T T' = if T = T' then T else cannot_merge_types (); - -fun merge_type_args (As, As') = - if length As = length As' then map2 merge_type_arg As As' else cannot_merge_types (); - -fun reassoc_conjs thm = - reassoc_conjs (thm RS @{thm conj_assoc[THEN iffD1]}) - handle THM _ => thm; - -fun type_args_named_constrained_of ((((ncAs, _), _), _), _) = ncAs; -fun type_binding_of ((((_, b), _), _), _) = b; -fun map_binding_of (((_, (b, _)), _), _) = b; -fun rel_binding_of (((_, (_, b)), _), _) = b; -fun mixfix_of ((_, mx), _) = mx; -fun ctr_specs_of (_, ctr_specs) = ctr_specs; - -fun disc_of ((((disc, _), _), _), _) = disc; -fun ctr_of ((((_, ctr), _), _), _) = ctr; -fun args_of (((_, args), _), _) = args; -fun defaults_of ((_, ds), _) = ds; -fun ctr_mixfix_of (_, mx) = mx; - -fun add_nesty_bnf_names Us = - let - fun add (Type (s, Ts)) ss = - let val (needs, ss') = fold_map add Ts ss in - if exists I needs then (true, insert (op =) s ss') else (false, ss') - end - | add T ss = (member (op =) Us T, ss); - in snd oo add end; - -fun nesty_bnfs ctxt ctr_Tsss Us = - map_filter (bnf_of ctxt) (fold (fold (fold (add_nesty_bnf_names Us))) ctr_Tsss []); - -fun indexify proj xs f p = f (find_index (curry (op =) (proj p)) xs) p; - -fun build_map lthy build_simple = - let - fun build (TU as (T, U)) = - if T = U then - id_const T - else - (case TU of - (Type (s, Ts), Type (s', Us)) => - if s = s' then - let - val bnf = the (bnf_of lthy s); - val live = live_of_bnf bnf; - val mapx = mk_map live Ts Us (map_of_bnf bnf); - val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx))); - in Term.list_comb (mapx, map build TUs') end - else - build_simple TU - | _ => build_simple TU); - in build end; - fun mk_iter_body ctor_iter fss xssss = Term.list_comb (ctor_iter, map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss); @@ -450,19 +458,8 @@ (map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cqfss) (1 upto n))) end; -fun mk_coiter_body lthy cs cpss f_sum_prod_Ts f_Tsss cqssss cfssss dtor_coiter = - let - fun build_sum_inj mk_inj = build_map lthy (uncurry mk_inj o dest_sumT o snd); - - fun build_dtor_coiter_arg _ [] [cf] = cf - | build_dtor_coiter_arg T [cq] [cf, cf'] = - mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf) - (build_sum_inj Inr_const (fastype_of cf', T) $ cf') - - val cqfsss = map3 (map3 (map3 build_dtor_coiter_arg)) f_Tsss cqssss cfssss; - in - Term.list_comb (dtor_coiter, map4 mk_preds_getterss_join cs cpss f_sum_prod_Ts cqfsss) - end; +fun mk_coiter_body lthy cs cpss f_sum_prod_Ts f_Tsss cqfsss dtor_coiter = + Term.list_comb (dtor_coiter, map4 mk_preds_getterss_join cs cpss f_sum_prod_Ts cqfsss); fun define_co_iters fp fpT Cs binding_specs lthy0 = let @@ -500,19 +497,20 @@ define_co_iters Least_FP fpT Cs (map3 generate_iter iterNs iter_args_typess' ctor_iters) lthy end; -fun define_coiters coiterNs (cs, cpss, coiter_args_typess') mk_binding fpTs Cs dtor_coiters lthy = +fun define_coiters coiterNs (_, cs, cpss, coiter_args_typess') mk_binding fpTs Cs dtor_coiters + lthy = let val nn = length fpTs; val C_to_fpT as Type (_, [_, fpT]) = snd (strip_typeN nn (fastype_of (hd dtor_coiters))); - fun generate_coiter suf ((pfss, cqssss, cfssss), (f_sum_prod_Ts, f_Tsss, pf_Tss)) dtor_coiter = + fun generate_coiter suf ((pfss, cqfsss), (f_sum_prod_Ts, f_Tsss, pf_Tss)) dtor_coiter = let val res_T = fold_rev (curry (op --->)) pf_Tss C_to_fpT; val b = mk_binding suf; val spec = mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of b, res_T)), - mk_coiter_body lthy cs cpss f_sum_prod_Ts f_Tsss cqssss cfssss dtor_coiter); + mk_coiter_body lthy cs cpss f_sum_prod_Ts f_Tsss cqfsss dtor_coiter); in (b, spec) end; in define_co_iters Greatest_FP fpT Cs @@ -676,8 +674,8 @@ (fold_thmss, code_simp_attrs), (rec_thmss, code_simp_attrs)) end; -fun derive_coinduct_coiters_thms_for_types pre_bnfs (cs, cpss, - [(unfold_args as (pgss, _, cgssss), _), (corec_args as (phss, _, _), _)]) +fun derive_coinduct_coiters_thms_for_types pre_bnfs (z, cs, cpss, + [(unfold_args as (pgss, crgsss), _), (corec_args as (phss, cshsss), _)]) dtor_coinducts dtor_ctors dtor_coiter_thmss nesting_bnfs nested_bnfs fpTs Cs As kss mss ns ctr_defss ctr_sugars coiterss coiter_defss lthy = let @@ -693,7 +691,6 @@ val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nesting_bnfs; val nesting_rel_eqs = map rel_eq_of_bnf nesting_bnfs; val nested_map_comp's = map map_comp'_of_bnf nested_bnfs; - val nested_map_comps'' = map ((fn thm => thm RS sym) o map_comp_of_bnf) nested_bnfs; val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs; val fp_b_names = map base_name_of_typ fpTs; @@ -728,6 +725,7 @@ map4 (fn u => fn v => fn uvr => fn uv_eq => fold_rev Term.lambda [u, v] (HOLogic.mk_disj (uvr, uv_eq))) us vs uvrs uv_eqs; + (* TODO: generalize (cf. "build_map") *) fun build_rel rs' T = (case find_index (curry (op =) T) fpTs of ~1 => @@ -807,7 +805,7 @@ fun mk_maybe_not pos = not pos ? HOLogic.mk_not; val fcoiterss' as [gunfolds, hcorecs] = - map2 (fn (pfss, _, _) => map (lists_bmoc pfss)) [unfold_args, corec_args] coiterss'; + map2 (fn (pfss, _) => map (lists_bmoc pfss)) [unfold_args, corec_args] coiterss'; val (unfold_thmss, corec_thmss, safe_unfold_thmss, safe_corec_thmss) = let @@ -816,25 +814,48 @@ (Logic.list_implies (seq_conds (HOLogic.mk_Trueprop oo mk_maybe_not) n k cps, mk_Trueprop_eq (fcoiter $ c, Term.list_comb (ctr, take m cfs')))); - val substC = typ_subst_nonatomic (map2 pair Cs fpTs); + fun build_coiter fcoiters maybe_tack (T, U) = + if T = U then + id_const T + else + (case find_index (curry (op =) U) fpTs of + ~1 => build_map lthy (build_coiter fcoiters maybe_tack) (T, U) + | kk => maybe_tack (nth cs kk, nth us kk) (nth fcoiters kk)); + + fun mk_U maybe_mk_sumT = + typ_subst_nonatomic (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs); + + fun tack z_name (c, u) f = + let val z = Free (z_name, mk_sumT (fastype_of u, fastype_of c)) in + Term.lambda z (mk_sum_case (Term.lambda u u, Term.lambda c (f $ c)) $ z) + end; - fun intr_coiters fcoiters [] [cf] = - let val T = fastype_of cf in - if exists_subtype_in Cs T then - build_map lthy (indexify fst Cs (K o nth fcoiters)) (T, substC T) $ cf - else - cf - end - | intr_coiters fcoiters [cq] [cf, cf'] = - mk_If cq (intr_coiters fcoiters [] [cf]) (intr_coiters fcoiters [] [cf']); + fun intr_coiters fcoiters maybe_mk_sumT maybe_tack cqf = + let val T = fastype_of cqf in + if exists_subtype_in Cs T then + build_coiter fcoiters maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf + else + cqf + end; - val [crgsss, cshsss] = + val crgsss' = + (fn fcoiters => fn (_, cqfsss) => + map (map (map (intr_coiters fcoiters (K I) (K I)))) cqfsss) + (un_fold_of fcoiterss') unfold_args; + val cshsss' = + (fn fcoiters => fn (_, cqfsss) => + map (map (map (intr_coiters fcoiters (curry mk_sumT) (tack z)))) cqfsss) + (co_rec_of fcoiterss') corec_args; + +(*### + val [crgsss', cshsss'] = map2 (fn fcoiters => fn (_, cqssss, cfssss) => map2 (map2 (map2 (intr_coiters fcoiters))) cqssss cfssss) fcoiterss' [unfold_args, corec_args]; +*) - val unfold_goalss = map8 (map4 oooo mk_goal pgss) cs cpss gunfolds ns kss ctrss mss crgsss; - val corec_goalss = map8 (map4 oooo mk_goal phss) cs cpss hcorecs ns kss ctrss mss cshsss; + val unfold_goalss = map8 (map4 oooo mk_goal pgss) cs cpss gunfolds ns kss ctrss mss crgsss'; + val corec_goalss = map8 (map4 oooo mk_goal phss) cs cpss hcorecs ns kss ctrss mss cshsss'; fun mk_map_if_distrib bnf = let @@ -851,10 +872,10 @@ val nested_map_if_distribs = map mk_map_if_distrib nested_bnfs; val unfold_tacss = - map3 (map oo mk_coiter_tac unfold_defs [] [] nesting_map_ids'' []) + map3 (map oo mk_coiter_tac unfold_defs [] nesting_map_ids'' []) (map un_fold_of dtor_coiter_thmss) pre_map_defs ctr_defss; val corec_tacss = - map3 (map oo mk_coiter_tac corec_defs nested_map_comps'' nested_map_comp's + map3 (map oo mk_coiter_tac corec_defs nested_map_comp's (nested_map_ids'' @ nesting_map_ids'') nested_map_if_distribs) (map co_rec_of dtor_coiter_thmss) pre_map_defs ctr_defss; @@ -865,12 +886,19 @@ val unfold_thmss = map2 (map2 prove) unfold_goalss unfold_tacss; val corec_thmss = map2 (map2 prove) corec_goalss corec_tacss; + val corec_thmss = + map2 (map2 prove) corec_goalss corec_tacss + |> map (map (unfold_thms lthy @{thms sum_case_if})); + + val unfold_safesss = map2 (map2 (map2 (curry (op =)))) crgsss' crgsss; + val corec_safesss = map2 (map2 (map2 (curry (op =)))) cshsss' cshsss; + val filter_safesss = map2 (map_filter (fn (safes, thm) => if forall I safes then SOME thm else NONE) oo - curry (op ~~)) (map2 (map2 (map2 (member (op =)))) cgssss crgsss); + curry (op ~~)); - val safe_unfold_thmss = filter_safesss unfold_thmss; - val safe_corec_thmss = filter_safesss corec_thmss; + val safe_unfold_thmss = filter_safesss unfold_safesss unfold_thmss; + val safe_corec_thmss = filter_safesss corec_safesss corec_thmss; in (unfold_thmss, corec_thmss, safe_unfold_thmss, safe_corec_thmss) end; diff -r 856b3bd1d87e -r ead18e3b2c1b src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML --- a/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML Fri Jun 07 12:54:40 2013 +0200 +++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML Fri Jun 07 14:45:07 2013 +0200 @@ -14,8 +14,8 @@ val mk_case_tac: Proof.context -> int -> int -> int -> thm -> thm -> thm -> tactic val mk_coinduct_tac: Proof.context -> thm list -> int -> int list -> thm -> thm list -> thm list -> thm list -> thm list list -> thm list list list -> thm list list list -> tactic - val mk_coiter_tac: thm list -> thm list -> thm list -> thm list -> thm list -> thm -> thm -> - thm -> Proof.context -> tactic + val mk_coiter_tac: thm list -> thm list -> thm list -> thm list -> thm -> thm -> thm -> + Proof.context -> tactic val mk_ctor_iff_dtor_tac: Proof.context -> ctyp option list -> cterm -> cterm -> thm -> thm -> tactic val mk_disc_coiter_iff_tac: thm list -> thm list -> thm list -> Proof.context -> tactic @@ -109,13 +109,13 @@ val coiter_unfold_thms = @{thms id_def ident_o_ident sum_case_if sum_case_o_inj} @ sum_prod_thms_map; -fun mk_coiter_tac coiter_defs map_comps'' map_comp's map_ids'' map_if_distribs +fun mk_coiter_tac coiter_defs map_comp's map_ids'' map_if_distribs ctor_dtor_coiter pre_map_def ctr_def ctxt = unfold_thms_tac ctxt (ctr_def :: coiter_defs) THEN HEADGOAL (rtac (ctor_dtor_coiter RS trans) THEN' asm_simp_tac (put_simpset ss_if_True_False ctxt)) THEN_MAYBE - (unfold_thms_tac ctxt (pre_map_def :: map_comp's @ map_comps'' @ map_ids'' @ map_if_distribs @ - coiter_unfold_thms) THEN + (unfold_thms_tac ctxt (pre_map_def :: map_comp's @ map_ids'' @ map_if_distribs @ + coiter_unfold_thms) THEN HEADGOAL (rtac refl ORELSE' rtac (@{thm unit_eq} RS arg_cong))); fun mk_disc_coiter_iff_tac case_splits' coiters discs ctxt =