src/HOL/Tools/SMT/z3_model.ML
changeset 39536 c62359dd253d
parent 37153 8feed34275ce
child 40551 a0dd429e97d9
--- a/src/HOL/Tools/SMT/z3_model.ML	Sun Sep 19 00:29:13 2010 +0200
+++ b/src/HOL/Tools/SMT/z3_model.ML	Sun Sep 19 11:33:39 2010 +0200
@@ -6,7 +6,8 @@
 
 signature Z3_MODEL =
 sig
-  val parse_counterex: SMT_Translate.recon -> string list -> term list
+  val parse_counterex: Proof.context -> SMT_Translate.recon -> string list ->
+    term list
 end
 
 structure Z3_Model: Z3_MODEL =
@@ -15,82 +16,156 @@
 (* counterexample expressions *)
 
 datatype expr = True | False | Number of int * int option | Value of int |
-  Array of array
+  Array of array | App of string * expr list
 and array = Fresh of expr | Store of (array * expr) * expr
 
 
 (* parsing *)
 
 val space = Scan.many Symbol.is_ascii_blank
-fun in_parens p = Scan.$$ "(" |-- p --| Scan.$$ ")"
-fun in_braces p = (space -- Scan.$$ "{") |-- p --| (space -- Scan.$$ "}")
+fun spaced p = p --| space
+fun in_parens p = spaced (Scan.$$ "(") |-- p --| spaced (Scan.$$ ")")
+fun in_braces p = spaced (Scan.$$ "{") |-- p --| spaced (Scan.$$ "}")
 
 val digit = (fn
   "0" => SOME 0 | "1" => SOME 1 | "2" => SOME 2 | "3" => SOME 3 |
   "4" => SOME 4 | "5" => SOME 5 | "6" => SOME 6 | "7" => SOME 7 |
   "8" => SOME 8 | "9" => SOME 9 | _ => NONE)
 
-val nat_num = Scan.repeat1 (Scan.some digit) >>
-  (fn ds => fold (fn d => fn i => i * 10 + d) ds 0)
-val int_num = Scan.optional ($$ "-" >> K (fn i => ~i)) I :|--
-  (fn sign => nat_num >> sign)
+val nat_num = spaced (Scan.repeat1 (Scan.some digit) >>
+  (fn ds => fold (fn d => fn i => i * 10 + d) ds 0))
+val int_num = spaced (Scan.optional ($$ "-" >> K (fn i => ~i)) I :|--
+  (fn sign => nat_num >> sign))
 
 val is_char = Symbol.is_ascii_letter orf Symbol.is_ascii_digit orf
   member (op =) (explode "_+*-/%~=<>$&|?!.@^#")
-val name = Scan.many1 is_char >> implode
+val name = spaced (Scan.many1 is_char >> implode)
+
+fun $$$ s = spaced (Scan.this_string s)
 
-fun array_expr st = st |>
-  in_parens (space |-- (
-  Scan.this_string "const" |-- expr >> Fresh ||
-  Scan.this_string "store" -- space |-- array_expr -- expr -- expr >> Store))
+fun array_expr st = st |> in_parens (
+  $$$ "const" |-- expr >> Fresh ||
+  $$$ "store" |-- array_expr -- expr -- expr >> Store)
 
-and expr st = st |> (space |-- (
-  Scan.this_string "true" >> K True ||
-  Scan.this_string "false" >> K False ||
-  int_num -- Scan.option (Scan.$$ "/" |-- int_num) >> Number ||
-  Scan.this_string "val!" |-- nat_num >> Value ||
-  array_expr >> Array))
+and expr st = st |> (
+  $$$ "true" >> K True ||
+  $$$ "false" >> K False ||
+  int_num -- Scan.option ($$$ "/" |-- int_num) >> Number ||
+  $$$ "val!" |-- nat_num >> Value ||
+  name >> (App o rpair []) ||
+  array_expr >> Array ||
+  in_parens (name -- Scan.repeat1 expr) >> App)
 
-val mapping = space -- Scan.this_string "->"
-val value = mapping |-- expr
-
-val args_case = Scan.repeat expr -- value
-val else_case = space -- Scan.this_string "else" |-- value >>
-  pair ([] : expr list)
+fun args st = ($$$ "->" >> K [] || expr ::: args) st
+val args_case = args -- expr
+val else_case = $$$ "else" -- $$$ "->" |-- expr >> pair ([] : expr list)
 
 val func =
   let fun cases st = (else_case >> single || args_case ::: cases) st
   in in_braces cases end
 
-val cex = space |-- Scan.repeat (space |-- name --| mapping --
-  (func || expr >> (single o pair [])))
+val cex = space |--
+  Scan.repeat (name --| $$$ "->" -- (func || expr >> (single o pair [])))
 
 fun read_cex ls =
-  explode (cat_lines ls)
+  maps (cons "\n" o explode) ls
   |> try (fst o Scan.finite Symbol.stopper cex)
   |> the_default []
 
 
+(* normalization *)
+
+local
+  fun matches terms f n =
+    (case Symtab.lookup terms n of
+      NONE => false
+    | SOME t => f t)
+
+  fun subst f (n, cases) = (n, map (fn (args, v) => (map f args, f v)) cases)
+in
+
+fun reduce_function (n, [c]) = SOME ((n, 0), [c])
+  | reduce_function (n, cases) =
+      let val (patterns, else_case as (_, e)) = split_last cases
+      in
+        (case patterns of
+          [] => NONE
+        | (args, _) :: _ => SOME ((n, length args),
+            filter_out (equal e o snd) patterns @ [else_case]))
+      end
+
+fun drop_skolem_constants terms = filter (Symtab.defined terms o fst o fst)
+
+fun substitute_constants terms =
+  let
+    fun check vs1 [] = rev vs1
+      | check vs1 ((v as ((n, k), [([], Value i)])) :: vs2) =
+          if matches terms (fn Free _ => true | _ => false) n orelse k > 0
+          then check (v :: vs1) vs2
+          else
+            let
+              fun sub (e as Value j) = if i = j then App (n, []) else e
+                | sub e = e
+            in check (map (subst sub) vs1) (map (subst sub) vs2) end
+      | check vs1 (v :: vs2) = check (v :: vs1) vs2
+  in check [] end
+
+fun remove_int_nat_coercions terms vs =
+  let
+    fun match ts ((n, _), _) = matches terms (member (op aconv) ts) n
+
+    val ints =
+      find_first (match [@{term int}]) vs
+      |> Option.map (fn (_, cases) =>
+           let val (cs, (_, e)) = split_last cases
+           in (e, map (apfst hd) cs) end)
+    fun nat_of (v as Value _) = 
+          (case ints of
+            NONE => v
+          | SOME (e, tab) => the_default e (AList.lookup (op =) tab v))
+      | nat_of e = e
+  in
+    map (subst nat_of) vs
+    |> filter_out (match [@{term int}, @{term nat}])
+  end
+
+fun filter_valid_valuations terms = map_filter (fn
+    (_, []) => NONE
+  | ((n, i), cases) =>
+      let
+        fun valid_expr (Array a) = valid_array a
+          | valid_expr (App (n, es)) =
+              Symtab.defined terms n andalso forall valid_expr es
+          | valid_expr _ = true
+        and valid_array (Fresh e) = valid_expr e
+          | valid_array (Store ((a, e1), e2)) =
+              valid_array a andalso valid_expr e1 andalso valid_expr e2
+        fun valid_case (es, e) = forall valid_expr (e :: es)
+      in
+        if not (forall valid_case cases) then NONE
+        else Option.map (rpair cases o rpair i) (Symtab.lookup terms n)
+      end)
+
+end
+
+
 (* translation into terms *)
 
-fun lookup_term tab (name, e) = Option.map (rpair e) (Symtab.lookup tab name)
+fun with_context ctxt terms f vs =
+  fst (fold_map f vs (ctxt, terms, Inttab.empty))
 
-fun with_name_context tab f xs =
-  let
-    val ns = Symtab.fold (Term.add_free_names o snd) tab []
-    val nctxt = Name.make_context ns
-  in fst (fold_map f xs (Inttab.empty, nctxt)) end
+fun fresh_term T (ctxt, terms, values) =
+  let val (n, ctxt') = yield_singleton Variable.variant_fixes "" ctxt
+  in (Free (n, T), (ctxt', terms, values)) end
 
-fun fresh_term T (tab, nctxt) =
-  let val (n, nctxt') = yield_singleton Name.variants "" nctxt
-  in (Free (n, T), (tab, nctxt')) end
-
-fun term_of_value T i (cx as (tab, _)) =
-  (case Inttab.lookup tab i of
+fun term_of_value T i (cx as (_, _, values)) =
+  (case Inttab.lookup values i of
     SOME t => (t, cx)
   | NONE =>
-      let val (t, (tab', nctxt')) = fresh_term T cx
-      in (t, (Inttab.update (i, t) tab', nctxt')) end)
+      let val (t, (ctxt', terms', values')) = fresh_term T cx
+      in (t, (ctxt', terms', Inttab.update (i, t) values')) end)
+
+fun get_term n (cx as (_, terms, _)) = (the (Symtab.lookup terms n), cx)
 
 fun trans_expr _ True = pair @{term True}
   | trans_expr _ False = pair @{term False}
@@ -100,6 +175,13 @@
         HOLogic.mk_number T i $ HOLogic.mk_number T j)
   | trans_expr T (Value i) = term_of_value T i
   | trans_expr T (Array a) = trans_array T a
+  | trans_expr _ (App (n, es)) =
+      let val get_Ts = take (length es) o Term.binder_types o Term.fastype_of
+      in
+        get_term n #-> (fn t =>
+        fold_map (uncurry trans_expr) (get_Ts t ~~ es) #>>
+        Term.list_comb o pair t)
+      end
 
 and trans_array T a =
   let val dT = Term.domain_type T and rT = Term.range_type T
@@ -112,35 +194,60 @@
           Const (@{const_name fun_upd}, [T, dT, rT] ---> T) $ m $ k $ v))
   end
 
-fun trans_pat i T f x =
-  f (Term.domain_type T) ##>> trans (i-1) (Term.range_type T) x #>>
-  (fn (u, (us, t)) => (u :: us, t))
+fun trans_pattern T ([], e) = trans_expr T e #>> pair []
+  | trans_pattern T (arg :: args, e) =
+      trans_expr (Term.domain_type T) arg ##>>
+      trans_pattern (Term.range_type T) (args, e) #>>
+      (fn (arg', (args', e')) => (arg' :: args', e'))
 
-and trans i T ([], v) =
-      if i > 0 then trans_pat i T fresh_term ([], v)
-      else trans_expr T v #>> pair []
-  | trans i T (p :: ps, v) = trans_pat i T (fn U => trans_expr U p) (ps, v)
+fun mk_fun_upd T U = Const (@{const_name fun_upd}, [T --> U, T, U, T] ---> U)
+
+fun split_type T = (Term.domain_type T, Term.range_type T)
 
-fun mk_eq' t us u = HOLogic.mk_eq (Term.list_comb (t, us), u)
-fun mk_eq (Const (@{const_name fun_app}, _)) (u' :: us', u) = mk_eq' u' us' u
-  | mk_eq t (us, u) = mk_eq' t us u
+fun mk_update ([], u) _ = u
+  | mk_update ([t], u) f =
+      uncurry mk_fun_upd (split_type (Term.fastype_of f)) $ f $ t $ u
+  | mk_update (t :: ts, u) f =
+      let
+        val (dT, rT) = split_type (Term.fastype_of f)
+        val (dT', rT') = split_type rT
+      in
+        mk_fun_upd dT rT $ f $ t $
+          mk_update (ts, u) (Term.absdummy (dT', Const ("_", rT')))
+      end
+
+fun mk_lambda Ts (t, pats) =
+  fold_rev (curry Term.absdummy) Ts t |> fold mk_update pats
 
-fun translate (t, cs) =
-  let val T = Term.fastype_of t
-  in
-    (case (can HOLogic.dest_number t, cs) of
-      (true, [c]) => trans 0 T c #>> (fn (_, u) => [mk_eq u ([], t)])
-    | (_, (es, _) :: _) => fold_map (trans (length es) T) cs #>> map (mk_eq t)
-    | _ => raise TERM ("translate: no cases", [t]))
-  end
+fun translate' T i [([], e)] =
+      if i = 0 then trans_expr T e
+      else 
+        let val ((Us1, Us2), U) = Term.strip_type T |>> chop i
+        in trans_expr (Us2 ---> U) e #>> mk_lambda Us1 o rpair [] end
+  | translate' T i cases =
+      let
+        val (pat_cases, def) = split_last cases |> apsnd snd
+        val ((Us1, Us2), U) = Term.strip_type T |>> chop i
+      in
+        trans_expr (Us2 ---> U) def ##>>
+        fold_map (trans_pattern T) pat_cases #>>
+        mk_lambda Us1
+      end
+
+fun translate ((t, i), cases) =
+  translate' (Term.fastype_of t) i cases #>> HOLogic.mk_eq o pair t
 
 
 (* overall procedure *)
 
-fun parse_counterex ({terms, ...} : SMT_Translate.recon) ls =
+fun parse_counterex ctxt ({terms, ...} : SMT_Translate.recon) ls =
   read_cex ls
-  |> map_filter (lookup_term terms)
-  |> with_name_context terms translate
-  |> flat
+  |> map_filter reduce_function
+  |> drop_skolem_constants terms
+  |> substitute_constants terms
+  |> remove_int_nat_coercions terms
+  |> filter_valid_valuations terms
+  |> with_context ctxt terms translate
 
 end
+