define corecursors
authorblanchet
Sat, 08 Sep 2012 21:04:26 +0200
changeset 49211 239a4fa29ddf
parent 49210 656fb50d33f0
child 49212 ca59649170b0
define corecursors
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -168,20 +168,20 @@
     val fp_iters as fp_iter1 :: _ = map (mk_iter_like As Cs) fp_iters0;
     val fp_recs as fp_rec1 :: _ = map (mk_iter_like As Cs) fp_recs0;
 
-    val fp_iter_f_Ts = fst (split_last (binder_types (fastype_of fp_iter1)));
-    val fp_rec_f_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
+    val fp_iter_g_Ts = fst (split_last (binder_types (fastype_of fp_iter1)));
+    val fp_rec_h_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
 
     fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
         if member (op =) Cs U then Us else [T]
       | dest_rec_pair T = [T];
 
     val (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)),
-         (cs, (qss, q_Tss, gsss, g_Tsss), ())) =
+         (cs, pss, p_Tss, coiter_extra, corec_extra)) =
       if lfp then
         let
           val y_Tsss =
             map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN n o domain_type)
-              ns mss fp_iter_f_Ts;
+              ns mss fp_iter_g_Ts;
           val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
 
           val ((gss, ysss), _) =
@@ -191,7 +191,7 @@
 
           val z_Tssss =
             map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n
-              o domain_type) ns mss fp_rec_f_Ts;
+              o domain_type) ns mss fp_rec_h_Ts;
           val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
 
           val hss = map2 (map2 retype_free) gss h_Tss;
@@ -200,26 +200,36 @@
             |> mk_Freessss "x" z_Tssss;
         in
           (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)),
-           ([], ([], [], [], []), ()))
+           ([], [], [], ([], [], [], []), ([], [], [], [])))
         end
       else
         let
-          val q_Tss =
+          fun mk_to_dest_prodT C = map2 (map (curry (op -->) C) oo dest_tupleT);
+
+          val p_Tss =
             map2 (fn C => fn n => replicate (Int.max (0, n - 1)) (C --> HOLogic.boolT)) Cs ns;
-          val g_Tsss =
-            map4 (fn C => fn n => fn ms => map2 (map (curry (op -->) C) oo dest_tupleT) ms o
-              dest_sumTN n o range_type) Cs ns mss fp_iter_f_Ts;
+
+          val g_sum_prod_Ts = map range_type fp_iter_g_Ts;
+          val g_prod_Tss = map2 dest_sumTN ns g_sum_prod_Ts;
+          val g_Tsss = map3 mk_to_dest_prodT Cs mss g_prod_Tss;
 
-          val (((c, qss), gsss), _) =
+          val h_sum_prod_Ts = map range_type fp_rec_h_Ts;
+          val h_prod_Tss = map2 dest_sumTN ns h_sum_prod_Ts;
+          val h_Tsss = map3 mk_to_dest_prodT Cs mss h_prod_Tss;
+
+          val (((c, pss), gsss), _) =
             lthy
             |> yield_singleton (mk_Frees "c") dummyT
-            ||>> mk_Freess "p" q_Tss
+            ||>> mk_Freess "p" p_Tss
             ||>> mk_Freesss "g" g_Tsss;
 
+          val hsss = map2 (map2 (map2 retype_free)) gsss h_Tsss;
+
           val cs = map (retype_free c) Cs;
         in
           ((([], [], [], []), ([], [], [], [])),
-           (cs, (qss, q_Tss, gsss, g_Tsss), ()))
+           (cs, pss, p_Tss, (gsss, g_sum_prod_Ts, g_prod_Tss, g_Tsss),
+            (hsss, h_sum_prod_Ts, h_prod_Tss, h_Tsss)))
         end;
 
     fun pour_some_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf),
@@ -313,14 +323,11 @@
             val iter_binder = Binding.suffix_name ("_" ^ iterN) b;
             val rec_binder = Binding.suffix_name ("_" ^ recN) b;
 
-            val iter_free = Free (Binding.name_of iter_binder, iter_T);
-            val rec_free = Free (Binding.name_of rec_binder, rec_T);
-
             val iter_spec =
-              mk_Trueprop_eq (flat_list_comb (iter_free, gss),
+              mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of iter_binder, iter_T), gss),
                 Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
             val rec_spec =
-              mk_Trueprop_eq (flat_list_comb (rec_free, hss),
+              mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of rec_binder, rec_T), hss),
                 Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) hss zssss));
 
             val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy
@@ -346,73 +353,48 @@
 
         fun some_gfp_sugar no_defs_lthy =
           let
-            (* qss, q_Tss, gsss, g_Tsss *)
-            fun zip_preds_and_getters p_Ts f_Tss = p_Ts @ flat f_Tss;
-
-            val qg_Tss = map2 zip_preds_and_getters q_Tss g_Tsss;
+            fun zip_preds_and_getters ps fss = ps @ flat fss;
 
             val B_to_fpT = C --> fpT;
-            val coiter_T = fold_rev (curry (op --->)) qg_Tss B_to_fpT;
-(*
-            val corec_T = fold_rev (curry (op --->)) h_Tss fpT_to_C;
-*)
+
+            val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
 
-            val qgss = map2 zip_preds_and_getters qss gsss;
-            val cqss = map2 (fn c => map (fn q => q $ c)) cs qss;
-            val cgsss = map2 (fn c => map (map (fn g => g $ c))) cs gsss;
+            fun generate_coiter_like (suf, fp_iter_like,
+                (fsss, f_sum_prod_Ts, f_prod_Tss, f_Tsss)) =
+              let
+                val pf_Tss = map2 zip_preds_and_getters p_Tss f_Tsss;
+                val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
 
-            val coiter_binder = Binding.suffix_name ("_" ^ coiterN) b;
-            val corec_binder = Binding.suffix_name ("_" ^ corecN) b;
+                val pfss = map2 zip_preds_and_getters pss fsss;
+                val cfsss = map2 (fn c => map (map (fn f => f $ c))) cs fsss;
 
-            val coiter_free = Free (Binding.name_of coiter_binder, coiter_T);
-(*
-            val corec_free = Free (Binding.name_of corec_binder, corec_T);
-*)
+                val binder = Binding.suffix_name ("_" ^ suf) b;
 
-            val coiter_sum_prod_Ts = map range_type fp_iter_f_Ts;
-            val coiter_prod_Tss = map2 dest_sumTN ns coiter_sum_prod_Ts;
-
-            fun mk_join c n cqs sum_prod_T prod_Ts cgss =
-              Term.lambda c (mk_IfN sum_prod_T cqs
-                (map2 (mk_InN prod_Ts) (map HOLogic.mk_tuple cgss) (1 upto n)));
+                fun mk_join c n cps sum_prod_T prod_Ts cfss =
+                  Term.lambda c (mk_IfN sum_prod_T cps
+                    (map2 (mk_InN prod_Ts) (map HOLogic.mk_tuple cfss) (1 upto n)));
 
-            val coiter_spec =
-              mk_Trueprop_eq (flat_list_comb (coiter_free, qgss),
-                Term.list_comb (fp_iter,
-                  map6 mk_join cs ns cqss coiter_sum_prod_Ts coiter_prod_Tss cgsss));
-(*
-            val corec_spec =
-              mk_Trueprop_eq (flat_list_comb (corec_free, hss),
-                Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) hss zssss));
-*)
+                val spec =
+                  mk_Trueprop_eq (flat_list_comb (Free (Binding.name_of binder, res_T), pfss),
+                    Term.list_comb (fp_iter_like,
+                      map6 mk_join cs ns cpss f_sum_prod_Ts f_prod_Tss cfsss));
+              in (binder, spec) end;
 
-            val (([raw_coiter (*, raw_corec*)], [raw_coiter_def (*, raw_corec_def*)]), (lthy', lthy)) = no_defs_lthy
+            val coiter_likes = [(coiterN, fp_iter, coiter_extra), (corecN, fp_rec, corec_extra)];
+            val (binders, specs) = map generate_coiter_like coiter_likes |> split_list;
+
+            val ((csts, defs), (lthy', lthy)) = no_defs_lthy
               |> apfst split_list o fold_map2 (fn b => fn spec =>
                 Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec))
-                #>> apsnd snd) [coiter_binder (*, corec_binder*)] [coiter_spec (*, corec_spec*)]
+                #>> apsnd snd) binders specs
               ||> `Local_Theory.restore;
 
             (*transforms defined frees into consts (and more)*)
             val phi = Proof_Context.export_morphism lthy lthy';
 
-            val coiter_def = Morphism.thm phi raw_coiter_def;
-(*
-            val corec_def = Morphism.thm phi raw_corec_def;
-*)
+            val [coiter_def, corec_def] = map (Morphism.thm phi) defs;
 
-            val coiter0 = Morphism.term phi raw_coiter;
-(*
-            val corec0 = Morphism.term phi raw_corec;
-*)
-
-            val coiter = mk_iter_like As Cs coiter0;
-(*
-            val corec = mk_iter_like As Cs corec0;
-*)
-
-            (*###*)
-            val corec = @{term True};
-            val corec_def = TrueI;
+            val [coiter, corec] = map (mk_iter_like As Cs o Morphism.term phi) csts;
           in
             ((ctrs, coiter, corec, xss, ctr_defs, coiter_def, corec_def), lthy)
           end;
@@ -430,7 +412,7 @@
 
         val (iter_thmss, rec_thmss) =
           let
-            fun mk_goal_iter_or_rec fss fc xctr f xs xs' =
+            fun mk_goal_iter_like fss fc xctr f xs xs' =
               fold_rev (fold_rev Logic.all) (xs :: fss)
                 (mk_Trueprop_eq (fc $ xctr, Term.list_comb (f, xs')));
 
@@ -442,15 +424,13 @@
             val iter_xsss = map (map (map fix_iter_free)) xsss;
             val rec_xsss = map (map (maps fix_rec_free)) xsss;
 
-            val goal_iterss =
-              map5 (map4 o mk_goal_iter_or_rec gss) giters xctrss gss xsss iter_xsss;
-            val goal_recss =
-              map5 (map4 o mk_goal_iter_or_rec hss) hrecs xctrss hss xsss rec_xsss;
+            val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss iter_xsss;
+            val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss rec_xsss;
 
             val iter_tacss =
-              map2 (map o mk_iter_or_rec_tac pre_map_defs iter_defs) fp_iter_thms ctr_defss;
+              map2 (map o mk_iter_like_tac pre_map_defs iter_defs) fp_iter_thms ctr_defss;
             val rec_tacss =
-              map2 (map o mk_iter_or_rec_tac pre_map_defs rec_defs) fp_rec_thms ctr_defss;
+              map2 (map o mk_iter_like_tac pre_map_defs rec_defs) fp_rec_thms ctr_defss;
           in
             (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
                goal_iterss iter_tacss,
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar_tactics.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -13,7 +13,7 @@
     -> tactic
   val mk_half_distinct_tac: Proof.context -> thm -> thm list -> tactic
   val mk_inject_tac: Proof.context -> thm -> thm -> tactic
-  val mk_iter_or_rec_tac: thm list -> thm list -> thm -> thm -> Proof.context -> tactic
+  val mk_iter_like_tac: thm list -> thm list -> thm -> thm -> Proof.context -> tactic
 end;
 
 structure BNF_FP_Sugar_Tactics : BNF_FP_SUGAR_TACTICS =
@@ -49,11 +49,11 @@
   Local_Defs.unfold_tac ctxt [ctr_def] THEN rtac (fld_inject RS ssubst) 1 THEN
   Local_Defs.unfold_tac ctxt @{thms sum.inject Pair_eq conj_assoc} THEN rtac refl 1;
 
-val iter_or_rec_thms =
+val iter_like_thms =
   @{thms sum_map.simps sum.simps(5,6) convol_def case_unit map_pair_def split_conv id_def};
 
-fun mk_iter_or_rec_tac iter_or_rec_defs fld_iter_or_recs ctr_def pre_map_def ctxt =
-  Local_Defs.unfold_tac ctxt (ctr_def :: pre_map_def :: iter_or_rec_defs @ fld_iter_or_recs) THEN
-  Local_Defs.unfold_tac ctxt iter_or_rec_thms THEN rtac refl 1;
+fun mk_iter_like_tac iter_like_defs fld_iter_likes ctr_def pre_map_def ctxt =
+  Local_Defs.unfold_tac ctxt (ctr_def :: pre_map_def :: iter_like_defs @ fld_iter_likes) THEN
+  Local_Defs.unfold_tac ctxt iter_like_thms THEN rtac refl 1;
 
 end;