more general, reliable N2M
authorblanchet
Tue, 22 Mar 2016 07:57:02 +0100
changeset 62688 a3cccaef566a
parent 62687 1c4842b32bfb
child 62689 9b8b3db6ac03
more general, reliable N2M
src/Benchmarks/Datatype_Benchmark/Misc_N2M.thy
src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML
--- a/src/Benchmarks/Datatype_Benchmark/Misc_N2M.thy	Tue Mar 22 07:57:02 2016 +0100
+++ b/src/Benchmarks/Datatype_Benchmark/Misc_N2M.thy	Tue Mar 22 07:57:02 2016 +0100
@@ -102,7 +102,7 @@
 codatatype 'a M = CM "('a, 'a M) M0"
 codatatype 'a N = CN "('a N, 'a) N0"
 codatatype ('a, 'b) K = CK "('a, ('a, 'b) L) K0"
-         and ('a, 'b) L = CL "('b, ('a, 'b) K) L0"
+       and ('a, 'b) L = CL "('b, ('a, 'b) K) L0"
 codatatype 'a G = CG "('a, ('a G, 'a G N) K, ('a G M, 'a G) L) G0"
 
 primcorec
@@ -144,12 +144,6 @@
 datatype_compat ttttree
 *)
 
-datatype ('a,'b)complex = 
-  C1 nat "'a ttree" 
-| C2 "('a,'b)complex list tree tree" 'b "('a,'b)complex" "('a,'b)complex2 ttree list"
-and ('a,'b)complex2 = D1 "('a,'b)complex ttree"
-datatype_compat complex complex2
-
 datatype 'a F = F1 'a | F2 "'a F"
 datatype 'a G = G1 'a | G2 "'a G F"
 datatype H = H1 | H2 "H G"
@@ -167,7 +161,6 @@
 context linorder
 begin
 
-(* slow *)
 primrec
   f1_tl :: "(nat, 'a) t l \<Rightarrow> nat" and
   f1_t :: "(nat, 'a) t \<Rightarrow> nat"
@@ -176,7 +169,7 @@
   "f1_tl (C t ts) = f1_t t + f1_tl ts" |
   "f1_t (T n _ ts) = n + f1_tl ts"
 
-(* should be fast *)
+(* should hit cache *)
 primrec
   f2_t :: "(nat, 'b) t \<Rightarrow> nat" and
   f2_tl :: "(nat, 'b) t l \<Rightarrow> nat"
@@ -187,7 +180,7 @@
 
 end
 
-(* should be fast *)
+(* should hit cache *)
 primrec
   g1_t :: "('a, int) t \<Rightarrow> nat" and
   g1_tl :: "('a, int) t l \<Rightarrow> nat"
@@ -196,7 +189,7 @@
   "g1_tl N = 0" |
   "g1_tl (C _ ts) = g1_tl ts"
 
-(* should be fast *)
+(* should hit cache *)
 primrec
   g2_t :: "(int, int) t \<Rightarrow> nat" and
   g2_tl :: "(int, int) t l \<Rightarrow> nat"
@@ -223,7 +216,6 @@
   ('a, 'b) t1 = T1 'a 'b "('a, 'b) t1 l1" and
   ('a, 'b) t2 = T2 "('a, 'b) t1"
 
-(* slow *)
 primrec
   h1_tl1 :: "(nat, 'a) t1 l1 \<Rightarrow> nat" and
   h1_tl2 :: "(nat, 'a) t1 l2 \<Rightarrow> nat" and
@@ -235,7 +227,7 @@
   "h1_tl2 (C2 t ts) = h1_t1 t + h1_tl1 ts" |
   "h1_t1 (T1 n _ ts) = n + h1_tl1 ts"
 
-(* should be fast *)
+(* should hit cache *)
 primrec
   h2_tl1 :: "(nat, 'a) t1 l1 \<Rightarrow> nat" and
   h2_tl2 :: "(nat, 'a) t1 l2 \<Rightarrow> nat" and
@@ -247,7 +239,7 @@
   "h2_tl2 (C2 t ts) = h2_t1 t + h2_tl1 ts" |
   "h2_t1 (T1 n _ ts) = n + h2_tl1 ts"
 
-(* should be fast *)
+(* should hit cache *)
 primrec
   h3_tl2 :: "(nat, 'a) t1 l2 \<Rightarrow> nat" and
   h3_tl1 :: "(nat, 'a) t1 l1 \<Rightarrow> nat" and
@@ -259,7 +251,7 @@
   "h3_tl2 (C2 t ts) = h3_t1 t + h3_tl1 ts" |
   "h3_t1 (T1 n _ ts) = n + h3_tl1 ts"
 
-(* should be fast *)
+(* should hit cache *)
 primrec
   i1_tl2 :: "(nat, 'a) t1 l2 \<Rightarrow> nat" and
   i1_tl1 :: "(nat, 'a) t1 l1 \<Rightarrow> nat" and
@@ -273,7 +265,7 @@
   "i1_t1 (T1 n _ ts) = n + i1_tl1 ts" |
   "i1_t2 (T2 t) = i1_t1 t"
 
-(* should be fast *)
+(* should hit cache *)
 primrec
   j1_t2 :: "(nat, 'a) t2 \<Rightarrow> nat" and
   j1_t1 :: "(nat, 'a) t1 \<Rightarrow> nat" and
@@ -292,7 +284,6 @@
 datatype 'a l4 = N4 | C4 'a "'a l4"
 datatype ('a, 'b) t3 = T3 'a 'b "('a, 'b) t3 l3" "('a, 'b) t3 l4"
 
-(* slow *)
 primrec
   k1_tl3 :: "(nat, 'a) t3 l3 \<Rightarrow> nat" and
   k1_tl4 :: "(nat, 'a) t3 l4 \<Rightarrow> nat" and
@@ -304,7 +295,7 @@
   "k1_tl4 (C4 t ts) = k1_t3 t + k1_tl4 ts" |
   "k1_t3 (T3 n _ ts us) = n + k1_tl3 ts + k1_tl4 us"
 
-(* should be fast *)
+(* should hit cache *)
 primrec
   k2_tl4 :: "(nat, int) t3 l4 \<Rightarrow> nat" and
   k2_tl3 :: "(nat, int) t3 l3 \<Rightarrow> nat" and
@@ -321,7 +312,6 @@
 datatype ('a, 'b) l6 = N6 | C6 'a 'b "('a, 'b) l6"
 datatype ('a, 'b, 'c) t4 = T4 'a 'b "(('a, 'b, 'c) t4, 'b) l5" "(('a, 'b, 'c) t4, 'c) l6"
 
-(* slow *)
 primrec
   l1_tl5 :: "((nat, 'a, 'b) t4, 'a) l5 \<Rightarrow> nat" and
   l1_tl6 :: "((nat, 'a, 'b) t4, 'b) l6 \<Rightarrow> nat" and
@@ -342,7 +332,6 @@
 context linorder
 begin
 
-(* slow *)
 primcorec
   f1_cotcol :: "nat \<Rightarrow> (nat, 'a) cot col" and
   f1_cot :: "nat \<Rightarrow> (nat, 'a) cot"
@@ -351,7 +340,7 @@
   "_ \<Longrightarrow> f1_cotcol n = C (f1_cot n) (f1_cotcol n)" |
   "f1_cot n = T n undefined (f1_cotcol n)"
 
-(* should be fast *)
+(* should hit cache *)
 primcorec
   f2_cotcol :: "nat \<Rightarrow> (nat, 'b) cot col" and
   f2_cot :: "nat \<Rightarrow> (nat, 'b) cot"
@@ -362,7 +351,7 @@
 
 end
 
-(* should be fast *)
+(* should hit cache *)
 primcorec
   g1_cot :: "int \<Rightarrow> (int, 'a) cot" and
   g1_cotcol :: "int \<Rightarrow> (int, 'a) cot col"
@@ -371,7 +360,7 @@
   "n = 0 \<Longrightarrow> g1_cotcol n = N" |
   "_ \<Longrightarrow> g1_cotcol n = C (g1_cot n) (g1_cotcol n)"
 
-(* should be fast *)
+(* should hit cache *)
 primcorec
   g2_cot :: "int \<Rightarrow> (int, int) cot" and
   g2_cotcol :: "int \<Rightarrow> (int, int) cot col"
@@ -389,7 +378,6 @@
   ('a, 'b) cot1 = T1 'a 'b "('a, 'b) cot1 col1" and
   ('a, 'b) cot2 = T2 "('a, 'b) cot1"
 
-(* slow *)
 primcorec
   h1_cotcol1 :: "nat \<Rightarrow> (nat, 'a) cot1 col1" and
   h1_cotcol2 :: "nat \<Rightarrow> (nat, 'a) cot1 col2" and
@@ -399,7 +387,7 @@
   "h1_cotcol2 n = C2 (h1_cot1 n) (h1_cotcol1 n)" |
   "h1_cot1 n = T1 n undefined (h1_cotcol1 n)"
 
-(* should be fast *)
+(* should hit cache *)
 primcorec
   h2_cotcol1 :: "nat \<Rightarrow> (nat, 'a) cot1 col1" and
   h2_cotcol2 :: "nat \<Rightarrow> (nat, 'a) cot1 col2" and
@@ -409,7 +397,7 @@
   "h2_cotcol2 n = C2 (h2_cot1 n) (h2_cotcol1 n)" |
   "h2_cot1 n = T1 n undefined (h2_cotcol1 n)"
 
-(* should be fast *)
+(* should hit cache *)
 primcorec
   h3_cotcol2 :: "nat \<Rightarrow> (nat, 'a) cot1 col2" and
   h3_cotcol1 :: "nat \<Rightarrow> (nat, 'a) cot1 col1" and
@@ -419,7 +407,7 @@
   "h3_cotcol2 n = C2 (h3_cot1 n) (h3_cotcol1 n)" |
   "h3_cot1 n = T1 n undefined (h3_cotcol1 n)"
 
-(* should be fast *)
+(* should hit cache *)
 primcorec
   i1_cotcol2 :: "nat \<Rightarrow> (nat, 'a) cot1 col2" and
   i1_cotcol1 :: "nat \<Rightarrow> (nat, 'a) cot1 col1" and
@@ -431,7 +419,7 @@
   "i1_cot1 n = T1 n undefined (i1_cotcol1 n)" |
   "i1_cot2 n = T2 (i1_cot1 n)"
 
-(* should be fast *)
+(* should hit cache *)
 primcorec
   j1_cot2 :: "nat \<Rightarrow> (nat, 'a) cot2" and
   j1_cot1 :: "nat \<Rightarrow> (nat, 'a) cot1" and
@@ -448,7 +436,6 @@
 codatatype 'a col4 = N4 | C4 'a "'a col4"
 codatatype ('a, 'b) cot3 = T3 'a 'b "('a, 'b) cot3 col3" "('a, 'b) cot3 col4"
 
-(* slow *)
 primcorec
   k1_cotcol3 :: "nat \<Rightarrow> (nat, 'a) cot3 col3" and
   k1_cotcol4 :: "nat \<Rightarrow> (nat, 'a) cot3 col4" and
@@ -458,7 +445,7 @@
   "k1_cotcol4 n = C4 (k1_cot3 n) (k1_cotcol4 n)" |
   "k1_cot3 n = T3 n undefined (k1_cotcol3 n) (k1_cotcol4 n)"
 
-(* should be fast *)
+(* should hit cache *)
 primcorec
   k2_cotcol4 :: "nat \<Rightarrow> (nat, 'a) cot3 col4" and
   k2_cotcol3 :: "nat \<Rightarrow> (nat, 'a) cot3 col3" and
@@ -468,4 +455,10 @@
   "k2_cotcol3 n = C3 (k2_cot3 n) (k2_cotcol3 n)" |
   "k2_cot3 n = T3 n undefined (k2_cotcol3 n) (k2_cotcol4 n)"
 
+datatype ('a,'b)complex =
+  C1 nat "'a ttree"
+| C2 "('a,'b)complex list tree tree" 'b "('a,'b)complex" "('a,'b)complex2 ttree list"
+and ('a,'b)complex2 = D1 "('a,'b)complex ttree"
+datatype_compat complex complex2
+
 end
--- a/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Tue Mar 22 07:57:02 2016 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Tue Mar 22 07:57:02 2016 +0100
@@ -400,44 +400,38 @@
         else
           not_co_datatype0 T
       | not_co_datatype T = not_co_datatype0 T;
-    fun not_mutually_nested_rec Ts1 Ts2 =
-      error (qsotys Ts1 ^ " is neither mutually " ^ co_prefix fp ^ "recursive with " ^ qsotys Ts2 ^
-        " nor nested " ^ co_prefix fp ^ "recursive through " ^
-        (if Ts1 = Ts2 andalso length Ts1 = 1 then "itself" else qsotys Ts2));
 
     val sorted_actual_Ts =
       sort (prod_ord int_ord Term_Ord.typ_ord o apply2 (`Term.size_of_typ)) actual_Ts;
 
     fun the_ctrs_of (Type (s, Ts)) = map (mk_ctr Ts) (#ctrs (the (ctr_sugar_of lthy s)));
 
+    fun gen_rhss_in gen_Ts rho (subTs as Type (_, sub_tyargs) :: _) =
+      let
+        fun maybe_insert (T, Type (_, gen_tyargs)) =
+            member (op =) subTs T ? insert (op =) gen_tyargs
+          | maybe_insert _ = I;
+
+        val gen_ctrs = maps the_ctrs_of gen_Ts;
+        val gen_ctr_Ts = maps (binder_types o fastype_of) gen_ctrs;
+        val ctr_Ts = map (Term.typ_subst_atomic rho) gen_ctr_Ts;
+      in
+        (case fold (fold_subtype_pairs maybe_insert) (ctr_Ts ~~ gen_ctr_Ts) [] of
+          [] => [map (typ_subst_nonatomic (map swap rho)) sub_tyargs]
+        | gen_tyargss => gen_tyargss)
+      end;
+
     fun the_fp_sugar_of (T as Type (T_name, _)) =
       (case fp_sugar_of lthy T_name of
         SOME (fp_sugar as {fp = fp', ...}) => if fp = fp' then fp_sugar else not_co_datatype T
       | NONE => not_co_datatype T);
 
-    fun gen_rhss_in gen_Ts rho subTs =
-      let
-        fun maybe_insert (T, Type (_, gen_tyargs)) =
-            if member (op =) subTs T then insert (op =) gen_tyargs else I
-          | maybe_insert _ = I;
-
-        val ctrs = maps the_ctrs_of gen_Ts;
-        val gen_ctr_Ts = maps (binder_types o fastype_of) ctrs;
-        val ctr_Ts = map (Term.typ_subst_atomic rho) gen_ctr_Ts;
-      in
-        fold (fold_subtype_pairs maybe_insert) (ctr_Ts ~~ gen_ctr_Ts) []
-      end;
-
     fun gather_types _ _ rev_seens gen_seen [] = (rev rev_seens, gen_seen)
       | gather_types lthy rho rev_seens gen_seen ((T as Type (_, tyargs)) :: Ts) =
         let
           val {T = Type (_, tyargs0), fp_res = {Ts = mutual_Ts0, ...}, ...} = the_fp_sugar_of T;
           val mutual_Ts = map (retypargs tyargs) mutual_Ts0;
 
-          val rev_seen = flat rev_seens;
-          val _ = null rev_seens orelse exists (exists_strict_subtype_in rev_seen) mutual_Ts orelse
-            not_mutually_nested_rec mutual_Ts rev_seen;
-
           fun fresh_tyargs () =
             let
               val (unsorted_gen_tyargs, lthy') =
@@ -451,7 +445,7 @@
             end;
 
           val (rho', gen_tyargs, gen_seen', lthy') =
-            if exists (exists_subtype_in rev_seen) mutual_Ts then
+            if exists (exists_subtype_in (flat rev_seens)) mutual_Ts then
               (case gen_rhss_in gen_seen rho mutual_Ts of
                 [] => fresh_tyargs ()
               | gen_tyargs :: gen_tyargss_tl =>