allow defaults for one datatype to involve the constructor of another one in the mutually recursive case
authorblanchet
Tue, 11 Sep 2012 18:58:29 +0200
changeset 49287 ebe2a5cec4bf
parent 49286 dde4967c9233
child 49291 66058a677ddd
allow defaults for one datatype to involve the constructor of another one in the mutually recursive case
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 18:39:47 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 18:58:29 2012 +0200
@@ -319,7 +319,7 @@
             (mk_terms sssss hssss, (h_sum_prod_Ts, ph_Tss))))
         end;
 
-    fun pour_some_sugar_on_type ((((((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec),
+    fun define_ctrs_case_for_type ((((((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec),
           fld_unf), unf_fld), fld_inject), n), ks), ms), ctr_binders), ctr_mixfixes), ctr_Tss),
         disc_binders), sel_binderss), raw_sel_defaultss) no_defs_lthy =
       let
@@ -400,7 +400,7 @@
 
         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
 
-        fun some_lfp_sugar ((selss0, discIs, sel_thmss), no_defs_lthy) =
+        fun define_iter_rec ((selss0, discIs, sel_thmss), no_defs_lthy) =
           let
             val fpT_to_C = fpT --> C;
 
@@ -439,7 +439,7 @@
              lthy)
           end;
 
-        fun some_gfp_sugar ((selss0, discIs, sel_thmss), no_defs_lthy) =
+        fun define_coiter_corec ((selss0, discIs, sel_thmss), no_defs_lthy) =
           let
             val B_to_fpT = C --> fpT;
 
@@ -483,11 +483,15 @@
               corec_def), lthy)
           end;
 
-        val sel_defaultss = map (map (apsnd (prepare_term lthy'))) raw_sel_defaultss;
+        fun wrap lthy =
+          let val sel_defaultss = map (map (apsnd (prepare_term lthy))) raw_sel_defaultss in
+            wrap_datatype tacss (((no_dests, ctrs0), casex0), (disc_binders, (sel_binderss,
+              sel_defaultss))) lthy
+          end;
+
+        val define_iter_likes = if lfp then define_iter_rec else define_coiter_corec;
       in
-        wrap_datatype tacss (((no_dests, ctrs0), casex0), (disc_binders, (sel_binderss,
-          sel_defaultss))) lthy'
-        |> (if lfp then some_lfp_sugar else some_gfp_sugar)
+        ((wrap, define_iter_likes), lthy')
       end;
 
     val pre_map_defs = map map_def_of_bnf pre_bnfs;
@@ -506,7 +510,7 @@
         val args = map build_arg TUs;
       in Term.list_comb (mapx, args) end;
 
-    fun pour_more_sugar_on_lfps ((ctrss, _, iters, recs, vs, xsss, ctr_defss, _, _, iter_defs,
+    fun derive_iter_rec_thms_for_types ((ctrss, _, iters, recs, vs, xsss, ctr_defss, _, _, iter_defs,
         rec_defs), lthy) =
       let
         val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
@@ -567,8 +571,8 @@
         lthy |> Local_Theory.notes notes |> snd
       end;
 
-    fun pour_more_sugar_on_gfps ((ctrss, selsss, coiters, corecs, vs, _, ctr_defss, discIss,
-        sel_thmsss, coiter_defs, corec_defs), lthy) =
+    fun derive_coiter_corec_thms_for_types ((ctrss, selsss, coiters, corecs, vs, _, ctr_defss,
+        discIss, sel_thmsss, coiter_defs, corec_defs), lthy) =
       let
         val z = the_single zs;
 
@@ -665,12 +669,15 @@
         lthy |> Local_Theory.notes notes |> snd
       end;
 
+    fun wrap_types_and_define_iter_likes ((wraps, define_iter_likess), lthy) =
+      fold_map2 (curry (op o)) define_iter_likess wraps lthy |>> split_list11
+
     val lthy' = lthy
-      |> fold_map pour_some_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
+      |> fold_map define_ctrs_case_for_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
         fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_binderss ~~
         ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss ~~ raw_sel_defaultsss)
-      |>> split_list11
-      |> (if lfp then pour_more_sugar_on_lfps else pour_more_sugar_on_gfps);
+      |>> split_list |> wrap_types_and_define_iter_likes
+      |> (if lfp then derive_iter_rec_thms_for_types else derive_coiter_corec_thms_for_types);
 
     val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
       (if lfp then "" else "co") ^ "datatype"));