src/HOL/Tools/SMT/smt_builtin.ML
changeset 58061 3d060f43accb
parent 57230 ec5ff6bb2a92
child 66298 5ff9fe3fee66
equal deleted inserted replaced
58060:835b5443b978 58061:3d060f43accb
       
     1 (*  Title:      HOL/Tools/SMT/smt_builtin.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3 
       
     4 Tables of types and terms directly supported by SMT solvers.
       
     5 *)
       
     6 
       
     7 signature SMT_BUILTIN =
       
     8 sig
       
     9   (*for experiments*)
       
    10   val filter_builtins: (typ -> bool) -> Proof.context -> Proof.context
       
    11 
       
    12   (*built-in types*)
       
    13   val add_builtin_typ: SMT_Util.class ->
       
    14     typ * (typ -> string option) * (typ -> int -> string option) -> Context.generic ->
       
    15     Context.generic
       
    16   val add_builtin_typ_ext: typ * (typ -> bool) -> Context.generic ->
       
    17     Context.generic
       
    18   val dest_builtin_typ: Proof.context -> typ -> string option
       
    19   val is_builtin_typ_ext: Proof.context -> typ -> bool
       
    20 
       
    21   (*built-in numbers*)
       
    22   val dest_builtin_num: Proof.context -> term -> (string * typ) option
       
    23   val is_builtin_num: Proof.context -> term -> bool
       
    24   val is_builtin_num_ext: Proof.context -> term -> bool
       
    25 
       
    26   (*built-in functions*)
       
    27   type 'a bfun = Proof.context -> typ -> term list -> 'a
       
    28   type bfunr = string * int * term list * (term list -> term)
       
    29   val add_builtin_fun: SMT_Util.class -> (string * typ) * bfunr option bfun -> Context.generic ->
       
    30     Context.generic
       
    31   val add_builtin_fun': SMT_Util.class -> term * string -> Context.generic -> Context.generic
       
    32   val add_builtin_fun_ext: (string * typ) * term list bfun -> Context.generic -> Context.generic
       
    33   val add_builtin_fun_ext': string * typ -> Context.generic -> Context.generic
       
    34   val add_builtin_fun_ext'': string -> Context.generic -> Context.generic
       
    35   val dest_builtin_fun: Proof.context -> string * typ -> term list -> bfunr option
       
    36   val dest_builtin_eq: Proof.context -> term -> term -> bfunr option
       
    37   val dest_builtin_pred: Proof.context -> string * typ -> term list -> bfunr option
       
    38   val dest_builtin_conn: Proof.context -> string * typ -> term list -> bfunr option
       
    39   val dest_builtin: Proof.context -> string * typ -> term list -> bfunr option
       
    40   val dest_builtin_ext: Proof.context -> string * typ -> term list -> term list option
       
    41   val is_builtin_fun: Proof.context -> string * typ -> term list -> bool
       
    42   val is_builtin_fun_ext: Proof.context -> string * typ -> term list -> bool
       
    43 end;
       
    44 
       
    45 structure SMT_Builtin: SMT_BUILTIN =
       
    46 struct
       
    47 
       
    48 
       
    49 (* built-in tables *)
       
    50 
       
    51 datatype ('a, 'b) kind = Ext of 'a | Int of 'b
       
    52 
       
    53 type ('a, 'b) ttab = ((typ * ('a, 'b) kind) Ord_List.T) SMT_Util.dict
       
    54 
       
    55 fun typ_ord ((T, _), (U, _)) =
       
    56   let
       
    57     fun tord (TVar _, Type _) = GREATER
       
    58       | tord (Type _, TVar _) = LESS
       
    59       | tord (Type (n, Ts), Type (m, Us)) =
       
    60           if n = m then list_ord tord (Ts, Us)
       
    61           else Term_Ord.typ_ord (T, U)
       
    62       | tord TU = Term_Ord.typ_ord TU
       
    63   in tord (T, U) end
       
    64 
       
    65 fun insert_ttab cs T f =
       
    66   SMT_Util.dict_map_default (cs, [])
       
    67     (Ord_List.insert typ_ord (perhaps (try Logic.varifyT_global) T, f))
       
    68 
       
    69 fun merge_ttab ttabp = SMT_Util.dict_merge (Ord_List.merge typ_ord) ttabp
       
    70 
       
    71 fun lookup_ttab ctxt ttab T =
       
    72   let fun match (U, _) = Sign.typ_instance (Proof_Context.theory_of ctxt) (T, U)
       
    73   in
       
    74     get_first (find_first match) (SMT_Util.dict_lookup ttab (SMT_Config.solver_class_of ctxt))
       
    75   end
       
    76 
       
    77 type ('a, 'b) btab = ('a, 'b) ttab Symtab.table
       
    78 
       
    79 fun insert_btab cs n T f =
       
    80   Symtab.map_default (n, []) (insert_ttab cs T f)
       
    81 
       
    82 fun merge_btab btabp = Symtab.join (K merge_ttab) btabp
       
    83 
       
    84 fun lookup_btab ctxt btab (n, T) =
       
    85   (case Symtab.lookup btab n of
       
    86     NONE => NONE
       
    87   | SOME ttab => lookup_ttab ctxt ttab T)
       
    88 
       
    89 type 'a bfun = Proof.context -> typ -> term list -> 'a
       
    90 
       
    91 type bfunr = string * int * term list * (term list -> term)
       
    92 
       
    93 structure Builtins = Generic_Data
       
    94 (
       
    95   type T =
       
    96     (typ -> bool, (typ -> string option) * (typ -> int -> string option)) ttab *
       
    97     (term list bfun, bfunr option bfun) btab
       
    98   val empty = ([], Symtab.empty)
       
    99   val extend = I
       
   100   fun merge ((t1, b1), (t2, b2)) = (merge_ttab (t1, t2), merge_btab (b1, b2))
       
   101 )
       
   102 
       
   103 fun filter_ttab keep_T = map (apsnd (filter (keep_T o fst)))
       
   104 
       
   105 fun filter_builtins keep_T =
       
   106   Context.proof_map (Builtins.map (fn (ttab, btab) =>
       
   107     (filter_ttab keep_T ttab, Symtab.map (K (filter_ttab keep_T)) btab)))
       
   108 
       
   109 
       
   110 (* built-in types *)
       
   111 
       
   112 fun add_builtin_typ cs (T, f, g) =
       
   113   Builtins.map (apfst (insert_ttab cs T (Int (f, g))))
       
   114 
       
   115 fun add_builtin_typ_ext (T, f) = Builtins.map (apfst (insert_ttab SMT_Util.basicC T (Ext f)))
       
   116 
       
   117 fun lookup_builtin_typ ctxt =
       
   118   lookup_ttab ctxt (fst (Builtins.get (Context.Proof ctxt)))
       
   119 
       
   120 fun dest_builtin_typ ctxt T =
       
   121   (case lookup_builtin_typ ctxt T of
       
   122     SOME (_, Int (f, _)) => f T
       
   123   | _ => NONE)
       
   124 
       
   125 fun is_builtin_typ_ext ctxt T =
       
   126   (case lookup_builtin_typ ctxt T of
       
   127     SOME (_, Int (f, _)) => is_some (f T)
       
   128   | SOME (_, Ext f) => f T
       
   129   | NONE => false)
       
   130 
       
   131 
       
   132 (* built-in numbers *)
       
   133 
       
   134 fun dest_builtin_num ctxt t =
       
   135   (case try HOLogic.dest_number t of
       
   136     NONE => NONE
       
   137   | SOME (T, i) =>
       
   138       if i < 0 then NONE else
       
   139         (case lookup_builtin_typ ctxt T of
       
   140           SOME (_, Int (_, g)) => g T i |> Option.map (rpair T)
       
   141         | _ => NONE))
       
   142 
       
   143 val is_builtin_num = is_some oo dest_builtin_num
       
   144 
       
   145 fun is_builtin_num_ext ctxt t =
       
   146   (case try HOLogic.dest_number t of
       
   147     NONE => false
       
   148   | SOME (T, _) => is_builtin_typ_ext ctxt T)
       
   149 
       
   150 
       
   151 (* built-in functions *)
       
   152 
       
   153 fun add_builtin_fun cs ((n, T), f) =
       
   154   Builtins.map (apsnd (insert_btab cs n T (Int f)))
       
   155 
       
   156 fun add_builtin_fun' cs (t, n) =
       
   157   let
       
   158     val c as (m, T) = Term.dest_Const t
       
   159     fun app U ts = Term.list_comb (Const (m, U), ts)
       
   160     fun bfun _ U ts = SOME (n, length (Term.binder_types T), ts, app U)
       
   161   in add_builtin_fun cs (c, bfun) end
       
   162 
       
   163 fun add_builtin_fun_ext ((n, T), f) =
       
   164   Builtins.map (apsnd (insert_btab SMT_Util.basicC n T (Ext f)))
       
   165 
       
   166 fun add_builtin_fun_ext' c = add_builtin_fun_ext (c, fn _ => fn _ => I)
       
   167 
       
   168 fun add_builtin_fun_ext'' n context =
       
   169   let val thy = Context.theory_of context
       
   170   in add_builtin_fun_ext' (n, Sign.the_const_type thy n) context end
       
   171 
       
   172 fun lookup_builtin_fun ctxt =
       
   173   lookup_btab ctxt (snd (Builtins.get (Context.Proof ctxt)))
       
   174 
       
   175 fun dest_builtin_fun ctxt (c as (_, T)) ts =
       
   176   (case lookup_builtin_fun ctxt c of
       
   177     SOME (_, Int f) => f ctxt T ts
       
   178   | _ => NONE)
       
   179 
       
   180 fun dest_builtin_eq ctxt t u =
       
   181   let
       
   182     val aT = TFree (Name.aT, @{sort type})
       
   183     val c = (@{const_name HOL.eq}, aT --> aT --> @{typ bool})
       
   184     fun mk ts = Term.list_comb (HOLogic.eq_const (Term.fastype_of (hd ts)), ts)
       
   185   in
       
   186     dest_builtin_fun ctxt c []
       
   187     |> Option.map (fn (n, i, _, _) => (n, i, [t, u], mk))
       
   188   end
       
   189 
       
   190 fun special_builtin_fun pred ctxt (c as (_, T)) ts =
       
   191   if pred (Term.body_type T, Term.binder_types T) then
       
   192     dest_builtin_fun ctxt c ts
       
   193   else NONE
       
   194 
       
   195 fun dest_builtin_pred ctxt = special_builtin_fun (equal @{typ bool} o fst) ctxt
       
   196 
       
   197 fun dest_builtin_conn ctxt =
       
   198   special_builtin_fun (forall (equal @{typ bool}) o (op ::)) ctxt
       
   199 
       
   200 fun dest_builtin ctxt c ts =
       
   201   let val t = Term.list_comb (Const c, ts)
       
   202   in
       
   203     (case dest_builtin_num ctxt t of
       
   204       SOME (n, _) => SOME (n, 0, [], K t)
       
   205     | NONE => dest_builtin_fun ctxt c ts)
       
   206   end
       
   207 
       
   208 fun dest_builtin_fun_ext ctxt (c as (_, T)) ts =
       
   209   (case lookup_builtin_fun ctxt c of
       
   210     SOME (_, Int f) => f ctxt T ts |> Option.map (fn (_, _, us, _) => us)
       
   211   | SOME (_, Ext f) => SOME (f ctxt T ts)
       
   212   | NONE => NONE)
       
   213 
       
   214 fun dest_builtin_ext ctxt c ts =
       
   215   if is_builtin_num_ext ctxt (Term.list_comb (Const c, ts)) then SOME []
       
   216   else dest_builtin_fun_ext ctxt c ts
       
   217 
       
   218 fun is_builtin_fun ctxt c ts = is_some (dest_builtin_fun ctxt c ts)
       
   219 
       
   220 fun is_builtin_fun_ext ctxt c ts = is_some (dest_builtin_fun_ext ctxt c ts)
       
   221 
       
   222 end;