tuning
authorblanchet
Sat, 08 Sep 2012 21:04:26 +0200
changeset 49201 c69c2c18dccb
parent 49200 73f9aede57a4
child 49202 f493cd25737f
tuning
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_wrap.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -108,18 +108,18 @@
       |> mk_TFrees N
       ||>> mk_TFrees N;
 
-    fun is_same_recT (T as Type (s, Us)) (Type (s', Us')) =
+    fun is_same_fpT (T as Type (s, Us)) (Type (s', Us')) =
         s = s' andalso (Us = Us' orelse error ("Illegal occurrence of recursive type " ^
           quote (Syntax.string_of_typ fake_lthy T)))
-      | is_same_recT _ _ = false;
+      | is_same_fpT _ _ = false;
 
-    fun freeze_recXs (T as Type (s, Us)) =
-        (case find_index (is_same_recT T) fake_Ts of
-          ~1 => Type (s, map freeze_recXs Us)
+    fun freeze_fpXs (T as Type (s, Us)) =
+        (case find_index (is_same_fpT T) fake_Ts of
+          ~1 => Type (s, map freeze_fpXs Us)
         | i => nth Xs i)
-      | freeze_recXs T = T;
+      | freeze_fpXs T = T;
 
-    val ctr_TsssXs = map (map (map freeze_recXs)) fake_ctr_Tsss;
+    val ctr_TsssXs = map (map (map freeze_fpXs)) fake_ctr_Tsss;
     val sum_prod_TsXs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssXs;
 
     val eqs = map dest_TFree Xs ~~ sum_prod_TsXs;
@@ -140,22 +140,22 @@
     val unfs = map (mk_unf As) raw_unfs;
     val flds = map (mk_fld As) raw_flds;
 
-    val fp_Ts = map (domain_type o fastype_of) unfs;
-    val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fp_Ts)))) ctr_TsssXs;
+    val fpTs = map (domain_type o fastype_of) unfs;
+    val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
 
-    fun mk_fp_iter_or_rec Ts Us t =
+    fun mk_fp_iter_or_rec Ts Us c =
       let
-        val (binders, body) = strip_type (fastype_of t);
+        val (binders, body) = strip_type (fastype_of c);
         val Type (_, Ts0) = if gfp then body else List.last binders;
         val Us0 = map (if gfp then domain_type else body_type) (fst (split_last binders));
       in
-        Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
+        Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) c
       end;
 
     val fp_iters = map (mk_fp_iter_or_rec As Cs) raw_fp_iters;
     val fp_recs = map (mk_fp_iter_or_rec As Cs) raw_fp_recs;
 
-    fun pour_sugar_on_type ((((((((((((((b, fp_T), C), fld), unf), fp_iter), fp_rec), fld_unf),
+    fun pour_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf),
           unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss)
         no_defs_lthy =
       let
@@ -163,14 +163,14 @@
         val ks = 1 upto n;
         val ms = map length ctr_Tss;
 
-        val unf_T = domain_type (fastype_of fld);
+        val unfT = domain_type (fastype_of fld);
         val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
         val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
 
         val ((((u, v), fs), xss), _) =
           lthy
-          |> yield_singleton (mk_Frees "u") unf_T
-          ||>> yield_singleton (mk_Frees "v") fp_T
+          |> yield_singleton (mk_Frees "u") unfT
+          ||>> yield_singleton (mk_Frees "v") fpT
           ||>> mk_Frees "f" case_Ts
           ||>> mk_Freess "x" ctr_Tss;
 
@@ -183,12 +183,10 @@
         val case_rhs =
           fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (map2 mk_uncurried_fun fs xss) $ (unf $ v));
 
-        val (((raw_ctrs, raw_ctr_defs), (raw_case, raw_case_def)), (lthy', lthy)) = no_defs_lthy
+        val ((raw_case :: raw_ctrs, raw_case_def :: raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
           |> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
                Local_Theory.define ((b, mx), ((Thm.def_binding b, []), rhs)) #>> apsnd snd)
-             ctr_binders ctr_mixfixes ctr_rhss
-          ||>> (Local_Theory.define ((case_binder, NoSyn), ((Thm.def_binding case_binder, []),
-             case_rhs)) #>> apsnd snd)
+             (case_binder :: ctr_binders) (NoSyn :: ctr_mixfixes) (case_rhs :: ctr_rhss)
           ||> `Local_Theory.restore;
 
         (*transforms defined frees into consts (and more)*)
@@ -209,7 +207,7 @@
                     (mk_Trueprop_eq (HOLogic.mk_eq (v, fld $ u), HOLogic.mk_eq (unf $ v, u)));
               in
                 Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
-                  mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unf_T, fp_T])
+                  mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unfT, fpT])
                     (certify lthy fld) (certify lthy unf) fld_unf unf_fld)
                 |> Thm.close_derivation
                 |> Morphism.thm phi
@@ -241,7 +239,7 @@
 
         (* (co)iterators, (co)recursors, (co)induction *)
 
-        val is_recT = member (op =) fp_Ts;
+        val is_fpT = member (op =) fpTs;
 
         val ns = map length ctr_Tsss;
         val mss = map (map length) ctr_Tsss;
@@ -257,14 +255,14 @@
             val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts;
             val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss;
             val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
-            val iter_T = flat g_Tss ---> fp_T --> C;
+            val iter_T = flat g_Tss ---> fpT --> C;
 
             val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec))));
             val z_prod_Tss = map2 dest_sumTN ns fp_z_Ts;
             val z_Tsss = map2 (map2 dest_tupleT) mss z_prod_Tss;
             val z_Tssss = map (map (map dest_rec_pair)) z_Tsss;
             val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
-            val rec_T = flat h_Tss ---> fp_T --> C;
+            val rec_T = flat h_Tss ---> fpT --> C;
 
             val ((gss, ysss), _) =
               no_defs_lthy
@@ -291,24 +289,112 @@
                   map2 (mk_sum_caseN oo map2 mk_doubly_uncurried_fun) hss zssss));
 
             val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy
-              |> apfst split_list o fold_map (fn (b, spec) =>
+              |> apfst split_list o fold_map2 (fn b => fn spec =>
                 Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec))
-                #>> apsnd snd) [(iter_binder, iter_spec), (rec_binder, rec_spec)]
+                #>> apsnd snd) [iter_binder, rec_binder] [iter_spec, rec_spec]
               ||> `Local_Theory.restore;
+
+            (*transforms defined frees into consts (and more)*)
+            val phi = Proof_Context.export_morphism lthy lthy';
+
+            val iter_def = Morphism.thm phi raw_iter_def;
+            val rec_def = Morphism.thm phi raw_rec_def;
+
+            val iter = Morphism.term phi raw_iter;
+            val recx = Morphism.term phi raw_rec;
           in
-            lthy
+            ((iter, recx), lthy)
           end;
 
-        fun sugar_codatatype no_defs_lthy = no_defs_lthy;
+        fun sugar_codatatype no_defs_lthy =
+          let
+(*###
+            val fp_y_Ts = map range_type (fst (split_last (binder_types (fastype_of fp_iter))));
+            val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts;
+            val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss;
+            val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
+            val coiter_T = flat g_Tss ---> fpT --> C;
+
+            val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec))));
+            val z_prod_Tss = map2 dest_sumTN ns fp_z_Ts;
+            val z_Tsss = map2 (map2 dest_tupleT) mss z_prod_Tss;
+            val z_Tssss = map (map (map dest_rec_pair)) z_Tsss;
+            val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
+            val corec_T = flat h_Tss ---> fpT --> C;
+
+            val ((gss, ysss), _) =
+              no_defs_lthy
+              |> mk_Freess "f" g_Tss
+              ||>> mk_Freesss "x" y_Tsss;
+
+            val hss = map2 (map2 retype_free) gss h_Tss;
+            val (zssss, _) =
+              no_defs_lthy
+              |> mk_Freessss "x" z_Tssss;
+
+            val coiter_binder = Binding.suffix_name ("_" ^ coiterN) b;
+            val corec_binder = Binding.suffix_name ("_" ^ corecN) b;
+
+            val coiter_free = Free (Binding.name_of coiter_binder, coiter_T);
+            val corec_free = Free (Binding.name_of corec_binder, corec_T);
+
+            val coiter_spec =
+              mk_Trueprop_eq (fold (fn gs => fn t => Term.list_comb (t, gs)) gss coiter_free,
+                Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
+            val corec_spec =
+              mk_Trueprop_eq (fold (fn hs => fn t => Term.list_comb (t, hs)) hss corec_free,
+                Term.list_comb (fp_rec,
+                  map2 (mk_sum_caseN oo map2 mk_doubly_uncurried_fun) hss zssss));
+
+            val (([raw_coiter, raw_corec], [raw_coiter_def, raw_corec_def]), (lthy', lthy)) =
+              no_defs_lthy
+              |> apfst split_list o fold_map2 (fn b => fn spec =>
+                Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec))
+                #>> apsnd snd) [coiter_binder, corec_binder] [coiter_spec, corec_spec]
+              ||> `Local_Theory.restore;
+
+            (*transforms defined frees into consts (and more)*)
+            val phi = Proof_Context.export_morphism lthy lthy';
+
+            val coiter_def = Morphism.thm phi raw_coiter_def;
+            val corec_def = Morphism.thm phi raw_corec_def;
+
+            val coiter = Morphism.term phi raw_coiter;
+            val corec = Morphism.term phi raw_corec;
+*)
+            val coiter = @{term True}; (*###*)
+            val corec = @{term True}; (*###*)
+          in
+            ((coiter, corec), lthy)
+          end;
       in
         wrap_datatype tacss ((ctrs, casex), (disc_binders, sel_binderss)) lthy'
         |> (if gfp then sugar_codatatype else sugar_datatype)
       end;
 
-    val lthy'' =
-      fold pour_sugar_on_type (bs ~~ fp_Ts ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~ fp_recs ~~
+(* ###
+            val (iter_thms, rec_thms) =
+              let
+                fun mk_goal_iter_or_rec fc xctr f xs =
+                  mk_Trueprop_eq (fc $ xctr, Term.list_comb (f, ));
+
+                val giter = Term.list_comb (iter, gs);
+                val hrec = Term.list_comb (rec, hs);
+
+                val goal_iters = map2 (mk_goal_iter_or_rec iter) gss xctrs;
+                val goal_recs = map2 (mk_goal_iter_or_rec recx) hss xctrs;
+                val iter_tacs = [];
+                val rec_tacs = [];
+              in
+                (map2 (Skip_Proof.prove lthy [] []) goal_iters iter_tacs,
+                 map2 (Skip_Proof.prove lthy [] []) goal_recs rec_tacs)
+              end;
+*)
+
+    val ((iters, recs), lthy'') =
+      fold_map pour_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~ fp_recs ~~
         fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~
-        disc_binderss ~~ sel_bindersss) lthy';
+        disc_binderss ~~ sel_bindersss) lthy' |>> split_list;
 
     val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
       (if gfp then "co" else "") ^ "datatype"));
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -137,16 +137,14 @@
         val Type (_, Ts0) = List.last binders
       in Term.subst_atomic_types ((body, T) :: (Ts0 ~~ Ts)) case0 end;
 
-    val caseB = mk_case As B;
-    val caseB_Ts = map (fn Ts => Ts ---> B) ctr_Tss;
-
-    fun mk_caseB_term eta_fs = Term.list_comb (caseB, eta_fs);
+    val casex = mk_case As B;
+    val case_Ts = map (fn Ts => Ts ---> B) ctr_Tss;
 
     val (((((((xss, yss), fs), gs), (v, v')), w), (p, p')), names_lthy) = no_defs_lthy |>
       mk_Freess "x" ctr_Tss
       ||>> mk_Freess "y" ctr_Tss
-      ||>> mk_Frees "f" caseB_Ts
-      ||>> mk_Frees "g" caseB_Ts
+      ||>> mk_Frees "f" case_Ts
+      ||>> mk_Frees "g" case_Ts
       ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "v") T
       ||>> yield_singleton (mk_Frees "w") T
       ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "P") HOLogic.boolT;
@@ -164,7 +162,8 @@
     val eta_fs = map2 eta_expand_case_arg xss xfs;
     val eta_gs = map2 eta_expand_case_arg xss xgs;
 
-    val caseB_fs = Term.list_comb (caseB, eta_fs);
+    val fcase = Term.list_comb (casex, eta_fs);
+    val gcase = Term.list_comb (casex, eta_gs);
 
     val exist_xs_v_eq_ctrs =
       map2 (fn xctr => fn xs => list_exists_free xs (HOLogic.mk_eq (v, xctr))) xctrs xss;
@@ -255,7 +254,7 @@
 
     val goal_cases =
       map3 (fn xs => fn xctr => fn xf =>
-        fold_rev Logic.all (fs @ xs) (mk_Trueprop_eq (caseB_fs $ xctr, xf))) xss xctrs xfs;
+        fold_rev Logic.all (fs @ xs) (mk_Trueprop_eq (fcase $ xctr, xf))) xss xctrs xfs;
 
     val goalss = [goal_exhaust] :: goal_injectss @ goal_half_distinctss @ [goal_cases];
 
@@ -426,7 +425,7 @@
               | mk_rhs (disc :: discs) (f :: fs) (sels :: selss) =
                 Const (@{const_name If}, HOLogic.boolT --> B --> B --> B) $
                   betapply (disc, v) $ mk_core f sels $ mk_rhs discs fs selss;
-            val goal = mk_Trueprop_eq (caseB_fs $ v, mk_rhs discs fs selss);
+            val goal = mk_Trueprop_eq (fcase $ v, mk_rhs discs fs selss);
           in
             Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
               mk_case_eq_tac ctxt n exhaust_thm' case_thms disc_thmss' sel_thmss)
@@ -440,13 +439,11 @@
                 mk_Trueprop_eq (f, g)));
 
             val v_eq_w = mk_Trueprop_eq (v, w);
-            val case_fs = mk_caseB_term eta_fs;
-            val case_gs = mk_caseB_term eta_gs;
 
             val goal =
               Logic.list_implies (v_eq_w :: map4 mk_prem xctrs xss fs gs,
-                 mk_Trueprop_eq (case_fs $ v, case_gs $ w));
-            val goal_weak = Logic.mk_implies (v_eq_w, mk_Trueprop_eq (case_fs $ v, case_fs $ w));
+                 mk_Trueprop_eq (fcase $ v, gcase $ w));
+            val goal_weak = Logic.mk_implies (v_eq_w, mk_Trueprop_eq (fcase $ v, fcase $ w));
           in
             (Skip_Proof.prove lthy [] [] goal (fn _ => mk_case_cong_tac exhaust_thm' case_thms),
              Skip_Proof.prove lthy [] [] goal_weak (K (etac arg_cong 1)))
@@ -461,7 +458,7 @@
               list_exists_free xs (HOLogic.mk_conj (HOLogic.mk_eq (v, xctr),
                 HOLogic.mk_not (q $ f_xs)));
 
-            val lhs = q $ (mk_caseB_term eta_fs $ v);
+            val lhs = q $ (fcase $ v);
 
             val goal =
               mk_Trueprop_eq (lhs, Library.foldr1 HOLogic.mk_conj (map3 mk_conjunct xctrs xss xfs));