16 end |
16 end |
17 |
17 |
18 structure SMT_Normalize: SMT_NORMALIZE = |
18 structure SMT_Normalize: SMT_NORMALIZE = |
19 struct |
19 struct |
20 |
20 |
21 structure U = SMT_Utils |
|
22 structure B = SMT_Builtin |
|
23 |
|
24 |
21 |
25 (* general theorem normalizations *) |
22 (* general theorem normalizations *) |
26 |
23 |
27 (** instantiate elimination rules **) |
24 (** instantiate elimination rules **) |
28 |
25 |
29 local |
26 local |
30 val (cpfalse, cfalse) = `U.mk_cprop (Thm.cterm_of @{theory} @{const False}) |
27 val (cpfalse, cfalse) = |
|
28 `SMT_Utils.mk_cprop (Thm.cterm_of @{theory} @{const False}) |
31 |
29 |
32 fun inst f ct thm = |
30 fun inst f ct thm = |
33 let val cv = f (Drule.strip_imp_concl (Thm.cprop_of thm)) |
31 let val cv = f (Drule.strip_imp_concl (Thm.cprop_of thm)) |
34 in Thm.instantiate ([], [(cv, ct)]) thm end |
32 in Thm.instantiate ([], [(cv, ct)]) thm end |
35 in |
33 in |
68 Conv.binder_conv (atomize_conv o snd) ctxt then_conv |
66 Conv.binder_conv (atomize_conv o snd) ctxt then_conv |
69 Conv.rewr_conv @{thm atomize_all} |
67 Conv.rewr_conv @{thm atomize_all} |
70 | _ => Conv.all_conv) ct |
68 | _ => Conv.all_conv) ct |
71 |
69 |
72 val setup_atomize = |
70 val setup_atomize = |
73 fold B.add_builtin_fun_ext'' [@{const_name "==>"}, @{const_name "=="}, |
71 fold SMT_Builtin.add_builtin_fun_ext'' [@{const_name "==>"}, |
74 @{const_name all}, @{const_name Trueprop}] |
72 @{const_name "=="}, @{const_name all}, @{const_name Trueprop}] |
75 |
73 |
76 |
74 |
77 (** unfold special quantifiers **) |
75 (** unfold special quantifiers **) |
78 |
76 |
79 local |
77 local |
98 SOME thm => Conv.rewr_conv thm |
96 SOME thm => Conv.rewr_conv thm |
99 | NONE => Conv.all_conv) ct |
97 | NONE => Conv.all_conv) ct |
100 in |
98 in |
101 |
99 |
102 fun unfold_special_quants_conv ctxt = |
100 fun unfold_special_quants_conv ctxt = |
103 U.if_exists_conv (is_some o special_quant) |
101 SMT_Utils.if_exists_conv (is_some o special_quant) |
104 (Conv.top_conv special_quant_conv ctxt) |
102 (Conv.top_conv special_quant_conv ctxt) |
105 |
103 |
106 val setup_unfolded_quants = fold (B.add_builtin_fun_ext'' o fst) special_quants |
104 val setup_unfolded_quants = |
|
105 fold (SMT_Builtin.add_builtin_fun_ext'' o fst) special_quants |
107 |
106 |
108 end |
107 end |
109 |
108 |
110 |
109 |
111 (** trigger inference **) |
110 (** trigger inference **) |
139 fun check_trigger_error ctxt t = |
138 fun check_trigger_error ctxt t = |
140 error ("SMT triggers must only occur under quantifier and multipatterns " ^ |
139 error ("SMT triggers must only occur under quantifier and multipatterns " ^ |
141 "must have the same kind: " ^ Syntax.string_of_term ctxt t) |
140 "must have the same kind: " ^ Syntax.string_of_term ctxt t) |
142 |
141 |
143 fun check_trigger_conv ctxt ct = |
142 fun check_trigger_conv ctxt ct = |
144 if proper_quant false proper_trigger (U.term_of ct) then Conv.all_conv ct |
143 if proper_quant false proper_trigger (SMT_Utils.term_of ct) then |
|
144 Conv.all_conv ct |
145 else check_trigger_error ctxt (Thm.term_of ct) |
145 else check_trigger_error ctxt (Thm.term_of ct) |
146 |
146 |
147 |
147 |
148 (*** infer simple triggers ***) |
148 (*** infer simple triggers ***) |
149 |
149 |
167 | _ => false) |
167 | _ => false) |
168 |
168 |
169 fun is_simp_lhs ctxt t = |
169 fun is_simp_lhs ctxt t = |
170 (case Term.strip_comb t of |
170 (case Term.strip_comb t of |
171 (Const c, ts as _ :: _) => |
171 (Const c, ts as _ :: _) => |
172 not (B.is_builtin_fun_ext ctxt c ts) andalso |
172 not (SMT_Builtin.is_builtin_fun_ext ctxt c ts) andalso |
173 forall (is_constr_pat (ProofContext.theory_of ctxt)) ts |
173 forall (is_constr_pat (ProofContext.theory_of ctxt)) ts |
174 | _ => false) |
174 | _ => false) |
175 |
175 |
176 fun has_all_vars vs t = |
176 fun has_all_vars vs t = |
177 subset (op aconv) (vs, map Free (Term.add_frees t [])) |
177 subset (op aconv) (vs, map Free (Term.add_frees t [])) |
192 val tps = (op ~~) (`gen (map Thm.term_of cts)) |
192 val tps = (op ~~) (`gen (map Thm.term_of cts)) |
193 fun some_match u = tps |> exists (fn (t', t) => |
193 fun some_match u = tps |> exists (fn (t', t) => |
194 Pattern.matches thy (t', u) andalso not (t aconv u)) |
194 Pattern.matches thy (t', u) andalso not (t aconv u)) |
195 in not (Term.exists_subterm some_match u) end |
195 in not (Term.exists_subterm some_match u) end |
196 |
196 |
197 val pat = U.mk_const_pat @{theory} @{const_name SMT.pat} U.destT1 |
197 val pat = |
198 fun mk_pat ct = Thm.capply (U.instT' ct pat) ct |
198 SMT_Utils.mk_const_pat @{theory} @{const_name SMT.pat} SMT_Utils.destT1 |
|
199 fun mk_pat ct = Thm.capply (SMT_Utils.instT' ct pat) ct |
199 |
200 |
200 fun mk_clist T = pairself (Thm.cterm_of @{theory}) |
201 fun mk_clist T = pairself (Thm.cterm_of @{theory}) |
201 (HOLogic.cons_const T, HOLogic.nil_const T) |
202 (HOLogic.cons_const T, HOLogic.nil_const T) |
202 fun mk_list (ccons, cnil) f cts = fold_rev (Thm.mk_binop ccons o f) cts cnil |
203 fun mk_list (ccons, cnil) f cts = fold_rev (Thm.mk_binop ccons o f) cts cnil |
203 val mk_pat_list = mk_list (mk_clist @{typ SMT.pattern}) |
204 val mk_pat_list = mk_list (mk_clist @{typ SMT.pattern}) |
229 |
230 |
230 fun has_trigger (@{const SMT.trigger} $ _ $ _) = true |
231 fun has_trigger (@{const SMT.trigger} $ _ $ _) = true |
231 | has_trigger _ = false |
232 | has_trigger _ = false |
232 |
233 |
233 fun try_trigger_conv cv ct = |
234 fun try_trigger_conv cv ct = |
234 if U.under_quant has_trigger (U.term_of ct) then Conv.all_conv ct |
235 if SMT_Utils.under_quant has_trigger (SMT_Utils.term_of ct) then |
|
236 Conv.all_conv ct |
235 else Conv.try_conv cv ct |
237 else Conv.try_conv cv ct |
236 |
238 |
237 fun infer_trigger_conv ctxt = |
239 fun infer_trigger_conv ctxt = |
238 if Config.get ctxt SMT_Config.infer_triggers then |
240 if Config.get ctxt SMT_Config.infer_triggers then |
239 try_trigger_conv (U.under_quant_conv (infer_trigger_eq_conv ctxt) ctxt) |
241 try_trigger_conv |
|
242 (SMT_Utils.under_quant_conv (infer_trigger_eq_conv ctxt) ctxt) |
240 else Conv.all_conv |
243 else Conv.all_conv |
241 in |
244 in |
242 |
245 |
243 fun trigger_conv ctxt = |
246 fun trigger_conv ctxt = |
244 U.prop_conv (check_trigger_conv ctxt then_conv infer_trigger_conv ctxt) |
247 SMT_Utils.prop_conv |
245 |
248 (check_trigger_conv ctxt then_conv infer_trigger_conv ctxt) |
246 val setup_trigger = fold B.add_builtin_fun_ext'' |
249 |
247 [@{const_name SMT.pat}, @{const_name SMT.nopat}, @{const_name SMT.trigger}] |
250 val setup_trigger = |
|
251 fold SMT_Builtin.add_builtin_fun_ext'' |
|
252 [@{const_name SMT.pat}, @{const_name SMT.nopat}, @{const_name SMT.trigger}] |
248 |
253 |
249 end |
254 end |
250 |
255 |
251 |
256 |
252 (** adding quantifier weights **) |
257 (** adding quantifier weights **) |
270 error ("SMT weight must be a non-negative number and must only occur " ^ |
275 error ("SMT weight must be a non-negative number and must only occur " ^ |
271 "under the top-most quantifier and an optional trigger: " ^ |
276 "under the top-most quantifier and an optional trigger: " ^ |
272 Syntax.string_of_term ctxt t) |
277 Syntax.string_of_term ctxt t) |
273 |
278 |
274 fun check_weight_conv ctxt ct = |
279 fun check_weight_conv ctxt ct = |
275 if U.under_quant proper_trigger (U.term_of ct) then Conv.all_conv ct |
280 if SMT_Utils.under_quant proper_trigger (SMT_Utils.term_of ct) then |
|
281 Conv.all_conv ct |
276 else check_weight_error ctxt (Thm.term_of ct) |
282 else check_weight_error ctxt (Thm.term_of ct) |
277 |
283 |
278 |
284 |
279 (*** insertion of weights ***) |
285 (*** insertion of weights ***) |
280 |
286 |
292 end |
298 end |
293 |
299 |
294 fun add_weight_conv NONE _ = Conv.all_conv |
300 fun add_weight_conv NONE _ = Conv.all_conv |
295 | add_weight_conv (SOME weight) ctxt = |
301 | add_weight_conv (SOME weight) ctxt = |
296 let val cv = Conv.rewr_conv (mk_weight_eq weight) |
302 let val cv = Conv.rewr_conv (mk_weight_eq weight) |
297 in U.under_quant_conv (K (under_trigger_conv cv)) ctxt end |
303 in SMT_Utils.under_quant_conv (K (under_trigger_conv cv)) ctxt end |
298 in |
304 in |
299 |
305 |
300 fun weight_conv weight ctxt = |
306 fun weight_conv weight ctxt = |
301 U.prop_conv (check_weight_conv ctxt then_conv add_weight_conv weight ctxt) |
307 SMT_Utils.prop_conv |
302 |
308 (check_weight_conv ctxt then_conv add_weight_conv weight ctxt) |
303 val setup_weight = B.add_builtin_fun_ext'' @{const_name SMT.weight} |
309 |
|
310 val setup_weight = SMT_Builtin.add_builtin_fun_ext'' @{const_name SMT.weight} |
304 |
311 |
305 end |
312 end |
306 |
313 |
307 |
314 |
308 (** combined general normalizations **) |
315 (** combined general normalizations **) |
353 "distinct [] = True" |
360 "distinct [] = True" |
354 "distinct [x] = True" |
361 "distinct [x] = True" |
355 "distinct [x, y] = (x ~= y)" |
362 "distinct [x, y] = (x ~= y)" |
356 by simp_all} |
363 by simp_all} |
357 fun distinct_conv _ = |
364 fun distinct_conv _ = |
358 U.if_true_conv is_trivial_distinct (Conv.rewrs_conv thms) |
365 SMT_Utils.if_true_conv is_trivial_distinct (Conv.rewrs_conv thms) |
359 in |
366 in |
360 |
367 |
361 fun trivial_distinct_conv ctxt = U.if_exists_conv is_trivial_distinct |
368 fun trivial_distinct_conv ctxt = |
362 (Conv.top_conv distinct_conv ctxt) |
369 SMT_Utils.if_exists_conv is_trivial_distinct |
|
370 (Conv.top_conv distinct_conv ctxt) |
363 |
371 |
364 end |
372 end |
365 |
373 |
366 |
374 |
367 (** rewrite bool case expressions as if expressions **) |
375 (** rewrite bool case expressions as if expressions **) |
371 | is_bool_case _ = false |
379 | is_bool_case _ = false |
372 |
380 |
373 val thm = mk_meta_eq @{lemma |
381 val thm = mk_meta_eq @{lemma |
374 "bool_case = (%x y P. if P then x else y)" by (rule ext)+ simp} |
382 "bool_case = (%x y P. if P then x else y)" by (rule ext)+ simp} |
375 |
383 |
376 fun unfold_conv _ = U.if_true_conv is_bool_case (Conv.rewr_conv thm) |
384 fun unfold_conv _ = SMT_Utils.if_true_conv is_bool_case (Conv.rewr_conv thm) |
377 in |
385 in |
378 |
386 |
379 fun rewrite_bool_case_conv ctxt = U.if_exists_conv is_bool_case |
387 fun rewrite_bool_case_conv ctxt = |
380 (Conv.top_conv unfold_conv ctxt) |
388 SMT_Utils.if_exists_conv is_bool_case (Conv.top_conv unfold_conv ctxt) |
381 |
389 |
382 val setup_bool_case = B.add_builtin_fun_ext'' @{const_name "bool.bool_case"} |
390 val setup_bool_case = |
|
391 SMT_Builtin.add_builtin_fun_ext'' @{const_name "bool.bool_case"} |
383 |
392 |
384 end |
393 end |
385 |
394 |
386 |
395 |
387 (** unfold abs, min and max **) |
396 (** unfold abs, min and max **) |
398 by (rule ext)+ (rule max_def)} |
407 by (rule ext)+ (rule max_def)} |
399 |
408 |
400 val defs = [(@{const_name min}, min_def), (@{const_name max}, max_def), |
409 val defs = [(@{const_name min}, min_def), (@{const_name max}, max_def), |
401 (@{const_name abs}, abs_def)] |
410 (@{const_name abs}, abs_def)] |
402 |
411 |
403 fun is_builtinT ctxt T = B.is_builtin_typ_ext ctxt (Term.domain_type T) |
412 fun is_builtinT ctxt T = |
|
413 SMT_Builtin.is_builtin_typ_ext ctxt (Term.domain_type T) |
404 |
414 |
405 fun abs_min_max ctxt (Const (n, T)) = |
415 fun abs_min_max ctxt (Const (n, T)) = |
406 (case AList.lookup (op =) defs n of |
416 (case AList.lookup (op =) defs n of |
407 NONE => NONE |
417 NONE => NONE |
408 | SOME thm => if is_builtinT ctxt T then SOME thm else NONE) |
418 | SOME thm => if is_builtinT ctxt T then SOME thm else NONE) |
413 SOME thm => Conv.rewr_conv thm |
423 SOME thm => Conv.rewr_conv thm |
414 | NONE => Conv.all_conv) ct |
424 | NONE => Conv.all_conv) ct |
415 in |
425 in |
416 |
426 |
417 fun unfold_abs_min_max_conv ctxt = |
427 fun unfold_abs_min_max_conv ctxt = |
418 U.if_exists_conv (is_some o abs_min_max ctxt) |
428 SMT_Utils.if_exists_conv (is_some o abs_min_max ctxt) |
419 (Conv.top_conv unfold_amm_conv ctxt) |
429 (Conv.top_conv unfold_amm_conv ctxt) |
420 |
430 |
421 val setup_abs_min_max = fold (B.add_builtin_fun_ext'' o fst) defs |
431 val setup_abs_min_max = fold (SMT_Builtin.add_builtin_fun_ext'' o fst) defs |
422 |
432 |
423 end |
433 end |
424 |
434 |
425 |
435 |
426 (** embedding of standard natural number operations into integer operations **) |
436 (** embedding of standard natural number operations into integer operations **) |
480 "int (if P then n else m) = (if P then int n else int m)" |
490 "int (if P then n else m) = (if P then int n else int m)" |
481 by (auto simp add: int_mult zdiv_int zmod_int)} |
491 by (auto simp add: int_mult zdiv_int zmod_int)} |
482 |
492 |
483 fun mk_number_eq ctxt i lhs = |
493 fun mk_number_eq ctxt i lhs = |
484 let |
494 let |
485 val eq = U.mk_cequals lhs (Numeral.mk_cnumber @{ctyp int} i) |
495 val eq = SMT_Utils.mk_cequals lhs (Numeral.mk_cnumber @{ctyp int} i) |
486 val ss = HOL_ss |
496 val ss = HOL_ss |
487 addsimps [@{thm Nat_Numeral.int_nat_number_of}] |
497 addsimps [@{thm Nat_Numeral.int_nat_number_of}] |
488 addsimps @{thms neg_simps} |
498 addsimps @{thms neg_simps} |
489 fun tac _ = Simplifier.simp_tac (Simplifier.context ctxt ss) 1 |
499 fun tac _ = Simplifier.simp_tac (Simplifier.context ctxt ss) 1 |
490 in Goal.norm_result (Goal.prove_internal [] eq tac) end |
500 in Goal.norm_result (Goal.prove_internal [] eq tac) end |
506 | _ => Conv.no_conv) ct |
516 | _ => Conv.no_conv) ct |
507 |
517 |
508 and ints_conv ctxt = Conv.top_sweep_conv int_conv ctxt |
518 and ints_conv ctxt = Conv.top_sweep_conv int_conv ctxt |
509 |
519 |
510 and expand_conv ctxt = |
520 and expand_conv ctxt = |
511 U.if_conv (is_nat_const o Term.head_of) |
521 SMT_Utils.if_conv (is_nat_const o Term.head_of) |
512 (expand_head_conv (Conv.rewrs_conv expands) then_conv ints_conv ctxt) |
522 (expand_head_conv (Conv.rewrs_conv expands) then_conv ints_conv ctxt) |
513 (int_conv ctxt) |
523 (int_conv ctxt) |
514 |
524 |
515 and nat_conv ctxt = U.if_exists_conv is_nat_const' |
525 and nat_conv ctxt = SMT_Utils.if_exists_conv is_nat_const' |
516 (Conv.top_sweep_conv expand_conv ctxt) |
526 (Conv.top_sweep_conv expand_conv ctxt) |
517 |
527 |
518 val uses_nat_int = Term.exists_subterm (member (op aconv) nat_int_coercions) |
528 val uses_nat_int = Term.exists_subterm (member (op aconv) nat_int_coercions) |
519 in |
529 in |
520 |
530 |
523 fun add_nat_embedding thms = |
533 fun add_nat_embedding thms = |
524 if exists (uses_nat_int o Thm.prop_of) thms then (thms, nat_embedding) |
534 if exists (uses_nat_int o Thm.prop_of) thms then (thms, nat_embedding) |
525 else (thms, []) |
535 else (thms, []) |
526 |
536 |
527 val setup_nat_as_int = |
537 val setup_nat_as_int = |
528 B.add_builtin_typ_ext (@{typ nat}, K true) #> |
538 SMT_Builtin.add_builtin_typ_ext (@{typ nat}, K true) #> |
529 fold (B.add_builtin_fun_ext' o Term.dest_Const) builtin_nat_ops |
539 fold (SMT_Builtin.add_builtin_fun_ext' o Term.dest_Const) builtin_nat_ops |
530 |
540 |
531 end |
541 end |
532 |
542 |
533 |
543 |
534 (** normalize numerals **) |
544 (** normalize numerals **) |
540 rewrite Numeral1 into 1 |
550 rewrite Numeral1 into 1 |
541 *) |
551 *) |
542 |
552 |
543 fun is_strange_number ctxt (t as Const (@{const_name number_of}, _) $ _) = |
553 fun is_strange_number ctxt (t as Const (@{const_name number_of}, _) $ _) = |
544 (case try HOLogic.dest_number t of |
554 (case try HOLogic.dest_number t of |
545 SOME (_, i) => B.is_builtin_num ctxt t andalso i < 2 |
555 SOME (_, i) => SMT_Builtin.is_builtin_num ctxt t andalso i < 2 |
546 | NONE => false) |
556 | NONE => false) |
547 | is_strange_number _ _ = false |
557 | is_strange_number _ _ = false |
548 |
558 |
549 val pos_num_ss = HOL_ss |
559 val pos_num_ss = HOL_ss |
550 addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}] |
560 addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}] |
556 "Int.Bit0 (- Int.Pls) = - Int.Pls" |
566 "Int.Bit0 (- Int.Pls) = - Int.Pls" |
557 "Int.Bit0 (- k) = - Int.Bit0 k" |
567 "Int.Bit0 (- k) = - Int.Bit0 k" |
558 "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)" |
568 "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)" |
559 by simp_all (simp add: pred_def)} |
569 by simp_all (simp add: pred_def)} |
560 |
570 |
561 fun norm_num_conv ctxt = U.if_conv (is_strange_number ctxt) |
571 fun norm_num_conv ctxt = |
562 (Simplifier.rewrite (Simplifier.context ctxt pos_num_ss)) Conv.no_conv |
572 SMT_Utils.if_conv (is_strange_number ctxt) |
563 in |
573 (Simplifier.rewrite (Simplifier.context ctxt pos_num_ss)) Conv.no_conv |
564 |
574 in |
565 fun normalize_numerals_conv ctxt = U.if_exists_conv (is_strange_number ctxt) |
575 |
566 (Conv.top_sweep_conv norm_num_conv ctxt) |
576 fun normalize_numerals_conv ctxt = |
|
577 SMT_Utils.if_exists_conv (is_strange_number ctxt) |
|
578 (Conv.top_sweep_conv norm_num_conv ctxt) |
567 |
579 |
568 end |
580 end |
569 |
581 |
570 |
582 |
571 (** combined unfoldings and rewritings **) |
583 (** combined unfoldings and rewritings **) |
597 |
609 |
598 type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list |
610 type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list |
599 |
611 |
600 structure Extra_Norms = Generic_Data |
612 structure Extra_Norms = Generic_Data |
601 ( |
613 ( |
602 type T = extra_norm U.dict |
614 type T = extra_norm SMT_Utils.dict |
603 val empty = [] |
615 val empty = [] |
604 val extend = I |
616 val extend = I |
605 fun merge data = U.dict_merge fst data |
617 fun merge data = SMT_Utils.dict_merge fst data |
606 ) |
618 ) |
607 |
619 |
608 fun add_extra_norm (cs, norm) = Extra_Norms.map (U.dict_update (cs, norm)) |
620 fun add_extra_norm (cs, norm) = |
|
621 Extra_Norms.map (SMT_Utils.dict_update (cs, norm)) |
609 |
622 |
610 fun apply_extra_norms ithms ctxt = |
623 fun apply_extra_norms ithms ctxt = |
611 let |
624 let |
612 val cs = SMT_Config.solver_class_of ctxt |
625 val cs = SMT_Config.solver_class_of ctxt |
613 val es = U.dict_lookup (Extra_Norms.get (Context.Proof ctxt)) cs |
626 val es = SMT_Utils.dict_lookup (Extra_Norms.get (Context.Proof ctxt)) cs |
614 in (burrow_ids (fold (fn e => e ctxt) es o rpair []) ithms, ctxt) end |
627 in (burrow_ids (fold (fn e => e ctxt) es o rpair []) ithms, ctxt) end |
615 |
628 |
616 fun normalize iwthms ctxt = |
629 fun normalize iwthms ctxt = |
617 iwthms |
630 iwthms |
618 |> gen_normalize ctxt |
631 |> gen_normalize ctxt |