support induction principles with multiple occurrences of the same type in "fpTs" and (hopefully) with loss of recursion (e.g. primrec definition of is_nil, where the IH can be dropped)
authorblanchet
Thu, 06 Jun 2013 09:17:17 +0200
changeset 52310 28063e412793
parent 52309 f71d0a604e5a
child 52311 e2f6ac15d79a
support induction principles with multiple occurrences of the same type in "fpTs" and (hopefully) with loss of recursion (e.g. primrec definition of is_nil, where the IH can be dropped)
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Thu Jun 06 08:40:37 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Thu Jun 06 09:17:17 2013 +0200
@@ -62,8 +62,8 @@
     (term * term * thm * thm) * Proof.context
   val derive_induct_fold_rec_thms_for_types: BNF_Def.bnf list -> term list -> term list -> thm ->
     thm list -> thm list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list ->
-    term list list -> thm list list -> term list -> term list -> thm list -> thm list ->
-    local_theory ->
+    typ list -> typ list list list -> term list list -> thm list list -> term list -> term list ->
+    thm list -> thm list -> local_theory ->
     (thm * thm list * Args.src list) * (thm list list * Args.src list)
     * (thm list list * Args.src list)
   val derive_coinduct_unfold_corec_thms_for_types: BNF_Def.bnf list -> term list -> term list ->
@@ -551,8 +551,8 @@
   end ;
 
 fun derive_induct_fold_rec_thms_for_types pre_bnfs ctor_folds ctor_recs ctor_induct ctor_fold_thms
-    ctor_rec_thms nesting_bnfs nested_bnfs fpTs Cs ctrss ctr_defss folds recs fold_defs rec_defs
-    lthy =
+    ctor_rec_thms nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss folds recs
+    fold_defs rec_defs lthy =
   let
     val ctr_Tsss = map (map (binder_types o fastype_of)) ctrss;
 
@@ -604,24 +604,24 @@
             Term.subst_atomic_types (Ts0 ~~ Ts) t
           end;
 
-        fun mk_raw_prem_prems names_lthy (x as Free (s, T as Type (T_name, Ts0))) =
-            (case find_index (curry (op =) T) fpTs of
-              ~1 =>
-              (case AList.lookup (op =) setss_nested T_name of
-                NONE => []
-              | SOME raw_sets0 =>
-                let
-                  val (Ts, raw_sets) =
-                    split_list (filter (exists_subtype_in fpTs o fst) (Ts0 ~~ raw_sets0));
-                  val sets = map (mk_set Ts0) raw_sets;
-                  val (ys, names_lthy') = names_lthy |> mk_Frees s Ts;
-                  val xysets = map (pair x) (ys ~~ sets);
-                  val ppremss = map (mk_raw_prem_prems names_lthy') ys;
-                in
-                  flat (map2 (map o apfst o cons) xysets ppremss)
-                end)
-            | kk => [([], (kk + 1, x))])
-          | mk_raw_prem_prems _ _ = [];
+        fun mk_raw_prem_prems _ (x as Free (_, Type _)) (X as TFree _) =
+            [([], (find_index (curry (op =) X) Xs + 1, x))]
+          | mk_raw_prem_prems names_lthy (x as Free (s, Type (T_name, Ts0))) (Type (_, Xs_Ts0)) =
+            (case AList.lookup (op =) setss_nested T_name of
+              NONE => []
+            | SOME raw_sets0 =>
+              let
+                val (Xs_Ts, (Ts, raw_sets)) =
+                  filter (exists_subtype_in Xs o fst) (Xs_Ts0 ~~ (Ts0 ~~ raw_sets0))
+                  |> split_list ||> split_list;
+                val sets = map (mk_set Ts0) raw_sets;
+                val (ys, names_lthy') = names_lthy |> mk_Frees s Ts;
+                val xysets = map (pair x) (ys ~~ sets);
+                val ppremss = map2 (mk_raw_prem_prems names_lthy') ys Xs_Ts;
+              in
+                flat (map2 (map o apfst o cons) xysets ppremss)
+              end)
+          | mk_raw_prem_prems _ _ _ = [];
 
         fun close_prem_prem xs t =
           fold_rev Logic.all (map Free (drop (nn + length xs)
@@ -632,16 +632,16 @@
               HOLogic.mk_Trueprop (HOLogic.mk_mem (y, set $ x'))) xysets,
             HOLogic.mk_Trueprop (nth ps (j - 1) $ x)));
 
-        fun mk_raw_prem phi ctr ctr_Ts =
+        fun mk_raw_prem phi ctr ctr_Ts ctrXs_Ts =
           let
             val (xs, names_lthy') = names_lthy |> mk_Frees "x" ctr_Ts;
-            val pprems = maps (mk_raw_prem_prems names_lthy') xs;
+            val pprems = flat (map2 (mk_raw_prem_prems names_lthy') xs ctrXs_Ts);
           in (xs, pprems, HOLogic.mk_Trueprop (phi $ Term.list_comb (ctr, xs))) end;
 
         fun mk_prem (xs, raw_pprems, concl) =
           fold_rev Logic.all xs (Logic.list_implies (map (mk_prem_prem xs) raw_pprems, concl));
 
-        val raw_premss = map3 (map2 o mk_raw_prem) ps ctrss ctr_Tsss;
+        val raw_premss = map4 (map3 o mk_raw_prem) ps ctrss ctr_Tsss ctrXs_Tsss;
 
         val goal =
           Library.foldr (Logic.list_implies o apfst (map mk_prem)) (raw_premss,
@@ -1072,11 +1072,11 @@
         | kk => nth Xs kk)
       | freeze_fp T = T;
 
-    val ctr_TsssXs = map (map (map freeze_fp)) fake_ctr_Tsss;
-    val ctr_sum_prod_TsXs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssXs;
+    val ctrXs_Tsss = map (map (map freeze_fp)) fake_ctr_Tsss;
+    val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
 
     val fp_eqs =
-      map dest_TFree Xs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctr_sum_prod_TsXs;
+      map dest_TFree Xs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctrXs_sum_prod_Ts;
 
     val (pre_bnfs, (fp_res as {bnfs = fp_bnfs as any_fp_bnf :: _, ctors = ctors0, dtors = dtors0,
            un_folds = fp_folds0, co_recs = fp_recs0, co_induct = fp_induct,
@@ -1088,8 +1088,8 @@
 
     val timer = time (Timer.startRealTimer ());
 
-    val nesting_bnfs = nesty_bnfs lthy ctr_TsssXs As;
-    val nested_bnfs = nesty_bnfs lthy ctr_TsssXs Xs;
+    val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
+    val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
 
     val pre_map_defs = map map_def_of_bnf pre_bnfs;
     val pre_set_defss = map set_defs_of_bnf pre_bnfs;
@@ -1124,7 +1124,7 @@
             ((qualify true fp_b_name (Binding.name thmN), attrs T_name),
              [(thms, [])])) fp_b_names fpTs thmss);
 
-    val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
+    val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctrXs_Tsss;
     val ns = map length ctr_Tsss;
     val kss = map (fn n => 1 upto n) ns;
     val mss = map (map length) ctr_Tsss;
@@ -1345,8 +1345,8 @@
         val ((induct_thm, induct_thms, induct_attrs), (fold_thmss, fold_attrs),
              (rec_thmss, rec_attrs)) =
           derive_induct_fold_rec_thms_for_types pre_bnfs fp_folds fp_recs fp_induct fp_fold_thms
-            fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs ctrss ctr_defss folds recs fold_defs
-            rec_defs lthy;
+            fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss folds recs
+            fold_defs rec_defs lthy;
 
         val induct_type_attr = Attrib.internal o K o Induct.induct_type;