20 open BNF_FP_Sugar_Tactics |
20 open BNF_FP_Sugar_Tactics |
21 |
21 |
22 val caseN = "case"; |
22 val caseN = "case"; |
23 |
23 |
24 fun retype_free (Free (s, _)) T = Free (s, T); |
24 fun retype_free (Free (s, _)) T = Free (s, T); |
|
25 |
|
26 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs)); |
|
27 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs; |
|
28 fun mk_doubly_uncurried_fun f xss = |
|
29 mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss); |
25 |
30 |
26 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters"; |
31 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters"; |
27 |
32 |
28 fun merge_type_arg_constrained ctxt (T, c) (T', c') = |
33 fun merge_type_arg_constrained ctxt (T, c) (T', c') = |
29 if T = T' then |
34 if T = T' then |
240 |
245 |
241 val ns = map length ctr_Tsss; |
246 val ns = map length ctr_Tsss; |
242 val mss = map (map length) ctr_Tsss; |
247 val mss = map (map length) ctr_Tsss; |
243 val Css = map2 replicate ns Cs; |
248 val Css = map2 replicate ns Cs; |
244 |
249 |
|
250 fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) = |
|
251 if member (op =) Cs U then Us else [T] |
|
252 | dest_rec_pair T = [T]; |
|
253 |
245 fun sugar_datatype no_defs_lthy = |
254 fun sugar_datatype no_defs_lthy = |
246 let |
255 let |
247 val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter)))); |
256 val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter)))); |
248 val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts; |
257 val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts; |
249 val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss; |
258 val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss; |
251 val iter_T = flat g_Tss ---> fp_T --> C; |
260 val iter_T = flat g_Tss ---> fp_T --> C; |
252 |
261 |
253 val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec)))); |
262 val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec)))); |
254 val z_prod_Tss = map2 dest_sumTN ns fp_z_Ts; |
263 val z_prod_Tss = map2 dest_sumTN ns fp_z_Ts; |
255 val z_Tsss = map2 (map2 dest_tupleT) mss z_prod_Tss; |
264 val z_Tsss = map2 (map2 dest_tupleT) mss z_prod_Tss; |
256 val h_Tss = map2 (map2 (curry (op --->))) z_Tsss Css; |
265 val z_Tssss = map (map (map dest_rec_pair)) z_Tsss; |
|
266 val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css; |
257 val rec_T = flat h_Tss ---> fp_T --> C; |
267 val rec_T = flat h_Tss ---> fp_T --> C; |
258 |
268 |
259 val ((gss, ysss), _) = |
269 val ((gss, ysss), _) = |
260 no_defs_lthy |
270 no_defs_lthy |
261 |> mk_Freess "f" g_Tss |
271 |> mk_Freess "f" g_Tss |
262 ||>> mk_Freesss "x" y_Tsss; |
272 ||>> mk_Freesss "x" y_Tsss; |
263 |
273 |
264 val hss = map2 (map2 retype_free) gss h_Tss; |
274 val hss = map2 (map2 retype_free) gss h_Tss; |
265 val (zsss, _) = |
275 val (zssss, _) = |
266 no_defs_lthy |
276 no_defs_lthy |
267 |> mk_Freesss "x" z_Tsss; |
277 |> mk_Freessss "x" z_Tssss; |
268 |
278 |
269 val iter_binder = Binding.suffix_name ("_" ^ iterN) b; |
279 val iter_binder = Binding.suffix_name ("_" ^ iterN) b; |
270 val rec_binder = Binding.suffix_name ("_" ^ recN) b; |
280 val rec_binder = Binding.suffix_name ("_" ^ recN) b; |
271 |
281 |
272 val iter_free = Free (Binding.name_of iter_binder, iter_T); |
282 val iter_free = Free (Binding.name_of iter_binder, iter_T); |
275 val iter_spec = |
285 val iter_spec = |
276 mk_Trueprop_eq (fold (fn gs => fn t => Term.list_comb (t, gs)) gss iter_free, |
286 mk_Trueprop_eq (fold (fn gs => fn t => Term.list_comb (t, gs)) gss iter_free, |
277 Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss)); |
287 Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss)); |
278 val rec_spec = |
288 val rec_spec = |
279 mk_Trueprop_eq (fold (fn hs => fn t => Term.list_comb (t, hs)) hss rec_free, |
289 mk_Trueprop_eq (fold (fn hs => fn t => Term.list_comb (t, hs)) hss rec_free, |
280 Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) hss zsss)); |
290 Term.list_comb (fp_rec, |
|
291 map2 (mk_sum_caseN oo map2 mk_doubly_uncurried_fun) hss zssss)); |
281 |
292 |
282 val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy |
293 val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy |
283 |> apfst split_list o fold_map (fn (b, spec) => |
294 |> apfst split_list o fold_map (fn (b, spec) => |
284 Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec)) |
295 Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec)) |
285 #>> apsnd snd) [(iter_binder, iter_spec), (rec_binder, rec_spec)] |
296 #>> apsnd snd) [(iter_binder, iter_spec), (rec_binder, rec_spec)] |