src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 52310 28063e412793
parent 52309 f71d0a604e5a
child 52311 e2f6ac15d79a
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Thu Jun 06 08:40:37 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Thu Jun 06 09:17:17 2013 +0200
@@ -62,8 +62,8 @@
     (term * term * thm * thm) * Proof.context
   val derive_induct_fold_rec_thms_for_types: BNF_Def.bnf list -> term list -> term list -> thm ->
     thm list -> thm list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list ->
-    term list list -> thm list list -> term list -> term list -> thm list -> thm list ->
-    local_theory ->
+    typ list -> typ list list list -> term list list -> thm list list -> term list -> term list ->
+    thm list -> thm list -> local_theory ->
     (thm * thm list * Args.src list) * (thm list list * Args.src list)
     * (thm list list * Args.src list)
   val derive_coinduct_unfold_corec_thms_for_types: BNF_Def.bnf list -> term list -> term list ->
@@ -551,8 +551,8 @@
   end ;
 
 fun derive_induct_fold_rec_thms_for_types pre_bnfs ctor_folds ctor_recs ctor_induct ctor_fold_thms
-    ctor_rec_thms nesting_bnfs nested_bnfs fpTs Cs ctrss ctr_defss folds recs fold_defs rec_defs
-    lthy =
+    ctor_rec_thms nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss folds recs
+    fold_defs rec_defs lthy =
   let
     val ctr_Tsss = map (map (binder_types o fastype_of)) ctrss;
 
@@ -604,24 +604,24 @@
             Term.subst_atomic_types (Ts0 ~~ Ts) t
           end;
 
-        fun mk_raw_prem_prems names_lthy (x as Free (s, T as Type (T_name, Ts0))) =
-            (case find_index (curry (op =) T) fpTs of
-              ~1 =>
-              (case AList.lookup (op =) setss_nested T_name of
-                NONE => []
-              | SOME raw_sets0 =>
-                let
-                  val (Ts, raw_sets) =
-                    split_list (filter (exists_subtype_in fpTs o fst) (Ts0 ~~ raw_sets0));
-                  val sets = map (mk_set Ts0) raw_sets;
-                  val (ys, names_lthy') = names_lthy |> mk_Frees s Ts;
-                  val xysets = map (pair x) (ys ~~ sets);
-                  val ppremss = map (mk_raw_prem_prems names_lthy') ys;
-                in
-                  flat (map2 (map o apfst o cons) xysets ppremss)
-                end)
-            | kk => [([], (kk + 1, x))])
-          | mk_raw_prem_prems _ _ = [];
+        fun mk_raw_prem_prems _ (x as Free (_, Type _)) (X as TFree _) =
+            [([], (find_index (curry (op =) X) Xs + 1, x))]
+          | mk_raw_prem_prems names_lthy (x as Free (s, Type (T_name, Ts0))) (Type (_, Xs_Ts0)) =
+            (case AList.lookup (op =) setss_nested T_name of
+              NONE => []
+            | SOME raw_sets0 =>
+              let
+                val (Xs_Ts, (Ts, raw_sets)) =
+                  filter (exists_subtype_in Xs o fst) (Xs_Ts0 ~~ (Ts0 ~~ raw_sets0))
+                  |> split_list ||> split_list;
+                val sets = map (mk_set Ts0) raw_sets;
+                val (ys, names_lthy') = names_lthy |> mk_Frees s Ts;
+                val xysets = map (pair x) (ys ~~ sets);
+                val ppremss = map2 (mk_raw_prem_prems names_lthy') ys Xs_Ts;
+              in
+                flat (map2 (map o apfst o cons) xysets ppremss)
+              end)
+          | mk_raw_prem_prems _ _ _ = [];
 
         fun close_prem_prem xs t =
           fold_rev Logic.all (map Free (drop (nn + length xs)
@@ -632,16 +632,16 @@
               HOLogic.mk_Trueprop (HOLogic.mk_mem (y, set $ x'))) xysets,
             HOLogic.mk_Trueprop (nth ps (j - 1) $ x)));
 
-        fun mk_raw_prem phi ctr ctr_Ts =
+        fun mk_raw_prem phi ctr ctr_Ts ctrXs_Ts =
           let
             val (xs, names_lthy') = names_lthy |> mk_Frees "x" ctr_Ts;
-            val pprems = maps (mk_raw_prem_prems names_lthy') xs;
+            val pprems = flat (map2 (mk_raw_prem_prems names_lthy') xs ctrXs_Ts);
           in (xs, pprems, HOLogic.mk_Trueprop (phi $ Term.list_comb (ctr, xs))) end;
 
         fun mk_prem (xs, raw_pprems, concl) =
           fold_rev Logic.all xs (Logic.list_implies (map (mk_prem_prem xs) raw_pprems, concl));
 
-        val raw_premss = map3 (map2 o mk_raw_prem) ps ctrss ctr_Tsss;
+        val raw_premss = map4 (map3 o mk_raw_prem) ps ctrss ctr_Tsss ctrXs_Tsss;
 
         val goal =
           Library.foldr (Logic.list_implies o apfst (map mk_prem)) (raw_premss,
@@ -1072,11 +1072,11 @@
         | kk => nth Xs kk)
       | freeze_fp T = T;
 
-    val ctr_TsssXs = map (map (map freeze_fp)) fake_ctr_Tsss;
-    val ctr_sum_prod_TsXs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssXs;
+    val ctrXs_Tsss = map (map (map freeze_fp)) fake_ctr_Tsss;
+    val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
 
     val fp_eqs =
-      map dest_TFree Xs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctr_sum_prod_TsXs;
+      map dest_TFree Xs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctrXs_sum_prod_Ts;
 
     val (pre_bnfs, (fp_res as {bnfs = fp_bnfs as any_fp_bnf :: _, ctors = ctors0, dtors = dtors0,
            un_folds = fp_folds0, co_recs = fp_recs0, co_induct = fp_induct,
@@ -1088,8 +1088,8 @@
 
     val timer = time (Timer.startRealTimer ());
 
-    val nesting_bnfs = nesty_bnfs lthy ctr_TsssXs As;
-    val nested_bnfs = nesty_bnfs lthy ctr_TsssXs Xs;
+    val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
+    val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
 
     val pre_map_defs = map map_def_of_bnf pre_bnfs;
     val pre_set_defss = map set_defs_of_bnf pre_bnfs;
@@ -1124,7 +1124,7 @@
             ((qualify true fp_b_name (Binding.name thmN), attrs T_name),
              [(thms, [])])) fp_b_names fpTs thmss);
 
-    val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
+    val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctrXs_Tsss;
     val ns = map length ctr_Tsss;
     val kss = map (fn n => 1 upto n) ns;
     val mss = map (map length) ctr_Tsss;
@@ -1345,8 +1345,8 @@
         val ((induct_thm, induct_thms, induct_attrs), (fold_thmss, fold_attrs),
              (rec_thmss, rec_attrs)) =
           derive_induct_fold_rec_thms_for_types pre_bnfs fp_folds fp_recs fp_induct fp_fold_thms
-            fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs ctrss ctr_defss folds recs fold_defs
-            rec_defs lthy;
+            fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss folds recs
+            fold_defs rec_defs lthy;
 
         val induct_type_attr = Attrib.internal o K o Induct.induct_type;