--- a/src/HOL/Tools/SMT/smt_normalize.ML Wed Dec 15 08:39:24 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_normalize.ML Wed Dec 15 08:39:24 2010 +0100
@@ -1,28 +1,17 @@
(* Title: HOL/Tools/SMT/smt_normalize.ML
Author: Sascha Boehme, TU Muenchen
-Normalization steps on theorems required by SMT solvers:
- * simplify trivial distincts (those with less than three elements),
- * rewrite bool case expressions as if expressions,
- * normalize numerals (e.g. replace negative numerals by negated positive
- numerals),
- * embed natural numbers into integers,
- * add extra rules specifying types and constants which occur frequently,
- * fully translate into object logic, add universal closure,
- * monomorphize (create instances of schematic rules),
- * lift lambda terms,
- * make applications explicit for functions with varying number of arguments.
- * add (hypothetical definitions for) missing datatype selectors,
+Normalization steps on theorems required by SMT solvers.
*)
signature SMT_NORMALIZE =
sig
- type extra_norm = bool -> (int * thm) list -> Proof.context ->
- (int * thm) list * Proof.context
- val normalize: extra_norm -> bool -> (int * thm) list -> Proof.context ->
- (int * thm) list * Proof.context
val atomize_conv: Proof.context -> conv
- val eta_expand_conv: (Proof.context -> conv) -> Proof.context -> conv
+ type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list
+ val add_extra_norm: SMT_Utils.class * extra_norm -> Context.generic ->
+ Context.generic
+ val normalize: Proof.context -> (int * (int option * thm)) list ->
+ (int * thm) list
val setup: theory -> theory
end
@@ -32,12 +21,10 @@
structure U = SMT_Utils
structure B = SMT_Builtin
-infix 2 ??
-fun (test ?? f) x = if test x then f x else x
-
+(* general theorem normalizations *)
-(* instantiate elimination rules *)
+(** instantiate elimination rules **)
local
val (cpfalse, cfalse) = `U.mk_cprop (Thm.cterm_of @{theory} @{const False})
@@ -56,281 +43,18 @@
end
-
-(* simplification of trivial distincts (distinct should have at least
- three elements in the argument list) *)
-
-local
- fun is_trivial_distinct (Const (@{const_name distinct}, _) $ t) =
- (case try HOLogic.dest_list t of
- SOME [] => true
- | SOME [_] => true
- | _ => false)
- | is_trivial_distinct _ = false
-
- val thms = map mk_meta_eq @{lemma
- "distinct [] = True"
- "distinct [x] = True"
- "distinct [x, y] = (x ~= y)"
- by simp_all}
- fun distinct_conv _ =
- U.if_true_conv is_trivial_distinct (Conv.rewrs_conv thms)
-in
-fun trivial_distinct ctxt =
- map (apsnd ((Term.exists_subterm is_trivial_distinct o Thm.prop_of) ??
- Conv.fconv_rule (Conv.top_conv distinct_conv ctxt)))
-end
-
-
-
-(* rewrite bool case expressions as if expressions *)
-
-local
- val is_bool_case = (fn
- Const (@{const_name "bool.bool_case"}, _) $ _ $ _ $ _ => true
- | _ => false)
+(** normalize definitions **)
- val thm = mk_meta_eq @{lemma
- "(case P of True => x | False => y) = (if P then x else y)" by simp}
- val unfold_conv = U.if_true_conv is_bool_case (Conv.rewr_conv thm)
-in
-fun rewrite_bool_cases ctxt =
- map (apsnd ((Term.exists_subterm is_bool_case o Thm.prop_of) ??
- Conv.fconv_rule (Conv.top_conv (K unfold_conv) ctxt)))
-
-val setup_bool_case = B.add_builtin_fun_ext'' @{const_name "bool.bool_case"}
-
-end
-
-
-
-(* normalization of numerals: rewriting of negative integer numerals into
- positive numerals, Numeral0 into 0, Numeral1 into 1 *)
-
-local
- fun is_number_sort ctxt T =
- Sign.of_sort (ProofContext.theory_of ctxt) (T, @{sort number_ring})
-
- fun is_strange_number ctxt (t as Const (@{const_name number_of}, _) $ _) =
- (case try HOLogic.dest_number t of
- SOME (T, i) => is_number_sort ctxt T andalso i < 2
- | NONE => false)
- | is_strange_number _ _ = false
-
- val pos_numeral_ss = HOL_ss
- addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}]
- addsimps [@{thm Int.number_of_Pls}, @{thm Int.numeral_1_eq_1}]
- addsimps @{thms Int.pred_bin_simps}
- addsimps @{thms Int.normalize_bin_simps}
- addsimps @{lemma
- "Int.Min = - Int.Bit1 Int.Pls"
- "Int.Bit0 (- Int.Pls) = - Int.Pls"
- "Int.Bit0 (- k) = - Int.Bit0 k"
- "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)"
- by simp_all (simp add: pred_def)}
-
- fun pos_conv ctxt = U.if_conv (is_strange_number ctxt)
- (Simplifier.rewrite (Simplifier.context ctxt pos_numeral_ss))
- Conv.no_conv
-in
-fun normalize_numerals ctxt =
- map (apsnd ((Term.exists_subterm (is_strange_number ctxt) o Thm.prop_of) ??
- Conv.fconv_rule (Conv.top_sweep_conv pos_conv ctxt)))
-end
-
+fun norm_def thm =
+ (case Thm.prop_of thm of
+ @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ _ $ Abs _) =>
+ norm_def (thm RS @{thm fun_cong})
+ | Const (@{const_name "=="}, _) $ _ $ Abs _ =>
+ norm_def (thm RS @{thm meta_eq_to_obj_eq})
+ | _ => thm)
-(* embedding of standard natural number operations into integer operations *)
-
-local
- val nat_embedding = map (pair ~1) @{lemma
- "nat (int n) = n"
- "i >= 0 --> int (nat i) = i"
- "i < 0 --> int (nat i) = 0"
- by simp_all}
-
- val nat_rewriting = @{lemma
- "0 = nat 0"
- "1 = nat 1"
- "(number_of :: int => nat) = (%i. nat (number_of i))"
- "int (nat 0) = 0"
- "int (nat 1) = 1"
- "op < = (%a b. int a < int b)"
- "op <= = (%a b. int a <= int b)"
- "Suc = (%a. nat (int a + 1))"
- "op + = (%a b. nat (int a + int b))"
- "op - = (%a b. nat (int a - int b))"
- "op * = (%a b. nat (int a * int b))"
- "op div = (%a b. nat (int a div int b))"
- "op mod = (%a b. nat (int a mod int b))"
- "min = (%a b. nat (min (int a) (int b)))"
- "max = (%a b. nat (max (int a) (int b)))"
- "int (nat (int a + int b)) = int a + int b"
- "int (nat (int a + 1)) = int a + 1" (* special rule due to Suc above *)
- "int (nat (int a * int b)) = int a * int b"
- "int (nat (int a div int b)) = int a div int b"
- "int (nat (int a mod int b)) = int a mod int b"
- "int (nat (min (int a) (int b))) = min (int a) (int b)"
- "int (nat (max (int a) (int b))) = max (int a) (int b)"
- by (auto intro!: ext simp add: nat_mult_distrib nat_div_distrib
- nat_mod_distrib int_mult[symmetric] zdiv_int[symmetric]
- zmod_int[symmetric])}
-
- fun on_positive num f x =
- (case try HOLogic.dest_number (Thm.term_of num) of
- SOME (_, i) => if i >= 0 then SOME (f x) else NONE
- | NONE => NONE)
-
- val cancel_int_nat_ss = HOL_ss
- addsimps [@{thm Nat_Numeral.nat_number_of}]
- addsimps [@{thm Nat_Numeral.int_nat_number_of}]
- addsimps @{thms neg_simps}
-
- val int_eq = Thm.cterm_of @{theory} @{const "==" (int)}
-
- fun cancel_int_nat_simproc _ ss ct =
- let
- val num = Thm.dest_arg (Thm.dest_arg ct)
- val goal = Thm.mk_binop int_eq ct num
- val simpset = Simplifier.inherit_context ss cancel_int_nat_ss
- fun tac _ = Simplifier.simp_tac simpset 1
- in on_positive num (Goal.prove_internal [] goal) tac end
-
- val nat_ss = HOL_ss
- addsimps nat_rewriting
- addsimprocs [
- Simplifier.make_simproc {
- name = "cancel_int_nat_num", lhss = [@{cpat "int (nat _)"}],
- proc = cancel_int_nat_simproc, identifier = [] }]
-
- fun conv ctxt = Simplifier.rewrite (Simplifier.context ctxt nat_ss)
-
- val uses_nat_type = Term.exists_type (Term.exists_subtype (equal @{typ nat}))
- val uses_nat_int = Term.exists_subterm (member (op aconv)
- [@{const of_nat (int)}, @{const nat}])
-
- val nat_ops = [
- @{const less (nat)}, @{const less_eq (nat)},
- @{const Suc}, @{const plus (nat)}, @{const minus (nat)},
- @{const times (nat)}, @{const div (nat)}, @{const mod (nat)}]
- val nat_ops' = @{const of_nat (int)} :: @{const nat} :: nat_ops
-in
-fun nat_as_int ctxt =
- map (apsnd ((uses_nat_type o Thm.prop_of) ?? Conv.fconv_rule (conv ctxt))) #>
- exists (uses_nat_int o Thm.prop_of o snd) ?? append nat_embedding
-
-val setup_nat_as_int =
- B.add_builtin_typ_ext (@{typ nat}, K true) #>
- fold (B.add_builtin_fun_ext' o Term.dest_Const) nat_ops'
-end
-
-
-
-(* further normalizations: beta/eta, universal closure, atomize *)
-
-val eta_expand_eq = @{lemma "f == (%x. f x)" by (rule reflexive)}
-
-fun eta_expand_conv cv ctxt =
- Conv.rewr_conv eta_expand_eq then_conv Conv.abs_conv (cv o snd) ctxt
-
-local
- val eta_conv = eta_expand_conv
-
- fun args_conv cv ct =
- (case Thm.term_of ct of
- _ $ _ => Conv.combination_conv (args_conv cv) cv
- | _ => Conv.all_conv) ct
-
- fun eta_args_conv cv 0 = args_conv o cv
- | eta_args_conv cv i = eta_conv (eta_args_conv cv (i-1))
-
- fun keep_conv ctxt = Conv.binder_conv (norm_conv o snd) ctxt
- and eta_binder_conv ctxt = Conv.arg_conv (eta_conv norm_conv ctxt)
- and keep_let_conv ctxt = Conv.combination_conv
- (Conv.arg_conv (norm_conv ctxt)) (Conv.abs_conv (norm_conv o snd) ctxt)
- and unfold_let_conv ctxt = Conv.combination_conv
- (Conv.arg_conv (norm_conv ctxt)) (eta_conv norm_conv ctxt)
- and unfold_conv thm ctxt = Conv.rewr_conv thm then_conv keep_conv ctxt
- and unfold_ex1_conv ctxt = unfold_conv @{thm Ex1_def} ctxt
- and unfold_ball_conv ctxt = unfold_conv (mk_meta_eq @{thm Ball_def}) ctxt
- and unfold_bex_conv ctxt = unfold_conv (mk_meta_eq @{thm Bex_def}) ctxt
- and norm_conv ctxt ct =
- (case Thm.term_of ct of
- Const (@{const_name All}, _) $ Abs _ => keep_conv
- | Const (@{const_name All}, _) $ _ => eta_binder_conv
- | Const (@{const_name All}, _) => eta_conv eta_binder_conv
- | Const (@{const_name Ex}, _) $ Abs _ => keep_conv
- | Const (@{const_name Ex}, _) $ _ => eta_binder_conv
- | Const (@{const_name Ex}, _) => eta_conv eta_binder_conv
- | Const (@{const_name Let}, _) $ _ $ Abs _ => keep_let_conv
- | Const (@{const_name Let}, _) $ _ $ _ => unfold_let_conv
- | Const (@{const_name Let}, _) $ _ => eta_conv unfold_let_conv
- | Const (@{const_name Let}, _) => eta_conv (eta_conv unfold_let_conv)
- | Const (@{const_name Ex1}, _) $ _ => unfold_ex1_conv
- | Const (@{const_name Ex1}, _) => eta_conv unfold_ex1_conv
- | Const (@{const_name Ball}, _) $ _ $ _ => unfold_ball_conv
- | Const (@{const_name Ball}, _) $ _ => eta_conv unfold_ball_conv
- | Const (@{const_name Ball}, _) => eta_conv (eta_conv unfold_ball_conv)
- | Const (@{const_name Bex}, _) $ _ $ _ => unfold_bex_conv
- | Const (@{const_name Bex}, _) $ _ => eta_conv unfold_bex_conv
- | Const (@{const_name Bex}, _) => eta_conv (eta_conv unfold_bex_conv)
- | Abs _ => Conv.abs_conv (norm_conv o snd)
- | _ =>
- (case Term.strip_comb (Thm.term_of ct) of
- (Const (c as (_, T)), ts) =>
- if SMT_Builtin.is_builtin_fun ctxt c ts
- then eta_args_conv norm_conv
- (length (Term.binder_types T) - length ts)
- else args_conv o norm_conv
- | _ => args_conv o norm_conv)) ctxt ct
-
- fun is_normed ctxt t =
- (case t of
- Const (@{const_name All}, _) $ Abs (_, _, u) => is_normed ctxt u
- | Const (@{const_name All}, _) $ _ => false
- | Const (@{const_name All}, _) => false
- | Const (@{const_name Ex}, _) $ Abs (_, _, u) => is_normed ctxt u
- | Const (@{const_name Ex}, _) $ _ => false
- | Const (@{const_name Ex}, _) => false
- | Const (@{const_name Let}, _) $ u1 $ Abs (_, _, u2) =>
- is_normed ctxt u1 andalso is_normed ctxt u2
- | Const (@{const_name Let}, _) $ _ $ _ => false
- | Const (@{const_name Let}, _) $ _ => false
- | Const (@{const_name Let}, _) => false
- | Const (@{const_name Ex1}, _) $ _ => false
- | Const (@{const_name Ex1}, _) => false
- | Const (@{const_name Ball}, _) $ _ $ _ => false
- | Const (@{const_name Ball}, _) $ _ => false
- | Const (@{const_name Ball}, _) => false
- | Const (@{const_name Bex}, _) $ _ $ _ => false
- | Const (@{const_name Bex}, _) $ _ => false
- | Const (@{const_name Bex}, _) => false
- | Abs (_, _, u) => is_normed ctxt u
- | _ =>
- (case Term.strip_comb t of
- (Const (c as (_, T)), ts) =>
- if SMT_Builtin.is_builtin_fun ctxt c ts
- then length (Term.binder_types T) = length ts andalso
- forall (is_normed ctxt) ts
- else forall (is_normed ctxt) ts
- | (_, ts) => forall (is_normed ctxt) ts))
-in
-fun norm_binder_conv ctxt =
- U.if_conv (is_normed ctxt) Conv.all_conv (norm_conv ctxt)
-
-val setup_unfolded_quants =
- fold B.add_builtin_fun_ext'' [@{const_name Ball}, @{const_name Bex},
- @{const_name Ex1}]
-
-end
-
-fun norm_def ctxt thm =
- (case Thm.prop_of thm of
- @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ _ $ Abs _) =>
- norm_def ctxt (thm RS @{thm fun_cong})
- | Const (@{const_name "=="}, _) $ _ $ Abs _ =>
- norm_def ctxt (thm RS @{thm meta_eq_to_obj_eq})
- | _ => thm)
+(** atomization **)
fun atomize_conv ctxt ct =
(case Thm.term_of ct of
@@ -349,243 +73,543 @@
fold B.add_builtin_fun_ext'' [@{const_name "==>"}, @{const_name "=="},
@{const_name all}, @{const_name Trueprop}]
-fun normalize_rule ctxt =
- Conv.fconv_rule (
- (* reduce lambda abstractions, except at known binders: *)
- Thm.beta_conversion true then_conv
- Thm.eta_conversion then_conv
- norm_binder_conv ctxt) #>
- norm_def ctxt #>
- Drule.forall_intr_vars #>
- Conv.fconv_rule (atomize_conv ctxt)
-
-
-(* lift lambda terms into additional rules *)
+(** unfold special quantifiers **)
local
- fun used_vars cvs ct =
- let
- val lookup = AList.lookup (op aconv) (map (` Thm.term_of) cvs)
- val add = (fn SOME ct => insert (op aconvc) ct | _ => I)
- in Term.fold_aterms (add o lookup) (Thm.term_of ct) [] end
-
- fun apply cv thm =
- let val thm' = Thm.combination thm (Thm.reflexive cv)
- in Thm.transitive thm' (Thm.beta_conversion false (Thm.rhs_of thm')) end
- fun apply_def cvs eq = Thm.symmetric (fold apply cvs eq)
+ val ex1_def = mk_meta_eq @{lemma
+ "Ex1 = (%P. EX x. P x & (ALL y. P y --> y = x))"
+ by (rule ext) (simp only: Ex1_def)}
- fun replace_lambda cvs ct (cx as (ctxt, defs)) =
- let
- val cvs' = used_vars cvs ct
- val ct' = fold_rev Thm.cabs cvs' ct
- in
- (case Termtab.lookup defs (Thm.term_of ct') of
- SOME eq => (apply_def cvs' eq, cx)
- | NONE =>
- let
- val {T, ...} = Thm.rep_cterm ct' and n = Name.uu
- val (n', ctxt') = yield_singleton Variable.variant_fixes n ctxt
- val cu = U.mk_cequals (U.certify ctxt (Free (n', T))) ct'
- val (eq, ctxt'') = yield_singleton Assumption.add_assumes cu ctxt'
- val defs' = Termtab.update (Thm.term_of ct', eq) defs
- in (apply_def cvs' eq, (ctxt'', defs')) end)
- end
+ val ball_def = mk_meta_eq @{lemma "Ball = (%A P. ALL x. x : A --> P x)"
+ by (rule ext)+ (rule Ball_def)}
+
+ val bex_def = mk_meta_eq @{lemma "Bex = (%A P. EX x. x : A & P x)"
+ by (rule ext)+ (rule Bex_def)}
- fun none ct cx = (Thm.reflexive ct, cx)
- fun in_comb f g ct cx =
- let val (cu1, cu2) = Thm.dest_comb ct
- in cx |> f cu1 ||>> g cu2 |>> uncurry Thm.combination end
- fun in_arg f = in_comb none f
- fun in_abs f cvs ct (ctxt, defs) =
- let
- val (n, ctxt') = yield_singleton Variable.variant_fixes Name.uu ctxt
- val (cv, cu) = Thm.dest_abs (SOME n) ct
- in (ctxt', defs) |> f (cv :: cvs) cu |>> Thm.abstract_rule n cv end
-
- fun traverse cvs ct =
- (case Thm.term_of ct of
- Const (@{const_name All}, _) $ Abs _ => in_arg (in_abs traverse cvs)
- | Const (@{const_name Ex}, _) $ Abs _ => in_arg (in_abs traverse cvs)
- | Const (@{const_name Let}, _) $ _ $ Abs _ =>
- in_comb (in_arg (traverse cvs)) (in_abs traverse cvs)
- | Abs _ => at_lambda cvs
- | _ $ _ => in_comb (traverse cvs) (traverse cvs)
- | _ => none) ct
+ val special_quants = [(@{const_name Ex1}, ex1_def),
+ (@{const_name Ball}, ball_def), (@{const_name Bex}, bex_def)]
+
+ fun special_quant (Const (n, _)) = AList.lookup (op =) special_quants n
+ | special_quant _ = NONE
- and at_lambda cvs ct =
- in_abs traverse cvs ct #-> (fn thm =>
- replace_lambda cvs (Thm.rhs_of thm) #>> Thm.transitive thm)
+ fun special_quant_conv _ ct =
+ (case special_quant (Thm.term_of ct) of
+ SOME thm => Conv.rewr_conv thm
+ | NONE => Conv.all_conv) ct
+in
- fun has_free_lambdas t =
- (case t of
- Const (@{const_name All}, _) $ Abs (_, _, u) => has_free_lambdas u
- | Const (@{const_name Ex}, _) $ Abs (_, _, u) => has_free_lambdas u
- | Const (@{const_name Let}, _) $ u1 $ Abs (_, _, u2) =>
- has_free_lambdas u1 orelse has_free_lambdas u2
- | Abs _ => true
- | u1 $ u2 => has_free_lambdas u1 orelse has_free_lambdas u2
- | _ => false)
+fun unfold_special_quants_conv ctxt =
+ U.if_exists_conv (is_some o special_quant)
+ (Conv.top_conv special_quant_conv ctxt)
- fun lift_lm f thm cx =
- if not (has_free_lambdas (Thm.prop_of thm)) then (thm, cx)
- else cx |> f (Thm.cprop_of thm) |>> (fn thm' => Thm.equal_elim thm' thm)
-in
-fun lift_lambdas irules ctxt =
- let
- val cx = (ctxt, Termtab.empty)
- val (idxs, thms) = split_list irules
- val (thms', (ctxt', defs)) = fold_map (lift_lm (traverse [])) thms cx
- val eqs = Termtab.fold (cons o normalize_rule ctxt' o snd) defs []
- in (map (pair ~1) eqs @ (idxs ~~ thms'), ctxt') end
+val setup_unfolded_quants = fold (B.add_builtin_fun_ext'' o fst) special_quants
+
end
-
-(* make application explicit for functions with varying number of arguments *)
+(** trigger inference **)
local
- val const = prefix "c" and free = prefix "f"
- fun min i (e as (_, j)) = if i <> j then (true, Int.min (i, j)) else e
- fun add t i = Symtab.map_default (t, (false, i)) (min i)
- fun traverse t =
+ (*** check trigger syntax ***)
+
+ fun dest_trigger (Const (@{const_name pat}, _) $ _) = SOME true
+ | dest_trigger (Const (@{const_name nopat}, _) $ _) = SOME false
+ | dest_trigger _ = NONE
+
+ fun eq_list [] = false
+ | eq_list (b :: bs) = forall (equal b) bs
+
+ fun proper_trigger t =
+ t
+ |> these o try HOLogic.dest_list
+ |> map (map_filter dest_trigger o these o try HOLogic.dest_list)
+ |> (fn [] => false | bss => forall eq_list bss)
+
+ fun proper_quant inside f t =
+ (case t of
+ Const (@{const_name All}, _) $ Abs (_, _, u) => proper_quant true f u
+ | Const (@{const_name Ex}, _) $ Abs (_, _, u) => proper_quant true f u
+ | @{const trigger} $ p $ u =>
+ (if inside then f p else false) andalso proper_quant false f u
+ | Abs (_, _, u) => proper_quant false f u
+ | u1 $ u2 => proper_quant false f u1 andalso proper_quant false f u2
+ | _ => true)
+
+ fun check_trigger_error ctxt t =
+ error ("SMT triggers must only occur under quantifier and multipatterns " ^
+ "must have the same kind: " ^ Syntax.string_of_term ctxt t)
+
+ fun check_trigger_conv ctxt ct =
+ if proper_quant false proper_trigger (Thm.term_of ct) then Conv.all_conv ct
+ else check_trigger_error ctxt (Thm.term_of ct)
+
+
+ (*** infer simple triggers ***)
+
+ fun dest_cond_eq ct =
+ (case Thm.term_of ct of
+ Const (@{const_name HOL.eq}, _) $ _ $ _ => Thm.dest_binop ct
+ | @{const HOL.implies} $ _ $ _ => dest_cond_eq (Thm.dest_arg ct)
+ | _ => raise CTERM ("no equation", [ct]))
+
+ fun get_constrs thy (Type (n, _)) = these (Datatype.get_constrs thy n)
+ | get_constrs _ _ = []
+
+ fun is_constr thy (n, T) =
+ let fun match (m, U) = m = n andalso Sign.typ_instance thy (T, U)
+ in can (the o find_first match o get_constrs thy o Term.body_type) T end
+
+ fun is_constr_pat thy t =
+ (case Term.strip_comb t of
+ (Free _, []) => true
+ | (Const c, ts) => is_constr thy c andalso forall (is_constr_pat thy) ts
+ | _ => false)
+
+ fun is_simp_lhs ctxt t =
(case Term.strip_comb t of
- (Const (n, _), ts) => add (const n) (length ts) #> fold traverse ts
- | (Free (n, _), ts) => add (free n) (length ts) #> fold traverse ts
- | (Abs (_, _, u), ts) => fold traverse (u :: ts)
- | (_, ts) => fold traverse ts)
- fun prune tab = Symtab.fold (fn (n, (true, i)) =>
- Symtab.update (n, i) | _ => I) tab Symtab.empty
+ (Const c, ts as _ :: _) =>
+ not (B.is_builtin_fun_ext ctxt c ts) andalso
+ forall (is_constr_pat (ProofContext.theory_of ctxt)) ts
+ | _ => false)
+
+ fun has_all_vars vs t =
+ subset (op aconv) (vs, map Free (Term.add_frees t []))
+
+ fun minimal_pats vs ct =
+ if has_all_vars vs (Thm.term_of ct) then
+ (case Thm.term_of ct of
+ _ $ _ =>
+ (case pairself (minimal_pats vs) (Thm.dest_comb ct) of
+ ([], []) => [[ct]]
+ | (ctss, ctss') => union (eq_set (op aconvc)) ctss ctss')
+ | _ => [[ct]])
+ else []
+
+ fun proper_mpat _ _ _ [] = false
+ | proper_mpat thy gen u cts =
+ let
+ val tps = (op ~~) (`gen (map Thm.term_of cts))
+ fun some_match u = tps |> exists (fn (t', t) =>
+ Pattern.matches thy (t', u) andalso not (t aconv u))
+ in not (Term.exists_subterm some_match u) end
+
+ val pat = U.mk_const_pat @{theory} @{const_name SMT.pat} U.destT1
+ fun mk_pat ct = Thm.capply (U.instT' ct pat) ct
- fun binop_conv cv1 cv2 = Conv.combination_conv (Conv.arg_conv cv1) cv2
- fun nary_conv conv1 conv2 ct =
- (Conv.combination_conv (nary_conv conv1 conv2) conv2 else_conv conv1) ct
- fun abs_conv conv tb = Conv.abs_conv (fn (cv, cx) =>
- let val n = fst (Term.dest_Free (Thm.term_of cv))
- in conv (Symtab.update (free n, 0) tb) cx end)
- val fun_app_rule = @{lemma "f x == fun_app f x" by (simp add: fun_app_def)}
+ fun mk_clist T = pairself (Thm.cterm_of @{theory})
+ (HOLogic.cons_const T, HOLogic.nil_const T)
+ fun mk_list (ccons, cnil) f cts = fold_rev (Thm.mk_binop ccons o f) cts cnil
+ val mk_pat_list = mk_list (mk_clist @{typ SMT.pattern})
+ val mk_mpat_list = mk_list (mk_clist @{typ "SMT.pattern list"})
+ fun mk_trigger ctss = mk_mpat_list (mk_pat_list mk_pat) ctss
+
+ val trigger_eq =
+ mk_meta_eq @{lemma "p = SMT.trigger t p" by (simp add: trigger_def)}
+
+ fun insert_trigger_conv [] ct = Conv.all_conv ct
+ | insert_trigger_conv ctss ct =
+ let val (ctr, cp) = Thm.dest_binop (Thm.rhs_of trigger_eq) ||> rpair ct
+ in Thm.instantiate ([], [cp, (ctr, mk_trigger ctss)]) trigger_eq end
+
+ fun infer_trigger_eq_conv outer_ctxt (ctxt, cvs) ct =
+ let
+ val (lhs, rhs) = dest_cond_eq ct
+
+ val vs = map Thm.term_of cvs
+ val thy = ProofContext.theory_of ctxt
+
+ fun get_mpats ct =
+ if is_simp_lhs ctxt (Thm.term_of ct) then minimal_pats vs ct
+ else []
+ val gen = Variable.export_terms ctxt outer_ctxt
+ val filter_mpats = filter (proper_mpat thy gen (Thm.term_of rhs))
+
+ in insert_trigger_conv (filter_mpats (get_mpats lhs)) ct end
+
+ fun try_trigger_conv cv ct =
+ if proper_quant false (K false) (Thm.term_of ct) then Conv.all_conv ct
+ else Conv.try_conv cv ct
+
+ fun infer_trigger_conv ctxt =
+ if Config.get ctxt SMT_Config.infer_triggers then
+ try_trigger_conv (U.under_quant_conv (infer_trigger_eq_conv ctxt) ctxt)
+ else Conv.all_conv
in
-fun explicit_application ctxt irules =
- let
- fun sub_conv tb ctxt ct =
- (case Term.strip_comb (Thm.term_of ct) of
- (Const (n, _), ts) => app_conv tb (const n) (length ts) ctxt
- | (Free (n, _), ts) => app_conv tb (free n) (length ts) ctxt
- | (Abs _, _) => nary_conv (abs_conv sub_conv tb ctxt) (sub_conv tb ctxt)
- | (_, _) => nary_conv Conv.all_conv (sub_conv tb ctxt)) ct
- and app_conv tb n i ctxt =
- (case Symtab.lookup tb n of
- NONE => nary_conv Conv.all_conv (sub_conv tb ctxt)
- | SOME j => fun_app_conv tb ctxt (i - j))
- and fun_app_conv tb ctxt i ct = (
- if i = 0 then nary_conv Conv.all_conv (sub_conv tb ctxt)
- else
- Conv.rewr_conv fun_app_rule then_conv
- binop_conv (fun_app_conv tb ctxt (i-1)) (sub_conv tb ctxt)) ct
+
+fun trigger_conv ctxt =
+ U.prop_conv (check_trigger_conv ctxt then_conv infer_trigger_conv ctxt)
- fun needs_exp_app tab = Term.exists_subterm (fn
- Bound _ $ _ => true
- | Const (n, _) => Symtab.defined tab (const n)
- | Free (n, _) => Symtab.defined tab (free n)
- | _ => false)
+val setup_trigger = fold B.add_builtin_fun_ext''
+ [@{const_name SMT.pat}, @{const_name SMT.nopat}, @{const_name SMT.trigger}]
- fun rewrite tab ctxt thm =
- if not (needs_exp_app tab (Thm.prop_of thm)) then thm
- else Conv.fconv_rule (sub_conv tab ctxt) thm
-
- val tab = prune (fold (traverse o Thm.prop_of o snd) irules Symtab.empty)
- in map (apsnd (rewrite tab ctxt)) irules end
end
-
-(* add missing datatype selectors via hypothetical definitions *)
+(** adding quantifier weights **)
local
- val add = (fn Type (n, _) => Symtab.update (n, ()) | _ => I)
+ (*** check weight syntax ***)
+
+ val has_no_weight =
+ not o Term.exists_subterm (fn @{const SMT.weight} => true | _ => false)
- fun collect t =
- (case Term.strip_comb t of
- (Abs (_, T, t), _) => add T #> collect t
- | (Const (_, T), ts) => collects T ts
- | (Free (_, T), ts) => collects T ts
- | _ => I)
- and collects T ts =
- let val ((Ts, Us), U) = Term.strip_type T |> apfst (chop (length ts))
- in fold add Ts #> add (Us ---> U) #> fold collect ts end
+ fun is_weight (@{const SMT.weight} $ w $ t) =
+ (case try HOLogic.dest_number w of
+ SOME (_, i) => i > 0 andalso has_no_weight t
+ | _ => false)
+ | is_weight t = has_no_weight t
+
+ fun proper_trigger (@{const SMT.trigger} $ _ $ t) = is_weight t
+ | proper_trigger t = has_no_weight t
+
+ fun check_weight_error ctxt t =
+ error ("SMT weight must be a positive number and must only occur " ^
+ "under the top-most quantifier and an optional trigger: " ^
+ Syntax.string_of_term ctxt t)
- fun add_constructors thy n =
- (case Datatype.get_info thy n of
- NONE => I
- | SOME {descr, ...} => fold (fn (_, (_, _, cs)) => fold (fn (n, ds) =>
- fold (insert (op =) o pair n) (1 upto length ds)) cs) descr)
+ fun check_weight_conv ctxt ct =
+ if U.under_quant proper_trigger (Thm.term_of ct) then Conv.all_conv ct
+ else check_weight_error ctxt (Thm.term_of ct)
+
+
+ (*** insertion of weights ***)
+
+ fun under_trigger_conv cv ct =
+ (case Thm.term_of ct of
+ @{const SMT.trigger} $ _ $ _ => Conv.arg_conv cv
+ | _ => cv) ct
- fun add_selector (c as (n, i)) ctxt =
- (case Datatype_Selectors.lookup_selector ctxt c of
- SOME _ => ctxt
- | NONE =>
- let
- val T = Sign.the_const_type (ProofContext.theory_of ctxt) n
- val U = Term.body_type T --> nth (Term.binder_types T) (i-1)
- in
- ctxt
- |> yield_singleton Variable.variant_fixes Name.uu
- |>> pair ((n, T), i) o rpair U
- |-> Context.proof_map o Datatype_Selectors.add_selector
- end)
+ val weight_eq =
+ mk_meta_eq @{lemma "p = SMT.weight i p" by (simp add: weight_def)}
+ fun mk_weight_eq w =
+ let val cv = Thm.dest_arg1 (Thm.rhs_of weight_eq)
+ in
+ Thm.instantiate ([], [(cv, Numeral.mk_cnumber @{ctyp int} w)]) weight_eq
+ end
+
+ fun add_weight_conv NONE _ = Conv.all_conv
+ | add_weight_conv (SOME weight) ctxt =
+ let val cv = Conv.rewr_conv (mk_weight_eq weight)
+ in U.under_quant_conv (K (under_trigger_conv cv)) ctxt end
in
-fun datatype_selectors irules ctxt =
- let
- val ns = Symtab.keys (fold (collect o Thm.prop_of o snd) irules Symtab.empty)
- val cs = fold (add_constructors (ProofContext.theory_of ctxt)) ns []
- in (irules, fold add_selector cs ctxt) end
- (* FIXME: also generate hypothetical definitions for the selectors *)
+fun weight_conv weight ctxt =
+ U.prop_conv (check_weight_conv ctxt then_conv add_weight_conv weight ctxt)
+
+val setup_weight = B.add_builtin_fun_ext'' @{const_name SMT.weight}
end
-
-(* combined normalization *)
+(** combined general normalizations **)
-type extra_norm = bool -> (int * thm) list -> Proof.context ->
- (int * thm) list * Proof.context
-
-fun with_context f irules ctxt = (f ctxt irules, ctxt)
+fun gen_normalize1_conv ctxt weight =
+ atomize_conv ctxt then_conv
+ unfold_special_quants_conv ctxt then_conv
+ trigger_conv ctxt then_conv
+ weight_conv weight ctxt
-fun normalize extra_norm with_datatypes irules ctxt =
- let
- fun norm f ctxt' (i, thm) =
- if Config.get ctxt' SMT_Config.drop_bad_facts then
- (case try (f ctxt') thm of
- SOME thm' => SOME (i, thm')
- | NONE => (SMT_Config.verbose_msg ctxt' (prefix ("Warning: " ^
- "dropping assumption: ") o Display.string_of_thm ctxt') thm; NONE))
- else SOME (i, f ctxt' thm)
- in
- irules
- |> map (apsnd instantiate_elim)
- |> trivial_distinct ctxt
- |> rewrite_bool_cases ctxt
- |> normalize_numerals ctxt
- |> nat_as_int ctxt
- |> rpair ctxt
- |-> extra_norm with_datatypes
- |-> with_context (map_filter o norm normalize_rule)
- |-> SMT_Monomorph.monomorph
- |-> lift_lambdas
- |-> with_context explicit_application
- |-> (if with_datatypes then datatype_selectors else pair)
- end
+fun gen_normalize1 ctxt weight thm =
+ thm
+ |> instantiate_elim
+ |> norm_def
+ |> Conv.fconv_rule (Thm.beta_conversion true then_conv Thm.eta_conversion)
+ |> Drule.forall_intr_vars
+ |> Conv.fconv_rule (gen_normalize1_conv ctxt weight)
+
+fun drop_fact_warning ctxt =
+ let val pre = prefix "Warning: dropping assumption: "
+ in SMT_Config.verbose_msg ctxt (pre o Display.string_of_thm ctxt) end
+
+fun gen_norm1_safe ctxt (i, (weight, thm)) =
+ if Config.get ctxt SMT_Config.drop_bad_facts then
+ (case try (gen_normalize1 ctxt weight) thm of
+ SOME thm' => SOME (i, thm')
+ | NONE => (drop_fact_warning ctxt thm; NONE))
+ else SOME (i, gen_normalize1 ctxt weight thm)
+
+fun gen_normalize ctxt iwthms = map_filter (gen_norm1_safe ctxt) iwthms
-(* setup *)
+(* unfolding of definitions and theory-specific rewritings *)
+
+(** unfold trivial distincts **)
+
+local
+ fun is_trivial_distinct (Const (@{const_name distinct}, _) $ t) =
+ (case try HOLogic.dest_list t of
+ SOME [] => true
+ | SOME [_] => true
+ | _ => false)
+ | is_trivial_distinct _ = false
+
+ val thms = map mk_meta_eq @{lemma
+ "distinct [] = True"
+ "distinct [x] = True"
+ "distinct [x, y] = (x ~= y)"
+ by simp_all}
+ fun distinct_conv _ =
+ U.if_true_conv is_trivial_distinct (Conv.rewrs_conv thms)
+in
+
+fun trivial_distinct_conv ctxt = U.if_exists_conv is_trivial_distinct
+ (Conv.top_conv distinct_conv ctxt)
+
+end
+
+
+(** rewrite bool case expressions as if expressions **)
+
+local
+ fun is_bool_case (Const (@{const_name "bool.bool_case"}, _)) = true
+ | is_bool_case _ = false
+
+ val thm = mk_meta_eq @{lemma
+ "bool_case = (%x y P. if P then x else y)" by (rule ext)+ simp}
+
+ fun unfold_conv _ = U.if_true_conv is_bool_case (Conv.rewr_conv thm)
+in
+
+fun rewrite_bool_case_conv ctxt = U.if_exists_conv is_bool_case
+ (Conv.top_conv unfold_conv ctxt)
+
+val setup_bool_case = B.add_builtin_fun_ext'' @{const_name "bool.bool_case"}
+
+end
+
+
+(** unfold abs, min and max **)
+
+local
+ val abs_def = mk_meta_eq @{lemma
+ "abs = (%a::'a::abs_if. if a < 0 then - a else a)"
+ by (rule ext) (rule abs_if)}
+
+ val min_def = mk_meta_eq @{lemma "min = (%a b. if a <= b then a else b)"
+ by (rule ext)+ (rule min_def)}
+
+ val max_def = mk_meta_eq @{lemma "max = (%a b. if a <= b then b else a)"
+ by (rule ext)+ (rule max_def)}
+
+ val defs = [(@{const_name min}, min_def), (@{const_name max}, max_def),
+ (@{const_name abs}, abs_def)]
+
+ fun is_builtinT ctxt T = B.is_builtin_typ_ext ctxt (Term.domain_type T)
+
+ fun abs_min_max ctxt (Const (n, T)) =
+ (case AList.lookup (op =) defs n of
+ NONE => NONE
+ | SOME thm => if is_builtinT ctxt T then SOME thm else NONE)
+ | abs_min_max _ _ = NONE
+
+ fun unfold_amm_conv ctxt ct =
+ (case abs_min_max ctxt (Thm.term_of ct) of
+ SOME thm => Conv.rewr_conv thm
+ | NONE => Conv.all_conv) ct
+in
+
+fun unfold_abs_min_max_conv ctxt =
+ U.if_exists_conv (is_some o abs_min_max ctxt)
+ (Conv.top_conv unfold_amm_conv ctxt)
+
+val setup_abs_min_max = fold (B.add_builtin_fun_ext'' o fst) defs
+
+end
+
+
+(** embedding of standard natural number operations into integer operations **)
+
+local
+ val nat_embedding = @{lemma
+ "ALL n. nat (int n) = n"
+ "ALL i. i >= 0 --> int (nat i) = i"
+ "ALL i. i < 0 --> int (nat i) = 0"
+ by simp_all}
+
+ val nat_ops = [
+ @{const less (nat)}, @{const less_eq (nat)},
+ @{const Suc}, @{const plus (nat)}, @{const minus (nat)},
+ @{const times (nat)}, @{const div (nat)}, @{const mod (nat)}]
+
+ val nat_consts = nat_ops @ [@{const number_of (nat)},
+ @{const zero_class.zero (nat)}, @{const one_class.one (nat)}]
+
+ val nat_int_coercions = [@{const of_nat (int)}, @{const nat}]
+
+ val nat_ops' = nat_int_coercions @ nat_ops
+
+ val is_nat_const = member (op aconv) nat_consts
+
+ val expands = map mk_meta_eq @{lemma
+ "0 = nat 0"
+ "1 = nat 1"
+ "(number_of :: int => nat) = (%i. nat (number_of i))"
+ "op < = (%a b. int a < int b)"
+ "op <= = (%a b. int a <= int b)"
+ "Suc = (%a. nat (int a + 1))"
+ "op + = (%a b. nat (int a + int b))"
+ "op - = (%a b. nat (int a - int b))"
+ "op * = (%a b. nat (int a * int b))"
+ "op div = (%a b. nat (int a div int b))"
+ "op mod = (%a b. nat (int a mod int b))"
+ by (auto intro!: ext simp add: nat_mult_distrib nat_div_distrib
+ nat_mod_distrib)}
+
+ val ints = map mk_meta_eq @{lemma
+ "int 0 = 0"
+ "int 1 = 1"
+ "int (Suc n) = int n + 1"
+ "int (n + m) = int n + int m"
+ "int (n - m) = int (nat (int n - int m))"
+ "int (n * m) = int n * int m"
+ "int (n div m) = int n div int m"
+ "int (n mod m) = int n mod int m"
+ "int (if P then n else m) = (if P then int n else int m)"
+ by (auto simp add: int_mult zdiv_int zmod_int)}
+
+ fun mk_number_eq ctxt i lhs =
+ let
+ val eq = U.mk_cequals lhs (Numeral.mk_cnumber @{ctyp int} i)
+ val ss = HOL_ss
+ addsimps [@{thm Nat_Numeral.int_nat_number_of}]
+ addsimps @{thms neg_simps}
+ fun tac _ = Simplifier.simp_tac (Simplifier.context ctxt ss) 1
+ in Goal.norm_result (Goal.prove_internal [] eq tac) end
+
+ fun expand_head_conv cv ct =
+ (case Thm.term_of ct of
+ _ $ _ =>
+ Conv.fun_conv (expand_head_conv cv) then_conv
+ Thm.beta_conversion false
+ | _ => cv) ct
+
+ fun int_conv ctxt ct =
+ (case Thm.term_of ct of
+ @{const of_nat (int)} $ (n as (@{const number_of (nat)} $ _)) =>
+ Conv.rewr_conv (mk_number_eq ctxt (snd (HOLogic.dest_number n)) ct)
+ | @{const of_nat (int)} $ _ =>
+ (Conv.rewrs_conv ints then_conv Conv.sub_conv ints_conv ctxt) else_conv
+ Conv.top_sweep_conv nat_conv ctxt
+ | _ => Conv.no_conv) ct
+
+ and ints_conv ctxt = Conv.top_sweep_conv int_conv ctxt
+
+ and expand_conv ctxt =
+ U.if_conv (not o is_nat_const o Term.head_of) Conv.no_conv
+ (expand_head_conv (Conv.rewrs_conv expands) then_conv ints_conv ctxt)
+
+ and nat_conv ctxt = U.if_exists_conv is_nat_const
+ (Conv.top_sweep_conv expand_conv ctxt)
+
+ val uses_nat_int = Term.exists_subterm (member (op aconv) nat_int_coercions)
+in
+
+val nat_as_int_conv = nat_conv
+
+fun add_nat_embedding thms =
+ if exists (uses_nat_int o Thm.prop_of) thms then (thms, nat_embedding)
+ else (thms, [])
+
+val setup_nat_as_int =
+ B.add_builtin_typ_ext (@{typ nat}, K true) #>
+ fold (B.add_builtin_fun_ext' o Term.dest_Const) nat_ops'
+
+end
+
+
+(** normalize numerals **)
+
+local
+ (*
+ rewrite negative numerals into positive numerals,
+ rewrite Numeral0 into 0
+ rewrite Numeral1 into 1
+ *)
+
+ fun is_strange_number ctxt (t as Const (@{const_name number_of}, _) $ _) =
+ (case try HOLogic.dest_number t of
+ SOME (_, i) => B.is_builtin_num ctxt t andalso i < 2
+ | NONE => false)
+ | is_strange_number _ _ = false
+
+ val pos_num_ss = HOL_ss
+ addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}]
+ addsimps [@{thm Int.number_of_Pls}, @{thm Int.numeral_1_eq_1}]
+ addsimps @{thms Int.pred_bin_simps}
+ addsimps @{thms Int.normalize_bin_simps}
+ addsimps @{lemma
+ "Int.Min = - Int.Bit1 Int.Pls"
+ "Int.Bit0 (- Int.Pls) = - Int.Pls"
+ "Int.Bit0 (- k) = - Int.Bit0 k"
+ "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)"
+ by simp_all (simp add: pred_def)}
+
+ fun norm_num_conv ctxt = U.if_conv (is_strange_number ctxt)
+ (Simplifier.rewrite (Simplifier.context ctxt pos_num_ss)) Conv.no_conv
+in
+
+fun normalize_numerals_conv ctxt = U.if_exists_conv (is_strange_number ctxt)
+ (Conv.top_sweep_conv norm_num_conv ctxt)
+
+end
+
+
+(** combined unfoldings and rewritings **)
+
+fun unfold_conv ctxt =
+ trivial_distinct_conv ctxt then_conv
+ rewrite_bool_case_conv ctxt then_conv
+ unfold_abs_min_max_conv ctxt then_conv
+ nat_as_int_conv ctxt then_conv
+ normalize_numerals_conv ctxt then_conv
+ Thm.beta_conversion true
+
+fun burrow_ids f ithms =
+ let
+ val (is, thms) = split_list ithms
+ val (thms', extra_thms) = f thms
+ in (is ~~ thms') @ map (pair ~1) extra_thms end
+
+fun unfold ctxt =
+ burrow_ids (map (Conv.fconv_rule (unfold_conv ctxt)) #> add_nat_embedding)
+
+
+
+(* overall normalization *)
+
+type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list
+
+structure Extra_Norms = Generic_Data
+(
+ type T = extra_norm U.dict
+ val empty = []
+ val extend = I
+ val merge = U.dict_merge fst
+)
+
+fun add_extra_norm (cs, norm) = Extra_Norms.map (U.dict_update (cs, norm))
+
+fun apply_extra_norms ctxt =
+ let
+ val cs = SMT_Config.solver_class_of ctxt
+ val es = U.dict_lookup (Extra_Norms.get (Context.Proof ctxt)) cs
+ in burrow_ids (fold (fn e => e ctxt) es o rpair []) end
+
+fun normalize ctxt iwthms =
+ iwthms
+ |> gen_normalize ctxt
+ |> unfold ctxt
+ |> apply_extra_norms ctxt
val setup = Context.theory_map (
+ setup_atomize #>
+ setup_unfolded_quants #>
+ setup_trigger #>
+ setup_weight #>
setup_bool_case #>
- setup_nat_as_int #>
- setup_unfolded_quants #>
- setup_atomize)
+ setup_abs_min_max #>
+ setup_nat_as_int)
end
--- a/src/HOL/Tools/SMT/smt_solver.ML Wed Dec 15 08:39:24 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_solver.ML Wed Dec 15 08:39:24 2010 +0100
@@ -9,7 +9,6 @@
(*configuration*)
type interface = {
class: SMT_Utils.class,
- extra_norm: SMT_Normalize.extra_norm,
translate: SMT_Translate.config }
datatype outcome = Unsat | Sat | Unknown
type solver_config = {
@@ -26,8 +25,8 @@
default_max_relevant: int }
(*registry*)
- type solver = bool option -> Proof.context -> (int * thm) list ->
- int list * thm
+ type solver = bool option -> Proof.context ->
+ (int * (int option * thm)) list -> int list * thm
val add_solver: solver_config -> theory -> theory
val solver_name_of: Proof.context -> string
val solver_of: Proof.context -> solver
@@ -37,7 +36,8 @@
val default_max_relevant: Proof.context -> string -> int
(*filter*)
- val smt_filter: bool -> Time.time -> Proof.state -> ('a * thm) list -> int ->
+ val smt_filter: bool -> Time.time -> Proof.state ->
+ ('a * (int option * thm)) list -> int ->
{outcome: SMT_Failure.failure option, used_facts: ('a * thm) list,
run_time_in_msecs: int option}
@@ -59,7 +59,6 @@
type interface = {
class: SMT_Utils.class,
- extra_norm: SMT_Normalize.extra_norm,
translate: SMT_Translate.config }
datatype outcome = Unsat | Sat | Unknown
@@ -176,7 +175,7 @@
Pretty.big_list "functions:" (map p_term (Symtab.dest terms))])) ()
end
-fun invoke translate_config name cmd options irules ctxt =
+fun invoke translate_config name cmd options ithms ctxt =
let
val args = C.solver_options_of ctxt @ options ctxt
val comments = ("solver: " ^ name) ::
@@ -184,7 +183,7 @@
("random seed: " ^ string_of_int (Config.get ctxt C.random_seed)) ::
"arguments:" :: args
in
- irules
+ ithms
|> tap (trace_assms ctxt)
|> SMT_Translate.translate translate_config ctxt comments
||> tap (trace_recon_data ctxt)
@@ -197,22 +196,23 @@
else discharge_definitions (@{thm reflexive} RS thm)
fun set_has_datatypes with_datatypes translate =
- let
- val {prefixes, header, is_fol, has_datatypes, serialize} = translate
- val with_datatypes' = has_datatypes andalso with_datatypes
- val translate' = {prefixes=prefixes, header=header, is_fol=is_fol,
- has_datatypes=with_datatypes', serialize=serialize}
- in (with_datatypes', translate') end
+ let val {prefixes, header, is_fol, has_datatypes, serialize} = translate
+ in
+ {prefixes=prefixes, header=header, is_fol=is_fol,
+ has_datatypes=has_datatypes andalso with_datatypes, serialize=serialize}
+ end
-fun trace_assumptions ctxt irules idxs =
+fun trace_assumptions ctxt iwthms idxs =
let
- val thms = filter (fn i => i >= 0) idxs
- |> map_filter (AList.lookup (op =) irules)
+ val wthms =
+ idxs
+ |> filter (fn i => i >= 0)
+ |> map_filter (AList.lookup (op =) iwthms)
in
- if Config.get ctxt C.trace_used_facts andalso length thms > 0
+ if Config.get ctxt C.trace_used_facts andalso length wthms > 0
then
tracing (Pretty.string_of (Pretty.big_list "SMT used facts:"
- (map (Display.pretty_thm ctxt) thms)))
+ (map (Display.pretty_thm ctxt o snd) wthms)))
else ()
end
@@ -220,7 +220,8 @@
(* registry *)
-type solver = bool option -> Proof.context -> (int * thm) list -> int list * thm
+type solver = bool option -> Proof.context -> (int * (int option * thm)) list ->
+ int list * thm
type solver_info = {
env_var: string,
@@ -231,22 +232,22 @@
(int list * thm) * Proof.context,
default_max_relevant: int }
-fun gen_solver name (info : solver_info) rm ctxt irules =
+fun gen_solver name (info : solver_info) rm ctxt iwthms =
let
val {env_var, is_remote, options, interface, reconstruct, ...} = info
- val {extra_norm, translate, ...} = interface
- val (with_datatypes, translate') =
- set_has_datatypes (Config.get ctxt C.datatypes) translate
+ val {translate, ...} = interface
+ val translate' = set_has_datatypes (Config.get ctxt C.datatypes) translate
val cmd = (rm, env_var, is_remote, name)
in
- (irules, ctxt)
- |-> SMT_Normalize.normalize extra_norm with_datatypes
+ SMT_Normalize.normalize ctxt iwthms
+ |> rpair ctxt
+ |-> SMT_Monomorph.monomorph
|-> invoke translate' name cmd options
|-> reconstruct
|-> (fn (idxs, thm) => fn ctxt' => thm
|> singleton (ProofContext.export ctxt' ctxt)
|> discharge_definitions
- |> tap (fn _ => trace_assumptions ctxt irules idxs)
+ |> tap (fn _ => trace_assumptions ctxt iwthms idxs)
|> pair idxs)
end
@@ -330,38 +331,45 @@
| TVar (_, []) => true
| _ => false))
-fun smt_solver rm ctxt irules =
+fun smt_solver rm ctxt iwthms =
(* without this test, we would run into problems when atomizing the rules: *)
- if exists (has_topsort o Thm.prop_of o snd) irules then
+ if exists (has_topsort o Thm.prop_of o snd o snd) iwthms then
raise SMT_Failure.SMT (SMT_Failure.Other_Failure ("proof state " ^
"contains the universal sort {}"))
- else solver_of ctxt rm ctxt irules
+ else solver_of ctxt rm ctxt iwthms
val cnot = Thm.cterm_of @{theory} @{const Not}
-fun smt_filter run_remote time_limit st xrules i =
+fun mk_result outcome xrules =
+ { outcome = outcome, used_facts = xrules, run_time_in_msecs = NONE }
+
+fun smt_filter run_remote time_limit st xwrules i =
let
- val {facts, goal, ...} = Proof.goal st
val ctxt =
Proof.context_of st
|> Config.put C.oracle false
|> Config.put C.timeout (Time.toReal time_limit)
|> Config.put C.drop_bad_facts true
|> Config.put C.filter_only_facts true
+
+ val {facts, goal, ...} = Proof.goal st
val ({context=ctxt', prems, concl, ...}, _) = Subgoal.focus ctxt i goal
fun negate ct = Thm.dest_comb ct ||> Thm.capply cnot |-> Thm.capply
val cprop = negate (Thm.rhs_of (SMT_Normalize.atomize_conv ctxt' concl))
- val irs = map (pair ~1) (Thm.assume cprop :: prems @ facts)
- val rm = SOME run_remote
+
+ val (xs, wthms) = split_list xwrules
+ val xrules = xs ~~ map snd wthms
in
- (xrules, map snd xrules)
- ||> distinct (op =) o fst o smt_solver rm ctxt' o append irs o map_index I
- |-> map_filter o try o nth
- |> (fn xs => {outcome=NONE, used_facts=if solver_name_of ctxt = "z3" (* FIXME *) then xs
- else xrules, run_time_in_msecs=NONE})
+ wthms
+ |> map_index I
+ |> append (map (pair ~1 o pair NONE) (Thm.assume cprop :: prems @ facts))
+ |> smt_solver (SOME run_remote) ctxt'
+ |> distinct (op =) o fst
+ |> map_filter (try (nth xrules))
+ |> (if solver_name_of ctxt = "z3" (* FIXME *) then I else K xrules)
+ |> mk_result NONE
end
- handle SMT_Failure.SMT fail => {outcome=SOME fail, used_facts=[],
- run_time_in_msecs=NONE}
+ handle SMT_Failure.SMT fail => mk_result (SOME fail) []
(* FIXME: measure runtime *)
@@ -373,18 +381,18 @@
THEN' Tactic.rtac @{thm ccontr}
THEN' SUBPROOF (fn {context=ctxt', prems, ...} =>
let
- fun solve irules = snd (smt_solver NONE ctxt' irules)
+ fun solve iwthms = snd (smt_solver NONE ctxt' iwthms)
val tag = "Solver " ^ C.solver_of ctxt' ^ ": "
val str_of = prefix tag o SMT_Failure.string_of_failure ctxt'
- fun safe_solve irules =
- if pass_exns then SOME (solve irules)
- else (SOME (solve irules)
+ fun safe_solve iwthms =
+ if pass_exns then SOME (solve iwthms)
+ else (SOME (solve iwthms)
handle
SMT_Failure.SMT (fail as SMT_Failure.Counterexample _) =>
(C.verbose_msg ctxt' str_of fail; NONE)
| SMT_Failure.SMT fail => (C.trace_msg ctxt' str_of fail; NONE))
in
- safe_solve (map (pair ~1) (rules @ prems))
+ safe_solve (map (pair ~1 o pair NONE) (rules @ prems))
|> (fn SOME thm => Tactic.rtac thm 1 | _ => Tactical.no_tac)
end) ctxt