src/HOL/Tools/SMT/z3_model.ML
author boehmes
Sun, 19 Sep 2010 11:33:39 +0200
changeset 39536 c62359dd253d
parent 37153 8feed34275ce
child 40551 a0dd429e97d9
permissions -rw-r--r--
properly parse Z3 error models, including datatypes, and represent function valuations as lambda terms; also normalize Z3 error models

(*  Title:      HOL/Tools/SMT/z3_model.ML
    Author:     Sascha Boehme and Philipp Meyer, TU Muenchen

Parser for counterexamples generated by Z3.
*)

signature Z3_MODEL =
sig
  val parse_counterex: Proof.context -> SMT_Translate.recon -> string list ->
    term list
end

structure Z3_Model: Z3_MODEL =
struct

(* counterexample expressions *)

datatype expr = True | False | Number of int * int option | Value of int |
  Array of array | App of string * expr list
and array = Fresh of expr | Store of (array * expr) * expr


(* parsing *)

val space = Scan.many Symbol.is_ascii_blank
fun spaced p = p --| space
fun in_parens p = spaced (Scan.$$ "(") |-- p --| spaced (Scan.$$ ")")
fun in_braces p = spaced (Scan.$$ "{") |-- p --| spaced (Scan.$$ "}")

val digit = (fn
  "0" => SOME 0 | "1" => SOME 1 | "2" => SOME 2 | "3" => SOME 3 |
  "4" => SOME 4 | "5" => SOME 5 | "6" => SOME 6 | "7" => SOME 7 |
  "8" => SOME 8 | "9" => SOME 9 | _ => NONE)

val nat_num = spaced (Scan.repeat1 (Scan.some digit) >>
  (fn ds => fold (fn d => fn i => i * 10 + d) ds 0))
val int_num = spaced (Scan.optional ($$ "-" >> K (fn i => ~i)) I :|--
  (fn sign => nat_num >> sign))

val is_char = Symbol.is_ascii_letter orf Symbol.is_ascii_digit orf
  member (op =) (explode "_+*-/%~=<>$&|?!.@^#")
val name = spaced (Scan.many1 is_char >> implode)

fun $$$ s = spaced (Scan.this_string s)

fun array_expr st = st |> in_parens (
  $$$ "const" |-- expr >> Fresh ||
  $$$ "store" |-- array_expr -- expr -- expr >> Store)

and expr st = st |> (
  $$$ "true" >> K True ||
  $$$ "false" >> K False ||
  int_num -- Scan.option ($$$ "/" |-- int_num) >> Number ||
  $$$ "val!" |-- nat_num >> Value ||
  name >> (App o rpair []) ||
  array_expr >> Array ||
  in_parens (name -- Scan.repeat1 expr) >> App)

fun args st = ($$$ "->" >> K [] || expr ::: args) st
val args_case = args -- expr
val else_case = $$$ "else" -- $$$ "->" |-- expr >> pair ([] : expr list)

val func =
  let fun cases st = (else_case >> single || args_case ::: cases) st
  in in_braces cases end

val cex = space |--
  Scan.repeat (name --| $$$ "->" -- (func || expr >> (single o pair [])))

fun read_cex ls =
  maps (cons "\n" o explode) ls
  |> try (fst o Scan.finite Symbol.stopper cex)
  |> the_default []


(* normalization *)

local
  fun matches terms f n =
    (case Symtab.lookup terms n of
      NONE => false
    | SOME t => f t)

  fun subst f (n, cases) = (n, map (fn (args, v) => (map f args, f v)) cases)
in

fun reduce_function (n, [c]) = SOME ((n, 0), [c])
  | reduce_function (n, cases) =
      let val (patterns, else_case as (_, e)) = split_last cases
      in
        (case patterns of
          [] => NONE
        | (args, _) :: _ => SOME ((n, length args),
            filter_out (equal e o snd) patterns @ [else_case]))
      end

fun drop_skolem_constants terms = filter (Symtab.defined terms o fst o fst)

fun substitute_constants terms =
  let
    fun check vs1 [] = rev vs1
      | check vs1 ((v as ((n, k), [([], Value i)])) :: vs2) =
          if matches terms (fn Free _ => true | _ => false) n orelse k > 0
          then check (v :: vs1) vs2
          else
            let
              fun sub (e as Value j) = if i = j then App (n, []) else e
                | sub e = e
            in check (map (subst sub) vs1) (map (subst sub) vs2) end
      | check vs1 (v :: vs2) = check (v :: vs1) vs2
  in check [] end

fun remove_int_nat_coercions terms vs =
  let
    fun match ts ((n, _), _) = matches terms (member (op aconv) ts) n

    val ints =
      find_first (match [@{term int}]) vs
      |> Option.map (fn (_, cases) =>
           let val (cs, (_, e)) = split_last cases
           in (e, map (apfst hd) cs) end)
    fun nat_of (v as Value _) = 
          (case ints of
            NONE => v
          | SOME (e, tab) => the_default e (AList.lookup (op =) tab v))
      | nat_of e = e
  in
    map (subst nat_of) vs
    |> filter_out (match [@{term int}, @{term nat}])
  end

fun filter_valid_valuations terms = map_filter (fn
    (_, []) => NONE
  | ((n, i), cases) =>
      let
        fun valid_expr (Array a) = valid_array a
          | valid_expr (App (n, es)) =
              Symtab.defined terms n andalso forall valid_expr es
          | valid_expr _ = true
        and valid_array (Fresh e) = valid_expr e
          | valid_array (Store ((a, e1), e2)) =
              valid_array a andalso valid_expr e1 andalso valid_expr e2
        fun valid_case (es, e) = forall valid_expr (e :: es)
      in
        if not (forall valid_case cases) then NONE
        else Option.map (rpair cases o rpair i) (Symtab.lookup terms n)
      end)

end


(* translation into terms *)

fun with_context ctxt terms f vs =
  fst (fold_map f vs (ctxt, terms, Inttab.empty))

fun fresh_term T (ctxt, terms, values) =
  let val (n, ctxt') = yield_singleton Variable.variant_fixes "" ctxt
  in (Free (n, T), (ctxt', terms, values)) end

fun term_of_value T i (cx as (_, _, values)) =
  (case Inttab.lookup values i of
    SOME t => (t, cx)
  | NONE =>
      let val (t, (ctxt', terms', values')) = fresh_term T cx
      in (t, (ctxt', terms', Inttab.update (i, t) values')) end)

fun get_term n (cx as (_, terms, _)) = (the (Symtab.lookup terms n), cx)

fun trans_expr _ True = pair @{term True}
  | trans_expr _ False = pair @{term False}
  | trans_expr T (Number (i, NONE)) = pair (HOLogic.mk_number T i)
  | trans_expr T (Number (i, SOME j)) =
      pair (Const (@{const_name divide}, [T, T] ---> T) $
        HOLogic.mk_number T i $ HOLogic.mk_number T j)
  | trans_expr T (Value i) = term_of_value T i
  | trans_expr T (Array a) = trans_array T a
  | trans_expr _ (App (n, es)) =
      let val get_Ts = take (length es) o Term.binder_types o Term.fastype_of
      in
        get_term n #-> (fn t =>
        fold_map (uncurry trans_expr) (get_Ts t ~~ es) #>>
        Term.list_comb o pair t)
      end

and trans_array T a =
  let val dT = Term.domain_type T and rT = Term.range_type T
  in
    (case a of
      Fresh e => trans_expr rT e #>> (fn t => Abs ("x", dT, t))
    | Store ((a', e1), e2) =>
        trans_array T a' ##>> trans_expr dT e1 ##>> trans_expr rT e2 #>>
        (fn ((m, k), v) =>
          Const (@{const_name fun_upd}, [T, dT, rT] ---> T) $ m $ k $ v))
  end

fun trans_pattern T ([], e) = trans_expr T e #>> pair []
  | trans_pattern T (arg :: args, e) =
      trans_expr (Term.domain_type T) arg ##>>
      trans_pattern (Term.range_type T) (args, e) #>>
      (fn (arg', (args', e')) => (arg' :: args', e'))

fun mk_fun_upd T U = Const (@{const_name fun_upd}, [T --> U, T, U, T] ---> U)

fun split_type T = (Term.domain_type T, Term.range_type T)

fun mk_update ([], u) _ = u
  | mk_update ([t], u) f =
      uncurry mk_fun_upd (split_type (Term.fastype_of f)) $ f $ t $ u
  | mk_update (t :: ts, u) f =
      let
        val (dT, rT) = split_type (Term.fastype_of f)
        val (dT', rT') = split_type rT
      in
        mk_fun_upd dT rT $ f $ t $
          mk_update (ts, u) (Term.absdummy (dT', Const ("_", rT')))
      end

fun mk_lambda Ts (t, pats) =
  fold_rev (curry Term.absdummy) Ts t |> fold mk_update pats

fun translate' T i [([], e)] =
      if i = 0 then trans_expr T e
      else 
        let val ((Us1, Us2), U) = Term.strip_type T |>> chop i
        in trans_expr (Us2 ---> U) e #>> mk_lambda Us1 o rpair [] end
  | translate' T i cases =
      let
        val (pat_cases, def) = split_last cases |> apsnd snd
        val ((Us1, Us2), U) = Term.strip_type T |>> chop i
      in
        trans_expr (Us2 ---> U) def ##>>
        fold_map (trans_pattern T) pat_cases #>>
        mk_lambda Us1
      end

fun translate ((t, i), cases) =
  translate' (Term.fastype_of t) i cases #>> HOLogic.mk_eq o pair t


(* overall procedure *)

fun parse_counterex ctxt ({terms, ...} : SMT_Translate.recon) ls =
  read_cex ls
  |> map_filter reduce_function
  |> drop_skolem_constants terms
  |> substitute_constants terms
  |> remove_int_nat_coercions terms
  |> filter_valid_valuations terms
  |> with_context ctxt terms translate

end