src/HOL/Tools/SMT/z3_interface.ML
author boehmes
Wed Nov 24 15:33:35 2010 +0100 (2010-11-24)
changeset 40686 4725ed462387
parent 40681 872b08416fb4
child 41059 d2b1fc1b8e19
permissions -rw-r--r--
swap names for built-in tester functions (to better reflect the intuition of what they do);
eta-expand all built-in functions (even those which are only partially supported)
     1 (*  Title:      HOL/Tools/SMT/z3_interface.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Interface to Z3 based on a relaxed version of SMT-LIB.
     5 *)
     6 
     7 signature Z3_INTERFACE =
     8 sig
     9   type builtin_fun = string * typ -> term list -> (string * term list) option
    10   val add_builtin_funs: builtin_fun -> Context.generic -> Context.generic
    11   val interface: SMT_Solver.interface
    12 
    13   datatype sym = Sym of string * sym list
    14   type mk_builtins = {
    15     mk_builtin_typ: sym -> typ option,
    16     mk_builtin_num: theory -> int -> typ -> cterm option,
    17     mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
    18   val add_mk_builtins: mk_builtins -> Context.generic -> Context.generic
    19   val mk_builtin_typ: Proof.context -> sym -> typ option
    20   val mk_builtin_num: Proof.context -> int -> typ -> cterm option
    21   val mk_builtin_fun: Proof.context -> sym -> cterm list -> cterm option
    22 
    23   val is_builtin_theory_term: Proof.context -> term -> bool
    24 end
    25 
    26 structure Z3_Interface: Z3_INTERFACE =
    27 struct
    28 
    29 structure U = SMT_Utils
    30 
    31 
    32 (** Z3-specific builtins **)
    33 
    34 type builtin_fun = string * typ -> term list -> (string * term list) option
    35 
    36 fun fst_int_ord ((s1, _), (s2, _)) = int_ord (s1, s2)
    37 
    38 structure Builtins = Generic_Data
    39 (
    40   type T = (int * builtin_fun) list
    41   val empty = []
    42   val extend = I
    43   fun merge (bs1, bs2) = Ord_List.union fst_int_ord bs2 bs1
    44 )
    45 
    46 fun add_builtin_funs b =
    47   Builtins.map (Ord_List.insert fst_int_ord (serial (), b))
    48 
    49 fun get_builtin_funs ctxt c ts =
    50   let
    51     fun chained [] = NONE
    52       | chained (b :: bs) = (case b c ts of SOME x => SOME x | _ => chained bs)
    53   in chained (map snd (Builtins.get (Context.Proof ctxt))) end
    54 
    55 fun z3_builtin_fun builtin_fun ctxt c ts =
    56   (case builtin_fun ctxt c ts of
    57     SOME x => SOME x
    58   | _ => get_builtin_funs ctxt c ts)
    59 
    60 
    61 
    62 (** interface **)
    63 
    64 local
    65   val {translate, ...} = SMTLIB_Interface.interface
    66   val {prefixes, strict, header, builtins, serialize} = translate
    67   val {is_builtin_pred, ...}= the strict
    68   val {builtin_typ, builtin_num, builtin_fun, ...} = builtins
    69 
    70   fun is_int_div_mod @{const div (int)} = true
    71     | is_int_div_mod @{const mod (int)} = true
    72     | is_int_div_mod _ = false
    73 
    74   fun add_div_mod irules =
    75     if exists (Term.exists_subterm is_int_div_mod o Thm.prop_of o snd) irules
    76     then [(~1, @{thm div_by_z3div}), (~1, @{thm mod_by_z3mod})] @ irules
    77     else irules
    78 
    79   fun extra_norm' has_datatypes =
    80     SMTLIB_Interface.extra_norm has_datatypes o add_div_mod
    81 
    82   fun z3_builtin_fun' _ (@{const_name z3div}, _) ts = SOME ("div", ts)
    83     | z3_builtin_fun' _ (@{const_name z3mod}, _) ts = SOME ("mod", ts)
    84     | z3_builtin_fun' ctxt c ts = z3_builtin_fun builtin_fun ctxt c ts
    85 
    86   val as_propT = (fn @{typ bool} => @{typ prop} | T => T)
    87 in
    88 
    89 fun is_builtin_num ctxt (T, i) = is_some (builtin_num ctxt T i)
    90 
    91 fun is_builtin_fun ctxt (c as (n, T)) ts =
    92   is_some (z3_builtin_fun' ctxt c ts) orelse 
    93   is_builtin_pred ctxt (n, Term.strip_type T ||> as_propT |> (op --->))
    94 
    95 val interface = {
    96   extra_norm = extra_norm',
    97   translate = {
    98     prefixes = prefixes,
    99     strict = strict,
   100     header = header,
   101     builtins = {
   102       builtin_typ = builtin_typ,
   103       builtin_num = builtin_num,
   104       builtin_fun = z3_builtin_fun',
   105       has_datatypes = true},
   106     serialize = serialize}}
   107 
   108 end
   109 
   110 
   111 
   112 (** constructors **)
   113 
   114 datatype sym = Sym of string * sym list
   115 
   116 
   117 (* additional constructors *)
   118 
   119 type mk_builtins = {
   120   mk_builtin_typ: sym -> typ option,
   121   mk_builtin_num: theory -> int -> typ -> cterm option,
   122   mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
   123 
   124 fun chained _ [] = NONE
   125   | chained f (b :: bs) = (case f b of SOME y => SOME y | NONE => chained f bs)
   126 
   127 fun chained_mk_builtin_typ bs sym =
   128   chained (fn {mk_builtin_typ=mk, ...} : mk_builtins => mk sym) bs
   129 
   130 fun chained_mk_builtin_num ctxt bs i T =
   131   let val thy = ProofContext.theory_of ctxt
   132   in chained (fn {mk_builtin_num=mk, ...} : mk_builtins => mk thy i T) bs end
   133 
   134 fun chained_mk_builtin_fun ctxt bs s cts =
   135   let val thy = ProofContext.theory_of ctxt
   136   in chained (fn {mk_builtin_fun=mk, ...} : mk_builtins => mk thy s cts) bs end
   137 
   138 structure Mk_Builtins = Generic_Data
   139 (
   140   type T = (int * mk_builtins) list
   141   val empty = []
   142   val extend = I
   143   fun merge (bs1, bs2) = Ord_List.union fst_int_ord bs2 bs1
   144 )
   145 
   146 fun add_mk_builtins mk =
   147   Mk_Builtins.map (Ord_List.insert fst_int_ord (serial (), mk))
   148 
   149 fun get_mk_builtins ctxt = map snd (Mk_Builtins.get (Context.Proof ctxt))
   150 
   151 
   152 (* basic and additional constructors *)
   153 
   154 fun mk_builtin_typ _ (Sym ("bool", _)) = SOME @{typ bool}
   155   | mk_builtin_typ _ (Sym ("Int", _)) = SOME @{typ int}
   156   | mk_builtin_typ _ (Sym ("int", _)) = SOME @{typ int}  (*FIXME: delete*)
   157   | mk_builtin_typ ctxt sym = chained_mk_builtin_typ (get_mk_builtins ctxt) sym
   158 
   159 fun mk_builtin_num _ i @{typ int} = SOME (Numeral.mk_cnumber @{ctyp int} i)
   160   | mk_builtin_num ctxt i T =
   161       chained_mk_builtin_num ctxt (get_mk_builtins ctxt) i T
   162 
   163 val mk_true = Thm.cterm_of @{theory} (@{const Not} $ @{const False})
   164 val mk_false = Thm.cterm_of @{theory} @{const False}
   165 val mk_not = Thm.capply (Thm.cterm_of @{theory} @{const Not})
   166 val mk_implies = Thm.mk_binop (Thm.cterm_of @{theory} @{const HOL.implies})
   167 val mk_iff = Thm.mk_binop (Thm.cterm_of @{theory} @{const HOL.eq (bool)})
   168 val conj = Thm.cterm_of @{theory} @{const HOL.conj}
   169 val disj = Thm.cterm_of @{theory} @{const HOL.disj}
   170 
   171 fun mk_nary _ cu [] = cu
   172   | mk_nary ct _ cts = uncurry (fold_rev (Thm.mk_binop ct)) (split_last cts)
   173 
   174 val eq = U.mk_const_pat @{theory} @{const_name HOL.eq} U.destT1
   175 fun mk_eq ct cu = Thm.mk_binop (U.instT' ct eq) ct cu
   176 
   177 val if_term = U.mk_const_pat @{theory} @{const_name If} (U.destT1 o U.destT2)
   178 fun mk_if cc ct cu = Thm.mk_binop (Thm.capply (U.instT' ct if_term) cc) ct cu
   179 
   180 val nil_term = U.mk_const_pat @{theory} @{const_name Nil} U.destT1
   181 val cons_term = U.mk_const_pat @{theory} @{const_name Cons} U.destT1
   182 fun mk_list cT cts =
   183   fold_rev (Thm.mk_binop (U.instT cT cons_term)) cts (U.instT cT nil_term)
   184 
   185 val distinct = U.mk_const_pat @{theory} @{const_name distinct}
   186   (U.destT1 o U.destT1)
   187 fun mk_distinct [] = mk_true
   188   | mk_distinct (cts as (ct :: _)) =
   189       Thm.capply (U.instT' ct distinct) (mk_list (Thm.ctyp_of_term ct) cts)
   190 
   191 val access = U.mk_const_pat @{theory} @{const_name fun_app}
   192   (Thm.dest_ctyp o U.destT1)
   193 fun mk_access array index =
   194   let val cTs = Thm.dest_ctyp (Thm.ctyp_of_term array)
   195   in Thm.mk_binop (U.instTs cTs access) array index end
   196 
   197 val update = U.mk_const_pat @{theory} @{const_name fun_upd}
   198   (Thm.dest_ctyp o U.destT1)
   199 fun mk_update array index value =
   200   let val cTs = Thm.dest_ctyp (Thm.ctyp_of_term array)
   201   in Thm.capply (Thm.mk_binop (U.instTs cTs update) array index) value end
   202 
   203 val mk_uminus = Thm.capply (Thm.cterm_of @{theory} @{const uminus (int)})
   204 val mk_add = Thm.mk_binop (Thm.cterm_of @{theory} @{const plus (int)})
   205 val mk_sub = Thm.mk_binop (Thm.cterm_of @{theory} @{const minus (int)})
   206 val mk_mul = Thm.mk_binop (Thm.cterm_of @{theory} @{const times (int)})
   207 val mk_div = Thm.mk_binop (Thm.cterm_of @{theory} @{const z3div})
   208 val mk_mod = Thm.mk_binop (Thm.cterm_of @{theory} @{const z3mod})
   209 val mk_lt = Thm.mk_binop (Thm.cterm_of @{theory} @{const less (int)})
   210 val mk_le = Thm.mk_binop (Thm.cterm_of @{theory} @{const less_eq (int)})
   211 
   212 fun mk_builtin_fun ctxt sym cts =
   213   (case (sym, cts) of
   214     (Sym ("true", _), []) => SOME mk_true
   215   | (Sym ("false", _), []) => SOME mk_false
   216   | (Sym ("not", _), [ct]) => SOME (mk_not ct)
   217   | (Sym ("and", _), _) => SOME (mk_nary conj mk_true cts)
   218   | (Sym ("or", _), _) => SOME (mk_nary disj mk_false cts)
   219   | (Sym ("implies", _), [ct, cu]) => SOME (mk_implies ct cu)
   220   | (Sym ("iff", _), [ct, cu]) => SOME (mk_iff ct cu)
   221   | (Sym ("~", _), [ct, cu]) => SOME (mk_iff ct cu)
   222   | (Sym ("xor", _), [ct, cu]) => SOME (mk_not (mk_iff ct cu))
   223   | (Sym ("ite", _), [ct1, ct2, ct3]) => SOME (mk_if ct1 ct2 ct3)
   224   | (Sym ("=", _), [ct, cu]) => SOME (mk_eq ct cu)
   225   | (Sym ("distinct", _), _) => SOME (mk_distinct cts)
   226   | (Sym ("select", _), [ca, ck]) => SOME (mk_access ca ck)
   227   | (Sym ("store", _), [ca, ck, cv]) => SOME (mk_update ca ck cv)
   228   | _ =>
   229     (case (sym, try (#T o Thm.rep_cterm o hd) cts, cts) of
   230       (Sym ("+", _), SOME @{typ int}, [ct, cu]) => SOME (mk_add ct cu)
   231     | (Sym ("-", _), SOME @{typ int}, [ct]) => SOME (mk_uminus ct)
   232     | (Sym ("-", _), SOME @{typ int}, [ct, cu]) => SOME (mk_sub ct cu)
   233     | (Sym ("*", _), SOME @{typ int}, [ct, cu]) => SOME (mk_mul ct cu)
   234     | (Sym ("div", _), SOME @{typ int}, [ct, cu]) => SOME (mk_div ct cu)
   235     | (Sym ("mod", _), SOME @{typ int}, [ct, cu]) => SOME (mk_mod ct cu)
   236     | (Sym ("<", _), SOME @{typ int}, [ct, cu]) => SOME (mk_lt ct cu)
   237     | (Sym ("<=", _), SOME @{typ int}, [ct, cu]) => SOME (mk_le ct cu)
   238     | (Sym (">", _), SOME @{typ int}, [ct, cu]) => SOME (mk_lt cu ct)
   239     | (Sym (">=", _), SOME @{typ int}, [ct, cu]) => SOME (mk_le cu ct)
   240     | _ => chained_mk_builtin_fun ctxt (get_mk_builtins ctxt) sym cts))
   241 
   242 
   243 
   244 (** abstraction **)
   245 
   246 fun is_builtin_theory_term ctxt t =
   247   (case try HOLogic.dest_number t of
   248     SOME n => is_builtin_num ctxt n
   249   | NONE =>
   250       (case Term.strip_comb t of
   251         (Const c, ts) => is_builtin_fun ctxt c ts
   252       | _ => false))
   253 
   254 end