src/HOL/Tools/SMT/z3_model.ML
author boehmes
Mon Nov 22 15:45:43 2010 +0100 (2010-11-22)
changeset 40663 e080c9e68752
parent 40627 becf5d5187cc
child 40828 47ff261431c4
permissions -rw-r--r--
share and use more utility functions;
slightly reduced complexity for Z3 proof rule 'rewrite'
     1 (*  Title:      HOL/Tools/SMT/z3_model.ML
     2     Author:     Sascha Boehme and Philipp Meyer, TU Muenchen
     3 
     4 Parser for counterexamples generated by Z3.
     5 *)
     6 
     7 signature Z3_MODEL =
     8 sig
     9   val parse_counterex: Proof.context -> SMT_Translate.recon -> string list ->
    10     term list
    11 end
    12 
    13 structure Z3_Model: Z3_MODEL =
    14 struct
    15 
    16 structure U = SMT_Utils
    17 
    18 
    19 (* counterexample expressions *)
    20 
    21 datatype expr = True | False | Number of int * int option | Value of int |
    22   Array of array | App of string * expr list
    23 and array = Fresh of expr | Store of (array * expr) * expr
    24 
    25 
    26 (* parsing *)
    27 
    28 val space = Scan.many Symbol.is_ascii_blank
    29 fun spaced p = p --| space
    30 fun in_parens p = spaced (Scan.$$ "(") |-- p --| spaced (Scan.$$ ")")
    31 fun in_braces p = spaced (Scan.$$ "{") |-- p --| spaced (Scan.$$ "}")
    32 
    33 val digit = (fn
    34   "0" => SOME 0 | "1" => SOME 1 | "2" => SOME 2 | "3" => SOME 3 |
    35   "4" => SOME 4 | "5" => SOME 5 | "6" => SOME 6 | "7" => SOME 7 |
    36   "8" => SOME 8 | "9" => SOME 9 | _ => NONE)
    37 
    38 val nat_num = spaced (Scan.repeat1 (Scan.some digit) >>
    39   (fn ds => fold (fn d => fn i => i * 10 + d) ds 0))
    40 val int_num = spaced (Scan.optional ($$ "-" >> K (fn i => ~i)) I :|--
    41   (fn sign => nat_num >> sign))
    42 
    43 val is_char = Symbol.is_ascii_letter orf Symbol.is_ascii_digit orf
    44   member (op =) (raw_explode "_+*-/%~=<>$&|?!.@^#")
    45 val name = spaced (Scan.many1 is_char >> implode)
    46 
    47 fun $$$ s = spaced (Scan.this_string s)
    48 
    49 fun array_expr st = st |> in_parens (
    50   $$$ "const" |-- expr >> Fresh ||
    51   $$$ "store" |-- array_expr -- expr -- expr >> Store)
    52 
    53 and expr st = st |> (
    54   $$$ "true" >> K True ||
    55   $$$ "false" >> K False ||
    56   int_num -- Scan.option ($$$ "/" |-- int_num) >> Number ||
    57   $$$ "val!" |-- nat_num >> Value ||
    58   name >> (App o rpair []) ||
    59   array_expr >> Array ||
    60   in_parens (name -- Scan.repeat1 expr) >> App)
    61 
    62 fun args st = ($$$ "->" >> K [] || expr ::: args) st
    63 val args_case = args -- expr
    64 val else_case = $$$ "else" -- $$$ "->" |-- expr >> pair ([] : expr list)
    65 
    66 val func =
    67   let fun cases st = (else_case >> single || args_case ::: cases) st
    68   in in_braces cases end
    69 
    70 val cex = space |--
    71   Scan.repeat (name --| $$$ "->" -- (func || expr >> (single o pair [])))
    72 
    73 fun read_cex ls =
    74   maps (cons "\n" o raw_explode) ls
    75   |> try (fst o Scan.finite Symbol.stopper cex)
    76   |> the_default []
    77 
    78 
    79 (* normalization *)
    80 
    81 local
    82   fun matches terms f n =
    83     (case Symtab.lookup terms n of
    84       NONE => false
    85     | SOME t => f t)
    86 
    87   fun subst f (n, cases) = (n, map (fn (args, v) => (map f args, f v)) cases)
    88 in
    89 
    90 fun reduce_function (n, [c]) = SOME ((n, 0), [c])
    91   | reduce_function (n, cases) =
    92       let val (patterns, else_case as (_, e)) = split_last cases
    93       in
    94         (case patterns of
    95           [] => NONE
    96         | (args, _) :: _ => SOME ((n, length args),
    97             filter_out (equal e o snd) patterns @ [else_case]))
    98       end
    99 
   100 fun drop_skolem_constants terms = filter (Symtab.defined terms o fst o fst)
   101 
   102 fun substitute_constants terms =
   103   let
   104     fun check vs1 [] = rev vs1
   105       | check vs1 ((v as ((n, k), [([], Value i)])) :: vs2) =
   106           if matches terms (fn Free _ => true | _ => false) n orelse k > 0
   107           then check (v :: vs1) vs2
   108           else
   109             let
   110               fun sub (e as Value j) = if i = j then App (n, []) else e
   111                 | sub e = e
   112             in check (map (subst sub) vs1) (map (subst sub) vs2) end
   113       | check vs1 (v :: vs2) = check (v :: vs1) vs2
   114   in check [] end
   115 
   116 fun remove_int_nat_coercions terms vs =
   117   let
   118     fun match ts ((n, _), _) = matches terms (member (op aconv) ts) n
   119 
   120     val (default_int, ints) =
   121       (case find_first (match [@{const of_nat (int)}]) vs of
   122         NONE => (NONE, [])
   123       | SOME (_, cases) =>
   124           let val (cs, (_, e)) = split_last cases
   125           in (SOME e, map (apfst hd) cs) end)
   126 
   127     fun nat_of @{typ nat} (v as Value _) =
   128           AList.lookup (op =) ints v |> the_default (the_default v default_int)
   129       | nat_of _ e = e
   130 
   131     fun subst_nat T k ([], e) =
   132           let fun app f i = if i <= 0 then I else app f (i-1) o f
   133           in ([], nat_of (app Term.range_type k T) e) end
   134       | subst_nat T k (arg :: args, e) =
   135           subst_nat (Term.range_type T) (k-1) (args, e)
   136           |> apfst (cons (nat_of (Term.domain_type T) arg))
   137 
   138     fun subst_nats (v as ((n, k), cases)) =
   139       (case Symtab.lookup terms n of
   140         NONE => v
   141       | SOME t => ((n, k), map (subst_nat (Term.fastype_of t) k) cases))
   142   in
   143     map subst_nats vs
   144     |> filter_out (match [@{const of_nat (int)}, @{const nat}])
   145   end
   146 
   147 fun filter_valid_valuations terms = map_filter (fn
   148     (_, []) => NONE
   149   | ((n, i), cases) =>
   150       let
   151         fun valid_expr (Array a) = valid_array a
   152           | valid_expr (App (n, es)) =
   153               Symtab.defined terms n andalso forall valid_expr es
   154           | valid_expr _ = true
   155         and valid_array (Fresh e) = valid_expr e
   156           | valid_array (Store ((a, e1), e2)) =
   157               valid_array a andalso valid_expr e1 andalso valid_expr e2
   158         fun valid_case (es, e) = forall valid_expr (e :: es)
   159       in
   160         if not (forall valid_case cases) then NONE
   161         else Option.map (rpair cases o rpair i) (Symtab.lookup terms n)
   162       end)
   163 
   164 end
   165 
   166 
   167 (* translation into terms *)
   168 
   169 fun with_context ctxt terms f vs =
   170   fst (fold_map f vs (ctxt, terms, Inttab.empty))
   171 
   172 fun fresh_term T (ctxt, terms, values) =
   173   let val (n, ctxt') = yield_singleton Variable.variant_fixes "" ctxt
   174   in (Free (n, T), (ctxt', terms, values)) end
   175 
   176 fun term_of_value T i (cx as (_, _, values)) =
   177   (case Inttab.lookup values i of
   178     SOME t => (t, cx)
   179   | NONE =>
   180       let val (t, (ctxt', terms', values')) = fresh_term T cx
   181       in (t, (ctxt', terms', Inttab.update (i, t) values')) end)
   182 
   183 fun get_term n (cx as (_, terms, _)) = (the (Symtab.lookup terms n), cx)
   184 
   185 fun trans_expr _ True = pair @{const True}
   186   | trans_expr _ False = pair @{const False}
   187   | trans_expr T (Number (i, NONE)) = pair (HOLogic.mk_number T i)
   188   | trans_expr T (Number (i, SOME j)) =
   189       pair (Const (@{const_name divide}, [T, T] ---> T) $
   190         HOLogic.mk_number T i $ HOLogic.mk_number T j)
   191   | trans_expr T (Value i) = term_of_value T i
   192   | trans_expr T (Array a) = trans_array T a
   193   | trans_expr _ (App (n, es)) =
   194       let val get_Ts = take (length es) o Term.binder_types o Term.fastype_of
   195       in
   196         get_term n #-> (fn t =>
   197         fold_map (uncurry trans_expr) (get_Ts t ~~ es) #>>
   198         Term.list_comb o pair t)
   199       end
   200 
   201 and trans_array T a =
   202   let val dT = Term.domain_type T and rT = Term.range_type T
   203   in
   204     (case a of
   205       Fresh e => trans_expr rT e #>> (fn t => Abs ("x", dT, t))
   206     | Store ((a', e1), e2) =>
   207         trans_array T a' ##>> trans_expr dT e1 ##>> trans_expr rT e2 #>>
   208         (fn ((m, k), v) =>
   209           Const (@{const_name fun_upd}, [T, dT, rT] ---> T) $ m $ k $ v))
   210   end
   211 
   212 fun trans_pattern T ([], e) = trans_expr T e #>> pair []
   213   | trans_pattern T (arg :: args, e) =
   214       trans_expr (Term.domain_type T) arg ##>>
   215       trans_pattern (Term.range_type T) (args, e) #>>
   216       (fn (arg', (args', e')) => (arg' :: args', e'))
   217 
   218 fun mk_fun_upd T U = Const (@{const_name fun_upd}, [T --> U, T, U, T] ---> U)
   219 
   220 fun mk_update ([], u) _ = u
   221   | mk_update ([t], u) f =
   222       uncurry mk_fun_upd (U.split_type (Term.fastype_of f)) $ f $ t $ u
   223   | mk_update (t :: ts, u) f =
   224       let
   225         val (dT, rT) = U.split_type (Term.fastype_of f)
   226         val (dT', rT') = U.split_type rT
   227       in
   228         mk_fun_upd dT rT $ f $ t $
   229           mk_update (ts, u) (Term.absdummy (dT', Const ("_", rT')))
   230       end
   231 
   232 fun mk_lambda Ts (t, pats) =
   233   fold_rev (curry Term.absdummy) Ts t |> fold mk_update pats
   234 
   235 fun translate' T i [([], e)] =
   236       if i = 0 then trans_expr T e
   237       else 
   238         let val ((Us1, Us2), U) = Term.strip_type T |>> chop i
   239         in trans_expr (Us2 ---> U) e #>> mk_lambda Us1 o rpair [] end
   240   | translate' T i cases =
   241       let
   242         val (pat_cases, def) = split_last cases |> apsnd snd
   243         val ((Us1, Us2), U) = Term.strip_type T |>> chop i
   244       in
   245         trans_expr (Us2 ---> U) def ##>>
   246         fold_map (trans_pattern T) pat_cases #>>
   247         mk_lambda Us1
   248       end
   249 
   250 fun translate ((t, i), cases) =
   251   translate' (Term.fastype_of t) i cases #>> HOLogic.mk_eq o pair t
   252 
   253 
   254 (* overall procedure *)
   255 
   256 fun parse_counterex ctxt ({terms, ...} : SMT_Translate.recon) ls =
   257   read_cex ls
   258   |> map_filter reduce_function
   259   |> drop_skolem_constants terms
   260   |> substitute_constants terms
   261   |> remove_int_nat_coercions terms
   262   |> filter_valid_valuations terms
   263   |> with_context ctxt terms translate
   264 
   265 end
   266