src/HOL/Tools/SMT/smtlib_proof.ML
author nipkow
Wed Jan 10 15:25:09 2018 +0100 (19 months ago)
changeset 67399 eab6ce8368fa
parent 58061 3d060f43accb
child 69205 8050734eee3e
permissions -rw-r--r--
ran isabelle update_op on all sources
     1 (*  Title:      HOL/Tools/SMT/smtlib_proof.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3     Author:     Mathias Fleury, ENS Rennes
     4     Author:     Jasmin Blanchette, TU Muenchen
     5 
     6 SMT-LIB-2-style proofs: parsing and abstract syntax tree.
     7 *)
     8 
     9 signature SMTLIB_PROOF =
    10 sig
    11   datatype 'b shared = Tree of SMTLIB.tree | Term of term | Proof of 'b | None
    12   type ('a, 'b) context
    13 
    14   val mk_context: Proof.context -> int -> 'b shared Symtab.table -> typ Symtab.table ->
    15     term Symtab.table -> 'a -> ('a, 'b) context
    16   val empty_context: Proof.context -> typ Symtab.table -> term Symtab.table -> ('a list, 'b) context
    17   val ctxt_of: ('a, 'b) context -> Proof.context
    18   val lookup_binding: ('a, 'b) context -> string -> 'b shared
    19   val update_binding: string * 'b shared -> ('a, 'b) context -> ('a, 'b) context
    20   val with_bindings: (string * 'b shared) list -> (('a, 'b) context -> 'c * ('d, 'b) context) ->
    21     ('a, 'b) context -> 'c * ('d, 'b) context
    22   val next_id: ('a, 'b) context -> int * ('a, 'b) context
    23   val with_fresh_names: (('a list, 'b) context ->
    24     term * ((string * (string * typ)) list, 'b) context) -> ('c, 'b) context -> (term * string list)
    25 
    26   (*type and term parsers*)
    27   type type_parser = SMTLIB.tree * typ list -> typ option
    28   type term_parser = SMTLIB.tree * term list -> term option
    29   val add_type_parser: type_parser -> Context.generic -> Context.generic
    30   val add_term_parser: term_parser -> Context.generic -> Context.generic
    31 
    32   exception SMTLIB_PARSE of string * SMTLIB.tree
    33 
    34   val declare_fun: string -> typ -> ((string * typ) list, 'a) context ->
    35     ((string * typ) list, 'a) context
    36   val dest_binding: SMTLIB.tree -> string * 'a shared
    37   val type_of: ('a, 'b) context -> SMTLIB.tree -> typ
    38   val term_of: SMTLIB.tree -> ((string * (string * typ)) list, 'a) context ->
    39     term * ((string * (string * typ)) list, 'a) context
    40 end;
    41 
    42 structure SMTLIB_Proof: SMTLIB_PROOF =
    43 struct
    44 
    45 (* proof parser context *)
    46 
    47 datatype 'b shared = Tree of SMTLIB.tree | Term of term | Proof of 'b | None
    48 
    49 type ('a, 'b) context = {
    50   ctxt: Proof.context,
    51   id: int,
    52   syms: 'b shared Symtab.table,
    53   typs: typ Symtab.table,
    54   funs: term Symtab.table,
    55   extra: 'a}
    56 
    57 fun mk_context ctxt id syms typs funs extra: ('a, 'b) context =
    58   {ctxt = ctxt, id = id, syms = syms, typs = typs, funs = funs, extra = extra}
    59 
    60 fun empty_context ctxt typs funs = mk_context ctxt 1 Symtab.empty typs funs []
    61 
    62 fun ctxt_of ({ctxt, ...}: ('a, 'b) context) = ctxt
    63 
    64 fun lookup_binding ({syms, ...}: ('a, 'b) context) =
    65   the_default None o Symtab.lookup syms
    66 
    67 fun map_syms f ({ctxt, id, syms, typs, funs, extra}: ('a, 'b) context) =
    68   mk_context ctxt id (f syms) typs funs extra
    69 
    70 fun update_binding b = map_syms (Symtab.update b)
    71 
    72 fun with_bindings bs f cx =
    73   let val bs' = map (lookup_binding cx o fst) bs
    74   in
    75     cx
    76     |> fold update_binding bs
    77     |> f
    78     ||> fold2 (fn (name, _) => update_binding o pair name) bs bs'
    79   end
    80 
    81 fun next_id ({ctxt, id, syms, typs, funs, extra}: ('a, 'b) context) =
    82   (id, mk_context ctxt (id + 1) syms typs funs extra)
    83 
    84 fun with_fresh_names f ({ctxt, id, syms, typs, funs, ...}: ('a, 'b) context) =
    85   let
    86     fun bind (_, v as (_, T)) t = Logic.all_const T $ Term.absfree v t
    87 
    88     val needs_inferT = equal Term.dummyT orf Term.is_TVar
    89     val needs_infer = Term.exists_type (Term.exists_subtype needs_inferT)
    90     fun infer_types ctxt =
    91       singleton (Type_Infer_Context.infer_types ctxt) #>
    92       singleton (Proof_Context.standard_term_check_finish ctxt)
    93     fun infer ctxt t = if needs_infer t then infer_types ctxt t else t
    94 
    95     val (t, {ctxt = ctxt', extra = names, ...}: ((string * (string * typ)) list, 'b) context) =
    96       f (mk_context ctxt id syms typs funs [])
    97     val t' = infer ctxt' (fold_rev bind names (HOLogic.mk_Trueprop t))
    98   in
    99     (t', map fst names)
   100   end
   101 
   102 fun lookup_typ ({typs, ...}: ('a, 'b) context) = Symtab.lookup typs
   103 fun lookup_fun ({funs, ...}: ('a, 'b) context) = Symtab.lookup funs
   104 
   105 
   106 (* core type and term parser *)
   107 
   108 fun core_type_parser (SMTLIB.Sym "Bool", []) = SOME @{typ HOL.bool}
   109   | core_type_parser (SMTLIB.Sym "Int", []) = SOME @{typ Int.int}
   110   | core_type_parser _ = NONE
   111 
   112 fun mk_unary n t =
   113   let val T = fastype_of t
   114   in Const (n, T --> T) $ t end
   115 
   116 fun mk_binary' n T U t1 t2 = Const (n, [T, T] ---> U) $ t1 $ t2
   117 
   118 fun mk_binary n t1 t2 =
   119   let val T = fastype_of t1
   120   in mk_binary' n T T t1 t2 end
   121 
   122 fun mk_rassoc f t ts =
   123   let val us = rev (t :: ts)
   124   in fold f (tl us) (hd us) end
   125 
   126 fun mk_lassoc f t ts = fold (fn u1 => fn u2 => f u2 u1) ts t
   127 
   128 fun mk_lassoc' n = mk_lassoc (mk_binary n)
   129 
   130 fun mk_binary_pred n S t1 t2 =
   131   let
   132     val T1 = fastype_of t1
   133     val T2 = fastype_of t2
   134     val T =
   135       if T1 <> Term.dummyT then T1
   136       else if T2 <> Term.dummyT then T2
   137       else TVar (("?a", serial ()), S)
   138   in mk_binary' n T @{typ HOL.bool} t1 t2 end
   139 
   140 fun mk_less t1 t2 = mk_binary_pred @{const_name ord_class.less} @{sort linorder} t1 t2
   141 fun mk_less_eq t1 t2 = mk_binary_pred @{const_name ord_class.less_eq} @{sort linorder} t1 t2
   142 
   143 fun core_term_parser (SMTLIB.Sym "true", _) = SOME @{const HOL.True}
   144   | core_term_parser (SMTLIB.Sym "false", _) = SOME @{const HOL.False}
   145   | core_term_parser (SMTLIB.Sym "not", [t]) = SOME (HOLogic.mk_not t)
   146   | core_term_parser (SMTLIB.Sym "and", t :: ts) = SOME (mk_rassoc (curry HOLogic.mk_conj) t ts)
   147   | core_term_parser (SMTLIB.Sym "or", t :: ts) = SOME (mk_rassoc (curry HOLogic.mk_disj) t ts)
   148   | core_term_parser (SMTLIB.Sym "=>", [t1, t2]) = SOME (HOLogic.mk_imp (t1, t2))
   149   | core_term_parser (SMTLIB.Sym "implies", [t1, t2]) = SOME (HOLogic.mk_imp (t1, t2))
   150   | core_term_parser (SMTLIB.Sym "=", [t1, t2]) = SOME (HOLogic.mk_eq (t1, t2))
   151   | core_term_parser (SMTLIB.Sym "~", [t1, t2]) = SOME (HOLogic.mk_eq (t1, t2))
   152   | core_term_parser (SMTLIB.Sym "ite", [t1, t2, t3]) =
   153       let
   154         val T = fastype_of t2
   155         val c = Const (@{const_name HOL.If}, [@{typ HOL.bool}, T, T] ---> T)
   156       in SOME (c $ t1 $ t2 $ t3) end
   157   | core_term_parser (SMTLIB.Num i, []) = SOME (HOLogic.mk_number @{typ Int.int} i)
   158   | core_term_parser (SMTLIB.Sym "-", [t]) = SOME (mk_unary @{const_name uminus_class.uminus} t)
   159   | core_term_parser (SMTLIB.Sym "~", [t]) = SOME (mk_unary @{const_name uminus_class.uminus} t)
   160   | core_term_parser (SMTLIB.Sym "+", t :: ts) =
   161       SOME (mk_lassoc' @{const_name plus_class.plus} t ts)
   162   | core_term_parser (SMTLIB.Sym "-", t :: ts) =
   163       SOME (mk_lassoc' @{const_name minus_class.minus} t ts)
   164   | core_term_parser (SMTLIB.Sym "*", t :: ts) =
   165       SOME (mk_lassoc' @{const_name times_class.times} t ts)
   166   | core_term_parser (SMTLIB.Sym "div", [t1, t2]) = SOME (mk_binary @{const_name z3div} t1 t2)
   167   | core_term_parser (SMTLIB.Sym "mod", [t1, t2]) = SOME (mk_binary @{const_name z3mod} t1 t2)
   168   | core_term_parser (SMTLIB.Sym "<", [t1, t2]) = SOME (mk_less t1 t2)
   169   | core_term_parser (SMTLIB.Sym ">", [t1, t2]) = SOME (mk_less t2 t1)
   170   | core_term_parser (SMTLIB.Sym "<=", [t1, t2]) = SOME (mk_less_eq t1 t2)
   171   | core_term_parser (SMTLIB.Sym ">=", [t1, t2]) = SOME (mk_less_eq t2 t1)
   172   | core_term_parser _ = NONE
   173 
   174 
   175 (* custom type and term parsers *)
   176 
   177 type type_parser = SMTLIB.tree * typ list -> typ option
   178 
   179 type term_parser = SMTLIB.tree * term list -> term option
   180 
   181 fun id_ord ((id1, _), (id2, _)) = int_ord (id1, id2)
   182 
   183 structure Parsers = Generic_Data
   184 (
   185   type T = (int * type_parser) list * (int * term_parser) list
   186   val empty : T = ([(serial (), core_type_parser)], [(serial (), core_term_parser)])
   187   val extend = I
   188   fun merge ((tys1, ts1), (tys2, ts2)) =
   189     (Ord_List.merge id_ord (tys1, tys2), Ord_List.merge id_ord (ts1, ts2))
   190 )
   191 
   192 fun add_type_parser type_parser =
   193   Parsers.map (apfst (Ord_List.insert id_ord (serial (), type_parser)))
   194 
   195 fun add_term_parser term_parser =
   196   Parsers.map (apsnd (Ord_List.insert id_ord (serial (), term_parser)))
   197 
   198 fun get_type_parsers ctxt = map snd (fst (Parsers.get (Context.Proof ctxt)))
   199 fun get_term_parsers ctxt = map snd (snd (Parsers.get (Context.Proof ctxt)))
   200 
   201 fun apply_parsers parsers x =
   202   let
   203     fun apply [] = NONE
   204       | apply (parser :: parsers) =
   205           (case parser x of
   206             SOME y => SOME y
   207           | NONE => apply parsers)
   208   in apply parsers end
   209 
   210 
   211 (* type and term parsing *)
   212 
   213 exception SMTLIB_PARSE of string * SMTLIB.tree
   214 
   215 val desymbolize = Name.desymbolize (SOME false) o perhaps (try (unprefix "?"))
   216 
   217 fun fresh_fun add name n T ({ctxt, id, syms, typs, funs, extra}: ('a, 'b) context) =
   218   let
   219     val (n', ctxt') = yield_singleton Variable.variant_fixes n ctxt
   220     val t = Free (n', T)
   221     val funs' = Symtab.update (name, t) funs
   222   in (t, mk_context ctxt' id syms typs funs' (add (n', T) extra)) end
   223 
   224 fun declare_fun name = snd oo fresh_fun cons name (desymbolize name)
   225 fun declare_free name = fresh_fun (cons o pair name) name (desymbolize name)
   226 
   227 fun parse_type cx ty Ts =
   228   (case apply_parsers (get_type_parsers (ctxt_of cx)) (ty, Ts) of
   229     SOME T => T
   230   | NONE =>
   231       (case ty of
   232         SMTLIB.Sym name =>
   233           (case lookup_typ cx name of
   234             SOME T => T
   235           | NONE => raise SMTLIB_PARSE ("unknown SMT type", ty))
   236       | _ => raise SMTLIB_PARSE ("bad SMT type format", ty)))
   237 
   238 fun parse_term t ts cx =
   239   (case apply_parsers (get_term_parsers (ctxt_of cx)) (t, ts) of
   240     SOME u => (u, cx)
   241   | NONE =>
   242       (case t of
   243         SMTLIB.Sym name =>
   244           (case lookup_fun cx name of
   245             SOME u => (Term.list_comb (u, ts), cx)
   246           | NONE =>
   247               if null ts then declare_free name Term.dummyT cx
   248               else raise SMTLIB_PARSE ("bad SMT term", t))
   249       | _ => raise SMTLIB_PARSE ("bad SMT term format", t)))
   250 
   251 fun type_of cx ty =
   252   (case try (parse_type cx ty) [] of
   253     SOME T => T
   254   | NONE =>
   255       (case ty of
   256         SMTLIB.S (ty' :: tys) => parse_type cx ty' (map (type_of cx) tys)
   257       | _ => raise SMTLIB_PARSE ("bad SMT type", ty)))
   258 
   259 fun dest_var cx (SMTLIB.S [SMTLIB.Sym name, ty]) = (name, (desymbolize name, type_of cx ty))
   260   | dest_var _ v = raise SMTLIB_PARSE ("bad SMT quantifier variable format", v)
   261 
   262 fun dest_body (SMTLIB.S (SMTLIB.Sym "!" :: body :: _)) = dest_body body
   263   | dest_body body = body
   264 
   265 fun dest_binding (SMTLIB.S [SMTLIB.Sym name, t]) = (name, Tree t)
   266   | dest_binding b = raise SMTLIB_PARSE ("bad SMT let binding format", b)
   267 
   268 fun term_of t cx =
   269   (case t of
   270     SMTLIB.S [SMTLIB.Sym "forall", SMTLIB.S vars, body] => quant HOLogic.mk_all vars body cx
   271   | SMTLIB.S [SMTLIB.Sym "exists", SMTLIB.S vars, body] => quant HOLogic.mk_exists vars body cx
   272   | SMTLIB.S [SMTLIB.Sym "let", SMTLIB.S bindings, body] =>
   273       with_bindings (map dest_binding bindings) (term_of body) cx
   274   | SMTLIB.S (SMTLIB.Sym "!" :: t :: _) => term_of t cx
   275   | SMTLIB.S (f :: args) =>
   276       cx
   277       |> fold_map term_of args
   278       |-> parse_term f
   279   | SMTLIB.Sym name =>
   280       (case lookup_binding cx name of
   281         Tree u =>
   282           cx
   283           |> term_of u
   284           |-> (fn u' => pair u' o update_binding (name, Term u'))
   285       | Term u => (u, cx)
   286       | None => parse_term t [] cx
   287       | _ => raise SMTLIB_PARSE ("bad SMT term format", t))
   288   | _ => parse_term t [] cx)
   289 
   290 and quant q vars body cx =
   291   let val vs = map (dest_var cx) vars
   292   in
   293     cx
   294     |> with_bindings (map (apsnd (Term o Free)) vs) (term_of (dest_body body))
   295     |>> fold_rev (fn (_, (n, T)) => fn t => q (n, T, t)) vs
   296   end
   297 
   298 end;