centralized handling of built-in types and constants for bitvectors
authorboehmes
Tue, 07 Dec 2010 14:54:31 +0100
changeset 41061 492f8fd35fc0
parent 41060 4199fdcfa3c0
child 41062 304cfdbc6475
centralized handling of built-in types and constants for bitvectors
src/HOL/Word/Tools/smt_word.ML
--- a/src/HOL/Word/Tools/smt_word.ML	Tue Dec 07 14:53:44 2010 +0100
+++ b/src/HOL/Word/Tools/smt_word.ML	Tue Dec 07 14:54:31 2010 +0100
@@ -12,6 +12,8 @@
 structure SMT_Word: SMT_WORD =
 struct
 
+structure B = SMT_Builtin
+
 
 (* utilities *)
 
@@ -30,7 +32,6 @@
   | dest_wordT T = raise TYPE ("dest_wordT", [T], [])
 
 
-
 (* SMT-LIB logic *)
 
 fun smtlib_logic ts =
@@ -39,25 +40,37 @@
   else NONE
 
 
-
 (* SMT-LIB builtins *)
 
 local
+  val smtlibC = SMTLIB_Interface.smtlibC
+
+  val wordT = @{typ "'a::len word"}
+
   fun index1 n i = n ^ "[" ^ string_of_int i ^ "]"
   fun index2 n i j = n ^ "[" ^ string_of_int i ^ ":" ^ string_of_int j ^ "]"
 
-  fun smtlib_builtin_typ (Type (@{type_name word}, [T])) =
+  fun word_typ (Type (@{type_name word}, [T])) =
         Option.map (index1 "BitVec") (try dest_binT T)
-    | smtlib_builtin_typ _ = NONE
+    | word_typ _ = NONE
 
-  fun smtlib_builtin_num (Type (@{type_name word}, [T])) i =
+  fun word_num (Type (@{type_name word}, [T])) i =
         Option.map (index1 ("bv" ^ string_of_int i)) (try dest_binT T)
-    | smtlib_builtin_num _ _ = NONE
+    | word_num _ _ = NONE
 
   fun if_fixed n T ts =
     let val (Ts, T) = Term.strip_type T
     in if forall (can dest_wordT) (T :: Ts) then SOME (n, ts) else NONE end
 
+  fun if_fixed' n T ts =
+    if forall (can dest_wordT) (Term.binder_types T) then SOME (n, ts)
+    else NONE
+
+  fun add_word_fun f (t, n) =
+    B.add_builtin_fun smtlibC (Term.dest_Const t, K (f n))
+
+  fun add_word_fun' f (t, n) = add_word_fun f (t, n)
+
   fun dest_word_funT (Type ("fun", [T, U])) = (dest_wordT T, dest_wordT U)
     | dest_word_funT T = raise TYPE ("dest_word_funT", [T], [])
   fun dest_nat (@{const nat} $ n :: ts) = (snd (HOLogic.dest_number n), ts)
@@ -76,6 +89,10 @@
       | _ => NONE)
     end
 
+  fun extract n T ts =
+    try dest_nat_word_funT (T, ts)
+    |> Option.map (fn ((_, i), (lb, ts')) => (index2 n (i + lb - 1) lb, ts'))
+
   fun extend n T ts =
     (case try dest_word_funT T of
       SOME (i, j) => if j-i >= 0 then SOME (index1 n (j-i), ts) else NONE
@@ -84,60 +101,45 @@
   fun rotate n T ts =
     try dest_nat ts
     |> Option.map (fn (i, ts') => (index1 n i, ts'))
-
-  fun extract n T ts =
-    try dest_nat_word_funT (T, ts)
-    |> Option.map (fn ((_, i), (lb, ts')) => (index2 n (i + lb - 1) lb, ts'))
-
-  fun smtlib_builtin_func @{const_name uminus} = if_fixed "bvneg"
-    | smtlib_builtin_func @{const_name plus} = if_fixed "bvadd"
-    | smtlib_builtin_func @{const_name minus} = if_fixed "bvsub"
-    | smtlib_builtin_func @{const_name times} = if_fixed "bvmul"
-    | smtlib_builtin_func @{const_name bitNOT} = if_fixed "bvnot"
-    | smtlib_builtin_func @{const_name bitAND} = if_fixed "bvand"
-    | smtlib_builtin_func @{const_name bitOR} = if_fixed "bvor"
-    | smtlib_builtin_func @{const_name bitXOR} = if_fixed "bvxor"
-    | smtlib_builtin_func @{const_name word_cat} = if_fixed "concat"
-    | smtlib_builtin_func @{const_name shiftl} = shift "bvshl"
-    | smtlib_builtin_func @{const_name shiftr} = shift "bvlshr"
-    | smtlib_builtin_func @{const_name sshiftr} = shift "bvashr"
-    | smtlib_builtin_func @{const_name slice} = extract "extract"
-    | smtlib_builtin_func @{const_name ucast} = extend "zero_extend"
-    | smtlib_builtin_func @{const_name scast} = extend "sign_extend"
-    | smtlib_builtin_func @{const_name word_rotl} = rotate "rotate_left"
-    | smtlib_builtin_func @{const_name word_rotr} = rotate "rotate_right"
-    | smtlib_builtin_func _ = (fn _ => K NONE)
-        (* FIXME: support more builtin bitvector functions:
-             bvudiv/bvurem and bvsdiv/bvsmod/bvsrem *)
-
-  fun smtlib_builtin_pred @{const_name less} = SOME "bvult"
-    | smtlib_builtin_pred @{const_name less_eq} = SOME "bvule"
-    | smtlib_builtin_pred @{const_name word_sless} = SOME "bvslt"
-    | smtlib_builtin_pred @{const_name word_sle} = SOME "bvsle"
-    | smtlib_builtin_pred _ = NONE
-
-  fun smtlib_builtin_pred' (n, T) =
-    if can (dest_wordT o Term.domain_type) T then smtlib_builtin_pred n
-    else NONE
 in
 
-val smtlib_builtins : SMTLIB_Interface.builtins = {
-  builtin_typ = smtlib_builtin_typ,
-  builtin_num = smtlib_builtin_num,
-  builtin_func = (fn (n, T) => fn ts => smtlib_builtin_func n T ts),
-  builtin_pred = (fn c => fn ts =>
-    smtlib_builtin_pred' c |> Option.map (rpair ts)),
-  is_builtin_pred = curry (is_some o smtlib_builtin_pred') }
+val setup_builtins =
+  B.add_builtin_typ smtlibC (wordT, word_typ, word_num) #>
+  fold (add_word_fun' if_fixed) [
+    (@{term "uminus :: 'a::len word => _"}, "bvneg"),
+    (@{term "plus :: 'a::len word => _"}, "bvadd"),
+    (@{term "minus :: 'a::len word => _"}, "bvsub"),
+    (@{term "times :: 'a::len word => _"}, "bvmul"),
+    (@{term "bitNOT :: 'a::len word => _"}, "bvnot"),
+    (@{term "bitAND :: 'a::len word => _"}, "bvand"),
+    (@{term "bitOR :: 'a::len word => _"}, "bvor"),
+    (@{term "bitXOR :: 'a::len word => _"}, "bvxor"),
+    (@{term "word_cat :: 'a::len word => _"}, "concat") ] #>
+  fold (add_word_fun shift) [
+    (@{term "shiftl :: 'a::len word => _ "}, "bvshl"),
+    (@{term "shiftr :: 'a::len word => _"}, "bvlshr"),
+    (@{term "sshiftr :: 'a::len word => _"}, "bvashr") ] #>
+  add_word_fun extract
+    (@{term "slice :: _ => 'a::len word => _"}, "extract") #>
+  fold (add_word_fun extend) [
+    (@{term "ucast :: 'a::len word => _"}, "zero_extend"),
+    (@{term "scast :: 'a::len word => _"}, "sign_extend") ] #>
+  fold (add_word_fun rotate) [
+    (@{term word_rotl}, "rotate_left"),
+    (@{term word_rotr}, "rotate_right") ] #>
+  fold (add_word_fun' if_fixed') [
+    (@{term "less :: 'a::len word => _"}, "bvult"),
+    (@{term "less_eq :: 'a::len word => _"}, "bvule"),
+    (@{term word_sless}, "bvslt"),
+    (@{term word_sle}, "bvsle") ]
 
 end
 
 
-
 (* setup *)
 
 val setup = 
-  Context.theory_map (
-    SMTLIB_Interface.add_logic smtlib_logic #>
-    SMTLIB_Interface.add_builtins smtlib_builtins)
+  Context.theory_map (SMTLIB_Interface.add_logic (20, smtlib_logic)) #>
+  setup_builtins
 
 end