tuning
authorblanchet
Sat, 08 Sep 2012 21:04:26 +0200
changeset 49208 3f73424f86a7
parent 49207 4634c217b77b
child 49209 3c0deda51b32
tuning
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -66,7 +66,7 @@
 fun args_of ((_, args), _) = args;
 fun ctr_mixfix_of (_, mx) = mx;
 
-fun prepare_datatype prepare_typ gfp specs fake_lthy no_defs_lthy =
+fun prepare_datatype prepare_typ lfp specs fake_lthy no_defs_lthy =
   let
     val constrained_As =
       map (map (apfst (prepare_typ fake_lthy)) o type_args_constrained_of) specs
@@ -132,7 +132,7 @@
 
     val (pre_map_defs, ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects,
         fp_iter_thms, fp_rec_thms), lthy)) =
-      fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs mixfixes As' eqs no_defs_lthy;
+      fp_bnf (if lfp then bnf_lfp else bnf_gfp) bs mixfixes As' eqs no_defs_lthy;
 
     val timer = time (Timer.startRealTimer ());
 
@@ -160,8 +160,8 @@
       let
         val (binders, body) = strip_type (fastype_of c);
         val (fst_binders, last_binder) = split_last binders;
-        val Type (_, Ts0) = if gfp then body else last_binder;
-        val Us0 = map (if gfp then domain_type else body_type) fst_binders;
+        val Type (_, Ts0) = if lfp then last_binder else body;
+        val Us0 = map (if lfp then body_type else domain_type) fst_binders;
       in
         Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) c
       end;
@@ -169,28 +169,35 @@
     val fp_iters as fp_iter1 :: _ = map (mk_iter_or_rec As Cs) fp_iters0;
     val fp_recs as fp_rec1 :: _ = map (mk_iter_or_rec As Cs) fp_recs0;
 
-    val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter1))));
-    val y_Tsss = map3 (fn ms => map2 dest_tupleT ms oo dest_sumTN) mss ns fp_y_Ts;
-    val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
-
     fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
         if member (op =) Cs U then Us else [T]
       | dest_rec_pair T = [T];
 
-    val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec1))));
-    val z_Tssss =
-      map3 (fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms oo dest_sumTN) mss ns fp_z_Ts;
-    val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
+    val ((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)) =
+      if lfp then
+        let
+          val y_Tsss =
+            map3 (fn ms => fn n => map2 dest_tupleT ms o dest_sumTN n o domain_type) mss ns
+              (fst (split_last (binder_types (fastype_of fp_iter1))));
+          val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
+
+          val ((gss, ysss), _) =
+            lthy
+            |> mk_Freess "f" g_Tss
+            ||>> mk_Freesss "x" y_Tsss;
 
-    val ((gss, ysss), _) =
-      lthy
-      |> mk_Freess "f" g_Tss
-      ||>> mk_Freesss "x" y_Tsss;
+          val z_Tssss =
+            map3 (fn ms => fn n => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n o domain_type) mss ns
+              (fst (split_last (binder_types (fastype_of fp_rec1))));
+          val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
 
-    val hss = map2 (map2 retype_free) gss h_Tss;
-    val (zssss, _) =
-      lthy
-      |> mk_Freessss "x" z_Tssss;
+          val hss = map2 (map2 retype_free) gss h_Tss;
+          val (zssss, _) =
+            lthy
+            |> mk_Freessss "x" z_Tssss;
+        in ((gss, g_Tss, ysss, y_Tsss), (hss, h_Tss, zssss, z_Tssss)) end
+      else
+        (([], [], [], []), ([], [], [], [])); (* ### *)
 
     fun pour_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf),
           unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss)
@@ -276,8 +283,9 @@
 
         fun sugar_datatype no_defs_lthy =
           let
-            val iter_T = flat g_Tss ---> fpT --> C;
-            val rec_T = flat h_Tss ---> fpT --> C;
+            val fpT_to_C = fpT --> C;
+            val iter_T = fold_rev (curry (op --->)) g_Tss fpT_to_C;
+            val rec_T = fold_rev (curry (op --->)) h_Tss fpT_to_C;
 
             val iter_binder = Binding.suffix_name ("_" ^ iterN) b;
             val rec_binder = Binding.suffix_name ("_" ^ recN) b;
@@ -317,7 +325,7 @@
           (([], @{term True}, @{term True}, [], [], TrueI, TrueI), no_defs_lthy);
       in
         wrap_datatype tacss ((ctrs0, casex0), (disc_binders, sel_binderss)) lthy'
-        |> (if gfp then sugar_codatatype else sugar_datatype)
+        |> (if lfp then sugar_datatype else sugar_codatatype)
       end;
 
     fun pour_more_sugar_on_datatypes ((ctrss, iters, recs, xsss, ctr_defss, iter_defs, rec_defs),
@@ -373,10 +381,10 @@
         fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~
         disc_binderss ~~ sel_bindersss)
       |>> split_list7
-      |> (if gfp then snd else pour_more_sugar_on_datatypes);
+      |> (if lfp then pour_more_sugar_on_datatypes else snd);
 
     val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
-      (if gfp then "co" else "") ^ "datatype"));
+      (if lfp then "" else "co") ^ "datatype"));
   in
     (timer; lthy')
   end;
@@ -406,10 +414,10 @@
 
 val _ =
   Outer_Syntax.local_theory @{command_spec "data"} "define BNF-based inductive datatypes"
-    (Parse.and_list1 parse_single_spec >> datatype_cmd false);
+    (Parse.and_list1 parse_single_spec >> datatype_cmd true);
 
 val _ =
   Outer_Syntax.local_theory @{command_spec "codata"} "define BNF-based coinductive datatypes"
-    (Parse.and_list1 parse_single_spec >> datatype_cmd true);
+    (Parse.and_list1 parse_single_spec >> datatype_cmd false);
 
 end;