more sugar on codatatypes
authorblanchet
Sat, 08 Sep 2012 21:04:26 +0200
changeset 49212 ca59649170b0
parent 49211 239a4fa29ddf
child 49213 975ccb0130cb
more sugar on codatatypes
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_util.ML
src/HOL/Codatatype/Tools/bnf_wrap_tactics.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -21,6 +21,8 @@
 open BNF_FP_Sugar_Tactics
 
 val caseN = "case";
+val coitersN = "iters";
+val corecsN = "recs";
 val itersN = "iters";
 val recsN = "recs";
 
@@ -28,13 +30,16 @@
 
 fun retype_free (Free (s, _)) T = Free (s, T);
 
-fun flat_list_comb (f, xss) = fold (fn xs => fn t => Term.list_comb (t, xs)) xss f
+val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs))
 
 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
 fun mk_uncurried2_fun f xss =
   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
 
+fun popescu_zip [] [fs] = fs
+  | popescu_zip (p :: ps) (fs :: fss) = p :: fs @ popescu_zip ps fss;
+
 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
 
 fun merge_type_arg_constrained ctxt (T, c) (T', c') =
@@ -148,10 +153,10 @@
     val flds = map (mk_fld As) flds0;
 
     val fpTs = map (domain_type o fastype_of) unfs;
-    val is_fpT = member (op =) fpTs;
 
     val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
     val ns = map length ctr_Tsss;
+    val kss = map (fn n => 1 upto n) ns;
     val mss = map (map length) ctr_Tsss;
     val Css = map2 replicate ns Cs;
 
@@ -168,20 +173,21 @@
     val fp_iters as fp_iter1 :: _ = map (mk_iter_like As Cs) fp_iters0;
     val fp_recs as fp_rec1 :: _ = map (mk_iter_like As Cs) fp_recs0;
 
-    val fp_iter_g_Ts = fst (split_last (binder_types (fastype_of fp_iter1)));
-    val fp_rec_h_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
+    val fp_iter_fun_Ts = fst (split_last (binder_types (fastype_of fp_iter1)));
+    val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
 
     fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
         if member (op =) Cs U then Us else [T]
       | dest_rec_pair T = [T];
 
-    val (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)),
-         (cs, pss, p_Tss, coiter_extra, corec_extra)) =
+    val (((gss, g_Tss, ysss), (hss, h_Tss, zssss)),
+         (cs, cpss, p_Tss, coiter_extra as ((pgss, cgsss), g_sum_prod_Ts, g_prod_Tss, g_Tsss),
+          corec_extra as ((phss, chsss), h_sum_prod_Ts, h_prod_Tss, h_Tsss))) =
       if lfp then
         let
           val y_Tsss =
             map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN n o domain_type)
-              ns mss fp_iter_g_Ts;
+              ns mss fp_iter_fun_Ts;
           val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
 
           val ((gss, ysss), _) =
@@ -191,7 +197,7 @@
 
           val z_Tssss =
             map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n
-              o domain_type) ns mss fp_rec_h_Ts;
+              o domain_type) ns mss fp_rec_fun_Ts;
           val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
 
           val hss = map2 (map2 retype_free) gss h_Tss;
@@ -199,23 +205,25 @@
             lthy
             |> mk_Freessss "x" z_Tssss;
         in
-          (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)),
-           ([], [], [], ([], [], [], []), ([], [], [], [])))
+          (((gss, g_Tss, ysss), (hss, h_Tss, zssss)),
+           ([], [], [], (([], []), [], [], []), (([], []), [], [], [])))
         end
       else
         let
-          fun mk_to_dest_prodT C = map2 (map (curry (op -->) C) oo dest_tupleT);
-
           val p_Tss =
             map2 (fn C => fn n => replicate (Int.max (0, n - 1)) (C --> HOLogic.boolT)) Cs ns;
 
-          val g_sum_prod_Ts = map range_type fp_iter_g_Ts;
-          val g_prod_Tss = map2 dest_sumTN ns g_sum_prod_Ts;
-          val g_Tsss = map3 mk_to_dest_prodT Cs mss g_prod_Tss;
+          fun mk_types fun_Ts =
+            let
+              val f_sum_prod_Ts = map range_type fun_Ts;
+              val f_prod_Tss = map2 dest_sumTN ns f_sum_prod_Ts;
+              val f_Tsss =
+                map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss f_prod_Tss;
+              val pf_Tss = map2 popescu_zip p_Tss f_Tsss
+            in (f_sum_prod_Ts, f_prod_Tss, f_Tsss, pf_Tss) end;
 
-          val h_sum_prod_Ts = map range_type fp_rec_h_Ts;
-          val h_prod_Tss = map2 dest_sumTN ns h_sum_prod_Ts;
-          val h_Tsss = map3 mk_to_dest_prodT Cs mss h_prod_Tss;
+          val (g_sum_prod_Ts, g_prod_Tss, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
+          val (h_sum_prod_Ts, h_prod_Tss, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
 
           val (((c, pss), gsss), _) =
             lthy
@@ -226,20 +234,23 @@
           val hsss = map2 (map2 (map2 retype_free)) gsss h_Tsss;
 
           val cs = map (retype_free c) Cs;
+          val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
+
+          fun mk_terms fsss =
+            let
+              val pfss = map2 popescu_zip pss fsss;
+              val cfsss = map2 (fn c => map (map (fn f => f $ c))) cs fsss
+            in (pfss, cfsss) end;
         in
-          ((([], [], [], []), ([], [], [], [])),
-           (cs, pss, p_Tss, (gsss, g_sum_prod_Ts, g_prod_Tss, g_Tsss),
-            (hsss, h_sum_prod_Ts, h_prod_Tss, h_Tsss)))
+          ((([], [], []), ([], [], [])),
+           (cs, cpss, p_Tss, (mk_terms gsss, g_sum_prod_Ts, g_prod_Tss, pg_Tss),
+            (mk_terms hsss, h_sum_prod_Ts, h_prod_Tss, ph_Tss)))
         end;
 
-    fun pour_some_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf),
-          unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders),
-          sel_binderss) no_defs_lthy =
+    fun pour_some_sugar_on_type (((((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec),
+          fld_unf), unf_fld), fld_inject), n), ks), ms), ctr_binders), ctr_mixfixes), ctr_Tss),
+        disc_binders), sel_binderss) no_defs_lthy =
       let
-        val n = length ctr_Tss;
-        val ks = 1 upto n;
-        val ms = map length ctr_Tss;
-
         val unfT = domain_type (fastype_of fld);
         val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
         val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
@@ -324,10 +335,10 @@
             val rec_binder = Binding.suffix_name ("_" ^ recN) b;
 
             val iter_spec =
-              mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of iter_binder, iter_T), gss),
+              mk_Trueprop_eq (lists_bmoc gss (Free (Binding.name_of iter_binder, iter_T)),
                 Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
             val rec_spec =
-              mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of rec_binder, rec_T), hss),
+              mk_Trueprop_eq (lists_bmoc hss (Free (Binding.name_of rec_binder, rec_T)),
                 Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) hss zssss));
 
             val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy
@@ -353,21 +364,13 @@
 
         fun some_gfp_sugar no_defs_lthy =
           let
-            fun zip_preds_and_getters ps fss = ps @ flat fss;
-
             val B_to_fpT = C --> fpT;
 
-            val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
-
-            fun generate_coiter_like (suf, fp_iter_like,
-                (fsss, f_sum_prod_Ts, f_prod_Tss, f_Tsss)) =
+            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfsss), f_sum_prod_Ts, f_prod_Tss,
+                pf_Tss)) =
               let
-                val pf_Tss = map2 zip_preds_and_getters p_Tss f_Tsss;
                 val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
 
-                val pfss = map2 zip_preds_and_getters pss fsss;
-                val cfsss = map2 (fn c => map (map (fn f => f $ c))) cs fsss;
-
                 val binder = Binding.suffix_name ("_" ^ suf) b;
 
                 fun mk_join c n cps sum_prod_T prod_Ts cfss =
@@ -375,12 +378,15 @@
                     (map2 (mk_InN prod_Ts) (map HOLogic.mk_tuple cfss) (1 upto n)));
 
                 val spec =
-                  mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of binder, res_T), pfss),
+                  mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
                     Term.list_comb (fp_iter_like,
                       map6 mk_join cs ns cpss f_sum_prod_Ts f_prod_Tss cfsss));
               in (binder, spec) end;
 
-            val coiter_likes = [(coiterN, fp_iter, coiter_extra), (corecN, fp_rec, corec_extra)];
+            val coiter_likes =
+              [(coiterN, fp_iter, coiter_extra),
+               (corecN, fp_rec, corec_extra)];
+
             val (binders, specs) = map generate_coiter_like coiter_likes |> split_list;
 
             val ((csts, defs), (lthy', lthy)) = no_defs_lthy
@@ -403,29 +409,29 @@
         |> (if lfp then some_lfp_sugar else some_gfp_sugar)
       end;
 
-    fun pour_more_sugar_on_datatypes ((ctrss, iters, recs, xsss, ctr_defss, iter_defs, rec_defs),
+    fun pour_more_sugar_on_lfps ((ctrss, iters, recs, xsss, ctr_defss, iter_defs, rec_defs),
         lthy) =
       let
         val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
-        val giters = map (fn iter => flat_list_comb (iter, gss)) iters;
-        val hrecs = map (fn recx => flat_list_comb (recx, hss)) recs;
+        val giters = map (lists_bmoc gss) iters;
+        val hrecs = map (lists_bmoc hss) recs;
 
         val (iter_thmss, rec_thmss) =
           let
-            fun mk_goal_iter_like fss fc xctr f xs xs' =
+            fun mk_goal_iter_like fss fiter_like xctr f xs fxs =
               fold_rev (fold_rev Logic.all) (xs :: fss)
-                (mk_Trueprop_eq (fc $ xctr, Term.list_comb (f, xs')));
+                (mk_Trueprop_eq (fiter_like $ xctr, Term.list_comb (f, fxs)));
 
-            fun fix_iter_free (x as Free (_, T)) =
-              (case find_index (eq_fpT T) fpTs of ~1 => x | j => nth giters j $ x);
-            fun fix_rec_free (x as Free (_, T)) =
-              (case find_index (eq_fpT T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]);
+            fun repair_iter_call (x as Free (_, T)) =
+              (case find_index (curry (op =) T) fpTs of ~1 => x | j => nth giters j $ x);
+            fun repair_rec_call (x as Free (_, T)) =
+              (case find_index (curry (op =) T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]);
 
-            val iter_xsss = map (map (map fix_iter_free)) xsss;
-            val rec_xsss = map (map (maps fix_rec_free)) xsss;
+            val gxsss = map (map (map repair_iter_call)) xsss;
+            val hxsss = map (map (maps repair_rec_call)) xsss;
 
-            val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss iter_xsss;
-            val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss rec_xsss;
+            val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss gxsss;
+            val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss hxsss;
 
             val iter_tacss =
               map2 (map o mk_iter_like_tac pre_map_defs iter_defs) fp_iter_thms ctr_defss;
@@ -449,12 +455,53 @@
         lthy |> Local_Theory.notes notes |> snd
       end;
 
+    fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, xsss, ctr_defss, coiter_defs, corec_defs),
+        lthy) =
+      let
+        val gcoiters = map (lists_bmoc pgss) coiters;
+        val hcorecs = map (lists_bmoc phss) corecs;
+
+        val (coiter_thmss, corec_thmss) =
+          let
+            fun mk_cond pos = HOLogic.mk_Trueprop o (not pos ? HOLogic.mk_not);
+
+            fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr cfs' =
+              fold_rev (fold_rev Logic.all) ([c] :: pfss)
+                (Logic.list_implies (seq_conds mk_cond n k cps,
+                   mk_Trueprop_eq (fcoiter_like $ c, Term.list_comb (ctr, cfs'))));
+
+            fun repair_coiter_like_call fcoiter_likes (cf as Free (_, Type (_, [_, T])) $ _) =
+              (case find_index (curry (op =) T) Cs of ~1 => cf | j => nth fcoiter_likes j $ cf);
+
+            val cgsss = map (map (map (repair_coiter_like_call gcoiters))) cgsss;
+            val chsss = map (map (map (repair_coiter_like_call hcorecs))) chsss;
+
+            val goal_coiterss =
+              map7 (map3 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss cgsss;
+            val goal_corecss =
+              map7 (map3 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss chsss;
+          in
+            (map (map (Skip_Proof.make_thm (Proof_Context.theory_of lthy))) goal_coiterss,
+             map (map (Skip_Proof.make_thm (Proof_Context.theory_of lthy))) goal_coiterss (*### goal_corecss*))
+          end;
+
+        val notes =
+          [(coitersN, coiter_thmss),
+           (corecsN, corec_thmss)]
+          |> maps (fn (thmN, thmss) =>
+            map2 (fn b => fn thms =>
+                ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
+              bs thmss);
+      in
+        lthy |> Local_Theory.notes notes |> snd
+      end;
+
     val lthy' = lthy
       |> fold_map pour_some_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
-        fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~
-        ctr_Tsss ~~ disc_binderss ~~ sel_bindersss)
+        fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_binderss ~~
+        ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss)
       |>> split_list7
-      |> (if lfp then pour_more_sugar_on_datatypes else snd);
+      |> (if lfp then pour_more_sugar_on_lfps else pour_more_sugar_on_gfps);
 
     val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
       (if lfp then "" else "co") ^ "datatype"));
--- a/src/HOL/Codatatype/Tools/bnf_util.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_util.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -42,6 +42,7 @@
     'a list -> 'b list -> 'c list -> 'd list -> 'e list -> 'f list -> 'g list -> 'h -> 'i list * 'h
   val interleave: 'a list -> 'a list -> 'a list
   val transpose: 'a list list -> 'a list list
+  val seq_conds: (bool -> 'a -> 'b) -> int -> int -> 'a list -> 'b list
 
   val mk_fresh_names: Proof.context -> int -> string -> string list * Proof.context
   val mk_TFrees: int -> Proof.context -> typ list * Proof.context
@@ -537,6 +538,14 @@
   | transpose ([] :: xss) = transpose xss
   | transpose xss = map hd xss :: transpose (map tl xss);
 
+fun seq_conds f n k xs =
+  if k = n then
+    map (f false) (take (k - 1) xs)
+  else
+    let val (negs, pos) = split_last (take k xs) in
+      map (f false) negs @ [f true pos]
+    end;
+
 fun mk_unabs_def 0 thm = thm
   | mk_unabs_def n thm = mk_unabs_def (n - 1) thm RS @{thm spec[OF iffD1[OF fun_eq_iff]]};
 
--- a/src/HOL/Codatatype/Tools/bnf_wrap_tactics.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap_tactics.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -27,13 +27,7 @@
 open BNF_Util
 open BNF_Tactics
 
-fun triangle _ [] = []
-  | triangle k (xs :: xss) = take k xs :: triangle (k + 1) xss
-
-fun mk_case_if_P_or_not_Ps n k thms =
-  let val (negs, pos) = split_last thms in
-    map (fn thm => thm RS @{thm if_not_P}) negs @ (if k = n then [] else [pos RS @{thm if_P}])
-  end;
+fun if_P_or_not_P_OF pos thm = thm RS (if pos then @{thm if_P} else @{thm if_not_P});
 
 fun ss_only thms = Simplifier.clear_ss HOL_basic_ss addsimps thms
 
@@ -80,7 +74,7 @@
    EVERY' (map3 (fn case_thm => fn if_disc_thms => fn sel_thms =>
        EVERY' [hyp_subst_tac, SELECT_GOAL (Local_Defs.unfold_tac ctxt (if_disc_thms @ sel_thms)),
          rtac case_thm])
-     case_thms (map2 (mk_case_if_P_or_not_Ps n) (1 upto n) (triangle 1 disc_thmss')) sel_thmss)) 1;
+     case_thms (map2 (seq_conds if_P_or_not_P_OF n) (1 upto n) disc_thmss') sel_thmss)) 1;
 
 fun mk_case_cong_tac exhaust' case_thms =
   (rtac exhaust' THEN'