moved SMT classes and dictionary functions to SMT_Utils
authorboehmes
Wed, 15 Dec 2010 08:39:24 +0100
changeset 41124 1de17a2de5ad
parent 41123 3bb9be510a9d
child 41125 4a9eec045f2a
moved SMT classes and dictionary functions to SMT_Utils
src/HOL/SMT.thy
src/HOL/Tools/SMT/smt_builtin.ML
src/HOL/Tools/SMT/smt_config.ML
src/HOL/Tools/SMT/smt_solver.ML
src/HOL/Tools/SMT/smt_utils.ML
src/HOL/Tools/SMT/smtlib_interface.ML
src/HOL/Tools/SMT/z3_interface.ML
--- a/src/HOL/SMT.thy	Wed Dec 15 08:39:24 2010 +0100
+++ b/src/HOL/SMT.thy	Wed Dec 15 08:39:24 2010 +0100
@@ -8,9 +8,9 @@
 imports List
 uses
   "Tools/Datatype/datatype_selectors.ML"
+  "Tools/SMT/smt_utils.ML"
   "Tools/SMT/smt_failure.ML"
   "Tools/SMT/smt_config.ML"
-  "Tools/SMT/smt_utils.ML"
   "Tools/SMT/smt_monomorph.ML"
   ("Tools/SMT/smt_builtin.ML")
   ("Tools/SMT/smt_normalize.ML")
--- a/src/HOL/Tools/SMT/smt_builtin.ML	Wed Dec 15 08:39:24 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_builtin.ML	Wed Dec 15 08:39:24 2010 +0100
@@ -7,7 +7,7 @@
 signature SMT_BUILTIN =
 sig
   (*built-in types*)
-  val add_builtin_typ: SMT_Config.class ->
+  val add_builtin_typ: SMT_Utils.class ->
     typ * (typ -> string option) * (typ -> int -> string option) ->
     Context.generic -> Context.generic
   val add_builtin_typ_ext: typ * (typ -> bool) -> Context.generic ->
@@ -23,10 +23,10 @@
 
   (*built-in functions*)
   type 'a bfun = Proof.context -> typ -> term list -> 'a
-  val add_builtin_fun: SMT_Config.class ->
+  val add_builtin_fun: SMT_Utils.class ->
     (string * typ) * (string * term list) option bfun -> Context.generic ->
     Context.generic
-  val add_builtin_fun': SMT_Config.class -> term * string -> Context.generic ->
+  val add_builtin_fun': SMT_Utils.class -> term * string -> Context.generic ->
     Context.generic
   val add_builtin_fun_ext: (string * typ) * bool bfun -> Context.generic ->
     Context.generic
@@ -43,6 +43,7 @@
 structure SMT_Builtin: SMT_BUILTIN =
 struct
 
+structure U = SMT_Utils
 structure C = SMT_Config
 
 
@@ -50,7 +51,7 @@
 
 datatype ('a, 'b) kind = Ext of 'a | Int of 'b
 
-type ('a, 'b) ttab = (C.class * (typ * ('a, 'b) kind) Ord_List.T) list 
+type ('a, 'b) ttab = ((typ * ('a, 'b) kind) Ord_List.T) U.dict 
 
 fun typ_ord ((T, _), (U, _)) =
   let
@@ -63,21 +64,17 @@
   in tord (T, U) end
 
 fun insert_ttab cs T f =
-  AList.map_default (op =) (cs, [])
+  U.dict_map_default (cs, [])
     (Ord_List.insert typ_ord (perhaps (try Logic.varifyT_global) T, f))
 
 fun merge_ttab ttabp =
-  AList.join (op =) (K (uncurry (Ord_List.union typ_ord) o swap)) ttabp
+  U.dict_merge (uncurry (Ord_List.union typ_ord) o swap) ttabp
 
 fun lookup_ttab ctxt ttab T =
-  let
-    val cs = C.solver_class_of ctxt
-    fun match (U, _) = Sign.typ_instance (ProofContext.theory_of ctxt) (T, U)
-
-    fun matching (cs', Txs) =
-      if is_prefix (op =) cs' cs then find_first match Txs
-      else NONE
-  in get_first matching ttab end
+  let fun match (U, _) = Sign.typ_instance (ProofContext.theory_of ctxt) (T, U)
+  in
+    get_first (find_first match) (U.dict_lookup ttab (C.solver_class_of ctxt))
+  end
 
 type ('a, 'b) btab = ('a, 'b) ttab Symtab.table
 
@@ -109,7 +106,7 @@
   Builtin_Types.map (insert_ttab cs T (Int (f, g)))
 
 fun add_builtin_typ_ext (T, f) =
-  Builtin_Types.map (insert_ttab C.basicC T (Ext f))
+  Builtin_Types.map (insert_ttab U.basicC T (Ext f))
 
 fun lookup_builtin_typ ctxt =
   lookup_ttab ctxt (Builtin_Types.get (Context.Proof ctxt))
@@ -163,7 +160,7 @@
 
 fun basic_builtin_funcs () : builtin_funcs =
   empty_btab ()
-  |> fold (raw_add_builtin_fun_ext @{theory} C.basicC) basic_builtin_fun_names
+  |> fold (raw_add_builtin_fun_ext @{theory} U.basicC) basic_builtin_fun_names
        (* FIXME: SMT_Normalize should check that they are properly used *)
 
 structure Builtin_Funcs = Generic_Data
@@ -181,7 +178,7 @@
   add_builtin_fun cs (Term.dest_Const t, fn _ => fn _ => SOME o pair n)
 
 fun add_builtin_fun_ext ((n, T), f) =
-  Builtin_Funcs.map (insert_btab C.basicC n T (Ext f))
+  Builtin_Funcs.map (insert_btab U.basicC n T (Ext f))
 
 fun add_builtin_fun_ext' c = add_builtin_fun_ext (c, true3)
 
--- a/src/HOL/Tools/SMT/smt_config.ML	Wed Dec 15 08:39:24 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_config.ML	Wed Dec 15 08:39:24 2010 +0100
@@ -6,16 +6,13 @@
 
 signature SMT_CONFIG =
 sig
-  (*class*)
-  type class = string list
-  val basicC: class
-
   (*solver*)
-  val add_solver: string * class -> Context.generic -> Context.generic
+  val add_solver: string * SMT_Utils.class -> Context.generic ->
+    Context.generic
   val set_solver_options: string * string -> Context.generic -> Context.generic
   val select_solver: string -> Context.generic -> Context.generic
   val solver_of: Proof.context -> string
-  val solver_class_of: Proof.context -> class
+  val solver_class_of: Proof.context -> SMT_Utils.class
   val solver_options_of: Proof.context -> string list
 
   (*options*)
@@ -52,18 +49,11 @@
 structure SMT_Config: SMT_CONFIG =
 struct
 
-(* class *)
-
-type class = string list
-
-val basicC = []
-
-
 (* solver *)
 
 structure Solvers = Generic_Data
 (
-  type T = (class * string list) Symtab.table * string option
+  type T = (SMT_Utils.class * string list) Symtab.table * string option
   val empty = (Symtab.empty, NONE)
   val extend = I
   fun merge ((ss1, s), (ss2, _)) = (Symtab.merge (K true) (ss1, ss2), s)
--- a/src/HOL/Tools/SMT/smt_solver.ML	Wed Dec 15 08:39:24 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_solver.ML	Wed Dec 15 08:39:24 2010 +0100
@@ -8,7 +8,7 @@
 sig
   (*configuration*)
   type interface = {
-    class: SMT_Config.class,
+    class: SMT_Utils.class,
     extra_norm: SMT_Normalize.extra_norm,
     translate: SMT_Translate.config }
   datatype outcome = Unsat | Sat | Unknown
@@ -58,7 +58,7 @@
 (* configuration *)
 
 type interface = {
-  class: SMT_Config.class,
+  class: SMT_Utils.class,
   extra_norm: SMT_Normalize.extra_norm,
   translate: SMT_Translate.config }
 
--- a/src/HOL/Tools/SMT/smt_utils.ML	Wed Dec 15 08:39:24 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_utils.ML	Wed Dec 15 08:39:24 2010 +0100
@@ -10,6 +10,15 @@
   val repeat: ('a -> 'a option) -> 'a -> 'a
   val repeat_yield: ('a -> 'b -> ('a * 'b) option) -> 'a -> 'b -> 'a * 'b
 
+  (*class dictionaries*)
+  type class = string list
+  val basicC: class
+  type 'a dict = (class * 'a) Ord_List.T
+  val dict_map_default: class * 'a -> ('a -> 'a) -> 'a dict -> 'a dict
+  val dict_update: class * 'a -> 'a dict -> 'a dict
+  val dict_merge: ('a * 'a -> 'a) -> 'a dict * 'a dict -> 'a dict
+  val dict_lookup: 'a dict -> class -> 'a list
+
   (*types*)
   val dest_funT: int -> typ -> typ list * typ
 
@@ -57,6 +66,32 @@
   in rep end
 
 
+(* class dictionaries *)
+
+type class = string list
+
+val basicC = []
+
+type 'a dict = (class * 'a) Ord_List.T
+
+fun class_ord ((cs1, _), (cs2, _)) = list_ord fast_string_ord (cs1, cs2)
+
+fun dict_insert (cs, x) d =
+  if AList.defined (op =) d cs then d
+  else Ord_List.insert class_ord (cs, x) d
+
+fun dict_map_default (cs, x) f =
+  dict_insert (cs, x) #> AList.map_entry (op =) cs f
+
+fun dict_update (e as (_, x)) = dict_map_default e (K x)
+
+fun dict_merge val_merge = sort class_ord o AList.join (op =) (K val_merge)
+
+fun dict_lookup d cs =
+  let fun match (cs', x) = if is_prefix (op =) cs' cs then SOME x else NONE
+  in map_filter match d end
+
+
 (* types *)
 
 val dest_funT =
--- a/src/HOL/Tools/SMT/smtlib_interface.ML	Wed Dec 15 08:39:24 2010 +0100
+++ b/src/HOL/Tools/SMT/smtlib_interface.ML	Wed Dec 15 08:39:24 2010 +0100
@@ -6,7 +6,7 @@
 
 signature SMTLIB_INTERFACE =
 sig
-  val smtlibC: SMT_Config.class
+  val smtlibC: SMT_Utils.class
   val add_logic: int * (term list -> string option) -> Context.generic ->
     Context.generic
   val interface: SMT_Solver.interface
--- a/src/HOL/Tools/SMT/z3_interface.ML	Wed Dec 15 08:39:24 2010 +0100
+++ b/src/HOL/Tools/SMT/z3_interface.ML	Wed Dec 15 08:39:24 2010 +0100
@@ -6,7 +6,7 @@
 
 signature Z3_INTERFACE =
 sig
-  val smtlib_z3C: SMT_Config.class
+  val smtlib_z3C: SMT_Utils.class
   val interface: SMT_Solver.interface
   val setup: theory -> theory