src/HOL/Tools/SMT/smt_translate.ML
changeset 41059 d2b1fc1b8e19
parent 41057 8dbc951a291c
child 41123 3bb9be510a9d
--- a/src/HOL/Tools/SMT/smt_translate.ML	Mon Dec 06 16:54:22 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_translate.ML	Tue Dec 07 14:53:12 2010 +0100
@@ -17,17 +17,6 @@
 
   (* configuration options *)
   type prefixes = {sort_prefix: string, func_prefix: string}
-  type header = Proof.context -> term list -> string list
-  type strict = {
-    is_builtin_conn: string * typ -> bool,
-    is_builtin_pred: Proof.context -> string * typ -> bool,
-    is_builtin_distinct: bool}
-  type builtins = {
-    builtin_typ: Proof.context -> typ -> string option,
-    builtin_num: Proof.context -> typ -> int -> string option,
-    builtin_fun: Proof.context -> string * typ -> term list ->
-      (string * term list) option,
-    has_datatypes: bool }
   type sign = {
     header: string list,
     sorts: string list,
@@ -35,9 +24,9 @@
     funcs: (string * (string list * string)) list }
   type config = {
     prefixes: prefixes,
-    header: header,
-    strict: strict option,
-    builtins: builtins,
+    header: Proof.context -> term list -> string list,
+    is_fol: bool,
+    has_datatypes: bool,
     serialize: string list -> sign -> sterm list -> string }
   type recon = {
     typs: typ Symtab.table,
@@ -53,6 +42,7 @@
 struct
 
 structure U = SMT_Utils
+structure B = SMT_Builtin
 
 
 (* intermediate term structure *)
@@ -73,20 +63,6 @@
 
 type prefixes = {sort_prefix: string, func_prefix: string}
 
-type header = Proof.context -> term list -> string list
-
-type strict = {
-  is_builtin_conn: string * typ -> bool,
-  is_builtin_pred: Proof.context -> string * typ -> bool,
-  is_builtin_distinct: bool}
-
-type builtins = {
-  builtin_typ: Proof.context -> typ -> string option,
-  builtin_num: Proof.context -> typ -> int -> string option,
-  builtin_fun: Proof.context -> string * typ -> term list ->
-    (string * term list) option,
-  has_datatypes: bool }
-
 type sign = {
   header: string list,
   sorts: string list,
@@ -95,9 +71,9 @@
 
 type config = {
   prefixes: prefixes,
-  header: header,
-  strict: strict option,
-  builtins: builtins,
+  header: Proof.context -> term list -> string list,
+  is_fol: bool,
+  has_datatypes: bool,
   serialize: string list -> sign -> sterm list -> string }
 
 type recon = {
@@ -152,13 +128,20 @@
 
 
 
-(* enforce a strict separation between formulas and terms *)
+(* map HOL formulas to FOL formulas (i.e., separate formulas froms terms) *)
 
-val term_eq_rewr = @{lemma "term_eq x y == x = y" by (simp add: term_eq_def)}
+val tboolT = @{typ SMT.term_bool}
+val term_true = Const (@{const_name True}, tboolT)
+val term_false = Const (@{const_name False}, tboolT)
 
-val term_bool = @{lemma "~(term_eq True False)" by (simp add: term_eq_def)}
-val term_bool' = Simplifier.rewrite_rule [term_eq_rewr] term_bool
-
+val term_bool = @{lemma "True ~= False" by simp}
+val term_bool_prop =
+  let
+    fun replace @{const HOL.eq (bool)} = @{const HOL.eq (SMT.term_bool)}
+      | replace @{const True} = term_true
+      | replace @{const False} = term_false
+      | replace t = t
+  in Term.map_aterms replace (prop_of term_bool) end
 
 val needs_rewrite = Thm.prop_of #> Term.exists_subterm (fn
     Const (@{const_name Let}, _) => true
@@ -171,63 +154,57 @@
   @{lemma "P = True == P" by (rule eq_reflection) simp},
   @{lemma "if P then True else False == P" by (rule eq_reflection) simp}]
 
-fun rewrite ctxt = Simplifier.full_rewrite
-  (Simplifier.context ctxt empty_ss addsimps rewrite_rules)
+fun rewrite ctxt ct =
+  Conv.top_sweep_conv (fn ctxt' =>
+    Conv.rewrs_conv rewrite_rules then_conv rewrite ctxt') ctxt ct
 
 fun normalize ctxt thm =
   if needs_rewrite thm then Conv.fconv_rule (rewrite ctxt) thm else thm
 
-val unfold_rules = term_eq_rewr :: rewrite_rules
-
+fun revert_typ @{typ SMT.term_bool} = @{typ bool}
+  | revert_typ (Type (n, Ts)) = Type (n, map revert_typ Ts)
+  | revert_typ T = T
 
-val revert_types =
-  let
-    fun revert @{typ prop} = @{typ bool}
-      | revert (Type (n, Ts)) = Type (n, map revert Ts)
-      | revert T = T
-  in Term.map_types revert end
+val revert_types = Term.map_types revert_typ
 
-
-fun strictify {is_builtin_conn, is_builtin_pred, is_builtin_distinct} ctxt =
+fun folify ctxt =
   let
-    fun is_builtin_conn' (@{const_name True}, _) = false
-      | is_builtin_conn' (@{const_name False}, _) = false
-      | is_builtin_conn' c = is_builtin_conn c
+    fun is_builtin_conn (@{const_name True}, _) _ = false
+      | is_builtin_conn (@{const_name False}, _) _ = false
+      | is_builtin_conn c ts = B.is_builtin_conn ctxt c ts
 
-    fun is_builtin_pred' _ (@{const_name distinct}, _) [t] =
-          is_builtin_distinct andalso can HOLogic.dest_list t
-      | is_builtin_pred' ctxt c _ = is_builtin_pred ctxt c
+    fun as_term t = @{const HOL.eq (SMT.term_bool)} $ t $ term_true
 
-    val propT = @{typ prop} and boolT = @{typ bool}
-    val as_propT = (fn @{typ bool} => propT | T => T)
+    fun as_tbool @{typ bool} = tboolT
+      | as_tbool (Type (n, Ts)) = Type (n, map as_tbool Ts)
+      | as_tbool T = T
     fun mapTs f g = Term.strip_type #> (fn (Ts, T) => map f Ts ---> g T)
-    fun conn (n, T) = (n, mapTs as_propT as_propT T)
-    fun pred (n, T) = (n, mapTs I as_propT T)
+    fun predT T = mapTs as_tbool I T
+    fun funcT T = mapTs as_tbool as_tbool T
+    fun func (n, T) = Const (n, funcT T)
 
-    val term_eq = @{const HOL.eq (bool)} |> Term.dest_Const |> pred
-    fun as_term t = Const term_eq $ t $ @{const True}
-
-    val if_term = Const (@{const_name If}, [propT, boolT, boolT] ---> boolT)
-    fun wrap_in_if t = if_term $ t $ @{const True} $ @{const False}
+    fun map_ifT T = T |> Term.dest_funT ||> funcT |> (op -->)
+    val if_term = @{const If (bool)} |> Term.dest_Const ||> map_ifT |> Const
+    fun wrap_in_if t = if_term $ t $ term_true $ term_false
 
     fun in_list T f t = HOLogic.mk_list T (map f (HOLogic.dest_list t))
 
     fun in_term t =
       (case Term.strip_comb t of
-        (c as Const (@{const_name If}, _), [t1, t2, t3]) =>
-          c $ in_form t1 $ in_term t2 $ in_term t3
-      | (h as Const c, ts) =>
-          if is_builtin_conn' (conn c) orelse is_builtin_pred' ctxt (pred c) ts
+        (Const (c as @{const_name If}, T), [t1, t2, t3]) =>
+          Const (c, map_ifT T) $ in_form t1 $ in_term t2 $ in_term t3
+      | (Const c, ts) =>
+          if is_builtin_conn c ts orelse B.is_builtin_pred ctxt c ts
           then wrap_in_if (in_form t)
-          else Term.list_comb (h, map in_term ts)
-      | (h as Free _, ts) => Term.list_comb (h, map in_term ts)
+          else Term.list_comb (func c, map in_term ts)
+      | (Free (n, T), ts) => Term.list_comb (Free (n, funcT T), map in_term ts)
       | _ => t)
 
     and in_weight ((c as @{const SMT.weight}) $ w $ t) = c $ w $ in_form t
       | in_weight t = in_form t 
 
-    and in_pat ((c as Const (@{const_name pat}, _)) $ t) = c $ in_term t
-      | in_pat ((c as Const (@{const_name nopat}, _)) $ t) = c $ in_term t
+    and in_pat (Const (c as (@{const_name pat}, _)) $ t) = func c $ in_term t
+      | in_pat (Const (c as (@{const_name nopat}, _)) $ t) = func c $ in_term t
       | in_pat t = raise TERM ("in_pat", [t])
 
     and in_pats ps =
@@ -239,23 +216,23 @@
     and in_form t =
       (case Term.strip_comb t of
         (q as Const (qn, _), [Abs (n, T, t')]) =>
-          if is_some (quantifier qn) then q $ Abs (n, T, in_trig t')
+          if is_some (quantifier qn) then q $ Abs (n, as_tbool T, in_trig t')
           else as_term (in_term t)
-      | (Const (c as (@{const_name distinct}, T)), [t']) =>
-          if is_builtin_distinct andalso can HOLogic.dest_list t' then
-            Const (pred c) $ in_list T in_term t'
+      | (Const (c as (n as @{const_name distinct}, T)), [t']) =>
+          if B.is_builtin_fun ctxt c [t'] then
+            Const (n, predT T) $ in_list T in_term t'
           else as_term (in_term t)
-      | (Const c, ts) =>
-          if is_builtin_conn (conn c)
-          then Term.list_comb (Const (conn c), map in_form ts)
-          else if is_builtin_pred ctxt (pred c)
-          then Term.list_comb (Const (pred c), map in_term ts)
+      | (Const (c as (n, T)), ts) =>
+          if B.is_builtin_conn ctxt c ts
+          then Term.list_comb (Const c, map in_form ts)
+          else if B.is_builtin_pred ctxt c ts
+          then Term.list_comb (Const (n, predT T), map in_term ts)
           else as_term (in_term t)
       | _ => as_term (in_term t))
   in
     map (apsnd (normalize ctxt)) #> (fn irules =>
-    ((unfold_rules, (~1, term_bool') :: irules),
-     map (in_form o prop_of o snd) ((~1, term_bool) :: irules)))
+    ((rewrite_rules, (~1, term_bool) :: irules),
+     term_bool_prop :: map (in_form o prop_of o snd) irules))
   end
 
 
@@ -280,10 +257,12 @@
 fun string_of_index pre i = pre ^ string_of_int i
 
 fun new_typ sort_prefix proper T (Tidx, typs, dtyps, idx, terms) =
-  let val s = string_of_index sort_prefix Tidx
-  in (s, (Tidx+1, Typtab.update (T, (s, proper)) typs, dtyps, idx, terms)) end
+  let
+    val s = string_of_index sort_prefix Tidx
+    val U = revert_typ T
+  in (s, (Tidx+1, Typtab.update (U, (s, proper)) typs, dtyps, idx, terms)) end
 
-fun lookup_typ (_, typs, _, _, _) = Typtab.lookup typs
+fun lookup_typ (_, typs, _, _, _) = Typtab.lookup typs o revert_typ
 
 fun fresh_typ T f cx =
   (case lookup_typ cx T of
@@ -297,7 +276,7 @@
   in (f, (Tidx, typs, dtyps, idx+1, terms')) end
 
 fun fresh_fun func_prefix t ss (cx as (_, _, _, _, terms)) =
-  (case Termtab.lookup terms t of
+  (case Termtab.lookup terms (revert_types t) of
     SOME (f, _) => (f, cx)
   | NONE => new_fun func_prefix t ss cx)
 
@@ -335,15 +314,15 @@
   in ((make_sign (header ts) context, us), make_recon ths context) end
 
 
-fun translate {prefixes, strict, header, builtins, serialize} ctxt comments =
+fun translate config ctxt comments =
   let
+    val {prefixes, is_fol, header, has_datatypes, serialize} = config
     val {sort_prefix, func_prefix} = prefixes
-    val {builtin_typ, builtin_num, builtin_fun, has_datatypes} = builtins
 
     fun transT (T as TFree _) = fresh_typ T (new_typ sort_prefix true)
       | transT (T as TVar _) = (fn _ => raise TYPE ("smt_translate", [T], []))
       | transT (T as Type (n, Ts)) =
-          (case builtin_typ ctxt T of
+          (case B.builtin_typ ctxt T of
             SOME n => pair n
           | NONE => fresh_typ T (fn _ => fn cx =>
               if not has_datatypes then new_typ sort_prefix true T cx
@@ -387,17 +366,14 @@
           transT T ##>> trans t1 ##>> trans t2 #>>
           (fn ((U, u1), u2) => SLet (U, u1, u2))
       | (h as Const (c as (@{const_name distinct}, T)), ts) =>
-          (case builtin_fun ctxt c ts of
+          (case B.builtin_fun ctxt c ts of
             SOME (n, ts) => fold_map trans ts #>> app n
           | NONE => transs h T ts)
       | (h as Const (c as (_, T)), ts) =>
-          (case try HOLogic.dest_number t of
-            SOME (T, i) =>
-              (case builtin_num ctxt T i of
-                SOME n => pair (SApp (n, []))
-              | NONE => transs t T [])
+          (case B.builtin_num ctxt t of
+            SOME n => pair (SApp (n, []))
           | NONE =>
-              (case builtin_fun ctxt c ts of
+              (case B.builtin_fun ctxt c ts of
                 SOME (n, ts') => fold_map trans ts' #>> app n
               | NONE => transs h T ts))
       | (h as Free (_, T), ts) => transs h T ts
@@ -414,7 +390,7 @@
         fresh_fun func_prefix t (SOME Up) ##>> fold_map trans ts #>> SApp)
       end
   in
-    (case strict of SOME strct => strictify strct ctxt | NONE => relaxed) #>
+    (if is_fol then folify ctxt else relaxed) #>
     with_context (header ctxt) trans #>> uncurry (serialize comments)
   end