define coiterators
authorblanchet
Sat, 08 Sep 2012 21:04:26 +0200
changeset 49210 656fb50d33f0
parent 49209 3c0deda51b32
child 49211 239a4fa29ddf
define coiterators
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_util.ML
src/HOL/Codatatype/Tools/bnf_wrap.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -154,31 +154,34 @@
     val ns = map length ctr_Tsss;
     val mss = map (map length) ctr_Tsss;
     val Css = map2 replicate ns Cs;
-    val Cs' = flat Css;
 
-    fun mk_iter_or_rec Ts Us c =
+    fun mk_iter_like Ts Us c =
       let
         val (binders, body) = strip_type (fastype_of c);
-        val (fst_binders, last_binder) = split_last binders;
-        val Type (_, Ts0) = if lfp then last_binder else body;
-        val Us0 = map (if lfp then body_type else domain_type) fst_binders;
+        val (f_Us, prebody) = split_last binders;
+        val Type (_, Ts0) = if lfp then prebody else body;
+        val Us0 = distinct (op =) (map (if lfp then body_type else domain_type) f_Us);
       in
         Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) c
       end;
 
-    val fp_iters as fp_iter1 :: _ = map (mk_iter_or_rec As Cs) fp_iters0;
-    val fp_recs as fp_rec1 :: _ = map (mk_iter_or_rec As Cs) fp_recs0;
+    val fp_iters as fp_iter1 :: _ = map (mk_iter_like As Cs) fp_iters0;
+    val fp_recs as fp_rec1 :: _ = map (mk_iter_like As Cs) fp_recs0;
+
+    val fp_iter_f_Ts = fst (split_last (binder_types (fastype_of fp_iter1)));
+    val fp_rec_f_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
 
     fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
         if member (op =) Cs U then Us else [T]
       | dest_rec_pair T = [T];
 
-    val ((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)) =
+    val (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)),
+         (cs, (qss, q_Tss, gsss, g_Tsss), ())) =
       if lfp then
         let
           val y_Tsss =
-            map3 (fn ms => fn n => map2 dest_tupleT ms o dest_sumTN n o domain_type) mss ns
-              (fst (split_last (binder_types (fastype_of fp_iter1))));
+            map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN n o domain_type)
+              ns mss fp_iter_f_Ts;
           val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
 
           val ((gss, ysss), _) =
@@ -187,28 +190,48 @@
             ||>> mk_Freesss "x" y_Tsss;
 
           val z_Tssss =
-            map3 (fn ms => fn n => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n o domain_type) mss ns
-              (fst (split_last (binder_types (fastype_of fp_rec1))));
+            map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n
+              o domain_type) ns mss fp_rec_f_Ts;
           val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
 
           val hss = map2 (map2 retype_free) gss h_Tss;
           val (zssss, _) =
             lthy
             |> mk_Freessss "x" z_Tssss;
-        in ((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)) end
+        in
+          (((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)),
+           ([], ([], [], [], []), ()))
+        end
       else
-        (([], [], [], []), ([], [], [], [])); (* ### *)
+        let
+          val q_Tss =
+            map2 (fn C => fn n => replicate (Int.max (0, n - 1)) (C --> HOLogic.boolT)) Cs ns;
+          val g_Tsss =
+            map4 (fn C => fn n => fn ms => map2 (map (curry (op -->) C) oo dest_tupleT) ms o
+              dest_sumTN n o range_type) Cs ns mss fp_iter_f_Ts;
 
-    fun pour_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf),
-          unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss)
-        no_defs_lthy =
+          val (((c, qss), gsss), _) =
+            lthy
+            |> yield_singleton (mk_Frees "c") dummyT
+            ||>> mk_Freess "p" q_Tss
+            ||>> mk_Freesss "g" g_Tsss;
+
+          val cs = map (retype_free c) Cs;
+        in
+          ((([], [], [], []), ([], [], [], [])),
+           (cs, (qss, q_Tss, gsss, g_Tsss), ()))
+        end;
+
+    fun pour_some_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf),
+          unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders),
+          sel_binderss) no_defs_lthy =
       let
         val n = length ctr_Tss;
         val ks = 1 upto n;
         val ms = map length ctr_Tss;
 
         val unfT = domain_type (fastype_of fld);
-        val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
+        val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
         val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
 
         val ((((u, v), fs), xss), _) =
@@ -220,7 +243,7 @@
 
         val ctr_rhss =
           map2 (fn k => fn xs =>
-            fold_rev Term.lambda xs (fld $ mk_InN prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
+            fold_rev Term.lambda xs (fld $ mk_InN ctr_prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
 
         val case_binder = Binding.suffix_name ("_" ^ caseN) b;
 
@@ -261,7 +284,7 @@
 
             val sumEN_thm' =
               Local_Defs.unfold lthy @{thms all_unit_eq}
-                (Drule.instantiate' (map (SOME o certifyT lthy) prod_Ts) [] (mk_sumEN n))
+                (Drule.instantiate' (map (SOME o certifyT lthy) ctr_prod_Ts) [] (mk_sumEN n))
               |> Morphism.thm phi;
           in
             mk_exhaust_tac ctxt n ctr_defs fld_iff_unf_thm sumEN_thm'
@@ -281,7 +304,7 @@
 
         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
 
-        fun sugar_datatype no_defs_lthy =
+        fun some_lfp_sugar no_defs_lthy =
           let
             val fpT_to_C = fpT --> C;
             val iter_T = fold_rev (curry (op --->)) g_Tss fpT_to_C;
@@ -315,17 +338,87 @@
             val iter0 = Morphism.term phi raw_iter;
             val rec0 = Morphism.term phi raw_rec;
 
-            val iter = mk_iter_or_rec As Cs' iter0;
-            val recx = mk_iter_or_rec As Cs' rec0;
+            val iter = mk_iter_like As Cs iter0;
+            val recx = mk_iter_like As Cs rec0;
           in
             ((ctrs, iter, recx, xss, ctr_defs, iter_def, rec_def), lthy)
           end;
 
-        fun sugar_codatatype no_defs_lthy =
-          (([], @{term True}, @{term True}, [], [], TrueI, TrueI), no_defs_lthy);
+        fun some_gfp_sugar no_defs_lthy =
+          let
+            (* qss, q_Tss, gsss, g_Tsss *)
+            fun zip_preds_and_getters p_Ts f_Tss = p_Ts @ flat f_Tss;
+
+            val qg_Tss = map2 zip_preds_and_getters q_Tss g_Tsss;
+
+            val B_to_fpT = C --> fpT;
+            val coiter_T = fold_rev (curry (op --->)) qg_Tss B_to_fpT;
+(*
+            val corec_T = fold_rev (curry (op --->)) h_Tss fpT_to_C;
+*)
+
+            val qgss = map2 zip_preds_and_getters qss gsss;
+            val cqss = map2 (fn c => map (fn q => q $ c)) cs qss;
+            val cgsss = map2 (fn c => map (map (fn g => g $ c))) cs gsss;
+
+            val coiter_binder = Binding.suffix_name ("_" ^ coiterN) b;
+            val corec_binder = Binding.suffix_name ("_" ^ corecN) b;
+
+            val coiter_free = Free (Binding.name_of coiter_binder, coiter_T);
+(*
+            val corec_free = Free (Binding.name_of corec_binder, corec_T);
+*)
+
+            val coiter_sum_prod_Ts = map range_type fp_iter_f_Ts;
+            val coiter_prod_Tss = map2 dest_sumTN ns coiter_sum_prod_Ts;
+
+            fun mk_join c n cqs sum_prod_T prod_Ts cgss =
+              Term.lambda c (mk_IfN sum_prod_T cqs
+                (map2 (mk_InN prod_Ts) (map HOLogic.mk_tuple cgss) (1 upto n)));
+
+            val coiter_spec =
+              mk_Trueprop_eq (flat_list_comb (coiter_free, qgss),
+                Term.list_comb (fp_iter,
+                  map6 mk_join cs ns cqss coiter_sum_prod_Ts coiter_prod_Tss cgsss));
+(*
+            val corec_spec =
+              mk_Trueprop_eq (flat_list_comb (corec_free, hss),
+                Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) hss zssss));
+*)
+
+            val (([raw_coiter (*, raw_corec*)], [raw_coiter_def (*, raw_corec_def*)]), (lthy', lthy)) = no_defs_lthy
+              |> apfst split_list o fold_map2 (fn b => fn spec =>
+                Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec))
+                #>> apsnd snd) [coiter_binder (*, corec_binder*)] [coiter_spec (*, corec_spec*)]
+              ||> `Local_Theory.restore;
+
+            (*transforms defined frees into consts (and more)*)
+            val phi = Proof_Context.export_morphism lthy lthy';
+
+            val coiter_def = Morphism.thm phi raw_coiter_def;
+(*
+            val corec_def = Morphism.thm phi raw_corec_def;
+*)
+
+            val coiter0 = Morphism.term phi raw_coiter;
+(*
+            val corec0 = Morphism.term phi raw_corec;
+*)
+
+            val coiter = mk_iter_like As Cs coiter0;
+(*
+            val corec = mk_iter_like As Cs corec0;
+*)
+
+            (*###*)
+            val corec = @{term True};
+            val corec_def = TrueI;
+          in
+            ((ctrs, coiter, corec, xss, ctr_defs, coiter_def, corec_def), lthy)
+          end;
       in
         wrap_datatype tacss ((ctrs0, casex0), (disc_binders, sel_binderss)) lthy'
-        |> (if lfp then sugar_datatype else sugar_codatatype)
+        |> (if lfp then some_lfp_sugar else some_gfp_sugar)
       end;
 
     fun pour_more_sugar_on_datatypes ((ctrss, iters, recs, xsss, ctr_defss, iter_defs, rec_defs),
@@ -377,9 +470,9 @@
       end;
 
     val lthy' = lthy
-      |> fold_map pour_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~ fp_recs ~~
-        fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~
-        disc_binderss ~~ sel_bindersss)
+      |> fold_map pour_some_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
+        fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~
+        ctr_Tsss ~~ disc_binderss ~~ sel_bindersss)
       |>> split_list7
       |> (if lfp then pour_more_sugar_on_datatypes else snd);
 
--- a/src/HOL/Codatatype/Tools/bnf_util.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_util.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -73,6 +73,7 @@
   val mk_Card_order: term -> term
   val mk_Field: term -> term
   val mk_Gr: term -> term -> term
+  val mk_IfN: typ -> term list -> term list -> term
   val mk_Trueprop_eq: term * term -> term
   val mk_UNION: term -> term -> term
   val mk_Union: typ -> term
@@ -302,6 +303,10 @@
 
 val mk_Trueprop_eq = HOLogic.mk_Trueprop o HOLogic.mk_eq;
 
+fun mk_IfN _ _ [t] = t
+  | mk_IfN T (c :: cs) (t :: ts) =
+    Const (@{const_name If}, HOLogic.boolT --> T --> T --> T) $ c $ t $ mk_IfN T cs ts;
+
 fun mk_converse R =
   let
     val RT = dest_relT (fastype_of R);
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -152,6 +152,7 @@
 
     val q = Free (fst p', B --> HOLogic.boolT);
 
+    fun ap_v t = t $ v;
     fun mk_v_eq_v () = HOLogic.mk_eq (v, v);
 
     val xctrs = map2 (curry Term.list_comb) ctrs xss;
@@ -404,8 +405,7 @@
               let
                 val prem = HOLogic.mk_Trueprop (betapply (disc, v));
                 val concl =
-                  mk_Trueprop_eq ((null sels ? swap)
-                    (Term.list_comb (ctr, map (fn sel => sel $ v) sels), v));
+                  mk_Trueprop_eq ((null sels ? swap) (Term.list_comb (ctr, map ap_v sels), v));
               in
                 if prem aconv concl then NONE
                 else SOME (Logic.all v (Logic.mk_implies (prem, concl)))
@@ -421,12 +421,9 @@
 
         val case_eq_thm =
           let
-            fun mk_core f sels = Term.list_comb (f, map (fn sel => sel $ v) sels);
-            fun mk_rhs _ [f] [sels] = mk_core f sels
-              | mk_rhs (disc :: discs) (f :: fs) (sels :: selss) =
-                Const (@{const_name If}, HOLogic.boolT --> B --> B --> B) $
-                  betapply (disc, v) $ mk_core f sels $ mk_rhs discs fs selss;
-            val goal = mk_Trueprop_eq (fcase $ v, mk_rhs discs fs selss);
+            fun mk_body f sels = Term.list_comb (f, map ap_v sels);
+            val goal =
+              mk_Trueprop_eq (fcase $ v, mk_IfN B (map ap_v discs) (map2 mk_body fs selss));
           in
             Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
               mk_case_eq_tac ctxt n exhaust_thm' case_thms disc_thmss' sel_thmss)