tuned signature
authorblanchet
Tue, 30 Apr 2013 09:53:56 +0200
changeset 51827 836257faaad5
parent 51826 054a40461449
child 51828 67c6d6136915
tuned signature
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Apr 30 03:18:07 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Apr 30 09:53:56 2013 +0200
@@ -15,12 +15,12 @@
 
   val fp_of: Proof.context -> string -> fp option
 
-  val derive_induct_fold_rec_thms_for_types: BNF_Def.BNF list -> thm -> thm list -> thm list ->
-    BNF_Def.BNF list -> BNF_Def.BNF list -> typ list -> typ list -> typ list list list ->
-    int list list -> int list -> term list list -> term list list -> term list list -> term list
-    list list -> thm list list -> term list -> term list -> thm list -> thm list -> Proof.context ->
+  val derive_induct_fold_rec_thms_for_types: BNF_Def.BNF list -> term list -> term list -> thm ->
+    thm list -> thm list -> BNF_Def.BNF list -> BNF_Def.BNF list -> typ list -> typ list ->
+    typ list -> term list list -> thm list list -> term list -> term list -> thm list -> thm list ->
+    Proof.context ->
     (thm * thm list * Args.src list) * (thm list list * Args.src list)
-      * (thm list list * Args.src list)
+    * (thm list list * Args.src list)
   val derive_coinduct_unfold_corec_thms_for_types: Proof.context -> Proof.context ->
     BNF_Def.BNF list -> thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.BNF list ->
     BNF_Def.BNF list -> typ list -> typ list -> typ list -> int list list -> int list list ->
@@ -188,6 +188,31 @@
     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
   end;
 
+val mk_fp_rec_like_fun_types = fst o split_last o binder_types o fastype_of o hd;
+
+fun mk_fp_rec_like lfp As Cs fp_rec_likes0 =
+  map (mk_rec_like lfp As Cs) fp_rec_likes0
+  |> (fn ts => (ts, mk_fp_rec_like_fun_types ts));
+
+fun mk_rec_like_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
+
+fun project_recT fpTs proj =
+  let
+    fun project (Type (s as @{type_name prod}, Ts as [T, U])) =
+        if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts)
+      | project (Type (s, Ts)) = Type (s, map project Ts)
+      | project T = T;
+  in project end;
+
+fun unzip_recT fpTs T =
+  if exists_subtype_in fpTs T then ([project_recT fpTs fst T], [project_recT fpTs snd T])
+  else ([T], []);
+
+fun massage_rec_fun_arg_typesss fpTs = map (map (flat_rec (unzip_recT fpTs)));
+
+val mk_fold_fun_typess = map2 (map2 (curry (op --->)));
+val mk_rec_fun_typess = mk_fold_fun_typess oo massage_rec_fun_arg_typesss;
+
 fun mk_map live Ts Us t =
   let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
@@ -243,11 +268,16 @@
     val Ts' = map domain_type (fst (strip_typeN live (fastype_of rel)));
   in Term.list_comb (rel, map build_arg Ts') end;
 
-fun derive_induct_fold_rec_thms_for_types pre_bnfs ctor_induct ctor_fold_thms ctor_rec_thms
-    nesting_bnfs nested_bnfs fpTs Cs ctr_Tsss mss ns gss hss ctrss xsss ctr_defss folds recs
-    fold_defs rec_defs lthy =
+fun derive_induct_fold_rec_thms_for_types pre_bnfs ctor_folds0 ctor_recs0 ctor_induct ctor_fold_thms
+    ctor_rec_thms nesting_bnfs nested_bnfs fpTs Cs As ctrss ctr_defss folds recs fold_defs rec_defs
+    lthy =
   let
+    val ctr_Tsss = map (map (binder_types o fastype_of)) ctrss;
+
     val nn = length pre_bnfs;
+    val ns = map length ctr_Tsss;
+    val mss = map (map length) ctr_Tsss;
+    val Css = map2 replicate ns Cs;
 
     val pre_map_defs = map map_def_of_bnf pre_bnfs;
     val pre_set_defss = map set_defs_of_bnf pre_bnfs;
@@ -258,11 +288,23 @@
 
     val fp_b_names = map base_name_of_typ fpTs;
 
-    val (((ps, ps'), us'), names_lthy) =
+    val (_, ctor_fold_fun_Ts) = mk_fp_rec_like true As Cs ctor_folds0;
+    val (_, ctor_rec_fun_Ts) = mk_fp_rec_like true As Cs ctor_recs0;
+
+    val y_Tsss = map3 mk_rec_like_fun_arg_typess ns mss ctor_fold_fun_Ts;
+    val g_Tss = mk_fold_fun_typess y_Tsss Css;
+
+    val z_Tsss = map3 mk_rec_like_fun_arg_typess ns mss ctor_rec_fun_Ts;
+    val h_Tss = mk_rec_fun_typess fpTs z_Tsss Css;
+
+    val (((((ps, ps'), xsss), gss), us'), names_lthy) =
       lthy
       |> mk_Frees' "P" (map mk_pred1T fpTs)
+      ||>> mk_Freesss "x" ctr_Tsss
+      ||>> mk_Freess "f" g_Tss
       ||>> Variable.variant_fixes fp_b_names;
 
+    val hss = map2 (map2 retype_free) h_Tss gss;
     val us = map2 (curry Free) us' fpTs;
 
     fun mk_sets_nested bnf =
@@ -831,40 +873,24 @@
     val mss = map (map length) ctr_Tsss;
     val Css = map2 replicate ns Cs;
 
-    val fp_folds as any_fp_fold :: _ = map (mk_rec_like lfp As Cs) fp_folds0;
-    val fp_recs as any_fp_rec :: _ = map (mk_rec_like lfp As Cs) fp_recs0;
+    val (fp_folds, fp_fold_fun_Ts) = mk_fp_rec_like lfp As Cs fp_folds0;
+    val (fp_recs, fp_rec_fun_Ts) = mk_fp_rec_like lfp As Cs fp_recs0;
 
-    val fp_fold_fun_Ts = fst (split_last (binder_types (fastype_of any_fp_fold)));
-    val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of any_fp_rec)));
-
-    val (((fold_only as (gss, _, _), rec_only as (hss, _, _)),
+    val (((fold_only, rec_only),
           (cs, cpss, unfold_only as ((pgss, crssss, cgssss), (_, g_Tsss, _)),
            corec_only as ((phss, csssss, chssss), (_, h_Tsss, _)))), names_lthy0) =
       if lfp then
         let
-          val y_Tsss =
-            map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
-              ns mss fp_fold_fun_Ts;
-          val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
+          val y_Tsss = map3 mk_rec_like_fun_arg_typess ns mss fp_fold_fun_Ts;
+          val g_Tss = mk_fold_fun_typess y_Tsss Css;
 
           val ((gss, ysss), lthy) =
             lthy
             |> mk_Freess "f" g_Tss
             ||>> mk_Freesss "x" y_Tsss;
 
-          fun proj_recT proj (Type (s as @{type_name prod}, Ts as [T, U])) =
-              if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_recT proj) Ts)
-            | proj_recT proj (Type (s, Ts)) = Type (s, map (proj_recT proj) Ts)
-            | proj_recT _ T = T;
-
-          fun unzip_recT T =
-            if exists_subtype_in fpTs T then ([proj_recT fst T], [proj_recT snd T]) else ([T], []);
-
-          val z_Tsss =
-            map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
-              ns mss fp_rec_fun_Ts;
-          val z_Tsss' = map (map (flat_rec unzip_recT)) z_Tsss;
-          val h_Tss = map2 (map2 (curry (op --->))) z_Tsss' Css;
+          val z_Tsss = map3 mk_rec_like_fun_arg_typess ns mss fp_rec_fun_Ts;
+          val h_Tss = mk_rec_fun_typess fpTs z_Tsss Css;
 
           val hss = map2 (map2 retype_free) h_Tss gss;
           val zsss = map2 (map2 (map2 retype_free)) z_Tsss ysss;
@@ -1252,14 +1278,14 @@
         injects @ distincts @ case_thms @ rec_likes @ fold_likes);
 
     fun derive_and_note_induct_fold_rec_thms_for_types
-        (((ctrss, xsss, ctr_defss, ctr_wrap_ress), (folds, recs, fold_defs, rec_defs)), lthy) =
+        (((ctrss, _, ctr_defss, ctr_wrap_ress), (folds, recs, fold_defs, rec_defs)), lthy) =
       let
         val ((induct_thm, induct_thms, induct_attrs),
              (fold_thmss, fold_attrs),
              (rec_thmss, rec_attrs)) =
-          derive_induct_fold_rec_thms_for_types pre_bnfs fp_induct fp_fold_thms fp_rec_thms
-            nesting_bnfs nested_bnfs fpTs Cs ctr_Tsss mss ns gss hss ctrss xsss ctr_defss folds recs
-            fold_defs rec_defs lthy;
+          derive_induct_fold_rec_thms_for_types pre_bnfs fp_folds0 fp_recs0 fp_induct fp_fold_thms
+            fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs As ctrss ctr_defss folds recs fold_defs
+            rec_defs lthy;
 
         fun induct_type_attr T_name = Attrib.internal (K (Induct.induct_type T_name));