--- a/src/HOL/Codatatype/Tools/bnf_sugar.ML Fri Aug 31 22:25:06 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_sugar.ML Fri Aug 31 22:34:37 2012 +0200
@@ -16,32 +16,48 @@
open BNF_FP_Util
open BNF_Sugar_Tactics
-val case_congN = "case_cong"
-val case_discsN = "case_discs"
-val casesN = "cases"
-val ctr_selsN = "ctr_sels"
-val disc_disjointN = "disc_disjoint"
-val disc_exhaustN = "disc_exhaust"
-val discsN = "discs"
-val distinctN = "distinct"
-val selsN = "sels"
-val splitN = "split"
-val split_asmN = "split_asm"
-val weak_case_cong_thmsN = "weak_case_cong"
+val is_N = "is_";
+val un_N = "un_";
+fun mk_un_N 1 1 suf = un_N ^ suf
+ | mk_un_N _ l suf = un_N ^ suf ^ string_of_int l;
-fun mk_half_pairs [] = []
- | mk_half_pairs (x :: xs) = fold_rev (cons o pair x) xs (mk_half_pairs xs);
+val case_congN = "case_cong";
+val case_discsN = "case_discs";
+val casesN = "cases";
+val ctr_selsN = "ctr_sels";
+val disc_exclusN = "disc_exclus";
+val disc_exhaustN = "disc_exhaust";
+val discsN = "discs";
+val distinctN = "distinct";
+val selsN = "sels";
+val splitN = "split";
+val split_asmN = "split_asm";
+val weak_case_cong_thmsN = "weak_case_cong";
-fun index_of_half_row _ 0 = 0
- | index_of_half_row n j = index_of_half_row n (j - 1) + n - j;
+val default_name = @{binding _};
+
+fun pad_list x n xs = xs @ replicate (n - length xs) x;
-fun index_of_half_cell n j k = index_of_half_row n j + k - (j + 1);
+fun mk_half_pairss' _ [] = []
+ | mk_half_pairss' indent (y :: ys) =
+ indent @ fold_rev (cons o single o pair y) ys (mk_half_pairss' ([] :: indent) ys);
+
+fun mk_half_pairss ys = mk_half_pairss' [[]] ys;
val mk_Trueprop_eq = HOLogic.mk_Trueprop o HOLogic.mk_eq;
-fun eta_expand_caseof_arg f xs = fold_rev Term.lambda xs (Term.list_comb (f, xs));
+fun mk_undef T Ts = Const (@{const_name undefined}, Ts ---> T);
+
+fun eta_expand_caseof_arg xs f_xs = fold_rev Term.lambda xs f_xs;
-fun prepare_sugar prep_term (((raw_ctrs, raw_caseof), disc_names), sel_namess) no_defs_lthy =
+fun name_of_ctr t =
+ case head_of t of
+ Const (s, _) => s
+ | Free (s, _) => s
+ | _ => error "Cannot extract name of constructor";
+
+fun prepare_sugar prep_term ((raw_ctrs, raw_caseof), (raw_disc_names, raw_sel_namess))
+ no_defs_lthy =
let
(* TODO: sanity checks on arguments *)
@@ -61,41 +77,64 @@
|> mk_TFrees (length As0)
||> the_single o fst o mk_TFrees 1;
- fun mk_undef T Ts = Const (@{const_name undefined}, Ts ---> T);
-
fun mk_ctr Ts ctr =
let val Ts0 = snd (dest_Type (body_type (fastype_of ctr))) in
Term.subst_atomic_types (Ts0 ~~ Ts) ctr
end;
- fun mk_caseof Ts T =
- let val (binders, body) = strip_type (fastype_of caseof0) in
- Term.subst_atomic_types ((body, T) :: (snd (dest_Type (List.last binders)) ~~ Ts)) caseof0
- end;
-
val T = Type (T_name, As);
val ctrs = map (mk_ctr As) ctrs0;
val ctr_Tss = map (binder_types o fastype_of) ctrs;
val ms = map length ctr_Tss;
+ val disc_names =
+ pad_list default_name n raw_disc_names
+ |> map2 (fn ctr => fn disc =>
+ if Binding.eq_name (disc, default_name) then
+ Binding.name (prefix is_N (Long_Name.base_name (name_of_ctr ctr)))
+ else
+ disc) ctrs0;
+
+ val sel_namess =
+ pad_list [] n raw_sel_namess
+ |> map3 (fn ctr => fn m => map2 (fn l => fn sel =>
+ if Binding.eq_name (sel, default_name) then
+ Binding.name (mk_un_N m l (Long_Name.base_name (name_of_ctr ctr)))
+ else
+ sel) (1 upto m) o pad_list default_name m) ctrs0 ms;
+
+ fun mk_caseof Ts T =
+ let val (binders, body) = strip_type (fastype_of caseof0) in
+ Term.subst_atomic_types ((body, T) :: (snd (dest_Type (List.last binders)) ~~ Ts)) caseof0
+ end;
+
val caseofB = mk_caseof As B;
val caseofB_Ts = map (fn Ts => Ts ---> B) ctr_Tss;
- val (((((((xss, yss), fs), gs), (v, v')), w), p), names_lthy) = no_defs_lthy |>
+ fun mk_caseofB_term eta_fs = Term.list_comb (caseofB, eta_fs);
+
+ val (((((((xss, yss), fs), gs), (v, v')), w), (p, p')), names_lthy) = no_defs_lthy |>
mk_Freess "x" ctr_Tss
||>> mk_Freess "y" ctr_Tss
||>> mk_Frees "f" caseofB_Ts
||>> mk_Frees "g" caseofB_Ts
||>> yield_singleton (apfst (op ~~) oo mk_Frees' "v") T
||>> yield_singleton (mk_Frees "w") T
- ||>> yield_singleton (mk_Frees "P") HOLogic.boolT;
+ ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "P") HOLogic.boolT;
+
+ val q = Free (fst p', B --> HOLogic.boolT);
val xctrs = map2 (curry Term.list_comb) ctrs xss;
val yctrs = map2 (curry Term.list_comb) ctrs yss;
- val eta_fs = map2 eta_expand_caseof_arg fs xss;
- val eta_gs = map2 eta_expand_caseof_arg gs xss;
+ val xfs = map2 (curry Term.list_comb) fs xss;
+ val xgs = map2 (curry Term.list_comb) gs xss;
+
+ val eta_fs = map2 eta_expand_caseof_arg xss xfs;
+ val eta_gs = map2 eta_expand_caseof_arg xss xgs;
+
+ val caseofB_fs = Term.list_comb (caseofB, eta_fs);
val exist_xs_v_eq_ctrs =
map2 (fn xctr => fn xs => list_exists_free xs (HOLogic.mk_eq (v, xctr))) xctrs xss;
@@ -155,23 +194,17 @@
map4 mk_goal xctrs yctrs xss yss
end;
- val goal_half_distincts =
- map (HOLogic.mk_Trueprop o HOLogic.mk_not o HOLogic.mk_eq) (mk_half_pairs xctrs);
+ val goal_half_distinctss =
+ map (map (HOLogic.mk_Trueprop o HOLogic.mk_not o HOLogic.mk_eq)) (mk_half_pairss xctrs);
- val goal_cases =
- let
- val lhs0 = Term.list_comb (caseofB, eta_fs);
- fun mk_goal xctr xs f = mk_Trueprop_eq (lhs0 $ xctr, Term.list_comb (f, xs));
- in
- map3 mk_goal xctrs xss fs
- end;
+ val goal_cases = map2 (fn xctr => fn xf => mk_Trueprop_eq (caseofB_fs $ xctr, xf)) xctrs xfs;
- val goals = [goal_exhaust] :: goal_injectss @ [goal_half_distincts, goal_cases];
+ val goals = [goal_exhaust] :: goal_injectss @ goal_half_distinctss @ [goal_cases];
fun after_qed thmss lthy =
let
- val ([exhaust_thm], (inject_thmss, [half_distinct_thms, case_thms])) =
- (hd thmss, chop n (tl thmss));
+ val ([exhaust_thm], (inject_thmss, (half_distinct_thmss, [case_thms]))) =
+ (hd thmss, apsnd (chop (n * n)) (chop n (tl thmss)));
val exhaust_thm' =
let val Tinst = map (pairself (certifyT lthy)) (map Logic.varifyT_global As ~~ As) in
@@ -179,7 +212,13 @@
(Thm.instantiate (Tinst, []) (Drule.zero_var_indexes exhaust_thm))
end;
- val other_half_distinct_thms = map (fn thm => thm RS not_sym) half_distinct_thms;
+ val other_half_distinct_thmss = map (map (fn thm => thm RS not_sym)) half_distinct_thmss;
+
+ val (distinct_thmsss', distinct_thmsss) =
+ map2 (map2 append) (Library.chop_groups n half_distinct_thmss)
+ (transpose (Library.chop_groups n other_half_distinct_thmss))
+ |> `transpose;
+ val distinct_thms = interleave (flat half_distinct_thmss) (flat other_half_distinct_thmss);
val nchotomy_thm =
let
@@ -217,39 +256,38 @@
(Local_Defs.unfold lthy @{thms not_ex} (def RS @{thm ssubst[of _ _ Not]})))
ms disc_defs;
- val disc_thms =
+ val (disc_thmss', disc_thmss) =
let
- fun get_distinct_thm k k' =
- if k > k' then nth half_distinct_thms (index_of_half_cell n (k' - 1) (k - 1))
- else nth other_half_distinct_thms (index_of_half_cell n (k' - 1) (k' - 1))
- fun mk_thm ((k, discI), not_disc) k' =
- if k = k' then refl RS discI else get_distinct_thm k k' RS not_disc;
+ fun mk_thm discI _ [] = refl RS discI
+ | mk_thm _ not_disc [distinct] = distinct RS not_disc;
+ fun mk_thms discI not_disc distinctss = map (mk_thm discI not_disc) distinctss;
in
- map_product mk_thm (ks ~~ discI_thms ~~ not_disc_thms) ks
+ map3 mk_thms discI_thms not_disc_thms distinct_thmsss'
+ |> `transpose
end;
- val disc_disjoint_thms =
+ val disc_exclus_thms =
let
- fun get_disc_thm k k' = nth disc_thms ((k' - 1) * n + (k - 1));
fun mk_goal ((_, disc), (_, disc')) =
Logic.all v (Logic.mk_implies (HOLogic.mk_Trueprop (disc $ v),
HOLogic.mk_Trueprop (HOLogic.mk_not (disc' $ v))));
fun prove tac goal = Skip_Proof.prove lthy [] [] goal (K tac);
- val bundles = ks ~~ ms ~~ discD_thms ~~ discs;
- val half_pairs = mk_half_pairs bundles;
+ val bundles = ms ~~ discD_thms ~~ discs;
+ val half_pairss = mk_half_pairss bundles;
- val goal_halves = map mk_goal half_pairs;
- val half_thms =
- map2 (fn ((((k, m), discD), _), (((k', _), _), _)) =>
- prove (mk_half_disc_disjoint_tac m discD (get_disc_thm k k')))
- half_pairs goal_halves;
+ val goal_halvess = map (map mk_goal) half_pairss;
+ val half_thmss =
+ map3 (fn [] => K (K [])
+ | [(((m, discD), _), _)] => fn disc_thm => fn [goal] =>
+ [prove (mk_half_disc_exclus_tac m discD disc_thm) goal])
+ half_pairss (flat disc_thmss') goal_halvess;
- val goal_other_halves = map (mk_goal o swap) half_pairs;
- val other_half_thms =
- map2 (prove o mk_other_half_disc_disjoint_tac) half_thms goal_other_halves;
+ val goal_other_halvess = map (map (mk_goal o swap)) half_pairss;
+ val other_half_thmss =
+ map2 (map2 (prove o mk_other_half_disc_exclus_tac)) half_thmss goal_other_halvess;
in
- half_thms @ other_half_thms
+ interleave (flat half_thmss) (flat other_half_thmss)
end;
val disc_exhaust_thm =
@@ -281,26 +319,22 @@
| mk_rhs (disc :: discs) (f :: fs) (sels :: selss) =
Const (@{const_name If}, HOLogic.boolT --> B --> B --> B) $
(disc $ v) $ mk_core f sels $ mk_rhs discs fs selss;
-
- val lhs = Term.list_comb (caseofB, eta_fs) $ v;
- val rhs = mk_rhs discs fs selss;
- val goal = mk_Trueprop_eq (lhs, rhs);
+ val goal = mk_Trueprop_eq (caseofB_fs $ v, mk_rhs discs fs selss);
in
Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
- mk_case_disc_tac ctxt exhaust_thm' case_thms disc_thms sel_thmss)
+ mk_case_disc_tac ctxt exhaust_thm' case_thms disc_thmss' sel_thmss)
|> singleton (Proof_Context.export names_lthy lthy)
end;
val (case_cong_thm, weak_case_cong_thm) =
let
fun mk_prem xctr xs f g =
- fold_rev Logic.all xs (Logic.mk_implies (mk_Trueprop_eq (v, xctr),
+ fold_rev Logic.all xs (Logic.mk_implies (mk_Trueprop_eq (w, xctr),
mk_Trueprop_eq (f, g)));
- fun mk_caseof_term fs = Term.list_comb (caseofB, fs);
val v_eq_w = mk_Trueprop_eq (v, w);
- val caseof_fs = mk_caseof_term eta_fs;
- val caseof_gs = mk_caseof_term eta_gs;
+ val caseof_fs = mk_caseofB_term eta_fs;
+ val caseof_gs = mk_caseofB_term eta_gs;
val goal =
Logic.list_implies (v_eq_w :: map4 mk_prem xctrs xss fs gs,
@@ -308,44 +342,70 @@
val goal_weak =
Logic.mk_implies (v_eq_w, mk_Trueprop_eq (caseof_fs $ v, caseof_fs $ w));
in
- (Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
- mk_case_cong_tac ctxt exhaust_thm' case_thms),
+ (Skip_Proof.prove lthy [] [] goal (fn _ => mk_case_cong_tac exhaust_thm' case_thms),
Skip_Proof.prove lthy [] [] goal_weak (K (etac arg_cong 1)))
|> pairself (singleton (Proof_Context.export names_lthy lthy))
end;
- val split_thms = [];
+ val (split_thm, split_asm_thm) =
+ let
+ fun mk_conjunct xctr xs f_xs =
+ list_all_free xs (HOLogic.mk_imp (HOLogic.mk_eq (v, xctr), q $ f_xs));
+ fun mk_disjunct xctr xs f_xs =
+ list_exists_free xs (HOLogic.mk_conj (HOLogic.mk_eq (v, xctr),
+ HOLogic.mk_not (q $ f_xs)));
- val split_asm_thms = [];
+ val lhs = q $ (mk_caseofB_term eta_fs $ v);
+
+ val goal =
+ mk_Trueprop_eq (lhs, Library.foldr1 HOLogic.mk_conj (map3 mk_conjunct xctrs xss xfs));
+ val goal_asm =
+ mk_Trueprop_eq (lhs, HOLogic.mk_not (Library.foldr1 HOLogic.mk_disj
+ (map3 mk_disjunct xctrs xss xfs)));
- (* case syntax *)
+ val split_thm =
+ Skip_Proof.prove lthy [] [] goal
+ (fn _ => mk_split_tac exhaust_thm' case_thms inject_thmss distinct_thmsss)
+ |> singleton (Proof_Context.export names_lthy lthy)
+ val split_asm_thm =
+ Skip_Proof.prove lthy [] [] goal_asm (fn {context = ctxt, ...} =>
+ mk_split_asm_tac ctxt split_thm)
+ |> singleton (Proof_Context.export names_lthy lthy)
+ in
+ (split_thm, split_asm_thm)
+ end;
+
+ (* TODO: case syntax *)
+ (* TODO: attributes (simp, case_names, etc.) *)
- fun note thmN thms =
- snd o Local_Theory.note
- ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), thms);
+ val notes =
+ [(case_congN, [case_cong_thm]),
+ (case_discsN, [case_disc_thm]),
+ (casesN, case_thms),
+ (ctr_selsN, ctr_sel_thms),
+ (discsN, (flat disc_thmss)),
+ (disc_exclusN, disc_exclus_thms),
+ (disc_exhaustN, [disc_exhaust_thm]),
+ (distinctN, distinct_thms),
+ (exhaustN, [exhaust_thm]),
+ (injectN, (flat inject_thmss)),
+ (nchotomyN, [nchotomy_thm]),
+ (selsN, (flat sel_thmss)),
+ (splitN, [split_thm]),
+ (split_asmN, [split_asm_thm]),
+ (weak_case_cong_thmsN, [weak_case_cong_thm])]
+ |> map (fn (thmN, thms) =>
+ ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]));
in
- lthy
- |> note case_congN [case_cong_thm]
- |> note case_discsN [case_disc_thm]
- |> note casesN case_thms
- |> note ctr_selsN ctr_sel_thms
- |> note discsN disc_thms
- |> note disc_disjointN disc_disjoint_thms
- |> note disc_exhaustN [disc_exhaust_thm]
- |> note distinctN (half_distinct_thms @ other_half_distinct_thms)
- |> note exhaustN [exhaust_thm]
- |> note injectN (flat inject_thmss)
- |> note nchotomyN [nchotomy_thm]
- |> note selsN (flat sel_thmss)
- |> note splitN split_thms
- |> note split_asmN split_asm_thms
- |> note weak_case_cong_thmsN [weak_case_cong_thm]
+ lthy |> Local_Theory.notes notes |> snd
end;
in
(goals, after_qed, lthy')
end;
-val parse_binding_list = Parse.$$$ "[" |-- Parse.list Parse.binding --| Parse.$$$ "]";
+val parse_bindings = Parse.$$$ "[" |-- Parse.list Parse.binding --| Parse.$$$ "]";
+
+val parse_bindingss = Parse.$$$ "[" |-- Parse.list parse_bindings --| Parse.$$$ "]";
val bnf_sugar_cmd = (fn (goalss, after_qed, lthy) =>
Proof.theorem NONE after_qed (map (map (rpair [])) goalss) lthy) oo
@@ -354,7 +414,7 @@
val _ =
Outer_Syntax.local_theory_to_proof @{command_spec "bnf_sugar"} "adds sugar on top of a BNF"
(((Parse.$$$ "[" |-- Parse.list Parse.term --| Parse.$$$ "]") -- Parse.term --
- parse_binding_list -- (Parse.$$$ "[" |-- Parse.list parse_binding_list --| Parse.$$$ "]"))
- >> bnf_sugar_cmd);
+ Scan.optional (parse_bindings -- Scan.optional parse_bindingss []) ([], []))
+ >> bnf_sugar_cmd);
end;