|
1 (* Title: HOL/BNF/Tools/bnf_fp_n2m_sugar.ML |
|
2 Author: Jasmin Blanchette, TU Muenchen |
|
3 Copyright 2013 |
|
4 |
|
5 Suggared flattening of nested to mutual (co)recursion. |
|
6 *) |
|
7 |
|
8 signature BNF_FP_N2M_SUGAR = |
|
9 sig |
|
10 val mutualize_fp_sugars: bool -> bool -> BNF_FP_Util.fp_kind -> binding list -> typ list -> |
|
11 (term -> int list) -> term list list list list -> BNF_FP_Def_Sugar.fp_sugar list -> |
|
12 local_theory -> (bool * BNF_FP_Def_Sugar.fp_sugar list) * local_theory |
|
13 val pad_and_indexify_calls: BNF_FP_Def_Sugar.fp_sugar list -> int -> |
|
14 (term * term list list) list list -> term list list list list |
|
15 val nested_to_mutual_fps: bool -> BNF_FP_Util.fp_kind -> binding list -> typ list -> |
|
16 (term -> int list) -> ((term * term list list) list) list -> local_theory -> |
|
17 (bool * typ list * int list * BNF_FP_Def_Sugar.fp_sugar list) * local_theory |
|
18 end; |
|
19 |
|
20 structure BNF_FP_N2M_Sugar : BNF_FP_N2M_SUGAR = |
|
21 struct |
|
22 |
|
23 open BNF_Util |
|
24 open BNF_Def |
|
25 open BNF_Ctr_Sugar |
|
26 open BNF_FP_Util |
|
27 open BNF_FP_Def_Sugar |
|
28 open BNF_FP_N2M |
|
29 |
|
30 val n2mN = "n2m_" |
|
31 |
|
32 (* TODO: test with sort constraints on As *) |
|
33 (* TODO: use right sorting order for "fp_sort" w.r.t. original BNFs (?) -- treat new variables |
|
34 as deads? *) |
|
35 fun mutualize_fp_sugars lose_co_rec mutualize fp bs fpTs get_indices callssss fp_sugars0 |
|
36 no_defs_lthy0 = |
|
37 (* TODO: Also check whether there's any lost recursion? *) |
|
38 if mutualize orelse has_duplicates (op =) fpTs then |
|
39 let |
|
40 val thy = Proof_Context.theory_of no_defs_lthy0; |
|
41 |
|
42 val qsotm = quote o Syntax.string_of_term no_defs_lthy0; |
|
43 |
|
44 fun heterogeneous_call t = error ("Heterogeneous recursive call: " ^ qsotm t); |
|
45 fun incompatible_calls t1 t2 = |
|
46 error ("Incompatible recursive calls: " ^ qsotm t1 ^ " vs. " ^ qsotm t2); |
|
47 |
|
48 val b_names = map Binding.name_of bs; |
|
49 val fp_b_names = map base_name_of_typ fpTs; |
|
50 |
|
51 val nn = length fpTs; |
|
52 |
|
53 fun target_ctr_sugar_of_fp_sugar fpT {T, index, ctr_sugars, ...} = |
|
54 let |
|
55 val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T, fpT) Vartab.empty) []; |
|
56 val phi = Morphism.term_morphism (Term.subst_TVars rho); |
|
57 in |
|
58 morph_ctr_sugar phi (nth ctr_sugars index) |
|
59 end; |
|
60 |
|
61 val ctr_defss = map (of_fp_sugar #ctr_defss) fp_sugars0; |
|
62 val ctr_sugars0 = map2 target_ctr_sugar_of_fp_sugar fpTs fp_sugars0; |
|
63 |
|
64 val ctrss = map #ctrs ctr_sugars0; |
|
65 val ctr_Tss = map (map fastype_of) ctrss; |
|
66 |
|
67 val As' = fold (fold Term.add_tfreesT) ctr_Tss []; |
|
68 val As = map TFree As'; |
|
69 |
|
70 val ((Cs, Xs), no_defs_lthy) = |
|
71 no_defs_lthy0 |
|
72 |> fold Variable.declare_typ As |
|
73 |> mk_TFrees nn |
|
74 ||>> variant_tfrees fp_b_names; |
|
75 |
|
76 (* If "lose_co_rec" is "true", the function "null" on "'a list" gives rise to |
|
77 'list = unit + 'a list |
|
78 instead of |
|
79 'list = unit + 'list |
|
80 resulting in a simpler (co)induction rule and (co)recursor. *) |
|
81 fun freeze_fp_default (T as Type (s, Ts)) = |
|
82 (case find_index (curry (op =) T) fpTs of |
|
83 ~1 => Type (s, map freeze_fp_default Ts) |
|
84 | kk => nth Xs kk) |
|
85 | freeze_fp_default T = T; |
|
86 |
|
87 fun get_indices_checked call = |
|
88 (case get_indices call of |
|
89 _ :: _ :: _ => heterogeneous_call call |
|
90 | kks => kks); |
|
91 |
|
92 fun freeze_fp calls (T as Type (s, Ts)) = |
|
93 (case map_filter (try (snd o dest_map no_defs_lthy s)) calls of |
|
94 [] => |
|
95 (case union (op = o pairself fst) |
|
96 (maps (fn call => map (rpair call) (get_indices_checked call)) calls) [] of |
|
97 [] => T |> not lose_co_rec ? freeze_fp_default |
|
98 | [(kk, _)] => nth Xs kk |
|
99 | (_, call1) :: (_, call2) :: _ => incompatible_calls call1 call2) |
|
100 | callss => |
|
101 Type (s, map2 freeze_fp (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) [] |
|
102 (transpose callss)) Ts)) |
|
103 | freeze_fp _ T = T; |
|
104 |
|
105 val ctr_Tsss = map (map binder_types) ctr_Tss; |
|
106 val ctrXs_Tsss = map2 (map2 (map2 freeze_fp)) callssss ctr_Tsss; |
|
107 val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss; |
|
108 val Ts = map (body_type o hd) ctr_Tss; |
|
109 |
|
110 val ns = map length ctr_Tsss; |
|
111 val kss = map (fn n => 1 upto n) ns; |
|
112 val mss = map (map length) ctr_Tsss; |
|
113 |
|
114 val fp_eqs = map dest_TFree Xs ~~ ctrXs_sum_prod_Ts; |
|
115 |
|
116 val base_fp_names = Name.variant_list [] fp_b_names; |
|
117 val fp_bs = map2 (fn b_name => fn base_fp_name => |
|
118 Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name))) |
|
119 b_names base_fp_names; |
|
120 |
|
121 val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct, |
|
122 dtor_injects, dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) = |
|
123 fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy; |
|
124 |
|
125 val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As; |
|
126 val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs; |
|
127 |
|
128 val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) = |
|
129 mk_co_iters_prelims fp fpTs Cs ns mss xtor_co_iterss0 lthy; |
|
130 |
|
131 fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b; |
|
132 |
|
133 val ((co_iterss, co_iter_defss), lthy) = |
|
134 fold_map2 (fn b => |
|
135 (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types) |
|
136 else define_coiters [unfoldN, corecN] (the coiters_args_types)) |
|
137 (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy |
|
138 |>> split_list; |
|
139 |
|
140 val rho = tvar_subst thy Ts fpTs; |
|
141 val ctr_sugar_phi = |
|
142 Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho)) |
|
143 (Morphism.term_morphism (Term.subst_TVars rho)); |
|
144 val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi; |
|
145 |
|
146 val ctr_sugars = map inst_ctr_sugar ctr_sugars0; |
|
147 |
|
148 val (co_inducts, un_fold_thmss, co_rec_thmss) = |
|
149 if fp = Least_FP then |
|
150 derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct |
|
151 xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss |
|
152 co_iterss co_iter_defss lthy |
|
153 |> (fn ((_, induct, _), (fold_thmss, _), (rec_thmss, _)) => |
|
154 ([induct], fold_thmss, rec_thmss)) |
|
155 else |
|
156 derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types) xtor_co_induct |
|
157 dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs kss mss ns ctr_defss |
|
158 ctr_sugars co_iterss co_iter_defss (Proof_Context.export lthy no_defs_lthy) lthy |
|
159 |> (fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _), _, _, _, _) => |
|
160 (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss)); |
|
161 |
|
162 val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0; |
|
163 |
|
164 fun mk_target_fp_sugar (kk, T) = |
|
165 {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs, |
|
166 nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss, |
|
167 ctr_sugars = ctr_sugars, co_inducts = co_inducts, co_iterss = co_iterss, |
|
168 co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss]} |
|
169 |> morph_fp_sugar phi; |
|
170 in |
|
171 ((true, map_index mk_target_fp_sugar fpTs), lthy) |
|
172 end |
|
173 else |
|
174 (* TODO: reorder hypotheses and predicates in (co)induction rules? *) |
|
175 ((false, fp_sugars0), no_defs_lthy0); |
|
176 |
|
177 fun indexify_callsss fp_sugar callsss = |
|
178 let |
|
179 val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar; |
|
180 fun do_ctr ctr = |
|
181 (case AList.lookup Term.aconv_untyped callsss ctr of |
|
182 NONE => replicate (num_binder_types (fastype_of ctr)) [] |
|
183 | SOME callss => map (map Envir.beta_eta_contract) callss); |
|
184 in |
|
185 map do_ctr ctrs |
|
186 end; |
|
187 |
|
188 fun pad_and_indexify_calls fp_sugars0 = map2 indexify_callsss fp_sugars0 oo pad_list []; |
|
189 |
|
190 fun nested_to_mutual_fps lose_co_rec fp actual_bs actual_Ts get_indices actual_callssss0 lthy = |
|
191 let |
|
192 val qsoty = quote o Syntax.string_of_typ lthy; |
|
193 val qsotys = space_implode " or " o map qsoty; |
|
194 |
|
195 fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype"); |
|
196 fun not_co_datatype (T as Type (s, _)) = |
|
197 if fp = Least_FP andalso |
|
198 is_some (Datatype_Data.get_info (Proof_Context.theory_of lthy) s) then |
|
199 error (qsoty T ^ " is not a new-style datatype (cf. \"datatype_new\")") |
|
200 else |
|
201 not_co_datatype0 T |
|
202 | not_co_datatype T = not_co_datatype0 T; |
|
203 fun not_mutually_nested_rec Ts1 Ts2 = |
|
204 error (qsotys Ts1 ^ " is neither mutually recursive with nor nested recursive via " ^ |
|
205 qsotys Ts2); |
|
206 |
|
207 val perm_actual_Ts as Type (_, ty_args0) :: _ = |
|
208 sort (int_ord o pairself Term.size_of_typ) actual_Ts; |
|
209 |
|
210 fun check_enrich_with_mutuals _ [] = [] |
|
211 | check_enrich_with_mutuals seen ((T as Type (T_name, ty_args)) :: Ts) = |
|
212 (case fp_sugar_of lthy T_name of |
|
213 SOME ({fp = fp', fp_res = {Ts = Ts', ...}, ...}) => |
|
214 if fp = fp' then |
|
215 let |
|
216 val mutual_Ts = map (fn Type (s, _) => Type (s, ty_args)) Ts'; |
|
217 val _ = |
|
218 seen = [] orelse exists (exists_subtype_in seen) mutual_Ts orelse |
|
219 not_mutually_nested_rec mutual_Ts seen; |
|
220 val (seen', Ts') = List.partition (member (op =) mutual_Ts) Ts; |
|
221 in |
|
222 mutual_Ts @ check_enrich_with_mutuals (seen @ T :: seen') Ts' |
|
223 end |
|
224 else |
|
225 not_co_datatype T |
|
226 | NONE => not_co_datatype T) |
|
227 | check_enrich_with_mutuals _ (T :: _) = not_co_datatype T; |
|
228 |
|
229 val perm_Ts = check_enrich_with_mutuals [] perm_actual_Ts; |
|
230 |
|
231 val missing_Ts = perm_Ts |> subtract (op =) actual_Ts; |
|
232 val Ts = actual_Ts @ missing_Ts; |
|
233 |
|
234 val nn = length Ts; |
|
235 val kks = 0 upto nn - 1; |
|
236 |
|
237 val common_name = mk_common_name (map Binding.name_of actual_bs); |
|
238 val bs = pad_list (Binding.name common_name) nn actual_bs; |
|
239 |
|
240 fun permute xs = permute_like (op =) Ts perm_Ts xs; |
|
241 fun unpermute perm_xs = permute_like (op =) perm_Ts Ts perm_xs; |
|
242 |
|
243 val perm_bs = permute bs; |
|
244 val perm_kks = permute kks; |
|
245 val perm_fp_sugars0 = map (the o fp_sugar_of lthy o fst o dest_Type) perm_Ts; |
|
246 |
|
247 val mutualize = exists (fn Type (_, ty_args) => ty_args <> ty_args0) Ts; |
|
248 val perm_callssss = pad_and_indexify_calls perm_fp_sugars0 nn actual_callssss0; |
|
249 |
|
250 val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices; |
|
251 |
|
252 val ((nontriv, perm_fp_sugars), lthy) = |
|
253 mutualize_fp_sugars lose_co_rec mutualize fp perm_bs perm_Ts get_perm_indices perm_callssss |
|
254 perm_fp_sugars0 lthy; |
|
255 |
|
256 val fp_sugars = unpermute perm_fp_sugars; |
|
257 in |
|
258 ((nontriv, missing_Ts, perm_kks, fp_sugars), lthy) |
|
259 end; |
|
260 |
|
261 end; |