src/HOL/Tools/SMT2/z3_new_interface.ML
changeset 58061 3d060f43accb
parent 58060 835b5443b978
child 58062 f4d8987656b9
equal deleted inserted replaced
58060:835b5443b978 58061:3d060f43accb
     1 (*  Title:      HOL/Tools/SMT2/z3_new_interface.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3 
       
     4 Interface to Z3 based on a relaxed version of SMT-LIB.
       
     5 *)
       
     6 
       
     7 signature Z3_NEW_INTERFACE =
       
     8 sig
       
     9   val smtlib2_z3C: SMT2_Util.class
       
    10 
       
    11   datatype sym = Sym of string * sym list
       
    12   type mk_builtins = {
       
    13     mk_builtin_typ: sym -> typ option,
       
    14     mk_builtin_num: theory -> int -> typ -> cterm option,
       
    15     mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
       
    16   val add_mk_builtins: mk_builtins -> Context.generic -> Context.generic
       
    17   val mk_builtin_typ: Proof.context -> sym -> typ option
       
    18   val mk_builtin_num: Proof.context -> int -> typ -> cterm option
       
    19   val mk_builtin_fun: Proof.context -> sym -> cterm list -> cterm option
       
    20 
       
    21   val is_builtin_theory_term: Proof.context -> term -> bool
       
    22 end;
       
    23 
       
    24 structure Z3_New_Interface: Z3_NEW_INTERFACE =
       
    25 struct
       
    26 
       
    27 val smtlib2_z3C = SMTLIB2_Interface.smtlib2C @ ["z3"]
       
    28 
       
    29 
       
    30 (* interface *)
       
    31 
       
    32 local
       
    33   fun translate_config ctxt =
       
    34     {logic = K "", has_datatypes = true,
       
    35      serialize = #serialize (SMTLIB2_Interface.translate_config ctxt)}
       
    36 
       
    37   fun is_div_mod @{const div (int)} = true
       
    38     | is_div_mod @{const mod (int)} = true
       
    39     | is_div_mod _ = false
       
    40 
       
    41   val have_int_div_mod = exists (Term.exists_subterm is_div_mod o Thm.prop_of)
       
    42 
       
    43   fun add_div_mod _ (thms, extra_thms) =
       
    44     if have_int_div_mod thms orelse have_int_div_mod extra_thms then
       
    45       (thms, @{thms div_as_z3div mod_as_z3mod} @ extra_thms)
       
    46     else (thms, extra_thms)
       
    47 
       
    48   val setup_builtins =
       
    49     SMT2_Builtin.add_builtin_fun' smtlib2_z3C (@{const times (int)}, "*") #>
       
    50     SMT2_Builtin.add_builtin_fun' smtlib2_z3C (@{const z3div}, "div") #>
       
    51     SMT2_Builtin.add_builtin_fun' smtlib2_z3C (@{const z3mod}, "mod")
       
    52 in
       
    53 
       
    54 val _ = Theory.setup (Context.theory_map (
       
    55   setup_builtins #>
       
    56   SMT2_Normalize.add_extra_norm (smtlib2_z3C, add_div_mod) #>
       
    57   SMT2_Translate.add_config (smtlib2_z3C, translate_config)))
       
    58 
       
    59 end
       
    60 
       
    61 
       
    62 (* constructors *)
       
    63 
       
    64 datatype sym = Sym of string * sym list
       
    65 
       
    66 
       
    67 (** additional constructors **)
       
    68 
       
    69 type mk_builtins = {
       
    70   mk_builtin_typ: sym -> typ option,
       
    71   mk_builtin_num: theory -> int -> typ -> cterm option,
       
    72   mk_builtin_fun: theory -> sym -> cterm list -> cterm option }
       
    73 
       
    74 fun chained _ [] = NONE
       
    75   | chained f (b :: bs) = (case f b of SOME y => SOME y | NONE => chained f bs)
       
    76 
       
    77 fun chained_mk_builtin_typ bs sym =
       
    78   chained (fn {mk_builtin_typ=mk, ...} : mk_builtins => mk sym) bs
       
    79 
       
    80 fun chained_mk_builtin_num ctxt bs i T =
       
    81   let val thy = Proof_Context.theory_of ctxt
       
    82   in chained (fn {mk_builtin_num=mk, ...} : mk_builtins => mk thy i T) bs end
       
    83 
       
    84 fun chained_mk_builtin_fun ctxt bs s cts =
       
    85   let val thy = Proof_Context.theory_of ctxt
       
    86   in chained (fn {mk_builtin_fun=mk, ...} : mk_builtins => mk thy s cts) bs end
       
    87 
       
    88 fun fst_int_ord ((i1, _), (i2, _)) = int_ord (i1, i2)
       
    89 
       
    90 structure Mk_Builtins = Generic_Data
       
    91 (
       
    92   type T = (int * mk_builtins) list
       
    93   val empty = []
       
    94   val extend = I
       
    95   fun merge data = Ord_List.merge fst_int_ord data
       
    96 )
       
    97 
       
    98 fun add_mk_builtins mk = Mk_Builtins.map (Ord_List.insert fst_int_ord (serial (), mk))
       
    99 
       
   100 fun get_mk_builtins ctxt = map snd (Mk_Builtins.get (Context.Proof ctxt))
       
   101 
       
   102 
       
   103 (** basic and additional constructors **)
       
   104 
       
   105 fun mk_builtin_typ _ (Sym ("Bool", _)) = SOME @{typ bool}
       
   106   | mk_builtin_typ _ (Sym ("Int", _)) = SOME @{typ int}
       
   107   | mk_builtin_typ _ (Sym ("bool", _)) = SOME @{typ bool}  (*FIXME: legacy*)
       
   108   | mk_builtin_typ _ (Sym ("int", _)) = SOME @{typ int}  (*FIXME: legacy*)
       
   109   | mk_builtin_typ ctxt sym = chained_mk_builtin_typ (get_mk_builtins ctxt) sym
       
   110 
       
   111 fun mk_builtin_num _ i @{typ int} = SOME (Numeral.mk_cnumber @{ctyp int} i)
       
   112   | mk_builtin_num ctxt i T =
       
   113       chained_mk_builtin_num ctxt (get_mk_builtins ctxt) i T
       
   114 
       
   115 val mk_true = Thm.cterm_of @{theory} (@{const Not} $ @{const False})
       
   116 val mk_false = Thm.cterm_of @{theory} @{const False}
       
   117 val mk_not = Thm.apply (Thm.cterm_of @{theory} @{const Not})
       
   118 val mk_implies = Thm.mk_binop (Thm.cterm_of @{theory} @{const HOL.implies})
       
   119 val mk_iff = Thm.mk_binop (Thm.cterm_of @{theory} @{const HOL.eq (bool)})
       
   120 val conj = Thm.cterm_of @{theory} @{const HOL.conj}
       
   121 val disj = Thm.cterm_of @{theory} @{const HOL.disj}
       
   122 
       
   123 fun mk_nary _ cu [] = cu
       
   124   | mk_nary ct _ cts = uncurry (fold_rev (Thm.mk_binop ct)) (split_last cts)
       
   125 
       
   126 val eq = SMT2_Util.mk_const_pat @{theory} @{const_name HOL.eq} SMT2_Util.destT1
       
   127 fun mk_eq ct cu = Thm.mk_binop (SMT2_Util.instT' ct eq) ct cu
       
   128 
       
   129 val if_term =
       
   130   SMT2_Util.mk_const_pat @{theory} @{const_name If} (SMT2_Util.destT1 o SMT2_Util.destT2)
       
   131 fun mk_if cc ct = Thm.mk_binop (Thm.apply (SMT2_Util.instT' ct if_term) cc) ct
       
   132 
       
   133 val access = SMT2_Util.mk_const_pat @{theory} @{const_name fun_app} SMT2_Util.destT1
       
   134 fun mk_access array = Thm.apply (SMT2_Util.instT' array access) array
       
   135 
       
   136 val update =
       
   137   SMT2_Util.mk_const_pat @{theory} @{const_name fun_upd} (Thm.dest_ctyp o SMT2_Util.destT1)
       
   138 fun mk_update array index value =
       
   139   let val cTs = Thm.dest_ctyp (Thm.ctyp_of_term array)
       
   140   in Thm.apply (Thm.mk_binop (SMT2_Util.instTs cTs update) array index) value end
       
   141 
       
   142 val mk_uminus = Thm.apply (Thm.cterm_of @{theory} @{const uminus (int)})
       
   143 val add = Thm.cterm_of @{theory} @{const plus (int)}
       
   144 val int0 = Numeral.mk_cnumber @{ctyp int} 0
       
   145 val mk_sub = Thm.mk_binop (Thm.cterm_of @{theory} @{const minus (int)})
       
   146 val mk_mul = Thm.mk_binop (Thm.cterm_of @{theory} @{const times (int)})
       
   147 val mk_div = Thm.mk_binop (Thm.cterm_of @{theory} @{const z3div})
       
   148 val mk_mod = Thm.mk_binop (Thm.cterm_of @{theory} @{const z3mod})
       
   149 val mk_lt = Thm.mk_binop (Thm.cterm_of @{theory} @{const less (int)})
       
   150 val mk_le = Thm.mk_binop (Thm.cterm_of @{theory} @{const less_eq (int)})
       
   151 
       
   152 fun mk_builtin_fun ctxt sym cts =
       
   153   (case (sym, cts) of
       
   154     (Sym ("true", _), []) => SOME mk_true
       
   155   | (Sym ("false", _), []) => SOME mk_false
       
   156   | (Sym ("not", _), [ct]) => SOME (mk_not ct)
       
   157   | (Sym ("and", _), _) => SOME (mk_nary conj mk_true cts)
       
   158   | (Sym ("or", _), _) => SOME (mk_nary disj mk_false cts)
       
   159   | (Sym ("implies", _), [ct, cu]) => SOME (mk_implies ct cu)
       
   160   | (Sym ("iff", _), [ct, cu]) => SOME (mk_iff ct cu)
       
   161   | (Sym ("~", _), [ct, cu]) => SOME (mk_iff ct cu)
       
   162   | (Sym ("xor", _), [ct, cu]) => SOME (mk_not (mk_iff ct cu))
       
   163   | (Sym ("if", _), [ct1, ct2, ct3]) => SOME (mk_if ct1 ct2 ct3)
       
   164   | (Sym ("ite", _), [ct1, ct2, ct3]) => SOME (mk_if ct1 ct2 ct3) (* FIXME: remove *)
       
   165   | (Sym ("=", _), [ct, cu]) => SOME (mk_eq ct cu)
       
   166   | (Sym ("select", _), [ca, ck]) => SOME (Thm.apply (mk_access ca) ck)
       
   167   | (Sym ("store", _), [ca, ck, cv]) => SOME (mk_update ca ck cv)
       
   168   | _ =>
       
   169     (case (sym, try (#T o Thm.rep_cterm o hd) cts, cts) of
       
   170       (Sym ("+", _), SOME @{typ int}, _) => SOME (mk_nary add int0 cts)
       
   171     | (Sym ("-", _), SOME @{typ int}, [ct]) => SOME (mk_uminus ct)
       
   172     | (Sym ("-", _), SOME @{typ int}, [ct, cu]) => SOME (mk_sub ct cu)
       
   173     | (Sym ("*", _), SOME @{typ int}, [ct, cu]) => SOME (mk_mul ct cu)
       
   174     | (Sym ("div", _), SOME @{typ int}, [ct, cu]) => SOME (mk_div ct cu)
       
   175     | (Sym ("mod", _), SOME @{typ int}, [ct, cu]) => SOME (mk_mod ct cu)
       
   176     | (Sym ("<", _), SOME @{typ int}, [ct, cu]) => SOME (mk_lt ct cu)
       
   177     | (Sym ("<=", _), SOME @{typ int}, [ct, cu]) => SOME (mk_le ct cu)
       
   178     | (Sym (">", _), SOME @{typ int}, [ct, cu]) => SOME (mk_lt cu ct)
       
   179     | (Sym (">=", _), SOME @{typ int}, [ct, cu]) => SOME (mk_le cu ct)
       
   180     | _ => chained_mk_builtin_fun ctxt (get_mk_builtins ctxt) sym cts))
       
   181 
       
   182 
       
   183 (* abstraction *)
       
   184 
       
   185 fun is_builtin_theory_term ctxt t =
       
   186   if SMT2_Builtin.is_builtin_num ctxt t then true
       
   187   else
       
   188     (case Term.strip_comb t of
       
   189       (Const c, ts) => SMT2_Builtin.is_builtin_fun ctxt c ts
       
   190     | _ => false)
       
   191 
       
   192 end;