--- a/src/Pure/General/sql.scala Thu Feb 09 11:03:22 2017 +0100
+++ b/src/Pure/General/sql.scala Thu Feb 09 11:44:55 2017 +0100
@@ -39,6 +39,21 @@
def enclosure(ss: Iterable[String]): String = ss.mkString("(", ", ", ")")
+ /* types */
+
+ object Type extends Enumeration
+ {
+ val Int = Value("INTEGER")
+ val Long = Value("BIGINT")
+ val Double = Value("DOUBLE PRECISION")
+ val String = Value("TEXT")
+ val Bytes = Value("BLOB")
+ }
+
+ type Type_Name = Type.Value => String
+ def default_type_name(t: Type.Value): String = t.toString
+
+
/* columns */
object Column
@@ -52,19 +67,18 @@
def string(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[String] =
new Column_String(name, strict, primary_key)
def bytes(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Bytes] =
- new Column_Bytes(name, strict, primary_key, "BLOB") // SQL standard
- def bytea(name: String, strict: Boolean = true, primary_key: Boolean = false): Column[Bytes] =
- new Column_Bytes(name, strict, primary_key, "BYTEA") // PostgreSQL
+ new Column_Bytes(name, strict, primary_key)
}
abstract class Column[+A] private[SQL](
- val name: String, val strict: Boolean, val primary_key: Boolean)
- extends Function[ResultSet, A]
+ val name: String,
+ val strict: Boolean,
+ val primary_key: Boolean,
+ val sql_type: Type.Value) extends Function[ResultSet, A]
{
def sql_name: String = quote_ident(name)
- def sql_type: String
- def sql_decl: String =
- sql_name + " " + sql_type +
+ def sql_decl(type_name: Type_Name): String =
+ sql_name + " " + type_name(sql_type) +
(if (strict) " NOT NULL" else "") +
(if (primary_key) " PRIMARY KEY" else "")
@@ -80,34 +94,30 @@
if (rs.wasNull) None else Some(x)
}
- override def toString: String = sql_decl
+ override def toString: String = sql_decl(default_type_name)
}
class Column_Int private[SQL](name: String, strict: Boolean, primary_key: Boolean)
- extends Column[Int](name, strict, primary_key)
+ extends Column[Int](name, strict, primary_key, Type.Int)
{
- def sql_type: String = "INTEGER"
def apply(rs: ResultSet): Int = rs.getInt(name)
}
class Column_Long private[SQL](name: String, strict: Boolean, primary_key: Boolean)
- extends Column[Long](name, strict, primary_key)
+ extends Column[Long](name, strict, primary_key, Type.Long)
{
- def sql_type: String = "BIGINT"
def apply(rs: ResultSet): Long = rs.getLong(name)
}
class Column_Double private[SQL](name: String, strict: Boolean, primary_key: Boolean)
- extends Column[Double](name, strict, primary_key)
+ extends Column[Double](name, strict, primary_key, Type.Double)
{
- def sql_type: String = "DOUBLE PRECISION"
def apply(rs: ResultSet): Double = rs.getDouble(name)
}
class Column_String private[SQL](name: String, strict: Boolean, primary_key: Boolean)
- extends Column[String](name, strict, primary_key)
+ extends Column[String](name, strict, primary_key, Type.String)
{
- def sql_type: String = "TEXT"
def apply(rs: ResultSet): String =
{
val s = rs.getString(name)
@@ -115,9 +125,8 @@
}
}
- class Column_Bytes private[SQL](
- name: String, strict: Boolean, primary_key: Boolean, val sql_type: String)
- extends Column[Bytes](name, strict, primary_key)
+ class Column_Bytes private[SQL](name: String, strict: Boolean, primary_key: Boolean)
+ extends Column[Bytes](name, strict, primary_key, Type.Bytes)
{
def apply(rs: ResultSet): Bytes =
{
@@ -147,9 +156,9 @@
case _ =>
}
- def sql_create(strict: Boolean, rowid: Boolean): String =
+ def sql_create(strict: Boolean, rowid: Boolean, type_name: Type_Name): String =
"CREATE TABLE " + (if (strict) "" else "IF NOT EXISTS ") +
- quote_ident(name) + " " + enclosure(columns.map(_.sql_decl)) +
+ quote_ident(name) + " " + enclosure(columns.map(_.sql_decl(type_name))) +
(if (rowid) "" else " WITHOUT ROWID")
def sql_drop(strict: Boolean): String =
@@ -192,6 +201,11 @@
trait Database
{
+ /* types */
+
+ def type_name(t: Type.Value): String = default_type_name(t)
+
+
/* connection */
def connection: Connection
@@ -231,7 +245,7 @@
iterator(connection.getMetaData.getTables(null, null, "%", null))(_.getString(3)).toList
def create_table(table: Table, strict: Boolean = true, rowid: Boolean = true): Unit =
- using(statement(table.sql_create(strict, rowid)))(_.execute())
+ using(statement(table.sql_create(strict, rowid, type_name)))(_.execute())
def drop_table(table: Table, strict: Boolean = true): Unit =
using(statement(table.sql_drop(strict)))(_.execute())
@@ -295,5 +309,9 @@
class Database private[PostgreSQL](name: String, val connection: Connection) extends SQL.Database
{
override def toString: String = name
+
+ override def type_name(t: SQL.Type.Value): String =
+ if (t == SQL.Type.Bytes) "BYTEA"
+ else SQL.default_type_name(t)
}
}