8 signature CODE_PREPROC = |
8 signature CODE_PREPROC = |
9 sig |
9 sig |
10 val map_pre: (Proof.context -> Proof.context) -> theory -> theory |
10 val map_pre: (Proof.context -> Proof.context) -> theory -> theory |
11 val map_post: (Proof.context -> Proof.context) -> theory -> theory |
11 val map_post: (Proof.context -> Proof.context) -> theory -> theory |
12 val add_unfold: thm -> theory -> theory |
12 val add_unfold: thm -> theory -> theory |
13 val add_functrans: string * (theory -> (thm * bool) list -> (thm * bool) list option) -> theory -> theory |
13 val add_functrans: string * (Proof.context -> (thm * bool) list -> (thm * bool) list option) -> theory -> theory |
14 val del_functrans: string -> theory -> theory |
14 val del_functrans: string -> theory -> theory |
15 val simple_functrans: (theory -> thm list -> thm list option) |
15 val simple_functrans: (Proof.context -> thm list -> thm list option) |
16 -> theory -> (thm * bool) list -> (thm * bool) list option |
16 -> Proof.context -> (thm * bool) list -> (thm * bool) list option |
17 val print_codeproc: theory -> unit |
17 val print_codeproc: Proof.context -> unit |
18 |
18 |
19 type code_algebra |
19 type code_algebra |
20 type code_graph |
20 type code_graph |
21 val cert: code_graph -> string -> Code.cert |
21 val cert: code_graph -> string -> Code.cert |
22 val sortargs: code_graph -> string -> sort list |
22 val sortargs: code_graph -> string -> sort list |
23 val all: code_graph -> string list |
23 val all: code_graph -> string list |
24 val pretty: theory -> code_graph -> Pretty.T |
24 val pretty: Proof.context -> code_graph -> Pretty.T |
25 val obtain: bool -> theory -> string list -> term list -> code_algebra * code_graph |
25 val obtain: bool -> theory -> string list -> term list -> code_algebra * code_graph |
26 val dynamic_conv: theory |
26 val dynamic_conv: Proof.context |
27 -> (code_algebra -> code_graph -> (string * sort) list -> term -> conv) -> conv |
27 -> (code_algebra -> code_graph -> (string * sort) list -> term -> conv) -> conv |
28 val dynamic_value: theory -> ((term -> term) -> 'a -> 'a) |
28 val dynamic_value: Proof.context -> ((term -> term) -> 'a -> 'a) |
29 -> (code_algebra -> code_graph -> (string * sort) list -> term -> 'a) -> term -> 'a |
29 -> (code_algebra -> code_graph -> (string * sort) list -> term -> 'a) -> term -> 'a |
30 val static_conv: theory -> string list |
30 val static_conv: Proof.context -> string list |
31 -> (code_algebra -> code_graph -> (string * sort) list -> term -> conv) -> conv |
31 -> (code_algebra -> code_graph -> Proof.context -> (string * sort) list -> term -> conv) |
32 val static_value: theory -> ((term -> term) -> 'a -> 'a) -> string list |
32 -> Proof.context -> conv |
33 -> (code_algebra -> code_graph -> (string * sort) list -> term -> 'a) -> term -> 'a |
33 val static_value: Proof.context -> ((term -> term) -> 'a -> 'a) -> string list |
|
34 -> (code_algebra -> code_graph -> Proof.context -> (string * sort) list -> term -> 'a) |
|
35 -> Proof.context -> term -> 'a |
34 |
36 |
35 val setup: theory -> theory |
37 val setup: theory -> theory |
36 end |
38 end |
37 |
39 |
38 structure Code_Preproc : CODE_PREPROC = |
40 structure Code_Preproc : CODE_PREPROC = |
126 |> fold apply_beta all_vars |
129 |> fold apply_beta all_vars |
127 end; |
130 end; |
128 |
131 |
129 fun trans_conv_rule conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm); |
132 fun trans_conv_rule conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm); |
130 |
133 |
131 fun term_of_conv thy conv = |
134 fun term_of_conv ctxt conv = |
132 Thm.cterm_of thy |
135 Thm.cterm_of (Proof_Context.theory_of ctxt) |
133 #> conv |
136 #> conv ctxt |
134 #> Thm.prop_of |
137 #> Thm.prop_of |
135 #> Logic.dest_equals |
138 #> Logic.dest_equals |
136 #> snd; |
139 #> snd; |
137 |
140 |
138 fun term_of_conv_resubst thy conv t = |
141 fun term_of_conv_resubst ctxt conv t = |
139 let |
142 let |
140 val all_vars = fold_aterms (fn t as Free _ => insert (op aconv) t |
143 val all_vars = fold_aterms (fn t as Free _ => insert (op aconv) t |
141 | t as Var _ => insert (op aconv) t |
144 | t as Var _ => insert (op aconv) t |
142 | _ => I) t []; |
145 | _ => I) t []; |
143 val resubst = curry (Term.betapplys o swap) all_vars; |
146 val resubst = curry (Term.betapplys o swap) all_vars; |
144 in (resubst, term_of_conv thy conv (fold_rev lambda all_vars t)) end; |
147 in (resubst, term_of_conv ctxt conv (fold_rev lambda all_vars t)) end; |
145 |
148 |
146 fun global_simpset_context thy ss = |
149 fun preprocess_conv ctxt = |
147 Proof_Context.init_global thy |
150 let |
148 |> put_simpset ss; |
151 val thy = Proof_Context.theory_of ctxt; |
149 |
152 val ss = (#pre o the_thmproc) thy; |
150 fun preprocess_conv thy = |
153 in fn ctxt' => |
151 let |
154 Simplifier.rewrite (put_simpset ss ctxt') |
152 val pre = global_simpset_context thy ((#pre o the_thmproc) thy); |
155 #> trans_conv_rule (Axclass.unoverload_conv (Proof_Context.theory_of ctxt')) |
153 in |
156 end; |
154 Simplifier.rewrite pre |
157 |
155 #> trans_conv_rule (Axclass.unoverload_conv thy) |
158 fun preprocess_term ctxt = |
156 end; |
159 let |
157 |
160 val conv = preprocess_conv ctxt; |
158 fun preprocess_term thy = term_of_conv_resubst thy (preprocess_conv thy); |
161 in fn ctxt' => term_of_conv_resubst ctxt' conv end; |
159 |
162 |
160 fun postprocess_conv thy = |
163 fun postprocess_conv ctxt = |
161 let |
164 let |
162 val post = global_simpset_context thy ((#post o the_thmproc) thy); |
165 val thy = Proof_Context.theory_of ctxt; |
163 in |
166 val ss = (#post o the_thmproc) thy; |
164 Axclass.overload_conv thy |
167 in fn ctxt' => |
165 #> trans_conv_rule (Simplifier.rewrite post) |
168 Axclass.overload_conv (Proof_Context.theory_of ctxt') |
166 end; |
169 #> trans_conv_rule (Simplifier.rewrite (put_simpset ss ctxt')) |
167 |
170 end; |
168 fun postprocess_term thy = term_of_conv thy (postprocess_conv thy); |
171 |
169 |
172 fun postprocess_term ctxt = |
170 fun print_codeproc thy = |
173 let |
171 let |
174 val conv = postprocess_conv ctxt; |
172 val ctxt = Proof_Context.init_global thy; |
175 in fn ctxt' => term_of_conv ctxt' conv end; |
|
176 |
|
177 fun print_codeproc ctxt = |
|
178 let |
|
179 val thy = Proof_Context.theory_of ctxt; |
173 val pre = (#pre o the_thmproc) thy; |
180 val pre = (#pre o the_thmproc) thy; |
174 val post = (#post o the_thmproc) thy; |
181 val post = (#post o the_thmproc) thy; |
175 val functrans = (map fst o #functrans o the_thmproc) thy; |
182 val functrans = (map fst o #functrans o the_thmproc) thy; |
176 in |
183 in |
177 (Pretty.writeln o Pretty.chunks) [ |
184 (Pretty.writeln o Pretty.chunks) [ |
267 fun obtain_eqns ctxt eqngr c = |
278 fun obtain_eqns ctxt eqngr c = |
268 case try (Graph.get_node eqngr) c |
279 case try (Graph.get_node eqngr) c |
269 of SOME (lhs, cert) => ((lhs, []), cert) |
280 of SOME (lhs, cert) => ((lhs, []), cert) |
270 | NONE => let |
281 | NONE => let |
271 val thy = Proof_Context.theory_of ctxt; |
282 val thy = Proof_Context.theory_of ctxt; |
272 val functrans = (map (fn (_, (_, f)) => f thy) |
283 val functrans = (map (fn (_, (_, f)) => f ctxt) |
273 o #functrans o the_thmproc) thy; |
284 o #functrans o the_thmproc) thy; |
274 val cert = Code.get_cert thy { functrans = functrans, ss = simpset_of ctxt } c; (*FIXME*) |
285 val cert = Code.get_cert thy { functrans = functrans, ss = simpset_of ctxt } c; (*FIXME*) |
275 val (lhs, rhss) = |
286 val (lhs, rhss) = |
276 Code.typargs_deps_of_cert thy cert; |
287 Code.typargs_deps_of_cert thy cert; |
277 in ((lhs, rhss), cert) end; |
288 in ((lhs, rhss), cert) end; |
278 |
289 |
279 fun obtain_instance thy arities (inst as (class, tyco)) = |
290 fun obtain_instance ctxt arities (inst as (class, tyco)) = |
280 case AList.lookup (op =) arities inst |
291 case AList.lookup (op =) arities inst |
281 of SOME classess => (classess, ([], [])) |
292 of SOME classess => (classess, ([], [])) |
282 | NONE => let |
293 | NONE => let |
|
294 val thy = Proof_Context.theory_of ctxt; |
283 val all_classes = complete_proper_sort thy [class]; |
295 val all_classes = complete_proper_sort thy [class]; |
284 val super_classes = remove (op =) class all_classes; |
296 val super_classes = remove (op =) class all_classes; |
285 val classess = map (complete_proper_sort thy) |
297 val classess = map (complete_proper_sort thy) |
286 (Sign.arity_sorts thy tyco [class]); |
298 (Sign.arity_sorts thy tyco [class]); |
287 val inst_params = inst_params thy tyco all_classes; |
299 val inst_params = inst_params thy tyco all_classes; |
329 else vardeps_data (*permissive!*) |
341 else vardeps_data (*permissive!*) |
330 and ensure_inst ctxt arities eqngr (inst as (class, tyco)) (vardeps_data as (_, (_, insts))) = |
342 and ensure_inst ctxt arities eqngr (inst as (class, tyco)) (vardeps_data as (_, (_, insts))) = |
331 if member (op =) insts inst then vardeps_data |
343 if member (op =) insts inst then vardeps_data |
332 else let |
344 else let |
333 val (classess, (super_classes, inst_params)) = |
345 val (classess, (super_classes, inst_params)) = |
334 obtain_instance (Proof_Context.theory_of ctxt) arities inst; |
346 obtain_instance ctxt arities inst; |
335 in |
347 in |
336 vardeps_data |
348 vardeps_data |
337 |> (apsnd o apsnd) (insert (op =) inst) |
349 |> (apsnd o apsnd) (insert (op =) inst) |
338 |> fold_index (fn (k, _) => |
350 |> fold_index (fn (k, _) => |
339 apfst (Vargraph.new_node ((Inst (class, tyco), k), ([] ,[])))) classess |
351 apfst (Vargraph.new_node ((Inst (class, tyco), k), ([] ,[])))) classess |
393 { class_relation = K class_relation, type_constructor = type_constructor, |
406 { class_relation = K class_relation, type_constructor = type_constructor, |
394 type_variable = type_variable } (T, proj_sort sort) |
407 type_variable = type_variable } (T, proj_sort sort) |
395 handle Sorts.CLASS_ERROR _ => [] (*permissive!*)) |
408 handle Sorts.CLASS_ERROR _ => [] (*permissive!*)) |
396 end; |
409 end; |
397 |
410 |
398 fun add_arity thy vardeps (class, tyco) = |
411 fun add_arity ctxt vardeps (class, tyco) = |
399 AList.default (op =) ((class, tyco), |
412 AList.default (op =) ((class, tyco), |
400 map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) |
413 map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) |
401 (Sign.arity_number thy tyco)); |
414 (Sign.arity_number (Proof_Context.theory_of ctxt) tyco)); |
402 |
415 |
403 fun add_cert thy vardeps (c, (proto_lhs, proto_cert)) (rhss, eqngr) = |
416 fun add_cert ctxt vardeps (c, (proto_lhs, proto_cert)) (rhss, eqngr) = |
404 if can (Graph.get_node eqngr) c then (rhss, eqngr) |
417 if can (Graph.get_node eqngr) c then (rhss, eqngr) |
405 else let |
418 else let |
|
419 val thy = Proof_Context.theory_of ctxt; |
406 val lhs = map_index (fn (k, (v, _)) => |
420 val lhs = map_index (fn (k, (v, _)) => |
407 (v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs; |
421 (v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs; |
408 val cert = proto_cert |
422 val cert = proto_cert |
409 |> Code.constrain_cert thy (map (Sign.minimize_sort thy o snd) lhs) |
423 |> Code.constrain_cert thy (map (Sign.minimize_sort thy o snd) lhs) |
410 |> Code.conclude_cert; |
424 |> Code.conclude_cert; |
411 val (vs, rhss') = Code.typargs_deps_of_cert thy cert; |
425 val (vs, rhss') = Code.typargs_deps_of_cert thy cert; |
412 val eqngr' = Graph.new_node (c, (vs, cert)) eqngr; |
426 val eqngr' = Graph.new_node (c, (vs, cert)) eqngr; |
413 in (map (pair c) rhss' @ rhss, eqngr') end; |
427 in (map (pair c) rhss' @ rhss, eqngr') end; |
414 |
428 |
415 fun extend_arities_eqngr thy cs ts (arities, (eqngr : code_graph)) = |
429 fun extend_arities_eqngr raw_ctxt cs ts (arities, (eqngr : code_graph)) = |
416 let |
430 let |
|
431 val thy = Proof_Context.theory_of raw_ctxt; |
417 val {pre, ...} = the_thmproc thy; |
432 val {pre, ...} = the_thmproc thy; |
418 val ctxt = thy |> Proof_Context.init_global |> put_simpset pre; |
433 val ctxt = put_simpset pre raw_ctxt; |
419 val cs_rhss = (fold o fold_aterms) (fn Const (c_ty as (c, _)) => |
434 val cs_rhss = (fold o fold_aterms) (fn Const (c_ty as (c, _)) => |
420 insert (op =) (c, (map (styp_of NONE) o Sign.const_typargs thy) c_ty) | _ => I) ts []; |
435 insert (op =) (c, (map (styp_of NONE) o Sign.const_typargs thy) c_ty) | _ => I) ts []; |
421 val (vardeps, (eqntab, insts)) = empty_vardeps_data |
436 val (vardeps, (eqntab, insts)) = empty_vardeps_data |
422 |> fold (ensure_fun ctxt arities eqngr) cs |
437 |> fold (ensure_fun ctxt arities eqngr) cs |
423 |> fold (ensure_rhs ctxt arities eqngr) cs_rhss; |
438 |> fold (ensure_rhs ctxt arities eqngr) cs_rhss; |
424 val arities' = fold (add_arity thy vardeps) insts arities; |
439 val arities' = fold (add_arity ctxt vardeps) insts arities; |
425 val algebra = Sorts.subalgebra (Context.pretty_global thy) (is_proper_class thy) |
440 val algebra = Sorts.subalgebra (Context.pretty_global thy) (is_proper_class thy) |
426 (AList.lookup (op =) arities') (Sign.classes_of thy); |
441 (AList.lookup (op =) arities') (Sign.classes_of thy); |
427 val (rhss, eqngr') = Symtab.fold (add_cert thy vardeps) eqntab ([], eqngr); |
442 val (rhss, eqngr') = Symtab.fold (add_cert ctxt vardeps) eqntab ([], eqngr); |
428 fun deps_of (c, rhs) = c :: maps (dicts_of thy algebra) |
443 fun deps_of (c, rhs) = c :: maps (dicts_of ctxt algebra) |
429 (rhs ~~ sortargs eqngr' c); |
444 (rhs ~~ sortargs eqngr' c); |
430 val eqngr'' = fold (fn (c, rhs) => fold |
445 val eqngr'' = fold (fn (c, rhs) => fold |
431 (curry Graph.add_edge c) (deps_of rhs)) rhss eqngr'; |
446 (curry Graph.add_edge c) (deps_of rhs)) rhss eqngr'; |
432 in (algebra, (arities', eqngr'')) end; |
447 in (algebra, (arities', eqngr'')) end; |
433 |
448 |
442 |
457 |
443 |
458 |
444 (** retrieval and evaluation interfaces **) |
459 (** retrieval and evaluation interfaces **) |
445 |
460 |
446 fun obtain ignore_cache thy consts ts = apsnd snd |
461 fun obtain ignore_cache thy consts ts = apsnd snd |
447 (Wellsorted.change_yield (if ignore_cache then NONE else SOME thy) (extend_arities_eqngr thy consts ts)); |
462 (Wellsorted.change_yield (if ignore_cache then NONE else SOME thy) |
|
463 (extend_arities_eqngr (Proof_Context.init_global thy) consts ts)); |
448 |
464 |
449 fun dest_cterm ct = let val t = Thm.term_of ct in (Term.add_tfrees t [], t) end; |
465 fun dest_cterm ct = let val t = Thm.term_of ct in (Term.add_tfrees t [], t) end; |
450 |
466 |
451 fun dynamic_conv thy conv = no_variables_conv (fn ct => |
467 fun dynamic_conv ctxt conv = no_variables_conv ctxt (fn ct => |
452 let |
468 let |
453 val thm1 = preprocess_conv thy ct; |
469 val thm1 = preprocess_conv ctxt ctxt ct; |
454 val ct' = Thm.rhs_of thm1; |
470 val ct' = Thm.rhs_of thm1; |
455 val (vs', t') = dest_cterm ct'; |
471 val (vs', t') = dest_cterm ct'; |
456 val consts = fold_aterms |
472 val consts = fold_aterms |
457 (fn Const (c, _) => insert (op =) c | _ => I) t' []; |
473 (fn Const (c, _) => insert (op =) c | _ => I) t' []; |
458 val (algebra', eqngr') = obtain false thy consts [t']; |
474 val (algebra', eqngr') = obtain false (Proof_Context.theory_of ctxt) consts [t']; |
459 val thm2 = conv algebra' eqngr' vs' t' ct'; |
475 val thm2 = conv algebra' eqngr' vs' t' ct'; |
460 val thm3 = postprocess_conv thy (Thm.rhs_of thm2); |
476 val thm3 = postprocess_conv ctxt ctxt (Thm.rhs_of thm2); |
461 in |
477 in |
462 Thm.transitive thm1 (Thm.transitive thm2 thm3) handle THM _ => |
478 Thm.transitive thm1 (Thm.transitive thm2 thm3) handle THM _ => |
463 error ("could not construct evaluation proof:\n" |
479 error ("could not construct evaluation proof:\n" |
464 ^ (cat_lines o map (Display.string_of_thm_global thy)) [thm1, thm2, thm3]) |
480 ^ (cat_lines o map (Display.string_of_thm ctxt)) [thm1, thm2, thm3]) |
465 end); |
481 end); |
466 |
482 |
467 fun dynamic_value thy postproc evaluator t = |
483 fun dynamic_value ctxt postproc evaluator t = |
468 let |
484 let |
469 val (resubst, t') = preprocess_term thy t; |
485 val (resubst, t') = preprocess_term ctxt ctxt t; |
470 val vs' = Term.add_tfrees t' []; |
486 val vs' = Term.add_tfrees t' []; |
471 val consts = fold_aterms |
487 val consts = fold_aterms |
472 (fn Const (c, _) => insert (op =) c | _ => I) t' []; |
488 (fn Const (c, _) => insert (op =) c | _ => I) t' []; |
473 val (algebra', eqngr') = obtain false thy consts [t']; |
489 val (algebra', eqngr') = obtain false (Proof_Context.theory_of ctxt) consts [t']; |
474 in |
490 in |
475 t' |
491 t' |
476 |> evaluator algebra' eqngr' vs' |
492 |> evaluator algebra' eqngr' vs' |
477 |> postproc (postprocess_term thy o resubst) |
493 |> postproc (postprocess_term ctxt ctxt o resubst) |
478 end; |
494 end; |
479 |
495 |
480 fun static_conv thy consts conv = |
496 fun static_conv ctxt consts conv = |
481 let |
497 let |
482 val (algebra, eqngr) = obtain true thy consts []; |
498 val (algebra, eqngr) = obtain true (Proof_Context.theory_of ctxt) consts []; |
|
499 val pre_conv = preprocess_conv ctxt; |
483 val conv' = conv algebra eqngr; |
500 val conv' = conv algebra eqngr; |
484 in |
501 val post_conv = postprocess_conv ctxt; |
485 no_variables_conv ((preprocess_conv thy) |
502 in fn ctxt' => no_variables_conv ctxt' ((pre_conv ctxt') |
486 then_conv (fn ct => uncurry conv' (dest_cterm ct) ct) |
503 then_conv (fn ct => uncurry (conv' ctxt') (dest_cterm ct) ct) |
487 then_conv (postprocess_conv thy)) |
504 then_conv (post_conv ctxt')) |
488 end; |
505 end; |
489 |
506 |
490 fun static_value thy postproc consts evaluator = |
507 fun static_value ctxt postproc consts evaluator = |
491 let |
508 let |
492 val (algebra, eqngr) = obtain true thy consts []; |
509 val (algebra, eqngr) = obtain true (Proof_Context.theory_of ctxt) consts []; |
|
510 val preproc = preprocess_term ctxt; |
493 val evaluator' = evaluator algebra eqngr; |
511 val evaluator' = evaluator algebra eqngr; |
494 val postproc' = postprocess_term thy; |
512 val postproc' = postprocess_term ctxt; |
495 in |
513 in fn ctxt' => |
496 preprocess_term thy |
514 preproc ctxt' |
497 #-> (fn resubst => fn t => t |
515 #-> (fn resubst => fn t => t |
498 |> evaluator' (Term.add_tfrees t []) |
516 |> evaluator' ctxt' (Term.add_tfrees t []) |
499 |> postproc (postproc' o resubst)) |
517 |> postproc (postproc' ctxt' o resubst)) |
500 end; |
518 end; |
501 |
519 |
502 |
520 |
503 (** setup **) |
521 (** setup **) |
504 |
522 |