src/HOL/Tools/SMT/smt_translate.ML
changeset 39298 5aefb5bc8a93
parent 37124 fe22fc54b876
child 39435 5d18f4c00c07
--- a/src/HOL/Tools/SMT/smt_translate.ML	Fri Sep 10 23:56:35 2010 +0200
+++ b/src/HOL/Tools/SMT/smt_translate.ML	Mon Sep 13 06:02:47 2010 +0200
@@ -26,10 +26,12 @@
     builtin_typ: Proof.context -> typ -> string option,
     builtin_num: Proof.context -> typ -> int -> string option,
     builtin_fun: Proof.context -> string * typ -> term list ->
-      (string * term list) option }
+      (string * term list) option,
+    has_datatypes: bool }
   type sign = {
     header: string list,
     sorts: string list,
+    dtyps: (string * (string * (string * string) list) list) list list,
     funcs: (string * (string list * string)) list }
   type config = {
     prefixes: prefixes,
@@ -79,11 +81,13 @@
   builtin_typ: Proof.context -> typ -> string option,
   builtin_num: Proof.context -> typ -> int -> string option,
   builtin_fun: Proof.context -> string * typ -> term list ->
-    (string * term list) option }
+    (string * term list) option,
+  has_datatypes: bool }
 
 type sign = {
   header: string list,
   sorts: string list,
+  dtyps: (string * (string * (string * string) list) list) list list,
   funcs: (string * (string list * string)) list }
 
 type config = {
@@ -248,38 +252,67 @@
 
 (* translation from Isabelle terms into SMT intermediate terms *)
 
-val empty_context = (1, Typtab.empty, 1, Termtab.empty)
+val empty_context = (1, Typtab.empty, [], 1, Termtab.empty)
 
-fun make_sign header (_, typs, _, terms) = {
+fun make_sign header (_, typs, dtyps, _, terms) = {
   header = header,
-  sorts = Typtab.fold (cons o snd) typs [],
-  funcs = Termtab.fold (cons o snd) terms [] }
+  sorts = Typtab.fold (fn (_, (n, true)) => cons n | _ => I) typs [],
+  funcs = Termtab.fold (fn (_, (n, SOME ss)) => cons (n,ss) | _ => I) terms [],
+  dtyps = dtyps }
 
-fun make_recon (unfolds, assms) (_, typs, _, terms) = {
-  typs = Symtab.make (map swap (Typtab.dest typs)),
+fun make_recon (unfolds, assms) (_, typs, _, _, terms) = {
+  typs = Symtab.make (map (apfst fst o swap) (Typtab.dest typs)),
+    (*FIXME: don't drop the datatype information! *)
   terms = Symtab.make (map (fn (t, (n, _)) => (n, t)) (Termtab.dest terms)),
   unfolds = unfolds,
   assms = assms }
 
 fun string_of_index pre i = pre ^ string_of_int i
 
-fun fresh_typ sort_prefix T (cx as (Tidx, typs, idx, terms)) =
-  (case Typtab.lookup typs T of
-    SOME s => (s, cx)
-  | NONE =>
-      let
-        val s = string_of_index sort_prefix Tidx
-        val typs' = Typtab.update (T, s) typs
-      in (s, (Tidx+1, typs', idx, terms)) end)
+fun new_typ sort_prefix proper T (Tidx, typs, dtyps, idx, terms) =
+  let val s = string_of_index sort_prefix Tidx
+  in (s, (Tidx+1, Typtab.update (T, (s, proper)) typs, dtyps, idx, terms)) end
+
+fun lookup_typ (_, typs, _, _, _) = Typtab.lookup typs
 
-fun fresh_fun func_prefix t ss (cx as (Tidx, typs, idx, terms)) =
+fun fresh_typ T f cx =
+  (case lookup_typ cx T of
+    SOME (s, _) => (s, cx)
+  | NONE => f T cx)
+
+fun new_fun func_prefix t ss (Tidx, typs, dtyps, idx, terms) =
+  let
+    val f = string_of_index func_prefix idx
+    val terms' = Termtab.update (revert_types t, (f, ss)) terms
+  in (f, (Tidx, typs, dtyps, idx+1, terms')) end
+
+fun fresh_fun func_prefix t ss (cx as (_, _, _, _, terms)) =
   (case Termtab.lookup terms t of
     SOME (f, _) => (f, cx)
-  | NONE =>
-      let
-        val f = string_of_index func_prefix idx
-        val terms' = Termtab.update (revert_types t, (f, ss)) terms
-      in (f, (Tidx, typs, idx+1, terms')) end)
+  | NONE => new_fun func_prefix t ss cx)
+
+fun inst_const f Ts t =
+  let
+    val (n, T) = Term.dest_Const (snd (Type.varify_global [] t))
+    val inst = map Term.dest_TVar (snd (Term.dest_Type (f T))) ~~ Ts
+  in Const (n, Term_Subst.instantiateT inst T) end
+
+fun inst_constructor Ts = inst_const Term.body_type Ts
+fun inst_selector Ts = inst_const Term.domain_type Ts
+
+fun lookup_datatype ctxt n Ts = (* FIXME: use Datatype/Record.get_info *)
+  if n = @{type_name prod}
+  then SOME [
+    (Type (n, Ts), [
+      (inst_constructor Ts @{term Pair},
+        map (inst_selector Ts) [@{term fst}, @{term snd}])])]
+  else if n = @{type_name list}
+  then SOME [
+    (Type (n, Ts), [
+      (inst_constructor Ts @{term Nil}, []),
+      (inst_constructor Ts @{term Cons},
+        map (inst_selector Ts) [@{term hd}, @{term tl}])])]
+  else NONE
 
 fun relaxed thms = (([], thms), map prop_of thms)
 
@@ -291,12 +324,40 @@
 fun translate {prefixes, strict, header, builtins, serialize} ctxt comments =
   let
     val {sort_prefix, func_prefix} = prefixes
-    val {builtin_typ, builtin_num, builtin_fun} = builtins
+    val {builtin_typ, builtin_num, builtin_fun, has_datatypes} = builtins
+
+    fun transT (T as TFree _) = fresh_typ T (new_typ sort_prefix true)
+      | transT (T as TVar _) = (fn _ => raise TYPE ("smt_translate", [T], []))
+      | transT (T as Type (n, Ts)) =
+          (case builtin_typ ctxt T of
+            SOME n => pair n
+          | NONE => fresh_typ T (fn _ => fn cx =>
+              if not has_datatypes then new_typ sort_prefix true T cx
+              else
+                (case lookup_datatype ctxt n Ts of
+                  NONE => new_typ sort_prefix true T cx
+                | SOME dts =>
+                    let val cx' = new_dtyps dts cx 
+                    in (fst (the (lookup_typ cx' T)), cx') end)))
 
-    fun transT T =
-      (case builtin_typ ctxt T of
-        SOME n => pair n
-      | NONE => fresh_typ sort_prefix T)
+    and new_dtyps dts cx =
+      let
+        fun new_decl i t =
+          let val (Ts, T) = dest_funT i (Term.fastype_of t)
+          in
+            fold_map transT Ts ##>> transT T ##>>
+            new_fun func_prefix t NONE #>> swap
+          end
+        fun new_dtyp_decl (con, sels) =
+          new_decl (length sels) con ##>> fold_map (new_decl 1) sels #>>
+          (fn ((con', _), sels') => (con', map (apsnd snd) sels'))
+      in
+        cx
+        |> fold_map (new_typ sort_prefix false o fst) dts
+        ||>> fold_map (fold_map new_dtyp_decl o snd) dts
+        |-> (fn (ss, decls) => fn (Tidx, typs, dtyps, idx, terms) =>
+              (Tidx, typs, (ss ~~ decls) :: dtyps, idx, terms))
+      end
 
     fun app n ts = SApp (n, ts)
 
@@ -327,13 +388,13 @@
               | NONE => transs h T ts))
       | (h as Free (_, T), ts) => transs h T ts
       | (Bound i, []) => pair (SVar i)
-      | _ => raise TERM ("intermediate", [t]))
+      | _ => raise TERM ("smt_translate", [t]))
 
     and transs t T ts =
       let val (Us, U) = dest_funT (length ts) T
       in
         fold_map transT Us ##>> transT U #-> (fn Up =>
-        fresh_fun func_prefix t Up ##>> fold_map trans ts #>> SApp)
+        fresh_fun func_prefix t (SOME Up) ##>> fold_map trans ts #>> SApp)
       end
   in
     (case strict of SOME strct => strictify strct ctxt | NONE => relaxed) #>