src/HOL/Tools/SMT/z3_model.ML
changeset 39536 c62359dd253d
parent 37153 8feed34275ce
child 40551 a0dd429e97d9
equal deleted inserted replaced
39535:cd1bb7125d8a 39536:c62359dd253d
     4 Parser for counterexamples generated by Z3.
     4 Parser for counterexamples generated by Z3.
     5 *)
     5 *)
     6 
     6 
     7 signature Z3_MODEL =
     7 signature Z3_MODEL =
     8 sig
     8 sig
     9   val parse_counterex: SMT_Translate.recon -> string list -> term list
     9   val parse_counterex: Proof.context -> SMT_Translate.recon -> string list ->
       
    10     term list
    10 end
    11 end
    11 
    12 
    12 structure Z3_Model: Z3_MODEL =
    13 structure Z3_Model: Z3_MODEL =
    13 struct
    14 struct
    14 
    15 
    15 (* counterexample expressions *)
    16 (* counterexample expressions *)
    16 
    17 
    17 datatype expr = True | False | Number of int * int option | Value of int |
    18 datatype expr = True | False | Number of int * int option | Value of int |
    18   Array of array
    19   Array of array | App of string * expr list
    19 and array = Fresh of expr | Store of (array * expr) * expr
    20 and array = Fresh of expr | Store of (array * expr) * expr
    20 
    21 
    21 
    22 
    22 (* parsing *)
    23 (* parsing *)
    23 
    24 
    24 val space = Scan.many Symbol.is_ascii_blank
    25 val space = Scan.many Symbol.is_ascii_blank
    25 fun in_parens p = Scan.$$ "(" |-- p --| Scan.$$ ")"
    26 fun spaced p = p --| space
    26 fun in_braces p = (space -- Scan.$$ "{") |-- p --| (space -- Scan.$$ "}")
    27 fun in_parens p = spaced (Scan.$$ "(") |-- p --| spaced (Scan.$$ ")")
       
    28 fun in_braces p = spaced (Scan.$$ "{") |-- p --| spaced (Scan.$$ "}")
    27 
    29 
    28 val digit = (fn
    30 val digit = (fn
    29   "0" => SOME 0 | "1" => SOME 1 | "2" => SOME 2 | "3" => SOME 3 |
    31   "0" => SOME 0 | "1" => SOME 1 | "2" => SOME 2 | "3" => SOME 3 |
    30   "4" => SOME 4 | "5" => SOME 5 | "6" => SOME 6 | "7" => SOME 7 |
    32   "4" => SOME 4 | "5" => SOME 5 | "6" => SOME 6 | "7" => SOME 7 |
    31   "8" => SOME 8 | "9" => SOME 9 | _ => NONE)
    33   "8" => SOME 8 | "9" => SOME 9 | _ => NONE)
    32 
    34 
    33 val nat_num = Scan.repeat1 (Scan.some digit) >>
    35 val nat_num = spaced (Scan.repeat1 (Scan.some digit) >>
    34   (fn ds => fold (fn d => fn i => i * 10 + d) ds 0)
    36   (fn ds => fold (fn d => fn i => i * 10 + d) ds 0))
    35 val int_num = Scan.optional ($$ "-" >> K (fn i => ~i)) I :|--
    37 val int_num = spaced (Scan.optional ($$ "-" >> K (fn i => ~i)) I :|--
    36   (fn sign => nat_num >> sign)
    38   (fn sign => nat_num >> sign))
    37 
    39 
    38 val is_char = Symbol.is_ascii_letter orf Symbol.is_ascii_digit orf
    40 val is_char = Symbol.is_ascii_letter orf Symbol.is_ascii_digit orf
    39   member (op =) (explode "_+*-/%~=<>$&|?!.@^#")
    41   member (op =) (explode "_+*-/%~=<>$&|?!.@^#")
    40 val name = Scan.many1 is_char >> implode
    42 val name = spaced (Scan.many1 is_char >> implode)
    41 
    43 
    42 fun array_expr st = st |>
    44 fun $$$ s = spaced (Scan.this_string s)
    43   in_parens (space |-- (
    45 
    44   Scan.this_string "const" |-- expr >> Fresh ||
    46 fun array_expr st = st |> in_parens (
    45   Scan.this_string "store" -- space |-- array_expr -- expr -- expr >> Store))
    47   $$$ "const" |-- expr >> Fresh ||
    46 
    48   $$$ "store" |-- array_expr -- expr -- expr >> Store)
    47 and expr st = st |> (space |-- (
    49 
    48   Scan.this_string "true" >> K True ||
    50 and expr st = st |> (
    49   Scan.this_string "false" >> K False ||
    51   $$$ "true" >> K True ||
    50   int_num -- Scan.option (Scan.$$ "/" |-- int_num) >> Number ||
    52   $$$ "false" >> K False ||
    51   Scan.this_string "val!" |-- nat_num >> Value ||
    53   int_num -- Scan.option ($$$ "/" |-- int_num) >> Number ||
    52   array_expr >> Array))
    54   $$$ "val!" |-- nat_num >> Value ||
    53 
    55   name >> (App o rpair []) ||
    54 val mapping = space -- Scan.this_string "->"
    56   array_expr >> Array ||
    55 val value = mapping |-- expr
    57   in_parens (name -- Scan.repeat1 expr) >> App)
    56 
    58 
    57 val args_case = Scan.repeat expr -- value
    59 fun args st = ($$$ "->" >> K [] || expr ::: args) st
    58 val else_case = space -- Scan.this_string "else" |-- value >>
    60 val args_case = args -- expr
    59   pair ([] : expr list)
    61 val else_case = $$$ "else" -- $$$ "->" |-- expr >> pair ([] : expr list)
    60 
    62 
    61 val func =
    63 val func =
    62   let fun cases st = (else_case >> single || args_case ::: cases) st
    64   let fun cases st = (else_case >> single || args_case ::: cases) st
    63   in in_braces cases end
    65   in in_braces cases end
    64 
    66 
    65 val cex = space |-- Scan.repeat (space |-- name --| mapping --
    67 val cex = space |--
    66   (func || expr >> (single o pair [])))
    68   Scan.repeat (name --| $$$ "->" -- (func || expr >> (single o pair [])))
    67 
    69 
    68 fun read_cex ls =
    70 fun read_cex ls =
    69   explode (cat_lines ls)
    71   maps (cons "\n" o explode) ls
    70   |> try (fst o Scan.finite Symbol.stopper cex)
    72   |> try (fst o Scan.finite Symbol.stopper cex)
    71   |> the_default []
    73   |> the_default []
    72 
    74 
    73 
    75 
       
    76 (* normalization *)
       
    77 
       
    78 local
       
    79   fun matches terms f n =
       
    80     (case Symtab.lookup terms n of
       
    81       NONE => false
       
    82     | SOME t => f t)
       
    83 
       
    84   fun subst f (n, cases) = (n, map (fn (args, v) => (map f args, f v)) cases)
       
    85 in
       
    86 
       
    87 fun reduce_function (n, [c]) = SOME ((n, 0), [c])
       
    88   | reduce_function (n, cases) =
       
    89       let val (patterns, else_case as (_, e)) = split_last cases
       
    90       in
       
    91         (case patterns of
       
    92           [] => NONE
       
    93         | (args, _) :: _ => SOME ((n, length args),
       
    94             filter_out (equal e o snd) patterns @ [else_case]))
       
    95       end
       
    96 
       
    97 fun drop_skolem_constants terms = filter (Symtab.defined terms o fst o fst)
       
    98 
       
    99 fun substitute_constants terms =
       
   100   let
       
   101     fun check vs1 [] = rev vs1
       
   102       | check vs1 ((v as ((n, k), [([], Value i)])) :: vs2) =
       
   103           if matches terms (fn Free _ => true | _ => false) n orelse k > 0
       
   104           then check (v :: vs1) vs2
       
   105           else
       
   106             let
       
   107               fun sub (e as Value j) = if i = j then App (n, []) else e
       
   108                 | sub e = e
       
   109             in check (map (subst sub) vs1) (map (subst sub) vs2) end
       
   110       | check vs1 (v :: vs2) = check (v :: vs1) vs2
       
   111   in check [] end
       
   112 
       
   113 fun remove_int_nat_coercions terms vs =
       
   114   let
       
   115     fun match ts ((n, _), _) = matches terms (member (op aconv) ts) n
       
   116 
       
   117     val ints =
       
   118       find_first (match [@{term int}]) vs
       
   119       |> Option.map (fn (_, cases) =>
       
   120            let val (cs, (_, e)) = split_last cases
       
   121            in (e, map (apfst hd) cs) end)
       
   122     fun nat_of (v as Value _) = 
       
   123           (case ints of
       
   124             NONE => v
       
   125           | SOME (e, tab) => the_default e (AList.lookup (op =) tab v))
       
   126       | nat_of e = e
       
   127   in
       
   128     map (subst nat_of) vs
       
   129     |> filter_out (match [@{term int}, @{term nat}])
       
   130   end
       
   131 
       
   132 fun filter_valid_valuations terms = map_filter (fn
       
   133     (_, []) => NONE
       
   134   | ((n, i), cases) =>
       
   135       let
       
   136         fun valid_expr (Array a) = valid_array a
       
   137           | valid_expr (App (n, es)) =
       
   138               Symtab.defined terms n andalso forall valid_expr es
       
   139           | valid_expr _ = true
       
   140         and valid_array (Fresh e) = valid_expr e
       
   141           | valid_array (Store ((a, e1), e2)) =
       
   142               valid_array a andalso valid_expr e1 andalso valid_expr e2
       
   143         fun valid_case (es, e) = forall valid_expr (e :: es)
       
   144       in
       
   145         if not (forall valid_case cases) then NONE
       
   146         else Option.map (rpair cases o rpair i) (Symtab.lookup terms n)
       
   147       end)
       
   148 
       
   149 end
       
   150 
       
   151 
    74 (* translation into terms *)
   152 (* translation into terms *)
    75 
   153 
    76 fun lookup_term tab (name, e) = Option.map (rpair e) (Symtab.lookup tab name)
   154 fun with_context ctxt terms f vs =
    77 
   155   fst (fold_map f vs (ctxt, terms, Inttab.empty))
    78 fun with_name_context tab f xs =
   156 
    79   let
   157 fun fresh_term T (ctxt, terms, values) =
    80     val ns = Symtab.fold (Term.add_free_names o snd) tab []
   158   let val (n, ctxt') = yield_singleton Variable.variant_fixes "" ctxt
    81     val nctxt = Name.make_context ns
   159   in (Free (n, T), (ctxt', terms, values)) end
    82   in fst (fold_map f xs (Inttab.empty, nctxt)) end
   160 
    83 
   161 fun term_of_value T i (cx as (_, _, values)) =
    84 fun fresh_term T (tab, nctxt) =
   162   (case Inttab.lookup values i of
    85   let val (n, nctxt') = yield_singleton Name.variants "" nctxt
       
    86   in (Free (n, T), (tab, nctxt')) end
       
    87 
       
    88 fun term_of_value T i (cx as (tab, _)) =
       
    89   (case Inttab.lookup tab i of
       
    90     SOME t => (t, cx)
   163     SOME t => (t, cx)
    91   | NONE =>
   164   | NONE =>
    92       let val (t, (tab', nctxt')) = fresh_term T cx
   165       let val (t, (ctxt', terms', values')) = fresh_term T cx
    93       in (t, (Inttab.update (i, t) tab', nctxt')) end)
   166       in (t, (ctxt', terms', Inttab.update (i, t) values')) end)
       
   167 
       
   168 fun get_term n (cx as (_, terms, _)) = (the (Symtab.lookup terms n), cx)
    94 
   169 
    95 fun trans_expr _ True = pair @{term True}
   170 fun trans_expr _ True = pair @{term True}
    96   | trans_expr _ False = pair @{term False}
   171   | trans_expr _ False = pair @{term False}
    97   | trans_expr T (Number (i, NONE)) = pair (HOLogic.mk_number T i)
   172   | trans_expr T (Number (i, NONE)) = pair (HOLogic.mk_number T i)
    98   | trans_expr T (Number (i, SOME j)) =
   173   | trans_expr T (Number (i, SOME j)) =
    99       pair (Const (@{const_name divide}, [T, T] ---> T) $
   174       pair (Const (@{const_name divide}, [T, T] ---> T) $
   100         HOLogic.mk_number T i $ HOLogic.mk_number T j)
   175         HOLogic.mk_number T i $ HOLogic.mk_number T j)
   101   | trans_expr T (Value i) = term_of_value T i
   176   | trans_expr T (Value i) = term_of_value T i
   102   | trans_expr T (Array a) = trans_array T a
   177   | trans_expr T (Array a) = trans_array T a
       
   178   | trans_expr _ (App (n, es)) =
       
   179       let val get_Ts = take (length es) o Term.binder_types o Term.fastype_of
       
   180       in
       
   181         get_term n #-> (fn t =>
       
   182         fold_map (uncurry trans_expr) (get_Ts t ~~ es) #>>
       
   183         Term.list_comb o pair t)
       
   184       end
   103 
   185 
   104 and trans_array T a =
   186 and trans_array T a =
   105   let val dT = Term.domain_type T and rT = Term.range_type T
   187   let val dT = Term.domain_type T and rT = Term.range_type T
   106   in
   188   in
   107     (case a of
   189     (case a of
   110         trans_array T a' ##>> trans_expr dT e1 ##>> trans_expr rT e2 #>>
   192         trans_array T a' ##>> trans_expr dT e1 ##>> trans_expr rT e2 #>>
   111         (fn ((m, k), v) =>
   193         (fn ((m, k), v) =>
   112           Const (@{const_name fun_upd}, [T, dT, rT] ---> T) $ m $ k $ v))
   194           Const (@{const_name fun_upd}, [T, dT, rT] ---> T) $ m $ k $ v))
   113   end
   195   end
   114 
   196 
   115 fun trans_pat i T f x =
   197 fun trans_pattern T ([], e) = trans_expr T e #>> pair []
   116   f (Term.domain_type T) ##>> trans (i-1) (Term.range_type T) x #>>
   198   | trans_pattern T (arg :: args, e) =
   117   (fn (u, (us, t)) => (u :: us, t))
   199       trans_expr (Term.domain_type T) arg ##>>
   118 
   200       trans_pattern (Term.range_type T) (args, e) #>>
   119 and trans i T ([], v) =
   201       (fn (arg', (args', e')) => (arg' :: args', e'))
   120       if i > 0 then trans_pat i T fresh_term ([], v)
   202 
   121       else trans_expr T v #>> pair []
   203 fun mk_fun_upd T U = Const (@{const_name fun_upd}, [T --> U, T, U, T] ---> U)
   122   | trans i T (p :: ps, v) = trans_pat i T (fn U => trans_expr U p) (ps, v)
   204 
   123 
   205 fun split_type T = (Term.domain_type T, Term.range_type T)
   124 fun mk_eq' t us u = HOLogic.mk_eq (Term.list_comb (t, us), u)
   206 
   125 fun mk_eq (Const (@{const_name fun_app}, _)) (u' :: us', u) = mk_eq' u' us' u
   207 fun mk_update ([], u) _ = u
   126   | mk_eq t (us, u) = mk_eq' t us u
   208   | mk_update ([t], u) f =
   127 
   209       uncurry mk_fun_upd (split_type (Term.fastype_of f)) $ f $ t $ u
   128 fun translate (t, cs) =
   210   | mk_update (t :: ts, u) f =
   129   let val T = Term.fastype_of t
   211       let
   130   in
   212         val (dT, rT) = split_type (Term.fastype_of f)
   131     (case (can HOLogic.dest_number t, cs) of
   213         val (dT', rT') = split_type rT
   132       (true, [c]) => trans 0 T c #>> (fn (_, u) => [mk_eq u ([], t)])
   214       in
   133     | (_, (es, _) :: _) => fold_map (trans (length es) T) cs #>> map (mk_eq t)
   215         mk_fun_upd dT rT $ f $ t $
   134     | _ => raise TERM ("translate: no cases", [t]))
   216           mk_update (ts, u) (Term.absdummy (dT', Const ("_", rT')))
   135   end
   217       end
       
   218 
       
   219 fun mk_lambda Ts (t, pats) =
       
   220   fold_rev (curry Term.absdummy) Ts t |> fold mk_update pats
       
   221 
       
   222 fun translate' T i [([], e)] =
       
   223       if i = 0 then trans_expr T e
       
   224       else 
       
   225         let val ((Us1, Us2), U) = Term.strip_type T |>> chop i
       
   226         in trans_expr (Us2 ---> U) e #>> mk_lambda Us1 o rpair [] end
       
   227   | translate' T i cases =
       
   228       let
       
   229         val (pat_cases, def) = split_last cases |> apsnd snd
       
   230         val ((Us1, Us2), U) = Term.strip_type T |>> chop i
       
   231       in
       
   232         trans_expr (Us2 ---> U) def ##>>
       
   233         fold_map (trans_pattern T) pat_cases #>>
       
   234         mk_lambda Us1
       
   235       end
       
   236 
       
   237 fun translate ((t, i), cases) =
       
   238   translate' (Term.fastype_of t) i cases #>> HOLogic.mk_eq o pair t
   136 
   239 
   137 
   240 
   138 (* overall procedure *)
   241 (* overall procedure *)
   139 
   242 
   140 fun parse_counterex ({terms, ...} : SMT_Translate.recon) ls =
   243 fun parse_counterex ctxt ({terms, ...} : SMT_Translate.recon) ls =
   141   read_cex ls
   244   read_cex ls
   142   |> map_filter (lookup_term terms)
   245   |> map_filter reduce_function
   143   |> with_name_context terms translate
   246   |> drop_skolem_constants terms
   144   |> flat
   247   |> substitute_constants terms
       
   248   |> remove_int_nat_coercions terms
       
   249   |> filter_valid_valuations terms
       
   250   |> with_context ctxt terms translate
   145 
   251 
   146 end
   252 end
       
   253