signature tuning
authorblanchet
Tue, 30 Apr 2013 10:58:25 +0200
changeset 51829 3cc93eeac8cc
parent 51828 67c6d6136915
child 51830 403f7ecd061f
signature tuning
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Apr 30 10:07:41 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue Apr 30 10:58:25 2013 +0200
@@ -21,14 +21,11 @@
     Proof.context ->
     (thm * thm list * Args.src list) * (thm list list * Args.src list)
     * (thm list list * Args.src list)
-  val derive_coinduct_unfold_corec_thms_for_types: Proof.context -> Proof.context ->
-    BNF_Def.BNF list -> thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.BNF list ->
-    BNF_Def.BNF list -> typ list -> typ list -> typ list -> int list list -> int list list ->
-    int list -> term list -> term list list -> term list list -> term list list list list ->
-    term list list list list -> term list list -> term list list list list ->
-    term list list list list -> term list list -> thm list list ->
-    BNF_Ctr_Sugar.ctr_wrap_result list -> term list -> term list -> thm list -> thm list ->
-    Proof.context ->
+  val derive_coinduct_unfold_corec_thms_for_types: BNF_Def.BNF list -> term list -> term list ->
+    thm -> thm -> thm list -> thm list -> thm list -> BNF_Def.BNF list -> BNF_Def.BNF list ->
+    typ list -> typ list -> typ list -> int list list -> int list list -> int list ->
+    term list list -> thm list list -> BNF_Ctr_Sugar.ctr_wrap_result list -> term list ->
+    term list -> thm list -> thm list -> Proof.context ->
     (thm * thm list * thm * thm list * Args.src list) * (thm list list * thm list list * 'e list)
     * (thm list list * thm list list) * (thm list list * thm list list * Args.src list)
     * (thm list list * thm list list * Args.src list)
@@ -158,6 +155,12 @@
     maps fst ps @ maps snd ps
   end;
 
+fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss);
+
+fun flat_preds_predsss_gettersss [] [qss] [fss] = flat_predss_getterss qss fss
+  | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
+    p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
+
 fun mk_flip (x, Type (_, [T1, Type (_, [T2, T3])])) =
   Abs ("x", T1, Abs ("y", T2, Var (x, T2 --> T1 --> T3) $ Bound 0 $ Bound 1));
 
@@ -196,23 +199,86 @@
 
 fun mk_rec_like_fun_arg_typess n ms = map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type;
 
-fun project_recT fpTs proj =
+fun massage_rec_fun_arg_typesss fpTs =
   let
-    fun project (Type (s as @{type_name prod}, Ts as [T, U])) =
-        if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts)
-      | project (Type (s, Ts)) = Type (s, map project Ts)
-      | project T = T;
-  in project end;
-
-fun unzip_recT fpTs T =
-  if exists_subtype_in fpTs T then ([project_recT fpTs fst T], [project_recT fpTs snd T])
-  else ([T], []);
-
-fun massage_rec_fun_arg_typesss fpTs = map (map (flat_rec (unzip_recT fpTs)));
+    fun project_recT proj =
+      let
+        fun project (Type (s as @{type_name prod}, Ts as [T, U])) =
+            if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts)
+          | project (Type (s, Ts)) = Type (s, map project Ts)
+          | project T = T;
+      in project end;
+    fun unzip_recT T =
+      if exists_subtype_in fpTs T then ([project_recT fst T], [project_recT snd T]) else ([T], []);
+  in
+    map (map (flat_rec unzip_recT))
+  end;
 
 val mk_fold_fun_typess = map2 (map2 (curry (op --->)));
 val mk_rec_fun_typess = mk_fold_fun_typess oo massage_rec_fun_arg_typesss;
 
+fun mk_corec_like_pred_types n = replicate (Int.max (0, n - 1)) o mk_pred1T;
+
+fun mk_unfold_corec_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts =
+  let
+    (*avoid "'a itself" arguments in coiterators and corecursors*)
+    fun repair_arity [0] = [1]
+      | repair_arity ms = ms;
+
+    fun project_corecT proj =
+      let
+        fun project (Type (s as @{type_name sum}, Ts as [T, U])) =
+            if member (op =) fpTs T then proj (T, U) else Type (s, map project Ts)
+          | project (Type (s, Ts)) = Type (s, map project Ts)
+          | project T = T;
+      in project end;
+
+    fun unzip_corecT T =
+      if exists_subtype_in fpTs T then [project_corecT fst T, project_corecT snd T] else [T];
+
+    val p_Tss = map2 mk_corec_like_pred_types ns Cs;
+
+    fun mk_types maybe_unzipT fun_Ts =
+      let
+        val f_sum_prod_Ts = map range_type fun_Ts;
+        val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
+        val f_Tsss = map2 (map2 dest_tupleT o repair_arity) mss f_prod_Tss;
+        val f_Tssss =
+          map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
+        val q_Tssss =
+          map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss;
+        val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
+      in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end;
+  in
+    (p_Tss, mk_types single dtor_unfold_fun_Ts, mk_types unzip_corecT dtor_corec_fun_Ts)
+  end
+
+fun mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss lthy =
+  let
+    val (((cs, pss), gssss), lthy) =
+      lthy
+      |> mk_Frees "a" Cs
+      ||>> mk_Freess "p" p_Tss
+      ||>> mk_Freessss "g" g_Tssss;
+    val rssss = map (map (map (fn [] => []))) r_Tssss;
+
+    val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
+    val ((sssss, hssss_tl), lthy) =
+      lthy
+      |> mk_Freessss "q" s_Tssss
+      ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
+    val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
+  in
+    ((cs, pss, (gssss, rssss), (hssss, sssss)), lthy)
+  end;
+
+fun mk_corec_like_terms cs pss qssss fssss =
+  let
+    val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
+    val cqssss = map2 (map o map o map o rapp) cs qssss;
+    val cfssss = map2 (map o map o map o rapp) cs fssss;
+  in (pfss, cqssss, cfssss) end;
+
 fun mk_map live Ts Us t =
   let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
     Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
@@ -440,10 +506,9 @@
      (fold_thmss, code_simp_attrs), (rec_thmss, code_simp_attrs))
   end;
 
-fun derive_coinduct_unfold_corec_thms_for_types names_lthy0 no_defs_lthy pre_bnfs dtor_coinduct
+fun derive_coinduct_unfold_corec_thms_for_types pre_bnfs dtor_unfolds0 dtor_corecs0 dtor_coinduct
     dtor_strong_induct dtor_ctors dtor_unfold_thms dtor_corec_thms nesting_bnfs nested_bnfs fpTs Cs
-    As kss mss ns cs cpss pgss crssss cgssss phss csssss chssss ctrss ctr_defss ctr_wrap_ress
-    unfolds corecs unfold_defs corec_defs lthy =
+    As kss mss ns ctrss ctr_defss ctr_wrap_ress unfolds corecs unfold_defs corec_defs lthy =
   let
     val nn = length pre_bnfs;
 
@@ -457,6 +522,9 @@
 
     val fp_b_names = map base_name_of_typ fpTs;
 
+    val (_, dtor_unfold_fun_Ts) = mk_fp_rec_like false As Cs dtor_unfolds0;
+    val (_, dtor_corec_fun_Ts) = mk_fp_rec_like false As Cs dtor_corecs0;
+
     val discss = map (map (mk_disc_or_sel As) o #discs) ctr_wrap_ress;
     val selsss = map (map (map (mk_disc_or_sel As)) o #selss) ctr_wrap_ress;
     val exhausts = map #exhaust ctr_wrap_ress;
@@ -470,6 +538,15 @@
       ||>> Variable.variant_fixes fp_b_names
       ||>> Variable.variant_fixes (map (suffix "'") fp_b_names);
 
+    val (p_Tss, (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss),
+         (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss)) =
+      mk_unfold_corec_types fpTs Cs ns mss dtor_unfold_fun_Ts dtor_corec_fun_Ts;
+
+    val ((cs, pss, (gssss, rssss), (hssss, sssss)), names_lthy) =
+      mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss names_lthy;
+
+    val cpss = map2 (map o rapp) cs pss;
+
     val us = map2 (curry Free) us' fpTs;
     val udiscss = map2 (map o rapp) us discss;
     val uselsss = map2 (map o map o rapp) us selsss;
@@ -478,6 +555,9 @@
     val vdiscss = map2 (map o rapp) vs discss;
     val vselsss = map2 (map o map o rapp) vs selsss;
 
+    val (pgss, crssss, cgssss) = mk_corec_like_terms cs pss rssss gssss;
+    val (phss, csssss, chssss) = mk_corec_like_terms cs pss sssss hssss;
+
     val ((coinduct_thms, coinduct_thm), (strong_coinduct_thms, strong_coinduct_thm)) =
       let
         val uvrs = map3 (fn r => fn u => fn v => r $ u $ v) rs us vs;
@@ -652,7 +732,7 @@
 
         fun prove goal tac =
           Goal.prove_sorry lthy [] [] goal (tac o #context)
-          |> singleton (Proof_Context.export names_lthy0 no_defs_lthy)
+          |> singleton (Proof_Context.export names_lthy lthy)
           |> Thm.close_derivation;
 
         fun proves [_] [_] = []
@@ -894,68 +974,18 @@
         end
       else
         let
-          (*avoid "'a itself" arguments in coiterators and corecursors*)
-          val mss' =  map (fn [0] => [1] | ms => ms) mss;
-
-          val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_pred1T) ns Cs;
-
-          fun flat_predss_getterss qss fss = maps (op @) (qss ~~ fss);
-
-          fun flat_preds_predsss_gettersss [] [qss] [fss] = flat_predss_getterss qss fss
-            | flat_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
-              p :: flat_predss_getterss qss fss @ flat_preds_predsss_gettersss ps qsss fsss;
-
-          fun mk_types maybe_unzipT fun_Ts =
-            let
-              val f_sum_prod_Ts = map range_type fun_Ts;
-              val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
-              val f_Tsss = map2 (map2 dest_tupleT) mss' f_prod_Tss;
-              val f_Tssss =
-                map2 (fn C => map (map (map (curry (op -->) C) o maybe_unzipT))) Cs f_Tsss;
-              val q_Tssss =
-                map (map (map (fn [_] => [] | [_, C] => [mk_pred1T (domain_type C)]))) f_Tssss;
-              val pf_Tss = map3 flat_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
-            in (q_Tssss, f_sum_prod_Ts, f_Tsss, f_Tssss, pf_Tss) end;
-
-          val (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss) = mk_types single fp_fold_fun_Ts;
+          val (p_Tss, (r_Tssss, g_sum_prod_Ts, g_Tsss, g_Tssss, pg_Tss),
+               (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss)) =
+            mk_unfold_corec_types fpTs Cs ns mss fp_fold_fun_Ts fp_rec_fun_Ts;
 
-          val (((cs, pss), gssss), lthy) =
-            lthy
-            |> mk_Frees "a" Cs
-            ||>> mk_Freess "p" p_Tss
-            ||>> mk_Freessss "g" g_Tssss;
-          val rssss = map (map (map (fn [] => []))) r_Tssss;
-
-          fun proj_corecT proj (Type (s as @{type_name sum}, Ts as [T, U])) =
-              if member (op =) fpTs T then proj (T, U) else Type (s, map (proj_corecT proj) Ts)
-            | proj_corecT proj (Type (s, Ts)) = Type (s, map (proj_corecT proj) Ts)
-            | proj_corecT _ T = T;
-
-          fun unzip_corecT T =
-            if exists_subtype_in fpTs T then [proj_corecT fst T, proj_corecT snd T] else [T];
-
-          val (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss) =
-            mk_types unzip_corecT fp_rec_fun_Ts;
-
-          val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
-          val ((sssss, hssss_tl), lthy) =
-            lthy
-            |> mk_Freessss "q" s_Tssss
-            ||>> mk_Freessss "h" (map (map (map tl)) h_Tssss);
-          val hssss = map2 (map2 (map2 cons)) hssss_hd hssss_tl;
+          val ((cs, pss, (gssss, rssss), (hssss, sssss)), lthy) =
+            mk_unfold_corec_vars Cs p_Tss g_Tssss r_Tssss h_Tssss s_Tssss lthy;
 
           val cpss = map2 (map o rapp) cs pss;
-
-          fun mk_terms qssss fssss =
-            let
-              val pfss = map3 flat_preds_predsss_gettersss pss qssss fssss;
-              val cqssss = map2 (map o map o map o rapp) cs qssss;
-              val cfssss = map2 (map o map o map o rapp) cs fssss;
-            in (pfss, cqssss, cfssss) end;
         in
           (((([], [], []), ([], [], [])),
-            (cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, g_Tsss, pg_Tss)),
-             (mk_terms sssss hssss, (h_sum_prod_Ts, h_Tsss, ph_Tss)))), lthy)
+            (cs, cpss, (mk_corec_like_terms cs pss rssss gssss, (g_sum_prod_Ts, g_Tsss, pg_Tss)),
+             (mk_corec_like_terms cs pss sssss hssss, (h_sum_prod_Ts, h_Tsss, ph_Tss)))), lthy)
         end;
 
     fun define_ctrs_case_for_type (((((((((((((((((((((((((fp_bnf, fp_b), fpT), C), ctor), dtor),
@@ -1311,10 +1341,9 @@
              (disc_unfold_thmss, disc_corec_thmss, disc_corec_like_attrs),
              (disc_unfold_iff_thmss, disc_corec_iff_thmss, disc_corec_like_iff_attrs),
              (sel_unfold_thmss, sel_corec_thmss, sel_corec_like_attrs)) =
-          derive_coinduct_unfold_corec_thms_for_types names_lthy0 no_defs_lthy pre_bnfs fp_induct
+          derive_coinduct_unfold_corec_thms_for_types pre_bnfs fp_folds0 fp_recs0 fp_induct
             fp_strong_induct dtor_ctors fp_fold_thms fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs As
-            kss mss ns cs cpss pgss crssss cgssss phss csssss chssss ctrss ctr_defss ctr_wrap_ress
-            unfolds corecs unfold_defs corec_defs lthy;
+            kss mss ns ctrss ctr_defss ctr_wrap_ress unfolds corecs unfold_defs corec_defs lthy;
 
         fun coinduct_type_attr T_name = Attrib.internal (K (Induct.coinduct_type T_name));