register induct attributes
authorblanchet
Tue, 18 Sep 2012 11:42:11 +0200
changeset 49437 c139da00fb4a
parent 49436 37cae324d73e
child 49438 5bc80d96241e
register induct attributes
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_wrap.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 18 11:41:04 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 18 11:42:11 2012 +0200
@@ -505,7 +505,7 @@
         Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
       end;
 
-    fun build_map build_arg (T as Type (s, Ts)) (U as Type (_, Us)) =
+    fun build_map build_arg (Type (s, Ts)) (Type (_, Us)) =
       let
         val bnf = the (bnf_of lthy s);
         val live = live_of_bnf bnf;
@@ -603,6 +603,8 @@
             `(conj_dests nn) induct_thm
           end;
 
+        val induct_cases = Datatype_Prop.indexify_names (maps (map base_name_of_ctr) ctrss);
+
         val (iter_thmss, rec_thmss) =
           let
             val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
@@ -651,19 +653,23 @@
                goal_recss rec_tacss)
           end;
 
+        val induct_case_names_attr = Attrib.internal (K (Rule_Cases.case_names induct_cases));
+        fun induct_type_attr T_name = Attrib.internal (K (Induct.induct_type T_name));
+
         val common_notes =
-          (if nn > 1 then [(inductN, [induct_thm], [])] (* FIXME: attribs *) else [])
+          (if nn > 1 then [(inductN, [induct_thm], [induct_case_names_attr])] else [])
           |> map (fn (thmN, thms, attrs) =>
               ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
 
         val notes =
-          [(inductN, map single induct_thms, []), (* FIXME: attribs *)
-           (itersN, iter_thmss, simp_attrs),
-           (recsN, rec_thmss, Code.add_default_eqn_attrib :: simp_attrs)]
+          [(inductN, map single induct_thms,
+            fn T_name => [induct_case_names_attr, induct_type_attr T_name]),
+           (itersN, iter_thmss, K simp_attrs),
+           (recsN, rec_thmss, K (Code.add_default_eqn_attrib :: simp_attrs))]
           |> maps (fn (thmN, thmss, attrs) =>
-            map2 (fn b => fn thms =>
-              ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), attrs),
-                [(thms, [])])) fp_bs thmss);
+            map3 (fn b => fn Type (T_name, _) => fn thms =>
+              ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), attrs T_name),
+                [(thms, [])])) fp_bs fpTs thmss);
       in
         lthy |> Local_Theory.notes (common_notes @ notes) |> snd
       end;
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 18 11:41:04 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 18 11:42:11 2012 +0200
@@ -9,6 +9,7 @@
 sig
   val mk_half_pairss: 'a list -> ('a * 'a) list list
   val mk_ctr: typ list -> term -> term
+  val base_name_of_ctr: term -> string
   val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list ->
     ((bool * term list) * term) *
       (binding list * (binding list list * (binding * term) list list)) -> local_theory ->
@@ -69,14 +70,14 @@
     Term.subst_atomic_types (Ts0 ~~ Ts) ctr
   end;
 
-fun eta_expand_case_arg xs f_xs = fold_rev Term.lambda xs f_xs;
-
 fun base_name_of_ctr c =
   Long_Name.base_name (case head_of c of
       Const (s, _) => s
     | Free (s, _) => s
     | _ => error "Cannot extract name of constructor");
 
+fun eta_expand_arg xs f_xs = fold_rev Term.lambda xs f_xs;
+
 fun prepare_wrap_datatype prep_term (((no_dests, raw_ctrs), raw_case),
     (raw_disc_bindings, (raw_sel_bindingss, raw_sel_defaultss))) no_defs_lthy =
   let
@@ -176,8 +177,8 @@
     val xfs = map2 (curry Term.list_comb) fs xss;
     val xgs = map2 (curry Term.list_comb) gs xss;
 
-    val eta_fs = map2 eta_expand_case_arg xss xfs;
-    val eta_gs = map2 eta_expand_case_arg xss xgs;
+    val eta_fs = map2 eta_expand_arg xss xfs;
+    val eta_gs = map2 eta_expand_arg xss xgs;
 
     val fcase = Term.list_comb (casex, eta_fs);
     val gcase = Term.list_comb (casex, eta_gs);
@@ -557,8 +558,8 @@
             (split_thm, split_asm_thm)
           end;
 
+        val exhaust_case_names_attr = Attrib.internal (K (Rule_Cases.case_names exhaust_cases));
         val cases_type_attr = Attrib.internal (K (Induct.cases_type dataT_name));
-        val exhaust_case_names_attr = Attrib.internal (K (Rule_Cases.case_names exhaust_cases));
 
         val notes =
           [(case_congN, [case_cong_thm], []),