changed the type of the recursor for nested recursion
authorblanchet
Mon, 01 Oct 2012 10:34:58 +0200
changeset 49670 c7a034d01936
parent 49669 620fa6272c48
child 49671 61729b149397
changed the type of the recursor for nested recursion
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Sun Sep 30 23:45:03 2012 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Mon Oct 01 10:34:58 2012 +0200
@@ -67,8 +67,6 @@
 
 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
-fun mk_uncurried2_fun f xss =
-  mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
 
 fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) =
   Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1));
@@ -247,6 +245,22 @@
 
     val timer = time (Timer.startRealTimer ());
 
+    fun build_map build_arg (Type (s, Ts)) (Type (_, Us)) =
+      let
+        val bnf = the (bnf_of lthy s);
+        val live = live_of_bnf bnf;
+        val mapx = mk_map live Ts Us (map_of_bnf bnf);
+        val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
+      in Term.list_comb (mapx, map build_arg TUs') end;
+
+    fun build_rel_step build_arg (Type (s, Ts)) =
+      let
+        val bnf = the (bnf_of lthy s);
+        val live = live_of_bnf bnf;
+        val rel = mk_rel live Ts Ts (rel_of_bnf bnf);
+        val Ts' = map domain_type (fst (strip_typeN live (fastype_of rel)));
+      in Term.list_comb (rel, map build_arg Ts') end;
+
     fun add_nesty_bnf_names Us =
       let
         fun add (Type (s, Ts)) ss =
@@ -265,8 +279,11 @@
     val pre_map_defs = map map_def_of_bnf pre_bnfs;
     val pre_set_defss = map set_defs_of_bnf pre_bnfs;
     val pre_rel_defs = map rel_def_of_bnf pre_bnfs;
+    val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
+    val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
+    val nesting_map_ids = map map_id_of_bnf nesting_bnfs;
+    val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def}) nesting_map_ids;
     val nested_set_natural's = maps set_natural'_of_bnf nested_bnfs;
-    val nesting_map_ids = map map_id_of_bnf nesting_bnfs;
     val nesting_set_natural's = maps set_natural'_of_bnf nesting_bnfs;
 
     val live = live_of_bnf any_fp_bnf;
@@ -283,6 +300,7 @@
     val fpTs = map (domain_type o fastype_of) dtors;
 
     val exists_fp_subtype = exists_subtype (member (op =) fpTs);
+    val exists_Cs_subtype = exists_subtype (member (op =) Cs);
 
     val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
     val ns = map length ctr_Tsss;
@@ -310,25 +328,25 @@
             lthy
             |> mk_Freess "f" g_Tss
             ||>> mk_Freesss "x" y_Tsss;
-          val yssss = map (map (map single)) ysss;
+
+          fun proj_recT proj (Type (s as @{type_name prod}, Ts as [T, U])) =
+              if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_recT proj) Ts)
+            | proj_recT proj (Type (s, Ts)) = Type (s, map (proj_recT proj) Ts)
+            | proj_recT _ T = T;
 
-          fun dest_rec_prodT (T as Type (@{type_name prod}, Us as [_, U])) =
-              if member (op =) Cs U then Us else [T]
-            | dest_rec_prodT T = [T];
+          fun unzip_recT T =
+            if exists_fp_subtype T then [proj_recT fst T, proj_recT snd T] else [T];
 
-          val z_Tssss =
-            map3 (fn n => fn ms => map2 (map dest_rec_prodT oo dest_tupleT) ms o
-              dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
+          val z_Tsss =
+            map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
+              ns mss fp_rec_fun_Ts;
+          val z_Tssss = map (map (map unzip_recT)) z_Tsss;
           val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
 
           val hss = map2 (map2 retype_free) h_Tss gss;
-          val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
-          val (zssss_tl, lthy) =
-            lthy
-            |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
-          val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
+          val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
         in
-          ((((gss, g_Tss, yssss), (hss, h_Tss, zssss)),
+          ((((gss, g_Tss, ysss), (hss, h_Tss, zsss)),
             ([], [], [], (([], []), ([], [])), (([], []), ([], [])))), lthy)
         end
       else
@@ -578,14 +596,37 @@
           let
             val fpT_to_C = fpT --> C;
 
-            fun generate_rec_like (suf, fp_rec_like, (fss, f_Tss, xssss)) =
+            fun build_ctor_rec_arg mk_proj (T, U) =
+              if T = U then
+                id_const T
+              else
+                (case (T, U) of
+                  (Type (s, _), Type (s', _)) =>
+                  if s = s' then build_map (build_ctor_rec_arg mk_proj) T U else mk_proj T
+                | _ => mk_proj T);
+
+            fun mk_U proj (T as Type (@{type_name prod}, [T', U])) =
+                if member (op =) fpTs T' then proj (T', U) else T
+              | mk_U proj (Type (s, Ts)) = Type (s, map (mk_U proj) Ts)
+              | mk_U _ T = T;
+
+            fun unzip_rec (x as Free (_, T)) =
+              if exists_fp_subtype T then
+                [build_ctor_rec_arg fst_const (T, mk_U fst T) $ x,
+                 build_ctor_rec_arg snd_const (T, mk_U snd T) $ x]
+              else
+                [x];
+
+            fun mk_rec_like_arg f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f (maps unzip_rec xs);
+
+            fun generate_rec_like (suf, fp_rec_like, (fss, f_Tss, xsss)) =
               let
                 val res_T = fold_rev (curry (op --->)) f_Tss fpT_to_C;
                 val binding = qualify false fp_b_name (Binding.suffix_name ("_" ^ suf) fp_b);
                 val spec =
                   mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binding, res_T)),
                     Term.list_comb (fp_rec_like,
-                      map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss));
+                      map2 (mk_sum_caseN_balanced oo map2 mk_rec_like_arg) fss xsss));
               in (binding, spec) end;
 
             val rec_like_infos =
@@ -661,14 +702,6 @@
       fold_map I wrap_types_and_mores lthy
       |>> apsnd split_list4 o apfst split_list4 o split_list;
 
-    fun build_map build_arg (Type (s, Ts)) (Type (_, Us)) =
-      let
-        val bnf = the (bnf_of lthy s);
-        val live = live_of_bnf bnf;
-        val mapx = mk_map live Ts Us (map_of_bnf bnf);
-        val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
-      in Term.list_comb (mapx, map build_arg TUs') end;
-
     (* TODO: Add map, sets, rel simps *)
     val mk_simp_thmss =
       map3 (fn (_, _, _, injects, distincts, cases, _, _, _) => fn rec_likes => fn fold_likes =>
@@ -787,10 +820,8 @@
               typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
 
             fun intr_rec_likes frec_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
-              if member (op =) fpTs T then
+              if exists_fp_subtype T then
                 maybe_cons x [build_rec_like frec_likes (K I) (T, mk_U (K I) T) $ x]
-              else if exists_fp_subtype T then
-                [build_rec_like frec_likes maybe_tick (T, mk_U maybe_mk_prodT T) $ x]
               else
                 [x];
 
@@ -802,11 +833,11 @@
             val rec_goalss = map5 (map4 o mk_goal hss) hrecs xctrss hss xsss hxsss;
 
             val fold_tacss =
-              map2 (map o mk_rec_like_tac pre_map_defs nesting_map_ids fold_defs) fp_fold_thms
+              map2 (map o mk_rec_like_tac pre_map_defs [] nesting_map_ids'' fold_defs) fp_fold_thms
                 ctr_defss;
             val rec_tacss =
-              map2 (map o mk_rec_like_tac pre_map_defs nesting_map_ids rec_defs) fp_rec_thms
-                ctr_defss;
+              map2 (map o mk_rec_like_tac pre_map_defs nested_map_comp's
+                (nested_map_ids'' @ nesting_map_ids'') rec_defs) fp_rec_thms ctr_defss;
 
             fun prove goal tac =
               Skip_Proof.prove lthy [] [] goal (tac o #context)
@@ -873,14 +904,6 @@
               map4 (fn u => fn v => fn uvr => fn uv_eq =>
                 fold_rev Term.lambda [u, v] (HOLogic.mk_disj (uvr, uv_eq))) us vs uvrs uv_eqs;
 
-            fun build_rel_step build_arg (Type (s, Ts)) =
-              let
-                val bnf = the (bnf_of lthy s);
-                val live = live_of_bnf bnf;
-                val rel = mk_rel live Ts Ts (rel_of_bnf bnf);
-                val Ts' = map domain_type (fst (strip_typeN live (fastype_of rel)));
-              in Term.list_comb (rel, map build_arg Ts') end;
-
             fun build_rel rs' T =
               (case find_index (curry (op =) T) fpTs of
                 ~1 =>
@@ -963,7 +986,7 @@
 
             fun intr_corec_likes fcorec_likes maybe_mk_sumT maybe_tack cqf =
               let val T = fastype_of cqf in
-                if exists_subtype (member (op =) Cs) T then
+                if exists_Cs_subtype T then
                   build_corec_like fcorec_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
                 else
                   cqf
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML	Sun Sep 30 23:45:03 2012 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML	Mon Oct 01 10:34:58 2012 +0200
@@ -23,7 +23,8 @@
   val mk_induct_tac: Proof.context -> int -> int list -> int list list -> int list list list ->
     thm list -> thm -> thm list -> thm list list -> tactic
   val mk_inject_tac: Proof.context -> thm -> thm -> tactic
-  val mk_rec_like_tac: thm list -> thm list -> thm list -> thm -> thm -> Proof.context -> tactic
+  val mk_rec_like_tac: thm list -> thm list -> thm list -> thm list -> thm -> thm -> Proof.context
+    -> tactic
 end;
 
 structure BNF_FP_Def_Sugar_Tactics : BNF_FP_DEF_SUGAR_TACTICS =
@@ -99,12 +100,12 @@
 
 (*TODO: Try "sum_prod_thms_map" here, enriched with a few theorems*)
 val rec_like_unfold_thms =
-  @{thms comp_def convol_def id_apply map_pair_def prod_case_Pair_iden sum.simps(5,6) sum_map.simps
-      split_conv unit_case_Unity};
+  @{thms comp_def convol_def fst_conv id_def map_pair_def prod_case_Pair_iden snd_conv split_conv
+      sum.simps(5,6) sum_map.simps unit_case_Unity};
 
-fun mk_rec_like_tac pre_map_defs map_ids rec_like_defs ctor_rec_like ctr_def ctxt =
-  unfold_thms_tac ctxt (ctr_def :: ctor_rec_like :: rec_like_defs @ pre_map_defs @ map_ids @
-    rec_like_unfold_thms) THEN unfold_thms_tac ctxt @{thms id_def} THEN rtac refl 1;
+fun mk_rec_like_tac pre_map_defs map_comp's map_ids'' rec_like_defs ctor_rec_like ctr_def ctxt =
+  unfold_thms_tac ctxt (ctr_def :: ctor_rec_like :: rec_like_defs @ pre_map_defs @ map_comp's @
+    map_ids'' @ rec_like_unfold_thms) THEN rtac refl 1;
 
 fun mk_corec_like_tac corec_like_defs map_ids ctor_dtor_corec_like pre_map_def ctr_def ctxt =
   unfold_thms_tac ctxt (ctr_def :: corec_like_defs) THEN