--- a/src/HOL/Word/Tools/smt_word.ML Tue Dec 07 14:53:44 2010 +0100
+++ b/src/HOL/Word/Tools/smt_word.ML Tue Dec 07 14:54:31 2010 +0100
@@ -12,6 +12,8 @@
structure SMT_Word: SMT_WORD =
struct
+structure B = SMT_Builtin
+
(* utilities *)
@@ -30,7 +32,6 @@
| dest_wordT T = raise TYPE ("dest_wordT", [T], [])
-
(* SMT-LIB logic *)
fun smtlib_logic ts =
@@ -39,25 +40,37 @@
else NONE
-
(* SMT-LIB builtins *)
local
+ val smtlibC = SMTLIB_Interface.smtlibC
+
+ val wordT = @{typ "'a::len word"}
+
fun index1 n i = n ^ "[" ^ string_of_int i ^ "]"
fun index2 n i j = n ^ "[" ^ string_of_int i ^ ":" ^ string_of_int j ^ "]"
- fun smtlib_builtin_typ (Type (@{type_name word}, [T])) =
+ fun word_typ (Type (@{type_name word}, [T])) =
Option.map (index1 "BitVec") (try dest_binT T)
- | smtlib_builtin_typ _ = NONE
+ | word_typ _ = NONE
- fun smtlib_builtin_num (Type (@{type_name word}, [T])) i =
+ fun word_num (Type (@{type_name word}, [T])) i =
Option.map (index1 ("bv" ^ string_of_int i)) (try dest_binT T)
- | smtlib_builtin_num _ _ = NONE
+ | word_num _ _ = NONE
fun if_fixed n T ts =
let val (Ts, T) = Term.strip_type T
in if forall (can dest_wordT) (T :: Ts) then SOME (n, ts) else NONE end
+ fun if_fixed' n T ts =
+ if forall (can dest_wordT) (Term.binder_types T) then SOME (n, ts)
+ else NONE
+
+ fun add_word_fun f (t, n) =
+ B.add_builtin_fun smtlibC (Term.dest_Const t, K (f n))
+
+ fun add_word_fun' f (t, n) = add_word_fun f (t, n)
+
fun dest_word_funT (Type ("fun", [T, U])) = (dest_wordT T, dest_wordT U)
| dest_word_funT T = raise TYPE ("dest_word_funT", [T], [])
fun dest_nat (@{const nat} $ n :: ts) = (snd (HOLogic.dest_number n), ts)
@@ -76,6 +89,10 @@
| _ => NONE)
end
+ fun extract n T ts =
+ try dest_nat_word_funT (T, ts)
+ |> Option.map (fn ((_, i), (lb, ts')) => (index2 n (i + lb - 1) lb, ts'))
+
fun extend n T ts =
(case try dest_word_funT T of
SOME (i, j) => if j-i >= 0 then SOME (index1 n (j-i), ts) else NONE
@@ -84,60 +101,45 @@
fun rotate n T ts =
try dest_nat ts
|> Option.map (fn (i, ts') => (index1 n i, ts'))
-
- fun extract n T ts =
- try dest_nat_word_funT (T, ts)
- |> Option.map (fn ((_, i), (lb, ts')) => (index2 n (i + lb - 1) lb, ts'))
-
- fun smtlib_builtin_func @{const_name uminus} = if_fixed "bvneg"
- | smtlib_builtin_func @{const_name plus} = if_fixed "bvadd"
- | smtlib_builtin_func @{const_name minus} = if_fixed "bvsub"
- | smtlib_builtin_func @{const_name times} = if_fixed "bvmul"
- | smtlib_builtin_func @{const_name bitNOT} = if_fixed "bvnot"
- | smtlib_builtin_func @{const_name bitAND} = if_fixed "bvand"
- | smtlib_builtin_func @{const_name bitOR} = if_fixed "bvor"
- | smtlib_builtin_func @{const_name bitXOR} = if_fixed "bvxor"
- | smtlib_builtin_func @{const_name word_cat} = if_fixed "concat"
- | smtlib_builtin_func @{const_name shiftl} = shift "bvshl"
- | smtlib_builtin_func @{const_name shiftr} = shift "bvlshr"
- | smtlib_builtin_func @{const_name sshiftr} = shift "bvashr"
- | smtlib_builtin_func @{const_name slice} = extract "extract"
- | smtlib_builtin_func @{const_name ucast} = extend "zero_extend"
- | smtlib_builtin_func @{const_name scast} = extend "sign_extend"
- | smtlib_builtin_func @{const_name word_rotl} = rotate "rotate_left"
- | smtlib_builtin_func @{const_name word_rotr} = rotate "rotate_right"
- | smtlib_builtin_func _ = (fn _ => K NONE)
- (* FIXME: support more builtin bitvector functions:
- bvudiv/bvurem and bvsdiv/bvsmod/bvsrem *)
-
- fun smtlib_builtin_pred @{const_name less} = SOME "bvult"
- | smtlib_builtin_pred @{const_name less_eq} = SOME "bvule"
- | smtlib_builtin_pred @{const_name word_sless} = SOME "bvslt"
- | smtlib_builtin_pred @{const_name word_sle} = SOME "bvsle"
- | smtlib_builtin_pred _ = NONE
-
- fun smtlib_builtin_pred' (n, T) =
- if can (dest_wordT o Term.domain_type) T then smtlib_builtin_pred n
- else NONE
in
-val smtlib_builtins : SMTLIB_Interface.builtins = {
- builtin_typ = smtlib_builtin_typ,
- builtin_num = smtlib_builtin_num,
- builtin_func = (fn (n, T) => fn ts => smtlib_builtin_func n T ts),
- builtin_pred = (fn c => fn ts =>
- smtlib_builtin_pred' c |> Option.map (rpair ts)),
- is_builtin_pred = curry (is_some o smtlib_builtin_pred') }
+val setup_builtins =
+ B.add_builtin_typ smtlibC (wordT, word_typ, word_num) #>
+ fold (add_word_fun' if_fixed) [
+ (@{term "uminus :: 'a::len word => _"}, "bvneg"),
+ (@{term "plus :: 'a::len word => _"}, "bvadd"),
+ (@{term "minus :: 'a::len word => _"}, "bvsub"),
+ (@{term "times :: 'a::len word => _"}, "bvmul"),
+ (@{term "bitNOT :: 'a::len word => _"}, "bvnot"),
+ (@{term "bitAND :: 'a::len word => _"}, "bvand"),
+ (@{term "bitOR :: 'a::len word => _"}, "bvor"),
+ (@{term "bitXOR :: 'a::len word => _"}, "bvxor"),
+ (@{term "word_cat :: 'a::len word => _"}, "concat") ] #>
+ fold (add_word_fun shift) [
+ (@{term "shiftl :: 'a::len word => _ "}, "bvshl"),
+ (@{term "shiftr :: 'a::len word => _"}, "bvlshr"),
+ (@{term "sshiftr :: 'a::len word => _"}, "bvashr") ] #>
+ add_word_fun extract
+ (@{term "slice :: _ => 'a::len word => _"}, "extract") #>
+ fold (add_word_fun extend) [
+ (@{term "ucast :: 'a::len word => _"}, "zero_extend"),
+ (@{term "scast :: 'a::len word => _"}, "sign_extend") ] #>
+ fold (add_word_fun rotate) [
+ (@{term word_rotl}, "rotate_left"),
+ (@{term word_rotr}, "rotate_right") ] #>
+ fold (add_word_fun' if_fixed') [
+ (@{term "less :: 'a::len word => _"}, "bvult"),
+ (@{term "less_eq :: 'a::len word => _"}, "bvule"),
+ (@{term word_sless}, "bvslt"),
+ (@{term word_sle}, "bvsle") ]
end
-
(* setup *)
val setup =
- Context.theory_map (
- SMTLIB_Interface.add_logic smtlib_logic #>
- SMTLIB_Interface.add_builtins smtlib_builtins)
+ Context.theory_map (SMTLIB_Interface.add_logic (20, smtlib_logic)) #>
+ setup_builtins
end