properly detect when to perform n2m -- e.g. handle the case of two independent functions on irrelevant types being defined in parallel
authorblanchet
Thu, 07 Nov 2013 00:37:18 +0100
changeset 54286 22616f65d4ea
parent 54285 578371ba74cc
child 54287 7f096d8eb3d0
properly detect when to perform n2m -- e.g. handle the case of two independent functions on irrelevant types being defined in parallel
src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
src/HOL/BNF/Tools/bnf_lfp_compat.ML
--- a/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML	Wed Nov 06 23:05:44 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML	Thu Nov 07 00:37:18 2013 +0100
@@ -10,9 +10,8 @@
   val unfold_let: term -> term
   val dest_map: Proof.context -> string -> term -> term * term list
 
-  val mutualize_fp_sugars: bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
-    (term -> int list) -> term list list list list -> BNF_FP_Def_Sugar.fp_sugar list ->
-    local_theory ->
+  val mutualize_fp_sugars: BNF_FP_Util.fp_kind -> binding list -> typ list -> (term -> int list) ->
+    term list list list list -> BNF_FP_Def_Sugar.fp_sugar list -> local_theory ->
     (BNF_FP_Def_Sugar.fp_sugar list
      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
     * local_theory
@@ -109,157 +108,150 @@
   Type (fp_case fp "l" "g", fpTs @ maps (fn (x, T) => [TFree x, T]) fp_eqs);
 
 (* TODO: test with sort constraints on As *)
-(* TODO: use right sorting order for "fp_sort" w.r.t. original BNFs (?) -- treat new variables
-   as deads? *)
-fun mutualize_fp_sugars has_nested fp bs fpTs get_indices callssss fp_sugars0 no_defs_lthy0 =
-  if has_nested then
-    let
-      val thy = Proof_Context.theory_of no_defs_lthy0;
+fun mutualize_fp_sugars fp bs fpTs get_indices callssss fp_sugars0 no_defs_lthy0 =
+  let
+    val thy = Proof_Context.theory_of no_defs_lthy0;
 
-      val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
+    val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
 
-      fun incompatible_calls t1 t2 =
-        error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ qsotm t1 ^ " vs. " ^
-          qsotm t2);
+    fun incompatible_calls t1 t2 =
+      error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ qsotm t1 ^ " vs. " ^ qsotm t2);
 
-      val b_names = map Binding.name_of bs;
-      val fp_b_names = map base_name_of_typ fpTs;
+    val b_names = map Binding.name_of bs;
+    val fp_b_names = map base_name_of_typ fpTs;
 
-      val nn = length fpTs;
+    val nn = length fpTs;
 
-      fun target_ctr_sugar_of_fp_sugar fpT ({T, index, ctr_sugars, ...} : fp_sugar) =
-        let
-          val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T, fpT) Vartab.empty) [];
-          val phi = Morphism.term_morphism (Term.subst_TVars rho);
-        in
-          morph_ctr_sugar phi (nth ctr_sugars index)
-        end;
+    fun target_ctr_sugar_of_fp_sugar fpT ({T, index, ctr_sugars, ...} : fp_sugar) =
+      let
+        val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T, fpT) Vartab.empty) [];
+        val phi = Morphism.term_morphism (Term.subst_TVars rho);
+      in
+        morph_ctr_sugar phi (nth ctr_sugars index)
+      end;
 
-      val ctr_defss = map (of_fp_sugar #ctr_defss) fp_sugars0;
-      val mapss = map (of_fp_sugar #mapss) fp_sugars0;
-      val ctr_sugars0 = map2 target_ctr_sugar_of_fp_sugar fpTs fp_sugars0;
+    val ctr_defss = map (of_fp_sugar #ctr_defss) fp_sugars0;
+    val mapss = map (of_fp_sugar #mapss) fp_sugars0;
+    val ctr_sugars0 = map2 target_ctr_sugar_of_fp_sugar fpTs fp_sugars0;
 
-      val ctrss = map #ctrs ctr_sugars0;
-      val ctr_Tss = map (map fastype_of) ctrss;
+    val ctrss = map #ctrs ctr_sugars0;
+    val ctr_Tss = map (map fastype_of) ctrss;
 
-      val As' = fold (fold Term.add_tfreesT) ctr_Tss [];
-      val As = map TFree As';
+    val As' = fold (fold Term.add_tfreesT) ctr_Tss [];
+    val As = map TFree As';
 
-      val ((Cs, Xs), no_defs_lthy) =
-        no_defs_lthy0
-        |> fold Variable.declare_typ As
-        |> mk_TFrees nn
-        ||>> variant_tfrees fp_b_names;
+    val ((Cs, Xs), no_defs_lthy) =
+      no_defs_lthy0
+      |> fold Variable.declare_typ As
+      |> mk_TFrees nn
+      ||>> variant_tfrees fp_b_names;
 
-      fun check_call_dead live_call call =
-        if null (get_indices call) then () else incompatible_calls live_call call;
+    fun check_call_dead live_call call =
+      if null (get_indices call) then () else incompatible_calls live_call call;
 
-      fun freeze_fpTs_simple (T as Type (s, Ts)) =
-          (case find_index (curry (op =) T) fpTs of
-            ~1 => Type (s, map freeze_fpTs_simple Ts)
-          | kk => nth Xs kk)
-        | freeze_fpTs_simple T = T;
+    fun freeze_fpTs_simple (T as Type (s, Ts)) =
+        (case find_index (curry (op =) T) fpTs of
+          ~1 => Type (s, map freeze_fpTs_simple Ts)
+        | kk => nth Xs kk)
+      | freeze_fpTs_simple T = T;
 
-      fun freeze_fpTs_map (callss, (live_call :: _, dead_calls)) s Ts =
-        (List.app (check_call_dead live_call) dead_calls;
-         Type (s, map2 freeze_fpTs (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) []
-           (transpose callss)) Ts))
-      and freeze_fpTs calls (T as Type (s, Ts)) =
-          (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
-            ([], _) =>
-            (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
-              ([], _) => freeze_fpTs_simple T
-            | callsp => freeze_fpTs_map callsp s Ts)
+    fun freeze_fpTs_map (callss, (live_call :: _, dead_calls)) s Ts =
+      (List.app (check_call_dead live_call) dead_calls;
+       Type (s, map2 freeze_fpTs (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) []
+         (transpose callss)) Ts))
+    and freeze_fpTs calls (T as Type (s, Ts)) =
+        (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
+          ([], _) =>
+          (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
+            ([], _) => freeze_fpTs_simple T
           | callsp => freeze_fpTs_map callsp s Ts)
-        | freeze_fpTs _ T = T;
+        | callsp => freeze_fpTs_map callsp s Ts)
+      | freeze_fpTs _ T = T;
 
-      val ctr_Tsss = map (map binder_types) ctr_Tss;
-      val ctrXs_Tsss = map2 (map2 (map2 freeze_fpTs)) callssss ctr_Tsss;
-      val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
-      val Ts = map (body_type o hd) ctr_Tss;
+    val ctr_Tsss = map (map binder_types) ctr_Tss;
+    val ctrXs_Tsss = map2 (map2 (map2 freeze_fpTs)) callssss ctr_Tsss;
+    val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
+    val Ts = map (body_type o hd) ctr_Tss;
 
-      val ns = map length ctr_Tsss;
-      val kss = map (fn n => 1 upto n) ns;
-      val mss = map (map length) ctr_Tsss;
+    val ns = map length ctr_Tsss;
+    val kss = map (fn n => 1 upto n) ns;
+    val mss = map (map length) ctr_Tsss;
 
-      val fp_eqs = map dest_TFree Xs ~~ ctrXs_sum_prod_Ts;
-      val key = key_of_fp_eqs fp fpTs fp_eqs;
-    in
-      (case n2m_sugar_of no_defs_lthy key of
-        SOME n2m_sugar => (n2m_sugar, no_defs_lthy)
-      | NONE =>
-        let
-          val base_fp_names = Name.variant_list [] fp_b_names;
-          val fp_bs = map2 (fn b_name => fn base_fp_name =>
-              Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name)))
-            b_names base_fp_names;
+    val fp_eqs = map dest_TFree Xs ~~ ctrXs_sum_prod_Ts;
+    val key = key_of_fp_eqs fp fpTs fp_eqs;
+  in
+    (case n2m_sugar_of no_defs_lthy key of
+      SOME n2m_sugar => (n2m_sugar, no_defs_lthy)
+    | NONE =>
+      let
+        val base_fp_names = Name.variant_list [] fp_b_names;
+        val fp_bs = map2 (fn b_name => fn base_fp_name =>
+            Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name)))
+          b_names base_fp_names;
 
-          val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct,
-                 dtor_injects, dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) =
-            fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy;
+        val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct, dtor_injects,
+               dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) =
+          fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy;
 
-          val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
-          val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
+        val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
+        val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
 
-          val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
-            mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
+        val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
+          mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
 
-          fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
+        fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
 
-          val ((co_iterss, co_iter_defss), lthy) =
-            fold_map2 (fn b =>
-              (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types)
-               else define_coiters [unfoldN, corecN] (the coiters_args_types))
-                (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy
-            |>> split_list;
+        val ((co_iterss, co_iter_defss), lthy) =
+          fold_map2 (fn b =>
+            (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types)
+             else define_coiters [unfoldN, corecN] (the coiters_args_types))
+              (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy
+          |>> split_list;
 
-          val rho = tvar_subst thy Ts fpTs;
-          val ctr_sugar_phi =
-            Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho))
-              (Morphism.term_morphism (Term.subst_TVars rho));
-          val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi;
+        val rho = tvar_subst thy Ts fpTs;
+        val ctr_sugar_phi = Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho))
+            (Morphism.term_morphism (Term.subst_TVars rho));
+        val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi;
 
-          val ctr_sugars = map inst_ctr_sugar ctr_sugars0;
+        val ctr_sugars = map inst_ctr_sugar ctr_sugars0;
 
-          val ((co_inducts, un_fold_thmss, co_rec_thmss, disc_unfold_thmss, disc_corec_thmss,
-                sel_unfold_thmsss, sel_corec_thmsss), fp_sugar_thms) =
-            if fp = Least_FP then
-              derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
-                xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
-                co_iterss co_iter_defss lthy
-              |> `(fn ((_, induct, _), (fold_thmss, rec_thmss, _)) =>
-                ([induct], fold_thmss, rec_thmss, [], [], [], []))
-              ||> (fn info => (SOME info, NONE))
-            else
-              derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types)
-                xtor_co_induct dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs Xs
-                ctrXs_Tsss kss mss ns ctr_defss ctr_sugars co_iterss co_iter_defss
-                (Proof_Context.export lthy no_defs_lthy) lthy
-              |> `(fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _),
-                      (disc_unfold_thmss, disc_corec_thmss, _), _,
-                      (sel_unfold_thmsss, sel_corec_thmsss, _)) =>
-                (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss, disc_unfold_thmss,
-                 disc_corec_thmss, sel_unfold_thmsss, sel_corec_thmsss))
-              ||> (fn info => (NONE, SOME info));
+        val ((co_inducts, un_fold_thmss, co_rec_thmss, disc_unfold_thmss, disc_corec_thmss,
+              sel_unfold_thmsss, sel_corec_thmsss), fp_sugar_thms) =
+          if fp = Least_FP then
+            derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
+              xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
+              co_iterss co_iter_defss lthy
+            |> `(fn ((_, induct, _), (fold_thmss, rec_thmss, _)) =>
+              ([induct], fold_thmss, rec_thmss, [], [], [], []))
+            ||> (fn info => (SOME info, NONE))
+          else
+            derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types) xtor_co_induct
+              dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs Xs ctrXs_Tsss kss mss
+              ns ctr_defss ctr_sugars co_iterss co_iter_defss
+              (Proof_Context.export lthy no_defs_lthy) lthy
+            |> `(fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _),
+                    (disc_unfold_thmss, disc_corec_thmss, _), _,
+                    (sel_unfold_thmsss, sel_corec_thmsss, _)) =>
+              (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss, disc_unfold_thmss,
+               disc_corec_thmss, sel_unfold_thmsss, sel_corec_thmsss))
+            ||> (fn info => (NONE, SOME info));
 
-          val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
+        val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
 
-          fun mk_target_fp_sugar (kk, T) =
-            {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs,
-             nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss,
-             ctr_sugars = ctr_sugars, co_iterss = co_iterss, mapss = mapss, co_inducts = co_inducts,
-             co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss],
-             disc_co_itersss = transpose [disc_unfold_thmss, disc_corec_thmss],
-             sel_co_iterssss = transpose [sel_unfold_thmsss, sel_corec_thmsss]}
-            |> morph_fp_sugar phi;
+        fun mk_target_fp_sugar (kk, T) =
+          {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs,
+           nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss,
+           ctr_sugars = ctr_sugars, co_iterss = co_iterss, mapss = mapss, co_inducts = co_inducts,
+           co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss],
+           disc_co_itersss = transpose [disc_unfold_thmss, disc_corec_thmss],
+           sel_co_iterssss = transpose [sel_unfold_thmsss, sel_corec_thmsss]}
+          |> morph_fp_sugar phi;
 
-          val n2m_sugar = (map_index mk_target_fp_sugar fpTs, fp_sugar_thms);
-        in
-          (n2m_sugar, lthy |> register_n2m_sugar key n2m_sugar)
-        end)
-    end
-  else
-    ((fp_sugars0, (NONE, NONE)), no_defs_lthy0);
+        val n2m_sugar = (map_index mk_target_fp_sugar fpTs, fp_sugar_thms);
+      in
+        (n2m_sugar, lthy |> register_n2m_sugar key n2m_sugar)
+      end)
+  end;
 
 fun indexify_callsss fp_sugar callsss =
   let
@@ -295,7 +287,7 @@
 
     val _ = (case Library.duplicates (op =) actual_Ts of [] => () | T :: _ => duplicate_datatype T);
 
-    val perm_actual_Ts as Type (_, tyargs0) :: _ =
+    val perm_actual_Ts =
       sort (prod_ord int_ord Term_Ord.typ_ord o pairself (`Term.size_of_typ)) actual_Ts;
 
     fun the_ctrs_of (Type (s, Ts)) = map (mk_ctr Ts) (#ctrs (the (ctr_sugar_of lthy s)));
@@ -318,8 +310,8 @@
         fold (fold_subtype_pairs maybe_insert) (ctr_Ts ~~ gen_ctr_Ts) []
       end;
 
-    fun check_enrich_with_mutuals _ _ seen gen_seen [] = (seen, gen_seen)
-      | check_enrich_with_mutuals lthy rho seen gen_seen ((T as Type (_, tyargs)) :: Ts) =
+    fun gather_types _ _ num_groups seen gen_seen [] = (num_groups, seen, gen_seen)
+      | gather_types lthy rho num_groups seen gen_seen ((T as Type (_, tyargs)) :: Ts) =
         let
           val {fp_res = {Ts = mutual_Ts0, ...}, ...} = the_fp_sugar_of T;
           val mutual_Ts = map (retypargs tyargs) mutual_Ts0;
@@ -354,11 +346,12 @@
           val gen_mutual_Ts = map (retypargs gen_tyargs) mutual_Ts0;
           val Ts' = filter_out (member (op =) mutual_Ts) Ts;
         in
-          check_enrich_with_mutuals lthy' rho' (seen @ mutual_Ts) (gen_seen' @ gen_mutual_Ts) Ts'
+          gather_types lthy' rho' (num_groups + 1) (seen @ mutual_Ts) (gen_seen' @ gen_mutual_Ts)
+            Ts'
         end
-      | check_enrich_with_mutuals _ _ _ _ (T :: _) = not_co_datatype T;
+      | gather_types _ _ _ _ _ (T :: _) = not_co_datatype T;
 
-    val (perm_Ts, perm_gen_Ts) = check_enrich_with_mutuals lthy [] [] [] perm_actual_Ts;
+    val (num_groups, perm_Ts, perm_gen_Ts) = gather_types lthy [] 0 [] [] perm_actual_Ts;
     val perm_frozen_gen_Ts = map Logic.unvarifyT_global perm_gen_Ts;
 
     val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
@@ -380,14 +373,16 @@
     val perm_callssss0 = permute callssss0;
     val perm_fp_sugars0 = map (the o fp_sugar_of lthy o fst o dest_Type) perm_Ts;
 
-    val has_nested = exists (fn Type (_, tyargs) => tyargs <> tyargs0) Ts;
     val perm_callssss = map2 indexify_callsss perm_fp_sugars0 perm_callssss0;
 
     val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
 
     val ((perm_fp_sugars, fp_sugar_thms), lthy) =
-      mutualize_fp_sugars has_nested fp perm_bs perm_frozen_gen_Ts get_perm_indices perm_callssss
-        perm_fp_sugars0 lthy;
+      if num_groups > 1 then
+        mutualize_fp_sugars fp perm_bs perm_frozen_gen_Ts get_perm_indices perm_callssss
+          perm_fp_sugars0 lthy
+      else
+        ((perm_fp_sugars0, (NONE, NONE)), lthy);
 
     val fp_sugars = unpermute perm_fp_sugars;
   in
--- a/src/HOL/BNF/Tools/bnf_lfp_compat.ML	Wed Nov 06 23:05:44 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_lfp_compat.ML	Thu Nov 07 00:37:18 2013 +0100
@@ -94,10 +94,12 @@
     val get_indices = K [];
     val fp_sugars0 = if nn = 1 then [fp_sugar0] else map (lfp_sugar_of o fst o dest_Type) Ts;
     val callssss = map (fn fp_sugar0 => indexify_callsss fp_sugar0 []) fp_sugars0;
-    val has_nested = nn > nn_fp;
 
     val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
-      mutualize_fp_sugars has_nested Least_FP compat_bs Ts get_indices callssss fp_sugars0 lthy;
+      if nn > nn_fp then
+        mutualize_fp_sugars Least_FP compat_bs Ts get_indices callssss fp_sugars0 lthy
+      else
+        ((fp_sugars0, (NONE, NONE)), lthy);
 
     val {ctr_sugars, co_inducts = [induct], co_iterss, co_iter_thmsss = iter_thmsss, ...} :: _ =
       fp_sugars;