added "simp"s to coiter/corec theorems + export under "simps" name
authorblanchet
Thu, 20 Sep 2012 13:32:48 +0200
changeset 49479 504f0a38f608
parent 49478 416ad6e2343b
child 49480 4632b867fba7
child 49482 e6d6869eed08
added "simp"s to coiter/corec theorems + export under "simps" name
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Thu Sep 20 13:32:48 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Thu Sep 20 13:32:48 2012 +0200
@@ -66,6 +66,10 @@
 fun merge_type_args (As, As') =
   if length As = length As' then map2 merge_type_arg As As' else cannot_merge_types ();
 
+fun is_triv_implies thm =
+  op aconv (Logic.dest_implies (Thm.prop_of thm))
+  handle TERM _ => false;
+
 fun type_args_constrained_of (((cAs, _), _), _) = cAs;
 fun type_binding_of (((_, b), _), _) = b;
 fun mixfix_of ((_, mx), _) = mx;
@@ -513,13 +517,13 @@
         val args = map build_arg TUs;
       in Term.list_comb (mapx, args) end;
 
+    val mk_simp_thmss =
+      map3 (fn (_, injects, distincts, cases, _, _) => fn rec_likes => fn iter_likes =>
+        injects @ distincts @ cases @ rec_likes @ iter_likes);
+
     fun derive_induct_iter_rec_thms_for_types ((wrap_ress, ctrss, iters, recs, xsss, ctr_defss,
         iter_defs, rec_defs), lthy) =
       let
-        val inject_thmss = map #2 wrap_ress;
-        val distinct_thmss = map #3 wrap_ress;
-        val case_thmss = map #4 wrap_ress;
-
         val (((phis, phis'), vs'), names_lthy) =
           lthy
           |> mk_Frees' "P" (map mk_pred1T fpTs)
@@ -657,24 +661,22 @@
                recss_goal rec_tacss)
           end;
 
-        val simp_thmss =
-          map4 (fn injects => fn distincts => fn cases => fn recs =>
-            injects @ distincts @ cases @ recs) inject_thmss distinct_thmss case_thmss rec_thmss;
+        val simp_thmss = mk_simp_thmss wrap_ress rec_thmss iter_thmss;
 
         val induct_case_names_attr = Attrib.internal (K (Rule_Cases.case_names induct_cases));
         fun induct_type_attr T_name = Attrib.internal (K (Induct.induct_type T_name));
 
         (* TODO: Also note "recs", "simps", and "splits" if "nn > 1" (for compatibility with the
-           old package)? *)
+           old package)? And for codatatypes as well? *)
         val common_notes =
           (if nn > 1 then [(inductN, [induct_thm], [induct_case_names_attr])] else [])
           |> map (fn (thmN, thms, attrs) =>
-              ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
+            ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
 
         val notes =
           [(inductN, map single induct_thms,
             fn T_name => [induct_case_names_attr, induct_type_attr T_name]),
-           (itersN, iter_thmss, K simp_attrs),
+           (itersN, iter_thmss, K (Code.add_default_eqn_attrib :: simp_attrs)),
            (recsN, rec_thmss, K (Code.add_default_eqn_attrib :: simp_attrs)),
            (simpsN, simp_thmss, K [])]
           |> maps (fn (thmN, thmss, attrs) =>
@@ -688,7 +690,7 @@
     fun derive_coinduct_coiter_corec_thms_for_types ((wrap_ress, ctrss, coiters, corecs, _,
         ctr_defss, coiter_defs, corec_defs), lthy) =
       let
-        val selsss0 = map #1 wrap_ress;
+        val selsss = map #1 wrap_ress;
         val discIss = map #5 wrap_ress;
         val sel_thmsss = map #6 wrap_ress;
 
@@ -705,7 +707,7 @@
             `(conj_dests nn) coinduct_thm
           end;
 
-        val (coiter_thmss, corec_thmss) =
+        val (coiter_thmss, corec_thmss, safe_coiter_thmss, safe_corec_thmss) =
           let
             val z = the_single zs;
             val gcoiters = map (lists_bmoc pgss) coiters;
@@ -751,58 +753,86 @@
             val corec_tacss =
               map3 (map oo mk_coiter_like_tac corec_defs nesting_map_ids) fp_rec_thms pre_map_defs
                 ctr_defss;
-          in
-            (map2 (map2 (fn goal => fn tac =>
+
+            val coiter_thmss =
+              map2 (map2 (fn goal => fn tac =>
                  Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
-               coiterss_goal coiter_tacss,
-             map2 (map2 (fn goal => fn tac =>
+               coiterss_goal coiter_tacss;
+            val corec_thmss =
+              map2 (map2 (fn goal => fn tac =>
                  Skip_Proof.prove lthy [] [] goal (tac o #context)
                  |> unfold_defs lthy @{thms sum_case_if} |> Thm.close_derivation))
-               corecss_goal corec_tacss)
+               corecss_goal corec_tacss;
+
+            val coiter_safesss = map2 (map2 (map2 (curry (op =)))) crgsss' crgsss;
+            val corec_safesss = map2 (map2 (map2 (curry (op =)))) cshsss' cshsss;
+
+            val filter_safesss =
+              map2 (map_filter (fn (safes, thm) => if forall I safes then SOME thm else NONE) oo
+                curry (op ~~));
+
+            val safe_coiter_thmss = filter_safesss coiter_safesss coiter_thmss;
+            val safe_corec_thmss = filter_safesss corec_safesss corec_thmss;
+          in
+            (coiter_thmss, corec_thmss, safe_coiter_thmss, safe_corec_thmss)
           end;
 
-        fun mk_disc_coiter_like_thms [_] = K []
-          | mk_disc_coiter_like_thms thms = map2 (curry (op RS)) thms;
+        fun mk_disc_coiter_like_thms coiter_likes discIs =
+          map (op RS) (filter_out (is_triv_implies o snd) (coiter_likes ~~ discIs));
 
         val disc_coiter_thmss = map2 mk_disc_coiter_like_thms coiter_thmss discIss;
         val disc_corec_thmss = map2 mk_disc_coiter_like_thms corec_thmss discIss;
 
-        fun mk_sel_coiter_like_thm coiter_like_thm sel0 sel_thm =
+        fun mk_sel_coiter_like_thm coiter_like_thm sel sel_thm =
           let
-            val (domT, ranT) = dest_funT (fastype_of sel0);
+            val (domT, ranT) = dest_funT (fastype_of sel);
             val arg_cong' =
               Drule.instantiate' (map (SOME o certifyT lthy) [domT, ranT])
-                [NONE, NONE, SOME (certify lthy sel0)] arg_cong
+                [NONE, NONE, SOME (certify lthy sel)] arg_cong
               |> Thm.varifyT_global;
             val sel_thm' = sel_thm RSN (2, trans);
           in
             coiter_like_thm RS arg_cong' RS sel_thm'
           end;
 
-        val sel_coiter_thmsss =
-          map3 (map3 (map2 o mk_sel_coiter_like_thm)) coiter_thmss selsss0 sel_thmsss;
-        val sel_corec_thmsss =
-          map3 (map3 (map2 o mk_sel_coiter_like_thm)) corec_thmss selsss0 sel_thmsss;
+        fun mk_sel_coiter_like_thms coiter_likess =
+          map3 (map3 (map2 o mk_sel_coiter_like_thm)) coiter_likess selsss sel_thmsss |> map flat;
+
+        val sel_coiter_thmss = mk_sel_coiter_like_thms coiter_thmss;
+        val sel_corec_thmss = mk_sel_coiter_like_thms corec_thmss;
+
+        fun zip_coiter_like_thms coiter_likes disc_coiter_likes sel_coiter_likes =
+          coiter_likes @ disc_coiter_likes @ sel_coiter_likes;
+
+        val simp_thmss =
+          mk_simp_thmss wrap_ress
+            (map3 zip_coiter_like_thms safe_corec_thmss disc_corec_thmss sel_corec_thmss)
+            (map3 zip_coiter_like_thms safe_coiter_thmss disc_coiter_thmss sel_coiter_thmss);
+
+        val anonymous_notes =
+          [(flat safe_coiter_thmss @ flat safe_corec_thmss, simp_attrs)]
+          |> map (fn (thms, attrs) => ((Binding.empty, attrs), [(thms, [])]));
 
         val common_notes =
           (if nn > 1 then [(coinductN, [coinduct_thm], [])] (* FIXME: attribs *) else [])
           |> map (fn (thmN, thms, attrs) =>
-              ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
+            ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
 
         val notes =
           [(coinductN, map single coinduct_thms, []), (* FIXME: attribs *)
            (coitersN, coiter_thmss, []),
-           (disc_coitersN, disc_coiter_thmss, []),
-           (sel_coitersN, map flat sel_coiter_thmsss, []),
+           (disc_coitersN, disc_coiter_thmss, simp_attrs),
+           (sel_coitersN, sel_coiter_thmss, simp_attrs),
            (corecsN, corec_thmss, []),
-           (disc_corecsN, disc_corec_thmss, []),
-           (sel_corecsN, map flat sel_corec_thmsss, [])]
+           (disc_corecsN, disc_corec_thmss, simp_attrs),
+           (sel_corecsN, sel_corec_thmss, simp_attrs),
+           (simpsN, simp_thmss, [])]
           |> maps (fn (thmN, thmss, attrs) =>
             map_filter (fn (_, []) => NONE | (b, thms) =>
               SOME ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), attrs),
                 [(thms, [])])) (fp_bs ~~ thmss));
       in
-        lthy |> Local_Theory.notes (common_notes @ notes) |> snd
+        lthy |> Local_Theory.notes (anonymous_notes @ common_notes @ notes) |> snd
       end;
 
     fun wrap_types_and_define_iter_likes ((wraps, define_iter_likess), lthy) =