# HG changeset patch # User blanchet # Date 1347131066 -7200 # Node ID ca59649170b0738e7cd65c0dced3f8a28398a9c0 # Parent 239a4fa29ddfd5d033ff62bd985b94835381160d more sugar on codatatypes diff -r 239a4fa29ddf -r ca59649170b0 src/HOL/Codatatype/Tools/bnf_fp_sugar.ML --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sat Sep 08 21:04:26 2012 +0200 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Sat Sep 08 21:04:26 2012 +0200 @@ -21,6 +21,8 @@ open BNF_FP_Sugar_Tactics val caseN = "case"; +val coitersN = "iters"; +val corecsN = "recs"; val itersN = "iters"; val recsN = "recs"; @@ -28,13 +30,16 @@ fun retype_free (Free (s, _)) T = Free (s, T); -fun flat_list_comb (f, xss) = fold (fn xs => fn t => Term.list_comb (t, xs)) xss f +val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs)) fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs)); fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs; fun mk_uncurried2_fun f xss = mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss); +fun popescu_zip [] [fs] = fs + | popescu_zip (p :: ps) (fs :: fss) = p :: fs @ popescu_zip ps fss; + fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters"; fun merge_type_arg_constrained ctxt (T, c) (T', c') = @@ -148,10 +153,10 @@ val flds = map (mk_fld As) flds0; val fpTs = map (domain_type o fastype_of) unfs; - val is_fpT = member (op =) fpTs; val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs; val ns = map length ctr_Tsss; + val kss = map (fn n => 1 upto n) ns; val mss = map (map length) ctr_Tsss; val Css = map2 replicate ns Cs; @@ -168,20 +173,21 @@ val fp_iters as fp_iter1 :: _ = map (mk_iter_like As Cs) fp_iters0; val fp_recs as fp_rec1 :: _ = map (mk_iter_like As Cs) fp_recs0; - val fp_iter_g_Ts = fst (split_last (binder_types (fastype_of fp_iter1))); - val fp_rec_h_Ts = fst (split_last (binder_types (fastype_of fp_rec1))); + val fp_iter_fun_Ts = fst (split_last (binder_types (fastype_of fp_iter1))); + val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of fp_rec1))); fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) = if member (op =) Cs U then Us else [T] | dest_rec_pair T = [T]; - val (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)), - (cs, pss, p_Tss, coiter_extra, corec_extra)) = + val (((gss, g_Tss, ysss), (hss, h_Tss, zssss)), + (cs, cpss, p_Tss, coiter_extra as ((pgss, cgsss), g_sum_prod_Ts, g_prod_Tss, g_Tsss), + corec_extra as ((phss, chsss), h_sum_prod_Ts, h_prod_Tss, h_Tsss))) = if lfp then let val y_Tsss = map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN n o domain_type) - ns mss fp_iter_g_Ts; + ns mss fp_iter_fun_Ts; val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css; val ((gss, ysss), _) = @@ -191,7 +197,7 @@ val z_Tssss = map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n - o domain_type) ns mss fp_rec_h_Ts; + o domain_type) ns mss fp_rec_fun_Ts; val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css; val hss = map2 (map2 retype_free) gss h_Tss; @@ -199,23 +205,25 @@ lthy |> mk_Freessss "x" z_Tssss; in - (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)), - ([], [], [], ([], [], [], []), ([], [], [], []))) + (((gss, g_Tss, ysss), (hss, h_Tss, zssss)), + ([], [], [], (([], []), [], [], []), (([], []), [], [], []))) end else let - fun mk_to_dest_prodT C = map2 (map (curry (op -->) C) oo dest_tupleT); - val p_Tss = map2 (fn C => fn n => replicate (Int.max (0, n - 1)) (C --> HOLogic.boolT)) Cs ns; - val g_sum_prod_Ts = map range_type fp_iter_g_Ts; - val g_prod_Tss = map2 dest_sumTN ns g_sum_prod_Ts; - val g_Tsss = map3 mk_to_dest_prodT Cs mss g_prod_Tss; + fun mk_types fun_Ts = + let + val f_sum_prod_Ts = map range_type fun_Ts; + val f_prod_Tss = map2 dest_sumTN ns f_sum_prod_Ts; + val f_Tsss = + map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss f_prod_Tss; + val pf_Tss = map2 popescu_zip p_Tss f_Tsss + in (f_sum_prod_Ts, f_prod_Tss, f_Tsss, pf_Tss) end; - val h_sum_prod_Ts = map range_type fp_rec_h_Ts; - val h_prod_Tss = map2 dest_sumTN ns h_sum_prod_Ts; - val h_Tsss = map3 mk_to_dest_prodT Cs mss h_prod_Tss; + val (g_sum_prod_Ts, g_prod_Tss, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts; + val (h_sum_prod_Ts, h_prod_Tss, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts; val (((c, pss), gsss), _) = lthy @@ -226,20 +234,23 @@ val hsss = map2 (map2 (map2 retype_free)) gsss h_Tsss; val cs = map (retype_free c) Cs; + val cpss = map2 (fn c => map (fn p => p $ c)) cs pss; + + fun mk_terms fsss = + let + val pfss = map2 popescu_zip pss fsss; + val cfsss = map2 (fn c => map (map (fn f => f $ c))) cs fsss + in (pfss, cfsss) end; in - ((([], [], [], []), ([], [], [], [])), - (cs, pss, p_Tss, (gsss, g_sum_prod_Ts, g_prod_Tss, g_Tsss), - (hsss, h_sum_prod_Ts, h_prod_Tss, h_Tsss))) + ((([], [], []), ([], [], [])), + (cs, cpss, p_Tss, (mk_terms gsss, g_sum_prod_Ts, g_prod_Tss, pg_Tss), + (mk_terms hsss, h_sum_prod_Ts, h_prod_Tss, ph_Tss))) end; - fun pour_some_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf), - unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), - sel_binderss) no_defs_lthy = + fun pour_some_sugar_on_type (((((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), + fld_unf), unf_fld), fld_inject), n), ks), ms), ctr_binders), ctr_mixfixes), ctr_Tss), + disc_binders), sel_binderss) no_defs_lthy = let - val n = length ctr_Tss; - val ks = 1 upto n; - val ms = map length ctr_Tss; - val unfT = domain_type (fastype_of fld); val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss; val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss; @@ -324,10 +335,10 @@ val rec_binder = Binding.suffix_name ("_" ^ recN) b; val iter_spec = - mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of iter_binder, iter_T), gss), + mk_Trueprop_eq (lists_bmoc gss (Free (Binding.name_of iter_binder, iter_T)), Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss)); val rec_spec = - mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of rec_binder, rec_T), hss), + mk_Trueprop_eq (lists_bmoc hss (Free (Binding.name_of rec_binder, rec_T)), Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) hss zssss)); val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy @@ -353,21 +364,13 @@ fun some_gfp_sugar no_defs_lthy = let - fun zip_preds_and_getters ps fss = ps @ flat fss; - val B_to_fpT = C --> fpT; - val cpss = map2 (fn c => map (fn p => p $ c)) cs pss; - - fun generate_coiter_like (suf, fp_iter_like, - (fsss, f_sum_prod_Ts, f_prod_Tss, f_Tsss)) = + fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfsss), f_sum_prod_Ts, f_prod_Tss, + pf_Tss)) = let - val pf_Tss = map2 zip_preds_and_getters p_Tss f_Tsss; val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT; - val pfss = map2 zip_preds_and_getters pss fsss; - val cfsss = map2 (fn c => map (map (fn f => f $ c))) cs fsss; - val binder = Binding.suffix_name ("_" ^ suf) b; fun mk_join c n cps sum_prod_T prod_Ts cfss = @@ -375,12 +378,15 @@ (map2 (mk_InN prod_Ts) (map HOLogic.mk_tuple cfss) (1 upto n))); val spec = - mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of binder, res_T), pfss), + mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)), Term.list_comb (fp_iter_like, map6 mk_join cs ns cpss f_sum_prod_Ts f_prod_Tss cfsss)); in (binder, spec) end; - val coiter_likes = [(coiterN, fp_iter, coiter_extra), (corecN, fp_rec, corec_extra)]; + val coiter_likes = + [(coiterN, fp_iter, coiter_extra), + (corecN, fp_rec, corec_extra)]; + val (binders, specs) = map generate_coiter_like coiter_likes |> split_list; val ((csts, defs), (lthy', lthy)) = no_defs_lthy @@ -403,29 +409,29 @@ |> (if lfp then some_lfp_sugar else some_gfp_sugar) end; - fun pour_more_sugar_on_datatypes ((ctrss, iters, recs, xsss, ctr_defss, iter_defs, rec_defs), + fun pour_more_sugar_on_lfps ((ctrss, iters, recs, xsss, ctr_defss, iter_defs, rec_defs), lthy) = let val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss; - val giters = map (fn iter => flat_list_comb (iter, gss)) iters; - val hrecs = map (fn recx => flat_list_comb (recx, hss)) recs; + val giters = map (lists_bmoc gss) iters; + val hrecs = map (lists_bmoc hss) recs; val (iter_thmss, rec_thmss) = let - fun mk_goal_iter_like fss fc xctr f xs xs' = + fun mk_goal_iter_like fss fiter_like xctr f xs fxs = fold_rev (fold_rev Logic.all) (xs :: fss) - (mk_Trueprop_eq (fc $ xctr, Term.list_comb (f, xs'))); + (mk_Trueprop_eq (fiter_like $ xctr, Term.list_comb (f, fxs))); - fun fix_iter_free (x as Free (_, T)) = - (case find_index (eq_fpT T) fpTs of ~1 => x | j => nth giters j $ x); - fun fix_rec_free (x as Free (_, T)) = - (case find_index (eq_fpT T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]); + fun repair_iter_call (x as Free (_, T)) = + (case find_index (curry (op =) T) fpTs of ~1 => x | j => nth giters j $ x); + fun repair_rec_call (x as Free (_, T)) = + (case find_index (curry (op =) T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]); - val iter_xsss = map (map (map fix_iter_free)) xsss; - val rec_xsss = map (map (maps fix_rec_free)) xsss; + val gxsss = map (map (map repair_iter_call)) xsss; + val hxsss = map (map (maps repair_rec_call)) xsss; - val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss iter_xsss; - val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss rec_xsss; + val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss gxsss; + val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss hxsss; val iter_tacss = map2 (map o mk_iter_like_tac pre_map_defs iter_defs) fp_iter_thms ctr_defss; @@ -449,12 +455,53 @@ lthy |> Local_Theory.notes notes |> snd end; + fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, xsss, ctr_defss, coiter_defs, corec_defs), + lthy) = + let + val gcoiters = map (lists_bmoc pgss) coiters; + val hcorecs = map (lists_bmoc phss) corecs; + + val (coiter_thmss, corec_thmss) = + let + fun mk_cond pos = HOLogic.mk_Trueprop o (not pos ? HOLogic.mk_not); + + fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr cfs' = + fold_rev (fold_rev Logic.all) ([c] :: pfss) + (Logic.list_implies (seq_conds mk_cond n k cps, + mk_Trueprop_eq (fcoiter_like $ c, Term.list_comb (ctr, cfs')))); + + fun repair_coiter_like_call fcoiter_likes (cf as Free (_, Type (_, [_, T])) $ _) = + (case find_index (curry (op =) T) Cs of ~1 => cf | j => nth fcoiter_likes j $ cf); + + val cgsss = map (map (map (repair_coiter_like_call gcoiters))) cgsss; + val chsss = map (map (map (repair_coiter_like_call hcorecs))) chsss; + + val goal_coiterss = + map7 (map3 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss cgsss; + val goal_corecss = + map7 (map3 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss chsss; + in + (map (map (Skip_Proof.make_thm (Proof_Context.theory_of lthy))) goal_coiterss, + map (map (Skip_Proof.make_thm (Proof_Context.theory_of lthy))) goal_coiterss (*### goal_corecss*)) + end; + + val notes = + [(coitersN, coiter_thmss), + (corecsN, corec_thmss)] + |> maps (fn (thmN, thmss) => + map2 (fn b => fn thms => + ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])])) + bs thmss); + in + lthy |> Local_Theory.notes notes |> snd + end; + val lthy' = lthy |> fold_map pour_some_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~ - fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~ - ctr_Tsss ~~ disc_binderss ~~ sel_bindersss) + fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_binderss ~~ + ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss) |>> split_list7 - |> (if lfp then pour_more_sugar_on_datatypes else snd); + |> (if lfp then pour_more_sugar_on_lfps else pour_more_sugar_on_gfps); val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^ (if lfp then "" else "co") ^ "datatype")); diff -r 239a4fa29ddf -r ca59649170b0 src/HOL/Codatatype/Tools/bnf_util.ML --- a/src/HOL/Codatatype/Tools/bnf_util.ML Sat Sep 08 21:04:26 2012 +0200 +++ b/src/HOL/Codatatype/Tools/bnf_util.ML Sat Sep 08 21:04:26 2012 +0200 @@ -42,6 +42,7 @@ 'a list -> 'b list -> 'c list -> 'd list -> 'e list -> 'f list -> 'g list -> 'h -> 'i list * 'h val interleave: 'a list -> 'a list -> 'a list val transpose: 'a list list -> 'a list list + val seq_conds: (bool -> 'a -> 'b) -> int -> int -> 'a list -> 'b list val mk_fresh_names: Proof.context -> int -> string -> string list * Proof.context val mk_TFrees: int -> Proof.context -> typ list * Proof.context @@ -537,6 +538,14 @@ | transpose ([] :: xss) = transpose xss | transpose xss = map hd xss :: transpose (map tl xss); +fun seq_conds f n k xs = + if k = n then + map (f false) (take (k - 1) xs) + else + let val (negs, pos) = split_last (take k xs) in + map (f false) negs @ [f true pos] + end; + fun mk_unabs_def 0 thm = thm | mk_unabs_def n thm = mk_unabs_def (n - 1) thm RS @{thm spec[OF iffD1[OF fun_eq_iff]]}; diff -r 239a4fa29ddf -r ca59649170b0 src/HOL/Codatatype/Tools/bnf_wrap_tactics.ML --- a/src/HOL/Codatatype/Tools/bnf_wrap_tactics.ML Sat Sep 08 21:04:26 2012 +0200 +++ b/src/HOL/Codatatype/Tools/bnf_wrap_tactics.ML Sat Sep 08 21:04:26 2012 +0200 @@ -27,13 +27,7 @@ open BNF_Util open BNF_Tactics -fun triangle _ [] = [] - | triangle k (xs :: xss) = take k xs :: triangle (k + 1) xss - -fun mk_case_if_P_or_not_Ps n k thms = - let val (negs, pos) = split_last thms in - map (fn thm => thm RS @{thm if_not_P}) negs @ (if k = n then [] else [pos RS @{thm if_P}]) - end; +fun if_P_or_not_P_OF pos thm = thm RS (if pos then @{thm if_P} else @{thm if_not_P}); fun ss_only thms = Simplifier.clear_ss HOL_basic_ss addsimps thms @@ -80,7 +74,7 @@ EVERY' (map3 (fn case_thm => fn if_disc_thms => fn sel_thms => EVERY' [hyp_subst_tac, SELECT_GOAL (Local_Defs.unfold_tac ctxt (if_disc_thms @ sel_thms)), rtac case_thm]) - case_thms (map2 (mk_case_if_P_or_not_Ps n) (1 upto n) (triangle 1 disc_thmss')) sel_thmss)) 1; + case_thms (map2 (seq_conds if_P_or_not_P_OF n) (1 upto n) disc_thmss') sel_thmss)) 1; fun mk_case_cong_tac exhaust' case_thms = (rtac exhaust' THEN'