--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Mon Apr 29 09:45:14 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML Mon Apr 29 10:37:23 2013 +0200
@@ -56,6 +56,8 @@
let val (xs1, xs2, xs3, xs4) = split_list4 xs;
in (x1 :: xs1, x2 :: xs2, x3 :: xs3, x4 :: xs4) end;
+fun exists_subtype_in Ts = exists_subtype (member (op =) Ts);
+
fun resort_tfree S (TFree (s, _)) = TFree (s, S);
fun typ_subst inst (T as Type (s, Ts)) =
@@ -155,6 +157,175 @@
fun defaults_of ((_, ds), _) = ds;
fun ctr_mixfix_of (_, mx) = mx;
+fun build_map lthy build_arg (Type (s, Ts)) (Type (_, Us)) =
+ let
+ val bnf = the (bnf_of lthy s);
+ val live = live_of_bnf bnf;
+ val mapx = mk_map live Ts Us (map_of_bnf bnf);
+ val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
+ in Term.list_comb (mapx, map build_arg TUs') end;
+
+fun build_rel_step lthy build_arg (Type (s, Ts)) =
+ let
+ val bnf = the (bnf_of lthy s);
+ val live = live_of_bnf bnf;
+ val rel = mk_rel live Ts Ts (rel_of_bnf bnf);
+ val Ts' = map domain_type (fst (strip_typeN live (fastype_of rel)));
+ in Term.list_comb (rel, map build_arg Ts') end;
+
+fun derive_induct_fold_rec_thms_for_types
+ nn fp_b_names pre_bnfs fp_induct fp_fold_thms fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs
+ ctr_Tsss mss ns gss hss ((ctrss, xsss, ctr_defss, _), (folds, recs, fold_defs, rec_defs)) lthy =
+ let
+ val pre_map_defs = map map_def_of_bnf pre_bnfs;
+ val pre_set_defss = map set_defs_of_bnf pre_bnfs;
+ val nested_map_comp's = map map_comp'_of_bnf nested_bnfs;
+ val nested_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nested_bnfs;
+ val nesting_map_ids'' = map (unfold_thms lthy @{thms id_def} o map_id_of_bnf) nesting_bnfs;
+ val nested_set_map's = maps set_map'_of_bnf nested_bnfs;
+
+ val (((ps, ps'), us'), names_lthy) =
+ lthy
+ |> mk_Frees' "P" (map mk_pred1T fpTs)
+ ||>> Variable.variant_fixes fp_b_names;
+
+ val us = map2 (curry Free) us' fpTs;
+
+ fun mk_sets_nested bnf =
+ let
+ val Type (T_name, Us) = T_of_bnf bnf;
+ val lives = lives_of_bnf bnf;
+ val sets = sets_of_bnf bnf;
+ fun mk_set U =
+ (case find_index (curry (op =) U) lives of
+ ~1 => Term.dummy
+ | i => nth sets i);
+ in
+ (T_name, map mk_set Us)
+ end;
+
+ val setss_nested = map mk_sets_nested nested_bnfs;
+
+ val (induct_thms, induct_thm) =
+ let
+ fun mk_set Ts t =
+ let val Type (_, Ts0) = domain_type (fastype_of t) in
+ Term.subst_atomic_types (Ts0 ~~ Ts) t
+ end;
+
+ fun mk_raw_prem_prems names_lthy (x as Free (s, T as Type (T_name, Ts0))) =
+ (case find_index (curry (op =) T) fpTs of
+ ~1 =>
+ (case AList.lookup (op =) setss_nested T_name of
+ NONE => []
+ | SOME raw_sets0 =>
+ let
+ val (Ts, raw_sets) =
+ split_list (filter (exists_subtype_in fpTs o fst) (Ts0 ~~ raw_sets0));
+ val sets = map (mk_set Ts0) raw_sets;
+ val (ys, names_lthy') = names_lthy |> mk_Frees s Ts;
+ val xysets = map (pair x) (ys ~~ sets);
+ val ppremss = map (mk_raw_prem_prems names_lthy') ys;
+ in
+ flat (map2 (map o apfst o cons) xysets ppremss)
+ end)
+ | kk => [([], (kk + 1, x))])
+ | mk_raw_prem_prems _ _ = [];
+
+ fun close_prem_prem xs t =
+ fold_rev Logic.all (map Free (drop (nn + length xs)
+ (rev (Term.add_frees t (map dest_Free xs @ ps'))))) t;
+
+ fun mk_prem_prem xs (xysets, (j, x)) =
+ close_prem_prem xs (Logic.list_implies (map (fn (x', (y, set)) =>
+ HOLogic.mk_Trueprop (HOLogic.mk_mem (y, set $ x'))) xysets,
+ HOLogic.mk_Trueprop (nth ps (j - 1) $ x)));
+
+ fun mk_raw_prem phi ctr ctr_Ts =
+ let
+ val (xs, names_lthy') = names_lthy |> mk_Frees "x" ctr_Ts;
+ val pprems = maps (mk_raw_prem_prems names_lthy') xs;
+ in (xs, pprems, HOLogic.mk_Trueprop (phi $ Term.list_comb (ctr, xs))) end;
+
+ fun mk_prem (xs, raw_pprems, concl) =
+ fold_rev Logic.all xs (Logic.list_implies (map (mk_prem_prem xs) raw_pprems, concl));
+
+ val raw_premss = map3 (map2 o mk_raw_prem) ps ctrss ctr_Tsss;
+
+ val goal =
+ Library.foldr (Logic.list_implies o apfst (map mk_prem)) (raw_premss,
+ HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map2 (curry (op $)) ps us)));
+
+ val kksss = map (map (map (fst o snd) o #2)) raw_premss;
+
+ val ctor_induct' = fp_induct OF (map mk_sumEN_tupled_balanced mss);
+
+ val thm =
+ Goal.prove_sorry lthy [] [] goal (fn {context = ctxt, ...} =>
+ mk_induct_tac ctxt nn ns mss kksss (flat ctr_defss) ctor_induct' nested_set_map's
+ pre_set_defss)
+ |> singleton (Proof_Context.export names_lthy lthy)
+ |> Thm.close_derivation;
+ in
+ `(conj_dests nn) thm
+ end;
+
+ val induct_cases = quasi_unambiguous_case_names (maps (map name_of_ctr) ctrss);
+
+ val (fold_thmss, rec_thmss) =
+ let
+ val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
+ val gfolds = map (lists_bmoc gss) folds;
+ val hrecs = map (lists_bmoc hss) recs;
+
+ fun mk_goal fss frec_like xctr f xs fxs =
+ fold_rev (fold_rev Logic.all) (xs :: fss)
+ (mk_Trueprop_eq (frec_like $ xctr, Term.list_comb (f, fxs)));
+
+ fun build_rec_like frec_likes (T, U) =
+ if T = U then
+ id_const T
+ else
+ (case find_index (curry (op =) T) fpTs of
+ ~1 => build_map lthy (build_rec_like frec_likes) T U
+ | kk => nth frec_likes kk);
+
+ val mk_U = typ_subst (map2 pair fpTs Cs);
+
+ fun unzip_rec_likes frec_likes combine (x as Free (_, T)) =
+ if exists_subtype_in fpTs T then
+ combine (x, build_rec_like frec_likes (T, mk_U T) $ x)
+ else
+ ([x], []);
+
+ val gxsss = map (map (flat_rec (unzip_rec_likes gfolds (fn (_, t) => ([t], []))))) xsss;
+ val hxsss = map (map (flat_rec (unzip_rec_likes hrecs (pairself single)))) xsss;
+
+ val fold_goalss = map5 (map4 o mk_goal gss) gfolds xctrss gss xsss gxsss;
+ val rec_goalss = map5 (map4 o mk_goal hss) hrecs xctrss hss xsss hxsss;
+
+ val fold_tacss =
+ map2 (map o mk_rec_like_tac pre_map_defs [] nesting_map_ids'' fold_defs) fp_fold_thms
+ ctr_defss;
+ val rec_tacss =
+ map2 (map o mk_rec_like_tac pre_map_defs nested_map_comp's
+ (nested_map_ids'' @ nesting_map_ids'') rec_defs) fp_rec_thms ctr_defss;
+
+ fun prove goal tac =
+ Goal.prove_sorry lthy [] [] goal (tac o #context)
+ |> Thm.close_derivation;
+ in
+ (map2 (map2 prove) fold_goalss fold_tacss, map2 (map2 prove) rec_goalss rec_tacss)
+ end;
+
+ val induct_case_names_attr = Attrib.internal (K (Rule_Cases.case_names induct_cases));
+ fun induct_type_attr T_name = Attrib.internal (K (Induct.induct_type T_name));
+ in
+ ((induct_thm, [induct_case_names_attr]),
+ (induct_thms, fn T_name => [induct_case_names_attr, induct_type_attr T_name]),
+ (fold_thmss, code_simp_attrs), (rec_thmss, code_simp_attrs))
+ end;
+
fun define_datatypes prepare_constraint prepare_typ prepare_term lfp construct_fp
(wrap_opts as (no_dests, rep_compat), specs) no_defs_lthy0 =
let
@@ -264,22 +435,6 @@
val timer = time (Timer.startRealTimer ());
- fun build_map build_arg (Type (s, Ts)) (Type (_, Us)) =
- let
- val bnf = the (bnf_of lthy s);
- val live = live_of_bnf bnf;
- val mapx = mk_map live Ts Us (map_of_bnf bnf);
- val TUs' = map dest_funT (fst (strip_typeN live (fastype_of mapx)));
- in Term.list_comb (mapx, map build_arg TUs') end;
-
- fun build_rel_step build_arg (Type (s, Ts)) =
- let
- val bnf = the (bnf_of lthy s);
- val live = live_of_bnf bnf;
- val rel = mk_rel live Ts Ts (rel_of_bnf bnf);
- val Ts' = map domain_type (fst (strip_typeN live (fastype_of rel)));
- in Term.list_comb (rel, map build_arg Ts') end;
-
fun add_nesty_bnf_names Us =
let
fun add (Type (s, Ts)) ss =
@@ -332,9 +487,6 @@
((qualify true fp_b_name (Binding.name thmN), attrs T_name),
[(thms, [])])) fp_b_names fpTs thmss);
- val exists_fp_subtype = exists_subtype (member (op =) fpTs);
- val exists_Cs_subtype = exists_subtype (member (op =) Cs);
-
val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
val ns = map length ctr_Tsss;
val kss = map (fn n => 1 upto n) ns;
@@ -368,7 +520,7 @@
| proj_recT _ T = T;
fun unzip_recT T =
- if exists_fp_subtype T then ([proj_recT fst T], [proj_recT snd T]) else ([T], []);
+ if exists_subtype_in fpTs T then ([proj_recT fst T], [proj_recT snd T]) else ([T], []);
val z_Tsss =
map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
@@ -422,7 +574,7 @@
| proj_corecT _ T = T;
fun unzip_corecT T =
- if exists_fp_subtype T then [proj_corecT fst T, proj_corecT snd T] else [T];
+ if exists_subtype_in fpTs T then [proj_corecT fst T, proj_corecT snd T] else [T];
val (s_Tssss, h_sum_prod_Ts, h_Tsss, h_Tssss, ph_Tss) =
mk_types unzip_corecT fp_rec_fun_Ts;
@@ -636,7 +788,7 @@
else
(case (T, U) of
(Type (s, _), Type (s', _)) =>
- if s = s' then build_map (build_prod_proj mk_proj) T U else mk_proj T
+ if s = s' then build_map lthy (build_prod_proj mk_proj) T U else mk_proj T
| _ => mk_proj T);
(* TODO: Avoid these complications; cf. corec case *)
@@ -646,7 +798,7 @@
| mk_U _ T = T;
fun unzip_rec (x as Free (_, T)) =
- if exists_fp_subtype T then
+ if exists_subtype_in fpTs T then
([build_prod_proj fst_const (T, mk_U fst T) $ x],
[build_prod_proj snd_const (T, mk_U snd T) $ x])
else
@@ -696,7 +848,7 @@
else
(case (T, U) of
(Type (s, _), Type (s', _)) =>
- if s = s' then build_map (build_sum_inj mk_inj) T U
+ if s = s' then build_map lthy (build_sum_inj mk_inj) T U
else uncurry mk_inj (dest_sumT U)
| _ => uncurry mk_inj (dest_sumT U));
@@ -761,165 +913,31 @@
map3 (fn (_, _, _, injects, distincts, cases, _, _, _) => fn rec_likes => fn fold_likes =>
injects @ distincts @ cases @ rec_likes @ fold_likes);
- fun derive_induct_fold_rec_thms_for_types (((ctrss, xsss, ctr_defss, wrap_ress), (folds, recs,
- fold_defs, rec_defs)), lthy) =
+ fun derive_and_note_induct_fold_rec_thms_for_types (info as ((_, _, _, wrap_ress), _), lthy) =
let
- val (((ps, ps'), us'), names_lthy) =
- lthy
- |> mk_Frees' "P" (map mk_pred1T fpTs)
- ||>> Variable.variant_fixes fp_b_names;
-
- val us = map2 (curry Free) us' fpTs;
-
- fun mk_sets_nested bnf =
- let
- val Type (T_name, Us) = T_of_bnf bnf;
- val lives = lives_of_bnf bnf;
- val sets = sets_of_bnf bnf;
- fun mk_set U =
- (case find_index (curry (op =) U) lives of
- ~1 => Term.dummy
- | i => nth sets i);
- in
- (T_name, map mk_set Us)
- end;
-
- val setss_nested = map mk_sets_nested nested_bnfs;
-
- val (induct_thms, induct_thm) =
- let
- fun mk_set Ts t =
- let val Type (_, Ts0) = domain_type (fastype_of t) in
- Term.subst_atomic_types (Ts0 ~~ Ts) t
- end;
-
- fun mk_raw_prem_prems names_lthy (x as Free (s, T as Type (T_name, Ts0))) =
- (case find_index (curry (op =) T) fpTs of
- ~1 =>
- (case AList.lookup (op =) setss_nested T_name of
- NONE => []
- | SOME raw_sets0 =>
- let
- val (Ts, raw_sets) =
- split_list (filter (exists_fp_subtype o fst) (Ts0 ~~ raw_sets0));
- val sets = map (mk_set Ts0) raw_sets;
- val (ys, names_lthy') = names_lthy |> mk_Frees s Ts;
- val xysets = map (pair x) (ys ~~ sets);
- val ppremss = map (mk_raw_prem_prems names_lthy') ys;
- in
- flat (map2 (map o apfst o cons) xysets ppremss)
- end)
- | kk => [([], (kk + 1, x))])
- | mk_raw_prem_prems _ _ = [];
-
- fun close_prem_prem xs t =
- fold_rev Logic.all (map Free (drop (nn + length xs)
- (rev (Term.add_frees t (map dest_Free xs @ ps'))))) t;
-
- fun mk_prem_prem xs (xysets, (j, x)) =
- close_prem_prem xs (Logic.list_implies (map (fn (x', (y, set)) =>
- HOLogic.mk_Trueprop (HOLogic.mk_mem (y, set $ x'))) xysets,
- HOLogic.mk_Trueprop (nth ps (j - 1) $ x)));
-
- fun mk_raw_prem phi ctr ctr_Ts =
- let
- val (xs, names_lthy') = names_lthy |> mk_Frees "x" ctr_Ts;
- val pprems = maps (mk_raw_prem_prems names_lthy') xs;
- in (xs, pprems, HOLogic.mk_Trueprop (phi $ Term.list_comb (ctr, xs))) end;
-
- fun mk_prem (xs, raw_pprems, concl) =
- fold_rev Logic.all xs (Logic.list_implies (map (mk_prem_prem xs) raw_pprems, concl));
-
- val raw_premss = map3 (map2 o mk_raw_prem) ps ctrss ctr_Tsss;
-
- val goal =
- Library.foldr (Logic.list_implies o apfst (map mk_prem)) (raw_premss,
- HOLogic.mk_Trueprop (Library.foldr1 HOLogic.mk_conj (map2 (curry (op $)) ps us)));
-
- val kksss = map (map (map (fst o snd) o #2)) raw_premss;
-
- val ctor_induct' = fp_induct OF (map mk_sumEN_tupled_balanced mss);
-
- val thm =
- Goal.prove_sorry lthy [] [] goal (fn {context = ctxt, ...} =>
- mk_induct_tac ctxt nn ns mss kksss (flat ctr_defss) ctor_induct' nested_set_map's
- pre_set_defss)
- |> singleton (Proof_Context.export names_lthy lthy)
- |> Thm.close_derivation;
- in
- `(conj_dests nn) thm
- end;
-
- val induct_cases = quasi_unambiguous_case_names (maps (map name_of_ctr) ctrss);
-
- val (fold_thmss, rec_thmss) =
- let
- val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
- val gfolds = map (lists_bmoc gss) folds;
- val hrecs = map (lists_bmoc hss) recs;
-
- fun mk_goal fss frec_like xctr f xs fxs =
- fold_rev (fold_rev Logic.all) (xs :: fss)
- (mk_Trueprop_eq (frec_like $ xctr, Term.list_comb (f, fxs)));
-
- fun build_rec_like frec_likes (T, U) =
- if T = U then
- id_const T
- else
- (case find_index (curry (op =) T) fpTs of
- ~1 => build_map (build_rec_like frec_likes) T U
- | kk => nth frec_likes kk);
-
- val mk_U = typ_subst (map2 pair fpTs Cs);
-
- fun unzip_rec_likes frec_likes combine (x as Free (_, T)) =
- if exists_fp_subtype T then
- combine (x, build_rec_like frec_likes (T, mk_U T) $ x)
- else
- ([x], []);
-
- val gxsss = map (map (flat_rec (unzip_rec_likes gfolds (fn (_, t) => ([t], []))))) xsss;
- val hxsss = map (map (flat_rec (unzip_rec_likes hrecs (pairself single)))) xsss;
-
- val fold_goalss = map5 (map4 o mk_goal gss) gfolds xctrss gss xsss gxsss;
- val rec_goalss = map5 (map4 o mk_goal hss) hrecs xctrss hss xsss hxsss;
-
- val fold_tacss =
- map2 (map o mk_rec_like_tac pre_map_defs [] nesting_map_ids'' fold_defs) fp_fold_thms
- ctr_defss;
- val rec_tacss =
- map2 (map o mk_rec_like_tac pre_map_defs nested_map_comp's
- (nested_map_ids'' @ nesting_map_ids'') rec_defs) fp_rec_thms ctr_defss;
-
- fun prove goal tac =
- Goal.prove_sorry lthy [] [] goal (tac o #context)
- |> Thm.close_derivation;
- in
- (map2 (map2 prove) fold_goalss fold_tacss, map2 (map2 prove) rec_goalss rec_tacss)
- end;
+ val ((induct_thm, induct_attrs), (induct_thms, inducts_attrs), (fold_thmss, fold_attrs),
+ (rec_thmss, rec_attrs)) =
+ derive_induct_fold_rec_thms_for_types nn fp_b_names pre_bnfs fp_induct fp_fold_thms
+ fp_rec_thms nesting_bnfs nested_bnfs fpTs Cs ctr_Tsss mss ns gss hss info lthy;
val simp_thmss = mk_simp_thmss wrap_ress rec_thmss fold_thmss;
- val induct_case_names_attr = Attrib.internal (K (Rule_Cases.case_names induct_cases));
- fun induct_type_attr T_name = Attrib.internal (K (Induct.induct_type T_name));
-
val common_notes =
- (if nn > 1 then [(inductN, [induct_thm], [induct_case_names_attr])] else [])
+ (if nn > 1 then [(inductN, [induct_thm], induct_attrs)] else [])
|> massage_simple_notes fp_common_name;
val notes =
- [(foldN, fold_thmss, K code_simp_attrs),
- (inductN, map single induct_thms,
- fn T_name => [induct_case_names_attr, induct_type_attr T_name]),
- (recN, rec_thmss, K code_simp_attrs),
+ [(foldN, fold_thmss, K fold_attrs),
+ (inductN, map single induct_thms, inducts_attrs),
+ (recN, rec_thmss, K rec_attrs),
(simpsN, simp_thmss, K [])]
|> massage_multi_notes;
in
lthy |> Local_Theory.notes (common_notes @ notes) |> snd
end;
- fun derive_coinduct_unfold_corec_thms_for_types (((ctrss, _, ctr_defss, wrap_ress), (unfolds,
- corecs, unfold_defs, corec_defs)), lthy) =
+ fun derive_and_note_coinduct_unfold_corec_thms_for_types (((ctrss, _, ctr_defss, wrap_ress),
+ (unfolds, corecs, unfold_defs, corec_defs)), lthy) =
let
val nesting_rel_eqs = map rel_eq_of_bnf nesting_bnfs;
@@ -955,7 +973,8 @@
fun build_rel rs' T =
(case find_index (curry (op =) T) fpTs of
~1 =>
- if exists_fp_subtype T then build_rel_step (build_rel rs') T else HOLogic.eq_const T
+ if exists_subtype_in fpTs T then build_rel_step lthy (build_rel rs') T
+ else HOLogic.eq_const T
| kk => nth rs' kk);
fun build_rel_app rs' usel vsel =
@@ -1039,14 +1058,15 @@
id_const T
else
(case find_index (curry (op =) U) fpTs of
- ~1 => build_map (build_corec_like fcorec_likes) T U
+ ~1 => build_map lthy (build_corec_like fcorec_likes) T U
| kk => nth fcorec_likes kk);
val mk_U = typ_subst (map2 pair Cs fpTs);
fun intr_corec_likes fcorec_likes [] [cf] =
let val T = fastype_of cf in
- if exists_Cs_subtype T then build_corec_like fcorec_likes (T, mk_U T) $ cf else cf
+ if exists_subtype_in Cs T then build_corec_like fcorec_likes (T, mk_U T) $ cf
+ else cf
end
| intr_corec_likes fcorec_likes [cq] [cf, cf'] =
mk_If cq (intr_corec_likes fcorec_likes [] [cf])
@@ -1211,8 +1231,8 @@
mss ~~ ctr_bindingss ~~ ctr_mixfixess ~~ ctr_Tsss ~~ disc_bindingss ~~ sel_bindingsss ~~
raw_sel_defaultsss)
|> wrap_types_and_more
- |> (if lfp then derive_induct_fold_rec_thms_for_types
- else derive_coinduct_unfold_corec_thms_for_types);
+ |> (if lfp then derive_and_note_induct_fold_rec_thms_for_types
+ else derive_and_note_coinduct_unfold_corec_thms_for_types);
val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
(if lfp then "" else "co") ^ "datatype"));