continued changing type of corec type
authorblanchet
Tue, 02 Oct 2012 01:00:18 +0200
changeset 49683 78a3d5006cf1
parent 49682 f57af1c46f99
child 49684 1cf810b8f600
continued changing type of corec type
src/HOL/BNF/BNF_FP.thy
src/HOL/BNF/Examples/Infinite_Derivation_Trees/Tree.thy
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML
--- a/src/HOL/BNF/BNF_FP.thy	Tue Oct 02 01:00:18 2012 +0200
+++ b/src/HOL/BNF/BNF_FP.thy	Tue Oct 02 01:00:18 2012 +0200
@@ -47,11 +47,11 @@
 lemma Un_cong: "\<lbrakk>A = B; C = D\<rbrakk> \<Longrightarrow> A \<union> C = B \<union> D"
 by simp
 
-lemma pointfree_idE: "f o g = id \<Longrightarrow> f (g x) = x"
+lemma pointfree_idE: "f \<circ> g = id \<Longrightarrow> f (g x) = x"
 unfolding o_def fun_eq_iff by simp
 
 lemma o_bij:
-  assumes gf: "g o f = id" and fg: "f o g = id"
+  assumes gf: "g \<circ> f = id" and fg: "f \<circ> g = id"
   shows "bij f"
 unfolding bij_def inj_on_def surj_def proof safe
   fix a1 a2 assume "f a1 = f a2"
@@ -67,8 +67,8 @@
 lemma ssubst_mem: "\<lbrakk>t = s; s \<in> X\<rbrakk> \<Longrightarrow> t \<in> X" by simp
 
 lemma sum_case_step:
-  "sum_case (sum_case f' g') g (Inl p) = sum_case f' g' p"
-  "sum_case f (sum_case f' g') (Inr p) = sum_case f' g' p"
+"sum_case (sum_case f' g') g (Inl p) = sum_case f' g' p"
+"sum_case f (sum_case f' g') (Inr p) = sum_case f' g' p"
 by auto
 
 lemma one_pointE: "\<lbrakk>\<And>x. s = x \<Longrightarrow> P\<rbrakk> \<Longrightarrow> P"
@@ -100,6 +100,14 @@
 "sum_case f g (if p then Inl x else Inr y) = (if p then f x else g y)"
 by simp
 
+lemma sum_case_o_inj:
+"sum_case f g \<circ> Inl = f"
+"sum_case f g \<circ> Inr = g"
+by auto
+
+lemma ident_o_ident: "(\<lambda>x. x) \<circ> (\<lambda>x. x) = (\<lambda>x. x)"
+by (rule o_def)
+
 lemma mem_UN_compreh_eq: "(z : \<Union>{y. \<exists>x\<in>A. y = F x}) = (\<exists>x\<in>A. z : F x)"
 by blast
 
--- a/src/HOL/BNF/Examples/Infinite_Derivation_Trees/Tree.thy	Tue Oct 02 01:00:18 2012 +0200
+++ b/src/HOL/BNF/Examples/Infinite_Derivation_Trees/Tree.thy	Tue Oct 02 01:00:18 2012 +0200
@@ -69,7 +69,7 @@
 definition "Node n as \<equiv> NNode n (the_inv fset as)"
 definition "cont \<equiv> fset o ccont"
 definition "unfold rt ct \<equiv> Tree_unfold rt (the_inv fset o ct)"
-definition "corec rt ct \<equiv> Tree_corec rt (the_inv fset o ct)"
+definition "corec rt qt ct dt \<equiv> Tree_corec rt qt (the_inv fset o ct) (the_inv fset o dt)"
 
 definition lift ("_ ^#" 200) where
 "lift \<phi> as \<longleftrightarrow> (\<forall> tr. Inr tr \<in> as \<longrightarrow> \<phi> tr)"
@@ -179,9 +179,11 @@
 by (metis (no_types) fset_to_fset map_fset_image)
 
 theorem corec:
-"root (corec rt ct b) = rt b"
-"finite (ct b) \<Longrightarrow> cont (corec rt ct b) = image (id \<oplus> ([[id, corec rt ct]])) (ct b)"
-using Tree.sel_corec[of rt "the_inv fset \<circ> ct" b] unfolding corec_def
+"root (corec rt qt ct dt b) = rt b"
+"\<lbrakk>finite (ct b); finite (dt b)\<rbrakk> \<Longrightarrow>
+ cont (corec rt qt ct dt b) =
+ (if qt b then ct b else image (id \<oplus> corec rt qt ct dt) (dt b))"
+using Tree.sel_corec[of rt qt "the_inv fset \<circ> ct" "the_inv fset \<circ> dt" b] unfolding corec_def
 apply -
 apply simp
 unfolding cont_def comp_def id_def
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Oct 02 01:00:18 2012 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Oct 02 01:00:18 2012 +0200
@@ -113,11 +113,6 @@
     Type (_, Ts) => map (not o member (op =) (deads_of_bnf bnf)) Ts
   | _ => replicate n false);
 
-fun tack z_name (c, u) f =
-  let val z = Free (z_name, mk_sumT (fastype_of u, fastype_of c)) in
-    Term.lambda z (mk_sum_case (Term.lambda u u, Term.lambda c (f $ c)) $ z)
-  end;
-
 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
 
 fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
@@ -277,6 +272,7 @@
     val pre_map_defs = map map_def_of_bnf pre_bnfs;
     val pre_set_defss = map set_defs_of_bnf pre_bnfs;
     val pre_rel_defs = map rel_def_of_bnf pre_bnfs;
+    val nested_map_comps'' = map ((fn thm => thm RS sym) o map_comp_of_bnf) nested_bnfs;
     val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
     val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
     val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nesting_bnfs;
@@ -312,8 +308,8 @@
     val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of any_fp_rec)));
 
     val (((fold_only as (gss, _, _), rec_only as (hss, _, _)),
-          (zs, cs, cpss, unfold_only as ((pgss, crgsss), _), corec_only as ((phss, cshsss), _))),
-         names_lthy0) =
+          (cs, cpss, unfold_only as ((pgss, crssss, cgssss), (_, g_Tsss, _)),
+           corec_only as ((phss, csssss, chssss), (_, h_Tsss, _)))), names_lthy0) =
       if lfp then
         let
           val y_Tsss =
@@ -344,7 +340,7 @@
           val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
         in
           ((((gss, g_Tss, ysss), (hss, h_Tss, zsss)),
-            ([], [], [], (([], []), ([], [])), (([], []), ([], [])))), lthy)
+            ([], [], (([], [], []), ([], [], [])), (([], [], []), ([], [], [])))), lthy)
         end
       else
         let
@@ -373,10 +369,9 @@
 
           val (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
 
-          val ((((Free (z, _), cs), pss), gssss), lthy) =
+          val (((cs, pss), gssss), lthy) =
             lthy
-            |> yield_singleton (mk_Frees "z") dummyT
-            ||>> mk_Frees "a" Cs
+            |> mk_Frees "a" Cs
             ||>> mk_Freess "p" p_Tss
             ||>> mk_Freessss "g" g_Tssss;
           val rssss = map (map (map (fn [] => []))) r_Tssss;
@@ -401,32 +396,16 @@
 
           val cpss = map2 (map o rapp) cs pss;
 
-          fun build_sum_inj mk_inj (T, U) =
-            if T = U then
-              id_const T
-            else
-              (case (T, U) of
-                (Type (s, _), Type (s', _)) =>
-                if s = s' then build_map (build_sum_inj mk_inj) T U
-                else uncurry mk_inj (dest_sumT U)
-              | _ => uncurry mk_inj (dest_sumT U));
-
-          fun build_dtor_corec_arg _ [] [cf] = cf
-            | build_dtor_corec_arg T [cq] [cf, cf'] =
-              mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
-                (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
-
-          fun mk_terms f_Tsss qssss fssss =
+          fun mk_terms qssss fssss =
             let
               val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
               val cqssss = map2 (map o map o map o rapp) cs qssss;
               val cfssss = map2 (map o map o map o rapp) cs fssss;
-              val cqfsss = map3 (map3 (map3 build_dtor_corec_arg)) f_Tsss cqssss cfssss;
-            in (pfss, cqfsss) end;
+            in (pfss, cqssss, cfssss) end;
         in
           (((([], [], []), ([], [], [])),
-            ([z], cs, cpss, (mk_terms g_Tsss rssss gssss, (g_sum_prod_Ts, pg_Tss)),
-             (mk_terms h_Tsss sssss hssss, (h_sum_prod_Ts, ph_Tss)))), lthy)
+            (cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, g_Tsss, pg_Tss)),
+             (mk_terms sssss hssss, (h_sum_prod_Ts, h_Tsss, ph_Tss)))), lthy)
         end;
 
     fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
@@ -668,12 +647,30 @@
           let
             val B_to_fpT = C --> fpT;
 
+            fun build_sum_inj mk_inj (T, U) =
+              if T = U then
+                id_const T
+              else
+                (case (T, U) of
+                  (Type (s, _), Type (s', _)) =>
+                  if s = s' then build_map (build_sum_inj mk_inj) T U
+                  else uncurry mk_inj (dest_sumT U)
+                | _ => uncurry mk_inj (dest_sumT U));
+
+            fun build_dtor_corec_like_arg _ [] [cf] = cf
+              | build_dtor_corec_like_arg T [cq] [cf, cf'] =
+                mk_If cq (build_sum_inj Inl_const (fastype_of cf, T) $ cf)
+                  (build_sum_inj Inr_const (fastype_of cf', T) $ cf')
+
+            val crgsss = map3 (map3 (map3 build_dtor_corec_like_arg)) g_Tsss crssss cgssss;
+            val cshsss = map3 (map3 (map3 build_dtor_corec_like_arg)) h_Tsss csssss chssss;
+
             fun mk_preds_getterss_join c n cps sum_prod_T cqfss =
               Term.lambda c (mk_IfN sum_prod_T cps
                 (map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cqfss) (1 upto n)));
 
-            fun generate_corec_like (suf, fp_rec_like, ((pfss, cqfsss), (f_sum_prod_Ts,
-                pf_Tss))) =
+            fun generate_corec_like (suf, fp_rec_like, (cqfsss, ((pfss, _, _), (f_sum_prod_Ts, _,
+                pf_Tss)))) =
               let
                 val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
                 val binding = qualify false fp_b_name (Binding.suffix_name ("_" ^ suf) fp_b);
@@ -684,8 +681,8 @@
               in (binding, spec) end;
 
             val corec_like_infos =
-              [(unfoldN, fp_fold, unfold_only),
-               (corecN, fp_rec, corec_only)];
+              [(unfoldN, fp_fold, (crgsss, unfold_only)),
+               (corecN, fp_rec, (cshsss, corec_only))];
 
             val (bindings, specs) = map generate_corec_like corec_like_infos |> split_list;
 
@@ -919,8 +916,7 @@
             fun build_rel rs' T =
               (case find_index (curry (op =) T) fpTs of
                 ~1 =>
-                if exists_fp_subtype T then build_rel_step (build_rel rs') T
-                else HOLogic.eq_const T
+                if exists_fp_subtype T then build_rel_step (build_rel rs') T else HOLogic.eq_const T
               | kk => nth rs' kk);
 
             fun build_rel_app rs' usel vsel =
@@ -974,7 +970,6 @@
 
         fun mk_maybe_not pos = not pos ? HOLogic.mk_not;
 
-        val z = the_single zs;
         val gunfolds = map (lists_bmoc pgss) unfolds;
         val hcorecs = map (lists_bmoc phss) corecs;
 
@@ -985,58 +980,66 @@
                 (Logic.list_implies (seq_conds (HOLogic.mk_Trueprop oo mk_maybe_not) n k cps,
                    mk_Trueprop_eq (fcorec_like $ c, Term.list_comb (ctr, take m cfs'))));
 
-            fun build_corec_like fcorec_likes maybe_tack (T, U) =
+            fun build_corec_like fcorec_likes (T, U) =
               if T = U then
                 id_const T
               else
                 (case find_index (curry (op =) U) fpTs of
-                  ~1 => build_map (build_corec_like fcorec_likes maybe_tack) T U
-                | kk => maybe_tack (nth cs kk, nth us kk) (nth fcorec_likes kk));
+                  ~1 => build_map (build_corec_like fcorec_likes) T U
+                | kk => nth fcorec_likes kk);
+
+            val mk_U = typ_subst (map2 pair Cs fpTs);
 
-            fun mk_U maybe_mk_sumT =
-              typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
+            fun intr_corec_likes fcorec_likes [] [cf] =
+                let val T = fastype_of cf in
+                  if exists_Cs_subtype T then build_corec_like fcorec_likes (T, mk_U T) $ cf else cf
+                end
+              | intr_corec_likes fcorec_likes [cq] [cf, cf'] =
+                mk_If cq (intr_corec_likes fcorec_likes [] [cf])
+                  (intr_corec_likes fcorec_likes [] [cf']);
+
+            val crgsss = map2 (map2 (map2 (intr_corec_likes gunfolds))) crssss cgssss;
+            val cshsss = map2 (map2 (map2 (intr_corec_likes hcorecs))) csssss chssss;
 
-            fun intr_corec_likes fcorec_likes maybe_mk_sumT maybe_tack cqf =
-              let val T = fastype_of cqf in
-                if exists_Cs_subtype T then
-                  build_corec_like fcorec_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
-                else
-                  cqf
+            val unfold_goalss =
+              map8 (map4 oooo mk_goal pgss) cs cpss gunfolds ns kss ctrss mss crgsss;
+            val corec_goalss =
+              map8 (map4 oooo mk_goal phss) cs cpss hcorecs ns kss ctrss mss cshsss;
+
+            fun mk_map_if_distrib bnf =
+              let
+                val mapx = map_of_bnf bnf;
+                val live = live_of_bnf bnf;
+                val ((Ts, T), U) = strip_typeN (live + 1) (fastype_of mapx) |>> split_last;
+                val fs = Variable.variant_frees lthy [mapx] (map (pair "f") Ts);
+                val t = Term.list_comb (mapx, map (Var o apfst (rpair 0)) fs);
+              in
+                Drule.instantiate' (map (SOME o certifyT lthy) [U, T]) [SOME (certify lthy t)]
+                  @{thm if_distrib}
               end;
 
-            val crgsss' = map (map (map (intr_corec_likes gunfolds (K I) (K I)))) crgsss;
-            val cshsss' =
-              map (map (map (intr_corec_likes hcorecs (curry mk_sumT) (tack z)))) cshsss;
-
-            val unfold_goalss =
-              map8 (map4 oooo mk_goal pgss) cs cpss gunfolds ns kss ctrss mss crgsss';
-            val corec_goalss =
-              map8 (map4 oooo mk_goal phss) cs cpss hcorecs ns kss ctrss mss cshsss';
+            val nested_map_if_distribs = map mk_map_if_distrib nested_bnfs;
 
             val unfold_tacss =
-              map3 (map oo mk_corec_like_tac unfold_defs nesting_map_ids'') fp_fold_thms
-                pre_map_defs ctr_defss;
+              map3 (map oo mk_corec_like_tac unfold_defs [] [] nesting_map_ids'' [])
+                fp_fold_thms pre_map_defs ctr_defss;
             val corec_tacss =
-              map3 (map oo mk_corec_like_tac corec_defs nesting_map_ids'') fp_rec_thms pre_map_defs
-                ctr_defss;
+              map3 (map oo mk_corec_like_tac corec_defs nested_map_comps'' nested_map_comp's
+                  (nested_map_ids'' @ nesting_map_ids'') nested_map_if_distribs)
+                fp_rec_thms pre_map_defs ctr_defss;
 
             fun prove goal tac =
               Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation;
 
             val unfold_thmss = map2 (map2 prove) unfold_goalss unfold_tacss;
-            val corec_thmss =
-              map2 (map2 prove) corec_goalss corec_tacss
-              |> map (map (unfold_thms lthy @{thms sum_case_if}));
-
-            val unfold_safesss = map2 (map2 (map2 (curry (op =)))) crgsss' crgsss;
-            val corec_safesss = map2 (map2 (map2 (curry (op =)))) cshsss' cshsss;
+            val corec_thmss = map2 (map2 prove) corec_goalss corec_tacss;
 
             val filter_safesss =
               map2 (map_filter (fn (safes, thm) => if forall I safes then SOME thm else NONE) oo
-                curry (op ~~));
+                curry (op ~~)) (map2 (map2 (map2 (member (op =)))) cgssss crgsss);
 
-            val safe_unfold_thmss = filter_safesss unfold_safesss unfold_thmss;
-            val safe_corec_thmss = filter_safesss corec_safesss corec_thmss;
+            val safe_unfold_thmss = filter_safesss unfold_thmss;
+            val safe_corec_thmss = filter_safesss corec_thmss;
           in
             (unfold_thmss, corec_thmss, safe_unfold_thmss, safe_corec_thmss)
           end;
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML	Tue Oct 02 01:00:18 2012 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML	Tue Oct 02 01:00:18 2012 +0200
@@ -14,7 +14,8 @@
   val mk_case_tac: Proof.context -> int -> int -> int -> thm -> thm -> thm -> tactic
   val mk_coinduct_tac: Proof.context -> thm list -> int -> int list -> thm -> thm list ->
     thm list -> thm list -> thm list list -> thm list list list -> thm list list list -> tactic
-  val mk_corec_like_tac: thm list -> thm list -> thm -> thm -> thm -> Proof.context -> tactic
+  val mk_corec_like_tac: thm list -> thm list -> thm list -> thm list -> thm list -> thm -> thm ->
+    thm -> Proof.context -> tactic
   val mk_ctor_iff_dtor_tac: Proof.context -> ctyp option list -> cterm -> cterm -> thm -> thm ->
     tactic
   val mk_disc_corec_like_iff_tac: thm list -> thm list -> thm list -> Proof.context -> tactic
@@ -37,7 +38,7 @@
 val basic_simp_thms = @{thms simp_thms(7,8,12,14,22,24)};
 val more_simp_thms = basic_simp_thms @ @{thms simp_thms(11,15,16,21)};
 
-val sum_prod_thms_map = @{thms id_apply map_pair_simp sum_map.simps prod.cases};
+val sum_prod_thms_map = @{thms id_apply map_pair_simp prod.cases sum.cases sum_map.simps};
 val sum_prod_thms_set0 =
   @{thms SUP_empty Sup_empty Sup_insert UN_insert Un_empty_left Un_empty_right Un_iff
       Union_Un_distrib collect_def[abs_def] image_def o_apply map_pair_simp
@@ -107,10 +108,16 @@
   unfold_thms_tac ctxt (ctr_def :: ctor_rec_like :: rec_like_defs @ pre_map_defs @ map_comp's @
     map_ids'' @ rec_like_unfold_thms) THEN rtac refl 1;
 
-fun mk_corec_like_tac corec_like_defs map_ids'' ctor_dtor_corec_like pre_map_def ctr_def ctxt =
+(*TODO: sum_case_if needed?*)
+val corec_like_unfold_thms =
+  @{thms id_def ident_o_ident sum_case_if sum_case_o_inj} @ sum_prod_thms_map;
+
+fun mk_corec_like_tac corec_like_defs map_comps'' map_comp's map_ids'' map_if_distribs
+    ctor_dtor_corec_like pre_map_def ctr_def ctxt =
   unfold_thms_tac ctxt (ctr_def :: corec_like_defs) THEN
   (rtac (ctor_dtor_corec_like RS trans) THEN' asm_simp_tac ss_if_True_False) 1 THEN_MAYBE
-  (unfold_thms_tac ctxt (pre_map_def :: @{thm id_def} :: sum_prod_thms_map @ map_ids'') THEN
+  (unfold_thms_tac ctxt (pre_map_def :: map_comp's @ map_comps'' @ map_ids'' @ map_if_distribs @
+    corec_like_unfold_thms) THEN
    (rtac refl ORELSE' rtac (@{thm unit_eq} RS arg_cong)) 1);
 
 fun mk_disc_corec_like_iff_tac case_splits' corec_likes discs ctxt =