--- a/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML Wed Nov 06 21:40:41 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML Wed Nov 06 22:42:54 2013 +0100
@@ -264,14 +264,20 @@
fun indexify_callsss fp_sugar callsss =
let
val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
- fun do_ctr ctr =
+ fun indexify_ctr ctr =
(case AList.lookup Term.aconv_untyped callsss ctr of
NONE => replicate (num_binder_types (fastype_of ctr)) []
| SOME callss => map (map (Envir.beta_eta_contract o unfold_let)) callss);
in
- map do_ctr ctrs
+ map indexify_ctr ctrs
end;
+fun retypargs tyargs (Type (s, _)) = Type (s, tyargs);
+
+fun fold_subtype_pairs f (T as Type (s, Ts), U as Type (s', Us)) =
+ f (T, U) #> (if s = s' then fold (fold_subtype_pairs f) (Ts ~~ Us) else I)
+ | fold_subtype_pairs f TU = f TU;
+
fun nested_to_mutual_fps fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
let
val qsoty = quote o Syntax.string_of_typ lthy;
@@ -292,23 +298,70 @@
val perm_actual_Ts as Type (_, tyargs0) :: _ =
sort (prod_ord int_ord Term_Ord.typ_ord o pairself (`Term.size_of_typ)) actual_Ts;
+ fun the_ctrs_of (Type (s, Ts)) = map (mk_ctr Ts) (#ctrs (the (ctr_sugar_of lthy s)));
+
fun the_fp_sugar_of (T as Type (T_name, _)) =
(case fp_sugar_of lthy T_name of
SOME (fp_sugar as {fp = fp', ...}) => if fp = fp' then fp_sugar else not_co_datatype T
| NONE => not_co_datatype T);
- fun check_enrich_with_mutuals _ [] = []
- | check_enrich_with_mutuals seen ((T as Type (_, tyargs)) :: Ts) =
+ fun gen_rhss_in gen_Ts rho subTs =
+ let
+ fun maybe_insert (T, Type (_, gen_tyargs)) =
+ if member (op =) subTs T then insert (op =) gen_tyargs else I
+ | maybe_insert _ = I;
+
+ val ctrs = maps the_ctrs_of gen_Ts;
+ val gen_ctr_Ts = maps (binder_types o fastype_of) ctrs;
+ val ctr_Ts = map (Term.typ_subst_atomic rho) gen_ctr_Ts;
+ in
+ fold (fold_subtype_pairs maybe_insert) (ctr_Ts ~~ gen_ctr_Ts) []
+ end;
+
+ fun check_enrich_with_mutuals _ _ seen gen_seen [] = (seen, gen_seen)
+ | check_enrich_with_mutuals lthy rho seen gen_seen ((T as Type (_, tyargs)) :: Ts) =
let
- val {fp_res = {Ts = Ts', ...}, ...} = the_fp_sugar_of T
- val mutual_Ts = map (fn Type (s, _) => Type (s, tyargs)) Ts';
- val (seen', Ts') = List.partition (member (op =) mutual_Ts) Ts;
+ val {fp_res = {Ts = mutual_Ts0, ...}, ...} = the_fp_sugar_of T;
+ val mutual_Ts = map (retypargs tyargs) mutual_Ts0;
+
+ fun fresh_tyargs () =
+ let
+ (* The name "'z" is unlikely to clash with the context, yielding more cache hits. *)
+ val (gen_tyargs, lthy') =
+ variant_tfrees (replicate (length tyargs) "z") lthy
+ |>> map Logic.varifyT_global;
+ val rho' = (gen_tyargs ~~ tyargs) @ rho;
+ in
+ (rho', gen_tyargs, gen_seen, lthy')
+ end;
+
+ val (rho', gen_tyargs, gen_seen', lthy') =
+ if exists (exists_subtype_in seen) mutual_Ts then
+ (case gen_rhss_in gen_seen rho mutual_Ts of
+ [] => fresh_tyargs ()
+ | [gen_tyargs] => (rho, gen_tyargs, gen_seen, lthy)
+ | gen_tyargss as gen_tyargs :: gen_tyargss_tl =>
+ let
+ val unify_pairs = split_list (maps (curry (op ~~) gen_tyargs) gen_tyargss_tl);
+ val mgu = Type.raw_unifys unify_pairs Vartab.empty;
+ val gen_tyargs' = map (Envir.subst_type mgu) gen_tyargs;
+ val gen_seen' = map (Envir.subst_type mgu) gen_seen;
+ in
+ (rho, gen_tyargs', gen_seen', lthy)
+ end)
+ else
+ fresh_tyargs ();
+
+ val gen_mutual_Ts = map (retypargs gen_tyargs) mutual_Ts0;
+ val Ts' = filter_out (member (op =) mutual_Ts) Ts;
in
- mutual_Ts @ check_enrich_with_mutuals (seen @ T :: seen') Ts'
+ check_enrich_with_mutuals lthy' rho' (seen @ mutual_Ts) (gen_seen' @ gen_mutual_Ts) Ts'
end
- | check_enrich_with_mutuals _ (T :: _) = not_co_datatype T;
+ | check_enrich_with_mutuals _ _ _ _ (T :: _) = not_co_datatype T;
- val perm_Ts = check_enrich_with_mutuals [] perm_actual_Ts;
+ val (perm_Ts, perm_gen_Ts) = check_enrich_with_mutuals lthy [] [] [] perm_actual_Ts;
+ val perm_frozen_gen_Ts = map Logic.unvarifyT_global perm_gen_Ts;
+
val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
val Ts = actual_Ts @ missing_Ts;
@@ -334,7 +387,7 @@
val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
val ((perm_fp_sugars, fp_sugar_thms), lthy) =
- mutualize_fp_sugars has_nested fp perm_bs perm_Ts get_perm_indices perm_callssss
+ mutualize_fp_sugars has_nested fp perm_bs perm_frozen_gen_Ts get_perm_indices perm_callssss
perm_fp_sugars0 lthy;
val fp_sugars = unpermute perm_fp_sugars;