src/Pure/General/sql.scala
changeset 65740 83388f09e9ab
parent 65739 3f206cfca625
child 65741 cf42659364c9
--- a/src/Pure/General/sql.scala	Sat May 06 11:43:43 2017 +0200
+++ b/src/Pure/General/sql.scala	Sat May 06 12:45:42 2017 +0200
@@ -9,6 +9,8 @@
 import java.time.OffsetDateTime
 import java.sql.{DriverManager, Connection, PreparedStatement, ResultSet}
 
+import scala.collection.mutable
+
 
 object SQL
 {
@@ -166,17 +168,114 @@
 
   /** SQL database operations **/
 
+  /* statements */
+
+  class Statement private[SQL](val db: Database, val rep: PreparedStatement)
+  {
+    stmt =>
+
+    def set_bool(i: Int, x: Boolean) { rep.setBoolean(i, x) }
+    def set_bool(i: Int, x: Option[Boolean])
+    {
+      if (x.isDefined) set_bool(i, x.get)
+      else rep.setNull(i, java.sql.Types.BOOLEAN)
+    }
+
+    def set_int(i: Int, x: Int) { rep.setInt(i, x) }
+    def set_int(i: Int, x: Option[Int])
+    {
+      if (x.isDefined) set_int(i, x.get)
+      else rep.setNull(i, java.sql.Types.INTEGER)
+    }
+
+    def set_long(i: Int, x: Long) { rep.setLong(i, x) }
+    def set_long(i: Int, x: Option[Long])
+    {
+      if (x.isDefined) set_long(i, x.get)
+      else rep.setNull(i, java.sql.Types.BIGINT)
+    }
+
+    def set_double(i: Int, x: Double) { rep.setDouble(i, x) }
+    def set_double(i: Int, x: Option[Double])
+    {
+      if (x.isDefined) set_double(i, x.get)
+      else rep.setNull(i, java.sql.Types.DOUBLE)
+    }
+
+    def set_string(i: Int, x: String) { rep.setString(i, x) }
+    def set_string(i: Int, x: Option[String]): Unit = set_string(i, x.orNull)
+
+    def set_bytes(i: Int, bytes: Bytes)
+    {
+      if (bytes == null) rep.setBytes(i, null)
+      else rep.setBinaryStream(i, bytes.stream(), bytes.length)
+    }
+    def set_bytes(i: Int, bytes: Option[Bytes]): Unit = set_bytes(i, bytes.orNull)
+
+    def set_date(i: Int, date: Date): Unit = db.set_date(stmt, i, date)
+    def set_date(i: Int, date: Option[Date]): Unit = set_date(i, date.orNull)
+
+
+    def execute(): Boolean = rep.execute()
+    def execute_query(): Result = new Result(this, rep.executeQuery())
+
+    def close(): Unit = rep.close
+  }
+
+
   /* results */
 
-  def iterator[A](rs: ResultSet)(get: ResultSet => A): Iterator[A] = new Iterator[A]
+  class Result private[SQL](val stmt: Statement, val rep: ResultSet)
   {
-    private var _next: Boolean = rs.next()
-    def hasNext: Boolean = _next
-    def next: A = { val x = get(rs); _next = rs.next(); x }
+    res =>
+
+    def next(): Boolean = rep.next()
+
+    def iterator[A](get: Result => A): Iterator[A] = new Iterator[A]
+    {
+      private var _next: Boolean = res.next()
+      def hasNext: Boolean = _next
+      def next: A = { val x = get(res); _next = res.next(); x }
+    }
+
+    def bool(column: Column): Boolean = rep.getBoolean(column.name)
+    def int(column: Column): Int = rep.getInt(column.name)
+    def long(column: Column): Long = rep.getLong(column.name)
+    def double(column: Column): Double = rep.getDouble(column.name)
+    def string(column: Column): String =
+    {
+      val s = rep.getString(column.name)
+      if (s == null) "" else s
+    }
+    def bytes(column: Column): Bytes =
+    {
+      val bs = rep.getBytes(column.name)
+      if (bs == null) Bytes.empty else Bytes(bs)
+    }
+    def date(column: Column): Date = stmt.db.date(res, column)
+
+    def get[A](column: Column, f: Column => A): Option[A] =
+    {
+      val x = f(column)
+      if (rep.wasNull) None else Some(x)
+    }
+    def get_bool(column: Column): Option[Boolean] = get(column, bool _)
+    def get_int(column: Column): Option[Int] = get(column, int _)
+    def get_long(column: Column): Option[Long] = get(column, long _)
+    def get_double(column: Column): Option[Double] = get(column, double _)
+    def get_string(column: Column): Option[String] = get(column, string _)
+    def get_bytes(column: Column): Option[Bytes] = get(column, bytes _)
+    def get_date(column: Column): Option[Date] = get(column, date _)
   }
 
+
+  /* database */
+
   trait Database
   {
+    db =>
+
+
     /* types */
 
     def sql_type(T: Type.Value): Source
@@ -205,100 +304,29 @@
     }
 
 
-    /* statements */
+    /* statements and results */
+
+    def statement(sql: Source): Statement =
+      new Statement(db, connection.prepareStatement(sql))
 
-    def statement(sql: Source): PreparedStatement =
-      connection.prepareStatement(sql)
+    def using_statement[A](sql: Source)(f: Statement => A): A =
+      using(statement(sql))(f)
 
-    def using_statement[A](sql: Source)(f: PreparedStatement => A): A =
-      using(statement(sql))(f)
+    def set_date(stmt: Statement, i: Int, date: Date): Unit
+    def date(res: Result, column: Column): Date
 
     def insert_permissive(table: Table, sql: Source = ""): Source
 
 
-    /* input */
-
-    def set_bool(stmt: PreparedStatement, i: Int, x: Boolean) { stmt.setBoolean(i, x) }
-    def set_bool(stmt: PreparedStatement, i: Int, x: Option[Boolean])
-    {
-      if (x.isDefined) set_bool(stmt, i, x.get)
-      else stmt.setNull(i, java.sql.Types.BOOLEAN)
-    }
-
-    def set_int(stmt: PreparedStatement, i: Int, x: Int) { stmt.setInt(i, x) }
-    def set_int(stmt: PreparedStatement, i: Int, x: Option[Int])
-    {
-      if (x.isDefined) set_int(stmt, i, x.get)
-      else stmt.setNull(i, java.sql.Types.INTEGER)
-    }
-
-    def set_long(stmt: PreparedStatement, i: Int, x: Long) { stmt.setLong(i, x) }
-    def set_long(stmt: PreparedStatement, i: Int, x: Option[Long])
-    {
-      if (x.isDefined) set_long(stmt, i, x.get)
-      else stmt.setNull(i, java.sql.Types.BIGINT)
-    }
-
-    def set_double(stmt: PreparedStatement, i: Int, x: Double) { stmt.setDouble(i, x) }
-    def set_double(stmt: PreparedStatement, i: Int, x: Option[Double])
-    {
-      if (x.isDefined) set_double(stmt, i, x.get)
-      else stmt.setNull(i, java.sql.Types.DOUBLE)
-    }
-
-    def set_string(stmt: PreparedStatement, i: Int, x: String) { stmt.setString(i, x) }
-    def set_string(stmt: PreparedStatement, i: Int, x: Option[String]): Unit =
-      set_string(stmt, i, x.orNull)
-
-    def set_bytes(stmt: PreparedStatement, i: Int, bytes: Bytes)
-    {
-      if (bytes == null) stmt.setBytes(i, null)
-      else stmt.setBinaryStream(i, bytes.stream(), bytes.length)
-    }
-    def set_bytes(stmt: PreparedStatement, i: Int, bytes: Option[Bytes]): Unit =
-      set_bytes(stmt, i, bytes.orNull)
-
-    def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit
-    def set_date(stmt: PreparedStatement, i: Int, date: Option[Date]): Unit =
-      set_date(stmt, i, date.orNull)
-
-
-    /* output */
-
-    def bool(rs: ResultSet, column: Column): Boolean = rs.getBoolean(column.name)
-    def int(rs: ResultSet, column: Column): Int = rs.getInt(column.name)
-    def long(rs: ResultSet, column: Column): Long = rs.getLong(column.name)
-    def double(rs: ResultSet, column: Column): Double = rs.getDouble(column.name)
-    def string(rs: ResultSet, column: Column): String =
-    {
-      val s = rs.getString(column.name)
-      if (s == null) "" else s
-    }
-    def bytes(rs: ResultSet, column: Column): Bytes =
-    {
-      val bs = rs.getBytes(column.name)
-      if (bs == null) Bytes.empty else Bytes(bs)
-    }
-    def date(rs: ResultSet, column: Column): Date
-
-    def get[A](rs: ResultSet, column: Column, f: (ResultSet, Column) => A): Option[A] =
-    {
-      val x = f(rs, column)
-      if (rs.wasNull) None else Some(x)
-    }
-    def get_bool(rs: ResultSet, column: Column): Option[Boolean] = get(rs, column, bool _)
-    def get_int(rs: ResultSet, column: Column): Option[Int] = get(rs, column, int _)
-    def get_long(rs: ResultSet, column: Column): Option[Long] = get(rs, column, long _)
-    def get_double(rs: ResultSet, column: Column): Option[Double] = get(rs, column, double _)
-    def get_string(rs: ResultSet, column: Column): Option[String] = get(rs, column, string _)
-    def get_bytes(rs: ResultSet, column: Column): Option[Bytes] = get(rs, column, bytes _)
-    def get_date(rs: ResultSet, column: Column): Option[Date] = get(rs, column, date _)
-
-
     /* tables and views */
 
     def tables: List[String] =
-      iterator(connection.getMetaData.getTables(null, null, "%", null))(_.getString(3)).toList
+    {
+      val result = new mutable.ListBuffer[String]
+      val rs = connection.getMetaData.getTables(null, null, "%", null)
+      while (rs.next) { result += rs.getString(3) }
+      result.toList
+    }
 
     def create_table(table: Table, strict: Boolean = false, sql: Source = ""): Unit =
       using_statement(
@@ -345,12 +373,12 @@
 
     def sql_type(T: SQL.Type.Value): SQL.Source = SQL.sql_type_sqlite(T)
 
-    def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
-      if (date == null) set_string(stmt, i, null: String)
-      else set_string(stmt, i, date_format(date))
+    def set_date(stmt: SQL.Statement, i: Int, date: Date): Unit =
+      if (date == null) stmt.set_string(i, null: String)
+      else stmt.set_string(i, date_format(date))
 
-    def date(rs: ResultSet, column: SQL.Column): Date =
-      date_format.parse(string(rs, column))
+    def date(res: SQL.Result, column: SQL.Column): Date =
+      date_format.parse(res.string(column))
 
     def insert_permissive(table: SQL.Table, sql: SQL.Source = ""): SQL.Source =
       table.insert_cmd("INSERT OR IGNORE", sql = sql)
@@ -418,13 +446,13 @@
     def sql_type(T: SQL.Type.Value): SQL.Source = SQL.sql_type_postgresql(T)
 
     // see https://jdbc.postgresql.org/documentation/head/8-date-time.html
-    def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
-      if (date == null) stmt.setObject(i, null)
-      else stmt.setObject(i, OffsetDateTime.from(date.to_utc.rep))
+    def set_date(stmt: SQL.Statement, i: Int, date: Date): Unit =
+      if (date == null) stmt.rep.setObject(i, null)
+      else stmt.rep.setObject(i, OffsetDateTime.from(date.to_utc.rep))
 
-    def date(rs: ResultSet, column: SQL.Column): Date =
+    def date(res: SQL.Result, column: SQL.Column): Date =
     {
-      val obj = rs.getObject(column.name, classOf[OffsetDateTime])
+      val obj = res.rep.getObject(column.name, classOf[OffsetDateTime])
       if (obj == null) null else Date.instant(obj.toInstant)
     }