src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 51827 836257faaad5
parent 51824 27d073b0876c
child 51828 67c6d6136915
     1.1 --- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Apr 30 03:18:07 2013 +0200
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Apr 30 09:53:56 2013 +0200
     1.3 @@ -15,12 +15,12 @@
     1.4  
     1.5    val fp_of: Proof.context -> string -> fp option
     1.6  
     1.7 -  val derive_induct_fold_rec_thms_for_types: BNF_Def.BNF list -> thm -> thm list -> thm list ->
     1.8 -    BNF_Def.BNF list -> BNF_Def.BNF list -> typ list -> typ list -> typ list list list ->
     1.9 -    int list list -> int list -> term list list -> term list list -> term list list -> term list
    1.10 -    list list -> thm list list -> term list -> term list -> thm list -> thm list -> Proof.context ->
    1.11 +  val derive_induct_fold_rec_thms_for_types: BNF_Def.BNF list -> term list -> term list -> thm ->
    1.12 +    thm list -> thm list -> BNF_Def.BNF list -> BNF_Def.BNF list -> typ list -> typ list ->
    1.13 +    typ list -> term list list -> thm list list -> term list -> term list -> thm list -> thm list ->
    1.14 +    Proof.context ->
    1.15      (thm * thm list * Args.src list) * (thm list list * Args.src list)
    1.16 -      * (thm list list * Args.src list)
    1.17 +    * (thm list list * Args.src list)
    1.18    val derive_coinduct_unfold_corec_thms_for_types: Proof.context -> Proof.context ->
    1.19      BNF_Def.BNF list -> thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.BNF list ->
    1.20      BNF_Def.BNF list -> typ list -> typ list -> typ list -> int list list -> int list list ->
    1.21 @@ -188,6 +188,31 @@
    1.22      Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
    1.23    end;
    1.24  
    1.25 +val mk_fp_rec_like_fun_types = fst o split_last o binder_types o fastype_of o hd;
    1.26 +
    1.27 +fun mk_fp_rec_like lfp As Cs fp_rec_likes0 =
    1.28 +  map (mk_rec_like lfp As Cs) fp_rec_likes0
    1.29 +  |> (fn ts => (ts, mk_fp_rec_like_fun_types ts));
    1.30 +
    1.31 +fun mk_rec_like_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
    1.32 +
    1.33 +fun project_recT fpTs proj =
    1.34 +  let
    1.35 +    fun project (Type (s as @{type_name prod}, Ts as [T, U])) =
    1.36 +        if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts)
    1.37 +      | project (Type (s, Ts)) = Type (s, map project Ts)
    1.38 +      | project T = T;
    1.39 +  in project end;
    1.40 +
    1.41 +fun unzip_recT fpTs T =
    1.42 +  if exists_subtype_in fpTs T then ([project_recT fpTs fst T], [project_recT fpTs snd T])
    1.43 +  else ([T], []);
    1.44 +
    1.45 +fun massage_rec_fun_arg_typesss fpTs = map (map (flat_rec (unzip_recT fpTs)));
    1.46 +
    1.47 +val mk_fold_fun_typess = map2 (map2 (curry (op --->)));
    1.48 +val mk_rec_fun_typess = mk_fold_fun_typess oo massage_rec_fun_arg_typesss;
    1.49 +
    1.50  fun mk_map live Ts Us t =
    1.51    let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
    1.52      Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
    1.53 @@ -243,11 +268,16 @@
    1.54      val Ts' = map domain_type (fst (strip_typeN live (fastype_of rel)));
    1.55    in Term.list_comb (rel, map build_arg Ts') end;
    1.56  
    1.57 -fun derive_induct_fold_rec_thms_for_types pre_bnfs ctor_induct ctor_fold_thms ctor_rec_thms
    1.58 -    nesting_bnfs nested_bnfs fpTs Cs ctr_Tsss mss ns gss hss ctrss xsss ctr_defss folds recs
    1.59 -    fold_defs rec_defs lthy =
    1.60 +fun derive_induct_fold_rec_thms_for_types pre_bnfs ctor_folds0 ctor_recs0 ctor_induct ctor_fold_thms
    1.61 +    ctor_rec_thms nesting_bnfs nested_bnfs fpTs Cs As ctrss ctr_defss folds recs fold_defs rec_defs
    1.62 +    lthy =
    1.63    let
    1.64 +    val ctr_Tsss = map (map (binder_types o fastype_of)) ctrss;
    1.65 +
    1.66      val nn = length pre_bnfs;
    1.67 +    val ns = map length ctr_Tsss;
    1.68 +    val mss = map (map length) ctr_Tsss;
    1.69 +    val Css = map2 replicate ns Cs;
    1.70  
    1.71      val pre_map_defs = map map_def_of_bnf pre_bnfs;
    1.72      val pre_set_defss = map set_defs_of_bnf pre_bnfs;
    1.73 @@ -258,11 +288,23 @@
    1.74  
    1.75      val fp_b_names = map base_name_of_typ fpTs;
    1.76  
    1.77 -    val (((ps, ps'), us'), names_lthy) =
    1.78 +    val (_, ctor_fold_fun_Ts) = mk_fp_rec_like true As Cs ctor_folds0;
    1.79 +    val (_, ctor_rec_fun_Ts) = mk_fp_rec_like true As Cs ctor_recs0;
    1.80 +
    1.81 +    val y_Tsss = map3 mk_rec_like_fun_arg_typess ns mss ctor_fold_fun_Ts;
    1.82 +    val g_Tss = mk_fold_fun_typess y_Tsss Css;
    1.83 +
    1.84 +    val z_Tsss = map3 mk_rec_like_fun_arg_typess ns mss ctor_rec_fun_Ts;
    1.85 +    val h_Tss = mk_rec_fun_typess fpTs z_Tsss Css;
    1.86 +
    1.87 +    val (((((ps, ps'), xsss), gss), us'), names_lthy) =
    1.88        lthy
    1.89        |> mk_Frees' "P" (map mk_pred1T fpTs)
    1.90 +      ||>> mk_Freesss "x" ctr_Tsss
    1.91 +      ||>> mk_Freess "f" g_Tss
    1.92        ||>> Variable.variant_fixes fp_b_names;
    1.93  
    1.94 +    val hss = map2 (map2 retype_free) h_Tss gss;
    1.95      val us = map2 (curry Free) us' fpTs;
    1.96  
    1.97      fun mk_sets_nested bnf =
    1.98 @@ -831,40 +873,24 @@
    1.99      val mss = map (map length) ctr_Tsss;
   1.100      val Css = map2 replicate ns Cs;
   1.101  
   1.102 -    val fp_folds as any_fp_fold :: _ = map (mk_rec_like lfp As Cs) fp_folds0;
   1.103 -    val fp_recs as any_fp_rec :: _ = map (mk_rec_like lfp As Cs) fp_recs0;
   1.104 +    val (fp_folds, fp_fold_fun_Ts) = mk_fp_rec_like lfp As Cs fp_folds0;
   1.105 +    val (fp_recs, fp_rec_fun_Ts) = mk_fp_rec_like lfp As Cs fp_recs0;
   1.106  
   1.107 -    val fp_fold_fun_Ts = fst (split_last (binder_types (fastype_of any_fp_fold)));
   1.108 -    val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of any_fp_rec)));
   1.109 -
   1.110 -    val (((fold_only as (gss, _, _), rec_only as (hss, _, _)),
   1.111 +    val (((fold_only, rec_only),
   1.112            (cs, cpss, unfold_only as ((pgss, crssss, cgssss), (_, g_Tsss, _)),
   1.113             corec_only as ((phss, csssss, chssss), (_, h_Tsss, _)))), names_lthy0) =
   1.114        if lfp then
   1.115          let
   1.116 -          val y_Tsss =
   1.117 -            map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
   1.118 -              ns mss fp_fold_fun_Ts;
   1.119 -          val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
   1.120 +          val y_Tsss = map3 mk_rec_like_fun_arg_typess ns mss fp_fold_fun_Ts;
   1.121 +          val g_Tss = mk_fold_fun_typess y_Tsss Css;
   1.122  
   1.123            val ((gss, ysss), lthy) =
   1.124              lthy
   1.125              |> mk_Freess "f" g_Tss
   1.126              ||>> mk_Freesss "x" y_Tsss;
   1.127  
   1.128 -          fun proj_recT proj (Type (s as @{type_name prod}, Ts as [T, U])) =
   1.129 -              if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_recT proj) Ts)
   1.130 -            | proj_recT proj (Type (s, Ts)) = Type (s, map (proj_recT proj) Ts)
   1.131 -            | proj_recT _ T = T;
   1.132 -
   1.133 -          fun unzip_recT T =
   1.134 -            if exists_subtype_in fpTs T then ([proj_recT fst T], [proj_recT snd T]) else ([T], []);
   1.135 -
   1.136 -          val z_Tsss =
   1.137 -            map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
   1.138 -              ns mss fp_rec_fun_Ts;
   1.139 -          val z_Tsss' = map (map (flat_rec unzip_recT)) z_Tsss;
   1.140 -          val h_Tss = map2 (map2 (curry (op --->))) z_Tsss' Css;
   1.141 +          val z_Tsss = map3 mk_rec_like_fun_arg_typess ns mss fp_rec_fun_Ts;
   1.142 +          val h_Tss = mk_rec_fun_typess fpTs z_Tsss Css;
   1.143  
   1.144            val hss = map2 (map2 retype_free) h_Tss gss;
   1.145            val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
   1.146 @@ -1252,14 +1278,14 @@
   1.147          injects @ distincts @ case_thms @ rec_likes @ fold_likes);
   1.148  
   1.149      fun derive_and_note_induct_fold_rec_thms_for_types
   1.150 -        (((ctrss, xsss, ctr_defss, ctr_wrap_ress), (folds, recs, fold_defs, rec_defs)), lthy) =
   1.151 +        (((ctrss, _, ctr_defss, ctr_wrap_ress), (folds, recs, fold_defs, rec_defs)), lthy) =
   1.152        let
   1.153          val ((induct_thm, induct_thms, induct_attrs),
   1.154               (fold_thmss, fold_attrs),
   1.155               (rec_thmss, rec_attrs)) =
   1.156 -          derive_induct_fold_rec_thms_for_types pre_bnfs fp_induct fp_fold_thms fp_rec_thms
   1.157 -            nesting_bnfs nested_bnfs fpTs Cs ctr_Tsss mss ns gss hss ctrss xsss ctr_defss folds recs
   1.158 -            fold_defs rec_defs lthy;
   1.159 +          derive_induct_fold_rec_thms_for_types pre_bnfs fp_folds0 fp_recs0 fp_induct fp_fold_thms
   1.160 +            fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs As ctrss ctr_defss folds recs fold_defs
   1.161 +            rec_defs lthy;
   1.162  
   1.163          fun induct_type_attr T_name = Attrib.internal (K (Induct.induct_type T_name));
   1.164