# HG changeset patch # User blanchet # Date 1367312305 -7200 # Node ID 3cc93eeac8cc83ea10a1532bb8d1cfb2bc2afd50 # Parent 67c6d6136915d84ce81a9cefc0769f96a0ce8d6f signature tuning diff -r 67c6d6136915 -r 3cc93eeac8cc src/HOL/BNF/Tools/bnf_fp_def_sugar.ML --- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue Apr 30 10:07:41 2013 +0200 +++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Tue Apr 30 10:58:25 2013 +0200 @@ -21,14 +21,11 @@ Proof.context -> (thm * thm list * Args.src list) * (thm list list * Args.src list) * (thm list list * Args.src list) - val derive_coinduct_unfold_corec_thms_for_types: Proof.context -> Proof.context -> - BNF_Def.BNF list -> thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.BNF list -> - BNF_Def.BNF list -> typ list -> typ list -> typ list -> int list list -> int list list -> - int list -> term list -> term list list -> term list list -> term list list list list -> - term list list list list -> term list list -> term list list list list -> - term list list list list -> term list list -> thm list list -> - BNF_Ctr_Sugar.ctr_wrap_result list -> term list -> term list -> thm list -> thm list -> - Proof.context -> + val derive_coinduct_unfold_corec_thms_for_types: BNF_Def.BNF list -> term list -> term list -> + thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.BNF list -> BNF_Def.BNF list -> + typ list -> typ list -> typ list -> int list list -> int list list -> int list -> + term list list -> thm list list -> BNF_Ctr_Sugar.ctr_wrap_result list -> term list -> + term list -> thm list -> thm list -> Proof.context -> (thm * thm list * thm * thm list * Args.src list) * (thm list list * thm list list * 'e list) * (thm list list * thm list list) * (thm list list * thm list list * Args.src list) * (thm list list * thm list list * Args.src list) @@ -158,6 +155,12 @@ maps fst ps @ maps snd ps end; +fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss); + +fun flat_preds_predsss_gettersss [] [qss] [fss] = flat_predss_getterss qss fss + | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) = + p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss; + fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) = Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1)); @@ -196,23 +199,86 @@ fun mk_rec_like_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type; -fun project_recT fpTs proj = +fun massage_rec_fun_arg_typesss fpTs = let - fun project (Type (s as @{type_name prod}, Ts as [T, U])) = - if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts) - | project (Type (s, Ts)) = Type (s, map project Ts) - | project T = T; - in project end; - -fun unzip_recT fpTs T = - if exists_subtype_in fpTs T then ([project_recT fpTs fst T], [project_recT fpTs snd T]) - else ([T], []); - -fun massage_rec_fun_arg_typesss fpTs = map (map (flat_rec (unzip_recT fpTs))); + fun project_recT proj = + let + fun project (Type (s as @{type_name prod}, Ts as [T, U])) = + if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts) + | project (Type (s, Ts)) = Type (s, map project Ts) + | project T = T; + in project end; + fun unzip_recT T = + if exists_subtype_in fpTs T then ([project_recT fst T], [project_recT snd T]) else ([T], []); + in + map (map (flat_rec unzip_recT)) + end; val mk_fold_fun_typess = map2 (map2 (curry (op --->))); val mk_rec_fun_typess = mk_fold_fun_typess oo massage_rec_fun_arg_typesss; +fun mk_corec_like_pred_types n = replicate (Int.max (0, n - 1)) o mk_pred1T; + +fun mk_unfold_corec_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts = + let + (*avoid "'a itself" arguments in coiterators and corecursors*) + fun repair_arity [0] = [1] + | repair_arity ms = ms; + + fun project_corecT proj = + let + fun project (Type (s as @{type_name sum}, Ts as [T, U])) = + if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts) + | project (Type (s, Ts)) = Type (s, map project Ts) + | project T = T; + in project end; + + fun unzip_corecT T = + if exists_subtype_in fpTs T then [project_corecT fst T, project_corecT snd T] else [T]; + + val p_Tss = map2 mk_corec_like_pred_types ns Cs; + + fun mk_types maybe_unzipT fun_Ts = + let + val f_sum_prod_Ts = map range_type fun_Ts; + val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts; + val f_Tsss = map2 (map2 dest_tupleT o repair_arity) mss f_prod_Tss; + val f_Tssss = + map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss; + val q_Tssss = + map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss; + val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss; + in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end; + in + (p_Tss, mk_types single dtor_unfold_fun_Ts, mk_types unzip_corecT dtor_corec_fun_Ts) + end + +fun mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss lthy = + let + val (((cs, pss), gssss), lthy) = + lthy + |> mk_Frees "a" Cs + ||>> mk_Freess "p" p_Tss + ||>> mk_Freessss "g" g_Tssss; + val rssss = map (map (map (fn [] => []))) r_Tssss; + + val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss; + val ((sssss, hssss_tl), lthy) = + lthy + |> mk_Freessss "q" s_Tssss + ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss); + val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl; + in + ((cs, pss, (gssss, rssss), (hssss, sssss)), lthy) + end; + +fun mk_corec_like_terms cs pss qssss fssss = + let + val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss; + val cqssss = map2 (map o map o map o rapp) cs qssss; + val cfssss = map2 (map o map o map o rapp) cs fssss; + in (pfss, cqssss, cfssss) end; + fun mk_map live Ts Us t = let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t @@ -440,10 +506,9 @@ (fold_thmss, code_simp_attrs), (rec_thmss, code_simp_attrs)) end; -fun derive_coinduct_unfold_corec_thms_for_types names_lthy0 no_defs_lthy pre_bnfs dtor_coinduct +fun derive_coinduct_unfold_corec_thms_for_types pre_bnfs dtor_unfolds0 dtor_corecs0 dtor_coinduct dtor_strong_induct dtor_ctors dtor_unfold_thms dtor_corec_thms nesting_bnfs nested_bnfs fpTs Cs - As kss mss ns cs cpss pgss crssss cgssss phss csssss chssss ctrss ctr_defss ctr_wrap_ress - unfolds corecs unfold_defs corec_defs lthy = + As kss mss ns ctrss ctr_defss ctr_wrap_ress unfolds corecs unfold_defs corec_defs lthy = let val nn = length pre_bnfs; @@ -457,6 +522,9 @@ val fp_b_names = map base_name_of_typ fpTs; + val (_, dtor_unfold_fun_Ts) = mk_fp_rec_like false As Cs dtor_unfolds0; + val (_, dtor_corec_fun_Ts) = mk_fp_rec_like false As Cs dtor_corecs0; + val discss = map (map (mk_disc_or_sel As) o #discs) ctr_wrap_ress; val selsss = map (map (map (mk_disc_or_sel As)) o #selss) ctr_wrap_ress; val exhausts = map #exhaust ctr_wrap_ress; @@ -470,6 +538,15 @@ ||>> Variable.variant_fixes fp_b_names ||>> Variable.variant_fixes (map (suffix "'") fp_b_names); + val (p_Tss, (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss), + (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss)) = + mk_unfold_corec_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts; + + val ((cs, pss, (gssss, rssss), (hssss, sssss)), names_lthy) = + mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss names_lthy; + + val cpss = map2 (map o rapp) cs pss; + val us = map2 (curry Free) us' fpTs; val udiscss = map2 (map o rapp) us discss; val uselsss = map2 (map o map o rapp) us selsss; @@ -478,6 +555,9 @@ val vdiscss = map2 (map o rapp) vs discss; val vselsss = map2 (map o map o rapp) vs selsss; + val (pgss, crssss, cgssss) = mk_corec_like_terms cs pss rssss gssss; + val (phss, csssss, chssss) = mk_corec_like_terms cs pss sssss hssss; + val ((coinduct_thms, coinduct_thm), (strong_coinduct_thms, strong_coinduct_thm)) = let val uvrs = map3 (fn r => fn u => fn v => r $ u $ v) rs us vs; @@ -652,7 +732,7 @@ fun prove goal tac = Goal.prove_sorry lthy [] [] goal (tac o #context) - |> singleton (Proof_Context.export names_lthy0 no_defs_lthy) + |> singleton (Proof_Context.export names_lthy lthy) |> Thm.close_derivation; fun proves [_] [_] = [] @@ -894,68 +974,18 @@ end else let - (*avoid "'a itself" arguments in coiterators and corecursors*) - val mss' = map (fn [0] => [1] | ms => ms) mss; - - val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs; - - fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss); - - fun flat_preds_predsss_gettersss [] [qss] [fss] = flat_predss_getterss qss fss - | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) = - p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss; - - fun mk_types maybe_unzipT fun_Ts = - let - val f_sum_prod_Ts = map range_type fun_Ts; - val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts; - val f_Tsss = map2 (map2 dest_tupleT) mss' f_prod_Tss; - val f_Tssss = - map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss; - val q_Tssss = - map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss; - val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss; - in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end; - - val (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts; + val (p_Tss, (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss), + (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss)) = + mk_unfold_corec_types fpTs Cs ns mss fp_fold_fun_Ts fp_rec_fun_Ts; - val (((cs, pss), gssss), lthy) = - lthy - |> mk_Frees "a" Cs - ||>> mk_Freess "p" p_Tss - ||>> mk_Freessss "g" g_Tssss; - val rssss = map (map (map (fn [] => []))) r_Tssss; - - fun proj_corecT proj (Type (s as @{type_name sum}, Ts as [T, U])) = - if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_corecT proj) Ts) - | proj_corecT proj (Type (s, Ts)) = Type (s, map (proj_corecT proj) Ts) - | proj_corecT _ T = T; - - fun unzip_corecT T = - if exists_subtype_in fpTs T then [proj_corecT fst T, proj_corecT snd T] else [T]; - - val (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss) = - mk_types unzip_corecT fp_rec_fun_Ts; - - val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss; - val ((sssss, hssss_tl), lthy) = - lthy - |> mk_Freessss "q" s_Tssss - ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss); - val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl; + val ((cs, pss, (gssss, rssss), (hssss, sssss)), lthy) = + mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss lthy; val cpss = map2 (map o rapp) cs pss; - - fun mk_terms qssss fssss = - let - val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss; - val cqssss = map2 (map o map o map o rapp) cs qssss; - val cfssss = map2 (map o map o map o rapp) cs fssss; - in (pfss, cqssss, cfssss) end; in (((([], [], []), ([], [], [])), - (cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, g_Tsss, pg_Tss)), - (mk_terms sssss hssss, (h_sum_prod_Ts, h_Tsss, ph_Tss)))), lthy) + (cs, cpss, (mk_corec_like_terms cs pss rssss gssss, (g_sum_prod_Ts, g_Tsss, pg_Tss)), + (mk_corec_like_terms cs pss sssss hssss, (h_sum_prod_Ts, h_Tsss, ph_Tss)))), lthy) end; fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor), @@ -1311,10 +1341,9 @@ (disc_unfold_thmss, disc_corec_thmss, disc_corec_like_attrs), (disc_unfold_iff_thmss, disc_corec_iff_thmss, disc_corec_like_iff_attrs), (sel_unfold_thmss, sel_corec_thmss, sel_corec_like_attrs)) = - derive_coinduct_unfold_corec_thms_for_types names_lthy0 no_defs_lthy pre_bnfs fp_induct + derive_coinduct_unfold_corec_thms_for_types pre_bnfs fp_folds0 fp_recs0 fp_induct fp_strong_induct dtor_ctors fp_fold_thms fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs As - kss mss ns cs cpss pgss crssss cgssss phss csssss chssss ctrss ctr_defss ctr_wrap_ress - unfolds corecs unfold_defs corec_defs lthy; + kss mss ns ctrss ctr_defss ctr_wrap_ress unfolds corecs unfold_defs corec_defs lthy; fun coinduct_type_attr T_name = Attrib.internal (K (Induct.coinduct_type T_name));