--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML Mon Sep 08 14:03:01 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML Mon Sep 08 14:03:01 2014 +0200
@@ -43,15 +43,152 @@
open Ctr_Sugar
open BNF_Util
+open BNF_Tactics
open BNF_FP_Util
open BNF_FP_Def_Sugar
open BNF_FP_N2M_Sugar
open BNF_LFP
-val compatN = "compat_";
+val compat_N = "compat_";
+val rec_fun_N = "rec_fun_";
datatype nesting_preference = Keep_Nesting | Unfold_Nesting;
+fun mk_fun_rec_rhs ctxt fpTs Cs (recs as rec1 :: _) =
+ let
+ fun repair_rec_arg_args [] [] = []
+ | repair_rec_arg_args ((g_T as Type (@{type_name fun}, _)) :: g_Ts) (g :: gs) =
+ let
+ val (x_Ts, body_T) = strip_type g_T;
+ in
+ (case try HOLogic.dest_prodT body_T of
+ NONE => [g]
+ | SOME (fst_T, _) =>
+ if member (op =) fpTs fst_T then
+ let val (xs, _) = mk_Frees "x" x_Ts ctxt in
+ map (fn mk_proj => fold_rev Term.lambda xs (mk_proj (Term.list_comb (g, xs))))
+ [HOLogic.mk_fst, HOLogic.mk_snd]
+ end
+ else
+ [g])
+ :: repair_rec_arg_args g_Ts gs
+ end
+ | repair_rec_arg_args (g_T :: g_Ts) (g :: gs) =
+ if member (op =) fpTs g_T then
+ let
+ val j = find_index (member (op =) Cs) g_Ts;
+ val h = nth gs j;
+ val g_Ts' = nth_drop j g_Ts;
+ val gs' = nth_drop j gs;
+ in
+ [g, h] :: repair_rec_arg_args g_Ts' gs'
+ end
+ else
+ [g] :: repair_rec_arg_args g_Ts gs;
+
+ fun repair_back_rec_arg f_T f' =
+ let
+ val g_Ts = Term.binder_types f_T;
+ val (gs, _) = mk_Frees "g" g_Ts ctxt;
+ in
+ fold_rev Term.lambda gs (Term.list_comb (f',
+ flat_rec_arg_args (repair_rec_arg_args g_Ts gs)))
+ end;
+
+ val f_Ts = binder_fun_types (fastype_of rec1);
+ val (fs', _) = mk_Frees "f" (replicate (length f_Ts) Term.dummyT) ctxt;
+
+ fun mk_rec' recx =
+ fold_rev Term.lambda fs' (Term.list_comb (recx, map2 repair_back_rec_arg f_Ts fs'))
+ |> Syntax.check_term ctxt;
+ in
+ map mk_rec' recs
+ end;
+
+fun define_fun_recs fpTs Cs recs lthy =
+ let
+ val b_names = Name.variant_list [] (map base_name_of_typ fpTs);
+
+ fun mk_binding b_name =
+ Binding.qualify true (compat_N ^ b_name)
+ (Binding.prefix_name rec_fun_N (Binding.name b_name));
+
+ val bs = map mk_binding b_names;
+ val rhss = mk_fun_rec_rhs lthy fpTs Cs recs;
+ in
+ fold_map3 (define_co_rec_as Least_FP Cs) fpTs bs rhss lthy
+ end;
+
+fun mk_fun_rec_thmss ctxt rec0_thmss (recs as rec1 :: _) rec_defs =
+ let
+ val f_Ts = binder_fun_types (fastype_of rec1);
+ val (fs, _) = mk_Frees "f" f_Ts ctxt;
+ val frecs = map (fn recx => Term.list_comb (recx, fs)) recs;
+
+ fun mk_ctrs_of (Type (T_name, As)) =
+ map (mk_ctr As) (#ctrs (the (ctr_sugar_of ctxt T_name)));
+
+ val fpTs = map (domain_type o body_fun_type o fastype_of) recs;
+ val fpTs_frecs = fpTs ~~ frecs;
+ val ctrss = map mk_ctrs_of fpTs;
+ val fss = unflat ctrss fs;
+
+ fun mk_rec_call g n (Type (@{type_name fun}, [dom_T, ran_T])) =
+ Abs (Name.uu, dom_T, mk_rec_call g (n + 1) ran_T)
+ | mk_rec_call g n fpT =
+ let
+ val frec = the (AList.lookup (op =) fpTs_frecs fpT);
+ val xg = Term.list_comb (g, map Bound (n - 1 downto 0));
+ in frec $ xg end;
+
+ fun mk_rec_arg_arg g_T g =
+ g :: (if exists_subtype_in fpTs g_T then [mk_rec_call g 0 g_T] else []);
+
+ fun mk_goal frec ctr f =
+ let
+ val g_Ts = binder_types (fastype_of ctr);
+ val (gs, _) = mk_Frees "g" g_Ts ctxt;
+ val gctr = Term.list_comb (ctr, gs);
+ val fgs = flat_rec_arg_args (map2 mk_rec_arg_arg g_Ts gs);
+ in
+ fold_rev (fold_rev Logic.all) [fs, gs]
+ (mk_Trueprop_eq (frec $ gctr, Term.list_comb (f, fgs)))
+ end;
+
+ fun mk_goals ctrs fs frec = map2 (mk_goal frec) ctrs fs;
+
+ val goalss = map3 mk_goals ctrss fss frecs;
+
+ fun tac ctxt =
+ unfold_thms_tac ctxt (@{thms o_apply fst_conv snd_conv} @ rec_defs @ flat rec0_thmss) THEN
+ HEADGOAL (rtac refl);
+
+ fun prove goal =
+ Goal.prove_sorry ctxt [] [] goal (tac o #context)
+ |> Thm.close_derivation;
+ in
+ map (map prove) goalss
+ end;
+
+fun define_fun_rec_derive_thms induct inducts recs0 rec_thmss fpTs lthy =
+ let
+ val thy = Proof_Context.theory_of lthy;
+
+ (* imperfect: will not yield the expected theorem for functions taking a large number of
+ arguments *)
+ val repair_induct = unfold_thms lthy @{thms all_mem_range};
+
+ val induct' = repair_induct induct;
+ val inducts' = map repair_induct inducts;
+
+ val Cs = map ((fn TVar ((s, _), S) => TFree (s, S)) o body_type o fastype_of) recs0;
+ val recs = map2 (mk_co_rec thy Least_FP Cs) fpTs recs0;
+ val ((recs', rec'_defs), lthy') = define_fun_recs fpTs Cs recs lthy |>> split_list;
+ val rec'_thmss = mk_fun_rec_thmss lthy' rec_thmss recs' rec'_defs;
+ in
+ ((induct', inducts', recs', rec'_thmss), lthy')
+ end;
+
fun reindex_desc desc =
let
val kks = map fst desc;
@@ -130,10 +267,10 @@
val dest_dtyp = Old_Datatype_Aux.typ_of_dtyp descr;
- val Ts = Old_Datatype_Aux.get_rec_types descr;
- val nn = length Ts;
+ val fpTs' = Old_Datatype_Aux.get_rec_types descr;
+ val nn = length fpTs';
- val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
+ val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) fpTs';
val ctr_Tsss = map (map (map dest_dtyp o snd) o #3 o snd) descr;
val kkssss =
map (map (map (fn Old_Datatype_Aux.DtRec kk => [kk] | _ => []) o snd) o #3 o snd) descr;
@@ -146,34 +283,47 @@
val callssss =
map2 (map2 (map2 (fn ctr_T => map (apply_comps (num_binder_types ctr_T))))) ctr_Tsss kkssss;
- val b_names = Name.variant_list [] (map base_name_of_typ Ts);
- val compat_b_names = map (prefix compatN) b_names;
+ val b_names = Name.variant_list [] (map base_name_of_typ fpTs');
+ val compat_b_names = map (prefix compat_N) b_names;
val compat_bs = map Binding.name compat_b_names;
val ((fp_sugars, (lfp_sugar_thms, _)), lthy') =
if nn > nn_fp then
- mutualize_fp_sugars Least_FP cliques compat_bs Ts callers callssss fp_sugars0 lthy
+ mutualize_fp_sugars Least_FP cliques compat_bs fpTs' callers callssss fp_sugars0 lthy
else
((fp_sugars0, (NONE, NONE)), lthy);
- val recs = map (fst o dest_Const o #co_rec) fp_sugars;
- val rec_thms = maps #co_rec_thms fp_sugars;
-
val {common_co_inducts = [induct], ...} :: _ = fp_sugars;
val inducts = map (the_single o #co_inducts) fp_sugars;
+ val recs = map #co_rec fp_sugars;
+ val rec_thmss = map #co_rec_thms fp_sugars;
+
+ fun is_nested_rec_type (Type (@{type_name fun}, [_, T])) = member (op =) fpTs' (body_type T)
+ | is_nested_rec_type _ = false;
+
+ val ((induct', inducts', recs', rec'_thmss), lthy'') =
+ if nesting_pref = Unfold_Nesting andalso
+ exists (exists (exists is_nested_rec_type)) ctr_Tsss then
+ define_fun_rec_derive_thms induct inducts recs rec_thmss fpTs' lthy'
+ else
+ ((induct, inducts, recs, rec_thmss), lthy');
+
+ val rec'_names = map (fst o dest_Const) recs';
+ val rec'_thms = flat rec'_thmss;
+
fun mk_info (kk, {T = Type (T_name0, _), ctr_sugar = {casex, exhaust, nchotomy, injects,
distincts, case_thms, case_cong, case_cong_weak, split, split_asm, ...}, ...} : fp_sugar) =
(T_name0,
- {index = kk, descr = descr, inject = injects, distinct = distincts, induct = induct,
- inducts = inducts, exhaust = exhaust, nchotomy = nchotomy, rec_names = recs,
- rec_rewrites = rec_thms, case_name = fst (dest_Const casex), case_rewrites = case_thms,
+ {index = kk, descr = descr, inject = injects, distinct = distincts, induct = induct',
+ inducts = inducts', exhaust = exhaust, nchotomy = nchotomy, rec_names = rec'_names,
+ rec_rewrites = rec'_thms, case_name = fst (dest_Const casex), case_rewrites = case_thms,
case_cong = case_cong, case_cong_weak = case_cong_weak, split = split,
split_asm = split_asm});
val infos = map_index mk_info (take nn_fp fp_sugars);
in
- (nn, b_names, compat_b_names, lfp_sugar_thms, infos, lthy')
+ (nn, b_names, compat_b_names, lfp_sugar_thms, infos, lthy'')
end;
fun infos_of_new_datatype_mutual_cluster lthy fpT_name =
@@ -298,7 +448,7 @@
NONE => []
| SOME ((induct_thms, induct_thm, induct_attrs), (rec_thmss, _)) =>
let
- val common_name = compatN ^ mk_common_name b_names;
+ val common_name = compat_N ^ mk_common_name b_names;
val common_notes =
(if nn > 1 then [(inductN, [induct_thm], induct_attrs)] else [])