src/HOL/Tools/SMT/z3_model.ML
author boehmes
Tue Nov 30 18:22:43 2010 +0100 (2010-11-30)
changeset 40828 47ff261431c4
parent 40663 e080c9e68752
child 40840 2f97215e79bf
permissions -rw-r--r--
split up Z3 models into constraints on free variables and constant definitions;
reduce Z3 models by replacing unknowns with free variables and constants from the goal;
remove occurrences of the hidden constant fun_app from Z3 models
     1 (*  Title:      HOL/Tools/SMT/z3_model.ML
     2     Author:     Sascha Boehme and Philipp Meyer, TU Muenchen
     3 
     4 Parser for counterexamples generated by Z3.
     5 *)
     6 
     7 signature Z3_MODEL =
     8 sig
     9   val parse_counterex: Proof.context -> SMT_Translate.recon -> string list ->
    10     term list * term list
    11 end
    12 
    13 structure Z3_Model: Z3_MODEL =
    14 struct
    15 
    16 structure U = SMT_Utils
    17 
    18 
    19 (* counterexample expressions *)
    20 
    21 datatype expr = True | False | Number of int * int option | Value of int |
    22   Array of array | App of string * expr list
    23 and array = Fresh of expr | Store of (array * expr) * expr
    24 
    25 
    26 (* parsing *)
    27 
    28 val space = Scan.many Symbol.is_ascii_blank
    29 fun spaced p = p --| space
    30 fun in_parens p = spaced (Scan.$$ "(") |-- p --| spaced (Scan.$$ ")")
    31 fun in_braces p = spaced (Scan.$$ "{") |-- p --| spaced (Scan.$$ "}")
    32 
    33 val digit = (fn
    34   "0" => SOME 0 | "1" => SOME 1 | "2" => SOME 2 | "3" => SOME 3 |
    35   "4" => SOME 4 | "5" => SOME 5 | "6" => SOME 6 | "7" => SOME 7 |
    36   "8" => SOME 8 | "9" => SOME 9 | _ => NONE)
    37 
    38 val nat_num = spaced (Scan.repeat1 (Scan.some digit) >>
    39   (fn ds => fold (fn d => fn i => i * 10 + d) ds 0))
    40 val int_num = spaced (Scan.optional ($$ "-" >> K (fn i => ~i)) I :|--
    41   (fn sign => nat_num >> sign))
    42 
    43 val is_char = Symbol.is_ascii_letter orf Symbol.is_ascii_digit orf
    44   member (op =) (raw_explode "_+*-/%~=<>$&|?!.@^#")
    45 val name = spaced (Scan.many1 is_char >> implode)
    46 
    47 fun $$$ s = spaced (Scan.this_string s)
    48 
    49 fun array_expr st = st |> in_parens (
    50   $$$ "const" |-- expr >> Fresh ||
    51   $$$ "store" |-- array_expr -- expr -- expr >> Store)
    52 
    53 and expr st = st |> (
    54   $$$ "true" >> K True ||
    55   $$$ "false" >> K False ||
    56   int_num -- Scan.option ($$$ "/" |-- int_num) >> Number ||
    57   $$$ "val!" |-- nat_num >> Value ||
    58   name >> (App o rpair []) ||
    59   array_expr >> Array ||
    60   in_parens (name -- Scan.repeat1 expr) >> App)
    61 
    62 fun args st = ($$$ "->" >> K [] || expr ::: args) st
    63 val args_case = args -- expr
    64 val else_case = $$$ "else" -- $$$ "->" |-- expr >> pair ([] : expr list)
    65 
    66 val func =
    67   let fun cases st = (else_case >> single || args_case ::: cases) st
    68   in in_braces cases end
    69 
    70 val cex = space |--
    71   Scan.repeat (name --| $$$ "->" -- (func || expr >> (single o pair [])))
    72 
    73 fun resolve terms ((n, k), cases) =
    74   (case Symtab.lookup terms n of
    75     NONE => NONE
    76   | SOME t => SOME ((t, k), cases))
    77 
    78 fun annotate _ (_, []) = NONE
    79   | annotate terms (n, [([], c)]) = resolve terms ((n, 0), (c, []))
    80   | annotate _ (_, [_]) = NONE
    81   | annotate terms (n, cases as (args, _) :: _) =
    82       let val (cases', (_, else_case)) = split_last cases
    83       in resolve terms ((n, length args), (else_case, cases')) end
    84 
    85 fun read_cex terms ls =
    86   maps (cons "\n" o raw_explode) ls
    87   |> try (fst o Scan.finite Symbol.stopper cex)
    88   |> the_default []
    89   |> map_filter (annotate terms)
    90 
    91 
    92 (* translation into terms *)
    93 
    94 fun max_value vs =
    95   let
    96     fun max_val_expr (Value i) = Integer.max i
    97       | max_val_expr (App (_, es)) = fold max_val_expr es
    98       | max_val_expr (Array a) = max_val_array a
    99       | max_val_expr _ = I
   100 
   101     and max_val_array (Fresh e) = max_val_expr e
   102       | max_val_array (Store ((a, e1), e2)) =
   103           max_val_array a #> max_val_expr e1 #> max_val_expr e2
   104 
   105     fun max_val (_, (ec, cs)) =
   106       max_val_expr ec #> fold (fn (es, e) => fold max_val_expr (e :: es)) cs
   107 
   108   in fold max_val vs ~1 end
   109 
   110 fun with_context terms f vs = fst (fold_map f vs (terms, max_value vs + 1))
   111 
   112 fun get_term n T es (cx as (terms, next_val)) =
   113   (case Symtab.lookup terms n of
   114     SOME t => ((t, es), cx)
   115   | NONE =>
   116       let val t = Var (("fresh", next_val), T)
   117       in ((t, []), (Symtab.update (n, t) terms, next_val + 1)) end)
   118 
   119 fun trans_expr _ True = pair @{const True}
   120   | trans_expr _ False = pair @{const False}
   121   | trans_expr T (Number (i, NONE)) = pair (HOLogic.mk_number T i)
   122   | trans_expr T (Number (i, SOME j)) =
   123       pair (Const (@{const_name divide}, [T, T] ---> T) $
   124         HOLogic.mk_number T i $ HOLogic.mk_number T j)
   125   | trans_expr T (Value i) = pair (Var (("value", i), T))
   126   | trans_expr T (Array a) = trans_array T a
   127   | trans_expr T (App (n, es)) = get_term n T es #-> (fn (t, es') =>
   128       let val Ts = fst (U.dest_funT (length es') (Term.fastype_of t))
   129       in
   130         fold_map (uncurry trans_expr) (Ts ~~ es') #>> Term.list_comb o pair t
   131       end)
   132 
   133 and trans_array T a =
   134   let val (dT, rT) = U.split_type T
   135   in
   136     (case a of
   137       Fresh e => trans_expr rT e #>> (fn t => Abs ("x", dT, t))
   138     | Store ((a', e1), e2) =>
   139         trans_array T a' ##>> trans_expr dT e1 ##>> trans_expr rT e2 #>>
   140         (fn ((m, k), v) =>
   141           Const (@{const_name fun_upd}, [T, dT, rT] ---> T) $ m $ k $ v))
   142   end
   143 
   144 fun trans_pattern T ([], e) = trans_expr T e #>> pair []
   145   | trans_pattern T (arg :: args, e) =
   146       trans_expr (Term.domain_type T) arg ##>>
   147       trans_pattern (Term.range_type T) (args, e) #>>
   148       (fn (arg', (args', e')) => (arg' :: args', e'))
   149 
   150 fun mk_fun_upd T U = Const (@{const_name fun_upd}, [T --> U, T, U, T] ---> U)
   151 
   152 fun mk_update ([], u) _ = u
   153   | mk_update ([t], u) f =
   154       uncurry mk_fun_upd (U.split_type (Term.fastype_of f)) $ f $ t $ u
   155   | mk_update (t :: ts, u) f =
   156       let
   157         val (dT, rT) = U.split_type (Term.fastype_of f)
   158         val (dT', rT') = U.split_type rT
   159       in
   160         mk_fun_upd dT rT $ f $ t $
   161           mk_update (ts, u) (Term.absdummy (dT', Const ("_", rT')))
   162       end
   163 
   164 fun mk_lambda Ts (t, pats) =
   165   fold_rev (curry Term.absdummy) Ts t |> fold mk_update pats
   166 
   167 fun translate ((t, k), (e, cs)) =
   168   let
   169     val T = Term.fastype_of t
   170     val (Us, U) = U.dest_funT k (Term.fastype_of t)
   171 
   172     fun mk_full_def u' pats =
   173       pats
   174       |> filter_out (fn (_, u) => u aconv u')
   175       |> HOLogic.mk_eq o pair t o mk_lambda Us o pair u'
   176 
   177     fun mk_eq (us, u) = HOLogic.mk_eq (Term.list_comb (t, us), u)
   178     fun mk_eqs u' [] = [HOLogic.mk_eq (t, u')]
   179       | mk_eqs _ pats = map mk_eq pats
   180   in
   181     trans_expr U e ##>>
   182     (if k = 0 then pair [] else fold_map (trans_pattern T) cs) #>>
   183     (fn (u', pats) => (mk_eqs u' pats, mk_full_def u' pats))
   184   end
   185 
   186 
   187 (* normalization *)
   188 
   189 fun partition_eqs f =
   190   let
   191     fun part t (xs, ts) =
   192       (case try HOLogic.dest_eq t of
   193         SOME (l, r) => (case f l r of SOME x => (x::xs, ts) | _ => (xs, t::ts))
   194       | NONE => (xs, t :: ts))
   195   in (fn ts => fold part ts ([], [])) end
   196 
   197 fun replace_vars tab =
   198   let
   199     fun replace (v as Var _) = the_default v (AList.lookup (op aconv) tab v)
   200       | replace t = t
   201   in map (Term.map_aterms replace) end
   202 
   203 fun remove_int_nat_coercions (eqs, defs) =
   204   let
   205     fun mk_nat_num t i =
   206       (case try HOLogic.dest_number i of
   207         SOME (_, n) => SOME (t, HOLogic.mk_number @{typ nat} n)
   208       | NONE => NONE)
   209     fun nat_of (@{const of_nat (int)} $ (t as Var _)) i = mk_nat_num t i
   210       | nat_of (@{const nat} $ i) (t as Var _) = mk_nat_num t i
   211       | nat_of _ _ = NONE
   212     val (nats, eqs') = partition_eqs nat_of eqs
   213 
   214     fun is_coercion t =
   215       (case try HOLogic.dest_eq t of
   216         SOME (@{const of_nat (int)}, _) => true
   217       | SOME (@{const nat}, _) => true
   218       | _ => false)
   219   in pairself (replace_vars nats) (eqs', filter_out is_coercion defs) end
   220 
   221 fun unfold_funapp (eqs, defs) =
   222   let
   223     fun unfold_app (Const (@{const_name SMT.fun_app}, _) $ f $ t) = f $ t
   224       | unfold_app t = t
   225     fun unfold_eq ((eq as Const (@{const_name HOL.eq}, _)) $ t $ u) =
   226           eq $ unfold_app t $ u
   227       | unfold_eq t = t
   228 
   229     fun is_fun_app t =
   230       (case try HOLogic.dest_eq t of
   231         SOME (Const (@{const_name SMT.fun_app}, _), _) => true
   232       | _ => false)
   233 
   234   in (map unfold_eq eqs, filter_out is_fun_app defs) end
   235 
   236 fun unfold_simple_eqs (eqs, defs) =
   237   let
   238     fun add_rewr (l as Const _) (r as Var _) = SOME (r, l)
   239       | add_rewr (l as Free _) (r as Var _) = SOME (r, l)
   240       | add_rewr _ _ = NONE
   241     val (rs, eqs') = partition_eqs add_rewr eqs
   242 
   243     fun is_trivial (Const (@{const_name HOL.eq}, _) $ t $ u) = t aconv u
   244       | is_trivial _ = false
   245   in pairself (replace_vars rs #> filter_out is_trivial) (eqs', defs) end
   246 
   247 fun swap_free ((eq as Const (@{const_name HOL.eq}, _)) $ t $ (u as Free _)) =
   248       eq $ u $ t
   249   | swap_free t = t
   250 
   251 fun frees_for_vars ctxt (eqs, defs) =
   252   let
   253     fun fresh_free i T (cx as (frees, ctxt)) =
   254       (case Inttab.lookup frees i of
   255         SOME t => (t, cx)
   256       | NONE =>
   257           let
   258             val (n, ctxt') = yield_singleton Variable.variant_fixes "" ctxt
   259             val t = Free (n, T)
   260           in (t, (Inttab.update (i, t) frees, ctxt')) end)
   261 
   262     fun repl_var (Var ((_, i), T)) = fresh_free i T
   263       | repl_var (t $ u) = repl_var t ##>> repl_var u #>> op $
   264       | repl_var (Abs (n, T, t)) = repl_var t #>> (fn t' => Abs (n, T, t'))
   265       | repl_var t = pair t
   266   in
   267     (Inttab.empty, ctxt)
   268     |> fold_map repl_var eqs
   269     ||>> fold_map repl_var defs
   270     |> fst
   271   end
   272 
   273 
   274 (* overall procedure *)
   275 
   276 val is_free_constraint = Term.exists_subterm (fn Free _ => true | _ => false)
   277 
   278 fun is_const_def (Const (@{const_name HOL.eq}, _) $ Const _ $ _) = true
   279   | is_const_def _ = false
   280 
   281 fun parse_counterex ctxt ({terms, ...} : SMT_Translate.recon) ls =
   282   read_cex terms ls
   283   |> with_context terms translate
   284   |> apfst flat o split_list
   285   |> remove_int_nat_coercions
   286   |> unfold_funapp
   287   |> unfold_simple_eqs
   288   |>> map swap_free
   289   |>> filter is_free_constraint
   290   |> frees_for_vars ctxt
   291   ||> filter is_const_def
   292 
   293 end
   294