src/HOL/Tools/SMT/smt_translate.ML
author boehmes
Wed May 12 23:54:04 2010 +0200 (2010-05-12)
changeset 36899 bcd6fce5bf06
parent 36898 8e55aa1306c5
child 37124 fe22fc54b876
permissions -rw-r--r--
layered SMT setup, adapted SMT clients, added further tests, made Z3 proof abstraction configurable
     1 (*  Title:      HOL/Tools/SMT/smt_translate.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Translate theorems into an SMT intermediate format and serialize them.
     5 *)
     6 
     7 signature SMT_TRANSLATE =
     8 sig
     9   (* intermediate term structure *)
    10   datatype squant = SForall | SExists
    11   datatype 'a spattern = SPat of 'a list | SNoPat of 'a list
    12   datatype sterm =
    13     SVar of int |
    14     SApp of string * sterm list |
    15     SLet of string * sterm * sterm |
    16     SQua of squant * string list * sterm spattern list * sterm
    17 
    18   (* configuration options *)
    19   type prefixes = {sort_prefix: string, func_prefix: string}
    20   type header = Proof.context -> term list -> string list
    21   type strict = {
    22     is_builtin_conn: string * typ -> bool,
    23     is_builtin_pred: Proof.context -> string * typ -> bool,
    24     is_builtin_distinct: bool}
    25   type builtins = {
    26     builtin_typ: Proof.context -> typ -> string option,
    27     builtin_num: Proof.context -> typ -> int -> string option,
    28     builtin_fun: Proof.context -> string * typ -> term list ->
    29       (string * term list) option }
    30   type sign = {
    31     header: string list,
    32     sorts: string list,
    33     funcs: (string * (string list * string)) list }
    34   type config = {
    35     prefixes: prefixes,
    36     header: header,
    37     strict: strict option,
    38     builtins: builtins,
    39     serialize: string list -> sign -> sterm list -> string }
    40   type recon = {
    41     typs: typ Symtab.table,
    42     terms: term Symtab.table,
    43     unfolds: thm list,
    44     assms: thm list }
    45 
    46   val translate: config -> Proof.context -> string list -> thm list ->
    47     string * recon
    48 end
    49 
    50 structure SMT_Translate: SMT_TRANSLATE =
    51 struct
    52 
    53 (* intermediate term structure *)
    54 
    55 datatype squant = SForall | SExists
    56 
    57 datatype 'a spattern = SPat of 'a list | SNoPat of 'a list
    58 
    59 datatype sterm =
    60   SVar of int |
    61   SApp of string * sterm list |
    62   SLet of string * sterm * sterm |
    63   SQua of squant * string list * sterm spattern list * sterm
    64 
    65 
    66 
    67 (* configuration options *)
    68 
    69 type prefixes = {sort_prefix: string, func_prefix: string}
    70 
    71 type header = Proof.context -> term list -> string list
    72 
    73 type strict = {
    74   is_builtin_conn: string * typ -> bool,
    75   is_builtin_pred: Proof.context -> string * typ -> bool,
    76   is_builtin_distinct: bool}
    77 
    78 type builtins = {
    79   builtin_typ: Proof.context -> typ -> string option,
    80   builtin_num: Proof.context -> typ -> int -> string option,
    81   builtin_fun: Proof.context -> string * typ -> term list ->
    82     (string * term list) option }
    83 
    84 type sign = {
    85   header: string list,
    86   sorts: string list,
    87   funcs: (string * (string list * string)) list }
    88 
    89 type config = {
    90   prefixes: prefixes,
    91   header: header,
    92   strict: strict option,
    93   builtins: builtins,
    94   serialize: string list -> sign -> sterm list -> string }
    95 
    96 type recon = {
    97   typs: typ Symtab.table,
    98   terms: term Symtab.table,
    99   unfolds: thm list,
   100   assms: thm list }
   101 
   102 
   103 
   104 (* utility functions *)
   105 
   106 val dest_funT =
   107   let
   108     fun dest Ts 0 T = (rev Ts, T)
   109       | dest Ts i (Type ("fun", [T, U])) = dest (T::Ts) (i-1) U
   110       | dest _ _ T = raise TYPE ("dest_funT", [T], [])
   111   in dest [] end
   112 
   113 val quantifier = (fn
   114     @{const_name All} => SOME SForall
   115   | @{const_name Ex} => SOME SExists
   116   | _ => NONE)
   117 
   118 fun group_quant qname Ts (t as Const (q, _) $ Abs (_, T, u)) =
   119       if q = qname then group_quant qname (T :: Ts) u else (Ts, t)
   120   | group_quant _ Ts t = (Ts, t)
   121 
   122 fun dest_pat ts (Const (@{const_name pat}, _) $ t) = SPat (rev (t :: ts))
   123   | dest_pat ts (Const (@{const_name nopat}, _) $ t) = SNoPat (rev (t :: ts))
   124   | dest_pat ts (Const (@{const_name andpat}, _) $ p $ t) = dest_pat (t::ts) p
   125   | dest_pat _ t = raise TERM ("dest_pat", [t])
   126 
   127 fun dest_trigger (@{term trigger} $ tl $ t) =
   128       (map (dest_pat []) (HOLogic.dest_list tl), t)
   129   | dest_trigger t = ([], t)
   130 
   131 fun dest_quant qn T t = quantifier qn |> Option.map (fn q =>
   132   let
   133     val (Ts, u) = group_quant qn [T] t
   134     val (ps, b) = dest_trigger u
   135   in (q, rev Ts, ps, b) end)
   136 
   137 fun fold_map_pat f (SPat ts) = fold_map f ts #>> SPat
   138   | fold_map_pat f (SNoPat ts) = fold_map f ts #>> SNoPat
   139 
   140 fun prop_of thm = HOLogic.dest_Trueprop (Thm.prop_of thm)
   141 
   142 
   143 
   144 (* enforce a strict separation between formulas and terms *)
   145 
   146 val term_eq_rewr = @{lemma "x term_eq y == x = y" by (simp add: term_eq_def)}
   147 
   148 val term_bool = @{lemma "~(True term_eq False)" by (simp add: term_eq_def)}
   149 val term_bool' = Simplifier.rewrite_rule [term_eq_rewr] term_bool
   150 
   151 
   152 val needs_rewrite = Thm.prop_of #> Term.exists_subterm (fn
   153     Const (@{const_name Let}, _) => true
   154   | @{term "op = :: bool => _"} $ _ $ @{term True} => true
   155   | Const (@{const_name If}, _) $ _ $ @{term True} $ @{term False} => true
   156   | _ => false)
   157 
   158 val rewrite_rules = [
   159   Let_def,
   160   @{lemma "P = True == P" by (rule eq_reflection) simp},
   161   @{lemma "if P then True else False == P" by (rule eq_reflection) simp}]
   162 
   163 fun rewrite ctxt = Simplifier.full_rewrite
   164   (Simplifier.context ctxt empty_ss addsimps rewrite_rules)
   165 
   166 fun normalize ctxt thm =
   167   if needs_rewrite thm then Conv.fconv_rule (rewrite ctxt) thm else thm
   168 
   169 val unfold_rules = term_eq_rewr :: rewrite_rules
   170 
   171 
   172 val revert_types =
   173   let
   174     fun revert @{typ prop} = @{typ bool}
   175       | revert (Type (n, Ts)) = Type (n, map revert Ts)
   176       | revert T = T
   177   in Term.map_types revert end
   178 
   179 
   180 fun strictify {is_builtin_conn, is_builtin_pred, is_builtin_distinct} ctxt =
   181   let
   182     fun is_builtin_conn' (@{const_name True}, _) = false
   183       | is_builtin_conn' (@{const_name False}, _) = false
   184       | is_builtin_conn' c = is_builtin_conn c
   185 
   186     val propT = @{typ prop} and boolT = @{typ bool}
   187     val as_propT = (fn @{typ bool} => propT | T => T)
   188     fun mapTs f g = Term.strip_type #> (fn (Ts, T) => map f Ts ---> g T)
   189     fun conn (n, T) = (n, mapTs as_propT as_propT T)
   190     fun pred (n, T) = (n, mapTs I as_propT T)
   191 
   192     val term_eq = @{term "op = :: bool => _"} |> Term.dest_Const |> pred
   193     fun as_term t = Const term_eq $ t $ @{term True}
   194 
   195     val if_term = Const (@{const_name If}, [propT, boolT, boolT] ---> boolT)
   196     fun wrap_in_if t = if_term $ t $ @{term True} $ @{term False}
   197 
   198     fun in_list T f t = HOLogic.mk_list T (map f (HOLogic.dest_list t))
   199 
   200     fun in_term t =
   201       (case Term.strip_comb t of
   202         (c as Const (@{const_name If}, _), [t1, t2, t3]) =>
   203           c $ in_form t1 $ in_term t2 $ in_term t3
   204       | (h as Const c, ts) =>
   205           if is_builtin_conn' (conn c) orelse is_builtin_pred ctxt (pred c)
   206           then wrap_in_if (in_form t)
   207           else Term.list_comb (h, map in_term ts)
   208       | (h as Free _, ts) => Term.list_comb (h, map in_term ts)
   209       | _ => t)
   210 
   211     and in_pat ((c as Const (@{const_name pat}, _)) $ t) = c $ in_term t
   212       | in_pat ((c as Const (@{const_name nopat}, _)) $ t) = c $ in_term t
   213       | in_pat ((c as Const (@{const_name andpat}, _)) $ p $ t) =
   214           c $ in_pat p $ in_term t
   215       | in_pat t = raise TERM ("in_pat", [t])
   216 
   217     and in_pats p = in_list @{typ pattern} in_pat p
   218 
   219     and in_trig ((c as @{term trigger}) $ p $ t) = c $ in_pats p $ in_form t
   220       | in_trig t = in_form t
   221 
   222     and in_form t =
   223       (case Term.strip_comb t of
   224         (q as Const (qn, _), [Abs (n, T, t')]) =>
   225           if is_some (quantifier qn) then q $ Abs (n, T, in_trig t')
   226           else as_term (in_term t)
   227       | (Const (c as (@{const_name distinct}, T)), [t']) =>
   228           if is_builtin_distinct then Const (pred c) $ in_list T in_term t'
   229           else as_term (in_term t)
   230       | (Const c, ts) =>
   231           if is_builtin_conn (conn c)
   232           then Term.list_comb (Const (conn c), map in_form ts)
   233           else if is_builtin_pred ctxt (pred c)
   234           then Term.list_comb (Const (pred c), map in_term ts)
   235           else as_term (in_term t)
   236       | _ => as_term (in_term t))
   237   in
   238     map (normalize ctxt) #> (fn thms => ((unfold_rules, term_bool' :: thms),
   239     map (in_form o prop_of) (term_bool :: thms)))
   240   end
   241 
   242 
   243 
   244 (* translation from Isabelle terms into SMT intermediate terms *)
   245 
   246 val empty_context = (1, Typtab.empty, 1, Termtab.empty)
   247 
   248 fun make_sign header (_, typs, _, terms) = {
   249   header = header,
   250   sorts = Typtab.fold (cons o snd) typs [],
   251   funcs = Termtab.fold (cons o snd) terms [] }
   252 
   253 fun make_recon (unfolds, assms) (_, typs, _, terms) = {
   254   typs = Symtab.make (map swap (Typtab.dest typs)),
   255   terms = Symtab.make (map (fn (t, (n, _)) => (n, t)) (Termtab.dest terms)),
   256   unfolds = unfolds,
   257   assms = assms }
   258 
   259 fun string_of_index pre i = pre ^ string_of_int i
   260 
   261 fun fresh_typ sort_prefix T (cx as (Tidx, typs, idx, terms)) =
   262   (case Typtab.lookup typs T of
   263     SOME s => (s, cx)
   264   | NONE =>
   265       let
   266         val s = string_of_index sort_prefix Tidx
   267         val typs' = Typtab.update (T, s) typs
   268       in (s, (Tidx+1, typs', idx, terms)) end)
   269 
   270 fun fresh_fun func_prefix t ss (cx as (Tidx, typs, idx, terms)) =
   271   (case Termtab.lookup terms t of
   272     SOME (f, _) => (f, cx)
   273   | NONE =>
   274       let
   275         val f = string_of_index func_prefix idx
   276         val terms' = Termtab.update (revert_types t, (f, ss)) terms
   277       in (f, (Tidx, typs, idx+1, terms')) end)
   278 
   279 fun relaxed thms = (([], thms), map prop_of thms)
   280 
   281 fun with_context header f (ths, ts) =
   282   let val (us, context) = fold_map f ts empty_context
   283   in ((make_sign (header ts) context, us), make_recon ths context) end
   284 
   285 
   286 fun translate {prefixes, strict, header, builtins, serialize} ctxt comments =
   287   let
   288     val {sort_prefix, func_prefix} = prefixes
   289     val {builtin_typ, builtin_num, builtin_fun} = builtins
   290 
   291     fun transT T =
   292       (case builtin_typ ctxt T of
   293         SOME n => pair n
   294       | NONE => fresh_typ sort_prefix T)
   295 
   296     fun app n ts = SApp (n, ts)
   297 
   298     fun trans t =
   299       (case Term.strip_comb t of
   300         (Const (qn, _), [Abs (_, T, t1)]) =>
   301           (case dest_quant qn T t1 of
   302             SOME (q, Ts, ps, b) =>
   303               fold_map transT Ts ##>> fold_map (fold_map_pat trans) ps ##>>
   304               trans b #>> (fn ((Ts', ps'), b') => SQua (q, Ts', ps', b'))
   305           | NONE => raise TERM ("intermediate", [t]))
   306       | (Const (@{const_name Let}, _), [t1, Abs (_, T, t2)]) =>
   307           transT T ##>> trans t1 ##>> trans t2 #>>
   308           (fn ((U, u1), u2) => SLet (U, u1, u2))
   309       | (h as Const (c as (@{const_name distinct}, T)), [t1]) =>
   310           (case builtin_fun ctxt c (HOLogic.dest_list t1) of
   311             SOME (n, ts) => fold_map trans ts #>> app n
   312           | NONE => transs h T [t1])
   313       | (h as Const (c as (_, T)), ts) =>
   314           (case try HOLogic.dest_number t of
   315             SOME (T, i) =>
   316               (case builtin_num ctxt T i of
   317                 SOME n => pair (SApp (n, []))
   318               | NONE => transs t T [])
   319           | NONE =>
   320               (case builtin_fun ctxt c ts of
   321                 SOME (n, ts') => fold_map trans ts' #>> app n
   322               | NONE => transs h T ts))
   323       | (h as Free (_, T), ts) => transs h T ts
   324       | (Bound i, []) => pair (SVar i)
   325       | _ => raise TERM ("intermediate", [t]))
   326 
   327     and transs t T ts =
   328       let val (Us, U) = dest_funT (length ts) T
   329       in
   330         fold_map transT Us ##>> transT U #-> (fn Up =>
   331         fresh_fun func_prefix t Up ##>> fold_map trans ts #>> SApp)
   332       end
   333   in
   334     (case strict of SOME strct => strictify strct ctxt | NONE => relaxed) #>
   335     with_context (header ctxt) trans #>> uncurry (serialize comments)
   336   end
   337 
   338 end