generalized recursors, effectively reverting inductive half of c7a034d01936
authorblanchet
Wed, 29 May 2013 02:35:49 +0200
changeset 52214 4cc5a80bba80
parent 52213 f4c5c6320cce
child 52215 7facaee8586f
generalized recursors, effectively reverting inductive half of c7a034d01936
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Wed May 29 02:35:49 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Wed May 29 02:35:49 2013 +0200
@@ -32,8 +32,8 @@
   val indexify_fst: ''a list -> (int -> ''a * 'b -> 'c) -> ''a * 'b -> 'c
   val mk_un_fold_co_rec_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
     int list list -> term list -> term list -> Proof.context ->
-    (term list * term list * ((term list list * typ list list * term list list list)
-       * (term list list * typ list list * term list list list)) option
+    (term list * term list * ((term list list * typ list list * term list list list list)
+       * (term list list * typ list list * term list list list list)) option
      * (term list * term list list
         * ((term list list * term list list list list * term list list list list)
            * (typ list * typ list list list * typ list list))
@@ -44,9 +44,10 @@
 
   val mk_iter_fun_arg_types_pairsss: typ list -> int list -> int list list -> term ->
     (typ list * typ list) list list list
-  val define_fold_rec: (term list list * typ list list * term list list list)
-      * (term list list * typ list list * term list list list) -> (string -> binding) -> typ list ->
-    typ list -> term -> term -> Proof.context -> (term * term * thm * thm) * Proof.context
+  val define_fold_rec: (term list list * typ list list * term list list list list)
+      * (term list list * typ list list * term list list list list) -> (string -> binding) ->
+    typ list -> typ list -> term -> term -> Proof.context ->
+    (term * term * thm * thm) * Proof.context
   val define_unfold_corec: term list * term list list
       * ((term list list * term list list list list * term list list list list)
          * (typ list * typ list list list * typ list list))
@@ -182,17 +183,14 @@
 
 val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
 
-fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
-fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
-
 fun flat_rec unzipf xs =
   let val ps = map unzipf xs in
     (* The first line below gives the preferred order. The second line is for compatibility with the
        old datatype package: *)
-(*
     maps (op @) ps
+(* ###
+    maps fst ps @ maps snd ps
 *)
-    maps fst ps @ maps snd ps
   end;
 
 fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss);
@@ -201,6 +199,11 @@
   | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
     p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
 
+fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
+fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
+fun mk_uncurried2_fun f xss =
+  mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
+
 fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) =
   Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1));
 
@@ -245,8 +248,12 @@
 
 val mk_fp_iter_fun_types = fst o split_last o binder_types o fastype_of;
 
-fun meta_unzip_rec getT proj1 proj2 fpTs y =
-  if exists_subtype_in fpTs (getT y) then ([proj1 y], [proj2 y]) else ([y], []);
+fun meta_unzip_rec getT left right nested fpTs y =
+  let val T = getT y in
+    if member (op =) fpTs T then ([left y], [right y])
+    else if exists_subtype_in fpTs T then ([nested y], [])
+    else ([y], [])
+  end;
 
 fun project_co_recT special_Tname fpTs proj =
   let
@@ -259,10 +266,7 @@
 val project_recT = project_co_recT @{type_name prod};
 val project_corecT = project_co_recT @{type_name sum};
 
-fun unzip_recT fpTs = meta_unzip_rec I (project_recT fpTs fst) (project_recT fpTs snd) fpTs;
-
-fun mk_fold_fun_typess y_Tsss Cs = map2 (fn C => map (fn y_Ts => y_Ts ---> C)) Cs y_Tsss;
-val mk_rec_fun_typess = mk_fold_fun_typess oo map o map o flat_rec o unzip_recT;
+fun unzip_recT fpTs = meta_unzip_rec I (project_recT fpTs fst) (project_recT fpTs snd) I fpTs;
 
 fun mk_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
 
@@ -273,21 +277,40 @@
 
 fun mk_fold_rec_args_types fpTs Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy =
   let
+    val Css = map2 replicate ns Cs;
     val y_Tsss = map3 mk_fun_arg_typess ns mss ctor_fold_fun_Ts;
-    val g_Tss = mk_fold_fun_typess y_Tsss Cs;
+    val g_Tss = map2 (fn C => map (fn y_Ts => y_Ts ---> C)) Cs y_Tsss;
 
     val ((gss, ysss), lthy) =
       lthy
       |> mk_Freess "f" g_Tss
       ||>> mk_Freesss "x" y_Tsss;
+    val yssss = map (map (map single)) ysss;
+
+    (* ### *)
+    fun dest_rec_prodT (T as Type (@{type_name prod}, Us as [_, U])) =
+        if member (op =) Cs U then Us else [T]
+      | dest_rec_prodT T = [T];
+
+    val z_Tssss =
+      map3 (fn n => fn ms => map2 (map dest_rec_prodT oo dest_tupleT) ms o
+        dest_sumTN_balanced n o domain_type) ns mss ctor_rec_fun_Ts;
 
     val z_Tsss = map3 mk_fun_arg_typess ns mss ctor_rec_fun_Ts;
-    val h_Tss = mk_rec_fun_typess fpTs z_Tsss Cs;
+    val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
 
     val hss = map2 (map2 retype_free) h_Tss gss;
     val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
+    val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
+    val (zssss_tl, lthy) =
+      lthy
+      |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
+    val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
+
+val _ = tracing (" *** OLD:  " ^ PolyML.makestring (ysss, zsss)) (*###*)
+val _ = tracing ("  *** NEW: " ^ PolyML.makestring (yssss, zssss)) (*###*)
   in
-    (((gss, g_Tss, ysss), (hss, h_Tss, zsss)), lthy)
+    (((gss, g_Tss, yssss), (hss, h_Tss, zssss)), lthy)
   end;
 
 fun mk_unfold_corec_args_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy =
@@ -438,18 +461,12 @@
         | _ => build_simple TU);
   in build end;
 
-fun mk_iter_body lthy fpTs ctor_iter fss xsss =
+fun mk_iter_body lthy fpTs ctor_iter fss xssss =
   let
     fun build_proj sel sel_const (x as Free (_, T)) =
       build_map lthy (sel_const o fst) (T, project_recT fpTs sel T) $ x;
-
-    (* TODO: Avoid these complications; cf. corec case *)
-    val unzip_rec = meta_unzip_rec (snd o dest_Free) (build_proj fst fst_const)
-      (build_proj snd snd_const) fpTs;
-
-    fun mk_iter_arg f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f (flat_rec unzip_rec xs);
   in
-    Term.list_comb (ctor_iter, map2 (mk_sum_caseN_balanced oo map2 mk_iter_arg) fss xsss)
+    Term.list_comb (ctor_iter, map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss)
   end;
 
 fun mk_preds_getterss_join c cps sum_prod_T cqfss =
@@ -480,13 +497,13 @@
 
     val fpT_to_C as Type (_, [fpT, _]) = snd (strip_typeN nn (fastype_of ctor_fold));
 
-    fun generate_iter (suf, ctor_iter, (fss, f_Tss, xsss)) =
+    fun generate_iter (suf, ctor_iter, (fss, f_Tss, xssss)) =
       let
         val res_T = fold_rev (curry (op --->)) f_Tss fpT_to_C;
         val binding = mk_binding suf;
         val spec =
           mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binding, res_T)),
-            mk_iter_body lthy0 fpTs ctor_iter fss xsss);
+            mk_iter_body lthy0 fpTs ctor_iter fss xssss);
       in (binding, spec) end;
 
     val binding_specs =
@@ -558,7 +575,6 @@
     val pre_map_defs = map map_def_of_bnf pre_bnfs;
     val pre_set_defss = map set_defs_of_bnf pre_bnfs;
     val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nesting_bnfs;
-    val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
     val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
     val nested_set_map's = maps set_map'_of_bnf nested_bnfs;
 
@@ -671,24 +687,47 @@
 
         val mk_U = typ_subst_nonatomic (map2 pair fpTs Cs);
 
-        fun unzip_iters fiters =
+        (* ### *)
+        fun typ_subst inst (T as Type (s, Ts)) =
+            (case AList.lookup (op =) inst T of
+              NONE => Type (s, map (typ_subst inst) Ts)
+            | SOME T' => T')
+          | typ_subst inst T = the_default T (AList.lookup (op =) inst T);
+
+        fun mk_U' maybe_mk_prodT =
+          typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
+
+        (* ### *)
+        fun build_rec_like fiters maybe_tick (T, U) =
+          if T = U then
+            id_const T
+          else
+            (case find_index (curry (op =) T) fpTs of
+              ~1 => build_map lthy (build_rec_like fiters maybe_tick) (T, U)
+            | kk => maybe_tick (nth us kk) (nth fiters kk));
+
+        fun unzip_iters fiters maybe_tick maybe_mk_prodT =
           meta_unzip_rec (snd o dest_Free) I
             (fn x as Free (_, T) => build_map lthy (indexify_fst fpTs (K o nth fiters))
-              (T, mk_U T) $ x) fpTs;
+              (T, mk_U T) $ x)
+            (fn x as Free (_, T) => build_rec_like fiters maybe_tick (T, mk_U' maybe_mk_prodT T) $ x)
+            fpTs;
+
+        fun tick u f = Term.lambda u (HOLogic.mk_prod (u, f $ u));
 
         val gxsss = map (map (flat_rec ((fn (ts, ts') => ([hd (ts' @ ts)], [])) o
-          unzip_iters gfolds))) xsss;
-        val hxsss = map (map (flat_rec (unzip_iters hrecs))) xsss;
+          unzip_iters gfolds (K I) (K I)))) xsss;
+        val hxsss = map (map (flat_rec (unzip_iters hrecs tick (curry HOLogic.mk_prodT)))) xsss;
 
         val fold_goalss = map5 (map4 o mk_goal gss) gfolds xctrss gss xsss gxsss;
         val rec_goalss = map5 (map4 o mk_goal hss) hrecs xctrss hss xsss hxsss;
 
         val fold_tacss =
-          map2 (map o mk_iter_tac pre_map_defs [] nesting_map_ids'' fold_defs) ctor_fold_thms
-            ctr_defss;
+          map2 (map o mk_iter_tac pre_map_defs nesting_map_ids'' fold_defs)
+            ctor_fold_thms ctr_defss;
         val rec_tacss =
-          map2 (map o mk_iter_tac pre_map_defs nested_map_comp's
-            (nested_map_ids'' @ nesting_map_ids'') rec_defs) ctor_rec_thms ctr_defss;
+          map2 (map o mk_iter_tac pre_map_defs (nested_map_ids'' @ nesting_map_ids'') rec_defs)
+            ctor_rec_thms ctr_defss;
 
         fun prove goal tac =
           Goal.prove_sorry lthy [] [] goal (tac o #context)
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML	Wed May 29 02:35:49 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML	Wed May 29 02:35:49 2013 +0200
@@ -24,8 +24,7 @@
   val mk_induct_tac: Proof.context -> int -> int list -> int list list -> int list list list ->
     thm list -> thm -> thm list -> thm list list -> tactic
   val mk_inject_tac: Proof.context -> thm -> thm -> tactic
-  val mk_iter_tac: thm list -> thm list -> thm list -> thm list -> thm -> thm -> Proof.context
-    -> tactic
+  val mk_iter_tac: thm list -> thm list -> thm list -> thm -> thm -> Proof.context -> tactic
 end;
 
 structure BNF_FP_Def_Sugar_Tactics : BNF_FP_DEF_SUGAR_TACTICS =
@@ -103,9 +102,9 @@
   @{thms comp_def convol_def fst_conv id_def prod_case_Pair_iden snd_conv
       split_conv unit_case_Unity} @ sum_prod_thms_map;
 
-fun mk_iter_tac pre_map_defs map_comp's map_ids'' iter_defs ctor_iter ctr_def ctxt =
-  unfold_thms_tac ctxt (ctr_def :: ctor_iter :: iter_defs @ pre_map_defs @ map_comp's @
-    map_ids'' @ iter_unfold_thms) THEN rtac refl 1;
+fun mk_iter_tac pre_map_defs map_ids'' iter_defs ctor_iter ctr_def ctxt =
+  unfold_thms_tac ctxt (ctr_def :: ctor_iter :: iter_defs @ pre_map_defs @ map_ids'' @
+    iter_unfold_thms) THEN rtac refl 1;
 
 val coiter_unfold_thms =
   @{thms id_def ident_o_ident sum_case_if sum_case_o_inj} @ sum_prod_thms_map;