fourth attempt at generalizing N2M types (to leverage caching)
authorblanchet
Wed, 06 Nov 2013 22:42:54 +0100
changeset 54283 6f0a49ed1bb1
parent 54282 32b5c4821d9d
child 54284 0b53378080d9
fourth attempt at generalizing N2M types (to leverage caching)
src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML	Wed Nov 06 21:40:41 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML	Wed Nov 06 22:42:54 2013 +0100
@@ -264,14 +264,20 @@
 fun indexify_callsss fp_sugar callsss =
   let
     val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
-    fun do_ctr ctr =
+    fun indexify_ctr ctr =
       (case AList.lookup Term.aconv_untyped callsss ctr of
         NONE => replicate (num_binder_types (fastype_of ctr)) []
       | SOME callss => map (map (Envir.beta_eta_contract o unfold_let)) callss);
   in
-    map do_ctr ctrs
+    map indexify_ctr ctrs
   end;
 
+fun retypargs tyargs (Type (s, _)) = Type (s, tyargs);
+
+fun fold_subtype_pairs f (T as Type (s, Ts), U as Type (s', Us)) =
+    f (T, U) #> (if s = s' then fold (fold_subtype_pairs f) (Ts ~~ Us) else I)
+  | fold_subtype_pairs f TU = f TU;
+
 fun nested_to_mutual_fps fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
   let
     val qsoty = quote o Syntax.string_of_typ lthy;
@@ -292,23 +298,70 @@
     val perm_actual_Ts as Type (_, tyargs0) :: _ =
       sort (prod_ord int_ord Term_Ord.typ_ord o pairself (`Term.size_of_typ)) actual_Ts;
 
+    fun the_ctrs_of (Type (s, Ts)) = map (mk_ctr Ts) (#ctrs (the (ctr_sugar_of lthy s)));
+
     fun the_fp_sugar_of (T as Type (T_name, _)) =
       (case fp_sugar_of lthy T_name of
         SOME (fp_sugar as {fp = fp', ...}) => if fp = fp' then fp_sugar else not_co_datatype T
       | NONE => not_co_datatype T);
 
-    fun check_enrich_with_mutuals _ [] = []
-      | check_enrich_with_mutuals seen ((T as Type (_, tyargs)) :: Ts) =
+    fun gen_rhss_in gen_Ts rho subTs =
+      let
+        fun maybe_insert (T, Type (_, gen_tyargs)) =
+            if member (op =) subTs T then insert (op =) gen_tyargs else I
+          | maybe_insert _ = I;
+
+        val ctrs = maps the_ctrs_of gen_Ts;
+        val gen_ctr_Ts = maps (binder_types o fastype_of) ctrs;
+        val ctr_Ts = map (Term.typ_subst_atomic rho) gen_ctr_Ts;
+      in
+        fold (fold_subtype_pairs maybe_insert) (ctr_Ts ~~ gen_ctr_Ts) []
+      end;
+
+    fun check_enrich_with_mutuals _ _ seen gen_seen [] = (seen, gen_seen)
+      | check_enrich_with_mutuals lthy rho seen gen_seen ((T as Type (_, tyargs)) :: Ts) =
         let
-          val {fp_res = {Ts = Ts', ...}, ...} = the_fp_sugar_of T
-          val mutual_Ts = map (fn Type (s, _) => Type (s, tyargs)) Ts';
-          val (seen', Ts') = List.partition (member (op =) mutual_Ts) Ts;
+          val {fp_res = {Ts = mutual_Ts0, ...}, ...} = the_fp_sugar_of T;
+          val mutual_Ts = map (retypargs tyargs) mutual_Ts0;
+
+          fun fresh_tyargs () =
+            let
+              (* The name "'z" is unlikely to clash with the context, yielding more cache hits. *)
+              val (gen_tyargs, lthy') =
+                variant_tfrees (replicate (length tyargs) "z") lthy
+                |>> map Logic.varifyT_global;
+              val rho' = (gen_tyargs ~~ tyargs) @ rho;
+            in
+              (rho', gen_tyargs, gen_seen, lthy')
+            end;
+
+          val (rho', gen_tyargs, gen_seen', lthy') =
+            if exists (exists_subtype_in seen) mutual_Ts then
+              (case gen_rhss_in gen_seen rho mutual_Ts of
+                [] => fresh_tyargs ()
+              | [gen_tyargs] => (rho, gen_tyargs, gen_seen, lthy)
+              | gen_tyargss as gen_tyargs :: gen_tyargss_tl =>
+                let
+                  val unify_pairs = split_list (maps (curry (op ~~) gen_tyargs) gen_tyargss_tl);
+                  val mgu = Type.raw_unifys unify_pairs Vartab.empty;
+                  val gen_tyargs' = map (Envir.subst_type mgu) gen_tyargs;
+                  val gen_seen' = map (Envir.subst_type mgu) gen_seen;
+                in
+                  (rho, gen_tyargs', gen_seen', lthy)
+                end)
+            else
+              fresh_tyargs ();
+
+          val gen_mutual_Ts = map (retypargs gen_tyargs) mutual_Ts0;
+          val Ts' = filter_out (member (op =) mutual_Ts) Ts;
         in
-          mutual_Ts @ check_enrich_with_mutuals (seen @ T :: seen') Ts'
+          check_enrich_with_mutuals lthy' rho' (seen @ mutual_Ts) (gen_seen' @ gen_mutual_Ts) Ts'
         end
-      | check_enrich_with_mutuals _ (T :: _) = not_co_datatype T;
+      | check_enrich_with_mutuals _ _ _ _ (T :: _) = not_co_datatype T;
 
-    val perm_Ts = check_enrich_with_mutuals [] perm_actual_Ts;
+    val (perm_Ts, perm_gen_Ts) = check_enrich_with_mutuals lthy [] [] [] perm_actual_Ts;
+    val perm_frozen_gen_Ts = map Logic.unvarifyT_global perm_gen_Ts;
+
     val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
     val Ts = actual_Ts @ missing_Ts;
 
@@ -334,7 +387,7 @@
     val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
 
     val ((perm_fp_sugars, fp_sugar_thms), lthy) =
-      mutualize_fp_sugars has_nested fp perm_bs perm_Ts get_perm_indices perm_callssss
+      mutualize_fp_sugars has_nested fp perm_bs perm_frozen_gen_Ts get_perm_indices perm_callssss
         perm_fp_sugars0 lthy;
 
     val fp_sugars = unpermute perm_fp_sugars;