src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 51827 836257faaad5
parent 51824 27d073b0876c
child 51828 67c6d6136915
equal deleted inserted replaced
51826:054a40461449 51827:836257faaad5
    13      fp_res: BNF_FP.fp_result,
    13      fp_res: BNF_FP.fp_result,
    14      ctr_wrap_res: BNF_Ctr_Sugar.ctr_wrap_result};
    14      ctr_wrap_res: BNF_Ctr_Sugar.ctr_wrap_result};
    15 
    15 
    16   val fp_of: Proof.context -> string -> fp option
    16   val fp_of: Proof.context -> string -> fp option
    17 
    17 
    18   val derive_induct_fold_rec_thms_for_types: BNF_Def.BNF list -> thm -> thm list -> thm list ->
    18   val derive_induct_fold_rec_thms_for_types: BNF_Def.BNF list -> term list -> term list -> thm ->
    19     BNF_Def.BNF list -> BNF_Def.BNF list -> typ list -> typ list -> typ list list list ->
    19     thm list -> thm list -> BNF_Def.BNF list -> BNF_Def.BNF list -> typ list -> typ list ->
    20     int list list -> int list -> term list list -> term list list -> term list list -> term list
    20     typ list -> term list list -> thm list list -> term list -> term list -> thm list -> thm list ->
    21     list list -> thm list list -> term list -> term list -> thm list -> thm list -> Proof.context ->
    21     Proof.context ->
    22     (thm * thm list * Args.src list) * (thm list list * Args.src list)
    22     (thm * thm list * Args.src list) * (thm list list * Args.src list)
    23       * (thm list list * Args.src list)
    23     * (thm list list * Args.src list)
    24   val derive_coinduct_unfold_corec_thms_for_types: Proof.context -> Proof.context ->
    24   val derive_coinduct_unfold_corec_thms_for_types: Proof.context -> Proof.context ->
    25     BNF_Def.BNF list -> thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.BNF list ->
    25     BNF_Def.BNF list -> thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.BNF list ->
    26     BNF_Def.BNF list -> typ list -> typ list -> typ list -> int list list -> int list list ->
    26     BNF_Def.BNF list -> typ list -> typ list -> typ list -> int list list -> int list list ->
    27     int list -> term list -> term list list -> term list list -> term list list list list ->
    27     int list -> term list -> term list list -> term list list -> term list list list list ->
    28     term list list list list -> term list list -> term list list list list ->
    28     term list list list list -> term list list -> term list list list list ->
   186     val Us0 = distinct (op =) (map (if lfp then body_type else domain_type) f_Us);
   186     val Us0 = distinct (op =) (map (if lfp then body_type else domain_type) f_Us);
   187   in
   187   in
   188     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
   188     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
   189   end;
   189   end;
   190 
   190 
       
   191 val mk_fp_rec_like_fun_types = fst o split_last o binder_types o fastype_of o hd;
       
   192 
       
   193 fun mk_fp_rec_like lfp As Cs fp_rec_likes0 =
       
   194   map (mk_rec_like lfp As Cs) fp_rec_likes0
       
   195   |> (fn ts => (ts, mk_fp_rec_like_fun_types ts));
       
   196 
       
   197 fun mk_rec_like_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
       
   198 
       
   199 fun project_recT fpTs proj =
       
   200   let
       
   201     fun project (Type (s as @{type_name prod}, Ts as [T, U])) =
       
   202         if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts)
       
   203       | project (Type (s, Ts)) = Type (s, map project Ts)
       
   204       | project T = T;
       
   205   in project end;
       
   206 
       
   207 fun unzip_recT fpTs T =
       
   208   if exists_subtype_in fpTs T then ([project_recT fpTs fst T], [project_recT fpTs snd T])
       
   209   else ([T], []);
       
   210 
       
   211 fun massage_rec_fun_arg_typesss fpTs = map (map (flat_rec (unzip_recT fpTs)));
       
   212 
       
   213 val mk_fold_fun_typess = map2 (map2 (curry (op --->)));
       
   214 val mk_rec_fun_typess = mk_fold_fun_typess oo massage_rec_fun_arg_typesss;
       
   215 
   191 fun mk_map live Ts Us t =
   216 fun mk_map live Ts Us t =
   192   let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
   217   let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
   193     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
   218     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
   194   end;
   219   end;
   195 
   220 
   241     val live = live_of_bnf bnf;
   266     val live = live_of_bnf bnf;
   242     val rel = mk_rel live Ts Ts (rel_of_bnf bnf);
   267     val rel = mk_rel live Ts Ts (rel_of_bnf bnf);
   243     val Ts' = map domain_type (fst (strip_typeN live (fastype_of rel)));
   268     val Ts' = map domain_type (fst (strip_typeN live (fastype_of rel)));
   244   in Term.list_comb (rel, map build_arg Ts') end;
   269   in Term.list_comb (rel, map build_arg Ts') end;
   245 
   270 
   246 fun derive_induct_fold_rec_thms_for_types pre_bnfs ctor_induct ctor_fold_thms ctor_rec_thms
   271 fun derive_induct_fold_rec_thms_for_types pre_bnfs ctor_folds0 ctor_recs0 ctor_induct ctor_fold_thms
   247     nesting_bnfs nested_bnfs fpTs Cs ctr_Tsss mss ns gss hss ctrss xsss ctr_defss folds recs
   272     ctor_rec_thms nesting_bnfs nested_bnfs fpTs Cs As ctrss ctr_defss folds recs fold_defs rec_defs
   248     fold_defs rec_defs lthy =
   273     lthy =
   249   let
   274   let
       
   275     val ctr_Tsss = map (map (binder_types o fastype_of)) ctrss;
       
   276 
   250     val nn = length pre_bnfs;
   277     val nn = length pre_bnfs;
       
   278     val ns = map length ctr_Tsss;
       
   279     val mss = map (map length) ctr_Tsss;
       
   280     val Css = map2 replicate ns Cs;
   251 
   281 
   252     val pre_map_defs = map map_def_of_bnf pre_bnfs;
   282     val pre_map_defs = map map_def_of_bnf pre_bnfs;
   253     val pre_set_defss = map set_defs_of_bnf pre_bnfs;
   283     val pre_set_defss = map set_defs_of_bnf pre_bnfs;
   254     val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
   284     val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
   255     val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
   285     val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
   256     val nested_set_map's = maps set_map'_of_bnf nested_bnfs;
   286     val nested_set_map's = maps set_map'_of_bnf nested_bnfs;
   257     val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nesting_bnfs;
   287     val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nesting_bnfs;
   258 
   288 
   259     val fp_b_names = map base_name_of_typ fpTs;
   289     val fp_b_names = map base_name_of_typ fpTs;
   260 
   290 
   261     val (((ps, ps'), us'), names_lthy) =
   291     val (_, ctor_fold_fun_Ts) = mk_fp_rec_like true As Cs ctor_folds0;
       
   292     val (_, ctor_rec_fun_Ts) = mk_fp_rec_like true As Cs ctor_recs0;
       
   293 
       
   294     val y_Tsss = map3 mk_rec_like_fun_arg_typess ns mss ctor_fold_fun_Ts;
       
   295     val g_Tss = mk_fold_fun_typess y_Tsss Css;
       
   296 
       
   297     val z_Tsss = map3 mk_rec_like_fun_arg_typess ns mss ctor_rec_fun_Ts;
       
   298     val h_Tss = mk_rec_fun_typess fpTs z_Tsss Css;
       
   299 
       
   300     val (((((ps, ps'), xsss), gss), us'), names_lthy) =
   262       lthy
   301       lthy
   263       |> mk_Frees' "P" (map mk_pred1T fpTs)
   302       |> mk_Frees' "P" (map mk_pred1T fpTs)
       
   303       ||>> mk_Freesss "x" ctr_Tsss
       
   304       ||>> mk_Freess "f" g_Tss
   264       ||>> Variable.variant_fixes fp_b_names;
   305       ||>> Variable.variant_fixes fp_b_names;
   265 
   306 
       
   307     val hss = map2 (map2 retype_free) h_Tss gss;
   266     val us = map2 (curry Free) us' fpTs;
   308     val us = map2 (curry Free) us' fpTs;
   267 
   309 
   268     fun mk_sets_nested bnf =
   310     fun mk_sets_nested bnf =
   269       let
   311       let
   270         val Type (T_name, Us) = T_of_bnf bnf;
   312         val Type (T_name, Us) = T_of_bnf bnf;
   829     val ns = map length ctr_Tsss;
   871     val ns = map length ctr_Tsss;
   830     val kss = map (fn n => 1 upto n) ns;
   872     val kss = map (fn n => 1 upto n) ns;
   831     val mss = map (map length) ctr_Tsss;
   873     val mss = map (map length) ctr_Tsss;
   832     val Css = map2 replicate ns Cs;
   874     val Css = map2 replicate ns Cs;
   833 
   875 
   834     val fp_folds as any_fp_fold :: _ = map (mk_rec_like lfp As Cs) fp_folds0;
   876     val (fp_folds, fp_fold_fun_Ts) = mk_fp_rec_like lfp As Cs fp_folds0;
   835     val fp_recs as any_fp_rec :: _ = map (mk_rec_like lfp As Cs) fp_recs0;
   877     val (fp_recs, fp_rec_fun_Ts) = mk_fp_rec_like lfp As Cs fp_recs0;
   836 
   878 
   837     val fp_fold_fun_Ts = fst (split_last (binder_types (fastype_of any_fp_fold)));
   879     val (((fold_only, rec_only),
   838     val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of any_fp_rec)));
       
   839 
       
   840     val (((fold_only as (gss, _, _), rec_only as (hss, _, _)),
       
   841           (cs, cpss, unfold_only as ((pgss, crssss, cgssss), (_, g_Tsss, _)),
   880           (cs, cpss, unfold_only as ((pgss, crssss, cgssss), (_, g_Tsss, _)),
   842            corec_only as ((phss, csssss, chssss), (_, h_Tsss, _)))), names_lthy0) =
   881            corec_only as ((phss, csssss, chssss), (_, h_Tsss, _)))), names_lthy0) =
   843       if lfp then
   882       if lfp then
   844         let
   883         let
   845           val y_Tsss =
   884           val y_Tsss = map3 mk_rec_like_fun_arg_typess ns mss fp_fold_fun_Ts;
   846             map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
   885           val g_Tss = mk_fold_fun_typess y_Tsss Css;
   847               ns mss fp_fold_fun_Ts;
       
   848           val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
       
   849 
   886 
   850           val ((gss, ysss), lthy) =
   887           val ((gss, ysss), lthy) =
   851             lthy
   888             lthy
   852             |> mk_Freess "f" g_Tss
   889             |> mk_Freess "f" g_Tss
   853             ||>> mk_Freesss "x" y_Tsss;
   890             ||>> mk_Freesss "x" y_Tsss;
   854 
   891 
   855           fun proj_recT proj (Type (s as @{type_name prod}, Ts as [T, U])) =
   892           val z_Tsss = map3 mk_rec_like_fun_arg_typess ns mss fp_rec_fun_Ts;
   856               if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_recT proj) Ts)
   893           val h_Tss = mk_rec_fun_typess fpTs z_Tsss Css;
   857             | proj_recT proj (Type (s, Ts)) = Type (s, map (proj_recT proj) Ts)
       
   858             | proj_recT _ T = T;
       
   859 
       
   860           fun unzip_recT T =
       
   861             if exists_subtype_in fpTs T then ([proj_recT fst T], [proj_recT snd T]) else ([T], []);
       
   862 
       
   863           val z_Tsss =
       
   864             map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
       
   865               ns mss fp_rec_fun_Ts;
       
   866           val z_Tsss' = map (map (flat_rec unzip_recT)) z_Tsss;
       
   867           val h_Tss = map2 (map2 (curry (op --->))) z_Tsss' Css;
       
   868 
   894 
   869           val hss = map2 (map2 retype_free) h_Tss gss;
   895           val hss = map2 (map2 retype_free) h_Tss gss;
   870           val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
   896           val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
   871         in
   897         in
   872           ((((gss, g_Tss, ysss), (hss, h_Tss, zsss)),
   898           ((((gss, g_Tss, ysss), (hss, h_Tss, zsss)),
  1250     val mk_simp_thmss =
  1276     val mk_simp_thmss =
  1251       map3 (fn {injects, distincts, case_thms, ...} => fn rec_likes => fn fold_likes =>
  1277       map3 (fn {injects, distincts, case_thms, ...} => fn rec_likes => fn fold_likes =>
  1252         injects @ distincts @ case_thms @ rec_likes @ fold_likes);
  1278         injects @ distincts @ case_thms @ rec_likes @ fold_likes);
  1253 
  1279 
  1254     fun derive_and_note_induct_fold_rec_thms_for_types
  1280     fun derive_and_note_induct_fold_rec_thms_for_types
  1255         (((ctrss, xsss, ctr_defss, ctr_wrap_ress), (folds, recs, fold_defs, rec_defs)), lthy) =
  1281         (((ctrss, _, ctr_defss, ctr_wrap_ress), (folds, recs, fold_defs, rec_defs)), lthy) =
  1256       let
  1282       let
  1257         val ((induct_thm, induct_thms, induct_attrs),
  1283         val ((induct_thm, induct_thms, induct_attrs),
  1258              (fold_thmss, fold_attrs),
  1284              (fold_thmss, fold_attrs),
  1259              (rec_thmss, rec_attrs)) =
  1285              (rec_thmss, rec_attrs)) =
  1260           derive_induct_fold_rec_thms_for_types pre_bnfs fp_induct fp_fold_thms fp_rec_thms
  1286           derive_induct_fold_rec_thms_for_types pre_bnfs fp_folds0 fp_recs0 fp_induct fp_fold_thms
  1261             nesting_bnfs nested_bnfs fpTs Cs ctr_Tsss mss ns gss hss ctrss xsss ctr_defss folds recs
  1287             fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs As ctrss ctr_defss folds recs fold_defs
  1262             fold_defs rec_defs lthy;
  1288             rec_defs lthy;
  1263 
  1289 
  1264         fun induct_type_attr T_name = Attrib.internal (K (Induct.induct_type T_name));
  1290         fun induct_type_attr T_name = Attrib.internal (K (Induct.induct_type T_name));
  1265 
  1291 
  1266         val simp_thmss = mk_simp_thmss ctr_wrap_ress rec_thmss fold_thmss;
  1292         val simp_thmss = mk_simp_thmss ctr_wrap_ress rec_thmss fold_thmss;
  1267 
  1293