src/HOL/Tools/SMT/smt_translate.ML
author boehmes
Tue Dec 07 14:53:12 2010 +0100 (2010-12-07)
changeset 41059 d2b1fc1b8e19
parent 41057 8dbc951a291c
child 41123 3bb9be510a9d
permissions -rw-r--r--
centralized handling of built-in types and constants;
also store types and constants which are rewritten during preprocessing;
interfaces are identified by classes (supporting inheritance, at least on the level of built-in symbols);
removed term_eq in favor of type replacements: term-level occurrences of type bool are replaced by type term_bool (only for the translation)
     1 (*  Title:      HOL/Tools/SMT/smt_translate.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Translate theorems into an SMT intermediate format and serialize them.
     5 *)
     6 
     7 signature SMT_TRANSLATE =
     8 sig
     9   (* intermediate term structure *)
    10   datatype squant = SForall | SExists
    11   datatype 'a spattern = SPat of 'a list | SNoPat of 'a list
    12   datatype sterm =
    13     SVar of int |
    14     SApp of string * sterm list |
    15     SLet of string * sterm * sterm |
    16     SQua of squant * string list * sterm spattern list * int option * sterm
    17 
    18   (* configuration options *)
    19   type prefixes = {sort_prefix: string, func_prefix: string}
    20   type sign = {
    21     header: string list,
    22     sorts: string list,
    23     dtyps: (string * (string * (string * string) list) list) list list,
    24     funcs: (string * (string list * string)) list }
    25   type config = {
    26     prefixes: prefixes,
    27     header: Proof.context -> term list -> string list,
    28     is_fol: bool,
    29     has_datatypes: bool,
    30     serialize: string list -> sign -> sterm list -> string }
    31   type recon = {
    32     typs: typ Symtab.table,
    33     terms: term Symtab.table,
    34     unfolds: thm list,
    35     assms: (int * thm) list }
    36 
    37   val translate: config -> Proof.context -> string list -> (int * thm) list ->
    38     string * recon
    39 end
    40 
    41 structure SMT_Translate: SMT_TRANSLATE =
    42 struct
    43 
    44 structure U = SMT_Utils
    45 structure B = SMT_Builtin
    46 
    47 
    48 (* intermediate term structure *)
    49 
    50 datatype squant = SForall | SExists
    51 
    52 datatype 'a spattern = SPat of 'a list | SNoPat of 'a list
    53 
    54 datatype sterm =
    55   SVar of int |
    56   SApp of string * sterm list |
    57   SLet of string * sterm * sterm |
    58   SQua of squant * string list * sterm spattern list * int option * sterm
    59 
    60 
    61 
    62 (* configuration options *)
    63 
    64 type prefixes = {sort_prefix: string, func_prefix: string}
    65 
    66 type sign = {
    67   header: string list,
    68   sorts: string list,
    69   dtyps: (string * (string * (string * string) list) list) list list,
    70   funcs: (string * (string list * string)) list }
    71 
    72 type config = {
    73   prefixes: prefixes,
    74   header: Proof.context -> term list -> string list,
    75   is_fol: bool,
    76   has_datatypes: bool,
    77   serialize: string list -> sign -> sterm list -> string }
    78 
    79 type recon = {
    80   typs: typ Symtab.table,
    81   terms: term Symtab.table,
    82   unfolds: thm list,
    83   assms: (int * thm) list }
    84 
    85 
    86 
    87 (* utility functions *)
    88 
    89 val quantifier = (fn
    90     @{const_name All} => SOME SForall
    91   | @{const_name Ex} => SOME SExists
    92   | _ => NONE)
    93 
    94 fun group_quant qname Ts (t as Const (q, _) $ Abs (_, T, u)) =
    95       if q = qname then group_quant qname (T :: Ts) u else (Ts, t)
    96   | group_quant _ Ts t = (Ts, t)
    97 
    98 fun dest_weight (@{const SMT.weight} $ w $ t) =
    99       (SOME (snd (HOLogic.dest_number w)), t)
   100   | dest_weight t = (NONE, t)
   101 
   102 fun dest_pat (Const (@{const_name pat}, _) $ t) = (t, true)
   103   | dest_pat (Const (@{const_name nopat}, _) $ t) = (t, false)
   104   | dest_pat t = raise TERM ("dest_pat", [t])
   105 
   106 fun dest_pats [] = I
   107   | dest_pats ts =
   108       (case map dest_pat ts |> split_list ||> distinct (op =) of
   109         (ps, [true]) => cons (SPat ps)
   110       | (ps, [false]) => cons (SNoPat ps)
   111       | _ => raise TERM ("dest_pats", ts))
   112 
   113 fun dest_trigger (@{const trigger} $ tl $ t) =
   114       (rev (fold (dest_pats o HOLogic.dest_list) (HOLogic.dest_list tl) []), t)
   115   | dest_trigger t = ([], t)
   116 
   117 fun dest_quant qn T t = quantifier qn |> Option.map (fn q =>
   118   let
   119     val (Ts, u) = group_quant qn [T] t
   120     val (ps, p) = dest_trigger u
   121     val (w, b) = dest_weight p
   122   in (q, rev Ts, ps, w, b) end)
   123 
   124 fun fold_map_pat f (SPat ts) = fold_map f ts #>> SPat
   125   | fold_map_pat f (SNoPat ts) = fold_map f ts #>> SNoPat
   126 
   127 fun prop_of thm = HOLogic.dest_Trueprop (Thm.prop_of thm)
   128 
   129 
   130 
   131 (* map HOL formulas to FOL formulas (i.e., separate formulas froms terms) *)
   132 
   133 val tboolT = @{typ SMT.term_bool}
   134 val term_true = Const (@{const_name True}, tboolT)
   135 val term_false = Const (@{const_name False}, tboolT)
   136 
   137 val term_bool = @{lemma "True ~= False" by simp}
   138 val term_bool_prop =
   139   let
   140     fun replace @{const HOL.eq (bool)} = @{const HOL.eq (SMT.term_bool)}
   141       | replace @{const True} = term_true
   142       | replace @{const False} = term_false
   143       | replace t = t
   144   in Term.map_aterms replace (prop_of term_bool) end
   145 
   146 val needs_rewrite = Thm.prop_of #> Term.exists_subterm (fn
   147     Const (@{const_name Let}, _) => true
   148   | @{const HOL.eq (bool)} $ _ $ @{const True} => true
   149   | Const (@{const_name If}, _) $ _ $ @{const True} $ @{const False} => true
   150   | _ => false)
   151 
   152 val rewrite_rules = [
   153   Let_def,
   154   @{lemma "P = True == P" by (rule eq_reflection) simp},
   155   @{lemma "if P then True else False == P" by (rule eq_reflection) simp}]
   156 
   157 fun rewrite ctxt ct =
   158   Conv.top_sweep_conv (fn ctxt' =>
   159     Conv.rewrs_conv rewrite_rules then_conv rewrite ctxt') ctxt ct
   160 
   161 fun normalize ctxt thm =
   162   if needs_rewrite thm then Conv.fconv_rule (rewrite ctxt) thm else thm
   163 
   164 fun revert_typ @{typ SMT.term_bool} = @{typ bool}
   165   | revert_typ (Type (n, Ts)) = Type (n, map revert_typ Ts)
   166   | revert_typ T = T
   167 
   168 val revert_types = Term.map_types revert_typ
   169 
   170 fun folify ctxt =
   171   let
   172     fun is_builtin_conn (@{const_name True}, _) _ = false
   173       | is_builtin_conn (@{const_name False}, _) _ = false
   174       | is_builtin_conn c ts = B.is_builtin_conn ctxt c ts
   175 
   176     fun as_term t = @{const HOL.eq (SMT.term_bool)} $ t $ term_true
   177 
   178     fun as_tbool @{typ bool} = tboolT
   179       | as_tbool (Type (n, Ts)) = Type (n, map as_tbool Ts)
   180       | as_tbool T = T
   181     fun mapTs f g = Term.strip_type #> (fn (Ts, T) => map f Ts ---> g T)
   182     fun predT T = mapTs as_tbool I T
   183     fun funcT T = mapTs as_tbool as_tbool T
   184     fun func (n, T) = Const (n, funcT T)
   185 
   186     fun map_ifT T = T |> Term.dest_funT ||> funcT |> (op -->)
   187     val if_term = @{const If (bool)} |> Term.dest_Const ||> map_ifT |> Const
   188     fun wrap_in_if t = if_term $ t $ term_true $ term_false
   189 
   190     fun in_list T f t = HOLogic.mk_list T (map f (HOLogic.dest_list t))
   191 
   192     fun in_term t =
   193       (case Term.strip_comb t of
   194         (Const (c as @{const_name If}, T), [t1, t2, t3]) =>
   195           Const (c, map_ifT T) $ in_form t1 $ in_term t2 $ in_term t3
   196       | (Const c, ts) =>
   197           if is_builtin_conn c ts orelse B.is_builtin_pred ctxt c ts
   198           then wrap_in_if (in_form t)
   199           else Term.list_comb (func c, map in_term ts)
   200       | (Free (n, T), ts) => Term.list_comb (Free (n, funcT T), map in_term ts)
   201       | _ => t)
   202 
   203     and in_weight ((c as @{const SMT.weight}) $ w $ t) = c $ w $ in_form t
   204       | in_weight t = in_form t 
   205 
   206     and in_pat (Const (c as (@{const_name pat}, _)) $ t) = func c $ in_term t
   207       | in_pat (Const (c as (@{const_name nopat}, _)) $ t) = func c $ in_term t
   208       | in_pat t = raise TERM ("in_pat", [t])
   209 
   210     and in_pats ps =
   211       in_list @{typ "pattern list"} (in_list @{typ pattern} in_pat) ps
   212 
   213     and in_trig ((c as @{const trigger}) $ p $ t) = c $ in_pats p $ in_weight t
   214       | in_trig t = in_weight t
   215 
   216     and in_form t =
   217       (case Term.strip_comb t of
   218         (q as Const (qn, _), [Abs (n, T, t')]) =>
   219           if is_some (quantifier qn) then q $ Abs (n, as_tbool T, in_trig t')
   220           else as_term (in_term t)
   221       | (Const (c as (n as @{const_name distinct}, T)), [t']) =>
   222           if B.is_builtin_fun ctxt c [t'] then
   223             Const (n, predT T) $ in_list T in_term t'
   224           else as_term (in_term t)
   225       | (Const (c as (n, T)), ts) =>
   226           if B.is_builtin_conn ctxt c ts
   227           then Term.list_comb (Const c, map in_form ts)
   228           else if B.is_builtin_pred ctxt c ts
   229           then Term.list_comb (Const (n, predT T), map in_term ts)
   230           else as_term (in_term t)
   231       | _ => as_term (in_term t))
   232   in
   233     map (apsnd (normalize ctxt)) #> (fn irules =>
   234     ((rewrite_rules, (~1, term_bool) :: irules),
   235      term_bool_prop :: map (in_form o prop_of o snd) irules))
   236   end
   237 
   238 
   239 
   240 (* translation from Isabelle terms into SMT intermediate terms *)
   241 
   242 val empty_context = (1, Typtab.empty, [], 1, Termtab.empty)
   243 
   244 fun make_sign header (_, typs, dtyps, _, terms) = {
   245   header = header,
   246   sorts = Typtab.fold (fn (_, (n, true)) => cons n | _ => I) typs [],
   247   funcs = Termtab.fold (fn (_, (n, SOME ss)) => cons (n,ss) | _ => I) terms [],
   248   dtyps = rev dtyps }
   249 
   250 fun make_recon (unfolds, assms) (_, typs, _, _, terms) = {
   251   typs = Symtab.make (map (apfst fst o swap) (Typtab.dest typs)),
   252     (*FIXME: don't drop the datatype information! *)
   253   terms = Symtab.make (map (fn (t, (n, _)) => (n, t)) (Termtab.dest terms)),
   254   unfolds = unfolds,
   255   assms = assms }
   256 
   257 fun string_of_index pre i = pre ^ string_of_int i
   258 
   259 fun new_typ sort_prefix proper T (Tidx, typs, dtyps, idx, terms) =
   260   let
   261     val s = string_of_index sort_prefix Tidx
   262     val U = revert_typ T
   263   in (s, (Tidx+1, Typtab.update (U, (s, proper)) typs, dtyps, idx, terms)) end
   264 
   265 fun lookup_typ (_, typs, _, _, _) = Typtab.lookup typs o revert_typ
   266 
   267 fun fresh_typ T f cx =
   268   (case lookup_typ cx T of
   269     SOME (s, _) => (s, cx)
   270   | NONE => f T cx)
   271 
   272 fun new_fun func_prefix t ss (Tidx, typs, dtyps, idx, terms) =
   273   let
   274     val f = string_of_index func_prefix idx
   275     val terms' = Termtab.update (revert_types t, (f, ss)) terms
   276   in (f, (Tidx, typs, dtyps, idx+1, terms')) end
   277 
   278 fun fresh_fun func_prefix t ss (cx as (_, _, _, _, terms)) =
   279   (case Termtab.lookup terms (revert_types t) of
   280     SOME (f, _) => (f, cx)
   281   | NONE => new_fun func_prefix t ss cx)
   282 
   283 fun mk_type (_, Tfs) (d as Datatype.DtTFree _) = the (AList.lookup (op =) Tfs d)
   284   | mk_type Ts (Datatype.DtType (n, ds)) = Type (n, map (mk_type Ts) ds)
   285   | mk_type (Tds, _) (Datatype.DtRec i) = nth Tds i
   286 
   287 fun mk_selector ctxt Ts T n (i, d) =
   288   (case Datatype_Selectors.lookup_selector ctxt (n, i+1) of
   289     NONE => raise Fail ("missing selector for datatype constructor " ^ quote n)
   290   | SOME m => mk_type Ts d |> (fn U => (Const (m, T --> U), U)))
   291 
   292 fun mk_constructor ctxt Ts T (n, args) =
   293   let val (sels, Us) = split_list (map_index (mk_selector ctxt Ts T n) args)
   294   in (Const (n, Us ---> T), sels) end
   295 
   296 fun lookup_datatype ctxt n Ts =
   297   if member (op =) [@{type_name bool}, @{type_name nat}] n then NONE
   298   else
   299     Datatype.get_info (ProofContext.theory_of ctxt) n
   300     |> Option.map (fn {descr, ...} =>
   301          let
   302            val Tds = map (fn (_, (tn, _, _)) => Type (tn, Ts))
   303              (sort (int_ord o pairself fst) descr)
   304            val Tfs = (case hd descr of (_, (_, tfs, _)) => tfs ~~ Ts)
   305          in
   306            descr |> map (fn (i, (_, _, cs)) =>
   307              (nth Tds i, map (mk_constructor ctxt (Tds, Tfs) (nth Tds i)) cs))
   308          end)
   309 
   310 fun relaxed irules = (([], irules), map (prop_of o snd) irules)
   311 
   312 fun with_context header f (ths, ts) =
   313   let val (us, context) = fold_map f ts empty_context
   314   in ((make_sign (header ts) context, us), make_recon ths context) end
   315 
   316 
   317 fun translate config ctxt comments =
   318   let
   319     val {prefixes, is_fol, header, has_datatypes, serialize} = config
   320     val {sort_prefix, func_prefix} = prefixes
   321 
   322     fun transT (T as TFree _) = fresh_typ T (new_typ sort_prefix true)
   323       | transT (T as TVar _) = (fn _ => raise TYPE ("smt_translate", [T], []))
   324       | transT (T as Type (n, Ts)) =
   325           (case B.builtin_typ ctxt T of
   326             SOME n => pair n
   327           | NONE => fresh_typ T (fn _ => fn cx =>
   328               if not has_datatypes then new_typ sort_prefix true T cx
   329               else
   330                 (case lookup_datatype ctxt n Ts of
   331                   NONE => new_typ sort_prefix true T cx
   332                 | SOME dts =>
   333                     let val cx' = new_dtyps dts cx 
   334                     in (fst (the (lookup_typ cx' T)), cx') end)))
   335 
   336     and new_dtyps dts cx =
   337       let
   338         fun new_decl i t =
   339           let val (Ts, T) = U.dest_funT i (Term.fastype_of t)
   340           in
   341             fold_map transT Ts ##>> transT T ##>>
   342             new_fun func_prefix t NONE #>> swap
   343           end
   344         fun new_dtyp_decl (con, sels) =
   345           new_decl (length sels) con ##>> fold_map (new_decl 1) sels #>>
   346           (fn ((con', _), sels') => (con', map (apsnd snd) sels'))
   347       in
   348         cx
   349         |> fold_map (new_typ sort_prefix false o fst) dts
   350         ||>> fold_map (fold_map new_dtyp_decl o snd) dts
   351         |-> (fn (ss, decls) => fn (Tidx, typs, dtyps, idx, terms) =>
   352               (Tidx, typs, (ss ~~ decls) :: dtyps, idx, terms))
   353       end
   354 
   355     fun app n ts = SApp (n, ts)
   356 
   357     fun trans t =
   358       (case Term.strip_comb t of
   359         (Const (qn, _), [Abs (_, T, t1)]) =>
   360           (case dest_quant qn T t1 of
   361             SOME (q, Ts, ps, w, b) =>
   362               fold_map transT Ts ##>> fold_map (fold_map_pat trans) ps ##>>
   363               trans b #>> (fn ((Ts', ps'), b') => SQua (q, Ts', ps', w, b'))
   364           | NONE => raise TERM ("intermediate", [t]))
   365       | (Const (@{const_name Let}, _), [t1, Abs (_, T, t2)]) =>
   366           transT T ##>> trans t1 ##>> trans t2 #>>
   367           (fn ((U, u1), u2) => SLet (U, u1, u2))
   368       | (h as Const (c as (@{const_name distinct}, T)), ts) =>
   369           (case B.builtin_fun ctxt c ts of
   370             SOME (n, ts) => fold_map trans ts #>> app n
   371           | NONE => transs h T ts)
   372       | (h as Const (c as (_, T)), ts) =>
   373           (case B.builtin_num ctxt t of
   374             SOME n => pair (SApp (n, []))
   375           | NONE =>
   376               (case B.builtin_fun ctxt c ts of
   377                 SOME (n, ts') => fold_map trans ts' #>> app n
   378               | NONE => transs h T ts))
   379       | (h as Free (_, T), ts) => transs h T ts
   380       | (Bound i, []) => pair (SVar i)
   381       | (Abs (_, _, t1 $ Bound 0), []) =>
   382         if not (loose_bvar1 (t1, 0)) then trans t1 (* eta-reduce on the fly *)
   383         else raise TERM ("smt_translate", [t])
   384       | _ => raise TERM ("smt_translate", [t]))
   385 
   386     and transs t T ts =
   387       let val (Us, U) = U.dest_funT (length ts) T
   388       in
   389         fold_map transT Us ##>> transT U #-> (fn Up =>
   390         fresh_fun func_prefix t (SOME Up) ##>> fold_map trans ts #>> SApp)
   391       end
   392   in
   393     (if is_fol then folify ctxt else relaxed) #>
   394     with_context (header ctxt) trans #>> uncurry (serialize comments)
   395   end
   396 
   397 end