src/HOL/Tools/SMT/z3_interface.ML
author haftmann
Sat Aug 28 16:14:32 2010 +0200 (2010-08-28)
changeset 38864 4abe644fcea5
parent 38795 848be46708dc
child 39298 5aefb5bc8a93
permissions -rw-r--r--
formerly unnamed infix equality now named HOL.eq
boehmes@36898
     1
(*  Title:      HOL/Tools/SMT/z3_interface.ML
boehmes@36898
     2
    Author:     Sascha Boehme, TU Muenchen
boehmes@36898
     3
boehmes@36898
     4
Interface to Z3 based on a relaxed version of SMT-LIB.
boehmes@36898
     5
*)
boehmes@36898
     6
boehmes@36898
     7
signature Z3_INTERFACE =
boehmes@36898
     8
sig
boehmes@36899
     9
  type builtin_fun = string * typ -> term list -> (string * term list) option
boehmes@36899
    10
  val add_builtin_funs: builtin_fun -> Context.generic -> Context.generic
boehmes@36899
    11
  val interface: SMT_Solver.interface
boehmes@36898
    12
boehmes@36899
    13
  datatype sym = Sym of string * sym list
boehmes@36899
    14
  type mk_builtins = {
boehmes@36899
    15
    mk_builtin_typ: sym -> typ option,
boehmes@36899
    16
    mk_builtin_num: theory -> int -> typ -> cterm option,
boehmes@36899
    17
    mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
boehmes@36899
    18
  val add_mk_builtins: mk_builtins -> Context.generic -> Context.generic
boehmes@36899
    19
  val mk_builtin_typ: Proof.context -> sym -> typ option
boehmes@36899
    20
  val mk_builtin_num: Proof.context -> int -> typ -> cterm option
boehmes@36899
    21
  val mk_builtin_fun: Proof.context -> sym -> cterm list -> cterm option
boehmes@36899
    22
boehmes@36899
    23
  val is_builtin_theory_term: Proof.context -> term -> bool
boehmes@36899
    24
boehmes@36899
    25
  val mk_inst_pair: (ctyp -> 'a) -> cterm -> 'a * cterm
boehmes@36899
    26
  val destT1: ctyp -> ctyp
boehmes@36899
    27
  val destT2: ctyp -> ctyp
boehmes@36899
    28
  val instT': cterm -> ctyp * cterm -> cterm
boehmes@36898
    29
end
boehmes@36898
    30
boehmes@36898
    31
structure Z3_Interface: Z3_INTERFACE =
boehmes@36898
    32
struct
boehmes@36898
    33
boehmes@36899
    34
boehmes@36899
    35
(** Z3-specific builtins **)
boehmes@36899
    36
boehmes@36899
    37
type builtin_fun = string * typ -> term list -> (string * term list) option
boehmes@36899
    38
boehmes@36899
    39
fun fst_int_ord ((s1, _), (s2, _)) = int_ord (s1, s2)
boehmes@36899
    40
boehmes@36899
    41
structure Builtins = Generic_Data
boehmes@36899
    42
(
boehmes@36899
    43
  type T = (int * builtin_fun) list
boehmes@36899
    44
  val empty = []
boehmes@36899
    45
  val extend = I
boehmes@36899
    46
  fun merge (bs1, bs2) = OrdList.union fst_int_ord bs2 bs1
boehmes@36899
    47
)
boehmes@36899
    48
boehmes@36899
    49
fun add_builtin_funs b =
boehmes@36899
    50
  Builtins.map (OrdList.insert fst_int_ord (serial (), b))
boehmes@36899
    51
boehmes@36899
    52
fun get_builtin_funs ctxt c ts =
boehmes@36899
    53
  let
boehmes@36899
    54
    fun chained [] = NONE
boehmes@36899
    55
      | chained (b :: bs) = (case b c ts of SOME x => SOME x | _ => chained bs)
boehmes@36899
    56
  in chained (map snd (Builtins.get (Context.Proof ctxt))) end
boehmes@36899
    57
boehmes@36899
    58
fun z3_builtin_fun builtin_fun ctxt c ts =
boehmes@36899
    59
  (case builtin_fun ctxt c ts of
boehmes@36899
    60
    SOME x => SOME x
boehmes@36899
    61
  | _ => get_builtin_funs ctxt c ts)
boehmes@36898
    62
boehmes@36898
    63
boehmes@36899
    64
boehmes@36899
    65
(** interface **)
boehmes@36899
    66
boehmes@36899
    67
local
boehmes@36899
    68
  val {extra_norm, translate} = SMTLIB_Interface.interface
boehmes@36899
    69
  val {prefixes, strict, header, builtins, serialize} = translate
boehmes@36899
    70
  val {is_builtin_pred, ...}= the strict
boehmes@36899
    71
  val {builtin_typ, builtin_num, builtin_fun} = builtins
boehmes@36899
    72
boehmes@37151
    73
  fun is_int_div_mod @{term "op div :: int => _"} = true
boehmes@37151
    74
    | is_int_div_mod @{term "op mod :: int => _"} = true
boehmes@37151
    75
    | is_int_div_mod _ = false
boehmes@37151
    76
boehmes@37151
    77
  fun add_div_mod thms =
boehmes@37151
    78
    if exists (Term.exists_subterm is_int_div_mod o Thm.prop_of) thms
boehmes@37151
    79
    then [@{thm div_by_z3div}, @{thm mod_by_z3mod}] @ thms
boehmes@37151
    80
    else thms
boehmes@37151
    81
boehmes@37151
    82
  fun extra_norm' thms = extra_norm (add_div_mod thms)
boehmes@37151
    83
boehmes@37151
    84
  fun z3_builtin_fun' _ (@{const_name z3div}, _) ts = SOME ("div", ts)
boehmes@37151
    85
    | z3_builtin_fun' _ (@{const_name z3mod}, _) ts = SOME ("mod", ts)
boehmes@37151
    86
    | z3_builtin_fun' ctxt c ts = z3_builtin_fun builtin_fun ctxt c ts
boehmes@36899
    87
boehmes@36899
    88
  val as_propT = (fn @{typ bool} => @{typ prop} | T => T)
boehmes@36899
    89
in
boehmes@36899
    90
boehmes@36899
    91
fun is_builtin_num ctxt (T, i) = is_some (builtin_num ctxt T i)
boehmes@36899
    92
boehmes@36899
    93
fun is_builtin_fun ctxt (c as (n, T)) ts =
boehmes@36899
    94
  is_some (z3_builtin_fun' ctxt c ts) orelse 
boehmes@36899
    95
  is_builtin_pred ctxt (n, Term.strip_type T ||> as_propT |> (op --->))
boehmes@36898
    96
boehmes@36898
    97
val interface = {
boehmes@37151
    98
  extra_norm = extra_norm',
boehmes@36898
    99
  translate = {
boehmes@36898
   100
    prefixes = prefixes,
boehmes@36898
   101
    strict = strict,
boehmes@36899
   102
    header = header,
boehmes@36898
   103
    builtins = {
boehmes@36898
   104
      builtin_typ = builtin_typ,
boehmes@36898
   105
      builtin_num = builtin_num,
boehmes@36899
   106
      builtin_fun = z3_builtin_fun'},
boehmes@36898
   107
    serialize = serialize}}
boehmes@36898
   108
boehmes@36898
   109
end
boehmes@36899
   110
boehmes@36899
   111
boehmes@36899
   112
boehmes@36899
   113
(** constructors **)
boehmes@36899
   114
boehmes@36899
   115
datatype sym = Sym of string * sym list
boehmes@36899
   116
boehmes@36899
   117
boehmes@36899
   118
(* additional constructors *)
boehmes@36899
   119
boehmes@36899
   120
type mk_builtins = {
boehmes@36899
   121
  mk_builtin_typ: sym -> typ option,
boehmes@36899
   122
  mk_builtin_num: theory -> int -> typ -> cterm option,
boehmes@36899
   123
  mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
boehmes@36899
   124
boehmes@36899
   125
fun chained _ [] = NONE
boehmes@36899
   126
  | chained f (b :: bs) = (case f b of SOME y => SOME y | NONE => chained f bs)
boehmes@36899
   127
boehmes@36899
   128
fun chained_mk_builtin_typ bs sym =
boehmes@36899
   129
  chained (fn {mk_builtin_typ=mk, ...} : mk_builtins => mk sym) bs
boehmes@36899
   130
boehmes@36899
   131
fun chained_mk_builtin_num ctxt bs i T =
boehmes@36899
   132
  let val thy = ProofContext.theory_of ctxt
boehmes@36899
   133
  in chained (fn {mk_builtin_num=mk, ...} : mk_builtins => mk thy i T) bs end
boehmes@36899
   134
boehmes@36899
   135
fun chained_mk_builtin_fun ctxt bs s cts =
boehmes@36899
   136
  let val thy = ProofContext.theory_of ctxt
boehmes@36899
   137
  in chained (fn {mk_builtin_fun=mk, ...} : mk_builtins => mk thy s cts) bs end
boehmes@36899
   138
boehmes@36899
   139
structure Mk_Builtins = Generic_Data
boehmes@36899
   140
(
boehmes@36899
   141
  type T = (int * mk_builtins) list
boehmes@36899
   142
  val empty = []
boehmes@36899
   143
  val extend = I
boehmes@36899
   144
  fun merge (bs1, bs2) = OrdList.union fst_int_ord bs2 bs1
boehmes@36899
   145
)
boehmes@36899
   146
boehmes@36899
   147
fun add_mk_builtins mk =
boehmes@36899
   148
  Mk_Builtins.map (OrdList.insert fst_int_ord (serial (), mk))
boehmes@36899
   149
boehmes@36899
   150
fun get_mk_builtins ctxt = map snd (Mk_Builtins.get (Context.Proof ctxt))
boehmes@36899
   151
boehmes@36899
   152
boehmes@36899
   153
(* basic and additional constructors *)
boehmes@36899
   154
boehmes@36899
   155
fun mk_builtin_typ _ (Sym ("bool", _)) = SOME @{typ bool}
boehmes@36899
   156
  | mk_builtin_typ _ (Sym ("int", _)) = SOME @{typ int}
boehmes@36899
   157
  | mk_builtin_typ ctxt sym = chained_mk_builtin_typ (get_mk_builtins ctxt) sym
boehmes@36899
   158
boehmes@36899
   159
fun mk_builtin_num _ i @{typ int} = SOME (Numeral.mk_cnumber @{ctyp int} i)
boehmes@36899
   160
  | mk_builtin_num ctxt i T =
boehmes@36899
   161
      chained_mk_builtin_num ctxt (get_mk_builtins ctxt) i T
boehmes@36899
   162
boehmes@36899
   163
fun instTs cUs (cTs, ct) = Thm.instantiate_cterm (cTs ~~ cUs, []) ct
boehmes@36899
   164
fun instT cU (cT, ct) = instTs [cU] ([cT], ct)
boehmes@36899
   165
fun instT' ct = instT (Thm.ctyp_of_term ct)
boehmes@36899
   166
fun mk_inst_pair destT cpat = (destT (Thm.ctyp_of_term cpat), cpat)
boehmes@36899
   167
val destT1 = hd o Thm.dest_ctyp
boehmes@36899
   168
val destT2 = hd o tl o Thm.dest_ctyp
boehmes@36899
   169
boehmes@36899
   170
val mk_true = @{cterm "~False"}
boehmes@36899
   171
val mk_false = @{cterm False}
boehmes@36899
   172
val mk_not = Thm.capply @{cterm Not}
haftmann@38786
   173
val mk_implies = Thm.mk_binop @{cterm HOL.implies}
boehmes@36899
   174
val mk_iff = Thm.mk_binop @{cterm "op = :: bool => _"}
boehmes@36899
   175
boehmes@36899
   176
fun mk_nary _ cu [] = cu
boehmes@36899
   177
  | mk_nary ct _ cts = uncurry (fold_rev (Thm.mk_binop ct)) (split_last cts)
boehmes@36899
   178
haftmann@38864
   179
val eq = mk_inst_pair destT1 @{cpat HOL.eq}
boehmes@36899
   180
fun mk_eq ct cu = Thm.mk_binop (instT' ct eq) ct cu
boehmes@36899
   181
boehmes@36899
   182
val if_term = mk_inst_pair (destT1 o destT2) @{cpat If}
boehmes@36899
   183
fun mk_if cc ct cu = Thm.mk_binop (Thm.capply (instT' ct if_term) cc) ct cu
boehmes@36899
   184
boehmes@36899
   185
val nil_term = mk_inst_pair destT1 @{cpat Nil}
boehmes@36899
   186
val cons_term = mk_inst_pair destT1 @{cpat Cons}
boehmes@36899
   187
fun mk_list cT cts =
boehmes@36899
   188
  fold_rev (Thm.mk_binop (instT cT cons_term)) cts (instT cT nil_term)
boehmes@36899
   189
boehmes@36899
   190
val distinct = mk_inst_pair (destT1 o destT1) @{cpat distinct}
boehmes@36899
   191
fun mk_distinct [] = mk_true
boehmes@36899
   192
  | mk_distinct (cts as (ct :: _)) =
boehmes@36899
   193
      Thm.capply (instT' ct distinct) (mk_list (Thm.ctyp_of_term ct) cts)
boehmes@36899
   194
boehmes@37153
   195
val access = mk_inst_pair (Thm.dest_ctyp o destT1) @{cpat fun_app}
boehmes@36899
   196
fun mk_access array index =
boehmes@36899
   197
  let val cTs = Thm.dest_ctyp (Thm.ctyp_of_term array)
boehmes@36899
   198
  in Thm.mk_binop (instTs cTs access) array index end
boehmes@36899
   199
boehmes@36899
   200
val update = mk_inst_pair (Thm.dest_ctyp o destT1) @{cpat fun_upd}
boehmes@36899
   201
fun mk_update array index value =
boehmes@36899
   202
  let val cTs = Thm.dest_ctyp (Thm.ctyp_of_term array)
boehmes@36899
   203
  in Thm.capply (Thm.mk_binop (instTs cTs update) array index) value end
boehmes@36899
   204
boehmes@36899
   205
val mk_uminus = Thm.capply @{cterm "uminus :: int => _"}
boehmes@36899
   206
val mk_add = Thm.mk_binop @{cterm "op + :: int => _"}
boehmes@36899
   207
val mk_sub = Thm.mk_binop @{cterm "op - :: int => _"}
boehmes@36899
   208
val mk_mul = Thm.mk_binop @{cterm "op * :: int => _"}
boehmes@37151
   209
val mk_div = Thm.mk_binop @{cterm "z3div :: int => _"}
boehmes@37151
   210
val mk_mod = Thm.mk_binop @{cterm "z3mod :: int => _"}
boehmes@36899
   211
val mk_lt = Thm.mk_binop @{cterm "op < :: int => _"}
boehmes@36899
   212
val mk_le = Thm.mk_binop @{cterm "op <= :: int => _"}
boehmes@36899
   213
boehmes@36899
   214
fun mk_builtin_fun ctxt sym cts =
boehmes@36899
   215
  (case (sym, cts) of
boehmes@36899
   216
    (Sym ("true", _), []) => SOME mk_true
boehmes@36899
   217
  | (Sym ("false", _), []) => SOME mk_false
boehmes@36899
   218
  | (Sym ("not", _), [ct]) => SOME (mk_not ct)
haftmann@38795
   219
  | (Sym ("and", _), _) => SOME (mk_nary @{cterm HOL.conj} mk_true cts)
haftmann@38795
   220
  | (Sym ("or", _), _) => SOME (mk_nary @{cterm HOL.disj} mk_false cts)
boehmes@36899
   221
  | (Sym ("implies", _), [ct, cu]) => SOME (mk_implies ct cu)
boehmes@36899
   222
  | (Sym ("iff", _), [ct, cu]) => SOME (mk_iff ct cu)
boehmes@36899
   223
  | (Sym ("~", _), [ct, cu]) => SOME (mk_iff ct cu)
boehmes@36899
   224
  | (Sym ("xor", _), [ct, cu]) => SOME (mk_not (mk_iff ct cu))
boehmes@36899
   225
  | (Sym ("ite", _), [ct1, ct2, ct3]) => SOME (mk_if ct1 ct2 ct3)
boehmes@36899
   226
  | (Sym ("=", _), [ct, cu]) => SOME (mk_eq ct cu)
boehmes@36899
   227
  | (Sym ("distinct", _), _) => SOME (mk_distinct cts)
boehmes@36899
   228
  | (Sym ("select", _), [ca, ck]) => SOME (mk_access ca ck)
boehmes@36899
   229
  | (Sym ("store", _), [ca, ck, cv]) => SOME (mk_update ca ck cv)
boehmes@36899
   230
  | _ =>
boehmes@36899
   231
    (case (sym, try (#T o Thm.rep_cterm o hd) cts, cts) of
boehmes@36899
   232
      (Sym ("+", _), SOME @{typ int}, [ct, cu]) => SOME (mk_add ct cu)
boehmes@36899
   233
    | (Sym ("-", _), SOME @{typ int}, [ct]) => SOME (mk_uminus ct)
boehmes@36899
   234
    | (Sym ("-", _), SOME @{typ int}, [ct, cu]) => SOME (mk_sub ct cu)
boehmes@36899
   235
    | (Sym ("*", _), SOME @{typ int}, [ct, cu]) => SOME (mk_mul ct cu)
boehmes@37151
   236
    | (Sym ("div", _), SOME @{typ int}, [ct, cu]) => SOME (mk_div ct cu)
boehmes@37151
   237
    | (Sym ("mod", _), SOME @{typ int}, [ct, cu]) => SOME (mk_mod ct cu)
boehmes@36899
   238
    | (Sym ("<", _), SOME @{typ int}, [ct, cu]) => SOME (mk_lt ct cu)
boehmes@36899
   239
    | (Sym ("<=", _), SOME @{typ int}, [ct, cu]) => SOME (mk_le ct cu)
boehmes@36899
   240
    | (Sym (">", _), SOME @{typ int}, [ct, cu]) => SOME (mk_lt cu ct)
boehmes@36899
   241
    | (Sym (">=", _), SOME @{typ int}, [ct, cu]) => SOME (mk_le cu ct)
boehmes@36899
   242
    | _ => chained_mk_builtin_fun ctxt (get_mk_builtins ctxt) sym cts))
boehmes@36899
   243
boehmes@36899
   244
boehmes@36899
   245
boehmes@36899
   246
(** abstraction **)
boehmes@36899
   247
boehmes@36899
   248
fun is_builtin_theory_term ctxt t =
boehmes@36899
   249
  (case try HOLogic.dest_number t of
boehmes@36899
   250
    SOME n => is_builtin_num ctxt n
boehmes@36899
   251
  | NONE =>
boehmes@36899
   252
      (case Term.strip_comb t of
boehmes@36899
   253
        (Const c, ts) => is_builtin_fun ctxt c ts
boehmes@36899
   254
      | _ => false))
boehmes@36899
   255
boehmes@36899
   256
end