generate iter/rec goals
authorblanchet
Sat, 08 Sep 2012 21:04:26 +0200
changeset 49204 0b735fb2602e
parent 49203 262ab1ac38b9
child 49205 674f04c737e0
generate iter/rec goals
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -63,16 +63,17 @@
 fun args_of ((_, args), _) = args;
 fun ctr_mixfix_of (_, mx) = mx;
 
-fun prepare_datatype prepare_typ gfp specs fake_lthy lthy =
+fun prepare_datatype prepare_typ gfp specs fake_lthy no_defs_lthy =
   let
     val constrained_As =
       map (map (apfst (prepare_typ fake_lthy)) o type_args_constrained_of) specs
-      |> Library.foldr1 (merge_type_args_constrained lthy);
+      |> Library.foldr1 (merge_type_args_constrained no_defs_lthy);
     val As = map fst constrained_As;
     val As' = map dest_TFree As;
 
     val _ = (case duplicates (op =) As of [] => ()
-      | A :: _ => error ("Duplicate type parameter " ^ quote (Syntax.string_of_typ lthy A)));
+      | A :: _ => error ("Duplicate type parameter " ^
+          quote (Syntax.string_of_typ no_defs_lthy A)));
 
     (* TODO: use sort constraints on type args *)
 
@@ -83,7 +84,7 @@
         As);
 
     val bs = map type_binder_of specs;
-    val fake_Ts = map mk_fake_T bs;
+    val fakeTs = map mk_fake_T bs;
 
     val mixfixes = map mixfix_of specs;
 
@@ -104,32 +105,30 @@
     val _ = (case subtract (op =) As' rhs_As' of
         [] => ()
       | A' :: _ => error ("Extra type variables on rhs: " ^
-          quote (Syntax.string_of_typ lthy (TFree A'))));
+          quote (Syntax.string_of_typ no_defs_lthy (TFree A'))));
 
     val ((Cs, Xs), _) =
-      lthy
+      no_defs_lthy
       |> fold (fold (fn s => Variable.declare_typ (TFree (s, dummyS))) o type_args_of) specs
       |> mk_TFrees N
       ||>> mk_TFrees N;
 
-    fun is_same_fpT (T as Type (s, Us)) (Type (s', Us')) =
+    fun eq_fpT (T as Type (s, Us)) (Type (s', Us')) =
         s = s' andalso (Us = Us' orelse error ("Illegal occurrence of recursive type " ^
           quote (Syntax.string_of_typ fake_lthy T)))
-      | is_same_fpT _ _ = false;
+      | eq_fpT _ _ = false;
 
-    fun freeze_fpXs (T as Type (s, Us)) =
-        (case find_index (is_same_fpT T) fake_Ts of
-          ~1 => Type (s, map freeze_fpXs Us)
-        | i => nth Xs i)
-      | freeze_fpXs T = T;
+    fun freeze_fp (T as Type (s, Us)) =
+        (case find_index (eq_fpT T) fakeTs of ~1 => Type (s, map freeze_fp Us) | j => nth Xs j)
+      | freeze_fp T = T;
 
-    val ctr_TsssXs = map (map (map freeze_fpXs)) fake_ctr_Tsss;
+    val ctr_TsssXs = map (map (map freeze_fp)) fake_ctr_Tsss;
     val sum_prod_TsXs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssXs;
 
     val eqs = map dest_TFree Xs ~~ sum_prod_TsXs;
 
-    val ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects), lthy') =
-      fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs mixfixes As' eqs lthy;
+    val ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects), lthy) =
+      fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs mixfixes As' eqs no_defs_lthy;
 
     val timer = time (Timer.startRealTimer ());
 
@@ -145,8 +144,9 @@
     val flds = map (mk_fld As) flds0;
 
     val fpTs = map (domain_type o fastype_of) unfs;
+    val is_fpT = member (op =) fpTs;
+
     val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
-
     val ns = map length ctr_Tsss;
     val mss = map (map length) ctr_Tsss;
     val Css = map2 replicate ns Cs;
@@ -162,8 +162,31 @@
         Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) c
       end;
 
-    val fp_iters = map (mk_iter_or_rec As Cs) fp_iters0;
-    val fp_recs = map (mk_iter_or_rec As Cs) fp_recs0;
+    val fp_iters as fp_iter1 :: _ = map (mk_iter_or_rec As Cs) fp_iters0;
+    val fp_recs as fp_rec1 :: _ = map (mk_iter_or_rec As Cs) fp_recs0;
+
+    val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter1))));
+    val y_Tsss = map3 (fn ms => map2 dest_tupleT ms oo dest_sumTN) mss ns fp_y_Ts;
+    val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
+
+    fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
+        if member (op =) Cs U then Us else [T]
+      | dest_rec_pair T = [T];
+
+    val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec1))));
+    val z_Tssss =
+      map3 (fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms oo dest_sumTN) mss ns fp_z_Ts;
+    val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
+
+    val ((gss, ysss), _) =
+      lthy
+      |> mk_Freess "f" g_Tss
+      ||>> mk_Freesss "x" y_Tsss;
+
+    val hss = map2 (map2 retype_free) gss h_Tss;
+    val (zssss, _) =
+      lthy
+      |> mk_Freessss "x" z_Tssss;
 
     fun pour_sugar_on_type ((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec), fld_unf),
           unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss)
@@ -178,7 +201,7 @@
         val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
 
         val ((((u, v), fs), xss), _) =
-          lthy
+          no_defs_lthy
           |> yield_singleton (mk_Frees "u") unfT
           ||>> yield_singleton (mk_Frees "v") fpT
           ||>> mk_Frees "f" case_Ts
@@ -249,39 +272,11 @@
 
         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
 
-        (* (co)iterators, (co)recursors, (co)induction *)
-
-        val is_fpT = member (op =) fpTs;
-
-        fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
-            if member (op =) Cs U then Us else [T]
-          | dest_rec_pair T = [T];
-
         fun sugar_datatype no_defs_lthy =
           let
-            val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter))));
-            val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts;
-            val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss;
-            val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
             val iter_T = flat g_Tss ---> fpT --> C;
-
-            val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec))));
-            val z_prod_Tss = map2 dest_sumTN ns fp_z_Ts;
-            val z_Tsss = map2 (map2 dest_tupleT) mss z_prod_Tss;
-            val z_Tssss = map (map (map dest_rec_pair)) z_Tsss;
-            val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
             val rec_T = flat h_Tss ---> fpT --> C;
 
-            val ((gss, ysss), _) =
-              no_defs_lthy
-              |> mk_Freess "f" g_Tss
-              ||>> mk_Freesss "x" y_Tsss;
-
-            val hss = map2 (map2 retype_free) gss h_Tss;
-            val (zssss, _) =
-              no_defs_lthy
-              |> mk_Freessss "x" z_Tssss;
-
             val iter_binder = Binding.suffix_name ("_" ^ iterN) b;
             val rec_binder = Binding.suffix_name ("_" ^ recN) b;
 
@@ -313,7 +308,7 @@
             val iter = mk_iter_or_rec As Cs' iter0;
             val recx = mk_iter_or_rec As Cs' rec0;
           in
-            ([[ctrs], [[iter]], [[recx]], xss, gss, hss], lthy)
+            ([[ctrs], [[iter]], [[recx]], xss], lthy)
           end;
 
         fun sugar_codatatype no_defs_lthy = ([], no_defs_lthy);
@@ -322,19 +317,30 @@
         |> (if gfp then sugar_codatatype else sugar_datatype)
       end;
 
-    fun pour_more_sugar_on_datatypes ([[ctrss], [[iters]], [[recs]], xsss, gsss, hsss], lthy) =
+    fun pour_more_sugar_on_datatypes ([[ctrss], [[iters]], [[recs]], xsss], lthy) =
       let
         val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
-        val giters = map2 (curry flat_list_comb) iters gsss;
-        val hrecs = map2 (curry flat_list_comb) recs hsss;
+        val giters = map (fn iter => flat_list_comb (iter, gss)) iters;
+        val hrecs = map (fn recx => flat_list_comb (recx, hss)) recs;
 
         val (iter_thmss, rec_thmss) =
           let
-            fun mk_goal_iter_or_rec fc xctr =
-              mk_Trueprop_eq (fc $ xctr, fc $ xctr);
+            fun mk_goal_iter_or_rec fss fc xctr f xs xs' =
+              mk_Trueprop_eq (fc $ xctr, Term.list_comb (f, xs'));
+
+            fun fix_iter_free (x as Free (_, T)) =
+              (case find_index (eq_fpT T) fpTs of ~1 => x | j => nth giters j $ x);
+            fun fix_rec_free (x as Free (_, T)) =
+              (case find_index (eq_fpT T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]);
 
-            val goal_iterss = map2 (fn giter => map (mk_goal_iter_or_rec giter)) giters xctrss;
-            val goal_recss = map2 (fn hrec => map (mk_goal_iter_or_rec hrec)) hrecs xctrss;
+            val iter_xsss = map (map (map fix_iter_free)) xsss;
+            val rec_xsss = map (map (maps fix_rec_free)) xsss;
+
+            val goal_iterss =
+              map5 (map4 o mk_goal_iter_or_rec gss) giters xctrss gss xsss iter_xsss;
+            val goal_recss =
+              map5 (map4 o mk_goal_iter_or_rec hss) hrecs xctrss hss xsss rec_xsss;
+
             val iter_tacss =
               map (map (K (fn _ => Skip_Proof.cheat_tac (Proof_Context.theory_of lthy)))) goal_iterss;
               (* ### map (map mk_iter_or_rec_tac); (* needs ctr_def, iter_def, fld_iter *) *)
@@ -356,7 +362,7 @@
         lthy |> Local_Theory.notes notes |> snd
       end;
 
-    val lthy'' = lthy'
+    val lthy' = lthy
       |> fold_map pour_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~ fp_recs ~~
         fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~
         disc_binderss ~~ sel_bindersss)
@@ -365,7 +371,7 @@
     val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
       (if gfp then "co" else "") ^ "datatype"));
   in
-    (timer; lthy'')
+    (timer; lthy')
   end;
 
 fun datatype_cmd info specs lthy =