don't wrongly destroy sum types in coiterators
authorblanchet
Fri, 13 Sep 2013 02:26:59 +0200
changeset 53591 b6e2993fd0d3
parent 53590 b6dc5403cad1
child 53592 5a7bf8c859f6
don't wrongly destroy sum types in coiterators
src/HOL/BNF/Examples/Misc_Codatatype.thy
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
--- a/src/HOL/BNF/Examples/Misc_Codatatype.thy	Fri Sep 13 00:55:44 2013 +0200
+++ b/src/HOL/BNF/Examples/Misc_Codatatype.thy	Fri Sep 13 02:26:59 2013 +0200
@@ -43,6 +43,8 @@
   ('a, 'b1, 'b2) F2 = unit + 'b1 * 'b2
 *)
 
+codatatype 'a p = P "'a + 'a p"
+
 codatatype 'a J1 = J11 'a "'a J1" | J12 'a "'a J2"
 and 'a J2 = J21 | J22 "'a J1" "'a J2"
 
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Fri Sep 13 00:55:44 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Fri Sep 13 02:26:59 2013 +0200
@@ -44,8 +44,8 @@
   val build_rel: local_theory -> (typ * typ -> term) -> typ * typ -> term
   val dest_map: Proof.context -> string -> term -> term * term list
   val dest_ctr: Proof.context -> string -> term -> term * term list
-  val mk_co_iters_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
-    int list list -> term list list -> Proof.context ->
+  val mk_co_iters_prelims: BNF_FP_Util.fp_kind -> typ list list list -> typ list -> typ list ->
+    int list -> int list list -> term list list -> Proof.context ->
     (term list list
      * (typ list list * typ list list list list * term list list
         * term list list list list) list option
@@ -55,7 +55,7 @@
 
   val mk_iter_fun_arg_types: typ list -> int list -> int list list -> term ->
     typ list list list list
-  val mk_coiter_fun_arg_types: typ list -> int list -> int list list -> term ->
+  val mk_coiter_fun_arg_types: typ list list list -> typ list -> int list -> term ->
     typ list list
     * (typ list list list list * typ list list list * typ list list list list * typ list)
   val define_iters: string list ->
@@ -268,12 +268,13 @@
 
 val mk_fp_iter_fun_types = binder_fun_types o fastype_of;
 
+(* ### FIXME? *)
 fun unzip_recT Cs (T as Type (@{type_name prod}, Ts as [_, U])) =
     if member (op =) Cs U then Ts else [T]
   | unzip_recT _ T = [T];
 
-fun unzip_corecT Cs (T as Type (@{type_name sum}, Ts as [_, U])) =
-    if member (op =) Cs U then Ts else [T]
+fun unzip_corecT (Type (@{type_name sum}, _)) T = [T]
+  | unzip_corecT _ (T as Type (@{type_name sum}, Ts)) = Ts
   | unzip_corecT _ T = [T];
 
 fun mk_map live Ts Us t =
@@ -434,16 +435,18 @@
     ([(g_Tss, y_Tssss, gss, yssss), (h_Tss, z_Tssss, hss, zssss)], lthy)
   end;
 
-fun mk_coiter_fun_arg_types0 Cs ns mss fun_Ts =
+fun mk_coiter_fun_arg_types0 ctr_Tsss Cs ns fun_Ts =
   let
-    (*avoid "'a itself" arguments in coiterators and corecursors*)
-    fun repair_arity [0] = [1]
-      | repair_arity ms = ms;
+    (*avoid "'a itself" arguments in coiterators*)
+    fun repair_arity [[]] = [[@{typ unit}]]
+      | repair_arity Tss = Tss;
 
+    val ctr_Tsss' = map repair_arity ctr_Tsss;
     val f_sum_prod_Ts = map range_type fun_Ts;
     val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
-    val f_Tsss = map2 (map2 dest_tupleT o repair_arity) mss f_prod_Tss;
-    val f_Tssss = map2 (fn C => map (map (map (curry op --> C) o unzip_corecT Cs))) Cs f_Tsss;
+    val f_Tsss = map2 (map2 (dest_tupleT o length)) ctr_Tsss' f_prod_Tss;
+    val f_Tssss = map3 (fn C => map2 (map2 (map (curry op --> C) oo unzip_corecT)))
+      Cs ctr_Tsss' f_Tsss;
     val q_Tssss = map (map (map (fn [_] => [] | [_, T] => [mk_pred1T (domain_type T)]))) f_Tssss;
   in
     (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts)
@@ -451,18 +454,18 @@
 
 fun mk_coiter_p_pred_types Cs ns = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs;
 
-fun mk_coiter_fun_arg_types Cs ns mss dtor_coiter =
+fun mk_coiter_fun_arg_types ctr_Tsss Cs ns dtor_coiter =
   (mk_coiter_p_pred_types Cs ns,
-   mk_fp_iter_fun_types dtor_coiter |> mk_coiter_fun_arg_types0 Cs ns mss);
+   mk_fp_iter_fun_types dtor_coiter |> mk_coiter_fun_arg_types0 ctr_Tsss Cs ns);
 
-fun mk_coiters_args_types Cs ns mss dtor_coiter_fun_Tss lthy =
+fun mk_coiters_args_types ctr_Tsss Cs ns mss dtor_coiter_fun_Tss lthy =
   let
     val p_Tss = mk_coiter_p_pred_types Cs ns;
 
     fun mk_types get_Ts =
       let
         val fun_Ts = map get_Ts dtor_coiter_fun_Tss;
-        val (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts) = mk_coiter_fun_arg_types0 Cs ns mss fun_Ts;
+        val (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts) = mk_coiter_fun_arg_types0 ctr_Tsss Cs ns fun_Ts;
         val pf_Tss = map3 flat_corec_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
       in
         (q_Tssss, f_Tsss, f_Tssss, (f_sum_prod_Ts, pf_Tss))
@@ -509,7 +512,7 @@
     ((z, cs, cpss, [(unfold_args, unfold_types), (corec_args, corec_types)]), lthy)
   end;
 
-fun mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy =
+fun mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy =
   let
     val thy = Proof_Context.theory_of lthy;
 
@@ -521,7 +524,7 @@
       if fp = Least_FP then
         mk_iters_args_types Cs ns mss xtor_co_iter_fun_Tss lthy |>> (rpair NONE o SOME)
       else
-        mk_coiters_args_types Cs ns mss xtor_co_iter_fun_Tss lthy |>> (pair NONE o SOME)
+        mk_coiters_args_types ctr_Tsss Cs ns mss xtor_co_iter_fun_Tss lthy |>> (pair NONE o SOME)
   in
     ((xtor_co_iterss, iters_args_types, coiters_args_types), lthy')
   end;
@@ -1224,7 +1227,7 @@
     val mss = map (map length) ctr_Tsss;
 
     val ((xtor_co_iterss, iters_args_types, coiters_args_types), lthy') =
-      mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy;
+      mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
 
     fun define_ctrs_dtrs_for_type (((((((((((((((((((((((fp_bnf, fp_b), fpT), ctor), dtor),
             xtor_co_iters), ctor_dtor), dtor_ctor), ctor_inject), pre_map_def), pre_set_defs),
--- a/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML	Fri Sep 13 00:55:44 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML	Fri Sep 13 02:26:59 2013 +0200
@@ -127,7 +127,7 @@
       val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
 
       val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
-        mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy;
+        mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
 
       fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
 
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Fri Sep 13 00:55:44 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Fri Sep 13 02:26:59 2013 +0200
@@ -389,12 +389,11 @@
     val nn = length perm_fpTs;
     val kks = 0 upto nn - 1;
     val perm_ns = map length perm_ctr_Tsss;
-    val perm_mss = map (map length) perm_ctr_Tsss;
 
     val perm_Cs = map (domain_type o body_fun_type o fastype_of o co_rec_of o
       of_fp_sugar (#xtor_co_iterss o #fp_res)) perm_fp_sugars;
     val (perm_p_Tss, (perm_q_Tssss, _, perm_f_Tssss, _)) =
-      mk_coiter_fun_arg_types perm_Cs perm_ns perm_mss (co_rec_of dtor_coiters1);
+      mk_coiter_fun_arg_types perm_ctr_Tsss perm_Cs perm_ns (co_rec_of dtor_coiters1);
 
     val (perm_p_hss, h) = indexedd perm_p_Tss 0;
     val (perm_q_hssss, h') = indexedddd perm_q_Tssss h;