src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 53591 b6e2993fd0d3
parent 53569 b4db0ade27bd
child 53592 5a7bf8c859f6
     1.1 --- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Fri Sep 13 00:55:44 2013 +0200
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Fri Sep 13 02:26:59 2013 +0200
     1.3 @@ -44,8 +44,8 @@
     1.4    val build_rel: local_theory -> (typ * typ -> term) -> typ * typ -> term
     1.5    val dest_map: Proof.context -> string -> term -> term * term list
     1.6    val dest_ctr: Proof.context -> string -> term -> term * term list
     1.7 -  val mk_co_iters_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
     1.8 -    int list list -> term list list -> Proof.context ->
     1.9 +  val mk_co_iters_prelims: BNF_FP_Util.fp_kind -> typ list list list -> typ list -> typ list ->
    1.10 +    int list -> int list list -> term list list -> Proof.context ->
    1.11      (term list list
    1.12       * (typ list list * typ list list list list * term list list
    1.13          * term list list list list) list option
    1.14 @@ -55,7 +55,7 @@
    1.15  
    1.16    val mk_iter_fun_arg_types: typ list -> int list -> int list list -> term ->
    1.17      typ list list list list
    1.18 -  val mk_coiter_fun_arg_types: typ list -> int list -> int list list -> term ->
    1.19 +  val mk_coiter_fun_arg_types: typ list list list -> typ list -> int list -> term ->
    1.20      typ list list
    1.21      * (typ list list list list * typ list list list * typ list list list list * typ list)
    1.22    val define_iters: string list ->
    1.23 @@ -268,12 +268,13 @@
    1.24  
    1.25  val mk_fp_iter_fun_types = binder_fun_types o fastype_of;
    1.26  
    1.27 +(* ### FIXME? *)
    1.28  fun unzip_recT Cs (T as Type (@{type_name prod}, Ts as [_, U])) =
    1.29      if member (op =) Cs U then Ts else [T]
    1.30    | unzip_recT _ T = [T];
    1.31  
    1.32 -fun unzip_corecT Cs (T as Type (@{type_name sum}, Ts as [_, U])) =
    1.33 -    if member (op =) Cs U then Ts else [T]
    1.34 +fun unzip_corecT (Type (@{type_name sum}, _)) T = [T]
    1.35 +  | unzip_corecT _ (T as Type (@{type_name sum}, Ts)) = Ts
    1.36    | unzip_corecT _ T = [T];
    1.37  
    1.38  fun mk_map live Ts Us t =
    1.39 @@ -434,16 +435,18 @@
    1.40      ([(g_Tss, y_Tssss, gss, yssss), (h_Tss, z_Tssss, hss, zssss)], lthy)
    1.41    end;
    1.42  
    1.43 -fun mk_coiter_fun_arg_types0 Cs ns mss fun_Ts =
    1.44 +fun mk_coiter_fun_arg_types0 ctr_Tsss Cs ns fun_Ts =
    1.45    let
    1.46 -    (*avoid "'a itself" arguments in coiterators and corecursors*)
    1.47 -    fun repair_arity [0] = [1]
    1.48 -      | repair_arity ms = ms;
    1.49 +    (*avoid "'a itself" arguments in coiterators*)
    1.50 +    fun repair_arity [[]] = [[@{typ unit}]]
    1.51 +      | repair_arity Tss = Tss;
    1.52  
    1.53 +    val ctr_Tsss' = map repair_arity ctr_Tsss;
    1.54      val f_sum_prod_Ts = map range_type fun_Ts;
    1.55      val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
    1.56 -    val f_Tsss = map2 (map2 dest_tupleT o repair_arity) mss f_prod_Tss;
    1.57 -    val f_Tssss = map2 (fn C => map (map (map (curry op --> C) o unzip_corecT Cs))) Cs f_Tsss;
    1.58 +    val f_Tsss = map2 (map2 (dest_tupleT o length)) ctr_Tsss' f_prod_Tss;
    1.59 +    val f_Tssss = map3 (fn C => map2 (map2 (map (curry op --> C) oo unzip_corecT)))
    1.60 +      Cs ctr_Tsss' f_Tsss;
    1.61      val q_Tssss = map (map (map (fn [_] => [] | [_, T] => [mk_pred1T (domain_type T)]))) f_Tssss;
    1.62    in
    1.63      (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts)
    1.64 @@ -451,18 +454,18 @@
    1.65  
    1.66  fun mk_coiter_p_pred_types Cs ns = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs;
    1.67  
    1.68 -fun mk_coiter_fun_arg_types Cs ns mss dtor_coiter =
    1.69 +fun mk_coiter_fun_arg_types ctr_Tsss Cs ns dtor_coiter =
    1.70    (mk_coiter_p_pred_types Cs ns,
    1.71 -   mk_fp_iter_fun_types dtor_coiter |> mk_coiter_fun_arg_types0 Cs ns mss);
    1.72 +   mk_fp_iter_fun_types dtor_coiter |> mk_coiter_fun_arg_types0 ctr_Tsss Cs ns);
    1.73  
    1.74 -fun mk_coiters_args_types Cs ns mss dtor_coiter_fun_Tss lthy =
    1.75 +fun mk_coiters_args_types ctr_Tsss Cs ns mss dtor_coiter_fun_Tss lthy =
    1.76    let
    1.77      val p_Tss = mk_coiter_p_pred_types Cs ns;
    1.78  
    1.79      fun mk_types get_Ts =
    1.80        let
    1.81          val fun_Ts = map get_Ts dtor_coiter_fun_Tss;
    1.82 -        val (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts) = mk_coiter_fun_arg_types0 Cs ns mss fun_Ts;
    1.83 +        val (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts) = mk_coiter_fun_arg_types0 ctr_Tsss Cs ns fun_Ts;
    1.84          val pf_Tss = map3 flat_corec_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
    1.85        in
    1.86          (q_Tssss, f_Tsss, f_Tssss, (f_sum_prod_Ts, pf_Tss))
    1.87 @@ -509,7 +512,7 @@
    1.88      ((z, cs, cpss, [(unfold_args, unfold_types), (corec_args, corec_types)]), lthy)
    1.89    end;
    1.90  
    1.91 -fun mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy =
    1.92 +fun mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy =
    1.93    let
    1.94      val thy = Proof_Context.theory_of lthy;
    1.95  
    1.96 @@ -521,7 +524,7 @@
    1.97        if fp = Least_FP then
    1.98          mk_iters_args_types Cs ns mss xtor_co_iter_fun_Tss lthy |>> (rpair NONE o SOME)
    1.99        else
   1.100 -        mk_coiters_args_types Cs ns mss xtor_co_iter_fun_Tss lthy |>> (pair NONE o SOME)
   1.101 +        mk_coiters_args_types ctr_Tsss Cs ns mss xtor_co_iter_fun_Tss lthy |>> (pair NONE o SOME)
   1.102    in
   1.103      ((xtor_co_iterss, iters_args_types, coiters_args_types), lthy')
   1.104    end;
   1.105 @@ -1224,7 +1227,7 @@
   1.106      val mss = map (map length) ctr_Tsss;
   1.107  
   1.108      val ((xtor_co_iterss, iters_args_types, coiters_args_types), lthy') =
   1.109 -      mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy;
   1.110 +      mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
   1.111  
   1.112      fun define_ctrs_dtrs_for_type (((((((((((((((((((((((fp_bnf, fp_b), fpT), ctor), dtor),
   1.113              xtor_co_iters), ctor_dtor), dtor_ctor), ctor_inject), pre_map_def), pre_set_defs),