preserve case names in '(co)induct' theorems generated by prim(co)rec'
authorblanchet
Tue, 09 Sep 2014 20:51:36 +0200
changeset 58283 71d74e641538
parent 58282 48e16d74845b
child 58284 f9b6af3017fd
preserve case names in '(co)induct' theorems generated by prim(co)rec'
src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML
--- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Tue Sep 09 20:51:36 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Tue Sep 09 20:51:36 2014 +0200
@@ -91,6 +91,8 @@
       * ((term list list * term list list list) * typ list) -> (string -> binding) -> 'b list ->
     typ list -> term list -> term -> local_theory -> (term * thm) * local_theory
   val mk_induct_attrs: term list list -> Token.src list
+  val mk_coinduct_attrs: typ list -> term list list -> term list list -> int list list ->
+    Token.src list * Token.src list
   val derive_induct_recs_thms_for_types: BNF_Def.bnf list ->
      ('a * typ list list list list * term list list * 'b) option -> thm -> thm list ->
      BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list -> typ list ->
@@ -768,7 +770,7 @@
      (rec_thmss, code_nitpicksimp_attrs @ simp_attrs))
   end;
 
-fun mk_coinduct_attributes fpTs ctrss discss mss =
+fun mk_coinduct_attrs fpTs ctrss discss mss =
   let
     val nn = length fpTs;
     val fp_b_names = map base_name_of_typ fpTs;
@@ -876,7 +878,7 @@
   in
     (postproc_co_induct lthy (length fpA_Ts) @{thm predicate2D} @{thm predicate2D_conj}
        rel_coinduct0_thm,
-     mk_coinduct_attributes fpA_Ts (map #ctrs ctr_sugars) (map #discs ctr_sugars) mss)
+     mk_coinduct_attrs fpA_Ts (map #ctrs ctr_sugars) (map #discs ctr_sugars) mss)
   end;
 
 fun derive_set_induct_thms_for_types lthy nn fpTs ctrss setss dtor_set_inducts exhausts
@@ -1164,7 +1166,7 @@
     val corec_sel_thmsss = mk_corec_sel_thms corec_thmss;
   in
     ((coinduct_thms_pairs,
-      mk_coinduct_attributes fpTs (map #ctrs ctr_sugars) (map #discs ctr_sugars) mss),
+      mk_coinduct_attrs fpTs (map #ctrs ctr_sugars) (map #discs ctr_sugars) mss),
      corec_thmss,
      corec_disc_thmss,
      (corec_disc_iff_thmss, simp_attrs),
--- a/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Tue Sep 09 20:51:36 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Tue Sep 09 20:51:36 2014 +0200
@@ -41,7 +41,8 @@
 
   val corec_specs_of: binding list -> typ list -> typ list -> term list ->
     (term * term list list) list list -> local_theory ->
-    (bool * corec_spec list * typ list * thm * thm * thm list * thm list) * local_theory
+    corec_spec list * typ list * thm * thm * thm list * thm list * (Token.src list * Token.src list)
+    * bool * local_theory
   val add_primcorecursive_cmd: primcorec_option list ->
     (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
     Proof.context -> Proof.state
@@ -410,6 +411,8 @@
           common_co_inducts = common_coinduct_thms, ...} :: _, (_, gfp_sugar_thms)), lthy) =
       nested_to_mutual_fps Greatest_FP bs res_Ts callers callssss0 lthy0;
 
+    val coinduct_attrs_pair = (case gfp_sugar_thms of SOME ((_, _, pair), _) => pair | NONE => []);
+
     val perm_fp_sugars = sort (int_ord o pairself #fp_res_index) fp_sugars;
 
     val indices = map #fp_res_index fp_sugars;
@@ -502,9 +505,10 @@
        ctr_specs = mk_ctr_specs ctr_sugar p_is q_isss f_isss f_Tsss corec_thms corec_discs
          corec_selss};
   in
-    ((is_some gfp_sugar_thms, map5 mk_spec fp_sugars p_iss q_issss f_issss f_Tssss, missing_res_Ts,
-      co_induct_of common_coinduct_thms, strong_co_induct_of common_coinduct_thms,
-      co_induct_of coinduct_thmss, strong_co_induct_of coinduct_thmss), lthy)
+    (map5 mk_spec fp_sugars p_iss q_issss f_issss f_Tssss, missing_res_Ts,
+     co_induct_of common_coinduct_thms, strong_co_induct_of common_coinduct_thms,
+     co_induct_of coinduct_thmss, strong_co_induct_of coinduct_thmss, coinduct_attrs_pair,
+     is_some gfp_sugar_thms, lthy)
   end;
 
 val undef_const = Const (@{const_name undefined}, dummyT);
@@ -1024,8 +1028,8 @@
       |> map2 (curry (op |>)) (map (map (fn {ctr, sels, ...} =>
         (ctr, map (K []) sels))) basic_ctr_specss);
 
-    val ((n2m, corec_specs', _, coinduct_thm, coinduct_strong_thm, coinduct_thms,
-          coinduct_strong_thms), lthy') =
+    val (corec_specs', _, coinduct_thm, coinduct_strong_thm, coinduct_thms, coinduct_strong_thms,
+         (coinduct_attrs, common_coinduct_attrs), n2m, lthy') =
       corec_specs_of bs arg_Ts res_Ts frees callssss lthy;
     val corec_specs = take actual_nn corec_specs';
     val ctr_specss = map #ctr_specs corec_specs;
@@ -1406,8 +1410,17 @@
           [(flat disc_iff_or_disc_thmss, simp_attrs)]
           |> map (fn (thms, attrs) => ((Binding.empty, attrs), [(thms, [])]));
 
+        val common_notes =
+          [(coinductN, if n2m then [coinduct_thm] else [], common_coinduct_attrs),
+           (coinduct_strongN, if n2m then [coinduct_strong_thm] else [], common_coindut_attrs)]
+          |> filter_out (null o #2)
+          |> map (fn (thmN, thms, attrs) =>
+            ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
+
         val notes =
-          [(coinductN, map (if n2m then single else K []) coinduct_thms, []),
+          [(coinductN, map (if n2m then single else K []) coinduct_thms, coinduct_attrs),
+           (coinduct_strongN, map (if n2m then single else K []) coinduct_strong_thms,
+            coinduct_attrs),
            (codeN, code_thmss, code_nitpicksimp_attrs),
            (ctrN, ctr_thmss, []),
            (discN, disc_thmss, []),
@@ -1415,26 +1428,18 @@
            (excludeN, exclude_thmss, []),
            (exhaustN, nontriv_exhaust_thmss, []),
            (selN, sel_thmss, simp_attrs),
-           (simpsN, simp_thmss, []),
-           (coinduct_strongN, map (if n2m then single else K []) coinduct_strong_thms, [])]
+           (simpsN, simp_thmss, [])]
           |> maps (fn (thmN, thmss, attrs) =>
             map2 (fn fun_name => fn thms =>
                 ((Binding.qualify true fun_name (Binding.name thmN), attrs), [(thms, [])]))
               fun_names (take actual_nn thmss))
           |> filter_out (null o fst o hd o snd);
-
-        val common_notes =
-          [(coinductN, if n2m then [coinduct_thm] else [], []),
-           (coinduct_strongN, if n2m then [coinduct_strong_thm] else [], [])]
-          |> filter_out (null o #2)
-          |> map (fn (thmN, thms, attrs) =>
-            ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
       in
         lthy
         |> Spec_Rules.add Spec_Rules.Equational (map fst def_infos, flat sel_thmss)
         |> Spec_Rules.add Spec_Rules.Equational (map fst def_infos, flat ctr_thmss)
         |> Spec_Rules.add Spec_Rules.Equational (map fst def_infos, flat code_thmss)
-        |> Local_Theory.notes (anonymous_notes @ notes @ common_notes)
+        |> Local_Theory.notes (anonymous_notes @ common_notes @ notes)
         |> snd
       end;
 
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Tue Sep 09 20:51:36 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Tue Sep 09 20:51:36 2014 +0200
@@ -42,7 +42,8 @@
      is_new_datatype: Proof.context -> string -> bool,
      get_basic_lfp_sugars: binding list -> typ list -> term list ->
        (term * term list list) list list -> local_theory ->
-       typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory,
+       typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * Token.src list
+       * bool * local_theory,
      rewrite_nested_rec_call: Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
        term -> term -> term -> term};
 
@@ -51,7 +52,7 @@
   val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory
   val rec_specs_of: binding list -> typ list -> typ list -> term list ->
     (term * term list list) list list -> local_theory ->
-    (bool * rec_spec list * typ list * thm * thm list) * local_theory
+    (bool * rec_spec list * typ list * thm * thm list * Token.src list) * local_theory
 
   val add_primrec: (binding * typ option * mixfix) list ->
     (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
@@ -118,7 +119,8 @@
    is_new_datatype: Proof.context -> string -> bool,
    get_basic_lfp_sugars: binding list -> typ list -> term list ->
      (term * term list list) list list -> local_theory ->
-     typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory,
+     typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * Token.src list * bool
+     * local_theory,
    rewrite_nested_rec_call: Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
      term -> term -> term -> term};
 
@@ -156,7 +158,7 @@
     val thy = Proof_Context.theory_of lthy0;
 
     val (missing_arg_Ts, perm0_kks, basic_lfp_sugars, fp_nesting_map_ident0s, fp_nesting_map_comps,
-         common_induct, n2m, lthy) =
+         common_induct, induct_attrs, n2m, lthy) =
       get_basic_lfp_sugars bs arg_Ts callers callssss0 lthy0;
 
     val perm_basic_lfp_sugars = sort (int_ord o pairself #fp_res_index) basic_lfp_sugars;
@@ -218,7 +220,8 @@
        fp_nesting_map_ident0s = fp_nesting_map_ident0s, fp_nesting_map_comps = fp_nesting_map_comps,
        ctr_specs = mk_ctr_specs fp_res_index ctr_offset ctrs rec_thms};
   in
-    ((n2m, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, common_induct, inducts), lthy)
+    ((n2m, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, common_induct, inducts,
+      induct_attrs), lthy)
   end;
 
 val undef_const = Const (@{const_name undefined}, dummyT);
@@ -472,7 +475,7 @@
         [] => ()
       | (b, _) :: _ => raise PRIMREC ("type of " ^ Binding.print b ^ " contains top sort", []));
 
-    val ((n2m, rec_specs, _, common_induct, inducts), lthy) =
+    val ((n2m, rec_specs, _, common_induct, inducts, induct_attrs), lthy) =
       rec_specs_of bs arg_Ts res_Ts frees callssss lthy0;
 
     val actual_nn = length funs_data;
@@ -513,7 +516,8 @@
 
     val notes =
       (if n2m then
-         map2 (fn name => fn thm => (name, inductN, [thm], [])) fun_names (take actual_nn inducts)
+         map2 (fn name => fn thm => (name, inductN, [thm], induct_attrs)) fun_names
+           (take actual_nn inducts)
        else
          [])
       |> map (fn (prefix, thmN, thms, attrs) =>
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML	Tue Sep 09 20:51:36 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML	Tue Sep 09 20:51:36 2014 +0200
@@ -33,6 +33,8 @@
           (lfp_sugar_thms, _)), lthy) =
       nested_to_mutual_fps Least_FP bs arg_Ts callers callssss0 lthy0;
 
+    val induct_attrs = (case lfp_sugar_thms of SOME ((_, _, attrs), _) => attrs | NONE => []);
+
     val Ts = map #T fp_sugars;
     val Xs = map #X fp_sugars;
     val Cs = map (body_type o fastype_of o #co_rec) fp_sugars;
@@ -51,7 +53,8 @@
     val fp_nesting_map_comps = map map_comp_of_bnf fp_nesting_bnfs;
   in
     (missing_arg_Ts, perm0_kks, map3 basic_lfp_sugar_of Cs fun_arg_Tssss fp_sugars,
-     fp_nesting_map_ident0s, fp_nesting_map_comps, common_induct, is_some lfp_sugar_thms, lthy)
+     fp_nesting_map_ident0s, fp_nesting_map_comps, common_induct, induct_attrs,
+     is_some lfp_sugar_thms, lthy)
   end;
 
 exception NOT_A_MAP of term;