src/HOL/Tools/SMT/smt_translate.ML
author boehmes
Mon Dec 06 15:38:02 2010 +0100 (2010-12-06)
changeset 41057 8dbc951a291c
parent 40697 c3979dd80a50
child 41059 d2b1fc1b8e19
permissions -rw-r--r--
tuned
     1 (*  Title:      HOL/Tools/SMT/smt_translate.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Translate theorems into an SMT intermediate format and serialize them.
     5 *)
     6 
     7 signature SMT_TRANSLATE =
     8 sig
     9   (* intermediate term structure *)
    10   datatype squant = SForall | SExists
    11   datatype 'a spattern = SPat of 'a list | SNoPat of 'a list
    12   datatype sterm =
    13     SVar of int |
    14     SApp of string * sterm list |
    15     SLet of string * sterm * sterm |
    16     SQua of squant * string list * sterm spattern list * int option * sterm
    17 
    18   (* configuration options *)
    19   type prefixes = {sort_prefix: string, func_prefix: string}
    20   type header = Proof.context -> term list -> string list
    21   type strict = {
    22     is_builtin_conn: string * typ -> bool,
    23     is_builtin_pred: Proof.context -> string * typ -> bool,
    24     is_builtin_distinct: bool}
    25   type builtins = {
    26     builtin_typ: Proof.context -> typ -> string option,
    27     builtin_num: Proof.context -> typ -> int -> string option,
    28     builtin_fun: Proof.context -> string * typ -> term list ->
    29       (string * term list) option,
    30     has_datatypes: bool }
    31   type sign = {
    32     header: string list,
    33     sorts: string list,
    34     dtyps: (string * (string * (string * string) list) list) list list,
    35     funcs: (string * (string list * string)) list }
    36   type config = {
    37     prefixes: prefixes,
    38     header: header,
    39     strict: strict option,
    40     builtins: builtins,
    41     serialize: string list -> sign -> sterm list -> string }
    42   type recon = {
    43     typs: typ Symtab.table,
    44     terms: term Symtab.table,
    45     unfolds: thm list,
    46     assms: (int * thm) list }
    47 
    48   val translate: config -> Proof.context -> string list -> (int * thm) list ->
    49     string * recon
    50 end
    51 
    52 structure SMT_Translate: SMT_TRANSLATE =
    53 struct
    54 
    55 structure U = SMT_Utils
    56 
    57 
    58 (* intermediate term structure *)
    59 
    60 datatype squant = SForall | SExists
    61 
    62 datatype 'a spattern = SPat of 'a list | SNoPat of 'a list
    63 
    64 datatype sterm =
    65   SVar of int |
    66   SApp of string * sterm list |
    67   SLet of string * sterm * sterm |
    68   SQua of squant * string list * sterm spattern list * int option * sterm
    69 
    70 
    71 
    72 (* configuration options *)
    73 
    74 type prefixes = {sort_prefix: string, func_prefix: string}
    75 
    76 type header = Proof.context -> term list -> string list
    77 
    78 type strict = {
    79   is_builtin_conn: string * typ -> bool,
    80   is_builtin_pred: Proof.context -> string * typ -> bool,
    81   is_builtin_distinct: bool}
    82 
    83 type builtins = {
    84   builtin_typ: Proof.context -> typ -> string option,
    85   builtin_num: Proof.context -> typ -> int -> string option,
    86   builtin_fun: Proof.context -> string * typ -> term list ->
    87     (string * term list) option,
    88   has_datatypes: bool }
    89 
    90 type sign = {
    91   header: string list,
    92   sorts: string list,
    93   dtyps: (string * (string * (string * string) list) list) list list,
    94   funcs: (string * (string list * string)) list }
    95 
    96 type config = {
    97   prefixes: prefixes,
    98   header: header,
    99   strict: strict option,
   100   builtins: builtins,
   101   serialize: string list -> sign -> sterm list -> string }
   102 
   103 type recon = {
   104   typs: typ Symtab.table,
   105   terms: term Symtab.table,
   106   unfolds: thm list,
   107   assms: (int * thm) list }
   108 
   109 
   110 
   111 (* utility functions *)
   112 
   113 val quantifier = (fn
   114     @{const_name All} => SOME SForall
   115   | @{const_name Ex} => SOME SExists
   116   | _ => NONE)
   117 
   118 fun group_quant qname Ts (t as Const (q, _) $ Abs (_, T, u)) =
   119       if q = qname then group_quant qname (T :: Ts) u else (Ts, t)
   120   | group_quant _ Ts t = (Ts, t)
   121 
   122 fun dest_weight (@{const SMT.weight} $ w $ t) =
   123       (SOME (snd (HOLogic.dest_number w)), t)
   124   | dest_weight t = (NONE, t)
   125 
   126 fun dest_pat (Const (@{const_name pat}, _) $ t) = (t, true)
   127   | dest_pat (Const (@{const_name nopat}, _) $ t) = (t, false)
   128   | dest_pat t = raise TERM ("dest_pat", [t])
   129 
   130 fun dest_pats [] = I
   131   | dest_pats ts =
   132       (case map dest_pat ts |> split_list ||> distinct (op =) of
   133         (ps, [true]) => cons (SPat ps)
   134       | (ps, [false]) => cons (SNoPat ps)
   135       | _ => raise TERM ("dest_pats", ts))
   136 
   137 fun dest_trigger (@{const trigger} $ tl $ t) =
   138       (rev (fold (dest_pats o HOLogic.dest_list) (HOLogic.dest_list tl) []), t)
   139   | dest_trigger t = ([], t)
   140 
   141 fun dest_quant qn T t = quantifier qn |> Option.map (fn q =>
   142   let
   143     val (Ts, u) = group_quant qn [T] t
   144     val (ps, p) = dest_trigger u
   145     val (w, b) = dest_weight p
   146   in (q, rev Ts, ps, w, b) end)
   147 
   148 fun fold_map_pat f (SPat ts) = fold_map f ts #>> SPat
   149   | fold_map_pat f (SNoPat ts) = fold_map f ts #>> SNoPat
   150 
   151 fun prop_of thm = HOLogic.dest_Trueprop (Thm.prop_of thm)
   152 
   153 
   154 
   155 (* enforce a strict separation between formulas and terms *)
   156 
   157 val term_eq_rewr = @{lemma "term_eq x y == x = y" by (simp add: term_eq_def)}
   158 
   159 val term_bool = @{lemma "~(term_eq True False)" by (simp add: term_eq_def)}
   160 val term_bool' = Simplifier.rewrite_rule [term_eq_rewr] term_bool
   161 
   162 
   163 val needs_rewrite = Thm.prop_of #> Term.exists_subterm (fn
   164     Const (@{const_name Let}, _) => true
   165   | @{const HOL.eq (bool)} $ _ $ @{const True} => true
   166   | Const (@{const_name If}, _) $ _ $ @{const True} $ @{const False} => true
   167   | _ => false)
   168 
   169 val rewrite_rules = [
   170   Let_def,
   171   @{lemma "P = True == P" by (rule eq_reflection) simp},
   172   @{lemma "if P then True else False == P" by (rule eq_reflection) simp}]
   173 
   174 fun rewrite ctxt = Simplifier.full_rewrite
   175   (Simplifier.context ctxt empty_ss addsimps rewrite_rules)
   176 
   177 fun normalize ctxt thm =
   178   if needs_rewrite thm then Conv.fconv_rule (rewrite ctxt) thm else thm
   179 
   180 val unfold_rules = term_eq_rewr :: rewrite_rules
   181 
   182 
   183 val revert_types =
   184   let
   185     fun revert @{typ prop} = @{typ bool}
   186       | revert (Type (n, Ts)) = Type (n, map revert Ts)
   187       | revert T = T
   188   in Term.map_types revert end
   189 
   190 
   191 fun strictify {is_builtin_conn, is_builtin_pred, is_builtin_distinct} ctxt =
   192   let
   193     fun is_builtin_conn' (@{const_name True}, _) = false
   194       | is_builtin_conn' (@{const_name False}, _) = false
   195       | is_builtin_conn' c = is_builtin_conn c
   196 
   197     fun is_builtin_pred' _ (@{const_name distinct}, _) [t] =
   198           is_builtin_distinct andalso can HOLogic.dest_list t
   199       | is_builtin_pred' ctxt c _ = is_builtin_pred ctxt c
   200 
   201     val propT = @{typ prop} and boolT = @{typ bool}
   202     val as_propT = (fn @{typ bool} => propT | T => T)
   203     fun mapTs f g = Term.strip_type #> (fn (Ts, T) => map f Ts ---> g T)
   204     fun conn (n, T) = (n, mapTs as_propT as_propT T)
   205     fun pred (n, T) = (n, mapTs I as_propT T)
   206 
   207     val term_eq = @{const HOL.eq (bool)} |> Term.dest_Const |> pred
   208     fun as_term t = Const term_eq $ t $ @{const True}
   209 
   210     val if_term = Const (@{const_name If}, [propT, boolT, boolT] ---> boolT)
   211     fun wrap_in_if t = if_term $ t $ @{const True} $ @{const False}
   212 
   213     fun in_list T f t = HOLogic.mk_list T (map f (HOLogic.dest_list t))
   214 
   215     fun in_term t =
   216       (case Term.strip_comb t of
   217         (c as Const (@{const_name If}, _), [t1, t2, t3]) =>
   218           c $ in_form t1 $ in_term t2 $ in_term t3
   219       | (h as Const c, ts) =>
   220           if is_builtin_conn' (conn c) orelse is_builtin_pred' ctxt (pred c) ts
   221           then wrap_in_if (in_form t)
   222           else Term.list_comb (h, map in_term ts)
   223       | (h as Free _, ts) => Term.list_comb (h, map in_term ts)
   224       | _ => t)
   225 
   226     and in_weight ((c as @{const SMT.weight}) $ w $ t) = c $ w $ in_form t
   227       | in_weight t = in_form t 
   228 
   229     and in_pat ((c as Const (@{const_name pat}, _)) $ t) = c $ in_term t
   230       | in_pat ((c as Const (@{const_name nopat}, _)) $ t) = c $ in_term t
   231       | in_pat t = raise TERM ("in_pat", [t])
   232 
   233     and in_pats ps =
   234       in_list @{typ "pattern list"} (in_list @{typ pattern} in_pat) ps
   235 
   236     and in_trig ((c as @{const trigger}) $ p $ t) = c $ in_pats p $ in_weight t
   237       | in_trig t = in_weight t
   238 
   239     and in_form t =
   240       (case Term.strip_comb t of
   241         (q as Const (qn, _), [Abs (n, T, t')]) =>
   242           if is_some (quantifier qn) then q $ Abs (n, T, in_trig t')
   243           else as_term (in_term t)
   244       | (Const (c as (@{const_name distinct}, T)), [t']) =>
   245           if is_builtin_distinct andalso can HOLogic.dest_list t' then
   246             Const (pred c) $ in_list T in_term t'
   247           else as_term (in_term t)
   248       | (Const c, ts) =>
   249           if is_builtin_conn (conn c)
   250           then Term.list_comb (Const (conn c), map in_form ts)
   251           else if is_builtin_pred ctxt (pred c)
   252           then Term.list_comb (Const (pred c), map in_term ts)
   253           else as_term (in_term t)
   254       | _ => as_term (in_term t))
   255   in
   256     map (apsnd (normalize ctxt)) #> (fn irules =>
   257     ((unfold_rules, (~1, term_bool') :: irules),
   258      map (in_form o prop_of o snd) ((~1, term_bool) :: irules)))
   259   end
   260 
   261 
   262 
   263 (* translation from Isabelle terms into SMT intermediate terms *)
   264 
   265 val empty_context = (1, Typtab.empty, [], 1, Termtab.empty)
   266 
   267 fun make_sign header (_, typs, dtyps, _, terms) = {
   268   header = header,
   269   sorts = Typtab.fold (fn (_, (n, true)) => cons n | _ => I) typs [],
   270   funcs = Termtab.fold (fn (_, (n, SOME ss)) => cons (n,ss) | _ => I) terms [],
   271   dtyps = rev dtyps }
   272 
   273 fun make_recon (unfolds, assms) (_, typs, _, _, terms) = {
   274   typs = Symtab.make (map (apfst fst o swap) (Typtab.dest typs)),
   275     (*FIXME: don't drop the datatype information! *)
   276   terms = Symtab.make (map (fn (t, (n, _)) => (n, t)) (Termtab.dest terms)),
   277   unfolds = unfolds,
   278   assms = assms }
   279 
   280 fun string_of_index pre i = pre ^ string_of_int i
   281 
   282 fun new_typ sort_prefix proper T (Tidx, typs, dtyps, idx, terms) =
   283   let val s = string_of_index sort_prefix Tidx
   284   in (s, (Tidx+1, Typtab.update (T, (s, proper)) typs, dtyps, idx, terms)) end
   285 
   286 fun lookup_typ (_, typs, _, _, _) = Typtab.lookup typs
   287 
   288 fun fresh_typ T f cx =
   289   (case lookup_typ cx T of
   290     SOME (s, _) => (s, cx)
   291   | NONE => f T cx)
   292 
   293 fun new_fun func_prefix t ss (Tidx, typs, dtyps, idx, terms) =
   294   let
   295     val f = string_of_index func_prefix idx
   296     val terms' = Termtab.update (revert_types t, (f, ss)) terms
   297   in (f, (Tidx, typs, dtyps, idx+1, terms')) end
   298 
   299 fun fresh_fun func_prefix t ss (cx as (_, _, _, _, terms)) =
   300   (case Termtab.lookup terms t of
   301     SOME (f, _) => (f, cx)
   302   | NONE => new_fun func_prefix t ss cx)
   303 
   304 fun mk_type (_, Tfs) (d as Datatype.DtTFree _) = the (AList.lookup (op =) Tfs d)
   305   | mk_type Ts (Datatype.DtType (n, ds)) = Type (n, map (mk_type Ts) ds)
   306   | mk_type (Tds, _) (Datatype.DtRec i) = nth Tds i
   307 
   308 fun mk_selector ctxt Ts T n (i, d) =
   309   (case Datatype_Selectors.lookup_selector ctxt (n, i+1) of
   310     NONE => raise Fail ("missing selector for datatype constructor " ^ quote n)
   311   | SOME m => mk_type Ts d |> (fn U => (Const (m, T --> U), U)))
   312 
   313 fun mk_constructor ctxt Ts T (n, args) =
   314   let val (sels, Us) = split_list (map_index (mk_selector ctxt Ts T n) args)
   315   in (Const (n, Us ---> T), sels) end
   316 
   317 fun lookup_datatype ctxt n Ts =
   318   if member (op =) [@{type_name bool}, @{type_name nat}] n then NONE
   319   else
   320     Datatype.get_info (ProofContext.theory_of ctxt) n
   321     |> Option.map (fn {descr, ...} =>
   322          let
   323            val Tds = map (fn (_, (tn, _, _)) => Type (tn, Ts))
   324              (sort (int_ord o pairself fst) descr)
   325            val Tfs = (case hd descr of (_, (_, tfs, _)) => tfs ~~ Ts)
   326          in
   327            descr |> map (fn (i, (_, _, cs)) =>
   328              (nth Tds i, map (mk_constructor ctxt (Tds, Tfs) (nth Tds i)) cs))
   329          end)
   330 
   331 fun relaxed irules = (([], irules), map (prop_of o snd) irules)
   332 
   333 fun with_context header f (ths, ts) =
   334   let val (us, context) = fold_map f ts empty_context
   335   in ((make_sign (header ts) context, us), make_recon ths context) end
   336 
   337 
   338 fun translate {prefixes, strict, header, builtins, serialize} ctxt comments =
   339   let
   340     val {sort_prefix, func_prefix} = prefixes
   341     val {builtin_typ, builtin_num, builtin_fun, has_datatypes} = builtins
   342 
   343     fun transT (T as TFree _) = fresh_typ T (new_typ sort_prefix true)
   344       | transT (T as TVar _) = (fn _ => raise TYPE ("smt_translate", [T], []))
   345       | transT (T as Type (n, Ts)) =
   346           (case builtin_typ ctxt T of
   347             SOME n => pair n
   348           | NONE => fresh_typ T (fn _ => fn cx =>
   349               if not has_datatypes then new_typ sort_prefix true T cx
   350               else
   351                 (case lookup_datatype ctxt n Ts of
   352                   NONE => new_typ sort_prefix true T cx
   353                 | SOME dts =>
   354                     let val cx' = new_dtyps dts cx 
   355                     in (fst (the (lookup_typ cx' T)), cx') end)))
   356 
   357     and new_dtyps dts cx =
   358       let
   359         fun new_decl i t =
   360           let val (Ts, T) = U.dest_funT i (Term.fastype_of t)
   361           in
   362             fold_map transT Ts ##>> transT T ##>>
   363             new_fun func_prefix t NONE #>> swap
   364           end
   365         fun new_dtyp_decl (con, sels) =
   366           new_decl (length sels) con ##>> fold_map (new_decl 1) sels #>>
   367           (fn ((con', _), sels') => (con', map (apsnd snd) sels'))
   368       in
   369         cx
   370         |> fold_map (new_typ sort_prefix false o fst) dts
   371         ||>> fold_map (fold_map new_dtyp_decl o snd) dts
   372         |-> (fn (ss, decls) => fn (Tidx, typs, dtyps, idx, terms) =>
   373               (Tidx, typs, (ss ~~ decls) :: dtyps, idx, terms))
   374       end
   375 
   376     fun app n ts = SApp (n, ts)
   377 
   378     fun trans t =
   379       (case Term.strip_comb t of
   380         (Const (qn, _), [Abs (_, T, t1)]) =>
   381           (case dest_quant qn T t1 of
   382             SOME (q, Ts, ps, w, b) =>
   383               fold_map transT Ts ##>> fold_map (fold_map_pat trans) ps ##>>
   384               trans b #>> (fn ((Ts', ps'), b') => SQua (q, Ts', ps', w, b'))
   385           | NONE => raise TERM ("intermediate", [t]))
   386       | (Const (@{const_name Let}, _), [t1, Abs (_, T, t2)]) =>
   387           transT T ##>> trans t1 ##>> trans t2 #>>
   388           (fn ((U, u1), u2) => SLet (U, u1, u2))
   389       | (h as Const (c as (@{const_name distinct}, T)), ts) =>
   390           (case builtin_fun ctxt c ts of
   391             SOME (n, ts) => fold_map trans ts #>> app n
   392           | NONE => transs h T ts)
   393       | (h as Const (c as (_, T)), ts) =>
   394           (case try HOLogic.dest_number t of
   395             SOME (T, i) =>
   396               (case builtin_num ctxt T i of
   397                 SOME n => pair (SApp (n, []))
   398               | NONE => transs t T [])
   399           | NONE =>
   400               (case builtin_fun ctxt c ts of
   401                 SOME (n, ts') => fold_map trans ts' #>> app n
   402               | NONE => transs h T ts))
   403       | (h as Free (_, T), ts) => transs h T ts
   404       | (Bound i, []) => pair (SVar i)
   405       | (Abs (_, _, t1 $ Bound 0), []) =>
   406         if not (loose_bvar1 (t1, 0)) then trans t1 (* eta-reduce on the fly *)
   407         else raise TERM ("smt_translate", [t])
   408       | _ => raise TERM ("smt_translate", [t]))
   409 
   410     and transs t T ts =
   411       let val (Us, U) = U.dest_funT (length ts) T
   412       in
   413         fold_map transT Us ##>> transT U #-> (fn Up =>
   414         fresh_fun func_prefix t (SOME Up) ##>> fold_map trans ts #>> SApp)
   415       end
   416   in
   417     (case strict of SOME strct => strictify strct ctxt | NONE => relaxed) #>
   418     with_context (header ctxt) trans #>> uncurry (serialize comments)
   419   end
   420 
   421 end