src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 53591 b6e2993fd0d3
parent 53569 b4db0ade27bd
child 53592 5a7bf8c859f6
equal deleted inserted replaced
53590:b6dc5403cad1 53591:b6e2993fd0d3
    42   val mk_rel: int -> typ list -> typ list -> term -> term
    42   val mk_rel: int -> typ list -> typ list -> term -> term
    43   val build_map: local_theory -> (typ * typ -> term) -> typ * typ -> term
    43   val build_map: local_theory -> (typ * typ -> term) -> typ * typ -> term
    44   val build_rel: local_theory -> (typ * typ -> term) -> typ * typ -> term
    44   val build_rel: local_theory -> (typ * typ -> term) -> typ * typ -> term
    45   val dest_map: Proof.context -> string -> term -> term * term list
    45   val dest_map: Proof.context -> string -> term -> term * term list
    46   val dest_ctr: Proof.context -> string -> term -> term * term list
    46   val dest_ctr: Proof.context -> string -> term -> term * term list
    47   val mk_co_iters_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
    47   val mk_co_iters_prelims: BNF_FP_Util.fp_kind -> typ list list list -> typ list -> typ list ->
    48     int list list -> term list list -> Proof.context ->
    48     int list -> int list list -> term list list -> Proof.context ->
    49     (term list list
    49     (term list list
    50      * (typ list list * typ list list list list * term list list
    50      * (typ list list * typ list list list list * term list list
    51         * term list list list list) list option
    51         * term list list list list) list option
    52      * (string * term list * term list list
    52      * (string * term list * term list list
    53         * ((term list list * term list list list) * (typ list * typ list list)) list) option)
    53         * ((term list list * term list list list) * (typ list * typ list list)) list) option)
    54     * Proof.context
    54     * Proof.context
    55 
    55 
    56   val mk_iter_fun_arg_types: typ list -> int list -> int list list -> term ->
    56   val mk_iter_fun_arg_types: typ list -> int list -> int list list -> term ->
    57     typ list list list list
    57     typ list list list list
    58   val mk_coiter_fun_arg_types: typ list -> int list -> int list list -> term ->
    58   val mk_coiter_fun_arg_types: typ list list list -> typ list -> int list -> term ->
    59     typ list list
    59     typ list list
    60     * (typ list list list list * typ list list list * typ list list list list * typ list)
    60     * (typ list list list list * typ list list list * typ list list list list * typ list)
    61   val define_iters: string list ->
    61   val define_iters: string list ->
    62     (typ list list * typ list list list list * term list list * term list list list list) list ->
    62     (typ list list * typ list list list list * term list list * term list list list list) list ->
    63     (string -> binding) -> typ list -> typ list -> term list -> Proof.context ->
    63     (string -> binding) -> typ list -> typ list -> term list -> Proof.context ->
   266     map (Term.subst_TVars rho) ts0
   266     map (Term.subst_TVars rho) ts0
   267   end;
   267   end;
   268 
   268 
   269 val mk_fp_iter_fun_types = binder_fun_types o fastype_of;
   269 val mk_fp_iter_fun_types = binder_fun_types o fastype_of;
   270 
   270 
       
   271 (* ### FIXME? *)
   271 fun unzip_recT Cs (T as Type (@{type_name prod}, Ts as [_, U])) =
   272 fun unzip_recT Cs (T as Type (@{type_name prod}, Ts as [_, U])) =
   272     if member (op =) Cs U then Ts else [T]
   273     if member (op =) Cs U then Ts else [T]
   273   | unzip_recT _ T = [T];
   274   | unzip_recT _ T = [T];
   274 
   275 
   275 fun unzip_corecT Cs (T as Type (@{type_name sum}, Ts as [_, U])) =
   276 fun unzip_corecT (Type (@{type_name sum}, _)) T = [T]
   276     if member (op =) Cs U then Ts else [T]
   277   | unzip_corecT _ (T as Type (@{type_name sum}, Ts)) = Ts
   277   | unzip_corecT _ T = [T];
   278   | unzip_corecT _ T = [T];
   278 
   279 
   279 fun mk_map live Ts Us t =
   280 fun mk_map live Ts Us t =
   280   let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
   281   let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
   281     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
   282     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
   432     val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
   433     val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
   433   in
   434   in
   434     ([(g_Tss, y_Tssss, gss, yssss), (h_Tss, z_Tssss, hss, zssss)], lthy)
   435     ([(g_Tss, y_Tssss, gss, yssss), (h_Tss, z_Tssss, hss, zssss)], lthy)
   435   end;
   436   end;
   436 
   437 
   437 fun mk_coiter_fun_arg_types0 Cs ns mss fun_Ts =
   438 fun mk_coiter_fun_arg_types0 ctr_Tsss Cs ns fun_Ts =
   438   let
   439   let
   439     (*avoid "'a itself" arguments in coiterators and corecursors*)
   440     (*avoid "'a itself" arguments in coiterators*)
   440     fun repair_arity [0] = [1]
   441     fun repair_arity [[]] = [[@{typ unit}]]
   441       | repair_arity ms = ms;
   442       | repair_arity Tss = Tss;
   442 
   443 
       
   444     val ctr_Tsss' = map repair_arity ctr_Tsss;
   443     val f_sum_prod_Ts = map range_type fun_Ts;
   445     val f_sum_prod_Ts = map range_type fun_Ts;
   444     val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
   446     val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
   445     val f_Tsss = map2 (map2 dest_tupleT o repair_arity) mss f_prod_Tss;
   447     val f_Tsss = map2 (map2 (dest_tupleT o length)) ctr_Tsss' f_prod_Tss;
   446     val f_Tssss = map2 (fn C => map (map (map (curry op --> C) o unzip_corecT Cs))) Cs f_Tsss;
   448     val f_Tssss = map3 (fn C => map2 (map2 (map (curry op --> C) oo unzip_corecT)))
       
   449       Cs ctr_Tsss' f_Tsss;
   447     val q_Tssss = map (map (map (fn [_] => [] | [_, T] => [mk_pred1T (domain_type T)]))) f_Tssss;
   450     val q_Tssss = map (map (map (fn [_] => [] | [_, T] => [mk_pred1T (domain_type T)]))) f_Tssss;
   448   in
   451   in
   449     (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts)
   452     (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts)
   450   end;
   453   end;
   451 
   454 
   452 fun mk_coiter_p_pred_types Cs ns = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs;
   455 fun mk_coiter_p_pred_types Cs ns = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs;
   453 
   456 
   454 fun mk_coiter_fun_arg_types Cs ns mss dtor_coiter =
   457 fun mk_coiter_fun_arg_types ctr_Tsss Cs ns dtor_coiter =
   455   (mk_coiter_p_pred_types Cs ns,
   458   (mk_coiter_p_pred_types Cs ns,
   456    mk_fp_iter_fun_types dtor_coiter |> mk_coiter_fun_arg_types0 Cs ns mss);
   459    mk_fp_iter_fun_types dtor_coiter |> mk_coiter_fun_arg_types0 ctr_Tsss Cs ns);
   457 
   460 
   458 fun mk_coiters_args_types Cs ns mss dtor_coiter_fun_Tss lthy =
   461 fun mk_coiters_args_types ctr_Tsss Cs ns mss dtor_coiter_fun_Tss lthy =
   459   let
   462   let
   460     val p_Tss = mk_coiter_p_pred_types Cs ns;
   463     val p_Tss = mk_coiter_p_pred_types Cs ns;
   461 
   464 
   462     fun mk_types get_Ts =
   465     fun mk_types get_Ts =
   463       let
   466       let
   464         val fun_Ts = map get_Ts dtor_coiter_fun_Tss;
   467         val fun_Ts = map get_Ts dtor_coiter_fun_Tss;
   465         val (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts) = mk_coiter_fun_arg_types0 Cs ns mss fun_Ts;
   468         val (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts) = mk_coiter_fun_arg_types0 ctr_Tsss Cs ns fun_Ts;
   466         val pf_Tss = map3 flat_corec_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
   469         val pf_Tss = map3 flat_corec_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
   467       in
   470       in
   468         (q_Tssss, f_Tsss, f_Tssss, (f_sum_prod_Ts, pf_Tss))
   471         (q_Tssss, f_Tsss, f_Tssss, (f_sum_prod_Ts, pf_Tss))
   469       end;
   472       end;
   470 
   473 
   507     val corec_args = mk_args sssss hssss h_Tsss;
   510     val corec_args = mk_args sssss hssss h_Tsss;
   508   in
   511   in
   509     ((z, cs, cpss, [(unfold_args, unfold_types), (corec_args, corec_types)]), lthy)
   512     ((z, cs, cpss, [(unfold_args, unfold_types), (corec_args, corec_types)]), lthy)
   510   end;
   513   end;
   511 
   514 
   512 fun mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy =
   515 fun mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy =
   513   let
   516   let
   514     val thy = Proof_Context.theory_of lthy;
   517     val thy = Proof_Context.theory_of lthy;
   515 
   518 
   516     val (xtor_co_iter_fun_Tss, xtor_co_iterss) =
   519     val (xtor_co_iter_fun_Tss, xtor_co_iterss) =
   517       map (mk_co_iters thy fp fpTs Cs #> `(mk_fp_iter_fun_types o hd)) (transpose xtor_co_iterss0)
   520       map (mk_co_iters thy fp fpTs Cs #> `(mk_fp_iter_fun_types o hd)) (transpose xtor_co_iterss0)
   519 
   522 
   520     val ((iters_args_types, coiters_args_types), lthy') =
   523     val ((iters_args_types, coiters_args_types), lthy') =
   521       if fp = Least_FP then
   524       if fp = Least_FP then
   522         mk_iters_args_types Cs ns mss xtor_co_iter_fun_Tss lthy |>> (rpair NONE o SOME)
   525         mk_iters_args_types Cs ns mss xtor_co_iter_fun_Tss lthy |>> (rpair NONE o SOME)
   523       else
   526       else
   524         mk_coiters_args_types Cs ns mss xtor_co_iter_fun_Tss lthy |>> (pair NONE o SOME)
   527         mk_coiters_args_types ctr_Tsss Cs ns mss xtor_co_iter_fun_Tss lthy |>> (pair NONE o SOME)
   525   in
   528   in
   526     ((xtor_co_iterss, iters_args_types, coiters_args_types), lthy')
   529     ((xtor_co_iterss, iters_args_types, coiters_args_types), lthy')
   527   end;
   530   end;
   528 
   531 
   529 fun mk_iter_body ctor_iter fss xssss =
   532 fun mk_iter_body ctor_iter fss xssss =
  1222     val ns = map length ctr_Tsss;
  1225     val ns = map length ctr_Tsss;
  1223     val kss = map (fn n => 1 upto n) ns;
  1226     val kss = map (fn n => 1 upto n) ns;
  1224     val mss = map (map length) ctr_Tsss;
  1227     val mss = map (map length) ctr_Tsss;
  1225 
  1228 
  1226     val ((xtor_co_iterss, iters_args_types, coiters_args_types), lthy') =
  1229     val ((xtor_co_iterss, iters_args_types, coiters_args_types), lthy') =
  1227       mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy;
  1230       mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
  1228 
  1231 
  1229     fun define_ctrs_dtrs_for_type (((((((((((((((((((((((fp_bnf, fp_b), fpT), ctor), dtor),
  1232     fun define_ctrs_dtrs_for_type (((((((((((((((((((((((fp_bnf, fp_b), fpT), ctor), dtor),
  1230             xtor_co_iters), ctor_dtor), dtor_ctor), ctor_inject), pre_map_def), pre_set_defs),
  1233             xtor_co_iters), ctor_dtor), dtor_ctor), ctor_inject), pre_map_def), pre_set_defs),
  1231           pre_rel_def), fp_map_thm), fp_set_thms), fp_rel_thm), n), ks), ms), ctr_bindings),
  1234           pre_rel_def), fp_map_thm), fp_set_thms), fp_rel_thm), n), ks), ms), ctr_bindings),
  1232         ctr_mixfixes), ctr_Tss), disc_bindings), sel_bindingss), raw_sel_defaultss) no_defs_lthy =
  1235         ctr_mixfixes), ctr_Tss), disc_bindings), sel_bindingss), raw_sel_defaultss) no_defs_lthy =