src/HOL/Tools/SMT/z3_interface.ML
author boehmes
Wed May 23 16:03:38 2012 +0200 (2012-05-23)
changeset 47965 8ba6438557bc
parent 46497 89ccf66aa73d
child 49720 6279490e0438
permissions -rw-r--r--
extend the Z3 proof parser to accept polyadic addition (on integers and reals) due to changes introduced in Z3 4.0
     1 (*  Title:      HOL/Tools/SMT/z3_interface.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Interface to Z3 based on a relaxed version of SMT-LIB.
     5 *)
     6 
     7 signature Z3_INTERFACE =
     8 sig
     9   val smtlib_z3C: SMT_Utils.class
    10   val setup: theory -> theory
    11 
    12   datatype sym = Sym of string * sym list
    13   type mk_builtins = {
    14     mk_builtin_typ: sym -> typ option,
    15     mk_builtin_num: theory -> int -> typ -> cterm option,
    16     mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
    17   val add_mk_builtins: mk_builtins -> Context.generic -> Context.generic
    18   val mk_builtin_typ: Proof.context -> sym -> typ option
    19   val mk_builtin_num: Proof.context -> int -> typ -> cterm option
    20   val mk_builtin_fun: Proof.context -> sym -> cterm list -> cterm option
    21 
    22   val is_builtin_theory_term: Proof.context -> term -> bool
    23 end
    24 
    25 structure Z3_Interface: Z3_INTERFACE =
    26 struct
    27 
    28 val smtlib_z3C = SMTLIB_Interface.smtlibC @ ["z3"]
    29 
    30 
    31 
    32 (* interface *)
    33 
    34 local
    35   fun translate_config ctxt =
    36     let
    37       val {prefixes, header, is_fol, serialize, ...} =
    38         SMTLIB_Interface.translate_config ctxt
    39     in
    40       {prefixes=prefixes, header=header, is_fol=is_fol, serialize=serialize,
    41         has_datatypes=true}
    42     end
    43 
    44   fun is_div_mod @{const div (int)} = true
    45     | is_div_mod @{const mod (int)} = true
    46     | is_div_mod _ = false
    47 
    48   val div_by_z3div = @{lemma
    49     "ALL k l. k div l = (
    50       if k = 0 | l = 0 then 0
    51       else if (0 < k & 0 < l) | (k < 0 & 0 < l) then z3div k l
    52       else z3div (-k) (-l))"
    53     by (simp add: SMT.z3div_def)}
    54 
    55   val mod_by_z3mod = @{lemma
    56     "ALL k l. k mod l = (
    57       if l = 0 then k
    58       else if k = 0 then 0
    59       else if (0 < k & 0 < l) | (k < 0 & 0 < l) then z3mod k l
    60       else - z3mod (-k) (-l))"
    61     by (simp add: z3mod_def)}
    62 
    63   val have_int_div_mod =
    64     exists (Term.exists_subterm is_div_mod o Thm.prop_of)
    65 
    66   fun add_div_mod _ (thms, extra_thms) =
    67     if have_int_div_mod thms orelse have_int_div_mod extra_thms then
    68       (thms, div_by_z3div :: mod_by_z3mod :: extra_thms)
    69     else (thms, extra_thms)
    70 
    71   val setup_builtins =
    72     SMT_Builtin.add_builtin_fun' smtlib_z3C (@{const times (int)}, "*") #>
    73     SMT_Builtin.add_builtin_fun' smtlib_z3C (@{const z3div}, "div") #>
    74     SMT_Builtin.add_builtin_fun' smtlib_z3C (@{const z3mod}, "mod")
    75 in
    76 
    77 val setup = Context.theory_map (
    78   setup_builtins #>
    79   SMT_Normalize.add_extra_norm (smtlib_z3C, add_div_mod) #>
    80   SMT_Translate.add_config (smtlib_z3C, translate_config))
    81 
    82 end
    83 
    84 
    85 
    86 (* constructors *)
    87 
    88 datatype sym = Sym of string * sym list
    89 
    90 
    91 (** additional constructors **)
    92 
    93 type mk_builtins = {
    94   mk_builtin_typ: sym -> typ option,
    95   mk_builtin_num: theory -> int -> typ -> cterm option,
    96   mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
    97 
    98 fun chained _ [] = NONE
    99   | chained f (b :: bs) = (case f b of SOME y => SOME y | NONE => chained f bs)
   100 
   101 fun chained_mk_builtin_typ bs sym =
   102   chained (fn {mk_builtin_typ=mk, ...} : mk_builtins => mk sym) bs
   103 
   104 fun chained_mk_builtin_num ctxt bs i T =
   105   let val thy = Proof_Context.theory_of ctxt
   106   in chained (fn {mk_builtin_num=mk, ...} : mk_builtins => mk thy i T) bs end
   107 
   108 fun chained_mk_builtin_fun ctxt bs s cts =
   109   let val thy = Proof_Context.theory_of ctxt
   110   in chained (fn {mk_builtin_fun=mk, ...} : mk_builtins => mk thy s cts) bs end
   111 
   112 fun fst_int_ord ((i1, _), (i2, _)) = int_ord (i1, i2)
   113 
   114 structure Mk_Builtins = Generic_Data
   115 (
   116   type T = (int * mk_builtins) list
   117   val empty = []
   118   val extend = I
   119   fun merge data = Ord_List.merge fst_int_ord data
   120 )
   121 
   122 fun add_mk_builtins mk =
   123   Mk_Builtins.map (Ord_List.insert fst_int_ord (serial (), mk))
   124 
   125 fun get_mk_builtins ctxt = map snd (Mk_Builtins.get (Context.Proof ctxt))
   126 
   127 
   128 (** basic and additional constructors **)
   129 
   130 fun mk_builtin_typ _ (Sym ("bool", _)) = SOME @{typ bool}
   131   | mk_builtin_typ _ (Sym ("Int", _)) = SOME @{typ int}
   132   | mk_builtin_typ _ (Sym ("int", _)) = SOME @{typ int}  (*FIXME: delete*)
   133   | mk_builtin_typ ctxt sym = chained_mk_builtin_typ (get_mk_builtins ctxt) sym
   134 
   135 fun mk_builtin_num _ i @{typ int} = SOME (Numeral.mk_cnumber @{ctyp int} i)
   136   | mk_builtin_num ctxt i T =
   137       chained_mk_builtin_num ctxt (get_mk_builtins ctxt) i T
   138 
   139 val mk_true = Thm.cterm_of @{theory} (@{const Not} $ @{const False})
   140 val mk_false = Thm.cterm_of @{theory} @{const False}
   141 val mk_not = Thm.apply (Thm.cterm_of @{theory} @{const Not})
   142 val mk_implies = Thm.mk_binop (Thm.cterm_of @{theory} @{const HOL.implies})
   143 val mk_iff = Thm.mk_binop (Thm.cterm_of @{theory} @{const HOL.eq (bool)})
   144 val conj = Thm.cterm_of @{theory} @{const HOL.conj}
   145 val disj = Thm.cterm_of @{theory} @{const HOL.disj}
   146 
   147 fun mk_nary _ cu [] = cu
   148   | mk_nary ct _ cts = uncurry (fold_rev (Thm.mk_binop ct)) (split_last cts)
   149 
   150 val eq = SMT_Utils.mk_const_pat @{theory} @{const_name HOL.eq} SMT_Utils.destT1
   151 fun mk_eq ct cu = Thm.mk_binop (SMT_Utils.instT' ct eq) ct cu
   152 
   153 val if_term =
   154   SMT_Utils.mk_const_pat @{theory} @{const_name If}
   155     (SMT_Utils.destT1 o SMT_Utils.destT2)
   156 fun mk_if cc ct cu =
   157   Thm.mk_binop (Thm.apply (SMT_Utils.instT' ct if_term) cc) ct cu
   158 
   159 val nil_term =
   160   SMT_Utils.mk_const_pat @{theory} @{const_name Nil} SMT_Utils.destT1
   161 val cons_term =
   162   SMT_Utils.mk_const_pat @{theory} @{const_name Cons} SMT_Utils.destT1
   163 fun mk_list cT cts =
   164   fold_rev (Thm.mk_binop (SMT_Utils.instT cT cons_term)) cts
   165     (SMT_Utils.instT cT nil_term)
   166 
   167 val distinct = SMT_Utils.mk_const_pat @{theory} @{const_name distinct}
   168   (SMT_Utils.destT1 o SMT_Utils.destT1)
   169 fun mk_distinct [] = mk_true
   170   | mk_distinct (cts as (ct :: _)) =
   171       Thm.apply (SMT_Utils.instT' ct distinct)
   172         (mk_list (Thm.ctyp_of_term ct) cts)
   173 
   174 val access =
   175   SMT_Utils.mk_const_pat @{theory} @{const_name fun_app} SMT_Utils.destT1
   176 fun mk_access array = Thm.apply (SMT_Utils.instT' array access) array
   177 
   178 val update = SMT_Utils.mk_const_pat @{theory} @{const_name fun_upd}
   179   (Thm.dest_ctyp o SMT_Utils.destT1)
   180 fun mk_update array index value =
   181   let val cTs = Thm.dest_ctyp (Thm.ctyp_of_term array)
   182   in
   183     Thm.apply (Thm.mk_binop (SMT_Utils.instTs cTs update) array index) value
   184   end
   185 
   186 val mk_uminus = Thm.apply (Thm.cterm_of @{theory} @{const uminus (int)})
   187 val add = Thm.cterm_of @{theory} @{const plus (int)}
   188 val int0 = Numeral.mk_cnumber @{ctyp int} 0
   189 val mk_sub = Thm.mk_binop (Thm.cterm_of @{theory} @{const minus (int)})
   190 val mk_mul = Thm.mk_binop (Thm.cterm_of @{theory} @{const times (int)})
   191 val mk_div = Thm.mk_binop (Thm.cterm_of @{theory} @{const z3div})
   192 val mk_mod = Thm.mk_binop (Thm.cterm_of @{theory} @{const z3mod})
   193 val mk_lt = Thm.mk_binop (Thm.cterm_of @{theory} @{const less (int)})
   194 val mk_le = Thm.mk_binop (Thm.cterm_of @{theory} @{const less_eq (int)})
   195 
   196 fun mk_builtin_fun ctxt sym cts =
   197   (case (sym, cts) of
   198     (Sym ("true", _), []) => SOME mk_true
   199   | (Sym ("false", _), []) => SOME mk_false
   200   | (Sym ("not", _), [ct]) => SOME (mk_not ct)
   201   | (Sym ("and", _), _) => SOME (mk_nary conj mk_true cts)
   202   | (Sym ("or", _), _) => SOME (mk_nary disj mk_false cts)
   203   | (Sym ("implies", _), [ct, cu]) => SOME (mk_implies ct cu)
   204   | (Sym ("iff", _), [ct, cu]) => SOME (mk_iff ct cu)
   205   | (Sym ("~", _), [ct, cu]) => SOME (mk_iff ct cu)
   206   | (Sym ("xor", _), [ct, cu]) => SOME (mk_not (mk_iff ct cu))
   207   | (Sym ("if", _), [ct1, ct2, ct3]) => SOME (mk_if ct1 ct2 ct3)
   208   | (Sym ("ite", _), [ct1, ct2, ct3]) => SOME (mk_if ct1 ct2 ct3) (* FIXME: remove *)
   209   | (Sym ("=", _), [ct, cu]) => SOME (mk_eq ct cu)
   210   | (Sym ("distinct", _), _) => SOME (mk_distinct cts)
   211   | (Sym ("select", _), [ca, ck]) => SOME (Thm.apply (mk_access ca) ck)
   212   | (Sym ("store", _), [ca, ck, cv]) => SOME (mk_update ca ck cv)
   213   | _ =>
   214     (case (sym, try (#T o Thm.rep_cterm o hd) cts, cts) of
   215       (Sym ("+", _), SOME @{typ int}, _) => SOME (mk_nary add int0 cts)
   216     | (Sym ("-", _), SOME @{typ int}, [ct]) => SOME (mk_uminus ct)
   217     | (Sym ("-", _), SOME @{typ int}, [ct, cu]) => SOME (mk_sub ct cu)
   218     | (Sym ("*", _), SOME @{typ int}, [ct, cu]) => SOME (mk_mul ct cu)
   219     | (Sym ("div", _), SOME @{typ int}, [ct, cu]) => SOME (mk_div ct cu)
   220     | (Sym ("mod", _), SOME @{typ int}, [ct, cu]) => SOME (mk_mod ct cu)
   221     | (Sym ("<", _), SOME @{typ int}, [ct, cu]) => SOME (mk_lt ct cu)
   222     | (Sym ("<=", _), SOME @{typ int}, [ct, cu]) => SOME (mk_le ct cu)
   223     | (Sym (">", _), SOME @{typ int}, [ct, cu]) => SOME (mk_lt cu ct)
   224     | (Sym (">=", _), SOME @{typ int}, [ct, cu]) => SOME (mk_le cu ct)
   225     | _ => chained_mk_builtin_fun ctxt (get_mk_builtins ctxt) sym cts))
   226 
   227 
   228 
   229 (* abstraction *)
   230 
   231 fun is_builtin_theory_term ctxt t =
   232   if SMT_Builtin.is_builtin_num ctxt t then true
   233   else
   234     (case Term.strip_comb t of
   235       (Const c, ts) => SMT_Builtin.is_builtin_fun ctxt c ts
   236     | _ => false)
   237 
   238 end