5 Wrapping existing datatypes. |
5 Wrapping existing datatypes. |
6 *) |
6 *) |
7 |
7 |
8 signature BNF_WRAP = |
8 signature BNF_WRAP = |
9 sig |
9 sig |
10 val no_binder: binding |
10 val no_binding: binding |
11 val mk_half_pairss: 'a list -> ('a * 'a) list list |
11 val mk_half_pairss: 'a list -> ('a * 'a) list list |
12 val mk_ctr: typ list -> term -> term |
12 val mk_ctr: typ list -> term -> term |
13 val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list -> |
13 val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list -> |
14 ((bool * term list) * term) * |
14 ((bool * term list) * term) * |
15 (binding list * (binding list list * (binding * term) list list)) -> local_theory -> |
15 (binding list * (binding list list * (binding * term) list list)) -> local_theory -> |
43 val selsN = "sels"; |
43 val selsN = "sels"; |
44 val splitN = "split"; |
44 val splitN = "split"; |
45 val split_asmN = "split_asm"; |
45 val split_asmN = "split_asm"; |
46 val weak_case_cong_thmsN = "weak_case_cong"; |
46 val weak_case_cong_thmsN = "weak_case_cong"; |
47 |
47 |
48 val no_binder = @{binding ""}; |
48 val no_binding = @{binding ""}; |
49 val std_binder = @{binding _}; |
49 val std_binding = @{binding _}; |
50 |
50 |
51 val induct_simp_attrs = @{attributes [induct_simp]}; |
51 val induct_simp_attrs = @{attributes [induct_simp]}; |
52 val cong_attrs = @{attributes [cong]}; |
52 val cong_attrs = @{attributes [cong]}; |
53 val iff_attrs = @{attributes [iff]}; |
53 val iff_attrs = @{attributes [iff]}; |
54 val safe_elim_attrs = @{attributes [elim!]}; |
54 val safe_elim_attrs = @{attributes [elim!]}; |
78 Const (s, _) => s |
78 Const (s, _) => s |
79 | Free (s, _) => s |
79 | Free (s, _) => s |
80 | _ => error "Cannot extract name of constructor"); |
80 | _ => error "Cannot extract name of constructor"); |
81 |
81 |
82 fun prepare_wrap_datatype prep_term (((no_dests, raw_ctrs), raw_case), |
82 fun prepare_wrap_datatype prep_term (((no_dests, raw_ctrs), raw_case), |
83 (raw_disc_binders, (raw_sel_binderss, raw_sel_defaultss))) no_defs_lthy = |
83 (raw_disc_bindings, (raw_sel_bindingss, raw_sel_defaultss))) no_defs_lthy = |
84 let |
84 let |
85 (* TODO: sanity checks on arguments *) |
85 (* TODO: sanity checks on arguments *) |
86 (* TODO: attributes (simp, case_names, etc.) *) |
86 (* TODO: attributes (simp, case_names, etc.) *) |
87 (* TODO: case syntax *) |
87 (* TODO: case syntax *) |
88 (* TODO: integration with function package ("size") *) |
88 (* TODO: integration with function package ("size") *) |
109 val ctrs = map (mk_ctr As) ctrs0; |
109 val ctrs = map (mk_ctr As) ctrs0; |
110 val ctr_Tss = map (binder_types o fastype_of) ctrs; |
110 val ctr_Tss = map (binder_types o fastype_of) ctrs; |
111 |
111 |
112 val ms = map length ctr_Tss; |
112 val ms = map length ctr_Tss; |
113 |
113 |
114 val raw_disc_binders' = pad_list no_binder n raw_disc_binders; |
114 val raw_disc_bindings' = pad_list no_binding n raw_disc_bindings; |
115 |
115 |
116 fun can_really_rely_on_disc k = |
116 fun can_really_rely_on_disc k = |
117 not (Binding.eq_name (nth raw_disc_binders' (k - 1), no_binder)) orelse nth ms (k - 1) = 0; |
117 not (Binding.eq_name (nth raw_disc_bindings' (k - 1), no_binding)) orelse nth ms (k - 1) = 0; |
118 fun can_rely_on_disc k = |
118 fun can_rely_on_disc k = |
119 can_really_rely_on_disc k orelse (k = 1 andalso not (can_really_rely_on_disc 2)); |
119 can_really_rely_on_disc k orelse (k = 1 andalso not (can_really_rely_on_disc 2)); |
120 fun can_omit_disc_binder k m = |
120 fun can_omit_disc_binding k m = |
121 n = 1 orelse m = 0 orelse (n = 2 andalso can_rely_on_disc (3 - k)); |
121 n = 1 orelse m = 0 orelse (n = 2 andalso can_rely_on_disc (3 - k)); |
122 |
122 |
123 val std_disc_binder = |
123 val std_disc_binding = |
124 Binding.qualify false (Binding.name_of data_b) o Binding.name o prefix isN o base_name_of_ctr; |
124 Binding.qualify false (Binding.name_of data_b) o Binding.name o prefix isN o base_name_of_ctr; |
125 |
125 |
126 val disc_binders = |
126 val disc_bindings = |
127 raw_disc_binders' |
127 raw_disc_bindings' |
128 |> map4 (fn k => fn m => fn ctr => fn disc => |
128 |> map4 (fn k => fn m => fn ctr => fn disc => |
129 Option.map (Binding.qualify false (Binding.name_of data_b)) |
129 Option.map (Binding.qualify false (Binding.name_of data_b)) |
130 (if Binding.eq_name (disc, no_binder) then |
130 (if Binding.eq_name (disc, no_binding) then |
131 if can_omit_disc_binder k m then NONE else SOME (std_disc_binder ctr) |
131 if can_omit_disc_binding k m then NONE else SOME (std_disc_binding ctr) |
132 else if Binding.eq_name (disc, std_binder) then |
132 else if Binding.eq_name (disc, std_binding) then |
133 SOME (std_disc_binder ctr) |
133 SOME (std_disc_binding ctr) |
134 else |
134 else |
135 SOME disc)) ks ms ctrs0; |
135 SOME disc)) ks ms ctrs0; |
136 |
136 |
137 val no_discs = map is_none disc_binders; |
137 val no_discs = map is_none disc_bindings; |
138 val no_discs_at_all = forall I no_discs; |
138 val no_discs_at_all = forall I no_discs; |
139 |
139 |
140 fun std_sel_binder m l = Binding.name o mk_unN m l o base_name_of_ctr; |
140 fun std_sel_binding m l = Binding.name o mk_unN m l o base_name_of_ctr; |
141 |
141 |
142 val sel_binderss = |
142 val sel_bindingss = |
143 pad_list [] n raw_sel_binderss |
143 pad_list [] n raw_sel_bindingss |
144 |> map3 (fn ctr => fn m => map2 (fn l => fn sel => |
144 |> map3 (fn ctr => fn m => map2 (fn l => fn sel => |
145 Binding.qualify false (Binding.name_of data_b) |
145 Binding.qualify false (Binding.name_of data_b) |
146 (if Binding.eq_name (sel, no_binder) orelse Binding.eq_name (sel, std_binder) then |
146 (if Binding.eq_name (sel, no_binding) orelse Binding.eq_name (sel, std_binding) then |
147 std_sel_binder m l ctr |
147 std_sel_binding m l ctr |
148 else |
148 else |
149 sel)) (1 upto m) o pad_list no_binder m) ctrs0 ms; |
149 sel)) (1 upto m) o pad_list no_binding m) ctrs0 ms; |
150 |
150 |
151 fun mk_case Ts T = |
151 fun mk_case Ts T = |
152 let |
152 let |
153 val (binders, body) = strip_type (fastype_of case0) |
153 val (bindings, body) = strip_type (fastype_of case0) |
154 val Type (_, Ts0) = List.last binders |
154 val Type (_, Ts0) = List.last bindings |
155 in Term.subst_atomic_types ((body, T) :: (Ts0 ~~ Ts)) case0 end; |
155 in Term.subst_atomic_types ((body, T) :: (Ts0 ~~ Ts)) case0 end; |
156 |
156 |
157 val casex = mk_case As B; |
157 val casex = mk_case As B; |
158 val case_Ts = map (fn Ts => Ts ---> B) ctr_Tss; |
158 val case_Ts = map (fn Ts => Ts ---> B) ctr_Tss; |
159 |
159 |
189 val unique_disc_no_def = TrueI; (*arbitrary marker*) |
189 val unique_disc_no_def = TrueI; (*arbitrary marker*) |
190 val alternate_disc_no_def = FalseE; (*arbitrary marker*) |
190 val alternate_disc_no_def = FalseE; (*arbitrary marker*) |
191 |
191 |
192 fun alternate_disc_lhs get_disc k = |
192 fun alternate_disc_lhs get_disc k = |
193 HOLogic.mk_not |
193 HOLogic.mk_not |
194 (case nth disc_binders (k - 1) of |
194 (case nth disc_bindings (k - 1) of |
195 NONE => nth exist_xs_v_eq_ctrs (k - 1) |
195 NONE => nth exist_xs_v_eq_ctrs (k - 1) |
196 | SOME b => get_disc b (k - 1) $ v); |
196 | SOME b => get_disc b (k - 1) $ v); |
197 |
197 |
198 val (all_sels_distinct, discs, selss, disc_defs, sel_defs, sel_defss, lthy') = |
198 val (all_sels_distinct, discs, selss, disc_defs, sel_defs, sel_defss, lthy') = |
199 if no_dests then |
199 if no_dests then |
235 in |
235 in |
236 mk_Trueprop_eq (Free (Binding.name_of b, dataT --> T) $ v, |
236 mk_Trueprop_eq (Free (Binding.name_of b, dataT --> T) $ v, |
237 Term.list_comb (mk_case As T, mk_sel_case_args b proto_sels T) $ v) |
237 Term.list_comb (mk_case As T, mk_sel_case_args b proto_sels T) $ v) |
238 end; |
238 end; |
239 |
239 |
240 val sel_binders = flat sel_binderss; |
240 val sel_bindings = flat sel_bindingss; |
241 val uniq_sel_binders = distinct Binding.eq_name sel_binders; |
241 val uniq_sel_bindings = distinct Binding.eq_name sel_bindings; |
242 val all_sels_distinct = (length uniq_sel_binders = length sel_binders); |
242 val all_sels_distinct = (length uniq_sel_bindings = length sel_bindings); |
243 |
243 |
244 val sel_binder_index = |
244 val sel_binding_index = |
245 if all_sels_distinct then 1 upto length sel_binders |
245 if all_sels_distinct then 1 upto length sel_bindings |
246 else map (fn b => find_index (curry Binding.eq_name b) uniq_sel_binders) sel_binders; |
246 else map (fn b => find_index (curry Binding.eq_name b) uniq_sel_bindings) sel_bindings; |
247 |
247 |
248 val proto_sels = flat (map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss); |
248 val proto_sels = flat (map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss); |
249 val sel_infos = |
249 val sel_infos = |
250 AList.group (op =) (sel_binder_index ~~ proto_sels) |
250 AList.group (op =) (sel_binding_index ~~ proto_sels) |
251 |> sort (int_ord o pairself fst) |
251 |> sort (int_ord o pairself fst) |
252 |> map snd |> curry (op ~~) uniq_sel_binders; |
252 |> map snd |> curry (op ~~) uniq_sel_bindings; |
253 val sel_binders = map fst sel_infos; |
253 val sel_bindings = map fst sel_infos; |
254 |
254 |
255 fun unflat_selss xs = unflat_lookup Binding.eq_name sel_binders xs sel_binderss; |
255 fun unflat_selss xs = unflat_lookup Binding.eq_name sel_bindings xs sel_bindingss; |
256 |
256 |
257 val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) = |
257 val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) = |
258 no_defs_lthy |
258 no_defs_lthy |
259 |> apfst split_list o fold_map4 (fn k => fn m => fn exist_xs_v_eq_ctr => |
259 |> apfst split_list o fold_map4 (fn k => fn m => fn exist_xs_v_eq_ctr => |
260 fn NONE => |
260 fn NONE => |
261 if n = 1 then pair (Term.lambda v (mk_v_eq_v ()), unique_disc_no_def) |
261 if n = 1 then pair (Term.lambda v (mk_v_eq_v ()), unique_disc_no_def) |
262 else if m = 0 then pair (Term.lambda v exist_xs_v_eq_ctr, refl) |
262 else if m = 0 then pair (Term.lambda v exist_xs_v_eq_ctr, refl) |
263 else pair (alternate_disc k, alternate_disc_no_def) |
263 else pair (alternate_disc k, alternate_disc_no_def) |
264 | SOME b => Specification.definition (SOME (b, NONE, NoSyn), |
264 | SOME b => Specification.definition (SOME (b, NONE, NoSyn), |
265 ((Thm.def_binding b, []), disc_spec b exist_xs_v_eq_ctr)) #>> apsnd snd) |
265 ((Thm.def_binding b, []), disc_spec b exist_xs_v_eq_ctr)) #>> apsnd snd) |
266 ks ms exist_xs_v_eq_ctrs disc_binders |
266 ks ms exist_xs_v_eq_ctrs disc_bindings |
267 ||>> apfst split_list o fold_map (fn (b, proto_sels) => |
267 ||>> apfst split_list o fold_map (fn (b, proto_sels) => |
268 Specification.definition (SOME (b, NONE, NoSyn), |
268 Specification.definition (SOME (b, NONE, NoSyn), |
269 ((Thm.def_binding b, []), sel_spec b proto_sels)) #>> apsnd snd) sel_infos |
269 ((Thm.def_binding b, []), sel_spec b proto_sels)) #>> apsnd snd) sel_infos |
270 ||> `Local_Theory.restore; |
270 ||> `Local_Theory.restore; |
271 |
271 |