--- a/src/HOL/BNF/Tools/bnf_fp_sugar.ML Sun Sep 23 08:24:19 2012 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_sugar.ML Sun Sep 23 14:52:53 2012 +0200
@@ -7,19 +7,19 @@
signature BNF_FP_SUGAR =
sig
- val datatyp: bool ->
+ val datatypes: bool ->
(mixfix list -> (string * sort) list option -> binding list -> typ list * typ list list ->
BNF_Def.BNF list -> local_theory ->
- (term list * term list * term list * term list * thm * thm list * thm list * thm list *
- thm list * thm list) * local_theory) ->
+ (term list * term list * term list *term list * term list * thm * thm list * thm list *
+ thm list * thm list * thm list) * local_theory) ->
bool * ((((typ * sort) list * binding) * mixfix) * ((((binding * binding) *
(binding * typ) list) * (binding * term) list) * mixfix) list) list ->
local_theory -> local_theory
val parse_datatype_cmd: bool ->
(mixfix list -> (string * sort) list option -> binding list -> typ list * typ list list ->
BNF_Def.BNF list -> local_theory ->
- (term list * term list * term list * term list * thm * thm list * thm list * thm list *
- thm list * thm list) * local_theory) ->
+ (term list * term list * term list * term list * term list * thm * thm list * thm list *
+ thm list * thm list * thm list) * local_theory) ->
(local_theory -> local_theory) parser
end;
@@ -34,8 +34,9 @@
val simp_attrs = @{attributes [simp]};
-fun split_list8 xs =
- (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs, map #8 xs);
+fun split_list9 xs =
+ (map #1 xs, map #2 xs, map #3 xs, map #4 xs, map #5 xs, map #6 xs, map #7 xs, map #8 xs,
+ map #9 xs);
fun resort_tfree S (TFree (s, _)) = TFree (s, S);
@@ -52,6 +53,38 @@
fun mk_uncurried2_fun f xss =
mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
+fun mk_ctor_or_dtor get_T Ts t =
+ let val Type (_, Ts0) = get_T (fastype_of t) in
+ Term.subst_atomic_types (Ts0 ~~ Ts) t
+ end;
+
+val mk_ctor = mk_ctor_or_dtor range_type;
+val mk_dtor = mk_ctor_or_dtor domain_type;
+
+fun mk_rec_like lfp Ts Us t =
+ let
+ val (bindings, body) = strip_type (fastype_of t);
+ val (f_Us, prebody) = split_last bindings;
+ val Type (_, Ts0) = if lfp then prebody else body;
+ val Us0 = distinct (op =) (map (if lfp then body_type else domain_type) f_Us);
+ in
+ Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
+ end;
+
+fun mk_map live Ts Us t =
+ let val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last in
+ Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
+ end;
+
+fun mk_rel Ts Us t =
+ let
+ val ((Type (_, Ts0), Type (_, Us0)), body) =
+ strip_type (fastype_of t) |>> split_last |>> apfst List.last;
+val _ = tracing ("*** " ^ PolyML.makestring (Ts, "***", Us, "***", t, Ts0, Us0)) (*###*)
+ in
+ Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
+ end;
+
fun tick u f = Term.lambda u (HOLogic.mk_prod (u, f $ u));
fun tack z_name (c, u) f =
@@ -81,7 +114,7 @@
fun defaults_of ((_, ds), _) = ds;
fun ctr_mixfix_of (_, mx) = mx;
-fun define_datatype prepare_constraint prepare_typ prepare_term lfp construct (no_dests, specs)
+fun define_datatypes prepare_constraint prepare_typ prepare_term lfp construct (no_dests, specs)
no_defs_lthy0 =
let
(* TODO: sanity checks on arguments *)
@@ -104,7 +137,7 @@
val unsorted_Ass0 = map (map (resort_tfree HOLogic.typeS)) Ass0;
val unsorted_As = Library.foldr1 merge_type_args unsorted_Ass0;
- val ((Bs, Cs), no_defs_lthy) =
+ val ((Xs, Cs), no_defs_lthy) =
no_defs_lthy0
|> fold (Variable.declare_typ o resort_tfree dummyS) unsorted_As
|> mk_TFrees nn
@@ -160,16 +193,16 @@
| eq_fpT _ _ = false;
fun freeze_fp (T as Type (s, Us)) =
- (case find_index (eq_fpT T) fake_Ts of ~1 => Type (s, map freeze_fp Us) | j => nth Bs j)
+ (case find_index (eq_fpT T) fake_Ts of ~1 => Type (s, map freeze_fp Us) | j => nth Xs j)
| freeze_fp T = T;
- val ctr_TsssBs = map (map (map freeze_fp)) fake_ctr_Tsss;
- val ctr_sum_prod_TsBs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssBs;
+ val ctr_TsssXs = map (map (map freeze_fp)) fake_ctr_Tsss;
+ val ctr_sum_prod_TsXs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssXs;
val fp_eqs =
- map dest_TFree Bs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctr_sum_prod_TsBs;
+ map dest_TFree Xs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctr_sum_prod_TsXs;
- val (pre_bnfs, ((dtors0, ctors0, fp_folds0, fp_recs0, fp_induct, dtor_ctors, ctor_dtors,
+ val (pre_bnfs, ((dtors0, ctors0, rels0, fp_folds0, fp_recs0, fp_induct, dtor_ctors, ctor_dtors,
ctor_injects, fp_fold_thms, fp_rec_thms), lthy)) =
fp_bnf construct fp_bs mixfixes (map dest_TFree unsorted_As) fp_eqs no_defs_lthy0;
@@ -183,46 +216,29 @@
in snd oo add end;
fun nesty_bnfs Us =
- map_filter (bnf_of lthy) (fold (fold (fold (add_nesty_bnf_names Us))) ctr_TsssBs []);
+ map_filter (bnf_of lthy) (fold (fold (fold (add_nesty_bnf_names Us))) ctr_TsssXs []);
val nesting_bnfs = nesty_bnfs As;
- val nested_bnfs = nesty_bnfs Bs;
+ val nested_bnfs = nesty_bnfs Xs;
val timer = time (Timer.startRealTimer ());
- fun mk_ctor_or_dtor get_T Ts t =
- let val Type (_, Ts0) = get_T (fastype_of t) in
- Term.subst_atomic_types (Ts0 ~~ Ts) t
- end;
-
- val mk_ctor = mk_ctor_or_dtor range_type;
- val mk_dtor = mk_ctor_or_dtor domain_type;
-
val ctors = map (mk_ctor As) ctors0;
val dtors = map (mk_dtor As) dtors0;
+ val rels = map (mk_rel As As) rels0; (*FIXME*)
val fpTs = map (domain_type o fastype_of) dtors;
val exists_fp_subtype = exists_subtype (member (op =) fpTs);
- val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Bs ~~ fpTs)))) ctr_TsssBs;
+ val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
val ns = map length ctr_Tsss;
val kss = map (fn n => 1 upto n) ns;
val mss = map (map length) ctr_Tsss;
val Css = map2 replicate ns Cs;
- fun mk_rec_like Ts Us t =
- let
- val (bindings, body) = strip_type (fastype_of t);
- val (f_Us, prebody) = split_last bindings;
- val Type (_, Ts0) = if lfp then prebody else body;
- val Us0 = distinct (op =) (map (if lfp then body_type else domain_type) f_Us);
- in
- Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
- end;
-
- val fp_folds as fp_fold1 :: _ = map (mk_rec_like As Cs) fp_folds0;
- val fp_recs as fp_rec1 :: _ = map (mk_rec_like As Cs) fp_recs0;
+ val fp_folds as fp_fold1 :: _ = map (mk_rec_like lfp As Cs) fp_folds0;
+ val fp_recs as fp_rec1 :: _ = map (mk_rec_like lfp As Cs) fp_recs0;
val fp_fold_fun_Ts = fst (split_last (binder_types (fastype_of fp_fold1)));
val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
@@ -340,11 +356,12 @@
val ctr_sum_prod_T = mk_sumTN_balanced ctr_prod_Ts;
val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
- val ((((w, fs), xss), u'), _) =
+ val (((((w, fs), xss), yss), u'), _) =
no_defs_lthy
|> yield_singleton (mk_Frees "w") dtorT
||>> mk_Frees "f" case_Ts
||>> mk_Freess "x" ctr_Tss
+ ||>> mk_Freess "y" ctr_Tss
||>> yield_singleton Variable.variant_fixes fp_b_name;
val u = Free (u', fpT);
@@ -443,9 +460,9 @@
val [fold_def, rec_def] = map (Morphism.thm phi) defs;
- val [foldx, recx] = map (mk_rec_like As Cs o Morphism.term phi) csts;
+ val [foldx, recx] = map (mk_rec_like lfp As Cs o Morphism.term phi) csts;
in
- ((wrap_res, ctrs, foldx, recx, xss, ctr_defs, fold_def, rec_def), lthy)
+ ((wrap_res, ctrs, foldx, recx, xss, yss, ctr_defs, fold_def, rec_def), lthy)
end;
fun define_unfold_corec (wrap_res, no_defs_lthy) =
@@ -483,9 +500,9 @@
val [unfold_def, corec_def] = map (Morphism.thm phi) defs;
- val [unfold, corec] = map (mk_rec_like As Cs o Morphism.term phi) csts;
+ val [unfold, corec] = map (mk_rec_like lfp As Cs o Morphism.term phi) csts;
in
- ((wrap_res, ctrs, unfold, corec, xss, ctr_defs, unfold_def, corec_def), lthy)
+ ((wrap_res, ctrs, unfold, corec, xss, yss, ctr_defs, unfold_def, corec_def), lthy)
end;
fun wrap lthy =
@@ -500,20 +517,13 @@
end;
fun wrap_types_and_define_rec_likes ((wraps, define_rec_likess), lthy) =
- fold_map2 (curry (op o)) define_rec_likess wraps lthy |>> split_list8
+ fold_map2 (curry (op o)) define_rec_likess wraps lthy |>> split_list9
val pre_map_defs = map map_def_of_bnf pre_bnfs;
val pre_set_defss = map set_defs_of_bnf pre_bnfs;
val nested_set_natural's = maps set_natural'_of_bnf nested_bnfs;
val nesting_map_ids = map map_id_of_bnf nesting_bnfs;
- fun mk_map live Ts Us t =
- let
- val (Type (_, Ts0), Type (_, Us0)) = strip_typeN (live + 1) (fastype_of t) |>> List.last
- in
- Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
- end;
-
fun build_map build_arg (Type (s, Ts)) (Type (_, Us)) =
let
val bnf = the (bnf_of lthy s);
@@ -527,7 +537,7 @@
map3 (fn (_, _, injects, distincts, cases, _, _, _) => fn rec_likes => fn fold_likes =>
injects @ distincts @ cases @ rec_likes @ fold_likes);
- fun derive_induct_fold_rec_thms_for_types ((wrap_ress, ctrss, folds, recs, xsss, ctr_defss,
+ fun derive_induct_fold_rec_thms_for_types ((wrap_ress, ctrss, folds, recs, xsss, _, ctr_defss,
fold_defs, rec_defs), lthy) =
let
val (((phis, phis'), us'), names_lthy) =
@@ -683,7 +693,7 @@
fn T_name => [induct_case_names_attr, induct_type_attr T_name]),
(foldsN, fold_thmss, K (Code.add_default_eqn_attrib :: simp_attrs)),
(recsN, rec_thmss, K (Code.add_default_eqn_attrib :: simp_attrs)),
- (simpsN, simp_thmss, K [])]
+ (simpsN, simp_thmss, K [])] (* TODO: Add relator simps *)
|> maps (fn (thmN, thmss, attrs) =>
map3 (fn b => fn Type (T_name, _) => fn thms =>
((Binding.qualify true (Binding.name_of b) (Binding.name thmN), attrs T_name),
@@ -692,7 +702,7 @@
lthy |> Local_Theory.notes (common_notes @ notes) |> snd
end;
- fun derive_coinduct_unfold_corec_thms_for_types ((wrap_ress, ctrss, unfolds, corecs, _,
+ fun derive_coinduct_unfold_corec_thms_for_types ((wrap_ress, ctrss, unfolds, corecs, _, _,
ctr_defss, unfold_defs, corec_defs), lthy) =
let
val discss = map (map (mk_disc_or_sel As) o #1) wrap_ress;
@@ -864,7 +874,7 @@
(disc_corecsN, disc_corec_thmss, simp_attrs),
(sel_unfoldsN, sel_unfold_thmss, simp_attrs),
(sel_corecsN, sel_corec_thmss, simp_attrs),
- (simpsN, simp_thmss, []),
+ (simpsN, simp_thmss, []), (* TODO: Add relator simps *)
(unfoldsN, unfold_thmss, [])]
|> maps (fn (thmN, thmss, attrs) =>
map_filter (fn (_, []) => NONE | (b, thms) =>
@@ -874,9 +884,76 @@
lthy |> Local_Theory.notes (anonymous_notes @ common_notes @ notes) |> snd
end;
- fun derive_pred_thms_for_types ((wrap_ress, ctrss, unfolds, corecs, _, ctr_defss, unfold_defs,
- corec_defs), lthy) =
- lthy;
+ fun derive_rel_thms_for_types ((wrap_ress, ctrss, unfolds, corecs, xsss, ysss, ctr_defss,
+ unfold_defs, corec_defs), lthy) =
+ let
+ val selsss = map #2 wrap_ress;
+
+ val theta_Ts = [] (*###*)
+
+ val (thetas, _) =
+ lthy
+ |> mk_Frees "Q" (map mk_pred1T theta_Ts);
+
+ val (rel_thmss, rel_thmsss) =
+ let
+ val xctrss = map2 (map2 (curry Term.list_comb)) ctrss xsss;
+ val yctrss = map2 (map2 (curry Term.list_comb)) ctrss ysss;
+ val threls = map (fold_rev rapp thetas) rels;
+
+ fun mk_goal threl (xctr, xs) (yctr, ys) =
+ let
+ val lhs = threl $ xctr $ yctr;
+
+ (* ### fixme: lift rel *)
+ fun mk_conjunct x y = HOLogic.mk_eq (x, y);
+
+ fun mk_rhs () =
+ Library.foldr1 HOLogic.mk_conj (map2 mk_conjunct xs ys);
+ in
+ HOLogic.mk_Trueprop
+ (if Term.head_of xctr = Term.head_of yctr then
+ if null xs then
+ lhs
+ else
+ HOLogic.mk_eq (lhs, mk_rhs ())
+ else
+ HOLogic.mk_not lhs)
+ end;
+
+(*###*)
+ (* TODO: Prove and exploit symmetry of relators to halve the number of goals. *)
+ fun mk_goals threl xctrs xss yctrs yss =
+ map_product (mk_goal threl) (xctrs ~~ xss) (yctrs ~~ yss);
+
+ val goalsss = map5 mk_goals threls xctrss xsss yctrss ysss;
+
+(*###
+ val goalsss = map6 (fn threl =>
+ map5 (fn xctr => fn xs => fn sels =>
+ map2 (mk_goal threl xctr xs sels))) threls xctrss xsss selsss yctrss ysss;
+*)
+val _ = tracing ("goalsss: " ^ PolyML.makestring goalsss) (*###*)
+ in
+ ([], [])
+ end;
+
+ val (sel_rel_thmss, sel_rel_thmsss) =
+ let
+ in
+ ([], [])
+ end;
+
+ val notes =
+ [(* (relsN, rel_thmss, []),
+ (sel_relsN, sel_rel_thmss, []) *)]
+ |> maps (fn (thmN, thmss, attrs) =>
+ map2 (fn b => fn thms =>
+ ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), attrs),
+ [(thms, [])])) fp_bs thmss);
+ in
+ lthy |> Local_Theory.notes notes |> snd
+ end;
val lthy' = lthy
|> fold_map define_ctrs_case_for_type (fp_bs ~~ fpTs ~~ Cs ~~ ctors ~~ dtors ~~ fp_folds ~~
@@ -886,7 +963,7 @@
|> `(if lfp then derive_induct_fold_rec_thms_for_types
else derive_coinduct_unfold_corec_thms_for_types)
|> swap |>> fst
- |> derive_pred_thms_for_types;
+ |> (if null rels then snd else derive_rel_thms_for_types);
val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
(if lfp then "" else "co") ^ "datatype"));
@@ -894,9 +971,9 @@
timer; lthy'
end;
-val datatyp = define_datatype (K I) (K I) (K I);
+val datatypes = define_datatypes (K I) (K I) (K I);
-val datatype_cmd = define_datatype Typedecl.read_constraint Syntax.parse_typ Syntax.read_term;
+val datatype_cmd = define_datatypes Typedecl.read_constraint Syntax.parse_typ Syntax.read_term;
val parse_ctr_arg =
@{keyword "("} |-- parse_binding_colon -- Parse.typ --| @{keyword ")"} ||