src/HOL/Tools/SMT/smtlib_proof.ML
author nipkow
Wed Jan 10 15:25:09 2018 +0100 (19 months ago)
changeset 67399 eab6ce8368fa
parent 58061 3d060f43accb
child 69205 8050734eee3e
permissions -rw-r--r--
ran isabelle update_op on all sources
blanchet@58061
     1
(*  Title:      HOL/Tools/SMT/smtlib_proof.ML
blanchet@56078
     2
    Author:     Sascha Boehme, TU Muenchen
blanchet@57219
     3
    Author:     Mathias Fleury, ENS Rennes
blanchet@57219
     4
    Author:     Jasmin Blanchette, TU Muenchen
blanchet@56078
     5
blanchet@57219
     6
SMT-LIB-2-style proofs: parsing and abstract syntax tree.
blanchet@56078
     7
*)
blanchet@56078
     8
blanchet@58061
     9
signature SMTLIB_PROOF =
blanchet@56078
    10
sig
blanchet@58061
    11
  datatype 'b shared = Tree of SMTLIB.tree | Term of term | Proof of 'b | None
blanchet@57221
    12
  type ('a, 'b) context
blanchet@57219
    13
blanchet@57219
    14
  val mk_context: Proof.context -> int -> 'b shared Symtab.table -> typ Symtab.table ->
blanchet@57219
    15
    term Symtab.table -> 'a -> ('a, 'b) context
blanchet@57219
    16
  val empty_context: Proof.context -> typ Symtab.table -> term Symtab.table -> ('a list, 'b) context
blanchet@57219
    17
  val ctxt_of: ('a, 'b) context -> Proof.context
blanchet@57219
    18
  val lookup_binding: ('a, 'b) context -> string -> 'b shared
blanchet@57219
    19
  val update_binding: string * 'b shared -> ('a, 'b) context -> ('a, 'b) context
blanchet@57219
    20
  val with_bindings: (string * 'b shared) list -> (('a, 'b) context -> 'c * ('d, 'b) context) ->
blanchet@57219
    21
    ('a, 'b) context -> 'c * ('d, 'b) context
blanchet@57222
    22
  val next_id: ('a, 'b) context -> int * ('a, 'b) context
blanchet@57222
    23
  val with_fresh_names: (('a list, 'b) context ->
blanchet@57747
    24
    term * ((string * (string * typ)) list, 'b) context) -> ('c, 'b) context -> (term * string list)
blanchet@56078
    25
blanchet@56078
    26
  (*type and term parsers*)
blanchet@58061
    27
  type type_parser = SMTLIB.tree * typ list -> typ option
blanchet@58061
    28
  type term_parser = SMTLIB.tree * term list -> term option
blanchet@56078
    29
  val add_type_parser: type_parser -> Context.generic -> Context.generic
blanchet@56078
    30
  val add_term_parser: term_parser -> Context.generic -> Context.generic
blanchet@56078
    31
blanchet@58061
    32
  exception SMTLIB_PARSE of string * SMTLIB.tree
blanchet@56078
    33
blanchet@57219
    34
  val declare_fun: string -> typ -> ((string * typ) list, 'a) context ->
blanchet@57219
    35
    ((string * typ) list, 'a) context
blanchet@58061
    36
  val dest_binding: SMTLIB.tree -> string * 'a shared
blanchet@58061
    37
  val type_of: ('a, 'b) context -> SMTLIB.tree -> typ
blanchet@58061
    38
  val term_of: SMTLIB.tree -> ((string * (string * typ)) list, 'a) context ->
blanchet@57219
    39
    term * ((string * (string * typ)) list, 'a) context
blanchet@57219
    40
end;
blanchet@57219
    41
blanchet@58061
    42
structure SMTLIB_Proof: SMTLIB_PROOF =
blanchet@56078
    43
struct
blanchet@56078
    44
blanchet@57219
    45
(* proof parser context *)
blanchet@57219
    46
blanchet@58061
    47
datatype 'b shared = Tree of SMTLIB.tree | Term of term | Proof of 'b | None
blanchet@56078
    48
blanchet@57219
    49
type ('a, 'b) context = {
blanchet@57219
    50
  ctxt: Proof.context,
blanchet@57219
    51
  id: int,
blanchet@57219
    52
  syms: 'b shared Symtab.table,
blanchet@57219
    53
  typs: typ Symtab.table,
blanchet@57219
    54
  funs: term Symtab.table,
blanchet@57219
    55
  extra: 'a}
blanchet@57219
    56
blanchet@57219
    57
fun mk_context ctxt id syms typs funs extra: ('a, 'b) context =
blanchet@57747
    58
  {ctxt = ctxt, id = id, syms = syms, typs = typs, funs = funs, extra = extra}
blanchet@57219
    59
blanchet@57219
    60
fun empty_context ctxt typs funs = mk_context ctxt 1 Symtab.empty typs funs []
blanchet@57219
    61
blanchet@57219
    62
fun ctxt_of ({ctxt, ...}: ('a, 'b) context) = ctxt
blanchet@56078
    63
blanchet@57219
    64
fun lookup_binding ({syms, ...}: ('a, 'b) context) =
blanchet@57219
    65
  the_default None o Symtab.lookup syms
blanchet@57219
    66
blanchet@57219
    67
fun map_syms f ({ctxt, id, syms, typs, funs, extra}: ('a, 'b) context) =
blanchet@57219
    68
  mk_context ctxt id (f syms) typs funs extra
blanchet@57219
    69
blanchet@57219
    70
fun update_binding b = map_syms (Symtab.update b)
blanchet@56078
    71
blanchet@57219
    72
fun with_bindings bs f cx =
blanchet@57219
    73
  let val bs' = map (lookup_binding cx o fst) bs
blanchet@57219
    74
  in
blanchet@57219
    75
    cx
blanchet@57219
    76
    |> fold update_binding bs
blanchet@57219
    77
    |> f
blanchet@57219
    78
    ||> fold2 (fn (name, _) => update_binding o pair name) bs bs'
blanchet@57219
    79
  end
blanchet@56078
    80
blanchet@57221
    81
fun next_id ({ctxt, id, syms, typs, funs, extra}: ('a, 'b) context) =
blanchet@57221
    82
  (id, mk_context ctxt (id + 1) syms typs funs extra)
blanchet@57221
    83
blanchet@57748
    84
fun with_fresh_names f ({ctxt, id, syms, typs, funs, ...}: ('a, 'b) context) =
blanchet@57221
    85
  let
blanchet@57221
    86
    fun bind (_, v as (_, T)) t = Logic.all_const T $ Term.absfree v t
blanchet@57221
    87
blanchet@57221
    88
    val needs_inferT = equal Term.dummyT orf Term.is_TVar
blanchet@57221
    89
    val needs_infer = Term.exists_type (Term.exists_subtype needs_inferT)
blanchet@57221
    90
    fun infer_types ctxt =
blanchet@57221
    91
      singleton (Type_Infer_Context.infer_types ctxt) #>
blanchet@57221
    92
      singleton (Proof_Context.standard_term_check_finish ctxt)
blanchet@57221
    93
    fun infer ctxt t = if needs_infer t then infer_types ctxt t else t
blanchet@57221
    94
blanchet@57747
    95
    val (t, {ctxt = ctxt', extra = names, ...}: ((string * (string * typ)) list, 'b) context) =
blanchet@57221
    96
      f (mk_context ctxt id syms typs funs [])
blanchet@57221
    97
    val t' = infer ctxt' (fold_rev bind names (HOLogic.mk_Trueprop t))
blanchet@57221
    98
  in
blanchet@57747
    99
    (t', map fst names)
blanchet@57221
   100
  end
blanchet@57221
   101
blanchet@57219
   102
fun lookup_typ ({typs, ...}: ('a, 'b) context) = Symtab.lookup typs
blanchet@57219
   103
fun lookup_fun ({funs, ...}: ('a, 'b) context) = Symtab.lookup funs
blanchet@56078
   104
blanchet@56078
   105
blanchet@56078
   106
(* core type and term parser *)
blanchet@56078
   107
blanchet@58061
   108
fun core_type_parser (SMTLIB.Sym "Bool", []) = SOME @{typ HOL.bool}
blanchet@58061
   109
  | core_type_parser (SMTLIB.Sym "Int", []) = SOME @{typ Int.int}
blanchet@56078
   110
  | core_type_parser _ = NONE
blanchet@56078
   111
blanchet@56078
   112
fun mk_unary n t =
blanchet@56078
   113
  let val T = fastype_of t
blanchet@56078
   114
  in Const (n, T --> T) $ t end
blanchet@56078
   115
blanchet@56078
   116
fun mk_binary' n T U t1 t2 = Const (n, [T, T] ---> U) $ t1 $ t2
blanchet@56078
   117
blanchet@56078
   118
fun mk_binary n t1 t2 =
blanchet@56078
   119
  let val T = fastype_of t1
blanchet@56078
   120
  in mk_binary' n T T t1 t2 end
blanchet@56078
   121
blanchet@56078
   122
fun mk_rassoc f t ts =
blanchet@56078
   123
  let val us = rev (t :: ts)
blanchet@56078
   124
  in fold f (tl us) (hd us) end
blanchet@56078
   125
blanchet@56078
   126
fun mk_lassoc f t ts = fold (fn u1 => fn u2 => f u2 u1) ts t
blanchet@56078
   127
blanchet@56078
   128
fun mk_lassoc' n = mk_lassoc (mk_binary n)
blanchet@56078
   129
blanchet@56078
   130
fun mk_binary_pred n S t1 t2 =
blanchet@56078
   131
  let
blanchet@56078
   132
    val T1 = fastype_of t1
blanchet@56078
   133
    val T2 = fastype_of t2
blanchet@56078
   134
    val T =
blanchet@56078
   135
      if T1 <> Term.dummyT then T1
blanchet@56078
   136
      else if T2 <> Term.dummyT then T2
blanchet@56078
   137
      else TVar (("?a", serial ()), S)
blanchet@56078
   138
  in mk_binary' n T @{typ HOL.bool} t1 t2 end
blanchet@56078
   139
blanchet@56078
   140
fun mk_less t1 t2 = mk_binary_pred @{const_name ord_class.less} @{sort linorder} t1 t2
blanchet@56078
   141
fun mk_less_eq t1 t2 = mk_binary_pred @{const_name ord_class.less_eq} @{sort linorder} t1 t2
blanchet@56078
   142
blanchet@58061
   143
fun core_term_parser (SMTLIB.Sym "true", _) = SOME @{const HOL.True}
blanchet@58061
   144
  | core_term_parser (SMTLIB.Sym "false", _) = SOME @{const HOL.False}
blanchet@58061
   145
  | core_term_parser (SMTLIB.Sym "not", [t]) = SOME (HOLogic.mk_not t)
blanchet@58061
   146
  | core_term_parser (SMTLIB.Sym "and", t :: ts) = SOME (mk_rassoc (curry HOLogic.mk_conj) t ts)
blanchet@58061
   147
  | core_term_parser (SMTLIB.Sym "or", t :: ts) = SOME (mk_rassoc (curry HOLogic.mk_disj) t ts)
blanchet@58061
   148
  | core_term_parser (SMTLIB.Sym "=>", [t1, t2]) = SOME (HOLogic.mk_imp (t1, t2))
blanchet@58061
   149
  | core_term_parser (SMTLIB.Sym "implies", [t1, t2]) = SOME (HOLogic.mk_imp (t1, t2))
blanchet@58061
   150
  | core_term_parser (SMTLIB.Sym "=", [t1, t2]) = SOME (HOLogic.mk_eq (t1, t2))
blanchet@58061
   151
  | core_term_parser (SMTLIB.Sym "~", [t1, t2]) = SOME (HOLogic.mk_eq (t1, t2))
blanchet@58061
   152
  | core_term_parser (SMTLIB.Sym "ite", [t1, t2, t3]) =
blanchet@56078
   153
      let
blanchet@56078
   154
        val T = fastype_of t2
blanchet@56078
   155
        val c = Const (@{const_name HOL.If}, [@{typ HOL.bool}, T, T] ---> T)
blanchet@56078
   156
      in SOME (c $ t1 $ t2 $ t3) end
blanchet@58061
   157
  | core_term_parser (SMTLIB.Num i, []) = SOME (HOLogic.mk_number @{typ Int.int} i)
blanchet@58061
   158
  | core_term_parser (SMTLIB.Sym "-", [t]) = SOME (mk_unary @{const_name uminus_class.uminus} t)
blanchet@58061
   159
  | core_term_parser (SMTLIB.Sym "~", [t]) = SOME (mk_unary @{const_name uminus_class.uminus} t)
blanchet@58061
   160
  | core_term_parser (SMTLIB.Sym "+", t :: ts) =
blanchet@56078
   161
      SOME (mk_lassoc' @{const_name plus_class.plus} t ts)
blanchet@58061
   162
  | core_term_parser (SMTLIB.Sym "-", t :: ts) =
blanchet@56078
   163
      SOME (mk_lassoc' @{const_name minus_class.minus} t ts)
blanchet@58061
   164
  | core_term_parser (SMTLIB.Sym "*", t :: ts) =
blanchet@56078
   165
      SOME (mk_lassoc' @{const_name times_class.times} t ts)
blanchet@58061
   166
  | core_term_parser (SMTLIB.Sym "div", [t1, t2]) = SOME (mk_binary @{const_name z3div} t1 t2)
blanchet@58061
   167
  | core_term_parser (SMTLIB.Sym "mod", [t1, t2]) = SOME (mk_binary @{const_name z3mod} t1 t2)
blanchet@58061
   168
  | core_term_parser (SMTLIB.Sym "<", [t1, t2]) = SOME (mk_less t1 t2)
blanchet@58061
   169
  | core_term_parser (SMTLIB.Sym ">", [t1, t2]) = SOME (mk_less t2 t1)
blanchet@58061
   170
  | core_term_parser (SMTLIB.Sym "<=", [t1, t2]) = SOME (mk_less_eq t1 t2)
blanchet@58061
   171
  | core_term_parser (SMTLIB.Sym ">=", [t1, t2]) = SOME (mk_less_eq t2 t1)
blanchet@56078
   172
  | core_term_parser _ = NONE
blanchet@56078
   173
blanchet@56078
   174
blanchet@57219
   175
(* custom type and term parsers *)
blanchet@56078
   176
blanchet@58061
   177
type type_parser = SMTLIB.tree * typ list -> typ option
blanchet@56078
   178
blanchet@58061
   179
type term_parser = SMTLIB.tree * term list -> term option
blanchet@56078
   180
blanchet@56078
   181
fun id_ord ((id1, _), (id2, _)) = int_ord (id1, id2)
blanchet@56078
   182
blanchet@56078
   183
structure Parsers = Generic_Data
blanchet@56078
   184
(
blanchet@56078
   185
  type T = (int * type_parser) list * (int * term_parser) list
blanchet@56122
   186
  val empty : T = ([(serial (), core_type_parser)], [(serial (), core_term_parser)])
blanchet@56078
   187
  val extend = I
blanchet@56078
   188
  fun merge ((tys1, ts1), (tys2, ts2)) =
blanchet@56078
   189
    (Ord_List.merge id_ord (tys1, tys2), Ord_List.merge id_ord (ts1, ts2))
blanchet@56078
   190
)
blanchet@56078
   191
blanchet@56078
   192
fun add_type_parser type_parser =
blanchet@56078
   193
  Parsers.map (apfst (Ord_List.insert id_ord (serial (), type_parser)))
blanchet@56078
   194
blanchet@56078
   195
fun add_term_parser term_parser =
blanchet@56078
   196
  Parsers.map (apsnd (Ord_List.insert id_ord (serial (), term_parser)))
blanchet@56078
   197
blanchet@56078
   198
fun get_type_parsers ctxt = map snd (fst (Parsers.get (Context.Proof ctxt)))
blanchet@56078
   199
fun get_term_parsers ctxt = map snd (snd (Parsers.get (Context.Proof ctxt)))
blanchet@56078
   200
blanchet@56078
   201
fun apply_parsers parsers x =
blanchet@56078
   202
  let
blanchet@56078
   203
    fun apply [] = NONE
blanchet@56078
   204
      | apply (parser :: parsers) =
blanchet@56078
   205
          (case parser x of
blanchet@56078
   206
            SOME y => SOME y
blanchet@56078
   207
          | NONE => apply parsers)
blanchet@56078
   208
  in apply parsers end
blanchet@56078
   209
blanchet@56078
   210
blanchet@57219
   211
(* type and term parsing *)
blanchet@56078
   212
blanchet@58061
   213
exception SMTLIB_PARSE of string * SMTLIB.tree
blanchet@56078
   214
blanchet@57219
   215
val desymbolize = Name.desymbolize (SOME false) o perhaps (try (unprefix "?"))
blanchet@56078
   216
blanchet@57219
   217
fun fresh_fun add name n T ({ctxt, id, syms, typs, funs, extra}: ('a, 'b) context) =
blanchet@56078
   218
  let
blanchet@56078
   219
    val (n', ctxt') = yield_singleton Variable.variant_fixes n ctxt
blanchet@56078
   220
    val t = Free (n', T)
blanchet@56078
   221
    val funs' = Symtab.update (name, t) funs
blanchet@56078
   222
  in (t, mk_context ctxt' id syms typs funs' (add (n', T) extra)) end
blanchet@56078
   223
blanchet@57219
   224
fun declare_fun name = snd oo fresh_fun cons name (desymbolize name)
blanchet@57219
   225
fun declare_free name = fresh_fun (cons o pair name) name (desymbolize name)
blanchet@56078
   226
blanchet@56078
   227
fun parse_type cx ty Ts =
blanchet@56078
   228
  (case apply_parsers (get_type_parsers (ctxt_of cx)) (ty, Ts) of
blanchet@56078
   229
    SOME T => T
blanchet@56078
   230
  | NONE =>
blanchet@56078
   231
      (case ty of
blanchet@58061
   232
        SMTLIB.Sym name =>
blanchet@56078
   233
          (case lookup_typ cx name of
blanchet@56078
   234
            SOME T => T
blanchet@58061
   235
          | NONE => raise SMTLIB_PARSE ("unknown SMT type", ty))
blanchet@58061
   236
      | _ => raise SMTLIB_PARSE ("bad SMT type format", ty)))
blanchet@56078
   237
blanchet@56078
   238
fun parse_term t ts cx =
blanchet@56078
   239
  (case apply_parsers (get_term_parsers (ctxt_of cx)) (t, ts) of
blanchet@56078
   240
    SOME u => (u, cx)
blanchet@56078
   241
  | NONE =>
blanchet@56078
   242
      (case t of
blanchet@58061
   243
        SMTLIB.Sym name =>
blanchet@56078
   244
          (case lookup_fun cx name of
blanchet@56078
   245
            SOME u => (Term.list_comb (u, ts), cx)
blanchet@56078
   246
          | NONE =>
blanchet@57219
   247
              if null ts then declare_free name Term.dummyT cx
blanchet@58061
   248
              else raise SMTLIB_PARSE ("bad SMT term", t))
blanchet@58061
   249
      | _ => raise SMTLIB_PARSE ("bad SMT term format", t)))
blanchet@56078
   250
blanchet@56078
   251
fun type_of cx ty =
blanchet@56078
   252
  (case try (parse_type cx ty) [] of
blanchet@56078
   253
    SOME T => T
blanchet@56078
   254
  | NONE =>
blanchet@56078
   255
      (case ty of
blanchet@58061
   256
        SMTLIB.S (ty' :: tys) => parse_type cx ty' (map (type_of cx) tys)
blanchet@58061
   257
      | _ => raise SMTLIB_PARSE ("bad SMT type", ty)))
blanchet@56078
   258
blanchet@58061
   259
fun dest_var cx (SMTLIB.S [SMTLIB.Sym name, ty]) = (name, (desymbolize name, type_of cx ty))
blanchet@58061
   260
  | dest_var _ v = raise SMTLIB_PARSE ("bad SMT quantifier variable format", v)
blanchet@56078
   261
blanchet@58061
   262
fun dest_body (SMTLIB.S (SMTLIB.Sym "!" :: body :: _)) = dest_body body
blanchet@56078
   263
  | dest_body body = body
blanchet@56078
   264
blanchet@58061
   265
fun dest_binding (SMTLIB.S [SMTLIB.Sym name, t]) = (name, Tree t)
blanchet@58061
   266
  | dest_binding b = raise SMTLIB_PARSE ("bad SMT let binding format", b)
blanchet@56078
   267
blanchet@56078
   268
fun term_of t cx =
blanchet@56078
   269
  (case t of
blanchet@58061
   270
    SMTLIB.S [SMTLIB.Sym "forall", SMTLIB.S vars, body] => quant HOLogic.mk_all vars body cx
blanchet@58061
   271
  | SMTLIB.S [SMTLIB.Sym "exists", SMTLIB.S vars, body] => quant HOLogic.mk_exists vars body cx
blanchet@58061
   272
  | SMTLIB.S [SMTLIB.Sym "let", SMTLIB.S bindings, body] =>
blanchet@56078
   273
      with_bindings (map dest_binding bindings) (term_of body) cx
blanchet@58061
   274
  | SMTLIB.S (SMTLIB.Sym "!" :: t :: _) => term_of t cx
blanchet@58061
   275
  | SMTLIB.S (f :: args) =>
blanchet@56078
   276
      cx
blanchet@56078
   277
      |> fold_map term_of args
blanchet@56078
   278
      |-> parse_term f
blanchet@58061
   279
  | SMTLIB.Sym name =>
blanchet@56078
   280
      (case lookup_binding cx name of
blanchet@56078
   281
        Tree u =>
blanchet@56078
   282
          cx
blanchet@56078
   283
          |> term_of u
blanchet@56078
   284
          |-> (fn u' => pair u' o update_binding (name, Term u'))
blanchet@56078
   285
      | Term u => (u, cx)
blanchet@56078
   286
      | None => parse_term t [] cx
blanchet@58061
   287
      | _ => raise SMTLIB_PARSE ("bad SMT term format", t))
blanchet@56078
   288
  | _ => parse_term t [] cx)
blanchet@56078
   289
blanchet@56078
   290
and quant q vars body cx =
blanchet@56078
   291
  let val vs = map (dest_var cx) vars
blanchet@56078
   292
  in
blanchet@56078
   293
    cx
blanchet@56078
   294
    |> with_bindings (map (apsnd (Term o Free)) vs) (term_of (dest_body body))
blanchet@56078
   295
    |>> fold_rev (fn (_, (n, T)) => fn t => q (n, T, t)) vs
blanchet@56078
   296
  end
blanchet@56078
   297
blanchet@57219
   298
end;