src/HOL/Tools/SMT/z3_interface.ML
author haftmann
Sat Aug 28 16:14:32 2010 +0200 (2010-08-28)
changeset 38864 4abe644fcea5
parent 38795 848be46708dc
child 39298 5aefb5bc8a93
permissions -rw-r--r--
formerly unnamed infix equality now named HOL.eq
     1 (*  Title:      HOL/Tools/SMT/z3_interface.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Interface to Z3 based on a relaxed version of SMT-LIB.
     5 *)
     6 
     7 signature Z3_INTERFACE =
     8 sig
     9   type builtin_fun = string * typ -> term list -> (string * term list) option
    10   val add_builtin_funs: builtin_fun -> Context.generic -> Context.generic
    11   val interface: SMT_Solver.interface
    12 
    13   datatype sym = Sym of string * sym list
    14   type mk_builtins = {
    15     mk_builtin_typ: sym -> typ option,
    16     mk_builtin_num: theory -> int -> typ -> cterm option,
    17     mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
    18   val add_mk_builtins: mk_builtins -> Context.generic -> Context.generic
    19   val mk_builtin_typ: Proof.context -> sym -> typ option
    20   val mk_builtin_num: Proof.context -> int -> typ -> cterm option
    21   val mk_builtin_fun: Proof.context -> sym -> cterm list -> cterm option
    22 
    23   val is_builtin_theory_term: Proof.context -> term -> bool
    24 
    25   val mk_inst_pair: (ctyp -> 'a) -> cterm -> 'a * cterm
    26   val destT1: ctyp -> ctyp
    27   val destT2: ctyp -> ctyp
    28   val instT': cterm -> ctyp * cterm -> cterm
    29 end
    30 
    31 structure Z3_Interface: Z3_INTERFACE =
    32 struct
    33 
    34 
    35 (** Z3-specific builtins **)
    36 
    37 type builtin_fun = string * typ -> term list -> (string * term list) option
    38 
    39 fun fst_int_ord ((s1, _), (s2, _)) = int_ord (s1, s2)
    40 
    41 structure Builtins = Generic_Data
    42 (
    43   type T = (int * builtin_fun) list
    44   val empty = []
    45   val extend = I
    46   fun merge (bs1, bs2) = OrdList.union fst_int_ord bs2 bs1
    47 )
    48 
    49 fun add_builtin_funs b =
    50   Builtins.map (OrdList.insert fst_int_ord (serial (), b))
    51 
    52 fun get_builtin_funs ctxt c ts =
    53   let
    54     fun chained [] = NONE
    55       | chained (b :: bs) = (case b c ts of SOME x => SOME x | _ => chained bs)
    56   in chained (map snd (Builtins.get (Context.Proof ctxt))) end
    57 
    58 fun z3_builtin_fun builtin_fun ctxt c ts =
    59   (case builtin_fun ctxt c ts of
    60     SOME x => SOME x
    61   | _ => get_builtin_funs ctxt c ts)
    62 
    63 
    64 
    65 (** interface **)
    66 
    67 local
    68   val {extra_norm, translate} = SMTLIB_Interface.interface
    69   val {prefixes, strict, header, builtins, serialize} = translate
    70   val {is_builtin_pred, ...}= the strict
    71   val {builtin_typ, builtin_num, builtin_fun} = builtins
    72 
    73   fun is_int_div_mod @{term "op div :: int => _"} = true
    74     | is_int_div_mod @{term "op mod :: int => _"} = true
    75     | is_int_div_mod _ = false
    76 
    77   fun add_div_mod thms =
    78     if exists (Term.exists_subterm is_int_div_mod o Thm.prop_of) thms
    79     then [@{thm div_by_z3div}, @{thm mod_by_z3mod}] @ thms
    80     else thms
    81 
    82   fun extra_norm' thms = extra_norm (add_div_mod thms)
    83 
    84   fun z3_builtin_fun' _ (@{const_name z3div}, _) ts = SOME ("div", ts)
    85     | z3_builtin_fun' _ (@{const_name z3mod}, _) ts = SOME ("mod", ts)
    86     | z3_builtin_fun' ctxt c ts = z3_builtin_fun builtin_fun ctxt c ts
    87 
    88   val as_propT = (fn @{typ bool} => @{typ prop} | T => T)
    89 in
    90 
    91 fun is_builtin_num ctxt (T, i) = is_some (builtin_num ctxt T i)
    92 
    93 fun is_builtin_fun ctxt (c as (n, T)) ts =
    94   is_some (z3_builtin_fun' ctxt c ts) orelse 
    95   is_builtin_pred ctxt (n, Term.strip_type T ||> as_propT |> (op --->))
    96 
    97 val interface = {
    98   extra_norm = extra_norm',
    99   translate = {
   100     prefixes = prefixes,
   101     strict = strict,
   102     header = header,
   103     builtins = {
   104       builtin_typ = builtin_typ,
   105       builtin_num = builtin_num,
   106       builtin_fun = z3_builtin_fun'},
   107     serialize = serialize}}
   108 
   109 end
   110 
   111 
   112 
   113 (** constructors **)
   114 
   115 datatype sym = Sym of string * sym list
   116 
   117 
   118 (* additional constructors *)
   119 
   120 type mk_builtins = {
   121   mk_builtin_typ: sym -> typ option,
   122   mk_builtin_num: theory -> int -> typ -> cterm option,
   123   mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
   124 
   125 fun chained _ [] = NONE
   126   | chained f (b :: bs) = (case f b of SOME y => SOME y | NONE => chained f bs)
   127 
   128 fun chained_mk_builtin_typ bs sym =
   129   chained (fn {mk_builtin_typ=mk, ...} : mk_builtins => mk sym) bs
   130 
   131 fun chained_mk_builtin_num ctxt bs i T =
   132   let val thy = ProofContext.theory_of ctxt
   133   in chained (fn {mk_builtin_num=mk, ...} : mk_builtins => mk thy i T) bs end
   134 
   135 fun chained_mk_builtin_fun ctxt bs s cts =
   136   let val thy = ProofContext.theory_of ctxt
   137   in chained (fn {mk_builtin_fun=mk, ...} : mk_builtins => mk thy s cts) bs end
   138 
   139 structure Mk_Builtins = Generic_Data
   140 (
   141   type T = (int * mk_builtins) list
   142   val empty = []
   143   val extend = I
   144   fun merge (bs1, bs2) = OrdList.union fst_int_ord bs2 bs1
   145 )
   146 
   147 fun add_mk_builtins mk =
   148   Mk_Builtins.map (OrdList.insert fst_int_ord (serial (), mk))
   149 
   150 fun get_mk_builtins ctxt = map snd (Mk_Builtins.get (Context.Proof ctxt))
   151 
   152 
   153 (* basic and additional constructors *)
   154 
   155 fun mk_builtin_typ _ (Sym ("bool", _)) = SOME @{typ bool}
   156   | mk_builtin_typ _ (Sym ("int", _)) = SOME @{typ int}
   157   | mk_builtin_typ ctxt sym = chained_mk_builtin_typ (get_mk_builtins ctxt) sym
   158 
   159 fun mk_builtin_num _ i @{typ int} = SOME (Numeral.mk_cnumber @{ctyp int} i)
   160   | mk_builtin_num ctxt i T =
   161       chained_mk_builtin_num ctxt (get_mk_builtins ctxt) i T
   162 
   163 fun instTs cUs (cTs, ct) = Thm.instantiate_cterm (cTs ~~ cUs, []) ct
   164 fun instT cU (cT, ct) = instTs [cU] ([cT], ct)
   165 fun instT' ct = instT (Thm.ctyp_of_term ct)
   166 fun mk_inst_pair destT cpat = (destT (Thm.ctyp_of_term cpat), cpat)
   167 val destT1 = hd o Thm.dest_ctyp
   168 val destT2 = hd o tl o Thm.dest_ctyp
   169 
   170 val mk_true = @{cterm "~False"}
   171 val mk_false = @{cterm False}
   172 val mk_not = Thm.capply @{cterm Not}
   173 val mk_implies = Thm.mk_binop @{cterm HOL.implies}
   174 val mk_iff = Thm.mk_binop @{cterm "op = :: bool => _"}
   175 
   176 fun mk_nary _ cu [] = cu
   177   | mk_nary ct _ cts = uncurry (fold_rev (Thm.mk_binop ct)) (split_last cts)
   178 
   179 val eq = mk_inst_pair destT1 @{cpat HOL.eq}
   180 fun mk_eq ct cu = Thm.mk_binop (instT' ct eq) ct cu
   181 
   182 val if_term = mk_inst_pair (destT1 o destT2) @{cpat If}
   183 fun mk_if cc ct cu = Thm.mk_binop (Thm.capply (instT' ct if_term) cc) ct cu
   184 
   185 val nil_term = mk_inst_pair destT1 @{cpat Nil}
   186 val cons_term = mk_inst_pair destT1 @{cpat Cons}
   187 fun mk_list cT cts =
   188   fold_rev (Thm.mk_binop (instT cT cons_term)) cts (instT cT nil_term)
   189 
   190 val distinct = mk_inst_pair (destT1 o destT1) @{cpat distinct}
   191 fun mk_distinct [] = mk_true
   192   | mk_distinct (cts as (ct :: _)) =
   193       Thm.capply (instT' ct distinct) (mk_list (Thm.ctyp_of_term ct) cts)
   194 
   195 val access = mk_inst_pair (Thm.dest_ctyp o destT1) @{cpat fun_app}
   196 fun mk_access array index =
   197   let val cTs = Thm.dest_ctyp (Thm.ctyp_of_term array)
   198   in Thm.mk_binop (instTs cTs access) array index end
   199 
   200 val update = mk_inst_pair (Thm.dest_ctyp o destT1) @{cpat fun_upd}
   201 fun mk_update array index value =
   202   let val cTs = Thm.dest_ctyp (Thm.ctyp_of_term array)
   203   in Thm.capply (Thm.mk_binop (instTs cTs update) array index) value end
   204 
   205 val mk_uminus = Thm.capply @{cterm "uminus :: int => _"}
   206 val mk_add = Thm.mk_binop @{cterm "op + :: int => _"}
   207 val mk_sub = Thm.mk_binop @{cterm "op - :: int => _"}
   208 val mk_mul = Thm.mk_binop @{cterm "op * :: int => _"}
   209 val mk_div = Thm.mk_binop @{cterm "z3div :: int => _"}
   210 val mk_mod = Thm.mk_binop @{cterm "z3mod :: int => _"}
   211 val mk_lt = Thm.mk_binop @{cterm "op < :: int => _"}
   212 val mk_le = Thm.mk_binop @{cterm "op <= :: int => _"}
   213 
   214 fun mk_builtin_fun ctxt sym cts =
   215   (case (sym, cts) of
   216     (Sym ("true", _), []) => SOME mk_true
   217   | (Sym ("false", _), []) => SOME mk_false
   218   | (Sym ("not", _), [ct]) => SOME (mk_not ct)
   219   | (Sym ("and", _), _) => SOME (mk_nary @{cterm HOL.conj} mk_true cts)
   220   | (Sym ("or", _), _) => SOME (mk_nary @{cterm HOL.disj} mk_false cts)
   221   | (Sym ("implies", _), [ct, cu]) => SOME (mk_implies ct cu)
   222   | (Sym ("iff", _), [ct, cu]) => SOME (mk_iff ct cu)
   223   | (Sym ("~", _), [ct, cu]) => SOME (mk_iff ct cu)
   224   | (Sym ("xor", _), [ct, cu]) => SOME (mk_not (mk_iff ct cu))
   225   | (Sym ("ite", _), [ct1, ct2, ct3]) => SOME (mk_if ct1 ct2 ct3)
   226   | (Sym ("=", _), [ct, cu]) => SOME (mk_eq ct cu)
   227   | (Sym ("distinct", _), _) => SOME (mk_distinct cts)
   228   | (Sym ("select", _), [ca, ck]) => SOME (mk_access ca ck)
   229   | (Sym ("store", _), [ca, ck, cv]) => SOME (mk_update ca ck cv)
   230   | _ =>
   231     (case (sym, try (#T o Thm.rep_cterm o hd) cts, cts) of
   232       (Sym ("+", _), SOME @{typ int}, [ct, cu]) => SOME (mk_add ct cu)
   233     | (Sym ("-", _), SOME @{typ int}, [ct]) => SOME (mk_uminus ct)
   234     | (Sym ("-", _), SOME @{typ int}, [ct, cu]) => SOME (mk_sub ct cu)
   235     | (Sym ("*", _), SOME @{typ int}, [ct, cu]) => SOME (mk_mul ct cu)
   236     | (Sym ("div", _), SOME @{typ int}, [ct, cu]) => SOME (mk_div ct cu)
   237     | (Sym ("mod", _), SOME @{typ int}, [ct, cu]) => SOME (mk_mod ct cu)
   238     | (Sym ("<", _), SOME @{typ int}, [ct, cu]) => SOME (mk_lt ct cu)
   239     | (Sym ("<=", _), SOME @{typ int}, [ct, cu]) => SOME (mk_le ct cu)
   240     | (Sym (">", _), SOME @{typ int}, [ct, cu]) => SOME (mk_lt cu ct)
   241     | (Sym (">=", _), SOME @{typ int}, [ct, cu]) => SOME (mk_le cu ct)
   242     | _ => chained_mk_builtin_fun ctxt (get_mk_builtins ctxt) sym cts))
   243 
   244 
   245 
   246 (** abstraction **)
   247 
   248 fun is_builtin_theory_term ctxt t =
   249   (case try HOLogic.dest_number t of
   250     SOME n => is_builtin_num ctxt n
   251   | NONE =>
   252       (case Term.strip_comb t of
   253         (Const c, ts) => is_builtin_fun ctxt c ts
   254       | _ => false))
   255 
   256 end