src/HOL/Tools/SMT/smt_translate.ML
changeset 41281 679118e35378
parent 41250 41f86829e22f
child 41328 6792a5c92a58
     1.1 --- a/src/HOL/Tools/SMT/smt_translate.ML	Sun Dec 19 17:55:56 2010 +0100
     1.2 +++ b/src/HOL/Tools/SMT/smt_translate.ML	Sun Dec 19 18:54:29 2010 +0100
     1.3 @@ -120,12 +120,12 @@
     1.4    dtyps = dtyps,
     1.5    funcs = Termtab.fold (fn (_, (n, SOME ss)) => cons (n,ss) | _ => I) terms []}
     1.6  
     1.7 -fun recon_of ctxt rules thms ithms revertT revert (_, _, typs, _, _, terms) =
     1.8 +fun recon_of ctxt rules thms ithms (_, _, typs, _, _, terms) =
     1.9    let
    1.10 -    fun add_typ (T, (n, _)) = Symtab.update (n, revertT T)
    1.11 +    fun add_typ (T, (n, _)) = Symtab.update (n, T)
    1.12      val typs' = Typtab.fold add_typ typs Symtab.empty
    1.13  
    1.14 -    fun add_fun (t, (n, _)) = Symtab.update (n, revert t)
    1.15 +    fun add_fun (t, (n, _)) = Symtab.update (n, t)
    1.16      val terms' = Termtab.fold add_fun terms Symtab.empty
    1.17  
    1.18      val assms = map (pair ~1) thms @ ithms
    1.19 @@ -137,43 +137,11 @@
    1.20  
    1.21  (* preprocessing *)
    1.22  
    1.23 -(** mark built-in constants as Var **)
    1.24 -
    1.25 -fun mark_builtins ctxt =
    1.26 -  let
    1.27 -    (*
    1.28 -      Note: schematic terms cannot occur anymore in terms at this stage.
    1.29 -    *)
    1.30 -    fun mark t =
    1.31 -      (case Term.strip_comb t of
    1.32 -        (u as Const (@{const_name If}, _), ts) => marks u ts
    1.33 -      | (u as @{const SMT.weight}, [t1, t2]) =>
    1.34 -          mark t2 #>> (fn t2' => u $ t1 $ t2')
    1.35 -      | (u as Const c, ts) =>
    1.36 -          (case B.builtin_num ctxt t of
    1.37 -            SOME (n, T) =>
    1.38 -              let val v = ((n, 0), T)
    1.39 -              in Vartab.update v #> pair (Var v) end
    1.40 -          | NONE =>
    1.41 -              (case B.builtin_fun ctxt c ts of
    1.42 -                SOME ((ni, T), us, U) =>
    1.43 -                  Vartab.update (ni, U) #> marks (Var (ni, T)) us
    1.44 -              | NONE => marks u ts))
    1.45 -      | (Abs (n, T, u), ts) => mark u #-> (fn u' => marks (Abs (n, T, u')) ts)
    1.46 -      | (u, ts) => marks u ts)
    1.47 - 
    1.48 -    and marks t ts = fold_map mark ts #>> Term.list_comb o pair t
    1.49 -
    1.50 -  in (fn ts => swap (fold_map mark ts Vartab.empty)) end
    1.51 -
    1.52 -fun mark_builtins' ctxt t = hd (snd (mark_builtins ctxt [t]))
    1.53 -
    1.54 -
    1.55  (** FIXME **)
    1.56  
    1.57  local
    1.58    (*
    1.59 -    mark constructors and selectors as Vars (forcing eta-expansion),
    1.60 +    force eta-expansion for constructors and selectors,
    1.61      add missing datatype selectors via hypothetical definitions,
    1.62      also return necessary datatype and record theorems
    1.63    *)
    1.64 @@ -200,38 +168,44 @@
    1.65      let val (U1, U2) = Term.dest_funT T ||> Term.domain_type
    1.66      in Abs (Name.uu, U1, eta U2 (l $ Bound 0)) end
    1.67  
    1.68 -  fun expf t i T ts =
    1.69 -    let val Ts = U.dest_funT i T |> fst |> drop (length ts)
    1.70 +  fun expf k i T t =
    1.71 +    let val Ts = drop i (fst (U.dest_funT k T))
    1.72      in
    1.73 -      Term.list_comb (t, ts)
    1.74 -      |> Term.incr_boundvars (length Ts)
    1.75 +      Term.incr_boundvars (length Ts) t
    1.76        |> fold_index (fn (i, _) => fn u => u $ Bound i) Ts
    1.77        |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts
    1.78      end
    1.79 -
    1.80 -  fun expand ((q as Const (@{const_name All}, _)) $ Abs a) = q $ abs_expand a
    1.81 -    | expand ((q as Const (@{const_name All}, T)) $ t) = q $ exp T t
    1.82 -    | expand (q as Const (@{const_name All}, T)) = exp2 T q
    1.83 -    | expand ((q as Const (@{const_name Ex}, _)) $ Abs a) = q $ abs_expand a
    1.84 -    | expand ((q as Const (@{const_name Ex}, T)) $ t) = q $ exp T t
    1.85 -    | expand (q as Const (@{const_name Ex}, T)) = exp2 T q
    1.86 -    | expand ((l as Const (@{const_name Let}, _)) $ t $ Abs a) =
    1.87 -        l $ expand t $ abs_expand a
    1.88 -    | expand ((l as Const (@{const_name Let}, T)) $ t $ u) =
    1.89 -        l $ expand t $ exp (Term.range_type T) u
    1.90 -    | expand ((l as Const (@{const_name Let}, T)) $ t) = exp2 T (l $ expand t)
    1.91 -    | expand (l as Const (@{const_name Let}, T)) = exp2' T l
    1.92 -    | expand (Abs a) = abs_expand a
    1.93 -    | expand t =
    1.94 -        (case Term.strip_comb t of
    1.95 -          (u as Const (@{const_name If}, T), ts) => expf u 3 T (map expand ts)
    1.96 -        | (u as Var ((_, i), T), ts) => expf u i T (map expand ts)
    1.97 -        | (u, ts) => Term.list_comb (u, map expand ts))
    1.98 -
    1.99 -  and abs_expand (n, T, t) = Abs (n, T, expand t)
   1.100  in
   1.101  
   1.102 -val eta_expand = map expand
   1.103 +fun eta_expand ctxt =
   1.104 +  let
   1.105 +    fun expand ((q as Const (@{const_name All}, _)) $ Abs a) = q $ abs_expand a
   1.106 +      | expand ((q as Const (@{const_name All}, T)) $ t) = q $ exp T t
   1.107 +      | expand (q as Const (@{const_name All}, T)) = exp2 T q
   1.108 +      | expand ((q as Const (@{const_name Ex}, _)) $ Abs a) = q $ abs_expand a
   1.109 +      | expand ((q as Const (@{const_name Ex}, T)) $ t) = q $ exp T t
   1.110 +      | expand (q as Const (@{const_name Ex}, T)) = exp2 T q
   1.111 +      | expand ((l as Const (@{const_name Let}, _)) $ t $ Abs a) =
   1.112 +          l $ expand t $ abs_expand a
   1.113 +      | expand ((l as Const (@{const_name Let}, T)) $ t $ u) =
   1.114 +          l $ expand t $ exp (Term.range_type T) u
   1.115 +      | expand ((l as Const (@{const_name Let}, T)) $ t) =
   1.116 +          exp2 T (l $ expand t)
   1.117 +      | expand (l as Const (@{const_name Let}, T)) = exp2' T l
   1.118 +      | expand t =
   1.119 +          (case Term.strip_comb t of
   1.120 +            (u as Const (c as (_, T)), ts) =>
   1.121 +              (case B.dest_builtin ctxt c ts of
   1.122 +                SOME (_, k, us, mk) =>
   1.123 +                  if k = length us then mk (map expand us)
   1.124 +                  else expf k (length ts) T (mk (map expand us))
   1.125 +              | NONE => Term.list_comb (u, map expand ts))
   1.126 +          | (Abs a, ts) => Term.list_comb (abs_expand a, map expand ts)
   1.127 +          | (u, ts) => Term.list_comb (u, map expand ts))
   1.128 +
   1.129 +    and abs_expand (n, T, t) = Abs (n, T, expand t)
   1.130 +  
   1.131 +  in map expand end
   1.132  
   1.133  end
   1.134  
   1.135 @@ -354,118 +328,92 @@
   1.136  
   1.137  (** map HOL formulas to FOL formulas (i.e., separate formulas froms terms) **)
   1.138  
   1.139 -val tboolT = @{typ SMT.term_bool}
   1.140 -val term_true = Const (@{const_name True}, tboolT)
   1.141 -val term_false = Const (@{const_name False}, tboolT)
   1.142 -
   1.143 -val term_bool = @{lemma "True ~= False" by simp}
   1.144 -val term_bool_prop =
   1.145 -  let
   1.146 -    fun replace @{const HOL.eq (bool)} = @{const HOL.eq (SMT.term_bool)}
   1.147 -      | replace @{const True} = term_true
   1.148 -      | replace @{const False} = term_false
   1.149 -      | replace t = t
   1.150 -  in Term.map_aterms replace (U.prop_of term_bool) end
   1.151 +local
   1.152 +  val term_bool = @{lemma "SMT.term_true ~= SMT.term_false"
   1.153 +    by (simp add: SMT.term_true_def SMT.term_false_def)}
   1.154  
   1.155 -val fol_rules = [
   1.156 -  Let_def,
   1.157 -  @{lemma "P = True == P" by (rule eq_reflection) simp},
   1.158 -  @{lemma "if P then True else False == P" by (rule eq_reflection) simp}]
   1.159 +  val fol_rules = [
   1.160 +    Let_def,
   1.161 +    mk_meta_eq @{thm SMT.term_true_def},
   1.162 +    mk_meta_eq @{thm SMT.term_false_def},
   1.163 +    @{lemma "P = True == P" by (rule eq_reflection) simp},
   1.164 +    @{lemma "if P then True else False == P" by (rule eq_reflection) simp}]
   1.165  
   1.166 -fun reduce_let (Const (@{const_name Let}, _) $ t $ u) =
   1.167 -      reduce_let (Term.betapply (u, t))
   1.168 -  | reduce_let (t $ u) = reduce_let t $ reduce_let u
   1.169 -  | reduce_let (Abs (n, T, t)) = Abs (n, T, reduce_let t)
   1.170 -  | reduce_let t = t
   1.171 +  fun reduce_let (Const (@{const_name Let}, _) $ t $ u) =
   1.172 +        reduce_let (Term.betapply (u, t))
   1.173 +    | reduce_let (t $ u) = reduce_let t $ reduce_let u
   1.174 +    | reduce_let (Abs (n, T, t)) = Abs (n, T, reduce_let t)
   1.175 +    | reduce_let t = t
   1.176  
   1.177 -fun is_pred_type NONE = false
   1.178 -  | is_pred_type (SOME T) = (Term.body_type T = @{typ bool})
   1.179 +  fun as_term t = @{const HOL.eq (bool)} $ t $ @{const SMT.term_true}
   1.180  
   1.181 -fun is_conn_type NONE = false
   1.182 -  | is_conn_type (SOME T) =
   1.183 -      forall (equal @{typ bool}) (Term.body_type T :: Term.binder_types T)
   1.184 +  fun wrap_in_if t =
   1.185 +    @{const If (bool)} $ t $ @{const SMT.term_true} $ @{const SMT.term_false}
   1.186 +
   1.187 +  fun is_builtin_conn_or_pred ctxt c ts =
   1.188 +    is_some (B.dest_builtin_conn ctxt c ts) orelse
   1.189 +    is_some (B.dest_builtin_pred ctxt c ts)
   1.190  
   1.191 -fun revert_typ @{typ SMT.term_bool} = @{typ bool}
   1.192 -  | revert_typ (Type (n, Ts)) = Type (n, map revert_typ Ts)
   1.193 -  | revert_typ T = T
   1.194 +  fun builtin b ctxt c ts =
   1.195 +    (case (Const c, ts) of
   1.196 +      (@{const HOL.eq (bool)}, [t, u]) =>
   1.197 +        if t = @{const SMT.term_true} orelse u = @{const SMT.term_true} then
   1.198 +          B.dest_builtin_eq ctxt t u
   1.199 +        else b ctxt c ts
   1.200 +    | _ => b ctxt c ts)
   1.201 +in
   1.202  
   1.203 -val revert_types = Term.map_types revert_typ
   1.204 -
   1.205 -fun folify ctxt builtins =
   1.206 +fun folify ctxt =
   1.207    let
   1.208 -    fun as_term t = @{const HOL.eq (SMT.term_bool)} $ t $ term_true
   1.209 -
   1.210 -    fun as_tbool @{typ bool} = tboolT
   1.211 -      | as_tbool (Type (n, Ts)) = Type (n, map as_tbool Ts)
   1.212 -      | as_tbool T = T
   1.213 -    fun mapTs f g i = U.dest_funT i #> (fn (Ts, T) => map f Ts ---> g T)
   1.214 -    fun predT i = mapTs as_tbool I i
   1.215 -    fun funcT i = mapTs as_tbool as_tbool i
   1.216 -    fun func i (n, T) = (n, funcT i T)
   1.217 -
   1.218 -    fun map_ifT T = T |> Term.dest_funT ||> funcT 2 |> (op -->)
   1.219 -    val if_term = @{const If (bool)} |> Term.dest_Const ||> map_ifT |> Const
   1.220 -    fun wrap_in_if t = if_term $ t $ term_true $ term_false
   1.221 -
   1.222      fun in_list T f t = HOLogic.mk_list T (map f (HOLogic.dest_list t))
   1.223  
   1.224      fun in_term t =
   1.225        (case Term.strip_comb t of
   1.226 -        (Const (n as @{const_name If}, T), [t1, t2, t3]) =>
   1.227 -          Const (n, map_ifT T) $ in_form t1 $ in_term t2 $ in_term t3
   1.228 -      | (Const (@{const_name HOL.eq}, _), _) => wrap_in_if (in_form t)
   1.229 -      | (Var (ni as (_, i), T), ts) =>
   1.230 -          let val U = Vartab.lookup builtins ni
   1.231 -          in
   1.232 -            if is_conn_type U orelse is_pred_type U then wrap_in_if (in_form t)
   1.233 -            else Term.list_comb (Var (ni, funcT i T), map in_term ts)
   1.234 -          end
   1.235 +        (@{const True}, []) => @{const SMT.term_true}
   1.236 +      | (@{const False}, []) => @{const SMT.term_false}
   1.237 +      | (u as Const (@{const_name If}, _), [t1, t2, t3]) =>
   1.238 +          u $ in_form t1 $ in_term t2 $ in_term t3
   1.239        | (Const c, ts) =>
   1.240 -          Term.list_comb (Const (func (length ts) c), map in_term ts)
   1.241 -      | (Free c, ts) =>
   1.242 -          Term.list_comb (Free (func (length ts) c), map in_term ts)
   1.243 +          if is_builtin_conn_or_pred ctxt c ts then wrap_in_if (in_form t)
   1.244 +          else Term.list_comb (Const c, map in_term ts)
   1.245 +      | (Free c, ts) => Term.list_comb (Free c, map in_term ts)
   1.246        | _ => t)
   1.247  
   1.248      and in_weight ((c as @{const SMT.weight}) $ w $ t) = c $ w $ in_form t
   1.249        | in_weight t = in_form t 
   1.250  
   1.251 -    and in_pat (Const (c as (@{const_name SMT.pat}, _)) $ t) =
   1.252 -          Const (func 1 c) $ in_term t
   1.253 -      | in_pat (Const (c as (@{const_name SMT.nopat}, _)) $ t) =
   1.254 -          Const (func 1 c) $ in_term t
   1.255 +    and in_pat ((p as Const (@{const_name SMT.pat}, _)) $ t) = p $ in_term t
   1.256 +      | in_pat ((p as Const (@{const_name SMT.nopat}, _)) $ t) = p $ in_term t
   1.257        | in_pat t = raise TERM ("bad pattern", [t])
   1.258  
   1.259      and in_pats ps =
   1.260        in_list @{typ "SMT.pattern list"} (in_list @{typ SMT.pattern} in_pat) ps
   1.261  
   1.262 -    and in_trig ((c as @{const SMT.trigger}) $ p $ t) =
   1.263 +    and in_trigger ((c as @{const SMT.trigger}) $ p $ t) =
   1.264            c $ in_pats p $ in_weight t
   1.265 -      | in_trig t = in_weight t
   1.266 +      | in_trigger t = in_weight t
   1.267  
   1.268      and in_form t =
   1.269        (case Term.strip_comb t of
   1.270          (q as Const (qn, _), [Abs (n, T, u)]) =>
   1.271            if member (op =) [@{const_name All}, @{const_name Ex}] qn then
   1.272 -            q $ Abs (n, as_tbool T, in_trig u)
   1.273 +            q $ Abs (n, T, in_trigger u)
   1.274            else as_term (in_term t)
   1.275 -      | (u as Const (@{const_name If}, _), ts) =>
   1.276 -          Term.list_comb (u, map in_form ts)
   1.277 -      | (b as @{const HOL.eq (bool)}, ts) => Term.list_comb (b, map in_form ts)
   1.278 -      | (Const (n as @{const_name HOL.eq}, T), ts) =>
   1.279 -          Term.list_comb (Const (n, predT 2 T), map in_term ts)
   1.280 -      | (b as Var (ni as (_, i), T), ts) =>
   1.281 -          if is_conn_type (Vartab.lookup builtins ni) then
   1.282 -            Term.list_comb (b, map in_form ts)
   1.283 -          else if is_pred_type (Vartab.lookup builtins ni) then
   1.284 -            Term.list_comb (Var (ni, predT i T), map in_term ts)
   1.285 -          else as_term (in_term t)
   1.286 +      | (Const c, ts) =>
   1.287 +          (case B.dest_builtin_conn ctxt c ts of
   1.288 +            SOME (_, _, us, mk) => mk (map in_form us)
   1.289 +          | NONE =>
   1.290 +              (case B.dest_builtin_pred ctxt c ts of
   1.291 +                SOME (_, _, us, mk) => mk (map in_term us)
   1.292 +              | NONE => as_term (in_term t)))
   1.293        | _ => as_term (in_term t))
   1.294    in
   1.295      map (reduce_let #> in_form) #>
   1.296 -    cons (mark_builtins' ctxt term_bool_prop) #>
   1.297 -    pair (fol_rules, [term_bool])
   1.298 +    cons (U.prop_of term_bool) #>
   1.299 +    pair (fol_rules, [term_bool], builtin)
   1.300    end
   1.301  
   1.302 +end
   1.303  
   1.304  
   1.305  (* translation into intermediate format *)
   1.306 @@ -513,17 +461,15 @@
   1.307  
   1.308  (** translation from Isabelle terms into SMT intermediate terms **)
   1.309  
   1.310 -fun intermediate header dtyps ctxt ts trx =
   1.311 +fun intermediate header dtyps builtin ctxt ts trx =
   1.312    let
   1.313      fun transT (T as TFree _) = add_typ T true
   1.314        | transT (T as TVar _) = (fn _ => raise TYPE ("bad SMT type", [T], []))
   1.315        | transT (T as Type _) =
   1.316 -          (case B.builtin_typ ctxt T of
   1.317 +          (case B.dest_builtin_typ ctxt T of
   1.318              SOME n => pair n
   1.319            | NONE => add_typ T true)
   1.320  
   1.321 -    val unmarked_builtins = [@{const_name If}, @{const_name HOL.eq}]
   1.322 -
   1.323      fun app n ts = SApp (n, ts)
   1.324  
   1.325      fun trans t =
   1.326 @@ -537,13 +483,10 @@
   1.327        | (Const (@{const_name Let}, _), [t1, Abs (_, T, t2)]) =>
   1.328            transT T ##>> trans t1 ##>> trans t2 #>>
   1.329            (fn ((U, u1), u2) => SLet (U, u1, u2))
   1.330 -      | (Var ((n, _), _), ts) => fold_map trans ts #>> app n
   1.331 -      | (u as Const (c as (n, T)), ts) =>
   1.332 -          if member (op =) unmarked_builtins n then
   1.333 -            (case B.builtin_fun ctxt c ts of
   1.334 -              SOME (((m, _), _), us, _) => fold_map trans us #>> app m
   1.335 -            | NONE => raise TERM ("not a built-in symbol", [t]))
   1.336 -          else transs u T ts
   1.337 +      | (u as Const (c as (_, T)), ts) =>
   1.338 +          (case builtin ctxt c ts of
   1.339 +            SOME (n, _, us, _) => fold_map trans us #>> app n
   1.340 +          | NONE => transs u T ts)
   1.341        | (u as Free (_, T), ts) => transs u T ts
   1.342        | (Bound i, []) => pair (SVar i)
   1.343        | _ => raise TERM ("bad SMT term", [t]))
   1.344 @@ -590,10 +533,7 @@
   1.345  
   1.346      fun no_dtyps (tr_context, ctxt) ts = (([], tr_context, ctxt), ts)
   1.347  
   1.348 -    val (builtins, ts1) =
   1.349 -      ithms
   1.350 -      |> map (Envir.beta_eta_contract o U.prop_of o snd)
   1.351 -      |> mark_builtins ctxt
   1.352 +    val ts1 = map (Envir.beta_eta_contract o U.prop_of o snd) ithms
   1.353  
   1.354      val ((dtyps, tr_context, ctxt1), ts2) =
   1.355        ((make_tr_context prefixes, ctxt), ts1)
   1.356 @@ -601,19 +541,19 @@
   1.357  
   1.358      val (ctxt2, ts3) =
   1.359        ts2
   1.360 -      |> eta_expand
   1.361 +      |> eta_expand ctxt1
   1.362        |> lift_lambdas ctxt1
   1.363        ||> intro_explicit_application
   1.364  
   1.365 -    val ((rewrite_rules, extra_thms), ts4) =
   1.366 -      (if is_fol then folify ctxt2 builtins else pair ([], [])) ts3
   1.367 +    val ((rewrite_rules, extra_thms, builtin), ts4) =
   1.368 +      (if is_fol then folify ctxt2 else pair ([], [], I)) ts3
   1.369  
   1.370      val rewrite_rules' = fun_app_eq :: rewrite_rules
   1.371    in
   1.372      (ts4, tr_context)
   1.373 -    |-> intermediate header dtyps ctxt2
   1.374 +    |-> intermediate header dtyps (builtin B.dest_builtin) ctxt2
   1.375      |>> uncurry (serialize comments)
   1.376 -    ||> recon_of ctxt2 rewrite_rules' extra_thms ithms revert_typ revert_types
   1.377 +    ||> recon_of ctxt2 rewrite_rules' extra_thms ithms
   1.378    end
   1.379  
   1.380  end