re-ordered SMT normalization code (eta-normalization, lambda abstractions and partial functions will be dealt with on the term level);
authorboehmes
Wed Dec 15 08:39:24 2010 +0100 (2010-12-15)
changeset 41126e0bd443c0fdd
parent 41125 4a9eec045f2a
child 41127 2ea84c8535c6
re-ordered SMT normalization code (eta-normalization, lambda abstractions and partial functions will be dealt with on the term level);
added simple trigger inference mechanism;
added syntactic checks for triggers and quantifier weights;
factored out the normalization of special quantifiers (used to be in the eta-normalization part);
normalization now unfolds abs/min/max (not SMT-LIB-specific);
rules for pairs and function update are not anymore added automatically to the problem;
more aggressive rewriting of natural number operations into integer operations (minimizes the number of remaining nat-int coercions);
normalizations are now managed in a class-based manner (similar to built-in symbols)
src/HOL/SMT.thy
src/HOL/Tools/SMT/smt_builtin.ML
src/HOL/Tools/SMT/smt_normalize.ML
src/HOL/Tools/SMT/smt_solver.ML
src/HOL/Tools/SMT/smt_utils.ML
src/HOL/Tools/SMT/smtlib_interface.ML
src/HOL/Tools/SMT/z3_interface.ML
src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML
     1.1 --- a/src/HOL/SMT.thy	Wed Dec 15 08:39:24 2010 +0100
     1.2 +++ b/src/HOL/SMT.thy	Wed Dec 15 08:39:24 2010 +0100
     1.3 @@ -130,18 +130,20 @@
     1.4  definition z3mod :: "int \<Rightarrow> int \<Rightarrow> int" where
     1.5    "z3mod k l = (if 0 \<le> l then k mod l else k mod (-l))"
     1.6  
     1.7 -lemma div_by_z3div: "k div l = (
     1.8 -     if k = 0 \<or> l = 0 then 0
     1.9 -     else if (0 < k \<and> 0 < l) \<or> (k < 0 \<and> 0 < l) then z3div k l
    1.10 -     else z3div (-k) (-l))"
    1.11 -  by (auto simp add: z3div_def)
    1.12 +lemma div_by_z3div:
    1.13 +  "\<forall>k l. k div l = (
    1.14 +    if k = 0 \<or> l = 0 then 0
    1.15 +    else if (0 < k \<and> 0 < l) \<or> (k < 0 \<and> 0 < l) then z3div k l
    1.16 +    else z3div (-k) (-l))"
    1.17 +  by (auto simp add: z3div_def trigger_def)
    1.18  
    1.19 -lemma mod_by_z3mod: "k mod l = (
    1.20 -     if l = 0 then k
    1.21 -     else if k = 0 then 0
    1.22 -     else if (0 < k \<and> 0 < l) \<or> (k < 0 \<and> 0 < l) then z3mod k l
    1.23 -     else - z3mod (-k) (-l))"
    1.24 -  by (auto simp add: z3mod_def)
    1.25 +lemma mod_by_z3mod:
    1.26 +  "\<forall>k l. k mod l = (
    1.27 +    if l = 0 then k
    1.28 +    else if k = 0 then 0
    1.29 +    else if (0 < k \<and> 0 < l) \<or> (k < 0 \<and> 0 < l) then z3mod k l
    1.30 +    else - z3mod (-k) (-l))"
    1.31 +  by (auto simp add: z3mod_def trigger_def)
    1.32  
    1.33  
    1.34  
     2.1 --- a/src/HOL/Tools/SMT/smt_builtin.ML	Wed Dec 15 08:39:24 2010 +0100
     2.2 +++ b/src/HOL/Tools/SMT/smt_builtin.ML	Wed Dec 15 08:39:24 2010 +0100
     2.3 @@ -37,6 +37,7 @@
     2.4    val is_builtin_fun: Proof.context -> string * typ -> term list -> bool
     2.5    val is_builtin_pred: Proof.context -> string * typ -> term list -> bool
     2.6    val is_builtin_conn: Proof.context -> string * typ -> term list -> bool
     2.7 +  val is_builtin_fun_ext: Proof.context -> string * typ -> term list -> bool
     2.8    val is_builtin_ext: Proof.context -> string * typ -> term list -> bool
     2.9  end
    2.10  
    2.11 @@ -78,8 +79,6 @@
    2.12  
    2.13  type ('a, 'b) btab = ('a, 'b) ttab Symtab.table
    2.14  
    2.15 -fun empty_btab () = Symtab.empty
    2.16 -
    2.17  fun insert_btab cs n T f =
    2.18    Symtab.map_default (n, []) (insert_ttab cs T f)
    2.19  
    2.20 @@ -147,26 +146,10 @@
    2.21  
    2.22  type 'a bfun = Proof.context -> typ -> term list -> 'a
    2.23  
    2.24 -fun true3 _ _ _ = true
    2.25 -
    2.26 -fun raw_add_builtin_fun_ext thy cs n =
    2.27 -  insert_btab cs n (Sign.the_const_type thy n) (Ext true3)
    2.28 -
    2.29 -val basic_builtin_fun_names = [
    2.30 -  @{const_name SMT.pat}, @{const_name SMT.nopat},
    2.31 -  @{const_name SMT.trigger}, @{const_name SMT.weight}]
    2.32 -
    2.33 -type builtin_funcs = (bool bfun, (string * term list) option bfun) btab
    2.34 -
    2.35 -fun basic_builtin_funcs () : builtin_funcs =
    2.36 -  empty_btab ()
    2.37 -  |> fold (raw_add_builtin_fun_ext @{theory} U.basicC) basic_builtin_fun_names
    2.38 -       (* FIXME: SMT_Normalize should check that they are properly used *)
    2.39 -
    2.40  structure Builtin_Funcs = Generic_Data
    2.41  (
    2.42 -  type T = builtin_funcs
    2.43 -  val empty = basic_builtin_funcs ()
    2.44 +  type T = (bool bfun, (string * term list) option bfun) btab
    2.45 +  val empty = Symtab.empty
    2.46    val extend = I
    2.47    val merge = merge_btab
    2.48  )
    2.49 @@ -180,7 +163,8 @@
    2.50  fun add_builtin_fun_ext ((n, T), f) =
    2.51    Builtin_Funcs.map (insert_btab U.basicC n T (Ext f))
    2.52  
    2.53 -fun add_builtin_fun_ext' c = add_builtin_fun_ext (c, true3)
    2.54 +fun add_builtin_fun_ext' c =
    2.55 +  add_builtin_fun_ext (c, fn _ => fn _ => fn _ => true)
    2.56  
    2.57  fun add_builtin_fun_ext'' n context =
    2.58    let val thy = Context.theory_of context
     3.1 --- a/src/HOL/Tools/SMT/smt_normalize.ML	Wed Dec 15 08:39:24 2010 +0100
     3.2 +++ b/src/HOL/Tools/SMT/smt_normalize.ML	Wed Dec 15 08:39:24 2010 +0100
     3.3 @@ -1,28 +1,17 @@
     3.4  (*  Title:      HOL/Tools/SMT/smt_normalize.ML
     3.5      Author:     Sascha Boehme, TU Muenchen
     3.6  
     3.7 -Normalization steps on theorems required by SMT solvers:
     3.8 -  * simplify trivial distincts (those with less than three elements),
     3.9 -  * rewrite bool case expressions as if expressions,
    3.10 -  * normalize numerals (e.g. replace negative numerals by negated positive
    3.11 -    numerals),
    3.12 -  * embed natural numbers into integers,
    3.13 -  * add extra rules specifying types and constants which occur frequently,
    3.14 -  * fully translate into object logic, add universal closure,
    3.15 -  * monomorphize (create instances of schematic rules),
    3.16 -  * lift lambda terms,
    3.17 -  * make applications explicit for functions with varying number of arguments.
    3.18 -  * add (hypothetical definitions for) missing datatype selectors,
    3.19 +Normalization steps on theorems required by SMT solvers.
    3.20  *)
    3.21  
    3.22  signature SMT_NORMALIZE =
    3.23  sig
    3.24 -  type extra_norm = bool -> (int * thm) list -> Proof.context ->
    3.25 -    (int * thm) list * Proof.context
    3.26 -  val normalize: extra_norm -> bool -> (int * thm) list -> Proof.context ->
    3.27 -    (int * thm) list * Proof.context
    3.28    val atomize_conv: Proof.context -> conv
    3.29 -  val eta_expand_conv: (Proof.context -> conv) -> Proof.context -> conv
    3.30 +  type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list
    3.31 +  val add_extra_norm: SMT_Utils.class * extra_norm -> Context.generic ->
    3.32 +    Context.generic
    3.33 +  val normalize: Proof.context -> (int * (int option * thm)) list ->
    3.34 +    (int * thm) list
    3.35    val setup: theory -> theory
    3.36  end
    3.37  
    3.38 @@ -32,12 +21,10 @@
    3.39  structure U = SMT_Utils
    3.40  structure B = SMT_Builtin
    3.41  
    3.42 -infix 2 ??
    3.43 -fun (test ?? f) x = if test x then f x else x
    3.44  
    3.45 -
    3.46 +(* general theorem normalizations *)
    3.47  
    3.48 -(* instantiate elimination rules *)
    3.49 +(** instantiate elimination rules **)
    3.50   
    3.51  local
    3.52    val (cpfalse, cfalse) = `U.mk_cprop (Thm.cterm_of @{theory} @{const False})
    3.53 @@ -56,281 +43,18 @@
    3.54  end
    3.55  
    3.56  
    3.57 -
    3.58 -(* simplification of trivial distincts (distinct should have at least
    3.59 -   three elements in the argument list) *)
    3.60 -
    3.61 -local
    3.62 -  fun is_trivial_distinct (Const (@{const_name distinct}, _) $ t) =
    3.63 -        (case try HOLogic.dest_list t of
    3.64 -          SOME [] => true
    3.65 -        | SOME [_] => true
    3.66 -        | _ => false)
    3.67 -    | is_trivial_distinct _ = false
    3.68 -
    3.69 -  val thms = map mk_meta_eq @{lemma
    3.70 -    "distinct [] = True"
    3.71 -    "distinct [x] = True"
    3.72 -    "distinct [x, y] = (x ~= y)"
    3.73 -    by simp_all}
    3.74 -  fun distinct_conv _ =
    3.75 -    U.if_true_conv is_trivial_distinct (Conv.rewrs_conv thms)
    3.76 -in
    3.77 -fun trivial_distinct ctxt =
    3.78 -  map (apsnd ((Term.exists_subterm is_trivial_distinct o Thm.prop_of) ??
    3.79 -    Conv.fconv_rule (Conv.top_conv distinct_conv ctxt)))
    3.80 -end
    3.81 -
    3.82 -
    3.83 -
    3.84 -(* rewrite bool case expressions as if expressions *)
    3.85 -
    3.86 -local
    3.87 -  val is_bool_case = (fn
    3.88 -      Const (@{const_name "bool.bool_case"}, _) $ _ $ _ $ _ => true
    3.89 -    | _ => false)
    3.90 +(** normalize definitions **)
    3.91  
    3.92 -  val thm = mk_meta_eq @{lemma
    3.93 -    "(case P of True => x | False => y) = (if P then x else y)" by simp}
    3.94 -  val unfold_conv = U.if_true_conv is_bool_case (Conv.rewr_conv thm)
    3.95 -in
    3.96 -fun rewrite_bool_cases ctxt =
    3.97 -  map (apsnd ((Term.exists_subterm is_bool_case o Thm.prop_of) ??
    3.98 -    Conv.fconv_rule (Conv.top_conv (K unfold_conv) ctxt)))
    3.99 -
   3.100 -val setup_bool_case = B.add_builtin_fun_ext'' @{const_name "bool.bool_case"}
   3.101 -
   3.102 -end
   3.103 -
   3.104 -
   3.105 -
   3.106 -(* normalization of numerals: rewriting of negative integer numerals into
   3.107 -   positive numerals, Numeral0 into 0, Numeral1 into 1 *)
   3.108 -
   3.109 -local
   3.110 -  fun is_number_sort ctxt T =
   3.111 -    Sign.of_sort (ProofContext.theory_of ctxt) (T, @{sort number_ring})
   3.112 -
   3.113 -  fun is_strange_number ctxt (t as Const (@{const_name number_of}, _) $ _) =
   3.114 -        (case try HOLogic.dest_number t of
   3.115 -          SOME (T, i) => is_number_sort ctxt T andalso i < 2
   3.116 -        | NONE => false)
   3.117 -    | is_strange_number _ _ = false
   3.118 -
   3.119 -  val pos_numeral_ss = HOL_ss
   3.120 -    addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}]
   3.121 -    addsimps [@{thm Int.number_of_Pls}, @{thm Int.numeral_1_eq_1}]
   3.122 -    addsimps @{thms Int.pred_bin_simps}
   3.123 -    addsimps @{thms Int.normalize_bin_simps}
   3.124 -    addsimps @{lemma
   3.125 -      "Int.Min = - Int.Bit1 Int.Pls"
   3.126 -      "Int.Bit0 (- Int.Pls) = - Int.Pls"
   3.127 -      "Int.Bit0 (- k) = - Int.Bit0 k"
   3.128 -      "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)"
   3.129 -      by simp_all (simp add: pred_def)}
   3.130 -
   3.131 -  fun pos_conv ctxt = U.if_conv (is_strange_number ctxt)
   3.132 -    (Simplifier.rewrite (Simplifier.context ctxt pos_numeral_ss))
   3.133 -    Conv.no_conv
   3.134 -in
   3.135 -fun normalize_numerals ctxt =
   3.136 -  map (apsnd ((Term.exists_subterm (is_strange_number ctxt) o Thm.prop_of) ??
   3.137 -    Conv.fconv_rule (Conv.top_sweep_conv pos_conv ctxt)))
   3.138 -end
   3.139 -
   3.140 +fun norm_def thm =
   3.141 +  (case Thm.prop_of thm of
   3.142 +    @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ _ $ Abs _) =>
   3.143 +      norm_def (thm RS @{thm fun_cong})
   3.144 +  | Const (@{const_name "=="}, _) $ _ $ Abs _ =>
   3.145 +      norm_def (thm RS @{thm meta_eq_to_obj_eq})
   3.146 +  | _ => thm)
   3.147  
   3.148  
   3.149 -(* embedding of standard natural number operations into integer operations *)
   3.150 -
   3.151 -local
   3.152 -  val nat_embedding = map (pair ~1) @{lemma
   3.153 -    "nat (int n) = n"
   3.154 -    "i >= 0 --> int (nat i) = i"
   3.155 -    "i < 0 --> int (nat i) = 0"
   3.156 -    by simp_all}
   3.157 -
   3.158 -  val nat_rewriting = @{lemma
   3.159 -    "0 = nat 0"
   3.160 -    "1 = nat 1"
   3.161 -    "(number_of :: int => nat) = (%i. nat (number_of i))"
   3.162 -    "int (nat 0) = 0"
   3.163 -    "int (nat 1) = 1"
   3.164 -    "op < = (%a b. int a < int b)"
   3.165 -    "op <= = (%a b. int a <= int b)"
   3.166 -    "Suc = (%a. nat (int a + 1))"
   3.167 -    "op + = (%a b. nat (int a + int b))"
   3.168 -    "op - = (%a b. nat (int a - int b))"
   3.169 -    "op * = (%a b. nat (int a * int b))"
   3.170 -    "op div = (%a b. nat (int a div int b))"
   3.171 -    "op mod = (%a b. nat (int a mod int b))"
   3.172 -    "min = (%a b. nat (min (int a) (int b)))"
   3.173 -    "max = (%a b. nat (max (int a) (int b)))"
   3.174 -    "int (nat (int a + int b)) = int a + int b"
   3.175 -    "int (nat (int a + 1)) = int a + 1"  (* special rule due to Suc above *)
   3.176 -    "int (nat (int a * int b)) = int a * int b"
   3.177 -    "int (nat (int a div int b)) = int a div int b"
   3.178 -    "int (nat (int a mod int b)) = int a mod int b"
   3.179 -    "int (nat (min (int a) (int b))) = min (int a) (int b)"
   3.180 -    "int (nat (max (int a) (int b))) = max (int a) (int b)"
   3.181 -    by (auto intro!: ext simp add: nat_mult_distrib nat_div_distrib
   3.182 -      nat_mod_distrib int_mult[symmetric] zdiv_int[symmetric]
   3.183 -      zmod_int[symmetric])}
   3.184 -
   3.185 -  fun on_positive num f x = 
   3.186 -    (case try HOLogic.dest_number (Thm.term_of num) of
   3.187 -      SOME (_, i) => if i >= 0 then SOME (f x) else NONE
   3.188 -    | NONE => NONE)
   3.189 -
   3.190 -  val cancel_int_nat_ss = HOL_ss
   3.191 -    addsimps [@{thm Nat_Numeral.nat_number_of}]
   3.192 -    addsimps [@{thm Nat_Numeral.int_nat_number_of}]
   3.193 -    addsimps @{thms neg_simps}
   3.194 -
   3.195 -  val int_eq = Thm.cterm_of @{theory} @{const "==" (int)}
   3.196 -
   3.197 -  fun cancel_int_nat_simproc _ ss ct = 
   3.198 -    let
   3.199 -      val num = Thm.dest_arg (Thm.dest_arg ct)
   3.200 -      val goal = Thm.mk_binop int_eq ct num
   3.201 -      val simpset = Simplifier.inherit_context ss cancel_int_nat_ss
   3.202 -      fun tac _ = Simplifier.simp_tac simpset 1
   3.203 -    in on_positive num (Goal.prove_internal [] goal) tac end
   3.204 -
   3.205 -  val nat_ss = HOL_ss
   3.206 -    addsimps nat_rewriting
   3.207 -    addsimprocs [
   3.208 -      Simplifier.make_simproc {
   3.209 -        name = "cancel_int_nat_num", lhss = [@{cpat "int (nat _)"}],
   3.210 -        proc = cancel_int_nat_simproc, identifier = [] }]
   3.211 -
   3.212 -  fun conv ctxt = Simplifier.rewrite (Simplifier.context ctxt nat_ss)
   3.213 -
   3.214 -  val uses_nat_type = Term.exists_type (Term.exists_subtype (equal @{typ nat}))
   3.215 -  val uses_nat_int = Term.exists_subterm (member (op aconv)
   3.216 -    [@{const of_nat (int)}, @{const nat}])
   3.217 -
   3.218 -  val nat_ops = [
   3.219 -    @{const less (nat)}, @{const less_eq (nat)},
   3.220 -    @{const Suc}, @{const plus (nat)}, @{const minus (nat)},
   3.221 -    @{const times (nat)}, @{const div (nat)}, @{const mod (nat)}]
   3.222 -  val nat_ops' = @{const of_nat (int)} :: @{const nat} :: nat_ops
   3.223 -in
   3.224 -fun nat_as_int ctxt =
   3.225 -  map (apsnd ((uses_nat_type o Thm.prop_of) ?? Conv.fconv_rule (conv ctxt))) #>
   3.226 -  exists (uses_nat_int o Thm.prop_of o snd) ?? append nat_embedding
   3.227 -
   3.228 -val setup_nat_as_int =
   3.229 -  B.add_builtin_typ_ext (@{typ nat}, K true) #>
   3.230 -  fold (B.add_builtin_fun_ext' o Term.dest_Const) nat_ops'
   3.231 -end
   3.232 -
   3.233 -
   3.234 -
   3.235 -(* further normalizations: beta/eta, universal closure, atomize *)
   3.236 -
   3.237 -val eta_expand_eq = @{lemma "f == (%x. f x)" by (rule reflexive)}
   3.238 -
   3.239 -fun eta_expand_conv cv ctxt =
   3.240 -  Conv.rewr_conv eta_expand_eq then_conv Conv.abs_conv (cv o snd) ctxt
   3.241 -
   3.242 -local
   3.243 -  val eta_conv = eta_expand_conv
   3.244 -
   3.245 -  fun args_conv cv ct =
   3.246 -    (case Thm.term_of ct of
   3.247 -      _ $ _ => Conv.combination_conv (args_conv cv) cv
   3.248 -    | _ => Conv.all_conv) ct
   3.249 -
   3.250 -  fun eta_args_conv cv 0 = args_conv o cv
   3.251 -    | eta_args_conv cv i = eta_conv (eta_args_conv cv (i-1))
   3.252 -
   3.253 -  fun keep_conv ctxt = Conv.binder_conv (norm_conv o snd) ctxt
   3.254 -  and eta_binder_conv ctxt = Conv.arg_conv (eta_conv norm_conv ctxt)
   3.255 -  and keep_let_conv ctxt = Conv.combination_conv
   3.256 -    (Conv.arg_conv (norm_conv ctxt)) (Conv.abs_conv (norm_conv o snd) ctxt)
   3.257 -  and unfold_let_conv ctxt = Conv.combination_conv
   3.258 -    (Conv.arg_conv (norm_conv ctxt)) (eta_conv norm_conv ctxt)
   3.259 -  and unfold_conv thm ctxt = Conv.rewr_conv thm then_conv keep_conv ctxt
   3.260 -  and unfold_ex1_conv ctxt = unfold_conv @{thm Ex1_def} ctxt
   3.261 -  and unfold_ball_conv ctxt = unfold_conv (mk_meta_eq @{thm Ball_def}) ctxt
   3.262 -  and unfold_bex_conv ctxt = unfold_conv (mk_meta_eq @{thm Bex_def}) ctxt
   3.263 -  and norm_conv ctxt ct =
   3.264 -    (case Thm.term_of ct of
   3.265 -      Const (@{const_name All}, _) $ Abs _ => keep_conv
   3.266 -    | Const (@{const_name All}, _) $ _ => eta_binder_conv
   3.267 -    | Const (@{const_name All}, _) => eta_conv eta_binder_conv
   3.268 -    | Const (@{const_name Ex}, _) $ Abs _ => keep_conv
   3.269 -    | Const (@{const_name Ex}, _) $ _ => eta_binder_conv
   3.270 -    | Const (@{const_name Ex}, _) => eta_conv eta_binder_conv
   3.271 -    | Const (@{const_name Let}, _) $ _ $ Abs _ => keep_let_conv
   3.272 -    | Const (@{const_name Let}, _) $ _ $ _ => unfold_let_conv
   3.273 -    | Const (@{const_name Let}, _) $ _ => eta_conv unfold_let_conv
   3.274 -    | Const (@{const_name Let}, _) => eta_conv (eta_conv unfold_let_conv)
   3.275 -    | Const (@{const_name Ex1}, _) $ _ => unfold_ex1_conv
   3.276 -    | Const (@{const_name Ex1}, _) => eta_conv unfold_ex1_conv 
   3.277 -    | Const (@{const_name Ball}, _) $ _ $ _ => unfold_ball_conv
   3.278 -    | Const (@{const_name Ball}, _) $ _ => eta_conv unfold_ball_conv
   3.279 -    | Const (@{const_name Ball}, _) => eta_conv (eta_conv unfold_ball_conv)
   3.280 -    | Const (@{const_name Bex}, _) $ _ $ _ => unfold_bex_conv
   3.281 -    | Const (@{const_name Bex}, _) $ _ => eta_conv unfold_bex_conv
   3.282 -    | Const (@{const_name Bex}, _) => eta_conv (eta_conv unfold_bex_conv)
   3.283 -    | Abs _ => Conv.abs_conv (norm_conv o snd)
   3.284 -    | _ =>
   3.285 -        (case Term.strip_comb (Thm.term_of ct) of
   3.286 -          (Const (c as (_, T)), ts) =>
   3.287 -            if SMT_Builtin.is_builtin_fun ctxt c ts
   3.288 -            then eta_args_conv norm_conv
   3.289 -              (length (Term.binder_types T) - length ts)
   3.290 -            else args_conv o norm_conv
   3.291 -        | _ => args_conv o norm_conv)) ctxt ct
   3.292 -
   3.293 -  fun is_normed ctxt t =
   3.294 -    (case t of
   3.295 -      Const (@{const_name All}, _) $ Abs (_, _, u) => is_normed ctxt u
   3.296 -    | Const (@{const_name All}, _) $ _ => false
   3.297 -    | Const (@{const_name All}, _) => false
   3.298 -    | Const (@{const_name Ex}, _) $ Abs (_, _, u) => is_normed ctxt u
   3.299 -    | Const (@{const_name Ex}, _) $ _ => false
   3.300 -    | Const (@{const_name Ex}, _) => false
   3.301 -    | Const (@{const_name Let}, _) $ u1 $ Abs (_, _, u2) =>
   3.302 -        is_normed ctxt u1 andalso is_normed ctxt u2
   3.303 -    | Const (@{const_name Let}, _) $ _ $ _ => false
   3.304 -    | Const (@{const_name Let}, _) $ _ => false
   3.305 -    | Const (@{const_name Let}, _) => false
   3.306 -    | Const (@{const_name Ex1}, _) $ _ => false
   3.307 -    | Const (@{const_name Ex1}, _) => false
   3.308 -    | Const (@{const_name Ball}, _) $ _ $ _ => false
   3.309 -    | Const (@{const_name Ball}, _) $ _ => false
   3.310 -    | Const (@{const_name Ball}, _) => false
   3.311 -    | Const (@{const_name Bex}, _) $ _ $ _ => false
   3.312 -    | Const (@{const_name Bex}, _) $ _ => false
   3.313 -    | Const (@{const_name Bex}, _) => false
   3.314 -    | Abs (_, _, u) => is_normed ctxt u
   3.315 -    | _ =>
   3.316 -        (case Term.strip_comb t of
   3.317 -          (Const (c as (_, T)), ts) =>
   3.318 -            if SMT_Builtin.is_builtin_fun ctxt c ts
   3.319 -            then length (Term.binder_types T) = length ts andalso
   3.320 -              forall (is_normed ctxt) ts
   3.321 -            else forall (is_normed ctxt) ts
   3.322 -        | (_, ts) => forall (is_normed ctxt) ts))
   3.323 -in
   3.324 -fun norm_binder_conv ctxt =
   3.325 -  U.if_conv (is_normed ctxt) Conv.all_conv (norm_conv ctxt)
   3.326 -
   3.327 -val setup_unfolded_quants =
   3.328 -  fold B.add_builtin_fun_ext'' [@{const_name Ball}, @{const_name Bex},
   3.329 -    @{const_name Ex1}]
   3.330 -
   3.331 -end
   3.332 -
   3.333 -fun norm_def ctxt thm =
   3.334 -  (case Thm.prop_of thm of
   3.335 -    @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ _ $ Abs _) =>
   3.336 -      norm_def ctxt (thm RS @{thm fun_cong})
   3.337 -  | Const (@{const_name "=="}, _) $ _ $ Abs _ =>
   3.338 -      norm_def ctxt (thm RS @{thm meta_eq_to_obj_eq})
   3.339 -  | _ => thm)
   3.340 +(** atomization **)
   3.341  
   3.342  fun atomize_conv ctxt ct =
   3.343    (case Thm.term_of ct of
   3.344 @@ -349,243 +73,543 @@
   3.345    fold B.add_builtin_fun_ext'' [@{const_name "==>"}, @{const_name "=="},
   3.346      @{const_name all}, @{const_name Trueprop}]
   3.347  
   3.348 -fun normalize_rule ctxt =
   3.349 -  Conv.fconv_rule (
   3.350 -    (* reduce lambda abstractions, except at known binders: *)
   3.351 -    Thm.beta_conversion true then_conv
   3.352 -    Thm.eta_conversion then_conv
   3.353 -    norm_binder_conv ctxt) #>
   3.354 -  norm_def ctxt #>
   3.355 -  Drule.forall_intr_vars #>
   3.356 -  Conv.fconv_rule (atomize_conv ctxt)
   3.357  
   3.358 -
   3.359 -
   3.360 -(* lift lambda terms into additional rules *)
   3.361 +(** unfold special quantifiers **)
   3.362  
   3.363  local
   3.364 -  fun used_vars cvs ct =
   3.365 -    let
   3.366 -      val lookup = AList.lookup (op aconv) (map (` Thm.term_of) cvs)
   3.367 -      val add = (fn SOME ct => insert (op aconvc) ct | _ => I)
   3.368 -    in Term.fold_aterms (add o lookup) (Thm.term_of ct) [] end
   3.369 -
   3.370 -  fun apply cv thm = 
   3.371 -    let val thm' = Thm.combination thm (Thm.reflexive cv)
   3.372 -    in Thm.transitive thm' (Thm.beta_conversion false (Thm.rhs_of thm')) end
   3.373 -  fun apply_def cvs eq = Thm.symmetric (fold apply cvs eq)
   3.374 +  val ex1_def = mk_meta_eq @{lemma
   3.375 +    "Ex1 = (%P. EX x. P x & (ALL y. P y --> y = x))"
   3.376 +    by (rule ext) (simp only: Ex1_def)}
   3.377  
   3.378 -  fun replace_lambda cvs ct (cx as (ctxt, defs)) =
   3.379 -    let
   3.380 -      val cvs' = used_vars cvs ct
   3.381 -      val ct' = fold_rev Thm.cabs cvs' ct
   3.382 -    in
   3.383 -      (case Termtab.lookup defs (Thm.term_of ct') of
   3.384 -        SOME eq => (apply_def cvs' eq, cx)
   3.385 -      | NONE =>
   3.386 -          let
   3.387 -            val {T, ...} = Thm.rep_cterm ct' and n = Name.uu
   3.388 -            val (n', ctxt') = yield_singleton Variable.variant_fixes n ctxt
   3.389 -            val cu = U.mk_cequals (U.certify ctxt (Free (n', T))) ct'
   3.390 -            val (eq, ctxt'') = yield_singleton Assumption.add_assumes cu ctxt'
   3.391 -            val defs' = Termtab.update (Thm.term_of ct', eq) defs
   3.392 -          in (apply_def cvs' eq, (ctxt'', defs')) end)
   3.393 -    end
   3.394 +  val ball_def = mk_meta_eq @{lemma "Ball = (%A P. ALL x. x : A --> P x)"
   3.395 +    by (rule ext)+ (rule Ball_def)}
   3.396 +
   3.397 +  val bex_def = mk_meta_eq @{lemma "Bex = (%A P. EX x. x : A & P x)"
   3.398 +    by (rule ext)+ (rule Bex_def)}
   3.399  
   3.400 -  fun none ct cx = (Thm.reflexive ct, cx)
   3.401 -  fun in_comb f g ct cx =
   3.402 -    let val (cu1, cu2) = Thm.dest_comb ct
   3.403 -    in cx |> f cu1 ||>> g cu2 |>> uncurry Thm.combination end
   3.404 -  fun in_arg f = in_comb none f
   3.405 -  fun in_abs f cvs ct (ctxt, defs) =
   3.406 -    let
   3.407 -      val (n, ctxt') = yield_singleton Variable.variant_fixes Name.uu ctxt
   3.408 -      val (cv, cu) = Thm.dest_abs (SOME n) ct
   3.409 -    in  (ctxt', defs) |> f (cv :: cvs) cu |>> Thm.abstract_rule n cv end
   3.410 -
   3.411 -  fun traverse cvs ct =
   3.412 -    (case Thm.term_of ct of
   3.413 -      Const (@{const_name All}, _) $ Abs _ => in_arg (in_abs traverse cvs)
   3.414 -    | Const (@{const_name Ex}, _) $ Abs _ => in_arg (in_abs traverse cvs)
   3.415 -    | Const (@{const_name Let}, _) $ _ $ Abs _ =>
   3.416 -        in_comb (in_arg (traverse cvs)) (in_abs traverse cvs)
   3.417 -    | Abs _ => at_lambda cvs
   3.418 -    | _ $ _ => in_comb (traverse cvs) (traverse cvs)
   3.419 -    | _ => none) ct
   3.420 +  val special_quants = [(@{const_name Ex1}, ex1_def),
   3.421 +    (@{const_name Ball}, ball_def), (@{const_name Bex}, bex_def)]
   3.422 +  
   3.423 +  fun special_quant (Const (n, _)) = AList.lookup (op =) special_quants n
   3.424 +    | special_quant _ = NONE
   3.425  
   3.426 -  and at_lambda cvs ct =
   3.427 -    in_abs traverse cvs ct #-> (fn thm =>
   3.428 -    replace_lambda cvs (Thm.rhs_of thm) #>> Thm.transitive thm)
   3.429 +  fun special_quant_conv _ ct =
   3.430 +    (case special_quant (Thm.term_of ct) of
   3.431 +      SOME thm => Conv.rewr_conv thm
   3.432 +    | NONE => Conv.all_conv) ct
   3.433 +in
   3.434  
   3.435 -  fun has_free_lambdas t =
   3.436 -    (case t of
   3.437 -      Const (@{const_name All}, _) $ Abs (_, _, u) => has_free_lambdas u
   3.438 -    | Const (@{const_name Ex}, _) $ Abs (_, _, u) => has_free_lambdas u
   3.439 -    | Const (@{const_name Let}, _) $ u1 $ Abs (_, _, u2) =>
   3.440 -        has_free_lambdas u1 orelse has_free_lambdas u2
   3.441 -    | Abs _ => true
   3.442 -    | u1 $ u2 => has_free_lambdas u1 orelse has_free_lambdas u2
   3.443 -    | _ => false)
   3.444 +fun unfold_special_quants_conv ctxt =
   3.445 +  U.if_exists_conv (is_some o special_quant)
   3.446 +    (Conv.top_conv special_quant_conv ctxt)
   3.447  
   3.448 -  fun lift_lm f thm cx =
   3.449 -    if not (has_free_lambdas (Thm.prop_of thm)) then (thm, cx)
   3.450 -    else cx |> f (Thm.cprop_of thm) |>> (fn thm' => Thm.equal_elim thm' thm)
   3.451 -in
   3.452 -fun lift_lambdas irules ctxt =
   3.453 -  let
   3.454 -    val cx = (ctxt, Termtab.empty)
   3.455 -    val (idxs, thms) = split_list irules
   3.456 -    val (thms', (ctxt', defs)) = fold_map (lift_lm (traverse [])) thms cx
   3.457 -    val eqs = Termtab.fold (cons o normalize_rule ctxt' o snd) defs []
   3.458 -  in (map (pair ~1) eqs @ (idxs ~~ thms'), ctxt') end
   3.459 +val setup_unfolded_quants = fold (B.add_builtin_fun_ext'' o fst) special_quants
   3.460 +
   3.461  end
   3.462  
   3.463  
   3.464 -
   3.465 -(* make application explicit for functions with varying number of arguments *)
   3.466 +(** trigger inference **)
   3.467  
   3.468  local
   3.469 -  val const = prefix "c" and free = prefix "f"
   3.470 -  fun min i (e as (_, j)) = if i <> j then (true, Int.min (i, j)) else e
   3.471 -  fun add t i = Symtab.map_default (t, (false, i)) (min i)
   3.472 -  fun traverse t =
   3.473 +  (*** check trigger syntax ***)
   3.474 +
   3.475 +  fun dest_trigger (Const (@{const_name pat}, _) $ _) = SOME true
   3.476 +    | dest_trigger (Const (@{const_name nopat}, _) $ _) = SOME false
   3.477 +    | dest_trigger _ = NONE
   3.478 +
   3.479 +  fun eq_list [] = false
   3.480 +    | eq_list (b :: bs) = forall (equal b) bs
   3.481 +
   3.482 +  fun proper_trigger t =
   3.483 +    t
   3.484 +    |> these o try HOLogic.dest_list
   3.485 +    |> map (map_filter dest_trigger o these o try HOLogic.dest_list)
   3.486 +    |> (fn [] => false | bss => forall eq_list bss)
   3.487 +
   3.488 +  fun proper_quant inside f t =
   3.489 +    (case t of
   3.490 +      Const (@{const_name All}, _) $ Abs (_, _, u) => proper_quant true f u
   3.491 +    | Const (@{const_name Ex}, _) $ Abs (_, _, u) => proper_quant true f u
   3.492 +    | @{const trigger} $ p $ u =>
   3.493 +        (if inside then f p else false) andalso proper_quant false f u
   3.494 +    | Abs (_, _, u) => proper_quant false f u
   3.495 +    | u1 $ u2 => proper_quant false f u1 andalso proper_quant false f u2
   3.496 +    | _ => true)
   3.497 +
   3.498 +  fun check_trigger_error ctxt t =
   3.499 +    error ("SMT triggers must only occur under quantifier and multipatterns " ^
   3.500 +      "must have the same kind: " ^ Syntax.string_of_term ctxt t)
   3.501 +
   3.502 +  fun check_trigger_conv ctxt ct =
   3.503 +    if proper_quant false proper_trigger (Thm.term_of ct) then Conv.all_conv ct
   3.504 +    else check_trigger_error ctxt (Thm.term_of ct)
   3.505 +
   3.506 +
   3.507 +  (*** infer simple triggers ***)
   3.508 +
   3.509 +  fun dest_cond_eq ct =
   3.510 +    (case Thm.term_of ct of
   3.511 +      Const (@{const_name HOL.eq}, _) $ _ $ _ => Thm.dest_binop ct
   3.512 +    | @{const HOL.implies} $ _ $ _ => dest_cond_eq (Thm.dest_arg ct)
   3.513 +    | _ => raise CTERM ("no equation", [ct]))
   3.514 +
   3.515 +  fun get_constrs thy (Type (n, _)) = these (Datatype.get_constrs thy n)
   3.516 +    | get_constrs _ _ = []
   3.517 +
   3.518 +  fun is_constr thy (n, T) =
   3.519 +    let fun match (m, U) = m = n andalso Sign.typ_instance thy (T, U)
   3.520 +    in can (the o find_first match o get_constrs thy o Term.body_type) T end
   3.521 +
   3.522 +  fun is_constr_pat thy t =
   3.523 +    (case Term.strip_comb t of
   3.524 +      (Free _, []) => true
   3.525 +    | (Const c, ts) => is_constr thy c andalso forall (is_constr_pat thy) ts
   3.526 +    | _ => false)
   3.527 +
   3.528 +  fun is_simp_lhs ctxt t =
   3.529      (case Term.strip_comb t of
   3.530 -      (Const (n, _), ts) => add (const n) (length ts) #> fold traverse ts 
   3.531 -    | (Free (n, _), ts) => add (free n) (length ts) #> fold traverse ts
   3.532 -    | (Abs (_, _, u), ts) => fold traverse (u :: ts)
   3.533 -    | (_, ts) => fold traverse ts)
   3.534 -  fun prune tab = Symtab.fold (fn (n, (true, i)) =>
   3.535 -    Symtab.update (n, i) | _ => I) tab Symtab.empty
   3.536 +      (Const c, ts as _ :: _) =>
   3.537 +        not (B.is_builtin_fun_ext ctxt c ts) andalso
   3.538 +        forall (is_constr_pat (ProofContext.theory_of ctxt)) ts
   3.539 +    | _ => false)
   3.540 +
   3.541 +  fun has_all_vars vs t =
   3.542 +    subset (op aconv) (vs, map Free (Term.add_frees t []))
   3.543 +
   3.544 +  fun minimal_pats vs ct =
   3.545 +    if has_all_vars vs (Thm.term_of ct) then
   3.546 +      (case Thm.term_of ct of
   3.547 +        _ $ _ =>
   3.548 +          (case pairself (minimal_pats vs) (Thm.dest_comb ct) of
   3.549 +            ([], []) => [[ct]]
   3.550 +          | (ctss, ctss') => union (eq_set (op aconvc)) ctss ctss')
   3.551 +      | _ => [[ct]])
   3.552 +    else []
   3.553 +
   3.554 +  fun proper_mpat _ _ _ [] = false
   3.555 +    | proper_mpat thy gen u cts =
   3.556 +        let
   3.557 +          val tps = (op ~~) (`gen (map Thm.term_of cts))
   3.558 +          fun some_match u = tps |> exists (fn (t', t) =>
   3.559 +            Pattern.matches thy (t', u) andalso not (t aconv u))
   3.560 +        in not (Term.exists_subterm some_match u) end
   3.561 +
   3.562 +  val pat = U.mk_const_pat @{theory} @{const_name SMT.pat} U.destT1
   3.563 +  fun mk_pat ct = Thm.capply (U.instT' ct pat) ct
   3.564  
   3.565 -  fun binop_conv cv1 cv2 = Conv.combination_conv (Conv.arg_conv cv1) cv2
   3.566 -  fun nary_conv conv1 conv2 ct =
   3.567 -    (Conv.combination_conv (nary_conv conv1 conv2) conv2 else_conv conv1) ct
   3.568 -  fun abs_conv conv tb = Conv.abs_conv (fn (cv, cx) =>
   3.569 -    let val n = fst (Term.dest_Free (Thm.term_of cv))
   3.570 -    in conv (Symtab.update (free n, 0) tb) cx end)
   3.571 -  val fun_app_rule = @{lemma "f x == fun_app f x" by (simp add: fun_app_def)}
   3.572 +  fun mk_clist T = pairself (Thm.cterm_of @{theory})
   3.573 +    (HOLogic.cons_const T, HOLogic.nil_const T)
   3.574 +  fun mk_list (ccons, cnil) f cts = fold_rev (Thm.mk_binop ccons o f) cts cnil
   3.575 +  val mk_pat_list = mk_list (mk_clist @{typ SMT.pattern})
   3.576 +  val mk_mpat_list = mk_list (mk_clist @{typ "SMT.pattern list"})  
   3.577 +  fun mk_trigger ctss = mk_mpat_list (mk_pat_list mk_pat) ctss
   3.578 +
   3.579 +  val trigger_eq =
   3.580 +    mk_meta_eq @{lemma "p = SMT.trigger t p" by (simp add: trigger_def)}
   3.581 +
   3.582 +  fun insert_trigger_conv [] ct = Conv.all_conv ct
   3.583 +    | insert_trigger_conv ctss ct =
   3.584 +        let val (ctr, cp) = Thm.dest_binop (Thm.rhs_of trigger_eq) ||> rpair ct
   3.585 +        in Thm.instantiate ([], [cp, (ctr, mk_trigger ctss)]) trigger_eq end
   3.586 +
   3.587 +  fun infer_trigger_eq_conv outer_ctxt (ctxt, cvs) ct =
   3.588 +    let
   3.589 +      val (lhs, rhs) = dest_cond_eq ct
   3.590 +
   3.591 +      val vs = map Thm.term_of cvs
   3.592 +      val thy = ProofContext.theory_of ctxt
   3.593 +
   3.594 +      fun get_mpats ct =
   3.595 +        if is_simp_lhs ctxt (Thm.term_of ct) then minimal_pats vs ct
   3.596 +        else []
   3.597 +      val gen = Variable.export_terms ctxt outer_ctxt
   3.598 +      val filter_mpats = filter (proper_mpat thy gen (Thm.term_of rhs))
   3.599 +
   3.600 +    in insert_trigger_conv (filter_mpats (get_mpats lhs)) ct end
   3.601 +
   3.602 +  fun try_trigger_conv cv ct =
   3.603 +    if proper_quant false (K false) (Thm.term_of ct) then Conv.all_conv ct
   3.604 +    else Conv.try_conv cv ct
   3.605 +
   3.606 +  fun infer_trigger_conv ctxt =
   3.607 +    if Config.get ctxt SMT_Config.infer_triggers then
   3.608 +      try_trigger_conv (U.under_quant_conv (infer_trigger_eq_conv ctxt) ctxt)
   3.609 +    else Conv.all_conv
   3.610  in
   3.611 -fun explicit_application ctxt irules =
   3.612 -  let
   3.613 -    fun sub_conv tb ctxt ct =
   3.614 -      (case Term.strip_comb (Thm.term_of ct) of
   3.615 -        (Const (n, _), ts) => app_conv tb (const n) (length ts) ctxt
   3.616 -      | (Free (n, _), ts) => app_conv tb (free n) (length ts) ctxt
   3.617 -      | (Abs _, _) => nary_conv (abs_conv sub_conv tb ctxt) (sub_conv tb ctxt)
   3.618 -      | (_, _) => nary_conv Conv.all_conv (sub_conv tb ctxt)) ct
   3.619 -    and app_conv tb n i ctxt =
   3.620 -      (case Symtab.lookup tb n of
   3.621 -        NONE => nary_conv Conv.all_conv (sub_conv tb ctxt)
   3.622 -      | SOME j => fun_app_conv tb ctxt (i - j))
   3.623 -    and fun_app_conv tb ctxt i ct = (
   3.624 -      if i = 0 then nary_conv Conv.all_conv (sub_conv tb ctxt)
   3.625 -      else
   3.626 -        Conv.rewr_conv fun_app_rule then_conv
   3.627 -        binop_conv (fun_app_conv tb ctxt (i-1)) (sub_conv tb ctxt)) ct
   3.628 +
   3.629 +fun trigger_conv ctxt =
   3.630 +  U.prop_conv (check_trigger_conv ctxt then_conv infer_trigger_conv ctxt)
   3.631  
   3.632 -    fun needs_exp_app tab = Term.exists_subterm (fn
   3.633 -        Bound _ $ _ => true
   3.634 -      | Const (n, _) => Symtab.defined tab (const n)
   3.635 -      | Free (n, _) => Symtab.defined tab (free n)
   3.636 -      | _ => false)
   3.637 +val setup_trigger = fold B.add_builtin_fun_ext''
   3.638 +  [@{const_name SMT.pat}, @{const_name SMT.nopat}, @{const_name SMT.trigger}]
   3.639  
   3.640 -    fun rewrite tab ctxt thm =
   3.641 -      if not (needs_exp_app tab (Thm.prop_of thm)) then thm
   3.642 -      else Conv.fconv_rule (sub_conv tab ctxt) thm
   3.643 -
   3.644 -    val tab = prune (fold (traverse o Thm.prop_of o snd) irules Symtab.empty)
   3.645 -  in map (apsnd (rewrite tab ctxt)) irules end
   3.646  end
   3.647  
   3.648  
   3.649 -
   3.650 -(* add missing datatype selectors via hypothetical definitions *)
   3.651 +(** adding quantifier weights **)
   3.652  
   3.653  local
   3.654 -  val add = (fn Type (n, _) => Symtab.update (n, ()) | _ => I)
   3.655 +  (*** check weight syntax ***)
   3.656 +
   3.657 +  val has_no_weight =
   3.658 +    not o Term.exists_subterm (fn @{const SMT.weight} => true | _ => false)
   3.659  
   3.660 -  fun collect t =
   3.661 -    (case Term.strip_comb t of
   3.662 -      (Abs (_, T, t), _) => add T #> collect t
   3.663 -    | (Const (_, T), ts) => collects T ts
   3.664 -    | (Free (_, T), ts) => collects T ts
   3.665 -    | _ => I)
   3.666 -  and collects T ts =
   3.667 -    let val ((Ts, Us), U) = Term.strip_type T |> apfst (chop (length ts))
   3.668 -    in fold add Ts #> add (Us ---> U) #> fold collect ts end
   3.669 +  fun is_weight (@{const SMT.weight} $ w $ t) =
   3.670 +        (case try HOLogic.dest_number w of
   3.671 +          SOME (_, i) => i > 0 andalso has_no_weight t
   3.672 +        | _ => false)
   3.673 +    | is_weight t = has_no_weight t
   3.674 +
   3.675 +  fun proper_trigger (@{const SMT.trigger} $ _ $ t) = is_weight t
   3.676 +    | proper_trigger t = has_no_weight t
   3.677 +
   3.678 +  fun check_weight_error ctxt t =
   3.679 +    error ("SMT weight must be a positive number and must only occur " ^
   3.680 +      "under the top-most quantifier and an optional trigger: " ^
   3.681 +      Syntax.string_of_term ctxt t)
   3.682  
   3.683 -  fun add_constructors thy n =
   3.684 -    (case Datatype.get_info thy n of
   3.685 -      NONE => I
   3.686 -    | SOME {descr, ...} => fold (fn (_, (_, _, cs)) => fold (fn (n, ds) =>
   3.687 -        fold (insert (op =) o pair n) (1 upto length ds)) cs) descr)
   3.688 +  fun check_weight_conv ctxt ct =
   3.689 +    if U.under_quant proper_trigger (Thm.term_of ct) then Conv.all_conv ct
   3.690 +    else check_weight_error ctxt (Thm.term_of ct)
   3.691 +
   3.692 +
   3.693 +  (*** insertion of weights ***)
   3.694 +
   3.695 +  fun under_trigger_conv cv ct =
   3.696 +    (case Thm.term_of ct of
   3.697 +      @{const SMT.trigger} $ _ $ _ => Conv.arg_conv cv
   3.698 +    | _ => cv) ct
   3.699  
   3.700 -  fun add_selector (c as (n, i)) ctxt =
   3.701 -    (case Datatype_Selectors.lookup_selector ctxt c of
   3.702 -      SOME _ => ctxt
   3.703 -    | NONE =>
   3.704 -        let
   3.705 -          val T = Sign.the_const_type (ProofContext.theory_of ctxt) n
   3.706 -          val U = Term.body_type T --> nth (Term.binder_types T) (i-1)
   3.707 -        in
   3.708 -          ctxt
   3.709 -          |> yield_singleton Variable.variant_fixes Name.uu
   3.710 -          |>> pair ((n, T), i) o rpair U
   3.711 -          |-> Context.proof_map o Datatype_Selectors.add_selector
   3.712 -        end)
   3.713 +  val weight_eq =
   3.714 +    mk_meta_eq @{lemma "p = SMT.weight i p" by (simp add: weight_def)}
   3.715 +  fun mk_weight_eq w =
   3.716 +    let val cv = Thm.dest_arg1 (Thm.rhs_of weight_eq)
   3.717 +    in
   3.718 +      Thm.instantiate ([], [(cv, Numeral.mk_cnumber @{ctyp int} w)]) weight_eq
   3.719 +    end
   3.720 +
   3.721 +  fun add_weight_conv NONE _ = Conv.all_conv
   3.722 +    | add_weight_conv (SOME weight) ctxt =
   3.723 +        let val cv = Conv.rewr_conv (mk_weight_eq weight)
   3.724 +        in U.under_quant_conv (K (under_trigger_conv cv)) ctxt end
   3.725  in
   3.726  
   3.727 -fun datatype_selectors irules ctxt =
   3.728 -  let
   3.729 -    val ns = Symtab.keys (fold (collect o Thm.prop_of o snd) irules Symtab.empty)
   3.730 -    val cs = fold (add_constructors (ProofContext.theory_of ctxt)) ns []
   3.731 -  in (irules, fold add_selector cs ctxt) end
   3.732 -    (* FIXME: also generate hypothetical definitions for the selectors *)
   3.733 +fun weight_conv weight ctxt = 
   3.734 +  U.prop_conv (check_weight_conv ctxt then_conv add_weight_conv weight ctxt)
   3.735 +
   3.736 +val setup_weight = B.add_builtin_fun_ext'' @{const_name SMT.weight}
   3.737  
   3.738  end
   3.739  
   3.740  
   3.741 -
   3.742 -(* combined normalization *)
   3.743 +(** combined general normalizations **)
   3.744  
   3.745 -type extra_norm = bool -> (int * thm) list -> Proof.context ->
   3.746 -  (int * thm) list * Proof.context
   3.747 -
   3.748 -fun with_context f irules ctxt = (f ctxt irules, ctxt)
   3.749 +fun gen_normalize1_conv ctxt weight =
   3.750 +  atomize_conv ctxt then_conv
   3.751 +  unfold_special_quants_conv ctxt then_conv
   3.752 +  trigger_conv ctxt then_conv
   3.753 +  weight_conv weight ctxt
   3.754  
   3.755 -fun normalize extra_norm with_datatypes irules ctxt =
   3.756 -  let
   3.757 -    fun norm f ctxt' (i, thm) =
   3.758 -      if Config.get ctxt' SMT_Config.drop_bad_facts then
   3.759 -        (case try (f ctxt') thm of
   3.760 -          SOME thm' => SOME (i, thm')
   3.761 -        | NONE => (SMT_Config.verbose_msg ctxt' (prefix ("Warning: " ^
   3.762 -            "dropping assumption: ") o Display.string_of_thm ctxt') thm; NONE))
   3.763 -      else SOME (i, f ctxt' thm)
   3.764 -  in
   3.765 -    irules
   3.766 -    |> map (apsnd instantiate_elim)
   3.767 -    |> trivial_distinct ctxt
   3.768 -    |> rewrite_bool_cases ctxt
   3.769 -    |> normalize_numerals ctxt
   3.770 -    |> nat_as_int ctxt
   3.771 -    |> rpair ctxt
   3.772 -    |-> extra_norm with_datatypes
   3.773 -    |-> with_context (map_filter o norm normalize_rule)
   3.774 -    |-> SMT_Monomorph.monomorph
   3.775 -    |-> lift_lambdas
   3.776 -    |-> with_context explicit_application
   3.777 -    |-> (if with_datatypes then datatype_selectors else pair)
   3.778 -  end
   3.779 +fun gen_normalize1 ctxt weight thm =
   3.780 +  thm
   3.781 +  |> instantiate_elim
   3.782 +  |> norm_def
   3.783 +  |> Conv.fconv_rule (Thm.beta_conversion true then_conv Thm.eta_conversion)
   3.784 +  |> Drule.forall_intr_vars
   3.785 +  |> Conv.fconv_rule (gen_normalize1_conv ctxt weight)
   3.786 +
   3.787 +fun drop_fact_warning ctxt =
   3.788 +  let val pre = prefix "Warning: dropping assumption: "
   3.789 +  in SMT_Config.verbose_msg ctxt (pre o Display.string_of_thm ctxt) end
   3.790 +
   3.791 +fun gen_norm1_safe ctxt (i, (weight, thm)) =
   3.792 +  if Config.get ctxt SMT_Config.drop_bad_facts then
   3.793 +    (case try (gen_normalize1 ctxt weight) thm of
   3.794 +      SOME thm' => SOME (i, thm')
   3.795 +    | NONE => (drop_fact_warning ctxt thm; NONE))
   3.796 +  else SOME (i, gen_normalize1 ctxt weight thm)
   3.797 +
   3.798 +fun gen_normalize ctxt iwthms = map_filter (gen_norm1_safe ctxt) iwthms
   3.799  
   3.800  
   3.801  
   3.802 -(* setup *)
   3.803 +(* unfolding of definitions and theory-specific rewritings *)
   3.804 +
   3.805 +(** unfold trivial distincts **)
   3.806 +
   3.807 +local
   3.808 +  fun is_trivial_distinct (Const (@{const_name distinct}, _) $ t) =
   3.809 +        (case try HOLogic.dest_list t of
   3.810 +          SOME [] => true
   3.811 +        | SOME [_] => true
   3.812 +        | _ => false)
   3.813 +    | is_trivial_distinct _ = false
   3.814 +
   3.815 +  val thms = map mk_meta_eq @{lemma
   3.816 +    "distinct [] = True"
   3.817 +    "distinct [x] = True"
   3.818 +    "distinct [x, y] = (x ~= y)"
   3.819 +    by simp_all}
   3.820 +  fun distinct_conv _ =
   3.821 +    U.if_true_conv is_trivial_distinct (Conv.rewrs_conv thms)
   3.822 +in
   3.823 +
   3.824 +fun trivial_distinct_conv ctxt = U.if_exists_conv is_trivial_distinct
   3.825 +  (Conv.top_conv distinct_conv ctxt)
   3.826 +
   3.827 +end
   3.828 +
   3.829 +
   3.830 +(** rewrite bool case expressions as if expressions **)
   3.831 +
   3.832 +local
   3.833 +  fun is_bool_case (Const (@{const_name "bool.bool_case"}, _)) = true
   3.834 +    | is_bool_case _ = false
   3.835 +
   3.836 +  val thm = mk_meta_eq @{lemma
   3.837 +    "bool_case = (%x y P. if P then x else y)" by (rule ext)+ simp}
   3.838 +
   3.839 +  fun unfold_conv _ = U.if_true_conv is_bool_case (Conv.rewr_conv thm)
   3.840 +in
   3.841 +
   3.842 +fun rewrite_bool_case_conv ctxt = U.if_exists_conv is_bool_case
   3.843 +  (Conv.top_conv unfold_conv ctxt)
   3.844 +
   3.845 +val setup_bool_case = B.add_builtin_fun_ext'' @{const_name "bool.bool_case"}
   3.846 +
   3.847 +end
   3.848 +
   3.849 +
   3.850 +(** unfold abs, min and max **)
   3.851 +
   3.852 +local
   3.853 +  val abs_def = mk_meta_eq @{lemma
   3.854 +    "abs = (%a::'a::abs_if. if a < 0 then - a else a)"
   3.855 +    by (rule ext) (rule abs_if)}
   3.856 +
   3.857 +  val min_def = mk_meta_eq @{lemma "min = (%a b. if a <= b then a else b)"
   3.858 +    by (rule ext)+ (rule min_def)}
   3.859 +
   3.860 +  val max_def = mk_meta_eq  @{lemma "max = (%a b. if a <= b then b else a)"
   3.861 +    by (rule ext)+ (rule max_def)}
   3.862 +
   3.863 +  val defs = [(@{const_name min}, min_def), (@{const_name max}, max_def),
   3.864 +    (@{const_name abs}, abs_def)]
   3.865 +
   3.866 +  fun is_builtinT ctxt T = B.is_builtin_typ_ext ctxt (Term.domain_type T)
   3.867 +
   3.868 +  fun abs_min_max ctxt (Const (n, T)) =
   3.869 +        (case AList.lookup (op =) defs n of
   3.870 +          NONE => NONE
   3.871 +        | SOME thm => if is_builtinT ctxt T then SOME thm else NONE)
   3.872 +    | abs_min_max _ _ = NONE
   3.873 +
   3.874 +  fun unfold_amm_conv ctxt ct =
   3.875 +    (case abs_min_max ctxt (Thm.term_of ct) of
   3.876 +      SOME thm => Conv.rewr_conv thm
   3.877 +    | NONE => Conv.all_conv) ct
   3.878 +in
   3.879 +
   3.880 +fun unfold_abs_min_max_conv ctxt =
   3.881 +  U.if_exists_conv (is_some o abs_min_max ctxt)
   3.882 +    (Conv.top_conv unfold_amm_conv ctxt)
   3.883 +  
   3.884 +val setup_abs_min_max = fold (B.add_builtin_fun_ext'' o fst) defs
   3.885 +
   3.886 +end
   3.887 +
   3.888 +
   3.889 +(** embedding of standard natural number operations into integer operations **)
   3.890 +
   3.891 +local
   3.892 +  val nat_embedding = @{lemma
   3.893 +    "ALL n. nat (int n) = n"
   3.894 +    "ALL i. i >= 0 --> int (nat i) = i"
   3.895 +    "ALL i. i < 0 --> int (nat i) = 0"
   3.896 +    by simp_all}
   3.897 +
   3.898 +  val nat_ops = [
   3.899 +    @{const less (nat)}, @{const less_eq (nat)},
   3.900 +    @{const Suc}, @{const plus (nat)}, @{const minus (nat)},
   3.901 +    @{const times (nat)}, @{const div (nat)}, @{const mod (nat)}]
   3.902 +
   3.903 +  val nat_consts = nat_ops @ [@{const number_of (nat)},
   3.904 +    @{const zero_class.zero (nat)}, @{const one_class.one (nat)}]
   3.905 +
   3.906 +  val nat_int_coercions = [@{const of_nat (int)}, @{const nat}]
   3.907 +
   3.908 +  val nat_ops' = nat_int_coercions @ nat_ops
   3.909 +
   3.910 +  val is_nat_const = member (op aconv) nat_consts
   3.911 +
   3.912 +  val expands = map mk_meta_eq @{lemma
   3.913 +    "0 = nat 0"
   3.914 +    "1 = nat 1"
   3.915 +    "(number_of :: int => nat) = (%i. nat (number_of i))"
   3.916 +    "op < = (%a b. int a < int b)"
   3.917 +    "op <= = (%a b. int a <= int b)"
   3.918 +    "Suc = (%a. nat (int a + 1))"
   3.919 +    "op + = (%a b. nat (int a + int b))"
   3.920 +    "op - = (%a b. nat (int a - int b))"
   3.921 +    "op * = (%a b. nat (int a * int b))"
   3.922 +    "op div = (%a b. nat (int a div int b))"
   3.923 +    "op mod = (%a b. nat (int a mod int b))"
   3.924 +    by (auto intro!: ext simp add: nat_mult_distrib nat_div_distrib
   3.925 +      nat_mod_distrib)}
   3.926 +
   3.927 +  val ints = map mk_meta_eq @{lemma
   3.928 +    "int 0 = 0"
   3.929 +    "int 1 = 1"
   3.930 +    "int (Suc n) = int n + 1"
   3.931 +    "int (n + m) = int n + int m"
   3.932 +    "int (n - m) = int (nat (int n - int m))"
   3.933 +    "int (n * m) = int n * int m"
   3.934 +    "int (n div m) = int n div int m"
   3.935 +    "int (n mod m) = int n mod int m"
   3.936 +    "int (if P then n else m) = (if P then int n else int m)"
   3.937 +    by (auto simp add: int_mult zdiv_int zmod_int)}
   3.938 +
   3.939 +  fun mk_number_eq ctxt i lhs =
   3.940 +    let
   3.941 +      val eq = U.mk_cequals lhs (Numeral.mk_cnumber @{ctyp int} i)
   3.942 +      val ss = HOL_ss
   3.943 +        addsimps [@{thm Nat_Numeral.int_nat_number_of}]
   3.944 +        addsimps @{thms neg_simps}
   3.945 +      fun tac _ = Simplifier.simp_tac (Simplifier.context ctxt ss) 1       
   3.946 +    in Goal.norm_result (Goal.prove_internal [] eq tac) end
   3.947 +
   3.948 +  fun expand_head_conv cv ct =
   3.949 +    (case Thm.term_of ct of
   3.950 +      _ $ _ =>
   3.951 +        Conv.fun_conv (expand_head_conv cv) then_conv
   3.952 +        Thm.beta_conversion false
   3.953 +    | _ => cv) ct
   3.954 +
   3.955 +  fun int_conv ctxt ct =
   3.956 +    (case Thm.term_of ct of
   3.957 +      @{const of_nat (int)} $ (n as (@{const number_of (nat)} $ _)) =>
   3.958 +        Conv.rewr_conv (mk_number_eq ctxt (snd (HOLogic.dest_number n)) ct)
   3.959 +    | @{const of_nat (int)} $ _ =>
   3.960 +        (Conv.rewrs_conv ints then_conv Conv.sub_conv ints_conv ctxt) else_conv
   3.961 +        Conv.top_sweep_conv nat_conv ctxt        
   3.962 +    | _ => Conv.no_conv) ct
   3.963 +
   3.964 +  and ints_conv ctxt = Conv.top_sweep_conv int_conv ctxt
   3.965 +
   3.966 +  and expand_conv ctxt =
   3.967 +    U.if_conv (not o is_nat_const o Term.head_of) Conv.no_conv
   3.968 +      (expand_head_conv (Conv.rewrs_conv expands) then_conv ints_conv ctxt)
   3.969 +
   3.970 +  and nat_conv ctxt = U.if_exists_conv is_nat_const
   3.971 +    (Conv.top_sweep_conv expand_conv ctxt)
   3.972 +
   3.973 +  val uses_nat_int = Term.exists_subterm (member (op aconv) nat_int_coercions)
   3.974 +in
   3.975 +
   3.976 +val nat_as_int_conv = nat_conv
   3.977 +
   3.978 +fun add_nat_embedding thms =
   3.979 +  if exists (uses_nat_int o Thm.prop_of) thms then (thms, nat_embedding)
   3.980 +  else (thms, [])
   3.981 +
   3.982 +val setup_nat_as_int =
   3.983 +  B.add_builtin_typ_ext (@{typ nat}, K true) #>
   3.984 +  fold (B.add_builtin_fun_ext' o Term.dest_Const) nat_ops'
   3.985 +
   3.986 +end
   3.987 +
   3.988 +
   3.989 +(** normalize numerals **)
   3.990 +
   3.991 +local
   3.992 +  (*
   3.993 +    rewrite negative numerals into positive numerals,
   3.994 +    rewrite Numeral0 into 0
   3.995 +    rewrite Numeral1 into 1
   3.996 +  *)
   3.997 +
   3.998 +  fun is_strange_number ctxt (t as Const (@{const_name number_of}, _) $ _) =
   3.999 +        (case try HOLogic.dest_number t of
  3.1000 +          SOME (_, i) => B.is_builtin_num ctxt t andalso i < 2
  3.1001 +        | NONE => false)
  3.1002 +    | is_strange_number _ _ = false
  3.1003 +
  3.1004 +  val pos_num_ss = HOL_ss
  3.1005 +    addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}]
  3.1006 +    addsimps [@{thm Int.number_of_Pls}, @{thm Int.numeral_1_eq_1}]
  3.1007 +    addsimps @{thms Int.pred_bin_simps}
  3.1008 +    addsimps @{thms Int.normalize_bin_simps}
  3.1009 +    addsimps @{lemma
  3.1010 +      "Int.Min = - Int.Bit1 Int.Pls"
  3.1011 +      "Int.Bit0 (- Int.Pls) = - Int.Pls"
  3.1012 +      "Int.Bit0 (- k) = - Int.Bit0 k"
  3.1013 +      "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)"
  3.1014 +      by simp_all (simp add: pred_def)}
  3.1015 +
  3.1016 +  fun norm_num_conv ctxt = U.if_conv (is_strange_number ctxt)
  3.1017 +    (Simplifier.rewrite (Simplifier.context ctxt pos_num_ss)) Conv.no_conv
  3.1018 +in
  3.1019 +
  3.1020 +fun normalize_numerals_conv ctxt = U.if_exists_conv (is_strange_number ctxt)
  3.1021 +  (Conv.top_sweep_conv norm_num_conv ctxt)
  3.1022 +
  3.1023 +end
  3.1024 +
  3.1025 +
  3.1026 +(** combined unfoldings and rewritings **)
  3.1027 +
  3.1028 +fun unfold_conv ctxt =
  3.1029 +  trivial_distinct_conv ctxt then_conv
  3.1030 +  rewrite_bool_case_conv ctxt then_conv
  3.1031 +  unfold_abs_min_max_conv ctxt then_conv
  3.1032 +  nat_as_int_conv ctxt then_conv
  3.1033 +  normalize_numerals_conv ctxt then_conv
  3.1034 +  Thm.beta_conversion true
  3.1035 +
  3.1036 +fun burrow_ids f ithms =
  3.1037 +  let
  3.1038 +    val (is, thms) = split_list ithms
  3.1039 +    val (thms', extra_thms) = f thms
  3.1040 +  in (is ~~ thms') @ map (pair ~1) extra_thms end
  3.1041 +
  3.1042 +fun unfold ctxt =
  3.1043 +  burrow_ids (map (Conv.fconv_rule (unfold_conv ctxt)) #> add_nat_embedding)
  3.1044 +
  3.1045 +
  3.1046 +
  3.1047 +(* overall normalization *)
  3.1048 +
  3.1049 +type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list
  3.1050 +
  3.1051 +structure Extra_Norms = Generic_Data
  3.1052 +(
  3.1053 +  type T = extra_norm U.dict
  3.1054 +  val empty = []
  3.1055 +  val extend = I
  3.1056 +  val merge = U.dict_merge fst
  3.1057 +)
  3.1058 +
  3.1059 +fun add_extra_norm (cs, norm) = Extra_Norms.map (U.dict_update (cs, norm))
  3.1060 +
  3.1061 +fun apply_extra_norms ctxt =
  3.1062 +  let
  3.1063 +    val cs = SMT_Config.solver_class_of ctxt
  3.1064 +    val es = U.dict_lookup (Extra_Norms.get (Context.Proof ctxt)) cs
  3.1065 +  in burrow_ids (fold (fn e => e ctxt) es o rpair []) end
  3.1066 +
  3.1067 +fun normalize ctxt iwthms =
  3.1068 +  iwthms
  3.1069 +  |> gen_normalize ctxt
  3.1070 +  |> unfold ctxt
  3.1071 +  |> apply_extra_norms ctxt
  3.1072  
  3.1073  val setup = Context.theory_map (
  3.1074 +  setup_atomize #>
  3.1075 +  setup_unfolded_quants #>
  3.1076 +  setup_trigger #>
  3.1077 +  setup_weight #>
  3.1078    setup_bool_case #>
  3.1079 -  setup_nat_as_int #>
  3.1080 -  setup_unfolded_quants #>
  3.1081 -  setup_atomize)
  3.1082 +  setup_abs_min_max #>
  3.1083 +  setup_nat_as_int)
  3.1084  
  3.1085  end
     4.1 --- a/src/HOL/Tools/SMT/smt_solver.ML	Wed Dec 15 08:39:24 2010 +0100
     4.2 +++ b/src/HOL/Tools/SMT/smt_solver.ML	Wed Dec 15 08:39:24 2010 +0100
     4.3 @@ -9,7 +9,6 @@
     4.4    (*configuration*)
     4.5    type interface = {
     4.6      class: SMT_Utils.class,
     4.7 -    extra_norm: SMT_Normalize.extra_norm,
     4.8      translate: SMT_Translate.config }
     4.9    datatype outcome = Unsat | Sat | Unknown
    4.10    type solver_config = {
    4.11 @@ -26,8 +25,8 @@
    4.12      default_max_relevant: int }
    4.13  
    4.14    (*registry*)
    4.15 -  type solver = bool option -> Proof.context -> (int * thm) list ->
    4.16 -    int list * thm
    4.17 +  type solver = bool option -> Proof.context ->
    4.18 +    (int * (int option * thm)) list -> int list * thm
    4.19    val add_solver: solver_config -> theory -> theory
    4.20    val solver_name_of: Proof.context -> string
    4.21    val solver_of: Proof.context -> solver
    4.22 @@ -37,7 +36,8 @@
    4.23    val default_max_relevant: Proof.context -> string -> int
    4.24  
    4.25    (*filter*)
    4.26 -  val smt_filter: bool -> Time.time -> Proof.state -> ('a * thm) list -> int ->
    4.27 +  val smt_filter: bool -> Time.time -> Proof.state ->
    4.28 +    ('a * (int option * thm)) list -> int ->
    4.29      {outcome: SMT_Failure.failure option, used_facts: ('a * thm) list,
    4.30      run_time_in_msecs: int option}
    4.31  
    4.32 @@ -59,7 +59,6 @@
    4.33  
    4.34  type interface = {
    4.35    class: SMT_Utils.class,
    4.36 -  extra_norm: SMT_Normalize.extra_norm,
    4.37    translate: SMT_Translate.config }
    4.38  
    4.39  datatype outcome = Unsat | Sat | Unknown
    4.40 @@ -176,7 +175,7 @@
    4.41          Pretty.big_list "functions:" (map p_term (Symtab.dest terms))])) ()
    4.42    end
    4.43  
    4.44 -fun invoke translate_config name cmd options irules ctxt =
    4.45 +fun invoke translate_config name cmd options ithms ctxt =
    4.46    let
    4.47      val args = C.solver_options_of ctxt @ options ctxt
    4.48      val comments = ("solver: " ^ name) ::
    4.49 @@ -184,7 +183,7 @@
    4.50        ("random seed: " ^ string_of_int (Config.get ctxt C.random_seed)) ::
    4.51        "arguments:" :: args
    4.52    in
    4.53 -    irules
    4.54 +    ithms
    4.55      |> tap (trace_assms ctxt)
    4.56      |> SMT_Translate.translate translate_config ctxt comments
    4.57      ||> tap (trace_recon_data ctxt)
    4.58 @@ -197,22 +196,23 @@
    4.59    else discharge_definitions (@{thm reflexive} RS thm)
    4.60  
    4.61  fun set_has_datatypes with_datatypes translate =
    4.62 -  let
    4.63 -    val {prefixes, header, is_fol, has_datatypes, serialize} = translate
    4.64 -    val with_datatypes' = has_datatypes andalso with_datatypes
    4.65 -    val translate' = {prefixes=prefixes, header=header, is_fol=is_fol,
    4.66 -      has_datatypes=with_datatypes', serialize=serialize}
    4.67 -  in (with_datatypes', translate') end
    4.68 +  let val {prefixes, header, is_fol, has_datatypes, serialize} = translate
    4.69 +  in
    4.70 +   {prefixes=prefixes, header=header, is_fol=is_fol,
    4.71 +    has_datatypes=has_datatypes andalso with_datatypes, serialize=serialize}
    4.72 +  end
    4.73  
    4.74 -fun trace_assumptions ctxt irules idxs =
    4.75 +fun trace_assumptions ctxt iwthms idxs =
    4.76    let
    4.77 -    val thms = filter (fn i => i >= 0) idxs
    4.78 -      |> map_filter (AList.lookup (op =) irules)
    4.79 +    val wthms =
    4.80 +      idxs
    4.81 +      |> filter (fn i => i >= 0)
    4.82 +      |> map_filter (AList.lookup (op =) iwthms)
    4.83    in
    4.84 -    if Config.get ctxt C.trace_used_facts andalso length thms > 0
    4.85 +    if Config.get ctxt C.trace_used_facts andalso length wthms > 0
    4.86      then
    4.87        tracing (Pretty.string_of (Pretty.big_list "SMT used facts:"
    4.88 -        (map (Display.pretty_thm ctxt) thms)))
    4.89 +        (map (Display.pretty_thm ctxt o snd) wthms)))
    4.90      else ()
    4.91    end
    4.92  
    4.93 @@ -220,7 +220,8 @@
    4.94  
    4.95  (* registry *)
    4.96  
    4.97 -type solver = bool option -> Proof.context -> (int * thm) list -> int list * thm
    4.98 +type solver = bool option -> Proof.context -> (int * (int option * thm)) list ->
    4.99 +  int list * thm
   4.100  
   4.101  type solver_info = {
   4.102    env_var: string,
   4.103 @@ -231,22 +232,22 @@
   4.104      (int list * thm) * Proof.context,
   4.105    default_max_relevant: int }
   4.106  
   4.107 -fun gen_solver name (info : solver_info) rm ctxt irules =
   4.108 +fun gen_solver name (info : solver_info) rm ctxt iwthms =
   4.109    let
   4.110      val {env_var, is_remote, options, interface, reconstruct, ...} = info
   4.111 -    val {extra_norm, translate, ...} = interface
   4.112 -    val (with_datatypes, translate') =
   4.113 -      set_has_datatypes (Config.get ctxt C.datatypes) translate
   4.114 +    val {translate, ...} = interface
   4.115 +    val translate' = set_has_datatypes (Config.get ctxt C.datatypes) translate
   4.116      val cmd = (rm, env_var, is_remote, name)
   4.117    in
   4.118 -    (irules, ctxt)
   4.119 -    |-> SMT_Normalize.normalize extra_norm with_datatypes
   4.120 +    SMT_Normalize.normalize ctxt iwthms
   4.121 +    |> rpair ctxt
   4.122 +    |-> SMT_Monomorph.monomorph
   4.123      |-> invoke translate' name cmd options
   4.124      |-> reconstruct
   4.125      |-> (fn (idxs, thm) => fn ctxt' => thm
   4.126      |> singleton (ProofContext.export ctxt' ctxt)
   4.127      |> discharge_definitions
   4.128 -    |> tap (fn _ => trace_assumptions ctxt irules idxs)
   4.129 +    |> tap (fn _ => trace_assumptions ctxt iwthms idxs)
   4.130      |> pair idxs)
   4.131    end
   4.132  
   4.133 @@ -330,38 +331,45 @@
   4.134    | TVar (_, []) => true
   4.135    | _ => false))
   4.136  
   4.137 -fun smt_solver rm ctxt irules =
   4.138 +fun smt_solver rm ctxt iwthms =
   4.139    (* without this test, we would run into problems when atomizing the rules: *)
   4.140 -  if exists (has_topsort o Thm.prop_of o snd) irules then
   4.141 +  if exists (has_topsort o Thm.prop_of o snd o snd) iwthms then
   4.142      raise SMT_Failure.SMT (SMT_Failure.Other_Failure ("proof state " ^
   4.143        "contains the universal sort {}"))
   4.144 -  else solver_of ctxt rm ctxt irules
   4.145 +  else solver_of ctxt rm ctxt iwthms
   4.146  
   4.147  val cnot = Thm.cterm_of @{theory} @{const Not}
   4.148  
   4.149 -fun smt_filter run_remote time_limit st xrules i =
   4.150 +fun mk_result outcome xrules =
   4.151 +  { outcome = outcome, used_facts = xrules, run_time_in_msecs = NONE }
   4.152 +
   4.153 +fun smt_filter run_remote time_limit st xwrules i =
   4.154    let
   4.155 -    val {facts, goal, ...} = Proof.goal st
   4.156      val ctxt =
   4.157        Proof.context_of st
   4.158        |> Config.put C.oracle false
   4.159        |> Config.put C.timeout (Time.toReal time_limit)
   4.160        |> Config.put C.drop_bad_facts true
   4.161        |> Config.put C.filter_only_facts true
   4.162 +
   4.163 +    val {facts, goal, ...} = Proof.goal st
   4.164      val ({context=ctxt', prems, concl, ...}, _) = Subgoal.focus ctxt i goal
   4.165      fun negate ct = Thm.dest_comb ct ||> Thm.capply cnot |-> Thm.capply
   4.166      val cprop = negate (Thm.rhs_of (SMT_Normalize.atomize_conv ctxt' concl))
   4.167 -    val irs = map (pair ~1) (Thm.assume cprop :: prems @ facts)
   4.168 -    val rm = SOME run_remote
   4.169 +
   4.170 +    val (xs, wthms) = split_list xwrules
   4.171 +    val xrules = xs ~~ map snd wthms
   4.172    in
   4.173 -    (xrules, map snd xrules)
   4.174 -    ||> distinct (op =) o fst o smt_solver rm ctxt' o append irs o map_index I
   4.175 -    |-> map_filter o try o nth
   4.176 -    |> (fn xs => {outcome=NONE, used_facts=if solver_name_of ctxt = "z3" (* FIXME *) then xs
   4.177 -      else xrules, run_time_in_msecs=NONE})
   4.178 +    wthms
   4.179 +    |> map_index I
   4.180 +    |> append (map (pair ~1 o pair NONE) (Thm.assume cprop :: prems @ facts))
   4.181 +    |> smt_solver (SOME run_remote) ctxt'
   4.182 +    |> distinct (op =) o fst
   4.183 +    |> map_filter (try (nth xrules))
   4.184 +    |> (if solver_name_of ctxt = "z3" (* FIXME *) then I else K xrules)
   4.185 +    |> mk_result NONE
   4.186    end
   4.187 -  handle SMT_Failure.SMT fail => {outcome=SOME fail, used_facts=[],
   4.188 -    run_time_in_msecs=NONE}
   4.189 +  handle SMT_Failure.SMT fail => mk_result (SOME fail) []
   4.190    (* FIXME: measure runtime *)
   4.191  
   4.192  
   4.193 @@ -373,18 +381,18 @@
   4.194    THEN' Tactic.rtac @{thm ccontr}
   4.195    THEN' SUBPROOF (fn {context=ctxt', prems, ...} =>
   4.196      let
   4.197 -      fun solve irules = snd (smt_solver NONE ctxt' irules)
   4.198 +      fun solve iwthms = snd (smt_solver NONE ctxt' iwthms)
   4.199        val tag = "Solver " ^ C.solver_of ctxt' ^ ": "
   4.200        val str_of = prefix tag o SMT_Failure.string_of_failure ctxt'
   4.201 -      fun safe_solve irules =
   4.202 -        if pass_exns then SOME (solve irules)
   4.203 -        else (SOME (solve irules)
   4.204 +      fun safe_solve iwthms =
   4.205 +        if pass_exns then SOME (solve iwthms)
   4.206 +        else (SOME (solve iwthms)
   4.207            handle
   4.208              SMT_Failure.SMT (fail as SMT_Failure.Counterexample _) =>
   4.209                (C.verbose_msg ctxt' str_of fail; NONE)
   4.210            | SMT_Failure.SMT fail => (C.trace_msg ctxt' str_of fail; NONE))
   4.211      in
   4.212 -      safe_solve (map (pair ~1) (rules @ prems))
   4.213 +      safe_solve (map (pair ~1 o pair NONE) (rules @ prems))
   4.214        |> (fn SOME thm => Tactic.rtac thm 1 | _ => Tactical.no_tac)
   4.215      end) ctxt
   4.216  
     5.1 --- a/src/HOL/Tools/SMT/smt_utils.ML	Wed Dec 15 08:39:24 2010 +0100
     5.2 +++ b/src/HOL/Tools/SMT/smt_utils.ML	Wed Dec 15 08:39:24 2010 +0100
     5.3 @@ -25,6 +25,7 @@
     5.4    (*terms*)
     5.5    val dest_conj: term -> term * term
     5.6    val dest_disj: term -> term * term
     5.7 +  val under_quant: (term -> 'a) -> term -> 'a
     5.8  
     5.9    (*patterns and instantiations*)
    5.10    val mk_const_pat: theory -> string -> (ctyp -> 'a) -> 'a * cterm
    5.11 @@ -48,7 +49,10 @@
    5.12    (*conversions*)
    5.13    val if_conv: (term -> bool) -> conv -> conv -> conv
    5.14    val if_true_conv: (term -> bool) -> conv -> conv
    5.15 +  val if_exists_conv: (term -> bool) -> conv -> conv
    5.16    val binders_conv: (Proof.context -> conv) -> Proof.context -> conv
    5.17 +  val under_quant_conv: (Proof.context * cterm list -> conv) ->
    5.18 +    Proof.context -> conv
    5.19    val prop_conv: conv -> conv
    5.20  end
    5.21  
    5.22 @@ -110,6 +114,12 @@
    5.23  fun dest_disj (@{const HOL.disj} $ t $ u) = (t, u)
    5.24    | dest_disj t = raise TERM ("not a disjunction", [t])
    5.25  
    5.26 +fun under_quant f t =
    5.27 +  (case t of
    5.28 +    Const (@{const_name All}, _) $ Abs (_, _, u) => under_quant f u
    5.29 +  | Const (@{const_name Ex}, _) $ Abs (_, _, u) => under_quant f u
    5.30 +  | _ => f t)
    5.31 +
    5.32  
    5.33  (* patterns and instantiations *)
    5.34  
    5.35 @@ -164,9 +174,23 @@
    5.36  
    5.37  fun if_true_conv pred cv = if_conv pred cv Conv.all_conv
    5.38  
    5.39 +fun if_exists_conv pred = if_true_conv (Term.exists_subterm pred)
    5.40 +
    5.41  fun binders_conv cv ctxt =
    5.42    Conv.binder_conv (binders_conv cv o snd) ctxt else_conv cv ctxt
    5.43  
    5.44 +fun under_quant_conv cv ctxt =
    5.45 +  let
    5.46 +    fun quant_conv inside ctxt cvs ct =
    5.47 +      (case Thm.term_of ct of
    5.48 +        Const (@{const_name All}, _) $ Abs _ =>
    5.49 +          Conv.binder_conv (under_conv cvs) ctxt
    5.50 +      | Const (@{const_name Ex}, _) $ Abs _ =>
    5.51 +          Conv.binder_conv (under_conv cvs) ctxt
    5.52 +      | _ => if inside then cv (ctxt, cvs) else Conv.all_conv) ct
    5.53 +    and under_conv cvs (cv, ctxt) = quant_conv true ctxt (cv :: cvs)
    5.54 +  in quant_conv false ctxt [] end
    5.55 +
    5.56  fun prop_conv cv ct =
    5.57    (case Thm.term_of ct of
    5.58      @{const Trueprop} $ _ => Conv.arg_conv cv ct
     6.1 --- a/src/HOL/Tools/SMT/smtlib_interface.ML	Wed Dec 15 08:39:24 2010 +0100
     6.2 +++ b/src/HOL/Tools/SMT/smtlib_interface.ML	Wed Dec 15 08:39:24 2010 +0100
     6.3 @@ -23,77 +23,6 @@
     6.4  val smtlibC = ["smtlib"]
     6.5  
     6.6  
     6.7 -
     6.8 -(* facts about uninterpreted constants *)
     6.9 -
    6.10 -infix 2 ??
    6.11 -fun (ex ?? f) irules = irules |> exists (ex o Thm.prop_of o snd) irules ? f
    6.12 -
    6.13 -
    6.14 -(** pairs **)
    6.15 -
    6.16 -val pair_rules = [@{thm fst_conv}, @{thm snd_conv}, @{thm pair_collapse}]
    6.17 -
    6.18 -val pair_type = (fn Type (@{type_name Product_Type.prod}, _) => true | _ => false)
    6.19 -val exists_pair_type = Term.exists_type (Term.exists_subtype pair_type)
    6.20 -
    6.21 -val add_pair_rules = exists_pair_type ?? append (map (pair ~1) pair_rules)
    6.22 -
    6.23 -
    6.24 -(** function update **)
    6.25 -
    6.26 -val fun_upd_rules = [@{thm fun_upd_same}, @{thm fun_upd_apply}]
    6.27 -
    6.28 -val is_fun_upd = (fn Const (@{const_name fun_upd}, _) => true | _ => false)
    6.29 -val exists_fun_upd = Term.exists_subterm is_fun_upd
    6.30 -
    6.31 -val add_fun_upd_rules = exists_fun_upd ?? append (map (pair ~1) fun_upd_rules)
    6.32 -
    6.33 -
    6.34 -(** abs/min/max **)
    6.35 -
    6.36 -val exists_abs_min_max = Term.exists_subterm (fn
    6.37 -    Const (@{const_name abs}, _) => true
    6.38 -  | Const (@{const_name min}, _) => true
    6.39 -  | Const (@{const_name max}, _) => true
    6.40 -  | _ => false)
    6.41 -
    6.42 -val unfold_abs_conv = Conv.rewr_conv (mk_meta_eq @{thm abs_if})
    6.43 -val unfold_min_conv = Conv.rewr_conv (mk_meta_eq @{thm min_def})
    6.44 -val unfold_max_conv = Conv.rewr_conv (mk_meta_eq @{thm max_def})
    6.45 -
    6.46 -fun expand_conv cv = N.eta_expand_conv (K cv)
    6.47 -fun expand2_conv cv = N.eta_expand_conv (N.eta_expand_conv (K cv))
    6.48 -
    6.49 -fun unfold_def_conv ctxt ct =
    6.50 -  (case Thm.term_of ct of
    6.51 -    Const (@{const_name abs}, _) $ _ => unfold_abs_conv
    6.52 -  | Const (@{const_name abs}, _) => expand_conv unfold_abs_conv ctxt
    6.53 -  | Const (@{const_name min}, _) $ _ $ _ => unfold_min_conv
    6.54 -  | Const (@{const_name min}, _) $ _ => expand_conv unfold_min_conv ctxt
    6.55 -  | Const (@{const_name min}, _) => expand2_conv unfold_min_conv ctxt
    6.56 -  | Const (@{const_name max}, _) $ _ $ _ => unfold_max_conv
    6.57 -  | Const (@{const_name max}, _) $ _ => expand_conv unfold_max_conv ctxt
    6.58 -  | Const (@{const_name max}, _) => expand2_conv unfold_max_conv ctxt
    6.59 -  | _ => Conv.all_conv) ct
    6.60 -
    6.61 -fun unfold_abs_min_max_defs ctxt thm =
    6.62 -  if exists_abs_min_max (Thm.prop_of thm)
    6.63 -  then Conv.fconv_rule (Conv.top_conv unfold_def_conv ctxt) thm
    6.64 -  else thm
    6.65 -
    6.66 -
    6.67 -(** include additional facts **)
    6.68 -
    6.69 -fun extra_norm has_datatypes irules ctxt =
    6.70 -  irules
    6.71 -  |> not has_datatypes ? add_pair_rules
    6.72 -  |> add_fun_upd_rules
    6.73 -  |> map (apsnd (unfold_abs_min_max_defs ctxt))
    6.74 -  |> rpair ctxt
    6.75 -
    6.76 -
    6.77 -
    6.78  (* builtins *)
    6.79  
    6.80  local
    6.81 @@ -131,7 +60,6 @@
    6.82  end
    6.83  
    6.84  
    6.85 -
    6.86  (* serialization *)
    6.87  
    6.88  (** header **)
    6.89 @@ -215,12 +143,10 @@
    6.90    |> Buffer.content
    6.91  
    6.92  
    6.93 -
    6.94  (* interface *)
    6.95  
    6.96  val interface = {
    6.97    class = smtlibC,
    6.98 -  extra_norm = extra_norm,
    6.99    translate = {
   6.100      prefixes = {
   6.101        sort_prefix = "S",
     7.1 --- a/src/HOL/Tools/SMT/z3_interface.ML	Wed Dec 15 08:39:24 2010 +0100
     7.2 +++ b/src/HOL/Tools/SMT/z3_interface.ML	Wed Dec 15 08:39:24 2010 +0100
     7.3 @@ -43,12 +43,13 @@
     7.4      | is_int_div_mod @{const mod (int)} = true
     7.5      | is_int_div_mod _ = false
     7.6  
     7.7 -  fun add_div_mod irules =
     7.8 -    if exists (Term.exists_subterm is_int_div_mod o Thm.prop_of o snd) irules
     7.9 -    then [(~1, @{thm div_by_z3div}), (~1, @{thm mod_by_z3mod})] @ irules
    7.10 -    else irules
    7.11 +  val have_int_div_mod =
    7.12 +    exists (Term.exists_subterm is_int_div_mod o Thm.prop_of)
    7.13  
    7.14 -  fun extra_norm' has_datatypes = extra_norm has_datatypes o add_div_mod
    7.15 +  fun add_div_mod _ (thms, extra_thms) =
    7.16 +    if have_int_div_mod thms orelse have_int_div_mod extra_thms then
    7.17 +      (thms, @{thm div_by_z3div} :: @{thm mod_by_z3mod} :: extra_thms)
    7.18 +    else (thms, extra_thms)
    7.19  
    7.20    val setup_builtins =
    7.21      B.add_builtin_fun' smtlib_z3C (@{const z3div}, "div") #>
    7.22 @@ -57,7 +58,6 @@
    7.23  
    7.24  val interface = {
    7.25    class = smtlib_z3C,
    7.26 -  extra_norm = extra_norm',
    7.27    translate = {
    7.28      prefixes = prefixes,
    7.29      is_fol = is_fol,
    7.30 @@ -65,7 +65,9 @@
    7.31      has_datatypes = true,
    7.32      serialize = serialize}}
    7.33  
    7.34 -val setup = Context.theory_map setup_builtins
    7.35 +val setup = Context.theory_map (
    7.36 +  setup_builtins #>
    7.37 +  SMT_Normalize.add_extra_norm (smtlib_z3C, add_div_mod))
    7.38  
    7.39  end
    7.40  
     8.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Dec 15 08:39:24 2010 +0100
     8.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Wed Dec 15 08:39:24 2010 +0100
     8.3 @@ -528,7 +528,7 @@
     8.4        #> Config.put SMT_Config.monomorph_limit smt_monomorph_limit
     8.5      val state = state |> Proof.map_context repair_context
     8.6      val thy = Proof.theory_of state
     8.7 -    val facts = facts |> map (apsnd (Thm.transfer thy) o untranslated_fact)
     8.8 +    val facts = facts |> map (apsnd (pair NONE o Thm.transfer thy) o untranslated_fact)
     8.9      val {outcome, used_facts, run_time_in_msecs} =
    8.10        smt_filter_loop params remote state subgoal facts
    8.11      val (chained_lemmas, other_lemmas) = split_used_facts (map fst used_facts)