src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 49337 538687a77075
parent 49336 a2e6473145e4
child 49338 4a922800531d
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Wed Sep 12 17:26:05 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Wed Sep 12 17:26:05 2012 +0200
@@ -10,16 +10,16 @@
   val datatyp: bool ->
     (mixfix list -> (string * sort) list option -> binding list -> typ list * typ list list ->
       BNF_Def.BNF list -> local_theory ->
-      (term list * term list * term list * term list * thm list * thm list * thm list * thm list *
-         thm list) * local_theory) ->
+      (term list * term list * term list * term list * thm * thm list * thm list * thm list *
+         thm list * thm list) * local_theory) ->
     bool * ((((typ * sort) list * binding) * mixfix) * ((((binding * binding) *
       (binding * typ) list) * (binding * term) list) * mixfix) list) list ->
     local_theory -> local_theory
   val parse_datatype_cmd: bool ->
     (mixfix list -> (string * sort) list option -> binding list -> typ list * typ list list ->
       BNF_Def.BNF list -> local_theory ->
-      (term list * term list * term list * term list * thm list * thm list * thm list * thm list *
-         thm list) * local_theory) ->
+      (term list * term list * term list * term list * thm * thm list * thm list * thm list *
+         thm list * thm list) * local_theory) ->
     (local_theory -> local_theory) parser
 end;
 
@@ -33,10 +33,12 @@
 open BNF_FP_Sugar_Tactics
 
 val caseN = "case";
+val coinductsN = "coinducts";
 val coitersN = "coiters";
 val corecsN = "corecs";
 val disc_coitersN = "disc_coiters";
 val disc_corecsN = "disc_corecs";
+val inductsN = "inducts";
 val itersN = "iters";
 val recsN = "recs";
 val sel_coitersN = "sel_coiters";
@@ -131,6 +133,8 @@
         unsorted_As);
 
     val fp_bs = map type_binding_of specs;
+    val fp_common_name = mk_common_name fp_bs;
+
     val fake_Ts = map mk_fake_T fp_bs;
 
     val mixfixes = map mixfix_of specs;
@@ -179,7 +183,7 @@
     val fp_eqs =
       map dest_TFree Bs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctr_sum_prod_TsBs;
 
-    val (pre_bnfs, ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects,
+    val (pre_bnfs, ((unfs0, flds0, fp_iters0, fp_recs0, fp_induct, unf_flds, fld_unfs, fld_injects,
         fp_iter_thms, fp_rec_thms), lthy)) =
       fp_bnf construct fp_bs mixfixes (map dest_TFree unsorted_As) fp_eqs no_defs_lthy0;
 
@@ -517,15 +521,22 @@
         val args = map build_arg TUs;
       in Term.list_comb (mapx, args) end;
 
-    fun derive_iter_rec_thms_for_types ((ctrss, _, iters, recs, vs, xsss, ctr_defss, _, _, iter_defs,
-        rec_defs), lthy) =
+    fun derive_induct_iter_rec_thms_for_types ((ctrss, _, iters, recs, vs, xsss, ctr_defss, _, _,
+        iter_defs, rec_defs), lthy) =
       let
-        val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
-        val giters = map (lists_bmoc gss) iters;
-        val hrecs = map (lists_bmoc hss) recs;
+        val (induct_thms, induct_thm) =
+          let
+            val induct_thm = fp_induct;
+          in
+            `(conj_dests N) induct_thm
+          end;
 
         val (iter_thmss, rec_thmss) =
           let
+            val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
+            val giters = map (lists_bmoc gss) iters;
+            val hrecs = map (lists_bmoc hss) recs;
+
             fun mk_goal_iter_like fss fiter_like xctr f xs fxs =
               fold_rev (fold_rev Logic.all) (xs :: fss)
                 (mk_Trueprop_eq (fiter_like $ xctr, Term.list_comb (f, fxs)));
@@ -567,6 +578,12 @@
                goal_recss rec_tacss)
           end;
 
+        val common_notes =
+          [(inductN, [induct_thm], []), (*### attribs *)
+           (inductsN, induct_thms, [])] (*### attribs *)
+          |> map (fn (thmN, thms, attrs) =>
+              ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
+
         val notes =
           [(itersN, iter_thmss, simp_attrs),
            (recsN, rec_thmss, Code.add_default_eqn_attrib :: simp_attrs)]
@@ -575,19 +592,25 @@
               ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), attrs),
                 [(thms, [])])) fp_bs thmss);
       in
-        lthy |> Local_Theory.notes notes |> snd
+        lthy |> Local_Theory.notes (common_notes @ notes) |> snd
       end;
 
-    fun derive_coiter_corec_thms_for_types ((ctrss, selsss, coiters, corecs, vs, _, ctr_defss,
-        discIss, sel_thmsss, coiter_defs, corec_defs), lthy) =
+    fun derive_coinduct_coiter_corec_thms_for_types ((ctrss, selsss, coiters, corecs, vs, _,
+        ctr_defss, discIss, sel_thmsss, coiter_defs, corec_defs), lthy) =
       let
-        val z = the_single zs;
-
-        val gcoiters = map (lists_bmoc pgss) coiters;
-        val hcorecs = map (lists_bmoc phss) corecs;
+        val (coinduct_thms, coinduct_thm) =
+          let
+            val coinduct_thm = fp_induct;
+          in
+            `(conj_dests N) coinduct_thm
+          end;
 
         val (coiter_thmss, corec_thmss) =
           let
+            val z = the_single zs;
+            val gcoiters = map (lists_bmoc pgss) coiters;
+            val hcorecs = map (lists_bmoc phss) corecs;
+
             fun mk_goal_cond pos = HOLogic.mk_Trueprop o (not pos ? HOLogic.mk_not);
 
             fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfs' =
@@ -684,7 +707,8 @@
         fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_bindingss ~~
         ctr_mixfixess ~~ ctr_Tsss ~~ disc_bindingss ~~ sel_bindingsss ~~ raw_sel_defaultsss)
       |>> split_list |> wrap_types_and_define_iter_likes
-      |> (if lfp then derive_iter_rec_thms_for_types else derive_coiter_corec_thms_for_types);
+      |> (if lfp then derive_induct_iter_rec_thms_for_types
+          else derive_coinduct_coiter_corec_thms_for_types);
 
     val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
       (if lfp then "" else "co") ^ "datatype"));