--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML Mon Sep 10 17:36:02 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML Mon Sep 10 17:36:02 2012 +0200
@@ -46,6 +46,9 @@
fun pad_list x n xs = xs @ replicate (n - length xs) x;
+fun unflat_lookup _ _ [] = []
+ | unflat_lookup eq ps (xs :: xss) = map (the o AList.lookup eq ps) xs :: unflat_lookup eq ps xss;
+
fun mk_half_pairss' _ [] = []
| mk_half_pairss' indent (y :: ys) =
indent @ fold_rev (cons o single o pair y) ys (mk_half_pairss' ([] :: indent) ys);
@@ -84,15 +87,15 @@
val _ = if n > 0 then () else error "No constructors specified";
- val Type (T_name, As0) = body_type (fastype_of (hd ctrs0));
- val b = Binding.qualified_name T_name;
+ val Type (fpT_name, As0) = body_type (fastype_of (hd ctrs0));
+ val b = Binding.qualified_name fpT_name;
val (As, B) =
no_defs_lthy
|> mk_TFrees (length As0)
||> the_single o fst o mk_TFrees 1;
- val T = Type (T_name, As);
+ val fpT = Type (fpT_name, As);
val ctrs = map (mk_ctr As) ctrs0;
val ctr_Tss = map (binder_types o fastype_of) ctrs;
@@ -146,8 +149,8 @@
||>> mk_Freess "y" ctr_Tss
||>> mk_Frees "f" case_Ts
||>> mk_Frees "g" case_Ts
- ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "v") T
- ||>> yield_singleton (mk_Frees "w") T
+ ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "v") fpT
+ ||>> yield_singleton (mk_Frees "w") fpT
||>> yield_singleton (apfst (op ~~) oo mk_Frees' "P") HOLogic.boolT;
val q = Free (fst p', B --> HOLogic.boolT);
@@ -170,10 +173,7 @@
val exist_xs_v_eq_ctrs =
map2 (fn xctr => fn xs => list_exists_free xs (HOLogic.mk_eq (v, xctr))) xctrs xss;
- fun mk_sel_case_args k xs x T =
- map2 (fn Ts => fn i => if i = k then fold_rev Term.lambda xs x else mk_undef T Ts) ctr_Tss ks;
-
- fun disc_free b = Free (Binding.name_of b, T --> HOLogic.boolT);
+ fun disc_free b = Free (Binding.name_of b, fpT --> HOLogic.boolT);
fun disc_spec b exist_xs_v_eq_ctr = mk_Trueprop_eq (disc_free b $ v, exist_xs_v_eq_ctr);
@@ -186,18 +186,30 @@
fun alternate_disc k =
if n = 2 then Term.lambda v (alternate_disc_lhs (3 - k)) else error "Cannot use \"*\" here"
- fun sel_spec b x xs k =
- let val T' = fastype_of x in
- mk_Trueprop_eq (Free (Binding.name_of b, T --> T') $ v,
- Term.list_comb (mk_case As T', mk_sel_case_args k xs x T') $ v)
+ fun mk_sel_case_args proto_sels T =
+ map2 (fn Ts => fn i =>
+ case AList.lookup (op =) proto_sels i of
+ NONE => mk_undef T Ts
+ | SOME (xs, x) => fold_rev Term.lambda xs x) ctr_Tss ks;
+
+ (* TODO: check types of tail of list *)
+ fun sel_spec b (proto_sels as ((_, (_, x)) :: _)) =
+ let val T = fastype_of x in
+ mk_Trueprop_eq (Free (Binding.name_of b, fpT --> T) $ v,
+ Term.list_comb (mk_case As T, mk_sel_case_args proto_sels T) $ v)
end;
val missing_unique_disc_def = TrueI; (*arbitrary marker*)
val missing_alternate_disc_def = FalseE; (*arbitrary marker*)
- (* TODO: Allow use of same selector for several constructors *)
+ val proto_selss = map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss;
- val (((raw_discs, raw_disc_defs), (raw_selss, raw_sel_defss)), (lthy', lthy)) =
+ val sel_bundles = AList.group Binding.eq_name (flat sel_binderss ~~ flat proto_selss);
+ val sel_binders = map fst sel_bundles;
+
+ fun unflat_sels xs = unflat_lookup Binding.eq_name (sel_binders ~~ xs) sel_binderss;
+
+ val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) =
no_defs_lthy
|> apfst split_list o fold_map4 (fn k => fn m => fn exist_xs_v_eq_ctr =>
fn NONE =>
@@ -207,19 +219,19 @@
| SOME b => Specification.definition (SOME (b, NONE, NoSyn),
((Thm.def_binding b, []), disc_spec b exist_xs_v_eq_ctr)) #>> apsnd snd)
ks ms exist_xs_v_eq_ctrs disc_binders
- ||>> apfst split_list o fold_map3 (fn bs => fn xs => fn k => apfst split_list o
- fold_map2 (fn b => fn x => Specification.definition (SOME (b, NONE, NoSyn),
- ((Thm.def_binding b, []), sel_spec b x xs k)) #>> apsnd snd) bs xs) sel_binderss xss ks
+ ||>> apfst split_list o fold_map (fn (b, proto_sels) =>
+ Specification.definition (SOME (b, NONE, NoSyn),
+ ((Thm.def_binding b, []), sel_spec b proto_sels)) #>> apsnd snd) sel_bundles
||> `Local_Theory.restore;
(*transforms defined frees into consts (and more)*)
val phi = Proof_Context.export_morphism lthy lthy';
val disc_defs = map (Morphism.thm phi) raw_disc_defs;
- val sel_defss = map (map (Morphism.thm phi)) raw_sel_defss;
+ val sel_defss = unflat_sels (map (Morphism.thm phi) raw_sel_defs);
val discs0 = map (Morphism.term phi) raw_discs;
- val selss0 = map (map (Morphism.term phi)) raw_selss;
+ val selss0 = unflat_sels (map (Morphism.term phi) raw_sels);
fun mk_disc_or_sel Ts c =
Term.subst_atomic_types (snd (Term.dest_Type (domain_type (fastype_of c))) ~~ Ts) c;
@@ -288,23 +300,8 @@
end;
val sel_thmss =
- let
- fun mk_thm k xs goal_case case_thm x sel_def =
- let
- val T = fastype_of x;
- val cTs =
- map ((fn T' => certifyT lthy (if T' = B then T else T')) o TFree)
- (rev (Term.add_tfrees goal_case []));
- val cxs = map (certify lthy) (mk_sel_case_args k xs x T);
- in
- Local_Defs.fold lthy [sel_def]
- (Drule.instantiate' (map SOME cTs) (map SOME cxs) case_thm)
- end;
- fun mk_thms k xs goal_case case_thm sel_defs =
- map2 (mk_thm k xs (strip_all_body goal_case) case_thm) xs sel_defs;
- in
- map5 mk_thms ks xss goal_cases case_thms sel_defss
- end;
+ map2 (fn case_thm => map (fn sel_def => case_thm RS (sel_def RS trans))) case_thms
+ sel_defss;
fun mk_unique_disc_def () =
let