construct the right iterator theorem in the recursive case
authorblanchet
Sat, 08 Sep 2012 21:04:26 +0200
changeset 49214 2a3cb4c71b87
parent 49213 975ccb0130cb
child 49215 1c5d6e2eb0c6
construct the right iterator theorem in the recursive case
src/HOL/Codatatype/Tools/bnf_def.ML
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
--- a/src/HOL/Codatatype/Tools/bnf_def.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_def.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -28,6 +28,8 @@
   val rel_unfoldN: string
   val pred_unfoldN: string
 
+  val map_of_bnf: BNF -> term
+
   val mk_T_of_bnf: typ list -> typ list -> BNF -> typ
   val mk_bd_of_bnf: typ list -> typ list -> BNF -> term
   val mk_map_of_bnf: typ list -> typ list -> typ list -> BNF -> term
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -15,6 +15,7 @@
 
 open BNF_Util
 open BNF_Wrap
+open BNF_Def
 open BNF_FP_Util
 open BNF_LFP
 open BNF_GFP
@@ -26,7 +27,14 @@
 val itersN = "iters";
 val recsN = "recs";
 
-fun split_list7 xs = (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs);
+fun split_list8 xs =
+  (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs, map #8 xs);
+
+fun typ_subst inst (T as Type (s, Ts)) =
+    (case AList.lookup (op =) inst T of
+      NONE => Type (s, map (typ_subst inst) Ts)
+    | SOME T' => T')
+  | typ_subst inst T = the_default T (AList.lookup (op =) inst T);
 
 fun retype_free (Free (s, _)) T = Free (s, T);
 
@@ -37,6 +45,8 @@
 fun mk_uncurried2_fun f xss =
   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
 
+fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v))
+
 fun popescu_zip [] [fs] = fs
   | popescu_zip (p :: ps) (fs :: fss) = p :: fs @ popescu_zip ps fss;
 
@@ -160,14 +170,14 @@
     val mss = map (map length) ctr_Tsss;
     val Css = map2 replicate ns Cs;
 
-    fun mk_iter_like Ts Us c =
+    fun mk_iter_like Ts Us t =
       let
-        val (binders, body) = strip_type (fastype_of c);
+        val (binders, body) = strip_type (fastype_of t);
         val (f_Us, prebody) = split_last binders;
         val Type (_, Ts0) = if lfp then prebody else body;
         val Us0 = distinct (op =) (map (if lfp then body_type else domain_type) f_Us);
       in
-        Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) c
+        Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
       end;
 
     val fp_iters as fp_iter1 :: _ = map (mk_iter_like As Cs) fp_iters0;
@@ -359,7 +369,7 @@
             val iter = mk_iter_like As Cs iter0;
             val recx = mk_iter_like As Cs rec0;
           in
-            ((ctrs, iter, recx, xss, ctr_defs, iter_def, rec_def), lthy)
+            ((ctrs, iter, recx, v, xss, ctr_defs, iter_def, rec_def), lthy)
           end;
 
         fun some_gfp_sugar no_defs_lthy =
@@ -402,14 +412,19 @@
 
             val [coiter, corec] = map (mk_iter_like As Cs o Morphism.term phi) csts;
           in
-            ((ctrs, coiter, corec, xss, ctr_defs, coiter_def, corec_def), lthy)
+            ((ctrs, coiter, corec, v, xss, ctr_defs, coiter_def, corec_def), lthy)
           end;
       in
         wrap_datatype tacss ((ctrs0, casex0), (disc_binders, sel_binderss)) lthy'
         |> (if lfp then some_lfp_sugar else some_gfp_sugar)
       end;
 
-    fun pour_more_sugar_on_lfps ((ctrss, iters, recs, xsss, ctr_defss, iter_defs, rec_defs),
+    fun mk_map Ts Us t =
+      let val (Type (_, Ts0), Type (_, Us0)) = strip_type (fastype_of t) |>> List.last in
+        Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
+      end;
+
+    fun pour_more_sugar_on_lfps ((ctrss, iters, recs, vs, xsss, ctr_defss, iter_defs, rec_defs),
         lthy) =
       let
         val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
@@ -422,13 +437,40 @@
               fold_rev (fold_rev Logic.all) (xs :: fss)
                 (mk_Trueprop_eq (fiter_like $ xctr, Term.list_comb (f, fxs)));
 
-            fun repair_iter_call (x as Free (_, T)) =
-              (case find_index (curry (op =) T) fpTs of ~1 => x | j => nth giters j $ x);
+            fun build_iter_like fiter_likes maybe_tick =
+              let
+                fun build (T, U) =
+                  if T = U then
+                    Const (@{const_name id}, T --> T)
+                  else
+                    (case (find_index (curry (op =) T) fpTs, (T, U)) of
+                      (~1, (Type (s, Ts), Type (_, Us))) =>
+                      let
+                        val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
+                        val mapx = mk_map Ts Us map0;
+                        val TUs = map dest_funT (fst (split_last (binder_types (fastype_of mapx))));
+                        val args = map build TUs;
+                      in Term.list_comb (mapx, args) end
+                    | (j, _) => maybe_tick (nth vs j) (nth fiter_likes j))
+              in build end;
+
+            fun mk_U maybe_prodT =
+              typ_subst (map2 (fn fpT => fn C => (fpT, maybe_prodT fpT C)) fpTs Cs);
+
+            fun repair_calls fiter_likes maybe_cons maybe_tick maybe_prodT (x as Free (_, T)) =
+              if member (op =) fpTs T then
+                maybe_cons x [build_iter_like fiter_likes (K I) (T, mk_U (K I) T) $ x]
+              else if exists_subtype (member (op =) fpTs) T then
+                [build_iter_like fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
+              else
+                [x];
+
             fun repair_rec_call (x as Free (_, T)) =
               (case find_index (curry (op =) T) fpTs of ~1 => [x] | j => [x, nth hrecs j $ x]);
 
-            val gxsss = map (map (map repair_iter_call)) xsss;
-            val hxsss = map (map (maps repair_rec_call)) xsss;
+            val gxsss = map (map (maps (repair_calls giters (K I) (K I) (K I)))) xsss;
+            val hxsss =
+              map (map (maps (repair_calls hrecs cons tick (curry HOLogic.mk_prodT)))) xsss;
 
             val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss gxsss;
             val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss hxsss;
@@ -455,8 +497,8 @@
         lthy |> Local_Theory.notes notes |> snd
       end;
 
-    fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, xsss, ctr_defss, coiter_defs, corec_defs),
-        lthy) =
+    fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, vs, xsss, ctr_defss, coiter_defs,
+        corec_defs), lthy) =
       let
         val gcoiters = map (lists_bmoc pgss) coiters;
         val hcorecs = map (lists_bmoc phss) corecs;
@@ -505,7 +547,7 @@
       |> fold_map pour_some_sugar_on_type (bs ~~ fpTs ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~
         fp_recs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ns ~~ kss ~~ mss ~~ ctr_binderss ~~
         ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss)
-      |>> split_list7
+      |>> split_list8
       |> (if lfp then pour_more_sugar_on_lfps else pour_more_sugar_on_gfps);
 
     val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^