|
1 (* Title: Pure/Tools/codegen_funcgr.ML |
|
2 ID: $Id$ |
|
3 Author: Florian Haftmann, TU Muenchen |
|
4 |
|
5 Retrieving and structuring code function theorems. |
|
6 *) |
|
7 |
|
8 signature CODEGEN_FUNCGR = |
|
9 sig |
|
10 type T; |
|
11 val mk_funcgr: theory -> CodegenConsts.const list -> (string * typ) list -> T |
|
12 val get_funcs: T -> CodegenConsts.const -> thm list |
|
13 val get_func_typs: T -> (CodegenConsts.const * typ) list |
|
14 val preprocess: theory -> thm list -> thm list |
|
15 val print_codethms: theory -> CodegenConsts.const list -> unit |
|
16 end; |
|
17 |
|
18 structure CodegenFuncgr: CODEGEN_FUNCGR = |
|
19 struct |
|
20 |
|
21 (** code data **) |
|
22 |
|
23 structure Consttab = CodegenConsts.Consttab; |
|
24 structure Constgraph = GraphFun ( |
|
25 type key = CodegenConsts.const; |
|
26 val ord = CodegenConsts.const_ord; |
|
27 ); |
|
28 |
|
29 type T = (typ * thm list) Constgraph.T; |
|
30 |
|
31 structure Funcgr = CodeDataFun |
|
32 (struct |
|
33 val name = "Pure/codegen_funcgr"; |
|
34 type T = T; |
|
35 val empty = Constgraph.empty; |
|
36 fun merge _ _ = Constgraph.empty; |
|
37 fun purge _ _ = Constgraph.empty; |
|
38 end); |
|
39 |
|
40 val _ = Context.add_setup Funcgr.init; |
|
41 |
|
42 |
|
43 (** theorem purification **) |
|
44 |
|
45 fun abs_norm thy thm = |
|
46 let |
|
47 fun expvars t = |
|
48 let |
|
49 val lhs = (fst o Logic.dest_equals) t; |
|
50 val tys = (fst o strip_type o fastype_of) lhs; |
|
51 val used = fold_aterms (fn Var ((v, _), _) => insert (op =) v | _ => I) lhs []; |
|
52 val vs = Name.invent_list used "x" (length tys); |
|
53 in |
|
54 map2 (fn v => fn ty => Var ((v, 0), ty)) vs tys |
|
55 end; |
|
56 fun expand ct thm = |
|
57 Thm.combination thm (Thm.reflexive ct); |
|
58 fun beta_norm thm = |
|
59 thm |
|
60 |> prop_of |
|
61 |> Logic.dest_equals |
|
62 |> fst |
|
63 |> cterm_of thy |
|
64 |> Thm.beta_conversion true |
|
65 |> Thm.symmetric |
|
66 |> (fn thm' => Thm.transitive thm' thm); |
|
67 in |
|
68 thm |
|
69 |> fold (expand o cterm_of thy) ((expvars o prop_of) thm) |
|
70 |> beta_norm |
|
71 end; |
|
72 |
|
73 fun canonical_tvars thy thm = |
|
74 let |
|
75 fun mk_inst (v_i as (v, i), (v', sort)) (s as (maxidx, set, acc)) = |
|
76 if v = v' orelse member (op =) set v then s |
|
77 else let |
|
78 val ty = TVar (v_i, sort) |
|
79 in |
|
80 (maxidx + 1, v :: set, |
|
81 (ctyp_of thy ty, ctyp_of thy (TVar ((v', maxidx), sort))) :: acc) |
|
82 end; |
|
83 fun tvars_of thm = (fold_types o fold_atyps) |
|
84 (fn TVar (v_i as (v, i), sort) => cons (v_i, (CodegenNames.purify_var v, sort)) |
|
85 | _ => I) (prop_of thm) []; |
|
86 val maxidx = Thm.maxidx_of thm + 1; |
|
87 val (_, _, inst) = fold mk_inst (tvars_of thm) (maxidx + 1, [], []); |
|
88 in Thm.instantiate (inst, []) thm end; |
|
89 |
|
90 fun canonical_vars thy thm = |
|
91 let |
|
92 fun mk_inst (v_i as (v, i), (v', ty)) (s as (maxidx, set, acc)) = |
|
93 if v = v' orelse member (op =) set v then s |
|
94 else let |
|
95 val t = if i = ~1 then Free (v, ty) else Var (v_i, ty) |
|
96 in |
|
97 (maxidx + 1, v :: set, |
|
98 (cterm_of thy t, cterm_of thy (Var ((v', maxidx), ty))) :: acc) |
|
99 end; |
|
100 fun vars_of thm = fold_aterms |
|
101 (fn Var (v_i as (v, i), ty) => cons (v_i, (CodegenNames.purify_var v, ty)) |
|
102 | _ => I) (prop_of thm) []; |
|
103 val maxidx = Thm.maxidx_of thm + 1; |
|
104 val (_, _, inst) = fold mk_inst (vars_of thm) (maxidx + 1, [], []); |
|
105 in Thm.instantiate ([], inst) thm end; |
|
106 |
|
107 fun preprocess thy thms = |
|
108 let |
|
109 fun burrow_thms f [] = [] |
|
110 | burrow_thms f thms = |
|
111 thms |
|
112 |> Conjunction.intr_list |
|
113 |> f |
|
114 |> Conjunction.elim_list; |
|
115 fun unvarify thms = |
|
116 #2 (#1 (Variable.import true thms (ProofContext.init thy))); |
|
117 in |
|
118 thms |
|
119 |> CodegenData.preprocess thy |
|
120 |> map (abs_norm thy) |
|
121 |> burrow_thms ( |
|
122 canonical_tvars thy |
|
123 #> canonical_vars thy |
|
124 #> Drule.zero_var_indexes |
|
125 ) |
|
126 end; |
|
127 |
|
128 fun check_thms c thms = |
|
129 let |
|
130 fun check_head_lhs thm (lhs, rhs) = |
|
131 case strip_comb lhs |
|
132 of (Const (c', _), _) => if c' = c then () |
|
133 else error ("Illegal function equation for " ^ quote c |
|
134 ^ ", actually defining " ^ quote c' ^ ": " ^ Display.string_of_thm thm) |
|
135 | _ => error ("Illegal function equation: " ^ Display.string_of_thm thm); |
|
136 fun check_vars_lhs thm (lhs, rhs) = |
|
137 if has_duplicates (op =) |
|
138 (fold_aterms (fn Free (v, _) => cons v | _ => I) lhs []) |
|
139 then error ("Repeated variables on left hand side of function equation:" |
|
140 ^ Display.string_of_thm thm) |
|
141 else (); |
|
142 fun check_vars_rhs thm (lhs, rhs) = |
|
143 if null (subtract (op =) |
|
144 (fold_aterms (fn Free (v, _) => cons v | _ => I) lhs []) |
|
145 (fold_aterms (fn Free (v, _) => cons v | _ => I) rhs [])) |
|
146 then () |
|
147 else error ("Free variables on right hand side of function equation:" |
|
148 ^ Display.string_of_thm thm) |
|
149 val tts = map (Logic.dest_equals o Logic.unvarify o Thm.prop_of) thms; |
|
150 in |
|
151 (map2 check_head_lhs thms tts; map2 check_vars_lhs thms tts; |
|
152 map2 check_vars_rhs thms tts; thms) |
|
153 end; |
|
154 |
|
155 |
|
156 |
|
157 (** retrieval **) |
|
158 |
|
159 fun get_funcs funcgr (c_tys as (c, _)) = |
|
160 (check_thms c o these o Option.map snd o try (Constgraph.get_node funcgr)) c_tys; |
|
161 |
|
162 fun get_func_typs funcgr = |
|
163 AList.make (fst o Constgraph.get_node funcgr) (Constgraph.keys funcgr); |
|
164 |
|
165 local |
|
166 |
|
167 fun add_things_of thy f (c, thms) = |
|
168 (fold o fold_aterms) |
|
169 (fn Const c_ty => let |
|
170 val c' = CodegenConsts.norm_of_typ thy c_ty |
|
171 in if CodegenConsts.eq_const (c, c') then I |
|
172 else f (c', c_ty) end |
|
173 | _ => I) (maps (op :: o swap o apfst (snd o strip_comb) |
|
174 o Logic.dest_equals o Drule.plain_prop_of) thms) |
|
175 |
|
176 fun rhs_of thy (c, thms) = |
|
177 Consttab.empty |
|
178 |> add_things_of thy (Consttab.update o rpair () o fst) (c, thms) |
|
179 |> Consttab.keys; |
|
180 |
|
181 fun rhs_of' thy (c, thms) = |
|
182 add_things_of thy (cons o snd) (c, thms) []; |
|
183 |
|
184 fun insts_of thy funcgr (c, ty) = |
|
185 let |
|
186 val tys = Sign.const_typargs thy (c, ty); |
|
187 val c' = CodegenConsts.norm thy (c, tys); |
|
188 val ty_decl = if (is_none o AxClass.class_of_param thy) c |
|
189 then (fst o Constgraph.get_node funcgr) (CodegenConsts.norm thy (c, tys)) |
|
190 else CodegenConsts.typ_of_classop thy (c, tys); |
|
191 val tys_decl = Sign.const_typargs thy (c, ty_decl); |
|
192 val pp = Sign.pp thy; |
|
193 val algebra = Sign.classes_of thy; |
|
194 fun classrel (x, _) _ = x; |
|
195 fun constructor tyco xs class = |
|
196 (tyco, class) :: maps (maps fst) xs; |
|
197 fun variable (TVar (_, sort)) = map (pair []) sort |
|
198 | variable (TFree (_, sort)) = map (pair []) sort; |
|
199 fun mk_inst ty (TVar (_, sort)) = cons (ty, sort) |
|
200 | mk_inst ty (TFree (_, sort)) = cons (ty, sort) |
|
201 | mk_inst (Type (tyco1, tys1)) (Type (tyco2, tys2)) = |
|
202 if tyco1 <> tyco2 then error "bad instance" |
|
203 else fold2 mk_inst tys1 tys2; |
|
204 in |
|
205 flat (maps (Sorts.of_sort_derivation pp algebra |
|
206 { classrel = classrel, constructor = constructor, variable = variable }) |
|
207 (fold2 mk_inst tys tys_decl [])) |
|
208 end; |
|
209 |
|
210 fun all_classops thy tyco class = |
|
211 maps (AxClass.params_of thy) |
|
212 (Graph.all_succs ((#classes o Sorts.rep_algebra o Sign.classes_of) thy) [class]) |
|
213 |> AList.make (fn c => CodegenConsts.typ_of_classop thy (c, [Type (tyco, [])])) |
|
214 (*typ_of_classop is very liberal in its type arguments*) |
|
215 |> map (CodegenConsts.norm_of_typ thy); |
|
216 |
|
217 fun instdefs_of thy insts = |
|
218 let |
|
219 val thy_classes = (#classes o Sorts.rep_algebra o Sign.classes_of) thy; |
|
220 in |
|
221 Symtab.empty |
|
222 |> fold (fn (tyco, class) => |
|
223 Symtab.map_default (tyco, []) (insert (op =) class)) insts |
|
224 |> (fn tab => Symtab.fold (fn (tyco, classes) => append (maps (all_classops thy tyco) |
|
225 (Graph.all_succs thy_classes classes))) tab []) |
|
226 end; |
|
227 |
|
228 fun insts_of_thms thy funcgr c_thms = |
|
229 let |
|
230 val insts = add_things_of thy (fn (_, c_ty) => fold (insert (op =)) |
|
231 (insts_of thy funcgr c_ty)) c_thms []; |
|
232 in instdefs_of thy insts end; |
|
233 |
|
234 fun ensure_const thy funcgr c auxgr = |
|
235 if can (Constgraph.get_node funcgr) c |
|
236 then (NONE, auxgr) |
|
237 else if can (Constgraph.get_node auxgr) c |
|
238 then (SOME c, auxgr) |
|
239 else if is_some (CodegenData.get_datatype_of_constr thy c) then |
|
240 auxgr |
|
241 |> Constgraph.new_node (c, []) |
|
242 |> pair (SOME c) |
|
243 else let |
|
244 val thms = preprocess thy (CodegenData.these_funcs thy c); |
|
245 val rhs = rhs_of thy (c, thms); |
|
246 in |
|
247 auxgr |
|
248 |> Constgraph.new_node (c, thms) |
|
249 |> fold_map (ensure_const thy funcgr) rhs |
|
250 |-> (fn rhs' => fold (fn SOME c' => Constgraph.add_edge (c, c') |
|
251 | NONE => I) rhs') |
|
252 |> pair (SOME c) |
|
253 end; |
|
254 |
|
255 fun specialize_typs thy funcgr eqss = |
|
256 let |
|
257 fun max k [] = k |
|
258 | max k (l::ls) = max (if k < l then l else k) ls; |
|
259 fun typscheme_of (c, ty) = |
|
260 try (Constgraph.get_node funcgr) (CodegenConsts.norm_of_typ thy (c, ty)) |
|
261 |> Option.map fst; |
|
262 fun incr_indices (c, thms) maxidx = |
|
263 let |
|
264 val thms' = map (Thm.incr_indexes maxidx) thms; |
|
265 val maxidx' = Int.max |
|
266 (maxidx, max ~1 (map Thm.maxidx_of thms') + 1); |
|
267 in ((c, thms'), maxidx') end; |
|
268 val tsig = Sign.tsig_of thy; |
|
269 fun unify_const thms (c, ty) (env, maxidx) = |
|
270 case typscheme_of (c, ty) |
|
271 of SOME ty_decl => let |
|
272 val ty_decl' = Logic.incr_tvar maxidx ty_decl; |
|
273 val maxidx' = Int.max (Term.maxidx_of_typ ty_decl' + 1, maxidx); |
|
274 in Type.unify tsig (ty_decl', ty) (env, maxidx') |
|
275 handle TUNIFY => error ("Failed to instantiate\n" |
|
276 ^ (Sign.string_of_typ thy o Envir.norm_type env) ty_decl' ^ "\nto\n" |
|
277 ^ (Sign.string_of_typ thy o Envir.norm_type env) ty ^ ",\n" |
|
278 ^ "in function theorems\n" |
|
279 ^ cat_lines (map string_of_thm thms)) |
|
280 end |
|
281 | NONE => (env, maxidx); |
|
282 fun apply_unifier unif (c, []) = (c, []) |
|
283 | apply_unifier unif (c, thms as thm :: _) = |
|
284 let |
|
285 val ty = CodegenData.typ_func thy thm; |
|
286 val ty' = Envir.norm_type unif ty; |
|
287 val env = Type.typ_match (Sign.tsig_of thy) (ty, ty') Vartab.empty; |
|
288 val inst = Thm.instantiate (Vartab.fold (fn (x_i, (sort, ty)) => |
|
289 cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [], []); |
|
290 in (c, map (Drule.zero_var_indexes o inst) thms) end; |
|
291 val (eqss', maxidx) = |
|
292 fold_map incr_indices eqss 0; |
|
293 val (unif, _) = |
|
294 fold (fn (c, thms) => fold (unify_const thms) (rhs_of' thy (c, thms))) |
|
295 eqss' (Vartab.empty, maxidx); |
|
296 val eqss'' = |
|
297 map (apply_unifier unif) eqss'; |
|
298 in eqss'' end; |
|
299 |
|
300 fun merge_eqsyss thy raw_eqss funcgr = |
|
301 let |
|
302 val eqss = specialize_typs thy funcgr raw_eqss; |
|
303 val tys = map (fn (c as (name, _), []) => (case AxClass.class_of_param thy name |
|
304 of SOME class => (case ClassPackage.the_consts_sign thy class of (v, cs) => |
|
305 (Logic.varifyT o map_type_tfree (fn u as (w, _) => |
|
306 if w = v then TFree (v, [class]) else TFree u)) |
|
307 ((the o AList.lookup (op =) cs) name)) |
|
308 | NONE => Sign.the_const_type thy name) |
|
309 | (_, eq :: _) => CodegenData.typ_func thy eq) eqss; |
|
310 val rhss = map (rhs_of thy) eqss; |
|
311 in |
|
312 funcgr |
|
313 |> fold2 (fn (c, thms) => fn ty => Constgraph.new_node (c, (ty, thms))) eqss tys |
|
314 |> `(fn funcgr => map (insts_of_thms thy funcgr) eqss) |
|
315 |-> (fn rhs_insts => fold2 (fn (c, _) => fn rhs_inst => |
|
316 ensure_consts thy rhs_inst #> fold (curry Constgraph.add_edge c) rhs_inst) eqss rhs_insts) |
|
317 |> fold2 (fn (c, _) => fn rhs => fold (curry Constgraph.add_edge c) rhs) eqss rhss |
|
318 end |
|
319 and ensure_consts thy cs funcgr = |
|
320 fold (snd oo ensure_const thy funcgr) cs Constgraph.empty |
|
321 |> (fn auxgr => fold (merge_eqsyss thy) |
|
322 (map (AList.make (Constgraph.get_node auxgr)) |
|
323 (rev (Constgraph.strong_conn auxgr))) funcgr); |
|
324 |
|
325 in |
|
326 |
|
327 val ensure_consts = ensure_consts; |
|
328 |
|
329 fun mk_funcgr thy consts cs = |
|
330 Funcgr.change thy ( |
|
331 ensure_consts thy consts |
|
332 #> (fn funcgr => ensure_consts thy |
|
333 (instdefs_of thy (fold (fold (insert (op =)) o insts_of thy funcgr) cs [])) funcgr) |
|
334 ); |
|
335 |
|
336 end; (*local*) |
|
337 |
|
338 fun print_funcgr thy funcgr = |
|
339 AList.make (snd o Constgraph.get_node funcgr) (Constgraph.keys funcgr) |
|
340 |> (map o apfst) (CodegenConsts.string_of_const thy) |
|
341 |> sort (string_ord o pairself fst) |
|
342 |> map (fn (s, thms) => |
|
343 (Pretty.block o Pretty.fbreaks) ( |
|
344 Pretty.str s |
|
345 :: map Display.pretty_thm thms |
|
346 )) |
|
347 |> Pretty.chunks |
|
348 |> Pretty.writeln; |
|
349 |
|
350 fun print_codethms thy consts = |
|
351 mk_funcgr thy consts [] |> print_funcgr thy; |
|
352 |
|
353 fun print_codethms_e thy cs = |
|
354 print_codethms thy (map (CodegenConsts.read_const thy) cs); |
|
355 |
|
356 |
|
357 (** Isar **) |
|
358 |
|
359 structure P = OuterParse; |
|
360 |
|
361 val print_codethmsK = "print_codethms"; |
|
362 |
|
363 val print_codethmsP = |
|
364 OuterSyntax.improper_command print_codethmsK "print code theorems of this theory" OuterKeyword.diag |
|
365 (Scan.option (P.$$$ "(" |-- Scan.repeat P.term --| P.$$$ ")") |
|
366 >> (fn NONE => CodegenData.print_thms |
|
367 | SOME cs => fn thy => print_codethms_e thy cs) |
|
368 >> (fn f => Toplevel.no_timing o Toplevel.unknown_theory |
|
369 o Toplevel.keep (f o Toplevel.theory_of))); |
|
370 |
|
371 val _ = OuterSyntax.add_parsers [print_codethmsP]; |
|
372 |
|
373 end; (*struct*) |