src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 51852 23d938495367
parent 51850 106afdf5806c
child 51853 cce8b6ba429d
equal deleted inserted replaced
51851:7e9265a0eb01 51852:23d938495367
    16      xxfolds: term list,
    16      xxfolds: term list,
    17      xxrecs: term list,
    17      xxrecs: term list,
    18      xxfold_thmss: thm list list,
    18      xxfold_thmss: thm list list,
    19      xxrec_thmss: thm list list};
    19      xxrec_thmss: thm list list};
    20 
    20 
    21   val fp_sugar_of: Proof.context -> string -> fp_sugar option
    21   val fp_sugar_of: local_theory -> string -> fp_sugar option
       
    22 
       
    23   val build_maps: local_theory -> typ list -> (int -> typ * typ -> term) -> typ * typ -> term
    22 
    24 
    23   val mk_fp_iter_fun_types: term -> typ list
    25   val mk_fp_iter_fun_types: term -> typ list
    24   val mk_fun_arg_typess: int -> int list -> typ -> typ list list
    26   val mk_fun_arg_typess: int -> int list -> typ -> typ list list
    25   val unzip_recT: typ list -> typ -> typ list * typ list
    27   val unzip_recT: typ list -> typ -> typ list * typ list
    26   val mk_fold_fun_typess: typ list list list -> typ list list -> typ list list
    28   val mk_fold_fun_typess: typ list list list -> typ list list -> typ list list
    27   val mk_rec_fun_typess: typ list -> typ list list list -> typ list list -> typ list list
    29   val mk_rec_fun_typess: typ list -> typ list list list -> typ list list -> typ list list
    28 
    30 
    29   val derive_induct_fold_rec_thms_for_types: BNF_Def.bnf list -> term list -> term list -> thm ->
    31   val derive_induct_fold_rec_thms_for_types: BNF_Def.bnf list -> term list -> term list -> thm ->
    30     thm list -> thm list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list ->
    32     thm list -> thm list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list ->
    31     typ list -> term list list -> thm list list -> term list -> term list -> thm list -> thm list ->
    33     typ list -> term list list -> thm list list -> term list -> term list -> thm list -> thm list ->
    32     Proof.context ->
    34     local_theory ->
    33     (thm * thm list * Args.src list) * (thm list list * Args.src list)
    35     (thm * thm list * Args.src list) * (thm list list * Args.src list)
    34     * (thm list list * Args.src list)
    36     * (thm list list * Args.src list)
    35   val derive_coinduct_unfold_corec_thms_for_types: BNF_Def.bnf list -> term list -> term list ->
    37   val derive_coinduct_unfold_corec_thms_for_types: BNF_Def.bnf list -> term list -> term list ->
    36     thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.bnf list -> BNF_Def.bnf list ->
    38     thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.bnf list -> BNF_Def.bnf list ->
    37     typ list -> typ list -> typ list -> int list list -> int list list -> int list ->
    39     typ list -> typ list -> typ list -> int list list -> int list list -> int list ->
    38     thm list list -> BNF_Ctr_Sugar.ctr_sugar list -> term list -> term list -> thm list ->
    40     thm list list -> BNF_Ctr_Sugar.ctr_sugar list -> term list -> term list -> thm list ->
    39     thm list -> Proof.context ->
    41     thm list -> local_theory ->
    40     (thm * thm list * thm * thm list * Args.src list)
    42     (thm * thm list * thm * thm list * Args.src list)
    41     * (thm list list * thm list list * Args.src list)
    43     * (thm list list * thm list list * Args.src list)
    42     * (thm list list * thm list list) * (thm list list * thm list list * Args.src list)
    44     * (thm list list * thm list list) * (thm list list * thm list list * Args.src list)
    43     * (thm list list * thm list list * Args.src list)
    45     * (thm list list * thm list list * Args.src list)
    44     * (thm list list * thm list list * Args.src list)
    46     * (thm list list * thm list list * Args.src list)
   358 fun ctr_of ((((_, ctr), _), _), _) = ctr;
   360 fun ctr_of ((((_, ctr), _), _), _) = ctr;
   359 fun args_of (((_, args), _), _) = args;
   361 fun args_of (((_, args), _), _) = args;
   360 fun defaults_of ((_, ds), _) = ds;
   362 fun defaults_of ((_, ds), _) = ds;
   361 fun ctr_mixfix_of (_, mx) = mx;
   363 fun ctr_mixfix_of (_, mx) = mx;
   362 
   364 
   363 fun build_map_step lthy build_arg (Type (s, Ts)) (Type (_, Us)) =
   365 fun build_map_step lthy build_arg (Type (s, Ts), Type (_, Us)) =
   364   let
   366   let
   365     val bnf = the (bnf_of lthy s);
   367     val bnf = the (bnf_of lthy s);
   366     val live = live_of_bnf bnf;
   368     val live = live_of_bnf bnf;
   367     val mapx = mk_map live Ts Us (map_of_bnf bnf);
   369     val mapx = mk_map live Ts Us (map_of_bnf bnf);
   368     val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
   370     val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
   369   in Term.list_comb (mapx, map build_arg TUs') end;
   371   in Term.list_comb (mapx, map build_arg TUs') end;
       
   372 
       
   373 fun build_maps lthy Ts build_simple =
       
   374   let
       
   375     fun build (TU as (T, U)) =
       
   376       if T = U then
       
   377         id_const T
       
   378       else
       
   379         (case find_index (curry (op =) T) Ts of
       
   380           ~1 => build_map_step lthy build TU
       
   381         | kk => build_simple kk TU);
       
   382   in build end;
   370 
   383 
   371 fun build_rel_step lthy build_arg (Type (s, Ts)) =
   384 fun build_rel_step lthy build_arg (Type (s, Ts)) =
   372   let
   385   let
   373     val bnf = the (bnf_of lthy s);
   386     val bnf = the (bnf_of lthy s);
   374     val live = live_of_bnf bnf;
   387     val live = live_of_bnf bnf;
   499 
   512 
   500         fun mk_goal fss fiter xctr f xs fxs =
   513         fun mk_goal fss fiter xctr f xs fxs =
   501           fold_rev (fold_rev Logic.all) (xs :: fss)
   514           fold_rev (fold_rev Logic.all) (xs :: fss)
   502             (mk_Trueprop_eq (fiter $ xctr, Term.list_comb (f, fxs)));
   515             (mk_Trueprop_eq (fiter $ xctr, Term.list_comb (f, fxs)));
   503 
   516 
   504         fun build_iter fiters (T, U) =
       
   505           if T = U then
       
   506             id_const T
       
   507           else
       
   508             (case find_index (curry (op =) T) fpTs of
       
   509               ~1 => build_map_step lthy (build_iter fiters) T U
       
   510             | kk => nth fiters kk);
       
   511 
       
   512         val mk_U = typ_subst_nonatomic (map2 pair fpTs Cs);
   517         val mk_U = typ_subst_nonatomic (map2 pair fpTs Cs);
   513 
   518 
   514         fun unzip_iters fiters combine (x as Free (_, T)) =
   519         fun unzip_iters fiters combine (x as Free (_, T)) =
   515           if exists_subtype_in fpTs T then
   520           if exists_subtype_in fpTs T then
   516             combine (x, build_iter fiters (T, mk_U T) $ x)
   521             combine (x, build_maps lthy fpTs (K o nth fiters) (T, mk_U T) $ x)
   517           else
   522           else
   518             ([x], []);
   523             ([x], []);
   519 
   524 
   520         val gxsss = map (map (flat_rec (unzip_iters gfolds (fn (_, t) => ([t], []))))) xsss;
   525         val gxsss = map (map (flat_rec (unzip_iters gfolds (fn (_, t) => ([t], []))))) xsss;
   521         val hxsss = map (map (flat_rec (unzip_iters hrecs (pairself single)))) xsss;
   526         val hxsss = map (map (flat_rec (unzip_iters hrecs (pairself single)))) xsss;
   675         fun mk_goal pfss c cps fcoiter n k ctr m cfs' =
   680         fun mk_goal pfss c cps fcoiter n k ctr m cfs' =
   676           fold_rev (fold_rev Logic.all) ([c] :: pfss)
   681           fold_rev (fold_rev Logic.all) ([c] :: pfss)
   677             (Logic.list_implies (seq_conds (HOLogic.mk_Trueprop oo mk_maybe_not) n k cps,
   682             (Logic.list_implies (seq_conds (HOLogic.mk_Trueprop oo mk_maybe_not) n k cps,
   678                mk_Trueprop_eq (fcoiter $ c, Term.list_comb (ctr, take m cfs'))));
   683                mk_Trueprop_eq (fcoiter $ c, Term.list_comb (ctr, take m cfs'))));
   679 
   684 
   680         fun build_coiter fcoiters (T, U) =
       
   681           if T = U then
       
   682             id_const T
       
   683           else
       
   684             (case find_index (curry (op =) U) fpTs of
       
   685               ~1 => build_map_step lthy (build_coiter fcoiters) T U
       
   686             | kk => nth fcoiters kk);
       
   687 
       
   688         val mk_U = typ_subst_nonatomic (map2 pair Cs fpTs);
   685         val mk_U = typ_subst_nonatomic (map2 pair Cs fpTs);
   689 
   686 
   690         fun intr_coiters fcoiters [] [cf] =
   687         fun intr_coiters fcoiters [] [cf] =
   691             let val T = fastype_of cf in
   688             let val T = fastype_of cf in
   692               if exists_subtype_in Cs T then build_coiter fcoiters (T, mk_U T) $ cf else cf
   689               if exists_subtype_in Cs T then build_maps lthy Cs (K o nth fcoiters) (T, mk_U T) $ cf
       
   690               else cf
   693             end
   691             end
   694           | intr_coiters fcoiters [cq] [cf, cf'] =
   692           | intr_coiters fcoiters [cq] [cf, cf'] =
   695             mk_If cq (intr_coiters fcoiters [] [cf]) (intr_coiters fcoiters [] [cf']);
   693             mk_If cq (intr_coiters fcoiters [] [cf]) (intr_coiters fcoiters [] [cf']);
   696 
   694 
   697         val crgsss = map2 (map2 (map2 (intr_coiters gunfolds))) crssss cgssss;
   695         val crgsss = map2 (map2 (map2 (intr_coiters gunfolds))) crssss cgssss;
  1173 
  1171 
  1174         fun define_fold_rec no_defs_lthy =
  1172         fun define_fold_rec no_defs_lthy =
  1175           let
  1173           let
  1176             val fpT_to_C = fpT --> C;
  1174             val fpT_to_C = fpT --> C;
  1177 
  1175 
  1178             fun build_prod_proj mk_proj (T, U) =
  1176             fun build_prod_proj mk_proj (TU as (T, U)) =
  1179               if T = U then
  1177               if T = U then
  1180                 id_const T
  1178                 id_const T
  1181               else
  1179               else
  1182                 (case (T, U) of
  1180                 (case TU of
  1183                   (Type (s, _), Type (s', _)) =>
  1181                   (Type (s, _), Type (s', _)) =>
  1184                   if s = s' then build_map_step lthy (build_prod_proj mk_proj) T U else mk_proj T
  1182                   if s = s' then build_map_step lthy (build_prod_proj mk_proj) TU else mk_proj T
  1185                 | _ => mk_proj T);
  1183                 | _ => mk_proj T);
  1186 
  1184 
  1187             (* TODO: Avoid these complications; cf. corec case *)
  1185             (* TODO: Avoid these complications; cf. corec case *)
  1188             fun mk_U proj (Type (s as @{type_name prod}, Ts as [T', U])) =
  1186             fun mk_U proj (Type (s as @{type_name prod}, Ts as [T', U])) =
  1189                 if member (op =) fpTs T' then proj (T', U) else Type (s, map (mk_U proj) Ts)
  1187                 if member (op =) fpTs T' then proj (T', U) else Type (s, map (mk_U proj) Ts)
  1233 
  1231 
  1234         fun define_unfold_corec no_defs_lthy =
  1232         fun define_unfold_corec no_defs_lthy =
  1235           let
  1233           let
  1236             val B_to_fpT = C --> fpT;
  1234             val B_to_fpT = C --> fpT;
  1237 
  1235 
  1238             fun build_sum_inj mk_inj (T, U) =
  1236             fun build_sum_inj mk_inj (TU as (T, U)) =
  1239               if T = U then
  1237               if T = U then
  1240                 id_const T
  1238                 id_const T
  1241               else
  1239               else
  1242                 (case (T, U) of
  1240                 (case TU of
  1243                   (Type (s, _), Type (s', _)) =>
  1241                   (Type (s, _), Type (s', _)) =>
  1244                   if s = s' then build_map_step lthy (build_sum_inj mk_inj) T U
  1242                   if s = s' then build_map_step lthy (build_sum_inj mk_inj) TU
  1245                   else uncurry mk_inj (dest_sumT U)
  1243                   else uncurry mk_inj (dest_sumT U)
  1246                 | _ => uncurry mk_inj (dest_sumT U));
  1244                 | _ => uncurry mk_inj (dest_sumT U));
  1247 
  1245 
  1248             fun build_dtor_coiter_arg _ [] [cf] = cf
  1246             fun build_dtor_coiter_arg _ [] [cf] = cf
  1249               | build_dtor_coiter_arg T [cq] [cf, cf'] =
  1247               | build_dtor_coiter_arg T [cq] [cf, cf'] =