src/HOL/Tools/BNF/bnf_lfp_compat.ML
changeset 58217 d81d39278d48
parent 58214 bd1754377965
child 58218 a92acec845a7
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Mon Sep 08 14:03:02 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Mon Sep 08 14:03:02 2014 +0200
@@ -119,7 +119,7 @@
     fold_map3 (define_co_rec_as Least_FP Cs) fpTs bs rhss lthy
   end;
 
-fun mk_split_rec_thmss ctxt fpTs ctrss rec0_thmss (recs as rec1 :: _) rec_defs =
+fun mk_split_rec_thmss ctxt Xs fpTs ctr_Tsss ctrss rec0_thmss (recs as rec1 :: _) rec_defs =
   let
     val f_Ts = binder_fun_types (fastype_of rec1);
     val (fs, _) = mk_Frees "f" f_Ts ctxt;
@@ -136,23 +136,20 @@
           val xg = Term.list_comb (g, map Bound (n - 1 downto 0));
         in frec $ xg end;
 
-    fun mk_rec_arg_arg g_T g =
-      g :: (if exists_subtype_in fpTs g_T then [mk_rec_call g 0 g_T] else []);
+    fun mk_rec_arg_arg ctr_T g =
+      g :: (if exists_subtype_in fpTs ctr_T then [mk_rec_call g 0 ctr_T] else []);
 
-    fun mk_goal frec ctr f =
+    fun mk_goal frec ctr_Ts ctr f =
       let
-        val g_Ts = binder_types (fastype_of ctr);
-        val (gs, _) = mk_Frees "g" g_Ts ctxt;
+        val (gs, _) = mk_Frees "g" ctr_Ts ctxt;
         val gctr = Term.list_comb (ctr, gs);
-        val fgs = flat_rec_arg_args (map2 mk_rec_arg_arg g_Ts gs);
+        val fgs = flat_rec_arg_args (map2 mk_rec_arg_arg ctr_Ts gs);
       in
         fold_rev (fold_rev Logic.all) [fs, gs]
           (mk_Trueprop_eq (frec $ gctr, Term.list_comb (f, fgs)))
       end;
 
-    fun mk_goals ctrs fs frec = map2 (mk_goal frec) ctrs fs;
-
-    val goalss = map3 mk_goals ctrss fss frecs;
+    val goalss = map4 (map3 o mk_goal) frecs ctr_Tsss ctrss fss;
 
     fun tac ctxt =
       unfold_thms_tac ctxt (@{thms o_apply fst_conv snd_conv} @ rec_defs @ flat rec0_thmss) THEN
@@ -165,7 +162,8 @@
     map (map prove) goalss
   end;
 
-fun define_split_rec_derive_induct_rec_thms fpTs ctrss inducts induct recs0 rec_thmss lthy =
+fun define_split_rec_derive_induct_rec_thms Xs fpTs ctr_Tsss ctrss inducts induct recs0 rec_thmss
+    lthy =
   let
     val thy = Proof_Context.theory_of lthy;
 
@@ -179,7 +177,7 @@
     val Cs = map ((fn TVar ((s, _), S) => TFree (s, S)) o body_type o fastype_of) recs0;
     val recs = map2 (mk_co_rec thy Least_FP Cs) fpTs recs0;
     val ((recs', rec'_defs), lthy') = define_split_recs fpTs Cs recs lthy |>> split_list;
-    val rec'_thmss = mk_split_rec_thmss lthy' fpTs ctrss rec_thmss recs' rec'_defs;
+    val rec'_thmss = mk_split_rec_thmss lthy' Xs fpTs ctr_Tsss ctrss rec_thmss recs' rec'_defs;
   in
     ((inducts', induct', recs', rec'_thmss), lthy')
   end;
@@ -265,7 +263,7 @@
     val fpTs' = Old_Datatype_Aux.get_rec_types descr;
     val nn = length fpTs';
 
-    val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) fpTs';
+    val fp_sugars = map (lfp_sugar_of o fst o dest_Type) fpTs';
     val ctr_Tsss = map (map (map dest_dtyp o snd) o #3 o snd) descr;
     val kkssss =
       map (map (map (fn Old_Datatype_Aux.DtRec kk => [kk] | _ => []) o snd) o #3 o snd) descr;
@@ -282,32 +280,34 @@
     val compat_b_names = map (prefix compat_N) b_names;
     val compat_bs = map Binding.name compat_b_names;
 
-    val ((fp_sugars, (lfp_sugar_thms, _)), lthy') =
+    val ((fp_sugars', (lfp_sugar_thms', _)), lthy') =
       if nn > nn_fp then
-        mutualize_fp_sugars Least_FP cliques compat_bs fpTs' callers callssss fp_sugars0 lthy
+        mutualize_fp_sugars Least_FP cliques compat_bs fpTs' callers callssss fp_sugars lthy
       else
-        ((fp_sugars0, (NONE, NONE)), lthy);
+        ((fp_sugars, (NONE, NONE)), lthy);
+
+    fun mk_ctr_of {ctr_sugar = {ctrs, ...}, ...} (Type (_, Ts)) = map (mk_ctr Ts) ctrs;
 
-    fun mk_ctrs_of (Type (T_name, As)) =
-      map (mk_ctr As) (#ctrs (the (ctr_sugar_of lthy' T_name)));
-
-    val ctrss' = map mk_ctrs_of fpTs';
-    val {common_co_inducts = [induct], ...} :: _ = fp_sugars;
-    val inducts = map (the_single o #co_inducts) fp_sugars;
-    val recs = map #co_rec fp_sugars;
-    val rec_thmss = map #co_rec_thms fp_sugars;
+    val Xs' = map #X fp_sugars';
+    val ctr_Tsss' = map (map (binder_types o fastype_of) o #ctrs o #ctr_sugar) fp_sugars'; (*###*)
+    val ctrss' = map2 mk_ctr_of fp_sugars' fpTs';
+    val {common_co_inducts = [induct], ...} :: _ = fp_sugars';
+    val inducts = map (the_single o #co_inducts) fp_sugars';
+    val recs = map #co_rec fp_sugars';
+    val rec_thmss = map #co_rec_thms fp_sugars';
 
     fun is_nested_rec_type (Type (@{type_name fun}, [_, T])) = member (op =) fpTs' (body_type T)
       | is_nested_rec_type _ = false;
 
-    val ((lfp_sugar_thms', (inducts', induct', recs', rec'_thmss)), lthy'') =
+    val ((lfp_sugar_thms'', (inducts', induct', recs', rec'_thmss)), lthy'') =
       if nesting_pref = Unfold_Nesting andalso
-         exists (exists (exists is_nested_rec_type)) ctr_Tsss then
-        define_split_rec_derive_induct_rec_thms fpTs' ctrss' inducts induct recs rec_thmss lthy'
+         exists (exists (exists is_nested_rec_type)) ctr_Tsss' then
+        define_split_rec_derive_induct_rec_thms Xs' fpTs' ctr_Tsss' ctrss' inducts induct recs
+          rec_thmss lthy'
         |>> `(fn (inducts', induct', _, rec'_thmss) =>
           SOME ((inducts', induct', mk_induct_attrs ctrss'), (rec'_thmss, [])))
       else
-        ((lfp_sugar_thms, (inducts, induct, recs, rec_thmss)), lthy');
+        ((lfp_sugar_thms', (inducts, induct, recs, rec_thmss)), lthy');
 
     val rec'_names = map (fst o dest_Const) recs';
     val rec'_thms = flat rec'_thmss;
@@ -321,9 +321,9 @@
         case_cong = case_cong, case_cong_weak = case_cong_weak, split = split,
         split_asm = split_asm});
 
-    val infos = map_index mk_info (take nn_fp fp_sugars);
+    val infos = map_index mk_info (take nn_fp fp_sugars');
   in
-    (nn, b_names, compat_b_names, lfp_sugar_thms', infos, lthy'')
+    (nn, b_names, compat_b_names, lfp_sugar_thms'', infos, lthy'')
   end;
 
 fun infos_of_new_datatype_mutual_cluster lthy fpT_name =