src/Pure/General/sql.scala
author wenzelm
Sun Apr 30 09:23:03 2017 +0200 (2017-04-30)
changeset 65641 3b0110e25745
parent 65636 df804cdba5f9
child 65644 7ef438495a02
permissions -rw-r--r--
tuned message;
     1 /*  Title:      Pure/General/sql.scala
     2     Author:     Makarius
     3 
     4 Support for SQL databases: SQLite and PostgreSQL.
     5 */
     6 
     7 package isabelle
     8 
     9 import java.time.OffsetDateTime
    10 import java.sql.{DriverManager, Connection, PreparedStatement, ResultSet}
    11 
    12 
    13 object SQL
    14 {
    15   /** SQL language **/
    16 
    17   /* concrete syntax */
    18 
    19   def escape_char(c: Char): String =
    20     c match {
    21       case '\u0000' => "\\0"
    22       case '\'' => "\\'"
    23       case '\"' => "\\\""
    24       case '\b' => "\\b"
    25       case '\n' => "\\n"
    26       case '\r' => "\\r"
    27       case '\t' => "\\t"
    28       case '\u001a' => "\\Z"
    29       case '\\' => "\\\\"
    30       case _ => c.toString
    31     }
    32 
    33   def quote_string(s: String): String =
    34     "'" + s.map(escape_char(_)).mkString + "'"
    35 
    36   def quote_ident(s: String): String =
    37     quote(s.replace("\"", "\"\""))
    38 
    39   def enclosure(ss: Iterable[String]): String = ss.mkString("(", ", ", ")")
    40 
    41 
    42   /* types */
    43 
    44   object Type extends Enumeration
    45   {
    46     val Boolean = Value("BOOLEAN")
    47     val Int = Value("INTEGER")
    48     val Long = Value("BIGINT")
    49     val Double = Value("DOUBLE PRECISION")
    50     val String = Value("TEXT")
    51     val Bytes = Value("BLOB")
    52     val Date = Value("TIMESTAMP WITH TIME ZONE")
    53   }
    54 
    55   def sql_type_default(T: Type.Value): String = T.toString
    56 
    57   def sql_type_sqlite(T: Type.Value): String =
    58     if (T == Type.Boolean) "INTEGER"
    59     else if (T == Type.Date) "TEXT"
    60     else sql_type_default(T)
    61 
    62   def sql_type_postgresql(T: Type.Value): String =
    63     if (T == Type.Bytes) "BYTEA"
    64     else sql_type_default(T)
    65 
    66 
    67   /* columns */
    68 
    69   object Column
    70   {
    71     def bool(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    72       Column(name, Type.Boolean, strict, primary_key)
    73     def int(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    74       Column(name, Type.Int, strict, primary_key)
    75     def long(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    76       Column(name, Type.Long, strict, primary_key)
    77     def double(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    78       Column(name, Type.Double, strict, primary_key)
    79     def string(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    80       Column(name, Type.String, strict, primary_key)
    81     def bytes(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    82       Column(name, Type.Bytes, strict, primary_key)
    83     def date(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    84       Column(name, Type.Date, strict, primary_key)
    85   }
    86 
    87   sealed case class Column(
    88     name: String, T: Type.Value, strict: Boolean = false, primary_key: Boolean = false)
    89   {
    90     def sql_name: String = quote_ident(name)
    91     def sql_decl(sql_type: Type.Value => String): String =
    92       sql_name + " " + sql_type(T) + (if (strict || primary_key) " NOT NULL" else "")
    93 
    94     def sql_where_eq: String = "WHERE " + sql_name + " = "
    95     def sql_where_equal(s: String): String = sql_where_eq + quote_string(s)
    96 
    97     override def toString: String = sql_decl(sql_type_default)
    98   }
    99 
   100 
   101   /* tables */
   102 
   103   sealed case class Table(name: String, columns: List[Column])
   104   {
   105     private val columns_index: Map[String, Int] =
   106       columns.iterator.map(_.name).zipWithIndex.toMap
   107 
   108     Library.duplicates(columns.map(_.name)) match {
   109       case Nil =>
   110       case bad => error("Duplicate column names " + commas_quote(bad) + " for table " + quote(name))
   111     }
   112 
   113     def sql_columns(sql_type: Type.Value => String): String =
   114     {
   115       val primary_key =
   116         columns.filter(_.primary_key).map(_.name) match {
   117           case Nil => Nil
   118           case keys => List("PRIMARY KEY " + enclosure(keys))
   119         }
   120       enclosure(columns.map(_.sql_decl(sql_type)) ::: primary_key)
   121     }
   122 
   123     def sql_create(strict: Boolean, sql_type: Type.Value => String): String =
   124       "CREATE TABLE " + (if (strict) "" else "IF NOT EXISTS ") +
   125         quote_ident(name) + " " + sql_columns(sql_type)
   126 
   127     def sql_drop(strict: Boolean): String =
   128       "DROP TABLE " + (if (strict) "" else "IF EXISTS ") + quote_ident(name)
   129 
   130     def sql_create_index(
   131         index_name: String, index_columns: List[Column],
   132         strict: Boolean, unique: Boolean): String =
   133       "CREATE " + (if (unique) "UNIQUE " else "") + "INDEX " +
   134         (if (strict) "" else "IF NOT EXISTS ") + quote_ident(index_name) + " ON " +
   135         quote_ident(name) + " " + enclosure(index_columns.map(_.name))
   136 
   137     def sql_drop_index(index_name: String, strict: Boolean): String =
   138       "DROP INDEX " + (if (strict) "" else "IF EXISTS ") + quote_ident(index_name)
   139 
   140     def sql_insert: String =
   141       "INSERT INTO " + quote_ident(name) + " VALUES " + enclosure(columns.map(_ => "?"))
   142 
   143     def sql_delete: String =
   144       "DELETE FROM " + quote_ident(name)
   145 
   146     def sql_select(select_columns: List[Column], distinct: Boolean): String =
   147       "SELECT " + (if (distinct) "DISTINCT " else "") +
   148       commas(select_columns.map(_.sql_name)) + " FROM " + quote_ident(name)
   149 
   150     override def toString: String =
   151       "TABLE " + quote_ident(name) + " " + sql_columns(sql_type_default)
   152   }
   153 
   154 
   155 
   156   /** SQL database operations **/
   157 
   158   /* results */
   159 
   160   def iterator[A](rs: ResultSet)(get: ResultSet => A): Iterator[A] = new Iterator[A]
   161   {
   162     private var _next: Boolean = rs.next()
   163     def hasNext: Boolean = _next
   164     def next: A = { val x = get(rs); _next = rs.next(); x }
   165   }
   166 
   167   trait Database
   168   {
   169     /* types */
   170 
   171     def sql_type(T: Type.Value): String
   172 
   173 
   174     /* connection */
   175 
   176     def connection: Connection
   177 
   178     def close() { connection.close }
   179 
   180     def transaction[A](body: => A): A =
   181     {
   182       val auto_commit = connection.getAutoCommit
   183       try {
   184         connection.setAutoCommit(false)
   185         val savepoint = connection.setSavepoint
   186         try {
   187           val result = body
   188           connection.commit
   189           result
   190         }
   191         catch { case exn: Throwable => connection.rollback(savepoint); throw exn }
   192       }
   193       finally { connection.setAutoCommit(auto_commit) }
   194     }
   195 
   196 
   197     /* statements */
   198 
   199     def statement(sql: String): PreparedStatement = connection.prepareStatement(sql)
   200 
   201     def insert(table: Table): PreparedStatement = statement(table.sql_insert)
   202 
   203     def delete(table: Table, sql: String = ""): PreparedStatement =
   204       statement(table.sql_delete + (if (sql == "") "" else " " + sql))
   205 
   206     def select(table: Table, columns: List[Column], sql: String = "", distinct: Boolean = false)
   207         : PreparedStatement =
   208       statement(table.sql_select(columns, distinct) + (if (sql == "") "" else " " + sql))
   209 
   210 
   211     /* input */
   212 
   213     def set_bool(stmt: PreparedStatement, i: Int, x: Boolean) { stmt.setBoolean(i, x) }
   214     def set_bool(stmt: PreparedStatement, i: Int, x: Option[Boolean])
   215     {
   216       if (x.isDefined) set_bool(stmt, i, x.get)
   217       else stmt.setNull(i, java.sql.Types.BOOLEAN)
   218     }
   219 
   220     def set_int(stmt: PreparedStatement, i: Int, x: Int) { stmt.setInt(i, x) }
   221     def set_int(stmt: PreparedStatement, i: Int, x: Option[Int])
   222     {
   223       if (x.isDefined) set_int(stmt, i, x.get)
   224       else stmt.setNull(i, java.sql.Types.INTEGER)
   225     }
   226 
   227     def set_long(stmt: PreparedStatement, i: Int, x: Long) { stmt.setLong(i, x) }
   228     def set_long(stmt: PreparedStatement, i: Int, x: Option[Long])
   229     {
   230       if (x.isDefined) set_long(stmt, i, x.get)
   231       else stmt.setNull(i, java.sql.Types.BIGINT)
   232     }
   233 
   234     def set_double(stmt: PreparedStatement, i: Int, x: Double) { stmt.setDouble(i, x) }
   235     def set_double(stmt: PreparedStatement, i: Int, x: Option[Double])
   236     {
   237       if (x.isDefined) set_double(stmt, i, x.get)
   238       else stmt.setNull(i, java.sql.Types.DOUBLE)
   239     }
   240 
   241     def set_string(stmt: PreparedStatement, i: Int, x: String) { stmt.setString(i, x) }
   242     def set_string(stmt: PreparedStatement, i: Int, x: Option[String]): Unit =
   243       set_string(stmt, i, x.orNull)
   244 
   245     def set_bytes(stmt: PreparedStatement, i: Int, bytes: Bytes)
   246     {
   247       if (bytes == null) stmt.setBytes(i, null)
   248       else stmt.setBinaryStream(i, bytes.stream(), bytes.length)
   249     }
   250     def set_bytes(stmt: PreparedStatement, i: Int, bytes: Option[Bytes]): Unit =
   251       set_bytes(stmt, i, bytes.orNull)
   252 
   253     def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit
   254     def set_date(stmt: PreparedStatement, i: Int, date: Option[Date]): Unit =
   255       set_date(stmt, i, date.orNull)
   256 
   257 
   258     /* output */
   259 
   260     def bool(rs: ResultSet, column: Column): Boolean = rs.getBoolean(column.name)
   261     def int(rs: ResultSet, column: Column): Int = rs.getInt(column.name)
   262     def long(rs: ResultSet, column: Column): Long = rs.getLong(column.name)
   263     def double(rs: ResultSet, column: Column): Double = rs.getDouble(column.name)
   264     def string(rs: ResultSet, column: Column): String =
   265     {
   266       val s = rs.getString(column.name)
   267       if (s == null) "" else s
   268     }
   269     def bytes(rs: ResultSet, column: Column): Bytes =
   270     {
   271       val bs = rs.getBytes(column.name)
   272       if (bs == null) Bytes.empty else Bytes(bs)
   273     }
   274     def date(rs: ResultSet, column: Column): Date
   275 
   276     def get[A](rs: ResultSet, column: Column, f: (ResultSet, Column) => A): Option[A] =
   277     {
   278       val x = f(rs, column)
   279       if (rs.wasNull) None else Some(x)
   280     }
   281 
   282 
   283     /* tables */
   284 
   285     def tables: List[String] =
   286       iterator(connection.getMetaData.getTables(null, null, "%", null))(_.getString(3)).toList
   287 
   288     def create_table(table: Table, strict: Boolean = false, sql: String = ""): Unit =
   289       using(statement(table.sql_create(strict, sql_type) + (if (sql == "") "" else " " + sql)))(
   290         _.execute())
   291 
   292     def drop_table(table: Table, strict: Boolean = false): Unit =
   293       using(statement(table.sql_drop(strict)))(_.execute())
   294 
   295     def create_index(table: Table, name: String, columns: List[Column],
   296         strict: Boolean = false, unique: Boolean = false): Unit =
   297       using(statement(table.sql_create_index(name, columns, strict, unique)))(_.execute())
   298 
   299     def drop_index(table: Table, name: String, strict: Boolean = false): Unit =
   300       using(statement(table.sql_drop_index(name, strict)))(_.execute())
   301   }
   302 }
   303 
   304 
   305 
   306 /** SQLite **/
   307 
   308 object SQLite
   309 {
   310   // see https://www.sqlite.org/lang_datefunc.html
   311   val date_format: Date.Format = Date.Format("uuuu-MM-dd HH:mm:ss.SSS x")
   312 
   313   lazy val init_jdbc: Unit = Class.forName("org.sqlite.JDBC")
   314 
   315   def open_database(path: Path): Database =
   316   {
   317     init_jdbc
   318     val path0 = path.expand
   319     val s0 = File.platform_path(path0)
   320     val s1 = if (Platform.is_windows) s0.replace('\\', '/') else s0
   321     val connection = DriverManager.getConnection("jdbc:sqlite:" + s1)
   322     new Database(path0.toString, connection)
   323   }
   324 
   325   class Database private[SQLite](name: String, val connection: Connection) extends SQL.Database
   326   {
   327     override def toString: String = name
   328 
   329     def sql_type(T: SQL.Type.Value): String = SQL.sql_type_sqlite(T)
   330 
   331     def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
   332       if (date == null) set_string(stmt, i, null: String)
   333       else set_string(stmt, i, date_format(date))
   334 
   335     def date(rs: ResultSet, column: SQL.Column): Date =
   336       date_format.parse(string(rs, column))
   337 
   338     def rebuild { using(statement("VACUUM"))(_.execute()) }
   339   }
   340 }
   341 
   342 
   343 
   344 /** PostgreSQL **/
   345 
   346 object PostgreSQL
   347 {
   348   val default_port = 5432
   349 
   350   lazy val init_jdbc: Unit = Class.forName("org.postgresql.Driver")
   351 
   352   def open_database(
   353     user: String,
   354     password: String,
   355     database: String = "",
   356     host: String = "",
   357     port: Int = 0,
   358     ssh: Option[SSH.Session] = None,
   359     ssh_close: Boolean = false): Database =
   360   {
   361     init_jdbc
   362 
   363     if (user == "") error("Undefined database user")
   364 
   365     val db_host = if (host != "") host else "localhost"
   366     val db_port = if (port > 0 && port != default_port) ":" + port else ""
   367     val db_name = "/" + (if (database != "") database else user)
   368 
   369     val (url, name, port_forwarding) =
   370       ssh match {
   371         case None =>
   372           val spec = db_host + db_port + db_name
   373           val url = "jdbc:postgresql://" + spec
   374           val name = user + "@" + spec
   375           (url, name, None)
   376         case Some(ssh) =>
   377           val fw =
   378             ssh.port_forwarding(remote_host = db_host,
   379               remote_port = if (port > 0) port else default_port,
   380               ssh_close = ssh_close)
   381           val url = "jdbc:postgresql://localhost:" + fw.local_port + db_name
   382           val name = user + "@" + fw + db_name + " via ssh " + ssh
   383           (url, name, Some(fw))
   384       }
   385     try {
   386       val connection = DriverManager.getConnection(url, user, password)
   387       new Database(name, connection, port_forwarding)
   388     }
   389     catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn }
   390   }
   391 
   392   class Database private[PostgreSQL](
   393       name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding])
   394     extends SQL.Database
   395   {
   396     override def toString: String = name
   397 
   398     def sql_type(T: SQL.Type.Value): String = SQL.sql_type_postgresql(T)
   399 
   400     // see https://jdbc.postgresql.org/documentation/head/8-date-time.html
   401     def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
   402       if (date == null) stmt.setObject(i, null)
   403       else stmt.setObject(i, OffsetDateTime.from(date.to_utc.rep))
   404 
   405     def date(rs: ResultSet, column: SQL.Column): Date =
   406       Date.instant(rs.getObject(column.name, classOf[OffsetDateTime]).toInstant)
   407 
   408     override def close() { super.close; port_forwarding.foreach(_.close) }
   409   }
   410 }