src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
changeset 52315 fafab8eac3ee
parent 52314 9606cf677021
child 52317 7132de305921
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Thu Jun 06 11:47:11 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Thu Jun 06 12:16:35 2013 +0200
@@ -32,11 +32,10 @@
   val mk_co_iter: theory -> BNF_FP_Util.fp_kind -> typ -> typ list -> term -> term
   val nesty_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
   val mk_un_fold_co_rec_prelims: BNF_FP_Util.fp_kind -> typ list -> typ list -> int list ->
-    int list list -> term list -> term list -> Proof.context ->
-    (term list * term list
-       * ((typ list list * typ list list list list * term list list * term list list list list)
-          * (typ list list * typ list list list list * term list list
-             * term list list list list)) option
+    int list list -> term list list -> Proof.context ->
+    (term list list
+     * (typ list list * typ list list list list * term list list
+        * term list list list list) list option
      * (term list * term list list
         * ((term list list * term list list list list * term list list list list)
            * (typ list * typ list list list * typ list list))
@@ -49,8 +48,7 @@
   val mk_iter_fun_arg_typessss: typ list -> int list -> int list list -> term ->
     typ list list list list
   val define_fold_rec:
-    (typ list list * typ list list list list * term list list * term list list list list)
-     * (typ list list * typ list list list list * term list list * term list list list list) ->
+    (typ list list * typ list list list list * term list list * term list list list list) list ->
     (string -> binding) -> typ list -> typ list -> term -> term -> Proof.context ->
     (term * term * thm * thm) * Proof.context
   val define_unfold_corec: term list * term list list
@@ -60,10 +58,11 @@
          * (typ list * typ list list list * typ list list)) ->
     (string -> binding) -> typ list -> typ list -> term -> term -> Proof.context ->
     (term * term * thm * thm) * Proof.context
-  val derive_induct_fold_rec_thms_for_types: BNF_Def.bnf list -> term list list -> thm ->
-    thm list list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list -> typ list ->
-    typ list list list -> term list list -> thm list list -> term list list -> thm list list ->
-    local_theory ->
+  val derive_induct_fold_rec_thms_for_types: BNF_Def.bnf list -> term list list ->
+    (typ list list * typ list list list list * term list list * term list list list list) list ->
+    thm -> thm list list -> BNF_Def.bnf list -> BNF_Def.bnf list -> typ list -> typ list ->
+    typ list -> typ list list list -> term list list -> thm list list -> term list list ->
+    thm list list -> local_theory ->
     (thm * thm list * Args.src list) * (thm list list * Args.src list)
     * (thm list list * Args.src list)
   val derive_coinduct_unfold_corec_thms_for_types: BNF_Def.bnf list -> term list -> term list ->
@@ -275,7 +274,7 @@
   #> map3 mk_fun_arg_typess ns mss
   #> map (map (map (unzip_recT Cs)));
 
-fun mk_fold_rec_args_types Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy =
+fun mk_fold_rec_args_types Cs ns mss [ctor_fold_fun_Ts, ctor_rec_fun_Ts] lthy =
   let
     val Css = map2 replicate ns Cs;
     val y_Tsss = map3 mk_fun_arg_typess ns mss ctor_fold_fun_Ts;
@@ -303,10 +302,10 @@
       |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
     val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
   in
-    (((g_Tss, y_Tssss, gss, yssss), (h_Tss, z_Tssss, hss, zssss)), lthy)
+    ([(g_Tss, y_Tssss, gss, yssss), (h_Tss, z_Tssss, hss, zssss)], lthy)
   end;
 
-fun mk_unfold_corec_args_types Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy =
+fun mk_unfold_corec_args_types Cs ns mss [dtor_unfold_fun_Ts, dtor_corec_fun_Ts] lthy =
   let
     (*avoid "'a itself" arguments in coiterators and corecursors*)
     fun repair_arity [0] = [1]
@@ -361,7 +360,7 @@
     ((cs, cpss, (unfold_args, unfold_types), (corec_args, corec_types)), lthy)
   end;
 
-fun mk_un_fold_co_rec_prelims fp fpTs Cs ns mss xtor_un_folds0 xtor_co_recs0 lthy =
+fun mk_un_fold_co_rec_prelims fp fpTs Cs ns mss [xtor_un_folds0, xtor_co_recs0] lthy =
   let
     val thy = Proof_Context.theory_of lthy;
 
@@ -372,13 +371,13 @@
 
     val ((fold_rec_args_types, unfold_corec_args_types), lthy') =
       if fp = Least_FP then
-        mk_fold_rec_args_types Cs ns mss xtor_un_fold_fun_Ts xtor_co_rec_fun_Ts lthy
+        mk_fold_rec_args_types Cs ns mss [xtor_un_fold_fun_Ts, xtor_co_rec_fun_Ts] lthy
         |>> (rpair NONE o SOME)
       else
-        mk_unfold_corec_args_types Cs ns mss xtor_un_fold_fun_Ts xtor_co_rec_fun_Ts lthy
+        mk_unfold_corec_args_types Cs ns mss [xtor_un_fold_fun_Ts, xtor_co_rec_fun_Ts] lthy
         |>> (pair NONE o SOME)
   in
-    ((xtor_un_folds, xtor_co_recs, fold_rec_args_types, unfold_corec_args_types), lthy')
+    (([xtor_un_folds, xtor_co_recs], fold_rec_args_types, unfold_corec_args_types), lthy')
   end;
 
 fun mk_map live Ts Us t =
@@ -477,7 +476,7 @@
     Term.list_comb (dtor_coiter, map4 mk_preds_getterss_join cs cpss f_sum_prod_Ts cqfsss)
   end;
 
-fun define_fold_rec (fold_only, rec_only) mk_binding fpTs Cs ctor_fold ctor_rec lthy0 =
+fun define_fold_rec [fold_args_types, rec_args_types] mk_binding fpTs Cs ctor_fold ctor_rec lthy0 =
   let
     val thy = Proof_Context.theory_of lthy0;
 
@@ -495,7 +494,7 @@
       in (b, spec) end;
 
     val binding_specs =
-      map generate_iter [(foldN, ctor_fold, fold_only), (recN, ctor_rec, rec_only)];
+      map generate_iter [(foldN, ctor_fold, fold_args_types), (recN, ctor_rec, rec_args_types)];
 
     val ((csts, defs), (lthy', lthy)) = lthy0
       |> apfst split_list o fold_map (fn (b, spec) =>
@@ -513,8 +512,8 @@
   end;
 
 (* TODO: merge with above function? *)
-fun define_unfold_corec (cs, cpss, unfold_only, corec_only) mk_binding fpTs Cs dtor_unfold
-    dtor_corec lthy0 =
+fun define_unfold_corec (cs, cpss, unfold_args_types, corec_args_types) mk_binding fpTs Cs
+    dtor_unfold dtor_corec lthy0 =
   let
     val thy = Proof_Context.theory_of lthy0;
 
@@ -533,7 +532,8 @@
       in (b, spec) end;
 
     val binding_specs =
-      map generate_coiter [(unfoldN, dtor_unfold, unfold_only), (corecN, dtor_corec, corec_only)];
+      map generate_coiter [(unfoldN, dtor_unfold, unfold_args_types),
+        (corecN, dtor_corec, corec_args_types)];
 
     val ((csts, defs), (lthy', lthy)) = lthy0
       |> apfst split_list o fold_map (fn (b, spec) =>
@@ -550,9 +550,9 @@
     ((unfold, corec, unfold_def, corec_def), lthy')
   end ;
 
-fun derive_induct_fold_rec_thms_for_types pre_bnfs [ctor_folds, ctor_recs] ctor_induct
-    [ctor_fold_thms, ctor_rec_thms] nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
-    [folds, recs] [fold_defs, rec_defs] lthy =
+fun derive_induct_fold_rec_thms_for_types pre_bnfs [ctor_folds, ctor_recs]
+    [fold_args_types, rec_args_types] ctor_induct [ctor_fold_thms, ctor_rec_thms] nesting_bnfs
+    nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss [folds, recs] [fold_defs, rec_defs] lthy =
   let
     val ctr_Tsss = map (map (binder_types o fastype_of)) ctrss;
 
@@ -571,11 +571,8 @@
     val ctor_fold_fun_Ts = mk_fp_iter_fun_types (hd ctor_folds);
     val ctor_rec_fun_Ts = mk_fp_iter_fun_types (hd ctor_recs);
 
-    val ((fold_only, rec_only), names_lthy0) =
-      mk_fold_rec_args_types Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy;
-
     val ((((ps, ps'), xsss), us'), names_lthy) =
-      names_lthy0
+      lthy
       |> mk_Frees' "P" (map mk_pred1T fpTs)
       ||>> mk_Freesss "x" ctr_Tsss
       ||>> Variable.variant_fixes fp_b_names;
@@ -700,8 +697,8 @@
         map2 (map2 prove) goalss tacss
       end;
 
-    val fold_thmss = mk_iter_thmss fold_only folds fold_defs ctor_fold_thms;
-    val rec_thmss = mk_iter_thmss rec_only recs rec_defs ctor_rec_thms;
+    val fold_thmss = mk_iter_thmss fold_args_types folds fold_defs ctor_fold_thms;
+    val rec_thmss = mk_iter_thmss rec_args_types recs rec_defs ctor_rec_thms;
   in
     ((induct_thm, induct_thms, [induct_case_names_attr]),
      (fold_thmss, code_simp_attrs), (rec_thmss, code_simp_attrs))
@@ -735,7 +732,7 @@
     val sel_thmsss = map #sel_thmss ctr_sugars;
 
     val ((cs, cpss, ((pgss, crssss, cgssss), _), ((phss, csssss, chssss), _)), names_lthy0) =
-      mk_unfold_corec_args_types Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts lthy;
+      mk_unfold_corec_args_types Cs ns mss [dtor_unfold_fun_Ts, dtor_corec_fun_Ts] lthy;
 
     val (((rs, us'), vs'), names_lthy) =
       names_lthy0
@@ -1128,8 +1125,8 @@
     val kss = map (fn n => 1 upto n) ns;
     val mss = map (map length) ctr_Tsss;
 
-    val ((xtor_un_folds, xtor_co_recs, fold_rec_args_types, unfold_corec_args_types), lthy) =
-      mk_un_fold_co_rec_prelims fp fpTs Cs ns mss xtor_un_folds0 xtor_co_recs0 lthy;
+    val (([xtor_un_folds, xtor_co_recs], fold_rec_args_types, unfold_corec_args_types), lthy) =
+      mk_un_fold_co_rec_prelims fp fpTs Cs ns mss [xtor_un_folds0, xtor_co_recs0] lthy;
 
     fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
               xtor_un_fold), xtor_co_rec), ctor_dtor), dtor_ctor), ctor_inject), pre_map_def),
@@ -1347,8 +1344,9 @@
         val ((induct_thm, induct_thms, induct_attrs), (fold_thmss, fold_attrs),
              (rec_thmss, rec_attrs)) =
           derive_induct_fold_rec_thms_for_types pre_bnfs [xtor_un_folds, xtor_co_recs]
-            xtor_co_induct [xtor_un_fold_thms, xtor_co_rec_thms] nesting_bnfs nested_bnfs fpTs Cs Xs
-            ctrXs_Tsss ctrss ctr_defss [folds, recs] [fold_defs, rec_defs] lthy;
+            (the fold_rec_args_types) xtor_co_induct [xtor_un_fold_thms, xtor_co_rec_thms]
+            nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss [folds, recs]
+            [fold_defs, rec_defs] lthy;
 
         val induct_type_attr = Attrib.internal o K o Induct.induct_type;