merged
authorwenzelm
Tue, 06 Aug 2013 23:24:10 +0200
changeset 52882 45678f8e7a0f
parent 52872 fd14b0ead643 (diff)
parent 52881 4eb44754f1bb (current diff)
child 52883 0a7c97c76f46
child 52885 9e4bae21494d
merged
--- a/src/Doc/Datatypes/Datatypes.thy	Tue Aug 06 23:20:25 2013 +0200
+++ b/src/Doc/Datatypes/Datatypes.thy	Tue Aug 06 23:24:10 2013 +0200
@@ -220,7 +220,7 @@
     datatype_new nat = Zero | Suc nat
 
 text {*
-Setup to be able to write @{term 0} instead of @{const Zero}:
+Setup to be able to write @{text 0} instead of @{const Zero}:
 *}
 
     instantiation nat :: zero
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Aug 06 23:20:25 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Aug 06 23:24:10 2013 +0200
@@ -26,7 +26,9 @@
 
   val tvar_subst: theory -> typ list -> typ list -> ((string * int) * typ) list
   val exists_subtype_in: typ list -> typ -> bool
-  val flat_rec: 'a list list -> 'a list
+  val flat_rec_arg_args: 'a list list -> 'a list
+  val flat_corec_preds_predsss_gettersss: 'a list -> 'a list list list -> 'a list list list ->
+    'a list
   val mk_co_iter: theory -> BNF_FP_Util.fp_kind -> typ -> typ list -> term -> term
   val nesty_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
   val mk_map: int -> typ list -> typ list -> term -> term
@@ -40,8 +42,10 @@
         * ((term list list * term list list list) * (typ list * typ list list)) list) option)
     * Proof.context
 
-  val mk_iter_fun_arg_typessss: typ list -> int list -> int list list -> term ->
+  val mk_iter_fun_arg_types: typ list -> int list -> int list list -> term ->
     typ list list list list
+  val mk_coiter_fun_arg_types: typ list -> int list -> int list list -> term ->
+    typ list list list list * typ list list list * typ list list list list * typ list
   val define_iters: string list ->
     (typ list list * typ list list list list * term list list * term list list list list) list ->
     (string -> binding) -> typ list -> typ list -> term list -> Proof.context ->
@@ -181,24 +185,24 @@
 
 val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
 
-fun flat_rec xss =
-  (* The first line below gives the preferred order. The second line is for compatibility with the
-     old datatype package: *)
+fun flat_rec_arg_args xss =
+  (* FIXME (once the old datatype package is phased out): The first line below gives the preferred
+     order. The second line is for compatibility with the old datatype package. *)
 (*
   flat xss
 *)
   map hd xss @ maps tl xss;
 
-fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss);
+fun flat_corec_predss_getterss qss fss = maps (op @) (qss ~~ fss);
 
-fun flat_preds_predsss_gettersss [] [qss] [fss] = flat_predss_getterss qss fss
-  | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
-    p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
+fun flat_corec_preds_predsss_gettersss [] [qss] [fss] = flat_corec_predss_getterss qss fss
+  | flat_corec_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
+    p :: flat_corec_predss_getterss qss fss @ flat_corec_preds_predsss_gettersss ps qsss fsss;
 
 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
 fun mk_uncurried2_fun f xss =
-  mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat_rec xss);
+  mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat_rec_arg_args xss);
 
 fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) =
   Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1));
@@ -248,6 +252,10 @@
     if member (op =) Cs U then Ts else [T]
   | unzip_recT _ T = [T];
 
+fun unzip_corecT Cs (T as Type (@{type_name sum}, Ts as [_, U])) =
+    if member (op =) Cs U then Ts else [T]
+  | unzip_corecT _ T = [T];
+
 fun mk_map live Ts Us t =
   let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
@@ -321,17 +329,17 @@
 
 fun indexify proj xs f p = f (find_index (curry (op =) (proj p)) xs) p;
 
-fun mk_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
+fun mk_iter_fun_arg_types0 n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
 
-fun mk_iter_fun_arg_typessss Cs ns mss =
+fun mk_iter_fun_arg_types Cs ns mss =
   mk_fp_iter_fun_types
-  #> map3 mk_fun_arg_typess ns mss
+  #> map3 mk_iter_fun_arg_types0 ns mss
   #> map (map (map (unzip_recT Cs)));
 
 fun mk_iters_args_types Cs ns mss ctor_iter_fun_Tss lthy =
   let
     val Css = map2 replicate ns Cs;
-    val y_Tsss = map3 mk_fun_arg_typess ns mss (map un_fold_of ctor_iter_fun_Tss);
+    val y_Tsss = map3 mk_iter_fun_arg_types0 ns mss (map un_fold_of ctor_iter_fun_Tss);
     val g_Tss = map2 (fn C => map (fn y_Ts => y_Ts ---> C)) Cs y_Tsss;
 
     val ((gss, ysss), lthy) =
@@ -346,7 +354,7 @@
       map3 (fn n => fn ms => map2 (map (unzip_recT Cs) oo dest_tupleT) ms o
         dest_sumTN_balanced n o domain_type o co_rec_of) ns mss ctor_iter_fun_Tss;
 
-    val z_Tsss' = map (map flat_rec) z_Tssss;
+    val z_Tsss' = map (map flat_rec_arg_args) z_Tssss;
     val h_Tss = map2 (map2 (curry (op --->))) z_Tsss' Css;
 
     val hss = map2 (map2 retype_free) h_Tss gss;
@@ -359,32 +367,40 @@
     ([(g_Tss, y_Tssss, gss, yssss), (h_Tss, z_Tssss, hss, zssss)], lthy)
   end;
 
-fun mk_coiters_args_types Cs ns mss dtor_coiter_fun_Tss lthy =
+fun mk_coiter_fun_arg_types0 Cs ns mss fun_Ts =
   let
     (*avoid "'a itself" arguments in coiterators and corecursors*)
     fun repair_arity [0] = [1]
       | repair_arity ms = ms;
 
-    fun unzip_corecT (T as Type (@{type_name sum}, Ts as [_, U])) =
-        if member (op =) Cs U then Ts else [T]
-      | unzip_corecT T = [T];
+    val f_sum_prod_Ts = map range_type fun_Ts;
+    val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
+    val f_Tsss = map2 (map2 dest_tupleT o repair_arity) mss f_prod_Tss;
+    val f_Tssss = map2 (fn C => map (map (map (curry (op -->) C) o unzip_corecT Cs))) Cs f_Tsss;
+    val q_Tssss = map (map (map (fn [_] => [] | [_, T] => [mk_pred1T (domain_type T)]))) f_Tssss;
+  in
+    (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts)
+  end;
 
+fun mk_coiter_fun_arg_types Cs ns mss =
+  mk_fp_iter_fun_types
+  #> mk_coiter_fun_arg_types0 Cs ns mss;
+
+fun mk_coiters_args_types Cs ns mss dtor_coiter_fun_Tss lthy =
+  let
     val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs;
 
-    fun mk_types maybe_unzipT get_Ts =
+    fun mk_types get_Ts =
       let
         val fun_Ts = map get_Ts dtor_coiter_fun_Tss;
-        val f_sum_prod_Ts = map range_type fun_Ts;
-        val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
-        val f_Tsss = map2 (map2 dest_tupleT o repair_arity) mss f_prod_Tss;
-        val f_Tssss = map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
-        val q_Tssss =
-          map (map (map (fn [_] => [] | [_, T] => [mk_pred1T (domain_type T)]))) f_Tssss;
-        val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
-      in (q_Tssss, f_Tsss, f_Tssss, (f_sum_prod_Ts, pf_Tss)) end;
+        val (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts) = mk_coiter_fun_arg_types0 Cs ns mss fun_Ts;
+        val pf_Tss = map3 flat_corec_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
+      in
+        (q_Tssss, f_Tsss, f_Tssss, (f_sum_prod_Ts, pf_Tss))
+      end;
 
-    val (r_Tssss, g_Tsss, g_Tssss, unfold_types) = mk_types single un_fold_of;
-    val (s_Tssss, h_Tsss, h_Tssss, corec_types) = mk_types unzip_corecT co_rec_of;
+    val (r_Tssss, g_Tsss, g_Tssss, unfold_types) = mk_types un_fold_of;
+    val (s_Tssss, h_Tsss, h_Tssss, corec_types) = mk_types co_rec_of;
 
     val ((((Free (z, _), cs), pss), gssss), lthy) =
       lthy
@@ -412,7 +428,7 @@
 
     fun mk_args qssss fssss f_Tsss =
       let
-        val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
+        val pfss = map3 flat_corec_preds_predsss_gettersss pss qssss fssss;
         val cqssss = map2 (map o map o map o rapp) cs qssss;
         val cfssss = map2 (map o map o map o rapp) cs fssss;
         val cqfsss = map3 (map3 (map3 build_dtor_coiter_arg)) f_Tsss cqssss cfssss;
@@ -646,7 +662,7 @@
             build_map lthy (indexify (perhaps (try (snd o HOLogic.dest_prodT)) o snd) Cs
               (fn kk => fn TU => maybe_tick TU (nth us kk) (nth fiters kk))) (T, U) $ x;
 
-        val fxsss = map2 (map2 (flat_rec oo map2 (map o build_iter))) xsss x_Tssss;
+        val fxsss = map2 (map2 (flat_rec_arg_args oo map2 (map o build_iter))) xsss x_Tssss;
 
         val goalss = map5 (map4 o mk_goal fss) fiters xctrss fss xsss fxsss;
 
@@ -669,7 +685,7 @@
   end;
 
 fun derive_coinduct_coiters_thms_for_types pre_bnfs (z, cs, cpss,
-      [(unfold_args as (pgss, crgsss), _), (corec_args as (phss, cshsss), _)])
+      coiters_args_types as [((pgss, crgsss), _), ((phss, cshsss), _)])
     dtor_coinducts dtor_ctors dtor_coiter_thmss nesting_bnfs nested_bnfs fpTs Cs As kss mss ns
     ctr_defss ctr_sugars coiterss coiter_defss export_args lthy =
   let
@@ -799,7 +815,7 @@
     fun mk_maybe_not pos = not pos ? HOLogic.mk_not;
 
     val fcoiterss' as [gunfolds, hcorecs] =
-      map2 (fn (pfss, _) => map (lists_bmoc pfss)) [unfold_args, corec_args] coiterss';
+      map2 (fn (pfss, _) => map (lists_bmoc pfss)) (map fst coiters_args_types) coiterss';
 
     val (unfold_thmss, corec_thmss, safe_unfold_thmss, safe_corec_thmss) =
       let