properly note theorems for split recursors
authorblanchet
Mon, 08 Sep 2014 14:03:01 +0200
changeset 58214 bd1754377965
parent 58213 6411ac1ef04d
child 58215 cccf5445e224
properly note theorems for split recursors
src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_compat.ML
--- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Mon Sep 08 14:03:01 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Mon Sep 08 14:03:01 2014 +0200
@@ -52,8 +52,7 @@
     'a list
   val nesting_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
 
-  type lfp_sugar_thms =
-    (thm list * thm * Token.src list) * (thm list list * Token.src list)
+  type lfp_sugar_thms = (thm list * thm * Token.src list) * (thm list list * Token.src list)
 
   val morph_lfp_sugar_thms: morphism -> lfp_sugar_thms -> lfp_sugar_thms
   val transfer_lfp_sugar_thms: theory -> lfp_sugar_thms -> lfp_sugar_thms
@@ -90,6 +89,7 @@
   val define_corec: 'a * term list * term list list
       * ((term list list * term list list list) * typ list) -> (string -> binding) -> 'b list ->
     typ list -> term list -> term -> local_theory -> (term * thm) * local_theory
+  val mk_induct_attrs: term list list -> Token.src list
   val derive_induct_recs_thms_for_types: BNF_Def.bnf list ->
      ('a * typ list list list list * term list list * 'b) option -> thm -> thm list ->
      BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list -> typ list ->
@@ -385,8 +385,7 @@
 
 fun indexify proj xs f p = f (find_index (curry (op =) (proj p)) xs) p;
 
-type lfp_sugar_thms =
-  (thm list * thm * Token.src list) * (thm list list * Token.src list);
+type lfp_sugar_thms = (thm list * thm * Token.src list) * (thm list list * Token.src list);
 
 fun morph_lfp_sugar_thms phi ((inducts, induct, induct_attrs), (recss, rec_attrs)) =
   ((map (Morphism.thm phi) inducts, Morphism.thm phi induct, induct_attrs),
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Mon Sep 08 14:03:01 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Mon Sep 08 14:03:01 2014 +0200
@@ -119,18 +119,13 @@
     fold_map3 (define_co_rec_as Least_FP Cs) fpTs bs rhss lthy
   end;
 
-fun mk_split_rec_thmss ctxt rec0_thmss (recs as rec1 :: _) rec_defs =
+fun mk_split_rec_thmss ctxt fpTs ctrss rec0_thmss (recs as rec1 :: _) rec_defs =
   let
     val f_Ts = binder_fun_types (fastype_of rec1);
     val (fs, _) = mk_Frees "f" f_Ts ctxt;
     val frecs = map (fn recx => Term.list_comb (recx, fs)) recs;
 
-    fun mk_ctrs_of (Type (T_name, As)) =
-      map (mk_ctr As) (#ctrs (the (ctr_sugar_of ctxt T_name)));
-
-    val fpTs = map (domain_type o body_fun_type o fastype_of) recs;
     val fpTs_frecs = fpTs ~~ frecs;
-    val ctrss = map mk_ctrs_of fpTs;
     val fss = unflat ctrss fs;
 
     fun mk_rec_call g n (Type (@{type_name fun}, [dom_T, ran_T])) =
@@ -170,7 +165,7 @@
     map (map prove) goalss
   end;
 
-fun define_split_rec_derive_induct_rec_thms induct inducts recs0 rec_thmss fpTs lthy =
+fun define_split_rec_derive_induct_rec_thms fpTs ctrss inducts induct recs0 rec_thmss lthy =
   let
     val thy = Proof_Context.theory_of lthy;
 
@@ -178,15 +173,15 @@
        arguments *)
     val repair_induct = unfold_thms lthy @{thms all_mem_range};
 
+    val inducts' = map repair_induct inducts;
     val induct' = repair_induct induct;
-    val inducts' = map repair_induct inducts;
 
     val Cs = map ((fn TVar ((s, _), S) => TFree (s, S)) o body_type o fastype_of) recs0;
     val recs = map2 (mk_co_rec thy Least_FP Cs) fpTs recs0;
     val ((recs', rec'_defs), lthy') = define_split_recs fpTs Cs recs lthy |>> split_list;
-    val rec'_thmss = mk_split_rec_thmss lthy' rec_thmss recs' rec'_defs;
+    val rec'_thmss = mk_split_rec_thmss lthy' fpTs ctrss rec_thmss recs' rec'_defs;
   in
-    ((induct', inducts', recs', rec'_thmss), lthy')
+    ((inducts', induct', recs', rec'_thmss), lthy')
   end;
 
 fun reindex_desc desc =
@@ -293,21 +288,26 @@
       else
         ((fp_sugars0, (NONE, NONE)), lthy);
 
+    fun mk_ctrs_of (Type (T_name, As)) =
+      map (mk_ctr As) (#ctrs (the (ctr_sugar_of lthy' T_name)));
+
+    val ctrss' = map mk_ctrs_of fpTs';
     val {common_co_inducts = [induct], ...} :: _ = fp_sugars;
     val inducts = map (the_single o #co_inducts) fp_sugars;
-
     val recs = map #co_rec fp_sugars;
     val rec_thmss = map #co_rec_thms fp_sugars;
 
     fun is_nested_rec_type (Type (@{type_name fun}, [_, T])) = member (op =) fpTs' (body_type T)
       | is_nested_rec_type _ = false;
 
-    val ((induct', inducts', recs', rec'_thmss), lthy'') =
+    val ((lfp_sugar_thms', (inducts', induct', recs', rec'_thmss)), lthy'') =
       if nesting_pref = Unfold_Nesting andalso
          exists (exists (exists is_nested_rec_type)) ctr_Tsss then
-        define_split_rec_derive_induct_rec_thms induct inducts recs rec_thmss fpTs' lthy'
+        define_split_rec_derive_induct_rec_thms fpTs' ctrss' inducts induct recs rec_thmss lthy'
+        |>> `(fn (inducts', induct', _, rec'_thmss) =>
+          SOME ((inducts', induct', mk_induct_attrs ctrss'), (rec'_thmss, [])))
       else
-        ((induct, inducts, recs, rec_thmss), lthy');
+        ((lfp_sugar_thms, (inducts, induct, recs, rec_thmss)), lthy');
 
     val rec'_names = map (fst o dest_Const) recs';
     val rec'_thms = flat rec'_thmss;
@@ -323,7 +323,7 @@
 
     val infos = map_index mk_info (take nn_fp fp_sugars);
   in
-    (nn, b_names, compat_b_names, lfp_sugar_thms, infos, lthy'')
+    (nn, b_names, compat_b_names, lfp_sugar_thms', infos, lthy'')
   end;
 
 fun infos_of_new_datatype_mutual_cluster lthy fpT_name =
@@ -446,18 +446,18 @@
     val all_notes =
       (case lfp_sugar_thms of
         NONE => []
-      | SOME ((induct_thms, induct_thm, induct_attrs), (rec_thmss, _)) =>
+      | SOME ((inducts, induct, induct_attrs), (rec_thmss, _)) =>
         let
           val common_name = compat_N ^ mk_common_name b_names;
 
           val common_notes =
-            (if nn > 1 then [(inductN, [induct_thm], induct_attrs)] else [])
+            (if nn > 1 then [(inductN, [induct], induct_attrs)] else [])
             |> filter_out (null o #2)
             |> map (fn (thmN, thms, attrs) =>
               ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
 
           val notes =
-            [(inductN, map single induct_thms, induct_attrs),
+            [(inductN, map single inducts, induct_attrs),
              (recN, rec_thmss, code_nitpicksimp_simp_attrs)]
             |> filter_out (null o #2)
             |> maps (fn (thmN, thmss, attrs) =>