src/HOL/Tools/SMT/smt_normalize.ML
changeset 41126 e0bd443c0fdd
parent 41072 9f9bc1bdacef
child 41173 7c6178d45cc8
child 41193 dc33b8ea4526
     1.1 --- a/src/HOL/Tools/SMT/smt_normalize.ML	Wed Dec 15 08:39:24 2010 +0100
     1.2 +++ b/src/HOL/Tools/SMT/smt_normalize.ML	Wed Dec 15 08:39:24 2010 +0100
     1.3 @@ -1,28 +1,17 @@
     1.4  (*  Title:      HOL/Tools/SMT/smt_normalize.ML
     1.5      Author:     Sascha Boehme, TU Muenchen
     1.6  
     1.7 -Normalization steps on theorems required by SMT solvers:
     1.8 -  * simplify trivial distincts (those with less than three elements),
     1.9 -  * rewrite bool case expressions as if expressions,
    1.10 -  * normalize numerals (e.g. replace negative numerals by negated positive
    1.11 -    numerals),
    1.12 -  * embed natural numbers into integers,
    1.13 -  * add extra rules specifying types and constants which occur frequently,
    1.14 -  * fully translate into object logic, add universal closure,
    1.15 -  * monomorphize (create instances of schematic rules),
    1.16 -  * lift lambda terms,
    1.17 -  * make applications explicit for functions with varying number of arguments.
    1.18 -  * add (hypothetical definitions for) missing datatype selectors,
    1.19 +Normalization steps on theorems required by SMT solvers.
    1.20  *)
    1.21  
    1.22  signature SMT_NORMALIZE =
    1.23  sig
    1.24 -  type extra_norm = bool -> (int * thm) list -> Proof.context ->
    1.25 -    (int * thm) list * Proof.context
    1.26 -  val normalize: extra_norm -> bool -> (int * thm) list -> Proof.context ->
    1.27 -    (int * thm) list * Proof.context
    1.28    val atomize_conv: Proof.context -> conv
    1.29 -  val eta_expand_conv: (Proof.context -> conv) -> Proof.context -> conv
    1.30 +  type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list
    1.31 +  val add_extra_norm: SMT_Utils.class * extra_norm -> Context.generic ->
    1.32 +    Context.generic
    1.33 +  val normalize: Proof.context -> (int * (int option * thm)) list ->
    1.34 +    (int * thm) list
    1.35    val setup: theory -> theory
    1.36  end
    1.37  
    1.38 @@ -32,12 +21,10 @@
    1.39  structure U = SMT_Utils
    1.40  structure B = SMT_Builtin
    1.41  
    1.42 -infix 2 ??
    1.43 -fun (test ?? f) x = if test x then f x else x
    1.44  
    1.45 -
    1.46 +(* general theorem normalizations *)
    1.47  
    1.48 -(* instantiate elimination rules *)
    1.49 +(** instantiate elimination rules **)
    1.50   
    1.51  local
    1.52    val (cpfalse, cfalse) = `U.mk_cprop (Thm.cterm_of @{theory} @{const False})
    1.53 @@ -56,281 +43,18 @@
    1.54  end
    1.55  
    1.56  
    1.57 -
    1.58 -(* simplification of trivial distincts (distinct should have at least
    1.59 -   three elements in the argument list) *)
    1.60 -
    1.61 -local
    1.62 -  fun is_trivial_distinct (Const (@{const_name distinct}, _) $ t) =
    1.63 -        (case try HOLogic.dest_list t of
    1.64 -          SOME [] => true
    1.65 -        | SOME [_] => true
    1.66 -        | _ => false)
    1.67 -    | is_trivial_distinct _ = false
    1.68 -
    1.69 -  val thms = map mk_meta_eq @{lemma
    1.70 -    "distinct [] = True"
    1.71 -    "distinct [x] = True"
    1.72 -    "distinct [x, y] = (x ~= y)"
    1.73 -    by simp_all}
    1.74 -  fun distinct_conv _ =
    1.75 -    U.if_true_conv is_trivial_distinct (Conv.rewrs_conv thms)
    1.76 -in
    1.77 -fun trivial_distinct ctxt =
    1.78 -  map (apsnd ((Term.exists_subterm is_trivial_distinct o Thm.prop_of) ??
    1.79 -    Conv.fconv_rule (Conv.top_conv distinct_conv ctxt)))
    1.80 -end
    1.81 -
    1.82 -
    1.83 -
    1.84 -(* rewrite bool case expressions as if expressions *)
    1.85 -
    1.86 -local
    1.87 -  val is_bool_case = (fn
    1.88 -      Const (@{const_name "bool.bool_case"}, _) $ _ $ _ $ _ => true
    1.89 -    | _ => false)
    1.90 +(** normalize definitions **)
    1.91  
    1.92 -  val thm = mk_meta_eq @{lemma
    1.93 -    "(case P of True => x | False => y) = (if P then x else y)" by simp}
    1.94 -  val unfold_conv = U.if_true_conv is_bool_case (Conv.rewr_conv thm)
    1.95 -in
    1.96 -fun rewrite_bool_cases ctxt =
    1.97 -  map (apsnd ((Term.exists_subterm is_bool_case o Thm.prop_of) ??
    1.98 -    Conv.fconv_rule (Conv.top_conv (K unfold_conv) ctxt)))
    1.99 -
   1.100 -val setup_bool_case = B.add_builtin_fun_ext'' @{const_name "bool.bool_case"}
   1.101 -
   1.102 -end
   1.103 -
   1.104 -
   1.105 -
   1.106 -(* normalization of numerals: rewriting of negative integer numerals into
   1.107 -   positive numerals, Numeral0 into 0, Numeral1 into 1 *)
   1.108 -
   1.109 -local
   1.110 -  fun is_number_sort ctxt T =
   1.111 -    Sign.of_sort (ProofContext.theory_of ctxt) (T, @{sort number_ring})
   1.112 -
   1.113 -  fun is_strange_number ctxt (t as Const (@{const_name number_of}, _) $ _) =
   1.114 -        (case try HOLogic.dest_number t of
   1.115 -          SOME (T, i) => is_number_sort ctxt T andalso i < 2
   1.116 -        | NONE => false)
   1.117 -    | is_strange_number _ _ = false
   1.118 -
   1.119 -  val pos_numeral_ss = HOL_ss
   1.120 -    addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}]
   1.121 -    addsimps [@{thm Int.number_of_Pls}, @{thm Int.numeral_1_eq_1}]
   1.122 -    addsimps @{thms Int.pred_bin_simps}
   1.123 -    addsimps @{thms Int.normalize_bin_simps}
   1.124 -    addsimps @{lemma
   1.125 -      "Int.Min = - Int.Bit1 Int.Pls"
   1.126 -      "Int.Bit0 (- Int.Pls) = - Int.Pls"
   1.127 -      "Int.Bit0 (- k) = - Int.Bit0 k"
   1.128 -      "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)"
   1.129 -      by simp_all (simp add: pred_def)}
   1.130 -
   1.131 -  fun pos_conv ctxt = U.if_conv (is_strange_number ctxt)
   1.132 -    (Simplifier.rewrite (Simplifier.context ctxt pos_numeral_ss))
   1.133 -    Conv.no_conv
   1.134 -in
   1.135 -fun normalize_numerals ctxt =
   1.136 -  map (apsnd ((Term.exists_subterm (is_strange_number ctxt) o Thm.prop_of) ??
   1.137 -    Conv.fconv_rule (Conv.top_sweep_conv pos_conv ctxt)))
   1.138 -end
   1.139 -
   1.140 +fun norm_def thm =
   1.141 +  (case Thm.prop_of thm of
   1.142 +    @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ _ $ Abs _) =>
   1.143 +      norm_def (thm RS @{thm fun_cong})
   1.144 +  | Const (@{const_name "=="}, _) $ _ $ Abs _ =>
   1.145 +      norm_def (thm RS @{thm meta_eq_to_obj_eq})
   1.146 +  | _ => thm)
   1.147  
   1.148  
   1.149 -(* embedding of standard natural number operations into integer operations *)
   1.150 -
   1.151 -local
   1.152 -  val nat_embedding = map (pair ~1) @{lemma
   1.153 -    "nat (int n) = n"
   1.154 -    "i >= 0 --> int (nat i) = i"
   1.155 -    "i < 0 --> int (nat i) = 0"
   1.156 -    by simp_all}
   1.157 -
   1.158 -  val nat_rewriting = @{lemma
   1.159 -    "0 = nat 0"
   1.160 -    "1 = nat 1"
   1.161 -    "(number_of :: int => nat) = (%i. nat (number_of i))"
   1.162 -    "int (nat 0) = 0"
   1.163 -    "int (nat 1) = 1"
   1.164 -    "op < = (%a b. int a < int b)"
   1.165 -    "op <= = (%a b. int a <= int b)"
   1.166 -    "Suc = (%a. nat (int a + 1))"
   1.167 -    "op + = (%a b. nat (int a + int b))"
   1.168 -    "op - = (%a b. nat (int a - int b))"
   1.169 -    "op * = (%a b. nat (int a * int b))"
   1.170 -    "op div = (%a b. nat (int a div int b))"
   1.171 -    "op mod = (%a b. nat (int a mod int b))"
   1.172 -    "min = (%a b. nat (min (int a) (int b)))"
   1.173 -    "max = (%a b. nat (max (int a) (int b)))"
   1.174 -    "int (nat (int a + int b)) = int a + int b"
   1.175 -    "int (nat (int a + 1)) = int a + 1"  (* special rule due to Suc above *)
   1.176 -    "int (nat (int a * int b)) = int a * int b"
   1.177 -    "int (nat (int a div int b)) = int a div int b"
   1.178 -    "int (nat (int a mod int b)) = int a mod int b"
   1.179 -    "int (nat (min (int a) (int b))) = min (int a) (int b)"
   1.180 -    "int (nat (max (int a) (int b))) = max (int a) (int b)"
   1.181 -    by (auto intro!: ext simp add: nat_mult_distrib nat_div_distrib
   1.182 -      nat_mod_distrib int_mult[symmetric] zdiv_int[symmetric]
   1.183 -      zmod_int[symmetric])}
   1.184 -
   1.185 -  fun on_positive num f x = 
   1.186 -    (case try HOLogic.dest_number (Thm.term_of num) of
   1.187 -      SOME (_, i) => if i >= 0 then SOME (f x) else NONE
   1.188 -    | NONE => NONE)
   1.189 -
   1.190 -  val cancel_int_nat_ss = HOL_ss
   1.191 -    addsimps [@{thm Nat_Numeral.nat_number_of}]
   1.192 -    addsimps [@{thm Nat_Numeral.int_nat_number_of}]
   1.193 -    addsimps @{thms neg_simps}
   1.194 -
   1.195 -  val int_eq = Thm.cterm_of @{theory} @{const "==" (int)}
   1.196 -
   1.197 -  fun cancel_int_nat_simproc _ ss ct = 
   1.198 -    let
   1.199 -      val num = Thm.dest_arg (Thm.dest_arg ct)
   1.200 -      val goal = Thm.mk_binop int_eq ct num
   1.201 -      val simpset = Simplifier.inherit_context ss cancel_int_nat_ss
   1.202 -      fun tac _ = Simplifier.simp_tac simpset 1
   1.203 -    in on_positive num (Goal.prove_internal [] goal) tac end
   1.204 -
   1.205 -  val nat_ss = HOL_ss
   1.206 -    addsimps nat_rewriting
   1.207 -    addsimprocs [
   1.208 -      Simplifier.make_simproc {
   1.209 -        name = "cancel_int_nat_num", lhss = [@{cpat "int (nat _)"}],
   1.210 -        proc = cancel_int_nat_simproc, identifier = [] }]
   1.211 -
   1.212 -  fun conv ctxt = Simplifier.rewrite (Simplifier.context ctxt nat_ss)
   1.213 -
   1.214 -  val uses_nat_type = Term.exists_type (Term.exists_subtype (equal @{typ nat}))
   1.215 -  val uses_nat_int = Term.exists_subterm (member (op aconv)
   1.216 -    [@{const of_nat (int)}, @{const nat}])
   1.217 -
   1.218 -  val nat_ops = [
   1.219 -    @{const less (nat)}, @{const less_eq (nat)},
   1.220 -    @{const Suc}, @{const plus (nat)}, @{const minus (nat)},
   1.221 -    @{const times (nat)}, @{const div (nat)}, @{const mod (nat)}]
   1.222 -  val nat_ops' = @{const of_nat (int)} :: @{const nat} :: nat_ops
   1.223 -in
   1.224 -fun nat_as_int ctxt =
   1.225 -  map (apsnd ((uses_nat_type o Thm.prop_of) ?? Conv.fconv_rule (conv ctxt))) #>
   1.226 -  exists (uses_nat_int o Thm.prop_of o snd) ?? append nat_embedding
   1.227 -
   1.228 -val setup_nat_as_int =
   1.229 -  B.add_builtin_typ_ext (@{typ nat}, K true) #>
   1.230 -  fold (B.add_builtin_fun_ext' o Term.dest_Const) nat_ops'
   1.231 -end
   1.232 -
   1.233 -
   1.234 -
   1.235 -(* further normalizations: beta/eta, universal closure, atomize *)
   1.236 -
   1.237 -val eta_expand_eq = @{lemma "f == (%x. f x)" by (rule reflexive)}
   1.238 -
   1.239 -fun eta_expand_conv cv ctxt =
   1.240 -  Conv.rewr_conv eta_expand_eq then_conv Conv.abs_conv (cv o snd) ctxt
   1.241 -
   1.242 -local
   1.243 -  val eta_conv = eta_expand_conv
   1.244 -
   1.245 -  fun args_conv cv ct =
   1.246 -    (case Thm.term_of ct of
   1.247 -      _ $ _ => Conv.combination_conv (args_conv cv) cv
   1.248 -    | _ => Conv.all_conv) ct
   1.249 -
   1.250 -  fun eta_args_conv cv 0 = args_conv o cv
   1.251 -    | eta_args_conv cv i = eta_conv (eta_args_conv cv (i-1))
   1.252 -
   1.253 -  fun keep_conv ctxt = Conv.binder_conv (norm_conv o snd) ctxt
   1.254 -  and eta_binder_conv ctxt = Conv.arg_conv (eta_conv norm_conv ctxt)
   1.255 -  and keep_let_conv ctxt = Conv.combination_conv
   1.256 -    (Conv.arg_conv (norm_conv ctxt)) (Conv.abs_conv (norm_conv o snd) ctxt)
   1.257 -  and unfold_let_conv ctxt = Conv.combination_conv
   1.258 -    (Conv.arg_conv (norm_conv ctxt)) (eta_conv norm_conv ctxt)
   1.259 -  and unfold_conv thm ctxt = Conv.rewr_conv thm then_conv keep_conv ctxt
   1.260 -  and unfold_ex1_conv ctxt = unfold_conv @{thm Ex1_def} ctxt
   1.261 -  and unfold_ball_conv ctxt = unfold_conv (mk_meta_eq @{thm Ball_def}) ctxt
   1.262 -  and unfold_bex_conv ctxt = unfold_conv (mk_meta_eq @{thm Bex_def}) ctxt
   1.263 -  and norm_conv ctxt ct =
   1.264 -    (case Thm.term_of ct of
   1.265 -      Const (@{const_name All}, _) $ Abs _ => keep_conv
   1.266 -    | Const (@{const_name All}, _) $ _ => eta_binder_conv
   1.267 -    | Const (@{const_name All}, _) => eta_conv eta_binder_conv
   1.268 -    | Const (@{const_name Ex}, _) $ Abs _ => keep_conv
   1.269 -    | Const (@{const_name Ex}, _) $ _ => eta_binder_conv
   1.270 -    | Const (@{const_name Ex}, _) => eta_conv eta_binder_conv
   1.271 -    | Const (@{const_name Let}, _) $ _ $ Abs _ => keep_let_conv
   1.272 -    | Const (@{const_name Let}, _) $ _ $ _ => unfold_let_conv
   1.273 -    | Const (@{const_name Let}, _) $ _ => eta_conv unfold_let_conv
   1.274 -    | Const (@{const_name Let}, _) => eta_conv (eta_conv unfold_let_conv)
   1.275 -    | Const (@{const_name Ex1}, _) $ _ => unfold_ex1_conv
   1.276 -    | Const (@{const_name Ex1}, _) => eta_conv unfold_ex1_conv 
   1.277 -    | Const (@{const_name Ball}, _) $ _ $ _ => unfold_ball_conv
   1.278 -    | Const (@{const_name Ball}, _) $ _ => eta_conv unfold_ball_conv
   1.279 -    | Const (@{const_name Ball}, _) => eta_conv (eta_conv unfold_ball_conv)
   1.280 -    | Const (@{const_name Bex}, _) $ _ $ _ => unfold_bex_conv
   1.281 -    | Const (@{const_name Bex}, _) $ _ => eta_conv unfold_bex_conv
   1.282 -    | Const (@{const_name Bex}, _) => eta_conv (eta_conv unfold_bex_conv)
   1.283 -    | Abs _ => Conv.abs_conv (norm_conv o snd)
   1.284 -    | _ =>
   1.285 -        (case Term.strip_comb (Thm.term_of ct) of
   1.286 -          (Const (c as (_, T)), ts) =>
   1.287 -            if SMT_Builtin.is_builtin_fun ctxt c ts
   1.288 -            then eta_args_conv norm_conv
   1.289 -              (length (Term.binder_types T) - length ts)
   1.290 -            else args_conv o norm_conv
   1.291 -        | _ => args_conv o norm_conv)) ctxt ct
   1.292 -
   1.293 -  fun is_normed ctxt t =
   1.294 -    (case t of
   1.295 -      Const (@{const_name All}, _) $ Abs (_, _, u) => is_normed ctxt u
   1.296 -    | Const (@{const_name All}, _) $ _ => false
   1.297 -    | Const (@{const_name All}, _) => false
   1.298 -    | Const (@{const_name Ex}, _) $ Abs (_, _, u) => is_normed ctxt u
   1.299 -    | Const (@{const_name Ex}, _) $ _ => false
   1.300 -    | Const (@{const_name Ex}, _) => false
   1.301 -    | Const (@{const_name Let}, _) $ u1 $ Abs (_, _, u2) =>
   1.302 -        is_normed ctxt u1 andalso is_normed ctxt u2
   1.303 -    | Const (@{const_name Let}, _) $ _ $ _ => false
   1.304 -    | Const (@{const_name Let}, _) $ _ => false
   1.305 -    | Const (@{const_name Let}, _) => false
   1.306 -    | Const (@{const_name Ex1}, _) $ _ => false
   1.307 -    | Const (@{const_name Ex1}, _) => false
   1.308 -    | Const (@{const_name Ball}, _) $ _ $ _ => false
   1.309 -    | Const (@{const_name Ball}, _) $ _ => false
   1.310 -    | Const (@{const_name Ball}, _) => false
   1.311 -    | Const (@{const_name Bex}, _) $ _ $ _ => false
   1.312 -    | Const (@{const_name Bex}, _) $ _ => false
   1.313 -    | Const (@{const_name Bex}, _) => false
   1.314 -    | Abs (_, _, u) => is_normed ctxt u
   1.315 -    | _ =>
   1.316 -        (case Term.strip_comb t of
   1.317 -          (Const (c as (_, T)), ts) =>
   1.318 -            if SMT_Builtin.is_builtin_fun ctxt c ts
   1.319 -            then length (Term.binder_types T) = length ts andalso
   1.320 -              forall (is_normed ctxt) ts
   1.321 -            else forall (is_normed ctxt) ts
   1.322 -        | (_, ts) => forall (is_normed ctxt) ts))
   1.323 -in
   1.324 -fun norm_binder_conv ctxt =
   1.325 -  U.if_conv (is_normed ctxt) Conv.all_conv (norm_conv ctxt)
   1.326 -
   1.327 -val setup_unfolded_quants =
   1.328 -  fold B.add_builtin_fun_ext'' [@{const_name Ball}, @{const_name Bex},
   1.329 -    @{const_name Ex1}]
   1.330 -
   1.331 -end
   1.332 -
   1.333 -fun norm_def ctxt thm =
   1.334 -  (case Thm.prop_of thm of
   1.335 -    @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ _ $ Abs _) =>
   1.336 -      norm_def ctxt (thm RS @{thm fun_cong})
   1.337 -  | Const (@{const_name "=="}, _) $ _ $ Abs _ =>
   1.338 -      norm_def ctxt (thm RS @{thm meta_eq_to_obj_eq})
   1.339 -  | _ => thm)
   1.340 +(** atomization **)
   1.341  
   1.342  fun atomize_conv ctxt ct =
   1.343    (case Thm.term_of ct of
   1.344 @@ -349,243 +73,543 @@
   1.345    fold B.add_builtin_fun_ext'' [@{const_name "==>"}, @{const_name "=="},
   1.346      @{const_name all}, @{const_name Trueprop}]
   1.347  
   1.348 -fun normalize_rule ctxt =
   1.349 -  Conv.fconv_rule (
   1.350 -    (* reduce lambda abstractions, except at known binders: *)
   1.351 -    Thm.beta_conversion true then_conv
   1.352 -    Thm.eta_conversion then_conv
   1.353 -    norm_binder_conv ctxt) #>
   1.354 -  norm_def ctxt #>
   1.355 -  Drule.forall_intr_vars #>
   1.356 -  Conv.fconv_rule (atomize_conv ctxt)
   1.357  
   1.358 -
   1.359 -
   1.360 -(* lift lambda terms into additional rules *)
   1.361 +(** unfold special quantifiers **)
   1.362  
   1.363  local
   1.364 -  fun used_vars cvs ct =
   1.365 -    let
   1.366 -      val lookup = AList.lookup (op aconv) (map (` Thm.term_of) cvs)
   1.367 -      val add = (fn SOME ct => insert (op aconvc) ct | _ => I)
   1.368 -    in Term.fold_aterms (add o lookup) (Thm.term_of ct) [] end
   1.369 -
   1.370 -  fun apply cv thm = 
   1.371 -    let val thm' = Thm.combination thm (Thm.reflexive cv)
   1.372 -    in Thm.transitive thm' (Thm.beta_conversion false (Thm.rhs_of thm')) end
   1.373 -  fun apply_def cvs eq = Thm.symmetric (fold apply cvs eq)
   1.374 +  val ex1_def = mk_meta_eq @{lemma
   1.375 +    "Ex1 = (%P. EX x. P x & (ALL y. P y --> y = x))"
   1.376 +    by (rule ext) (simp only: Ex1_def)}
   1.377  
   1.378 -  fun replace_lambda cvs ct (cx as (ctxt, defs)) =
   1.379 -    let
   1.380 -      val cvs' = used_vars cvs ct
   1.381 -      val ct' = fold_rev Thm.cabs cvs' ct
   1.382 -    in
   1.383 -      (case Termtab.lookup defs (Thm.term_of ct') of
   1.384 -        SOME eq => (apply_def cvs' eq, cx)
   1.385 -      | NONE =>
   1.386 -          let
   1.387 -            val {T, ...} = Thm.rep_cterm ct' and n = Name.uu
   1.388 -            val (n', ctxt') = yield_singleton Variable.variant_fixes n ctxt
   1.389 -            val cu = U.mk_cequals (U.certify ctxt (Free (n', T))) ct'
   1.390 -            val (eq, ctxt'') = yield_singleton Assumption.add_assumes cu ctxt'
   1.391 -            val defs' = Termtab.update (Thm.term_of ct', eq) defs
   1.392 -          in (apply_def cvs' eq, (ctxt'', defs')) end)
   1.393 -    end
   1.394 +  val ball_def = mk_meta_eq @{lemma "Ball = (%A P. ALL x. x : A --> P x)"
   1.395 +    by (rule ext)+ (rule Ball_def)}
   1.396 +
   1.397 +  val bex_def = mk_meta_eq @{lemma "Bex = (%A P. EX x. x : A & P x)"
   1.398 +    by (rule ext)+ (rule Bex_def)}
   1.399  
   1.400 -  fun none ct cx = (Thm.reflexive ct, cx)
   1.401 -  fun in_comb f g ct cx =
   1.402 -    let val (cu1, cu2) = Thm.dest_comb ct
   1.403 -    in cx |> f cu1 ||>> g cu2 |>> uncurry Thm.combination end
   1.404 -  fun in_arg f = in_comb none f
   1.405 -  fun in_abs f cvs ct (ctxt, defs) =
   1.406 -    let
   1.407 -      val (n, ctxt') = yield_singleton Variable.variant_fixes Name.uu ctxt
   1.408 -      val (cv, cu) = Thm.dest_abs (SOME n) ct
   1.409 -    in  (ctxt', defs) |> f (cv :: cvs) cu |>> Thm.abstract_rule n cv end
   1.410 -
   1.411 -  fun traverse cvs ct =
   1.412 -    (case Thm.term_of ct of
   1.413 -      Const (@{const_name All}, _) $ Abs _ => in_arg (in_abs traverse cvs)
   1.414 -    | Const (@{const_name Ex}, _) $ Abs _ => in_arg (in_abs traverse cvs)
   1.415 -    | Const (@{const_name Let}, _) $ _ $ Abs _ =>
   1.416 -        in_comb (in_arg (traverse cvs)) (in_abs traverse cvs)
   1.417 -    | Abs _ => at_lambda cvs
   1.418 -    | _ $ _ => in_comb (traverse cvs) (traverse cvs)
   1.419 -    | _ => none) ct
   1.420 +  val special_quants = [(@{const_name Ex1}, ex1_def),
   1.421 +    (@{const_name Ball}, ball_def), (@{const_name Bex}, bex_def)]
   1.422 +  
   1.423 +  fun special_quant (Const (n, _)) = AList.lookup (op =) special_quants n
   1.424 +    | special_quant _ = NONE
   1.425  
   1.426 -  and at_lambda cvs ct =
   1.427 -    in_abs traverse cvs ct #-> (fn thm =>
   1.428 -    replace_lambda cvs (Thm.rhs_of thm) #>> Thm.transitive thm)
   1.429 +  fun special_quant_conv _ ct =
   1.430 +    (case special_quant (Thm.term_of ct) of
   1.431 +      SOME thm => Conv.rewr_conv thm
   1.432 +    | NONE => Conv.all_conv) ct
   1.433 +in
   1.434  
   1.435 -  fun has_free_lambdas t =
   1.436 -    (case t of
   1.437 -      Const (@{const_name All}, _) $ Abs (_, _, u) => has_free_lambdas u
   1.438 -    | Const (@{const_name Ex}, _) $ Abs (_, _, u) => has_free_lambdas u
   1.439 -    | Const (@{const_name Let}, _) $ u1 $ Abs (_, _, u2) =>
   1.440 -        has_free_lambdas u1 orelse has_free_lambdas u2
   1.441 -    | Abs _ => true
   1.442 -    | u1 $ u2 => has_free_lambdas u1 orelse has_free_lambdas u2
   1.443 -    | _ => false)
   1.444 +fun unfold_special_quants_conv ctxt =
   1.445 +  U.if_exists_conv (is_some o special_quant)
   1.446 +    (Conv.top_conv special_quant_conv ctxt)
   1.447  
   1.448 -  fun lift_lm f thm cx =
   1.449 -    if not (has_free_lambdas (Thm.prop_of thm)) then (thm, cx)
   1.450 -    else cx |> f (Thm.cprop_of thm) |>> (fn thm' => Thm.equal_elim thm' thm)
   1.451 -in
   1.452 -fun lift_lambdas irules ctxt =
   1.453 -  let
   1.454 -    val cx = (ctxt, Termtab.empty)
   1.455 -    val (idxs, thms) = split_list irules
   1.456 -    val (thms', (ctxt', defs)) = fold_map (lift_lm (traverse [])) thms cx
   1.457 -    val eqs = Termtab.fold (cons o normalize_rule ctxt' o snd) defs []
   1.458 -  in (map (pair ~1) eqs @ (idxs ~~ thms'), ctxt') end
   1.459 +val setup_unfolded_quants = fold (B.add_builtin_fun_ext'' o fst) special_quants
   1.460 +
   1.461  end
   1.462  
   1.463  
   1.464 -
   1.465 -(* make application explicit for functions with varying number of arguments *)
   1.466 +(** trigger inference **)
   1.467  
   1.468  local
   1.469 -  val const = prefix "c" and free = prefix "f"
   1.470 -  fun min i (e as (_, j)) = if i <> j then (true, Int.min (i, j)) else e
   1.471 -  fun add t i = Symtab.map_default (t, (false, i)) (min i)
   1.472 -  fun traverse t =
   1.473 +  (*** check trigger syntax ***)
   1.474 +
   1.475 +  fun dest_trigger (Const (@{const_name pat}, _) $ _) = SOME true
   1.476 +    | dest_trigger (Const (@{const_name nopat}, _) $ _) = SOME false
   1.477 +    | dest_trigger _ = NONE
   1.478 +
   1.479 +  fun eq_list [] = false
   1.480 +    | eq_list (b :: bs) = forall (equal b) bs
   1.481 +
   1.482 +  fun proper_trigger t =
   1.483 +    t
   1.484 +    |> these o try HOLogic.dest_list
   1.485 +    |> map (map_filter dest_trigger o these o try HOLogic.dest_list)
   1.486 +    |> (fn [] => false | bss => forall eq_list bss)
   1.487 +
   1.488 +  fun proper_quant inside f t =
   1.489 +    (case t of
   1.490 +      Const (@{const_name All}, _) $ Abs (_, _, u) => proper_quant true f u
   1.491 +    | Const (@{const_name Ex}, _) $ Abs (_, _, u) => proper_quant true f u
   1.492 +    | @{const trigger} $ p $ u =>
   1.493 +        (if inside then f p else false) andalso proper_quant false f u
   1.494 +    | Abs (_, _, u) => proper_quant false f u
   1.495 +    | u1 $ u2 => proper_quant false f u1 andalso proper_quant false f u2
   1.496 +    | _ => true)
   1.497 +
   1.498 +  fun check_trigger_error ctxt t =
   1.499 +    error ("SMT triggers must only occur under quantifier and multipatterns " ^
   1.500 +      "must have the same kind: " ^ Syntax.string_of_term ctxt t)
   1.501 +
   1.502 +  fun check_trigger_conv ctxt ct =
   1.503 +    if proper_quant false proper_trigger (Thm.term_of ct) then Conv.all_conv ct
   1.504 +    else check_trigger_error ctxt (Thm.term_of ct)
   1.505 +
   1.506 +
   1.507 +  (*** infer simple triggers ***)
   1.508 +
   1.509 +  fun dest_cond_eq ct =
   1.510 +    (case Thm.term_of ct of
   1.511 +      Const (@{const_name HOL.eq}, _) $ _ $ _ => Thm.dest_binop ct
   1.512 +    | @{const HOL.implies} $ _ $ _ => dest_cond_eq (Thm.dest_arg ct)
   1.513 +    | _ => raise CTERM ("no equation", [ct]))
   1.514 +
   1.515 +  fun get_constrs thy (Type (n, _)) = these (Datatype.get_constrs thy n)
   1.516 +    | get_constrs _ _ = []
   1.517 +
   1.518 +  fun is_constr thy (n, T) =
   1.519 +    let fun match (m, U) = m = n andalso Sign.typ_instance thy (T, U)
   1.520 +    in can (the o find_first match o get_constrs thy o Term.body_type) T end
   1.521 +
   1.522 +  fun is_constr_pat thy t =
   1.523 +    (case Term.strip_comb t of
   1.524 +      (Free _, []) => true
   1.525 +    | (Const c, ts) => is_constr thy c andalso forall (is_constr_pat thy) ts
   1.526 +    | _ => false)
   1.527 +
   1.528 +  fun is_simp_lhs ctxt t =
   1.529      (case Term.strip_comb t of
   1.530 -      (Const (n, _), ts) => add (const n) (length ts) #> fold traverse ts 
   1.531 -    | (Free (n, _), ts) => add (free n) (length ts) #> fold traverse ts
   1.532 -    | (Abs (_, _, u), ts) => fold traverse (u :: ts)
   1.533 -    | (_, ts) => fold traverse ts)
   1.534 -  fun prune tab = Symtab.fold (fn (n, (true, i)) =>
   1.535 -    Symtab.update (n, i) | _ => I) tab Symtab.empty
   1.536 +      (Const c, ts as _ :: _) =>
   1.537 +        not (B.is_builtin_fun_ext ctxt c ts) andalso
   1.538 +        forall (is_constr_pat (ProofContext.theory_of ctxt)) ts
   1.539 +    | _ => false)
   1.540 +
   1.541 +  fun has_all_vars vs t =
   1.542 +    subset (op aconv) (vs, map Free (Term.add_frees t []))
   1.543 +
   1.544 +  fun minimal_pats vs ct =
   1.545 +    if has_all_vars vs (Thm.term_of ct) then
   1.546 +      (case Thm.term_of ct of
   1.547 +        _ $ _ =>
   1.548 +          (case pairself (minimal_pats vs) (Thm.dest_comb ct) of
   1.549 +            ([], []) => [[ct]]
   1.550 +          | (ctss, ctss') => union (eq_set (op aconvc)) ctss ctss')
   1.551 +      | _ => [[ct]])
   1.552 +    else []
   1.553 +
   1.554 +  fun proper_mpat _ _ _ [] = false
   1.555 +    | proper_mpat thy gen u cts =
   1.556 +        let
   1.557 +          val tps = (op ~~) (`gen (map Thm.term_of cts))
   1.558 +          fun some_match u = tps |> exists (fn (t', t) =>
   1.559 +            Pattern.matches thy (t', u) andalso not (t aconv u))
   1.560 +        in not (Term.exists_subterm some_match u) end
   1.561 +
   1.562 +  val pat = U.mk_const_pat @{theory} @{const_name SMT.pat} U.destT1
   1.563 +  fun mk_pat ct = Thm.capply (U.instT' ct pat) ct
   1.564  
   1.565 -  fun binop_conv cv1 cv2 = Conv.combination_conv (Conv.arg_conv cv1) cv2
   1.566 -  fun nary_conv conv1 conv2 ct =
   1.567 -    (Conv.combination_conv (nary_conv conv1 conv2) conv2 else_conv conv1) ct
   1.568 -  fun abs_conv conv tb = Conv.abs_conv (fn (cv, cx) =>
   1.569 -    let val n = fst (Term.dest_Free (Thm.term_of cv))
   1.570 -    in conv (Symtab.update (free n, 0) tb) cx end)
   1.571 -  val fun_app_rule = @{lemma "f x == fun_app f x" by (simp add: fun_app_def)}
   1.572 +  fun mk_clist T = pairself (Thm.cterm_of @{theory})
   1.573 +    (HOLogic.cons_const T, HOLogic.nil_const T)
   1.574 +  fun mk_list (ccons, cnil) f cts = fold_rev (Thm.mk_binop ccons o f) cts cnil
   1.575 +  val mk_pat_list = mk_list (mk_clist @{typ SMT.pattern})
   1.576 +  val mk_mpat_list = mk_list (mk_clist @{typ "SMT.pattern list"})  
   1.577 +  fun mk_trigger ctss = mk_mpat_list (mk_pat_list mk_pat) ctss
   1.578 +
   1.579 +  val trigger_eq =
   1.580 +    mk_meta_eq @{lemma "p = SMT.trigger t p" by (simp add: trigger_def)}
   1.581 +
   1.582 +  fun insert_trigger_conv [] ct = Conv.all_conv ct
   1.583 +    | insert_trigger_conv ctss ct =
   1.584 +        let val (ctr, cp) = Thm.dest_binop (Thm.rhs_of trigger_eq) ||> rpair ct
   1.585 +        in Thm.instantiate ([], [cp, (ctr, mk_trigger ctss)]) trigger_eq end
   1.586 +
   1.587 +  fun infer_trigger_eq_conv outer_ctxt (ctxt, cvs) ct =
   1.588 +    let
   1.589 +      val (lhs, rhs) = dest_cond_eq ct
   1.590 +
   1.591 +      val vs = map Thm.term_of cvs
   1.592 +      val thy = ProofContext.theory_of ctxt
   1.593 +
   1.594 +      fun get_mpats ct =
   1.595 +        if is_simp_lhs ctxt (Thm.term_of ct) then minimal_pats vs ct
   1.596 +        else []
   1.597 +      val gen = Variable.export_terms ctxt outer_ctxt
   1.598 +      val filter_mpats = filter (proper_mpat thy gen (Thm.term_of rhs))
   1.599 +
   1.600 +    in insert_trigger_conv (filter_mpats (get_mpats lhs)) ct end
   1.601 +
   1.602 +  fun try_trigger_conv cv ct =
   1.603 +    if proper_quant false (K false) (Thm.term_of ct) then Conv.all_conv ct
   1.604 +    else Conv.try_conv cv ct
   1.605 +
   1.606 +  fun infer_trigger_conv ctxt =
   1.607 +    if Config.get ctxt SMT_Config.infer_triggers then
   1.608 +      try_trigger_conv (U.under_quant_conv (infer_trigger_eq_conv ctxt) ctxt)
   1.609 +    else Conv.all_conv
   1.610  in
   1.611 -fun explicit_application ctxt irules =
   1.612 -  let
   1.613 -    fun sub_conv tb ctxt ct =
   1.614 -      (case Term.strip_comb (Thm.term_of ct) of
   1.615 -        (Const (n, _), ts) => app_conv tb (const n) (length ts) ctxt
   1.616 -      | (Free (n, _), ts) => app_conv tb (free n) (length ts) ctxt
   1.617 -      | (Abs _, _) => nary_conv (abs_conv sub_conv tb ctxt) (sub_conv tb ctxt)
   1.618 -      | (_, _) => nary_conv Conv.all_conv (sub_conv tb ctxt)) ct
   1.619 -    and app_conv tb n i ctxt =
   1.620 -      (case Symtab.lookup tb n of
   1.621 -        NONE => nary_conv Conv.all_conv (sub_conv tb ctxt)
   1.622 -      | SOME j => fun_app_conv tb ctxt (i - j))
   1.623 -    and fun_app_conv tb ctxt i ct = (
   1.624 -      if i = 0 then nary_conv Conv.all_conv (sub_conv tb ctxt)
   1.625 -      else
   1.626 -        Conv.rewr_conv fun_app_rule then_conv
   1.627 -        binop_conv (fun_app_conv tb ctxt (i-1)) (sub_conv tb ctxt)) ct
   1.628 +
   1.629 +fun trigger_conv ctxt =
   1.630 +  U.prop_conv (check_trigger_conv ctxt then_conv infer_trigger_conv ctxt)
   1.631  
   1.632 -    fun needs_exp_app tab = Term.exists_subterm (fn
   1.633 -        Bound _ $ _ => true
   1.634 -      | Const (n, _) => Symtab.defined tab (const n)
   1.635 -      | Free (n, _) => Symtab.defined tab (free n)
   1.636 -      | _ => false)
   1.637 +val setup_trigger = fold B.add_builtin_fun_ext''
   1.638 +  [@{const_name SMT.pat}, @{const_name SMT.nopat}, @{const_name SMT.trigger}]
   1.639  
   1.640 -    fun rewrite tab ctxt thm =
   1.641 -      if not (needs_exp_app tab (Thm.prop_of thm)) then thm
   1.642 -      else Conv.fconv_rule (sub_conv tab ctxt) thm
   1.643 -
   1.644 -    val tab = prune (fold (traverse o Thm.prop_of o snd) irules Symtab.empty)
   1.645 -  in map (apsnd (rewrite tab ctxt)) irules end
   1.646  end
   1.647  
   1.648  
   1.649 -
   1.650 -(* add missing datatype selectors via hypothetical definitions *)
   1.651 +(** adding quantifier weights **)
   1.652  
   1.653  local
   1.654 -  val add = (fn Type (n, _) => Symtab.update (n, ()) | _ => I)
   1.655 +  (*** check weight syntax ***)
   1.656 +
   1.657 +  val has_no_weight =
   1.658 +    not o Term.exists_subterm (fn @{const SMT.weight} => true | _ => false)
   1.659  
   1.660 -  fun collect t =
   1.661 -    (case Term.strip_comb t of
   1.662 -      (Abs (_, T, t), _) => add T #> collect t
   1.663 -    | (Const (_, T), ts) => collects T ts
   1.664 -    | (Free (_, T), ts) => collects T ts
   1.665 -    | _ => I)
   1.666 -  and collects T ts =
   1.667 -    let val ((Ts, Us), U) = Term.strip_type T |> apfst (chop (length ts))
   1.668 -    in fold add Ts #> add (Us ---> U) #> fold collect ts end
   1.669 +  fun is_weight (@{const SMT.weight} $ w $ t) =
   1.670 +        (case try HOLogic.dest_number w of
   1.671 +          SOME (_, i) => i > 0 andalso has_no_weight t
   1.672 +        | _ => false)
   1.673 +    | is_weight t = has_no_weight t
   1.674 +
   1.675 +  fun proper_trigger (@{const SMT.trigger} $ _ $ t) = is_weight t
   1.676 +    | proper_trigger t = has_no_weight t
   1.677 +
   1.678 +  fun check_weight_error ctxt t =
   1.679 +    error ("SMT weight must be a positive number and must only occur " ^
   1.680 +      "under the top-most quantifier and an optional trigger: " ^
   1.681 +      Syntax.string_of_term ctxt t)
   1.682  
   1.683 -  fun add_constructors thy n =
   1.684 -    (case Datatype.get_info thy n of
   1.685 -      NONE => I
   1.686 -    | SOME {descr, ...} => fold (fn (_, (_, _, cs)) => fold (fn (n, ds) =>
   1.687 -        fold (insert (op =) o pair n) (1 upto length ds)) cs) descr)
   1.688 +  fun check_weight_conv ctxt ct =
   1.689 +    if U.under_quant proper_trigger (Thm.term_of ct) then Conv.all_conv ct
   1.690 +    else check_weight_error ctxt (Thm.term_of ct)
   1.691 +
   1.692 +
   1.693 +  (*** insertion of weights ***)
   1.694 +
   1.695 +  fun under_trigger_conv cv ct =
   1.696 +    (case Thm.term_of ct of
   1.697 +      @{const SMT.trigger} $ _ $ _ => Conv.arg_conv cv
   1.698 +    | _ => cv) ct
   1.699  
   1.700 -  fun add_selector (c as (n, i)) ctxt =
   1.701 -    (case Datatype_Selectors.lookup_selector ctxt c of
   1.702 -      SOME _ => ctxt
   1.703 -    | NONE =>
   1.704 -        let
   1.705 -          val T = Sign.the_const_type (ProofContext.theory_of ctxt) n
   1.706 -          val U = Term.body_type T --> nth (Term.binder_types T) (i-1)
   1.707 -        in
   1.708 -          ctxt
   1.709 -          |> yield_singleton Variable.variant_fixes Name.uu
   1.710 -          |>> pair ((n, T), i) o rpair U
   1.711 -          |-> Context.proof_map o Datatype_Selectors.add_selector
   1.712 -        end)
   1.713 +  val weight_eq =
   1.714 +    mk_meta_eq @{lemma "p = SMT.weight i p" by (simp add: weight_def)}
   1.715 +  fun mk_weight_eq w =
   1.716 +    let val cv = Thm.dest_arg1 (Thm.rhs_of weight_eq)
   1.717 +    in
   1.718 +      Thm.instantiate ([], [(cv, Numeral.mk_cnumber @{ctyp int} w)]) weight_eq
   1.719 +    end
   1.720 +
   1.721 +  fun add_weight_conv NONE _ = Conv.all_conv
   1.722 +    | add_weight_conv (SOME weight) ctxt =
   1.723 +        let val cv = Conv.rewr_conv (mk_weight_eq weight)
   1.724 +        in U.under_quant_conv (K (under_trigger_conv cv)) ctxt end
   1.725  in
   1.726  
   1.727 -fun datatype_selectors irules ctxt =
   1.728 -  let
   1.729 -    val ns = Symtab.keys (fold (collect o Thm.prop_of o snd) irules Symtab.empty)
   1.730 -    val cs = fold (add_constructors (ProofContext.theory_of ctxt)) ns []
   1.731 -  in (irules, fold add_selector cs ctxt) end
   1.732 -    (* FIXME: also generate hypothetical definitions for the selectors *)
   1.733 +fun weight_conv weight ctxt = 
   1.734 +  U.prop_conv (check_weight_conv ctxt then_conv add_weight_conv weight ctxt)
   1.735 +
   1.736 +val setup_weight = B.add_builtin_fun_ext'' @{const_name SMT.weight}
   1.737  
   1.738  end
   1.739  
   1.740  
   1.741 -
   1.742 -(* combined normalization *)
   1.743 +(** combined general normalizations **)
   1.744  
   1.745 -type extra_norm = bool -> (int * thm) list -> Proof.context ->
   1.746 -  (int * thm) list * Proof.context
   1.747 -
   1.748 -fun with_context f irules ctxt = (f ctxt irules, ctxt)
   1.749 +fun gen_normalize1_conv ctxt weight =
   1.750 +  atomize_conv ctxt then_conv
   1.751 +  unfold_special_quants_conv ctxt then_conv
   1.752 +  trigger_conv ctxt then_conv
   1.753 +  weight_conv weight ctxt
   1.754  
   1.755 -fun normalize extra_norm with_datatypes irules ctxt =
   1.756 -  let
   1.757 -    fun norm f ctxt' (i, thm) =
   1.758 -      if Config.get ctxt' SMT_Config.drop_bad_facts then
   1.759 -        (case try (f ctxt') thm of
   1.760 -          SOME thm' => SOME (i, thm')
   1.761 -        | NONE => (SMT_Config.verbose_msg ctxt' (prefix ("Warning: " ^
   1.762 -            "dropping assumption: ") o Display.string_of_thm ctxt') thm; NONE))
   1.763 -      else SOME (i, f ctxt' thm)
   1.764 -  in
   1.765 -    irules
   1.766 -    |> map (apsnd instantiate_elim)
   1.767 -    |> trivial_distinct ctxt
   1.768 -    |> rewrite_bool_cases ctxt
   1.769 -    |> normalize_numerals ctxt
   1.770 -    |> nat_as_int ctxt
   1.771 -    |> rpair ctxt
   1.772 -    |-> extra_norm with_datatypes
   1.773 -    |-> with_context (map_filter o norm normalize_rule)
   1.774 -    |-> SMT_Monomorph.monomorph
   1.775 -    |-> lift_lambdas
   1.776 -    |-> with_context explicit_application
   1.777 -    |-> (if with_datatypes then datatype_selectors else pair)
   1.778 -  end
   1.779 +fun gen_normalize1 ctxt weight thm =
   1.780 +  thm
   1.781 +  |> instantiate_elim
   1.782 +  |> norm_def
   1.783 +  |> Conv.fconv_rule (Thm.beta_conversion true then_conv Thm.eta_conversion)
   1.784 +  |> Drule.forall_intr_vars
   1.785 +  |> Conv.fconv_rule (gen_normalize1_conv ctxt weight)
   1.786 +
   1.787 +fun drop_fact_warning ctxt =
   1.788 +  let val pre = prefix "Warning: dropping assumption: "
   1.789 +  in SMT_Config.verbose_msg ctxt (pre o Display.string_of_thm ctxt) end
   1.790 +
   1.791 +fun gen_norm1_safe ctxt (i, (weight, thm)) =
   1.792 +  if Config.get ctxt SMT_Config.drop_bad_facts then
   1.793 +    (case try (gen_normalize1 ctxt weight) thm of
   1.794 +      SOME thm' => SOME (i, thm')
   1.795 +    | NONE => (drop_fact_warning ctxt thm; NONE))
   1.796 +  else SOME (i, gen_normalize1 ctxt weight thm)
   1.797 +
   1.798 +fun gen_normalize ctxt iwthms = map_filter (gen_norm1_safe ctxt) iwthms
   1.799  
   1.800  
   1.801  
   1.802 -(* setup *)
   1.803 +(* unfolding of definitions and theory-specific rewritings *)
   1.804 +
   1.805 +(** unfold trivial distincts **)
   1.806 +
   1.807 +local
   1.808 +  fun is_trivial_distinct (Const (@{const_name distinct}, _) $ t) =
   1.809 +        (case try HOLogic.dest_list t of
   1.810 +          SOME [] => true
   1.811 +        | SOME [_] => true
   1.812 +        | _ => false)
   1.813 +    | is_trivial_distinct _ = false
   1.814 +
   1.815 +  val thms = map mk_meta_eq @{lemma
   1.816 +    "distinct [] = True"
   1.817 +    "distinct [x] = True"
   1.818 +    "distinct [x, y] = (x ~= y)"
   1.819 +    by simp_all}
   1.820 +  fun distinct_conv _ =
   1.821 +    U.if_true_conv is_trivial_distinct (Conv.rewrs_conv thms)
   1.822 +in
   1.823 +
   1.824 +fun trivial_distinct_conv ctxt = U.if_exists_conv is_trivial_distinct
   1.825 +  (Conv.top_conv distinct_conv ctxt)
   1.826 +
   1.827 +end
   1.828 +
   1.829 +
   1.830 +(** rewrite bool case expressions as if expressions **)
   1.831 +
   1.832 +local
   1.833 +  fun is_bool_case (Const (@{const_name "bool.bool_case"}, _)) = true
   1.834 +    | is_bool_case _ = false
   1.835 +
   1.836 +  val thm = mk_meta_eq @{lemma
   1.837 +    "bool_case = (%x y P. if P then x else y)" by (rule ext)+ simp}
   1.838 +
   1.839 +  fun unfold_conv _ = U.if_true_conv is_bool_case (Conv.rewr_conv thm)
   1.840 +in
   1.841 +
   1.842 +fun rewrite_bool_case_conv ctxt = U.if_exists_conv is_bool_case
   1.843 +  (Conv.top_conv unfold_conv ctxt)
   1.844 +
   1.845 +val setup_bool_case = B.add_builtin_fun_ext'' @{const_name "bool.bool_case"}
   1.846 +
   1.847 +end
   1.848 +
   1.849 +
   1.850 +(** unfold abs, min and max **)
   1.851 +
   1.852 +local
   1.853 +  val abs_def = mk_meta_eq @{lemma
   1.854 +    "abs = (%a::'a::abs_if. if a < 0 then - a else a)"
   1.855 +    by (rule ext) (rule abs_if)}
   1.856 +
   1.857 +  val min_def = mk_meta_eq @{lemma "min = (%a b. if a <= b then a else b)"
   1.858 +    by (rule ext)+ (rule min_def)}
   1.859 +
   1.860 +  val max_def = mk_meta_eq  @{lemma "max = (%a b. if a <= b then b else a)"
   1.861 +    by (rule ext)+ (rule max_def)}
   1.862 +
   1.863 +  val defs = [(@{const_name min}, min_def), (@{const_name max}, max_def),
   1.864 +    (@{const_name abs}, abs_def)]
   1.865 +
   1.866 +  fun is_builtinT ctxt T = B.is_builtin_typ_ext ctxt (Term.domain_type T)
   1.867 +
   1.868 +  fun abs_min_max ctxt (Const (n, T)) =
   1.869 +        (case AList.lookup (op =) defs n of
   1.870 +          NONE => NONE
   1.871 +        | SOME thm => if is_builtinT ctxt T then SOME thm else NONE)
   1.872 +    | abs_min_max _ _ = NONE
   1.873 +
   1.874 +  fun unfold_amm_conv ctxt ct =
   1.875 +    (case abs_min_max ctxt (Thm.term_of ct) of
   1.876 +      SOME thm => Conv.rewr_conv thm
   1.877 +    | NONE => Conv.all_conv) ct
   1.878 +in
   1.879 +
   1.880 +fun unfold_abs_min_max_conv ctxt =
   1.881 +  U.if_exists_conv (is_some o abs_min_max ctxt)
   1.882 +    (Conv.top_conv unfold_amm_conv ctxt)
   1.883 +  
   1.884 +val setup_abs_min_max = fold (B.add_builtin_fun_ext'' o fst) defs
   1.885 +
   1.886 +end
   1.887 +
   1.888 +
   1.889 +(** embedding of standard natural number operations into integer operations **)
   1.890 +
   1.891 +local
   1.892 +  val nat_embedding = @{lemma
   1.893 +    "ALL n. nat (int n) = n"
   1.894 +    "ALL i. i >= 0 --> int (nat i) = i"
   1.895 +    "ALL i. i < 0 --> int (nat i) = 0"
   1.896 +    by simp_all}
   1.897 +
   1.898 +  val nat_ops = [
   1.899 +    @{const less (nat)}, @{const less_eq (nat)},
   1.900 +    @{const Suc}, @{const plus (nat)}, @{const minus (nat)},
   1.901 +    @{const times (nat)}, @{const div (nat)}, @{const mod (nat)}]
   1.902 +
   1.903 +  val nat_consts = nat_ops @ [@{const number_of (nat)},
   1.904 +    @{const zero_class.zero (nat)}, @{const one_class.one (nat)}]
   1.905 +
   1.906 +  val nat_int_coercions = [@{const of_nat (int)}, @{const nat}]
   1.907 +
   1.908 +  val nat_ops' = nat_int_coercions @ nat_ops
   1.909 +
   1.910 +  val is_nat_const = member (op aconv) nat_consts
   1.911 +
   1.912 +  val expands = map mk_meta_eq @{lemma
   1.913 +    "0 = nat 0"
   1.914 +    "1 = nat 1"
   1.915 +    "(number_of :: int => nat) = (%i. nat (number_of i))"
   1.916 +    "op < = (%a b. int a < int b)"
   1.917 +    "op <= = (%a b. int a <= int b)"
   1.918 +    "Suc = (%a. nat (int a + 1))"
   1.919 +    "op + = (%a b. nat (int a + int b))"
   1.920 +    "op - = (%a b. nat (int a - int b))"
   1.921 +    "op * = (%a b. nat (int a * int b))"
   1.922 +    "op div = (%a b. nat (int a div int b))"
   1.923 +    "op mod = (%a b. nat (int a mod int b))"
   1.924 +    by (auto intro!: ext simp add: nat_mult_distrib nat_div_distrib
   1.925 +      nat_mod_distrib)}
   1.926 +
   1.927 +  val ints = map mk_meta_eq @{lemma
   1.928 +    "int 0 = 0"
   1.929 +    "int 1 = 1"
   1.930 +    "int (Suc n) = int n + 1"
   1.931 +    "int (n + m) = int n + int m"
   1.932 +    "int (n - m) = int (nat (int n - int m))"
   1.933 +    "int (n * m) = int n * int m"
   1.934 +    "int (n div m) = int n div int m"
   1.935 +    "int (n mod m) = int n mod int m"
   1.936 +    "int (if P then n else m) = (if P then int n else int m)"
   1.937 +    by (auto simp add: int_mult zdiv_int zmod_int)}
   1.938 +
   1.939 +  fun mk_number_eq ctxt i lhs =
   1.940 +    let
   1.941 +      val eq = U.mk_cequals lhs (Numeral.mk_cnumber @{ctyp int} i)
   1.942 +      val ss = HOL_ss
   1.943 +        addsimps [@{thm Nat_Numeral.int_nat_number_of}]
   1.944 +        addsimps @{thms neg_simps}
   1.945 +      fun tac _ = Simplifier.simp_tac (Simplifier.context ctxt ss) 1       
   1.946 +    in Goal.norm_result (Goal.prove_internal [] eq tac) end
   1.947 +
   1.948 +  fun expand_head_conv cv ct =
   1.949 +    (case Thm.term_of ct of
   1.950 +      _ $ _ =>
   1.951 +        Conv.fun_conv (expand_head_conv cv) then_conv
   1.952 +        Thm.beta_conversion false
   1.953 +    | _ => cv) ct
   1.954 +
   1.955 +  fun int_conv ctxt ct =
   1.956 +    (case Thm.term_of ct of
   1.957 +      @{const of_nat (int)} $ (n as (@{const number_of (nat)} $ _)) =>
   1.958 +        Conv.rewr_conv (mk_number_eq ctxt (snd (HOLogic.dest_number n)) ct)
   1.959 +    | @{const of_nat (int)} $ _ =>
   1.960 +        (Conv.rewrs_conv ints then_conv Conv.sub_conv ints_conv ctxt) else_conv
   1.961 +        Conv.top_sweep_conv nat_conv ctxt        
   1.962 +    | _ => Conv.no_conv) ct
   1.963 +
   1.964 +  and ints_conv ctxt = Conv.top_sweep_conv int_conv ctxt
   1.965 +
   1.966 +  and expand_conv ctxt =
   1.967 +    U.if_conv (not o is_nat_const o Term.head_of) Conv.no_conv
   1.968 +      (expand_head_conv (Conv.rewrs_conv expands) then_conv ints_conv ctxt)
   1.969 +
   1.970 +  and nat_conv ctxt = U.if_exists_conv is_nat_const
   1.971 +    (Conv.top_sweep_conv expand_conv ctxt)
   1.972 +
   1.973 +  val uses_nat_int = Term.exists_subterm (member (op aconv) nat_int_coercions)
   1.974 +in
   1.975 +
   1.976 +val nat_as_int_conv = nat_conv
   1.977 +
   1.978 +fun add_nat_embedding thms =
   1.979 +  if exists (uses_nat_int o Thm.prop_of) thms then (thms, nat_embedding)
   1.980 +  else (thms, [])
   1.981 +
   1.982 +val setup_nat_as_int =
   1.983 +  B.add_builtin_typ_ext (@{typ nat}, K true) #>
   1.984 +  fold (B.add_builtin_fun_ext' o Term.dest_Const) nat_ops'
   1.985 +
   1.986 +end
   1.987 +
   1.988 +
   1.989 +(** normalize numerals **)
   1.990 +
   1.991 +local
   1.992 +  (*
   1.993 +    rewrite negative numerals into positive numerals,
   1.994 +    rewrite Numeral0 into 0
   1.995 +    rewrite Numeral1 into 1
   1.996 +  *)
   1.997 +
   1.998 +  fun is_strange_number ctxt (t as Const (@{const_name number_of}, _) $ _) =
   1.999 +        (case try HOLogic.dest_number t of
  1.1000 +          SOME (_, i) => B.is_builtin_num ctxt t andalso i < 2
  1.1001 +        | NONE => false)
  1.1002 +    | is_strange_number _ _ = false
  1.1003 +
  1.1004 +  val pos_num_ss = HOL_ss
  1.1005 +    addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}]
  1.1006 +    addsimps [@{thm Int.number_of_Pls}, @{thm Int.numeral_1_eq_1}]
  1.1007 +    addsimps @{thms Int.pred_bin_simps}
  1.1008 +    addsimps @{thms Int.normalize_bin_simps}
  1.1009 +    addsimps @{lemma
  1.1010 +      "Int.Min = - Int.Bit1 Int.Pls"
  1.1011 +      "Int.Bit0 (- Int.Pls) = - Int.Pls"
  1.1012 +      "Int.Bit0 (- k) = - Int.Bit0 k"
  1.1013 +      "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)"
  1.1014 +      by simp_all (simp add: pred_def)}
  1.1015 +
  1.1016 +  fun norm_num_conv ctxt = U.if_conv (is_strange_number ctxt)
  1.1017 +    (Simplifier.rewrite (Simplifier.context ctxt pos_num_ss)) Conv.no_conv
  1.1018 +in
  1.1019 +
  1.1020 +fun normalize_numerals_conv ctxt = U.if_exists_conv (is_strange_number ctxt)
  1.1021 +  (Conv.top_sweep_conv norm_num_conv ctxt)
  1.1022 +
  1.1023 +end
  1.1024 +
  1.1025 +
  1.1026 +(** combined unfoldings and rewritings **)
  1.1027 +
  1.1028 +fun unfold_conv ctxt =
  1.1029 +  trivial_distinct_conv ctxt then_conv
  1.1030 +  rewrite_bool_case_conv ctxt then_conv
  1.1031 +  unfold_abs_min_max_conv ctxt then_conv
  1.1032 +  nat_as_int_conv ctxt then_conv
  1.1033 +  normalize_numerals_conv ctxt then_conv
  1.1034 +  Thm.beta_conversion true
  1.1035 +
  1.1036 +fun burrow_ids f ithms =
  1.1037 +  let
  1.1038 +    val (is, thms) = split_list ithms
  1.1039 +    val (thms', extra_thms) = f thms
  1.1040 +  in (is ~~ thms') @ map (pair ~1) extra_thms end
  1.1041 +
  1.1042 +fun unfold ctxt =
  1.1043 +  burrow_ids (map (Conv.fconv_rule (unfold_conv ctxt)) #> add_nat_embedding)
  1.1044 +
  1.1045 +
  1.1046 +
  1.1047 +(* overall normalization *)
  1.1048 +
  1.1049 +type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list
  1.1050 +
  1.1051 +structure Extra_Norms = Generic_Data
  1.1052 +(
  1.1053 +  type T = extra_norm U.dict
  1.1054 +  val empty = []
  1.1055 +  val extend = I
  1.1056 +  val merge = U.dict_merge fst
  1.1057 +)
  1.1058 +
  1.1059 +fun add_extra_norm (cs, norm) = Extra_Norms.map (U.dict_update (cs, norm))
  1.1060 +
  1.1061 +fun apply_extra_norms ctxt =
  1.1062 +  let
  1.1063 +    val cs = SMT_Config.solver_class_of ctxt
  1.1064 +    val es = U.dict_lookup (Extra_Norms.get (Context.Proof ctxt)) cs
  1.1065 +  in burrow_ids (fold (fn e => e ctxt) es o rpair []) end
  1.1066 +
  1.1067 +fun normalize ctxt iwthms =
  1.1068 +  iwthms
  1.1069 +  |> gen_normalize ctxt
  1.1070 +  |> unfold ctxt
  1.1071 +  |> apply_extra_norms ctxt
  1.1072  
  1.1073  val setup = Context.theory_map (
  1.1074 +  setup_atomize #>
  1.1075 +  setup_unfolded_quants #>
  1.1076 +  setup_trigger #>
  1.1077 +  setup_weight #>
  1.1078    setup_bool_case #>
  1.1079 -  setup_nat_as_int #>
  1.1080 -  setup_unfolded_quants #>
  1.1081 -  setup_atomize)
  1.1082 +  setup_abs_min_max #>
  1.1083 +  setup_nat_as_int)
  1.1084  
  1.1085  end