--- a/src/Pure/General/sql.scala Sat May 06 11:43:43 2017 +0200
+++ b/src/Pure/General/sql.scala Sat May 06 12:45:42 2017 +0200
@@ -9,6 +9,8 @@
import java.time.OffsetDateTime
import java.sql.{DriverManager, Connection, PreparedStatement, ResultSet}
+import scala.collection.mutable
+
object SQL
{
@@ -166,17 +168,114 @@
/** SQL database operations **/
+ /* statements */
+
+ class Statement private[SQL](val db: Database, val rep: PreparedStatement)
+ {
+ stmt =>
+
+ def set_bool(i: Int, x: Boolean) { rep.setBoolean(i, x) }
+ def set_bool(i: Int, x: Option[Boolean])
+ {
+ if (x.isDefined) set_bool(i, x.get)
+ else rep.setNull(i, java.sql.Types.BOOLEAN)
+ }
+
+ def set_int(i: Int, x: Int) { rep.setInt(i, x) }
+ def set_int(i: Int, x: Option[Int])
+ {
+ if (x.isDefined) set_int(i, x.get)
+ else rep.setNull(i, java.sql.Types.INTEGER)
+ }
+
+ def set_long(i: Int, x: Long) { rep.setLong(i, x) }
+ def set_long(i: Int, x: Option[Long])
+ {
+ if (x.isDefined) set_long(i, x.get)
+ else rep.setNull(i, java.sql.Types.BIGINT)
+ }
+
+ def set_double(i: Int, x: Double) { rep.setDouble(i, x) }
+ def set_double(i: Int, x: Option[Double])
+ {
+ if (x.isDefined) set_double(i, x.get)
+ else rep.setNull(i, java.sql.Types.DOUBLE)
+ }
+
+ def set_string(i: Int, x: String) { rep.setString(i, x) }
+ def set_string(i: Int, x: Option[String]): Unit = set_string(i, x.orNull)
+
+ def set_bytes(i: Int, bytes: Bytes)
+ {
+ if (bytes == null) rep.setBytes(i, null)
+ else rep.setBinaryStream(i, bytes.stream(), bytes.length)
+ }
+ def set_bytes(i: Int, bytes: Option[Bytes]): Unit = set_bytes(i, bytes.orNull)
+
+ def set_date(i: Int, date: Date): Unit = db.set_date(stmt, i, date)
+ def set_date(i: Int, date: Option[Date]): Unit = set_date(i, date.orNull)
+
+
+ def execute(): Boolean = rep.execute()
+ def execute_query(): Result = new Result(this, rep.executeQuery())
+
+ def close(): Unit = rep.close
+ }
+
+
/* results */
- def iterator[A](rs: ResultSet)(get: ResultSet => A): Iterator[A] = new Iterator[A]
+ class Result private[SQL](val stmt: Statement, val rep: ResultSet)
{
- private var _next: Boolean = rs.next()
- def hasNext: Boolean = _next
- def next: A = { val x = get(rs); _next = rs.next(); x }
+ res =>
+
+ def next(): Boolean = rep.next()
+
+ def iterator[A](get: Result => A): Iterator[A] = new Iterator[A]
+ {
+ private var _next: Boolean = res.next()
+ def hasNext: Boolean = _next
+ def next: A = { val x = get(res); _next = res.next(); x }
+ }
+
+ def bool(column: Column): Boolean = rep.getBoolean(column.name)
+ def int(column: Column): Int = rep.getInt(column.name)
+ def long(column: Column): Long = rep.getLong(column.name)
+ def double(column: Column): Double = rep.getDouble(column.name)
+ def string(column: Column): String =
+ {
+ val s = rep.getString(column.name)
+ if (s == null) "" else s
+ }
+ def bytes(column: Column): Bytes =
+ {
+ val bs = rep.getBytes(column.name)
+ if (bs == null) Bytes.empty else Bytes(bs)
+ }
+ def date(column: Column): Date = stmt.db.date(res, column)
+
+ def get[A](column: Column, f: Column => A): Option[A] =
+ {
+ val x = f(column)
+ if (rep.wasNull) None else Some(x)
+ }
+ def get_bool(column: Column): Option[Boolean] = get(column, bool _)
+ def get_int(column: Column): Option[Int] = get(column, int _)
+ def get_long(column: Column): Option[Long] = get(column, long _)
+ def get_double(column: Column): Option[Double] = get(column, double _)
+ def get_string(column: Column): Option[String] = get(column, string _)
+ def get_bytes(column: Column): Option[Bytes] = get(column, bytes _)
+ def get_date(column: Column): Option[Date] = get(column, date _)
}
+
+ /* database */
+
trait Database
{
+ db =>
+
+
/* types */
def sql_type(T: Type.Value): Source
@@ -205,100 +304,29 @@
}
- /* statements */
+ /* statements and results */
+
+ def statement(sql: Source): Statement =
+ new Statement(db, connection.prepareStatement(sql))
- def statement(sql: Source): PreparedStatement =
- connection.prepareStatement(sql)
+ def using_statement[A](sql: Source)(f: Statement => A): A =
+ using(statement(sql))(f)
- def using_statement[A](sql: Source)(f: PreparedStatement => A): A =
- using(statement(sql))(f)
+ def set_date(stmt: Statement, i: Int, date: Date): Unit
+ def date(res: Result, column: Column): Date
def insert_permissive(table: Table, sql: Source = ""): Source
- /* input */
-
- def set_bool(stmt: PreparedStatement, i: Int, x: Boolean) { stmt.setBoolean(i, x) }
- def set_bool(stmt: PreparedStatement, i: Int, x: Option[Boolean])
- {
- if (x.isDefined) set_bool(stmt, i, x.get)
- else stmt.setNull(i, java.sql.Types.BOOLEAN)
- }
-
- def set_int(stmt: PreparedStatement, i: Int, x: Int) { stmt.setInt(i, x) }
- def set_int(stmt: PreparedStatement, i: Int, x: Option[Int])
- {
- if (x.isDefined) set_int(stmt, i, x.get)
- else stmt.setNull(i, java.sql.Types.INTEGER)
- }
-
- def set_long(stmt: PreparedStatement, i: Int, x: Long) { stmt.setLong(i, x) }
- def set_long(stmt: PreparedStatement, i: Int, x: Option[Long])
- {
- if (x.isDefined) set_long(stmt, i, x.get)
- else stmt.setNull(i, java.sql.Types.BIGINT)
- }
-
- def set_double(stmt: PreparedStatement, i: Int, x: Double) { stmt.setDouble(i, x) }
- def set_double(stmt: PreparedStatement, i: Int, x: Option[Double])
- {
- if (x.isDefined) set_double(stmt, i, x.get)
- else stmt.setNull(i, java.sql.Types.DOUBLE)
- }
-
- def set_string(stmt: PreparedStatement, i: Int, x: String) { stmt.setString(i, x) }
- def set_string(stmt: PreparedStatement, i: Int, x: Option[String]): Unit =
- set_string(stmt, i, x.orNull)
-
- def set_bytes(stmt: PreparedStatement, i: Int, bytes: Bytes)
- {
- if (bytes == null) stmt.setBytes(i, null)
- else stmt.setBinaryStream(i, bytes.stream(), bytes.length)
- }
- def set_bytes(stmt: PreparedStatement, i: Int, bytes: Option[Bytes]): Unit =
- set_bytes(stmt, i, bytes.orNull)
-
- def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit
- def set_date(stmt: PreparedStatement, i: Int, date: Option[Date]): Unit =
- set_date(stmt, i, date.orNull)
-
-
- /* output */
-
- def bool(rs: ResultSet, column: Column): Boolean = rs.getBoolean(column.name)
- def int(rs: ResultSet, column: Column): Int = rs.getInt(column.name)
- def long(rs: ResultSet, column: Column): Long = rs.getLong(column.name)
- def double(rs: ResultSet, column: Column): Double = rs.getDouble(column.name)
- def string(rs: ResultSet, column: Column): String =
- {
- val s = rs.getString(column.name)
- if (s == null) "" else s
- }
- def bytes(rs: ResultSet, column: Column): Bytes =
- {
- val bs = rs.getBytes(column.name)
- if (bs == null) Bytes.empty else Bytes(bs)
- }
- def date(rs: ResultSet, column: Column): Date
-
- def get[A](rs: ResultSet, column: Column, f: (ResultSet, Column) => A): Option[A] =
- {
- val x = f(rs, column)
- if (rs.wasNull) None else Some(x)
- }
- def get_bool(rs: ResultSet, column: Column): Option[Boolean] = get(rs, column, bool _)
- def get_int(rs: ResultSet, column: Column): Option[Int] = get(rs, column, int _)
- def get_long(rs: ResultSet, column: Column): Option[Long] = get(rs, column, long _)
- def get_double(rs: ResultSet, column: Column): Option[Double] = get(rs, column, double _)
- def get_string(rs: ResultSet, column: Column): Option[String] = get(rs, column, string _)
- def get_bytes(rs: ResultSet, column: Column): Option[Bytes] = get(rs, column, bytes _)
- def get_date(rs: ResultSet, column: Column): Option[Date] = get(rs, column, date _)
-
-
/* tables and views */
def tables: List[String] =
- iterator(connection.getMetaData.getTables(null, null, "%", null))(_.getString(3)).toList
+ {
+ val result = new mutable.ListBuffer[String]
+ val rs = connection.getMetaData.getTables(null, null, "%", null)
+ while (rs.next) { result += rs.getString(3) }
+ result.toList
+ }
def create_table(table: Table, strict: Boolean = false, sql: Source = ""): Unit =
using_statement(
@@ -345,12 +373,12 @@
def sql_type(T: SQL.Type.Value): SQL.Source = SQL.sql_type_sqlite(T)
- def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
- if (date == null) set_string(stmt, i, null: String)
- else set_string(stmt, i, date_format(date))
+ def set_date(stmt: SQL.Statement, i: Int, date: Date): Unit =
+ if (date == null) stmt.set_string(i, null: String)
+ else stmt.set_string(i, date_format(date))
- def date(rs: ResultSet, column: SQL.Column): Date =
- date_format.parse(string(rs, column))
+ def date(res: SQL.Result, column: SQL.Column): Date =
+ date_format.parse(res.string(column))
def insert_permissive(table: SQL.Table, sql: SQL.Source = ""): SQL.Source =
table.insert_cmd("INSERT OR IGNORE", sql = sql)
@@ -418,13 +446,13 @@
def sql_type(T: SQL.Type.Value): SQL.Source = SQL.sql_type_postgresql(T)
// see https://jdbc.postgresql.org/documentation/head/8-date-time.html
- def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
- if (date == null) stmt.setObject(i, null)
- else stmt.setObject(i, OffsetDateTime.from(date.to_utc.rep))
+ def set_date(stmt: SQL.Statement, i: Int, date: Date): Unit =
+ if (date == null) stmt.rep.setObject(i, null)
+ else stmt.rep.setObject(i, OffsetDateTime.from(date.to_utc.rep))
- def date(rs: ResultSet, column: SQL.Column): Date =
+ def date(res: SQL.Result, column: SQL.Column): Date =
{
- val obj = rs.getObject(column.name, classOf[OffsetDateTime])
+ val obj = res.rep.getObject(column.name, classOf[OffsetDateTime])
if (obj == null) null else Date.instant(obj.toInstant)
}