src/HOL/Library/SMT/z3_model.ML
changeset 58055 625bdd5c70b2
parent 44241 7943b69f0188
equal deleted inserted replaced
58054:1d9edd486479 58055:625bdd5c70b2
       
     1 (*  Title:      HOL/Library/SMT/z3_model.ML
       
     2     Author:     Sascha Boehme and Philipp Meyer, TU Muenchen
       
     3 
       
     4 Parser for counterexamples generated by Z3.
       
     5 *)
       
     6 
       
     7 signature Z3_MODEL =
       
     8 sig
       
     9   val parse_counterex: Proof.context -> SMT_Translate.recon -> string list ->
       
    10     term list * term list
       
    11 end
       
    12 
       
    13 structure Z3_Model: Z3_MODEL =
       
    14 struct
       
    15 
       
    16 
       
    17 (* counterexample expressions *)
       
    18 
       
    19 datatype expr = True | False | Number of int * int option | Value of int |
       
    20   Array of array | App of string * expr list
       
    21 and array = Fresh of expr | Store of (array * expr) * expr
       
    22 
       
    23 
       
    24 (* parsing *)
       
    25 
       
    26 val space = Scan.many Symbol.is_ascii_blank
       
    27 fun spaced p = p --| space
       
    28 fun in_parens p = spaced (Scan.$$ "(") |-- p --| spaced (Scan.$$ ")")
       
    29 fun in_braces p = spaced (Scan.$$ "{") |-- p --| spaced (Scan.$$ "}")
       
    30 
       
    31 val digit = (fn
       
    32   "0" => SOME 0 | "1" => SOME 1 | "2" => SOME 2 | "3" => SOME 3 |
       
    33   "4" => SOME 4 | "5" => SOME 5 | "6" => SOME 6 | "7" => SOME 7 |
       
    34   "8" => SOME 8 | "9" => SOME 9 | _ => NONE)
       
    35 
       
    36 val nat_num = spaced (Scan.repeat1 (Scan.some digit) >>
       
    37   (fn ds => fold (fn d => fn i => i * 10 + d) ds 0))
       
    38 val int_num = spaced (Scan.optional ($$ "-" >> K (fn i => ~i)) I :|--
       
    39   (fn sign => nat_num >> sign))
       
    40 
       
    41 val is_char = Symbol.is_ascii_letter orf Symbol.is_ascii_digit orf
       
    42   member (op =) (raw_explode "_+*-/%~=<>$&|?!.@^#")
       
    43 val name = spaced (Scan.many1 is_char >> implode)
       
    44 
       
    45 fun $$$ s = spaced (Scan.this_string s)
       
    46 
       
    47 fun array_expr st = st |> in_parens (
       
    48   $$$ "const" |-- expr >> Fresh ||
       
    49   $$$ "store" |-- array_expr -- expr -- expr >> Store)
       
    50 
       
    51 and expr st = st |> (
       
    52   $$$ "true" >> K True ||
       
    53   $$$ "false" >> K False ||
       
    54   int_num -- Scan.option ($$$ "/" |-- int_num) >> Number ||
       
    55   $$$ "val!" |-- nat_num >> Value ||
       
    56   name >> (App o rpair []) ||
       
    57   array_expr >> Array ||
       
    58   in_parens (name -- Scan.repeat1 expr) >> App)
       
    59 
       
    60 fun args st = ($$$ "->" >> K [] || expr ::: args) st
       
    61 val args_case = args -- expr
       
    62 val else_case = $$$ "else" -- $$$ "->" |-- expr >> pair ([] : expr list)
       
    63 
       
    64 val func =
       
    65   let fun cases st = (else_case >> single || args_case ::: cases) st
       
    66   in in_braces cases end
       
    67 
       
    68 val cex = space |--
       
    69   Scan.repeat (name --| $$$ "->" -- (func || expr >> (single o pair [])))
       
    70 
       
    71 fun resolve terms ((n, k), cases) =
       
    72   (case Symtab.lookup terms n of
       
    73     NONE => NONE
       
    74   | SOME t => SOME ((t, k), cases))
       
    75 
       
    76 fun annotate _ (_, []) = NONE
       
    77   | annotate terms (n, [([], c)]) = resolve terms ((n, 0), (c, []))
       
    78   | annotate _ (_, [_]) = NONE
       
    79   | annotate terms (n, cases as (args, _) :: _) =
       
    80       let val (cases', (_, else_case)) = split_last cases
       
    81       in resolve terms ((n, length args), (else_case, cases')) end
       
    82 
       
    83 fun read_cex terms ls =
       
    84   maps (cons "\n" o raw_explode) ls
       
    85   |> try (fst o Scan.finite Symbol.stopper cex)
       
    86   |> the_default []
       
    87   |> map_filter (annotate terms)
       
    88 
       
    89 
       
    90 (* translation into terms *)
       
    91 
       
    92 fun max_value vs =
       
    93   let
       
    94     fun max_val_expr (Value i) = Integer.max i
       
    95       | max_val_expr (App (_, es)) = fold max_val_expr es
       
    96       | max_val_expr (Array a) = max_val_array a
       
    97       | max_val_expr _ = I
       
    98 
       
    99     and max_val_array (Fresh e) = max_val_expr e
       
   100       | max_val_array (Store ((a, e1), e2)) =
       
   101           max_val_array a #> max_val_expr e1 #> max_val_expr e2
       
   102 
       
   103     fun max_val (_, (ec, cs)) =
       
   104       max_val_expr ec #> fold (fn (es, e) => fold max_val_expr (e :: es)) cs
       
   105 
       
   106   in fold max_val vs ~1 end
       
   107 
       
   108 fun with_context terms f vs = fst (fold_map f vs (terms, max_value vs + 1))
       
   109 
       
   110 fun get_term n T es (cx as (terms, next_val)) =
       
   111   (case Symtab.lookup terms n of
       
   112     SOME t => ((t, es), cx)
       
   113   | NONE =>
       
   114       let val t = Var (("skolem", next_val), T)
       
   115       in ((t, []), (Symtab.update (n, t) terms, next_val + 1)) end)
       
   116 
       
   117 fun trans_expr _ True = pair @{const True}
       
   118   | trans_expr _ False = pair @{const False}
       
   119   | trans_expr T (Number (i, NONE)) = pair (HOLogic.mk_number T i)
       
   120   | trans_expr T (Number (i, SOME j)) =
       
   121       pair (Const (@{const_name divide}, [T, T] ---> T) $
       
   122         HOLogic.mk_number T i $ HOLogic.mk_number T j)
       
   123   | trans_expr T (Value i) = pair (Var (("value", i), T))
       
   124   | trans_expr T (Array a) = trans_array T a
       
   125   | trans_expr T (App (n, es)) = get_term n T es #-> (fn (t, es') =>
       
   126       let val Ts = fst (SMT_Utils.dest_funT (length es') (Term.fastype_of t))
       
   127       in
       
   128         fold_map (uncurry trans_expr) (Ts ~~ es') #>> Term.list_comb o pair t
       
   129       end)
       
   130 
       
   131 and trans_array T a =
       
   132   let val (dT, rT) = Term.dest_funT T
       
   133   in
       
   134     (case a of
       
   135       Fresh e => trans_expr rT e #>> (fn t => Abs ("x", dT, t))
       
   136     | Store ((a', e1), e2) =>
       
   137         trans_array T a' ##>> trans_expr dT e1 ##>> trans_expr rT e2 #>>
       
   138         (fn ((m, k), v) =>
       
   139           Const (@{const_name fun_upd}, [T, dT, rT] ---> T) $ m $ k $ v))
       
   140   end
       
   141 
       
   142 fun trans_pattern T ([], e) = trans_expr T e #>> pair []
       
   143   | trans_pattern T (arg :: args, e) =
       
   144       trans_expr (Term.domain_type T) arg ##>>
       
   145       trans_pattern (Term.range_type T) (args, e) #>>
       
   146       (fn (arg', (args', e')) => (arg' :: args', e'))
       
   147 
       
   148 fun mk_fun_upd T U = Const (@{const_name fun_upd}, [T --> U, T, U, T] ---> U)
       
   149 
       
   150 fun mk_update ([], u) _ = u
       
   151   | mk_update ([t], u) f =
       
   152       uncurry mk_fun_upd (Term.dest_funT (Term.fastype_of f)) $ f $ t $ u
       
   153   | mk_update (t :: ts, u) f =
       
   154       let
       
   155         val (dT, rT) = Term.dest_funT (Term.fastype_of f)
       
   156         val (dT', rT') = Term.dest_funT rT
       
   157       in
       
   158         mk_fun_upd dT rT $ f $ t $
       
   159           mk_update (ts, u) (absdummy dT' (Const ("_", rT')))
       
   160       end
       
   161 
       
   162 fun mk_lambda Ts (t, pats) =
       
   163   fold_rev absdummy Ts t |> fold mk_update pats
       
   164 
       
   165 fun translate ((t, k), (e, cs)) =
       
   166   let
       
   167     val T = Term.fastype_of t
       
   168     val (Us, U) = SMT_Utils.dest_funT k (Term.fastype_of t)
       
   169 
       
   170     fun mk_full_def u' pats =
       
   171       pats
       
   172       |> filter_out (fn (_, u) => u aconv u')
       
   173       |> HOLogic.mk_eq o pair t o mk_lambda Us o pair u'
       
   174 
       
   175     fun mk_eq (us, u) = HOLogic.mk_eq (Term.list_comb (t, us), u)
       
   176     fun mk_eqs u' [] = [HOLogic.mk_eq (t, u')]
       
   177       | mk_eqs _ pats = map mk_eq pats
       
   178   in
       
   179     trans_expr U e ##>>
       
   180     (if k = 0 then pair [] else fold_map (trans_pattern T) cs) #>>
       
   181     (fn (u', pats) => (mk_eqs u' pats, mk_full_def u' pats))
       
   182   end
       
   183 
       
   184 
       
   185 (* normalization *)
       
   186 
       
   187 fun partition_eqs f =
       
   188   let
       
   189     fun part t (xs, ts) =
       
   190       (case try HOLogic.dest_eq t of
       
   191         SOME (l, r) => (case f l r of SOME x => (x::xs, ts) | _ => (xs, t::ts))
       
   192       | NONE => (xs, t :: ts))
       
   193   in (fn ts => fold part ts ([], [])) end
       
   194 
       
   195 fun first_eq pred =
       
   196   let
       
   197     fun part _ [] = NONE
       
   198       | part us (t :: ts) =
       
   199           (case try (pred o HOLogic.dest_eq) t of
       
   200             SOME (SOME lr) => SOME (lr, fold cons us ts)
       
   201           | _ => part (t :: us) ts)
       
   202   in (fn ts => part [] ts) end
       
   203 
       
   204 fun replace_vars tab =
       
   205   let
       
   206     fun repl v = the_default v (AList.lookup (op aconv) tab v)
       
   207     fun replace (v as Var _) = repl v
       
   208       | replace (v as Free _) = repl v
       
   209       | replace t = t
       
   210   in map (Term.map_aterms replace) end
       
   211 
       
   212 fun remove_int_nat_coercions (eqs, defs) =
       
   213   let
       
   214     fun mk_nat_num t i =
       
   215       (case try HOLogic.dest_number i of
       
   216         SOME (_, n) => SOME (t, HOLogic.mk_number @{typ nat} n)
       
   217       | NONE => NONE)
       
   218     fun nat_of (@{const of_nat (int)} $ (t as Var _)) i = mk_nat_num t i
       
   219       | nat_of (@{const nat} $ i) (t as Var _) = mk_nat_num t i
       
   220       | nat_of _ _ = NONE
       
   221     val (nats, eqs') = partition_eqs nat_of eqs
       
   222 
       
   223     fun is_coercion t =
       
   224       (case try HOLogic.dest_eq t of
       
   225         SOME (@{const of_nat (int)}, _) => true
       
   226       | SOME (@{const nat}, _) => true
       
   227       | _ => false)
       
   228   in pairself (replace_vars nats) (eqs', filter_out is_coercion defs) end
       
   229 
       
   230 fun unfold_funapp (eqs, defs) =
       
   231   let
       
   232     fun unfold_app (Const (@{const_name SMT.fun_app}, _) $ f $ t) = f $ t
       
   233       | unfold_app t = t
       
   234     fun unfold_eq ((eq as Const (@{const_name HOL.eq}, _)) $ t $ u) =
       
   235           eq $ unfold_app t $ u
       
   236       | unfold_eq t = t
       
   237 
       
   238     fun is_fun_app t =
       
   239       (case try HOLogic.dest_eq t of
       
   240         SOME (Const (@{const_name SMT.fun_app}, _), _) => true
       
   241       | _ => false)
       
   242 
       
   243   in (map unfold_eq eqs, filter_out is_fun_app defs) end
       
   244 
       
   245 val unfold_eqs =
       
   246   let
       
   247     val is_ground = not o Term.exists_subterm Term.is_Var
       
   248     fun is_non_rec (v, t) = not (Term.exists_subterm (equal v) t)
       
   249 
       
   250     fun rewr_var (l as Var _, r) = if is_ground r then SOME (l, r) else NONE
       
   251       | rewr_var (r, l as Var _) = if is_ground r then SOME (l, r) else NONE
       
   252       | rewr_var _ = NONE
       
   253 
       
   254     fun rewr_free' e = if is_non_rec e then SOME e else NONE
       
   255     fun rewr_free (e as (Free _, _)) = rewr_free' e
       
   256       | rewr_free (e as (_, Free _)) = rewr_free' (swap e)
       
   257       | rewr_free _ = NONE
       
   258 
       
   259     fun is_trivial (Const (@{const_name HOL.eq}, _) $ t $ u) = t aconv u
       
   260       | is_trivial _ = false
       
   261 
       
   262     fun replace r = replace_vars [r] #> filter_out is_trivial
       
   263 
       
   264     fun unfold_vars (es, ds) =
       
   265       (case first_eq rewr_var es of
       
   266         SOME (lr, es') => unfold_vars (pairself (replace lr) (es', ds))
       
   267       | NONE => (es, ds))
       
   268 
       
   269     fun unfold_frees ues (es, ds) =
       
   270       (case first_eq rewr_free es of
       
   271         SOME (lr, es') =>
       
   272           pairself (replace lr) (es', ds)
       
   273           |> unfold_frees (HOLogic.mk_eq lr :: replace lr ues)
       
   274       | NONE => (ues @ es, ds))
       
   275 
       
   276   in unfold_vars #> unfold_frees [] end
       
   277 
       
   278 fun swap_free ((eq as Const (@{const_name HOL.eq}, _)) $ t $ (u as Free _)) =
       
   279       eq $ u $ t
       
   280   | swap_free t = t
       
   281 
       
   282 fun frees_for_vars ctxt (eqs, defs) =
       
   283   let
       
   284     fun fresh_free i T (cx as (frees, ctxt)) =
       
   285       (case Inttab.lookup frees i of
       
   286         SOME t => (t, cx)
       
   287       | NONE =>
       
   288           let
       
   289             val (n, ctxt') = yield_singleton Variable.variant_fixes "" ctxt
       
   290             val t = Free (n, T)
       
   291           in (t, (Inttab.update (i, t) frees, ctxt')) end)
       
   292 
       
   293     fun repl_var (Var ((_, i), T)) = fresh_free i T
       
   294       | repl_var (t $ u) = repl_var t ##>> repl_var u #>> op $
       
   295       | repl_var (Abs (n, T, t)) = repl_var t #>> (fn t' => Abs (n, T, t'))
       
   296       | repl_var t = pair t
       
   297   in
       
   298     (Inttab.empty, ctxt)
       
   299     |> fold_map repl_var eqs
       
   300     ||>> fold_map repl_var defs
       
   301     |> fst
       
   302   end
       
   303 
       
   304 
       
   305 (* overall procedure *)
       
   306 
       
   307 val is_free_constraint = Term.exists_subterm (fn Free _ => true | _ => false)
       
   308 
       
   309 fun is_free_def (Const (@{const_name HOL.eq}, _) $ Free _ $ _) = true
       
   310   | is_free_def _ = false
       
   311 
       
   312 fun defined tp =
       
   313   try (pairself (fst o HOLogic.dest_eq)) tp
       
   314   |> the_default false o Option.map (op aconv)
       
   315 
       
   316 fun add_free_defs free_cs defs =
       
   317   let val (free_defs, defs') = List.partition is_free_def defs
       
   318   in (free_cs @ filter_out (member defined free_cs) free_defs, defs') end
       
   319 
       
   320 fun is_const_def (Const (@{const_name HOL.eq}, _) $ Const _ $ _) = true
       
   321   | is_const_def _ = false
       
   322 
       
   323 fun parse_counterex ctxt ({terms, ...} : SMT_Translate.recon) ls =
       
   324   read_cex terms ls
       
   325   |> with_context terms translate
       
   326   |> apfst flat o split_list
       
   327   |> remove_int_nat_coercions
       
   328   |> unfold_funapp
       
   329   |> unfold_eqs
       
   330   |>> map swap_free
       
   331   |>> filter is_free_constraint
       
   332   |-> add_free_defs
       
   333   |> frees_for_vars ctxt
       
   334   ||> filter is_const_def
       
   335 
       
   336 end
       
   337