split up Z3 models into constraints on free variables and constant definitions;
authorboehmes
Tue, 30 Nov 2010 18:22:43 +0100
changeset 40828 47ff261431c4
parent 40807 eeaa59fb5ad8
child 40829 edd1e0764da1
split up Z3 models into constraints on free variables and constant definitions; reduce Z3 models by replacing unknowns with free variables and constants from the goal; remove occurrences of the hidden constant fun_app from Z3 models
src/HOL/Tools/SMT/smt_failure.ML
src/HOL/Tools/SMT/smt_solver.ML
src/HOL/Tools/SMT/z3_model.ML
--- a/src/HOL/Tools/SMT/smt_failure.ML	Tue Nov 30 00:12:29 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_failure.ML	Tue Nov 30 18:22:43 2010 +0100
@@ -6,12 +6,17 @@
 
 signature SMT_FAILURE =
 sig
+  type counterexample = {
+    is_real_cex: bool,
+    free_constraints: term list,
+    const_defs: term list}
   datatype failure =
-    Counterexample of bool * term list |
+    Counterexample of counterexample |
     Time_Out |
     Out_Of_Memory |
     Abnormal_Termination of int |
     Other_Failure of string
+  val pretty_counterexample: Proof.context -> counterexample -> Pretty.T
   val string_of_failure: Proof.context -> failure -> string
   exception SMT of failure
 end
@@ -19,23 +24,32 @@
 structure SMT_Failure: SMT_FAILURE =
 struct
 
+type counterexample = {
+  is_real_cex: bool,
+  free_constraints: term list,
+  const_defs: term list}
+
 datatype failure =
-  Counterexample of bool * term list |
+  Counterexample of counterexample |
   Time_Out |
   Out_Of_Memory |
   Abnormal_Termination of int |
   Other_Failure of string
 
-fun string_of_failure ctxt (Counterexample (real, ex)) =
-      let
-        val msg =
-          if real then "Counterexample found (possibly spurious)"
-          else "Potential counterexample found"
-      in
-        if null ex then msg
-        else Pretty.string_of (Pretty.big_list (msg ^ ":")
-          (map (Syntax.pretty_term ctxt) ex))
-      end
+fun pretty_counterexample ctxt {is_real_cex, free_constraints, const_defs} =
+  let
+    val msg =
+      if is_real_cex then "Counterexample found (possibly spurious)"
+      else "Potential counterexample found"
+  in
+    if null free_constraints andalso null const_defs then Pretty.str msg
+    else
+      Pretty.big_list (msg ^ ":")
+        (map (Syntax.pretty_term ctxt) (free_constraints @ const_defs))
+  end
+
+fun string_of_failure ctxt (Counterexample cex) =
+      Pretty.string_of (pretty_counterexample ctxt cex)
   | string_of_failure _ Time_Out = "Timed out"
   | string_of_failure _ Out_Of_Memory = "Ran out of memory"
   | string_of_failure _ (Abnormal_Termination err) =
--- a/src/HOL/Tools/SMT/smt_solver.ML	Tue Nov 30 00:12:29 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_solver.ML	Tue Nov 30 18:22:43 2010 +0100
@@ -19,7 +19,7 @@
     interface: interface,
     outcome: string -> string list -> outcome * string list,
     cex_parser: (Proof.context -> SMT_Translate.recon -> string list ->
-      term list) option,
+      term list * term list) option,
     reconstruct: (Proof.context -> SMT_Translate.recon -> string list ->
       (int list * thm) * Proof.context) option }
 
@@ -65,7 +65,7 @@
   interface: interface,
   outcome: string -> string list -> outcome * string list,
   cex_parser: (Proof.context -> SMT_Translate.recon -> string list ->
-    term list) option,
+    term list * term list) option,
   reconstruct: (Proof.context -> SMT_Translate.recon -> string list ->
     (int list * thm) * Proof.context) option }
 
@@ -260,9 +260,14 @@
         then the reconstruct ctxt recon ls
         else (([], ocl ()), ctxt)
     | (result, ls) =>
-        let val ts = (case cex_parser of SOME f => f ctxt recon ls | _ => [])
-        in
-          raise SMT_Failure.SMT (SMT_Failure.Counterexample (result = Sat, ts))
+        let
+          val (ts, us) =
+            (case cex_parser of SOME f => f ctxt recon ls | _ => ([], []))
+         in
+          raise SMT_Failure.SMT (SMT_Failure.Counterexample {
+            is_real_cex = (result = Sat),
+            free_constraints = ts,
+            const_defs = us})
         end)
 
   val cfalse = Thm.cterm_of @{theory} (@{const Trueprop} $ @{const False})
@@ -351,15 +356,14 @@
     let
       fun solve irules = snd (smt_solver NONE ctxt' irules)
       val tag = "Solver " ^ C.solver_of ctxt' ^ ": "
-      val str_of = SMT_Failure.string_of_failure ctxt'
+      val str_of = prefix tag o SMT_Failure.string_of_failure ctxt'
       fun safe_solve irules =
         if pass_exns then SOME (solve irules)
         else (SOME (solve irules)
           handle
             SMT_Failure.SMT (fail as SMT_Failure.Counterexample _) =>
-              (C.verbose_msg ctxt' (prefix tag o str_of) fail; NONE)
-          | SMT_Failure.SMT fail =>
-              (C.trace_msg ctxt' (prefix tag o str_of) fail; NONE))
+              (C.verbose_msg ctxt' str_of fail; NONE)
+          | SMT_Failure.SMT fail => (C.trace_msg ctxt' str_of fail; NONE))
     in
       safe_solve (map (pair ~1) (rules @ prems))
       |> (fn SOME thm => Tactic.rtac thm 1 | _ => Tactical.no_tac)
--- a/src/HOL/Tools/SMT/z3_model.ML	Tue Nov 30 00:12:29 2010 +0100
+++ b/src/HOL/Tools/SMT/z3_model.ML	Tue Nov 30 18:22:43 2010 +0100
@@ -7,7 +7,7 @@
 signature Z3_MODEL =
 sig
   val parse_counterex: Proof.context -> SMT_Translate.recon -> string list ->
-    term list
+    term list * term list
 end
 
 structure Z3_Model: Z3_MODEL =
@@ -70,117 +70,51 @@
 val cex = space |--
   Scan.repeat (name --| $$$ "->" -- (func || expr >> (single o pair [])))
 
-fun read_cex ls =
+fun resolve terms ((n, k), cases) =
+  (case Symtab.lookup terms n of
+    NONE => NONE
+  | SOME t => SOME ((t, k), cases))
+
+fun annotate _ (_, []) = NONE
+  | annotate terms (n, [([], c)]) = resolve terms ((n, 0), (c, []))
+  | annotate _ (_, [_]) = NONE
+  | annotate terms (n, cases as (args, _) :: _) =
+      let val (cases', (_, else_case)) = split_last cases
+      in resolve terms ((n, length args), (else_case, cases')) end
+
+fun read_cex terms ls =
   maps (cons "\n" o raw_explode) ls
   |> try (fst o Scan.finite Symbol.stopper cex)
   |> the_default []
-
-
-(* normalization *)
-
-local
-  fun matches terms f n =
-    (case Symtab.lookup terms n of
-      NONE => false
-    | SOME t => f t)
-
-  fun subst f (n, cases) = (n, map (fn (args, v) => (map f args, f v)) cases)
-in
-
-fun reduce_function (n, [c]) = SOME ((n, 0), [c])
-  | reduce_function (n, cases) =
-      let val (patterns, else_case as (_, e)) = split_last cases
-      in
-        (case patterns of
-          [] => NONE
-        | (args, _) :: _ => SOME ((n, length args),
-            filter_out (equal e o snd) patterns @ [else_case]))
-      end
-
-fun drop_skolem_constants terms = filter (Symtab.defined terms o fst o fst)
-
-fun substitute_constants terms =
-  let
-    fun check vs1 [] = rev vs1
-      | check vs1 ((v as ((n, k), [([], Value i)])) :: vs2) =
-          if matches terms (fn Free _ => true | _ => false) n orelse k > 0
-          then check (v :: vs1) vs2
-          else
-            let
-              fun sub (e as Value j) = if i = j then App (n, []) else e
-                | sub e = e
-            in check (map (subst sub) vs1) (map (subst sub) vs2) end
-      | check vs1 (v :: vs2) = check (v :: vs1) vs2
-  in check [] end
-
-fun remove_int_nat_coercions terms vs =
-  let
-    fun match ts ((n, _), _) = matches terms (member (op aconv) ts) n
-
-    val (default_int, ints) =
-      (case find_first (match [@{const of_nat (int)}]) vs of
-        NONE => (NONE, [])
-      | SOME (_, cases) =>
-          let val (cs, (_, e)) = split_last cases
-          in (SOME e, map (apfst hd) cs) end)
-
-    fun nat_of @{typ nat} (v as Value _) =
-          AList.lookup (op =) ints v |> the_default (the_default v default_int)
-      | nat_of _ e = e
-
-    fun subst_nat T k ([], e) =
-          let fun app f i = if i <= 0 then I else app f (i-1) o f
-          in ([], nat_of (app Term.range_type k T) e) end
-      | subst_nat T k (arg :: args, e) =
-          subst_nat (Term.range_type T) (k-1) (args, e)
-          |> apfst (cons (nat_of (Term.domain_type T) arg))
-
-    fun subst_nats (v as ((n, k), cases)) =
-      (case Symtab.lookup terms n of
-        NONE => v
-      | SOME t => ((n, k), map (subst_nat (Term.fastype_of t) k) cases))
-  in
-    map subst_nats vs
-    |> filter_out (match [@{const of_nat (int)}, @{const nat}])
-  end
-
-fun filter_valid_valuations terms = map_filter (fn
-    (_, []) => NONE
-  | ((n, i), cases) =>
-      let
-        fun valid_expr (Array a) = valid_array a
-          | valid_expr (App (n, es)) =
-              Symtab.defined terms n andalso forall valid_expr es
-          | valid_expr _ = true
-        and valid_array (Fresh e) = valid_expr e
-          | valid_array (Store ((a, e1), e2)) =
-              valid_array a andalso valid_expr e1 andalso valid_expr e2
-        fun valid_case (es, e) = forall valid_expr (e :: es)
-      in
-        if not (forall valid_case cases) then NONE
-        else Option.map (rpair cases o rpair i) (Symtab.lookup terms n)
-      end)
-
-end
+  |> map_filter (annotate terms)
 
 
 (* translation into terms *)
 
-fun with_context ctxt terms f vs =
-  fst (fold_map f vs (ctxt, terms, Inttab.empty))
+fun max_value vs =
+  let
+    fun max_val_expr (Value i) = Integer.max i
+      | max_val_expr (App (_, es)) = fold max_val_expr es
+      | max_val_expr (Array a) = max_val_array a
+      | max_val_expr _ = I
 
-fun fresh_term T (ctxt, terms, values) =
-  let val (n, ctxt') = yield_singleton Variable.variant_fixes "" ctxt
-  in (Free (n, T), (ctxt', terms, values)) end
+    and max_val_array (Fresh e) = max_val_expr e
+      | max_val_array (Store ((a, e1), e2)) =
+          max_val_array a #> max_val_expr e1 #> max_val_expr e2
 
-fun term_of_value T i (cx as (_, _, values)) =
-  (case Inttab.lookup values i of
-    SOME t => (t, cx)
+    fun max_val (_, (ec, cs)) =
+      max_val_expr ec #> fold (fn (es, e) => fold max_val_expr (e :: es)) cs
+
+  in fold max_val vs ~1 end
+
+fun with_context terms f vs = fst (fold_map f vs (terms, max_value vs + 1))
+
+fun get_term n T es (cx as (terms, next_val)) =
+  (case Symtab.lookup terms n of
+    SOME t => ((t, es), cx)
   | NONE =>
-      let val (t, (ctxt', terms', values')) = fresh_term T cx
-      in (t, (ctxt', terms', Inttab.update (i, t) values')) end)
-
-fun get_term n (cx as (_, terms, _)) = (the (Symtab.lookup terms n), cx)
+      let val t = Var (("fresh", next_val), T)
+      in ((t, []), (Symtab.update (n, t) terms, next_val + 1)) end)
 
 fun trans_expr _ True = pair @{const True}
   | trans_expr _ False = pair @{const False}
@@ -188,18 +122,16 @@
   | trans_expr T (Number (i, SOME j)) =
       pair (Const (@{const_name divide}, [T, T] ---> T) $
         HOLogic.mk_number T i $ HOLogic.mk_number T j)
-  | trans_expr T (Value i) = term_of_value T i
+  | trans_expr T (Value i) = pair (Var (("value", i), T))
   | trans_expr T (Array a) = trans_array T a
-  | trans_expr _ (App (n, es)) =
-      let val get_Ts = take (length es) o Term.binder_types o Term.fastype_of
+  | trans_expr T (App (n, es)) = get_term n T es #-> (fn (t, es') =>
+      let val Ts = fst (U.dest_funT (length es') (Term.fastype_of t))
       in
-        get_term n #-> (fn t =>
-        fold_map (uncurry trans_expr) (get_Ts t ~~ es) #>>
-        Term.list_comb o pair t)
-      end
+        fold_map (uncurry trans_expr) (Ts ~~ es') #>> Term.list_comb o pair t
+      end)
 
 and trans_array T a =
-  let val dT = Term.domain_type T and rT = Term.range_type T
+  let val (dT, rT) = U.split_type T
   in
     (case a of
       Fresh e => trans_expr rT e #>> (fn t => Abs ("x", dT, t))
@@ -232,35 +164,131 @@
 fun mk_lambda Ts (t, pats) =
   fold_rev (curry Term.absdummy) Ts t |> fold mk_update pats
 
-fun translate' T i [([], e)] =
-      if i = 0 then trans_expr T e
-      else 
-        let val ((Us1, Us2), U) = Term.strip_type T |>> chop i
-        in trans_expr (Us2 ---> U) e #>> mk_lambda Us1 o rpair [] end
-  | translate' T i cases =
-      let
-        val (pat_cases, def) = split_last cases |> apsnd snd
-        val ((Us1, Us2), U) = Term.strip_type T |>> chop i
-      in
-        trans_expr (Us2 ---> U) def ##>>
-        fold_map (trans_pattern T) pat_cases #>>
-        mk_lambda Us1
-      end
+fun translate ((t, k), (e, cs)) =
+  let
+    val T = Term.fastype_of t
+    val (Us, U) = U.dest_funT k (Term.fastype_of t)
+
+    fun mk_full_def u' pats =
+      pats
+      |> filter_out (fn (_, u) => u aconv u')
+      |> HOLogic.mk_eq o pair t o mk_lambda Us o pair u'
+
+    fun mk_eq (us, u) = HOLogic.mk_eq (Term.list_comb (t, us), u)
+    fun mk_eqs u' [] = [HOLogic.mk_eq (t, u')]
+      | mk_eqs _ pats = map mk_eq pats
+  in
+    trans_expr U e ##>>
+    (if k = 0 then pair [] else fold_map (trans_pattern T) cs) #>>
+    (fn (u', pats) => (mk_eqs u' pats, mk_full_def u' pats))
+  end
+
+
+(* normalization *)
+
+fun partition_eqs f =
+  let
+    fun part t (xs, ts) =
+      (case try HOLogic.dest_eq t of
+        SOME (l, r) => (case f l r of SOME x => (x::xs, ts) | _ => (xs, t::ts))
+      | NONE => (xs, t :: ts))
+  in (fn ts => fold part ts ([], [])) end
+
+fun replace_vars tab =
+  let
+    fun replace (v as Var _) = the_default v (AList.lookup (op aconv) tab v)
+      | replace t = t
+  in map (Term.map_aterms replace) end
+
+fun remove_int_nat_coercions (eqs, defs) =
+  let
+    fun mk_nat_num t i =
+      (case try HOLogic.dest_number i of
+        SOME (_, n) => SOME (t, HOLogic.mk_number @{typ nat} n)
+      | NONE => NONE)
+    fun nat_of (@{const of_nat (int)} $ (t as Var _)) i = mk_nat_num t i
+      | nat_of (@{const nat} $ i) (t as Var _) = mk_nat_num t i
+      | nat_of _ _ = NONE
+    val (nats, eqs') = partition_eqs nat_of eqs
 
-fun translate ((t, i), cases) =
-  translate' (Term.fastype_of t) i cases #>> HOLogic.mk_eq o pair t
+    fun is_coercion t =
+      (case try HOLogic.dest_eq t of
+        SOME (@{const of_nat (int)}, _) => true
+      | SOME (@{const nat}, _) => true
+      | _ => false)
+  in pairself (replace_vars nats) (eqs', filter_out is_coercion defs) end
+
+fun unfold_funapp (eqs, defs) =
+  let
+    fun unfold_app (Const (@{const_name SMT.fun_app}, _) $ f $ t) = f $ t
+      | unfold_app t = t
+    fun unfold_eq ((eq as Const (@{const_name HOL.eq}, _)) $ t $ u) =
+          eq $ unfold_app t $ u
+      | unfold_eq t = t
+
+    fun is_fun_app t =
+      (case try HOLogic.dest_eq t of
+        SOME (Const (@{const_name SMT.fun_app}, _), _) => true
+      | _ => false)
+
+  in (map unfold_eq eqs, filter_out is_fun_app defs) end
+
+fun unfold_simple_eqs (eqs, defs) =
+  let
+    fun add_rewr (l as Const _) (r as Var _) = SOME (r, l)
+      | add_rewr (l as Free _) (r as Var _) = SOME (r, l)
+      | add_rewr _ _ = NONE
+    val (rs, eqs') = partition_eqs add_rewr eqs
+
+    fun is_trivial (Const (@{const_name HOL.eq}, _) $ t $ u) = t aconv u
+      | is_trivial _ = false
+  in pairself (replace_vars rs #> filter_out is_trivial) (eqs', defs) end
+
+fun swap_free ((eq as Const (@{const_name HOL.eq}, _)) $ t $ (u as Free _)) =
+      eq $ u $ t
+  | swap_free t = t
+
+fun frees_for_vars ctxt (eqs, defs) =
+  let
+    fun fresh_free i T (cx as (frees, ctxt)) =
+      (case Inttab.lookup frees i of
+        SOME t => (t, cx)
+      | NONE =>
+          let
+            val (n, ctxt') = yield_singleton Variable.variant_fixes "" ctxt
+            val t = Free (n, T)
+          in (t, (Inttab.update (i, t) frees, ctxt')) end)
+
+    fun repl_var (Var ((_, i), T)) = fresh_free i T
+      | repl_var (t $ u) = repl_var t ##>> repl_var u #>> op $
+      | repl_var (Abs (n, T, t)) = repl_var t #>> (fn t' => Abs (n, T, t'))
+      | repl_var t = pair t
+  in
+    (Inttab.empty, ctxt)
+    |> fold_map repl_var eqs
+    ||>> fold_map repl_var defs
+    |> fst
+  end
 
 
 (* overall procedure *)
 
+val is_free_constraint = Term.exists_subterm (fn Free _ => true | _ => false)
+
+fun is_const_def (Const (@{const_name HOL.eq}, _) $ Const _ $ _) = true
+  | is_const_def _ = false
+
 fun parse_counterex ctxt ({terms, ...} : SMT_Translate.recon) ls =
-  read_cex ls
-  |> map_filter reduce_function
-  |> drop_skolem_constants terms
-  |> substitute_constants terms
-  |> remove_int_nat_coercions terms
-  |> filter_valid_valuations terms
-  |> with_context ctxt terms translate
+  read_cex terms ls
+  |> with_context terms translate
+  |> apfst flat o split_list
+  |> remove_int_nat_coercions
+  |> unfold_funapp
+  |> unfold_simple_eqs
+  |>> map swap_free
+  |>> filter is_free_constraint
+  |> frees_for_vars ctxt
+  ||> filter is_const_def
 
 end