src/Pure/General/sql.scala
author wenzelm
Fri Mar 17 20:33:27 2017 +0100 (2017-03-17)
changeset 65292 e3bd1e7ddd23
parent 65280 ef37f5236794
child 65319 64da14387b2c
permissions -rw-r--r--
more robust JDBC initialization, e.g. required for Isabelle/jEdit startup;
     1 /*  Title:      Pure/General/sql.scala
     2     Author:     Makarius
     3 
     4 Support for SQL databases: SQLite and PostgreSQL.
     5 */
     6 
     7 package isabelle
     8 
     9 import java.time.OffsetDateTime
    10 import java.sql.{DriverManager, Connection, PreparedStatement, ResultSet}
    11 
    12 
    13 object SQL
    14 {
    15   /** SQL language **/
    16 
    17   /* concrete syntax */
    18 
    19   def quote_char(c: Char): String =
    20     c match {
    21       case '\u0000' => "\\0"
    22       case '\'' => "\\'"
    23       case '\"' => "\\\""
    24       case '\b' => "\\b"
    25       case '\n' => "\\n"
    26       case '\r' => "\\r"
    27       case '\t' => "\\t"
    28       case '\u001a' => "\\Z"
    29       case '\\' => "\\\\"
    30       case _ => c.toString
    31     }
    32 
    33   def quote_string(s: String): String =
    34     quote(s.map(quote_char(_)).mkString)
    35 
    36   def quote_ident(s: String): String =
    37     quote(s.replace("\"", "\"\""))
    38 
    39   def enclosure(ss: Iterable[String]): String = ss.mkString("(", ", ", ")")
    40 
    41 
    42   /* types */
    43 
    44   object Type extends Enumeration
    45   {
    46     val Boolean = Value("BOOLEAN")
    47     val Int = Value("INTEGER")
    48     val Long = Value("BIGINT")
    49     val Double = Value("DOUBLE PRECISION")
    50     val String = Value("TEXT")
    51     val Bytes = Value("BLOB")
    52     val Date = Value("TIMESTAMP WITH TIME ZONE")
    53   }
    54 
    55   def sql_type_default(T: Type.Value): String = T.toString
    56 
    57   def sql_type_sqlite(T: Type.Value): String =
    58     if (T == Type.Boolean) "INTEGER"
    59     else if (T == Type.Date) "TEXT"
    60     else sql_type_default(T)
    61 
    62   def sql_type_postgresql(T: Type.Value): String =
    63     if (T == Type.Bytes) "BYTEA"
    64     else sql_type_default(T)
    65 
    66 
    67   /* columns */
    68 
    69   object Column
    70   {
    71     def bool(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    72       Column(name, Type.Boolean, strict, primary_key)
    73     def int(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    74       Column(name, Type.Int, strict, primary_key)
    75     def long(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    76       Column(name, Type.Long, strict, primary_key)
    77     def double(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    78       Column(name, Type.Double, strict, primary_key)
    79     def string(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    80       Column(name, Type.String, strict, primary_key)
    81     def bytes(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    82       Column(name, Type.Bytes, strict, primary_key)
    83     def date(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    84       Column(name, Type.Date, strict, primary_key)
    85   }
    86 
    87   sealed case class Column(
    88     name: String, T: Type.Value, strict: Boolean = false, primary_key: Boolean = false)
    89   {
    90     def sql_name: String = quote_ident(name)
    91     def sql_decl(sql_type: Type.Value => String): String =
    92       sql_name + " " + sql_type(T) +
    93       (if (strict) " NOT NULL" else "") +
    94       (if (primary_key) " PRIMARY KEY" else "")
    95 
    96     override def toString: String = sql_decl(sql_type_default)
    97   }
    98 
    99 
   100   /* tables */
   101 
   102   sealed case class Table(name: String, columns: List[Column])
   103   {
   104     private val columns_index: Map[String, Int] =
   105       columns.iterator.map(_.name).zipWithIndex.toMap
   106 
   107     Library.duplicates(columns.map(_.name)) match {
   108       case Nil =>
   109       case bad => error("Duplicate column names " + commas_quote(bad) + " for table " + quote(name))
   110     }
   111 
   112     columns.filter(_.primary_key) match {
   113       case bad if bad.length > 1 =>
   114         error("Multiple primary keys " + commas_quote(bad.map(_.name)) + " for table " + quote(name))
   115       case _ =>
   116     }
   117 
   118     def sql_create(strict: Boolean, rowid: Boolean, sql_type: Type.Value => String): String =
   119       "CREATE TABLE " + (if (strict) "" else "IF NOT EXISTS ") +
   120         quote_ident(name) + " " + enclosure(columns.map(_.sql_decl(sql_type))) +
   121         (if (rowid) "" else " WITHOUT ROWID")
   122 
   123     def sql_drop(strict: Boolean): String =
   124       "DROP TABLE " + (if (strict) "" else "IF EXISTS ") + quote_ident(name)
   125 
   126     def sql_create_index(
   127         index_name: String, index_columns: List[Column],
   128         strict: Boolean, unique: Boolean): String =
   129       "CREATE " + (if (unique) "UNIQUE " else "") + "INDEX " +
   130         (if (strict) "" else "IF NOT EXISTS ") + quote_ident(index_name) + " ON " +
   131         quote_ident(name) + " " + enclosure(index_columns.map(_.name))
   132 
   133     def sql_drop_index(index_name: String, strict: Boolean): String =
   134       "DROP INDEX " + (if (strict) "" else "IF EXISTS ") + quote_ident(index_name)
   135 
   136     def sql_insert: String =
   137       "INSERT INTO " + quote_ident(name) + " VALUES " + enclosure(columns.map(_ => "?"))
   138 
   139     def sql_select(select_columns: List[Column], distinct: Boolean): String =
   140       "SELECT " + (if (distinct) "DISTINCT " else "") +
   141       commas(select_columns.map(_.sql_name)) + " FROM " + quote_ident(name)
   142 
   143     override def toString: String =
   144       "TABLE " + quote_ident(name) + " " + enclosure(columns.map(_.toString))
   145   }
   146 
   147 
   148 
   149   /** SQL database operations **/
   150 
   151   /* results */
   152 
   153   def iterator[A](rs: ResultSet)(get: ResultSet => A): Iterator[A] = new Iterator[A]
   154   {
   155     private var _next: Boolean = rs.next()
   156     def hasNext: Boolean = _next
   157     def next: A = { val x = get(rs); _next = rs.next(); x }
   158   }
   159 
   160   trait Database
   161   {
   162     /* types */
   163 
   164     def sql_type(T: Type.Value): String
   165 
   166 
   167     /* connection */
   168 
   169     def connection: Connection
   170 
   171     def close() { connection.close }
   172 
   173     def transaction[A](body: => A): A =
   174     {
   175       val auto_commit = connection.getAutoCommit
   176       try {
   177         connection.setAutoCommit(false)
   178         val savepoint = connection.setSavepoint
   179         try {
   180           val result = body
   181           connection.commit
   182           result
   183         }
   184         catch { case exn: Throwable => connection.rollback(savepoint); throw exn }
   185       }
   186       finally { connection.setAutoCommit(auto_commit) }
   187     }
   188 
   189 
   190     /* statements */
   191 
   192     def statement(sql: String): PreparedStatement = connection.prepareStatement(sql)
   193 
   194     def insert_statement(table: Table): PreparedStatement = statement(table.sql_insert)
   195 
   196     def select_statement(table: Table, columns: List[Column],
   197         sql: String = "", distinct: Boolean = false): PreparedStatement =
   198       statement(table.sql_select(columns, distinct) + (if (sql == "") "" else " " + sql))
   199 
   200 
   201     /* input */
   202 
   203     def set_bool(stmt: PreparedStatement, i: Int, x: Boolean) { stmt.setBoolean(i, x) }
   204     def set_int(stmt: PreparedStatement, i: Int, x: Int) { stmt.setInt(i, x) }
   205     def set_long(stmt: PreparedStatement, i: Int, x: Long) { stmt.setLong(i, x) }
   206     def set_double(stmt: PreparedStatement, i: Int, x: Double) { stmt.setDouble(i, x) }
   207     def set_string(stmt: PreparedStatement, i: Int, x: String) { stmt.setString(i, x) }
   208     def set_bytes(stmt: PreparedStatement, i: Int, bytes: Bytes)
   209     { stmt.setBinaryStream(i, bytes.stream(), bytes.length) }
   210     def set_date(stmt: PreparedStatement, i: Int, date: Date)
   211 
   212 
   213     /* output */
   214 
   215     def bool(rs: ResultSet, name: String): Boolean = rs.getBoolean(name)
   216     def int(rs: ResultSet, name: String): Int = rs.getInt(name)
   217     def long(rs: ResultSet, name: String): Long = rs.getLong(name)
   218     def double(rs: ResultSet, name: String): Double = rs.getDouble(name)
   219     def string(rs: ResultSet, name: String): String =
   220     {
   221       val s = rs.getString(name)
   222       if (s == null) "" else s
   223     }
   224     def bytes(rs: ResultSet, name: String): Bytes =
   225     {
   226       val bs = rs.getBytes(name)
   227       if (bs == null) Bytes.empty else Bytes(bs)
   228     }
   229     def date(rs: ResultSet, name: String): Date
   230 
   231     def get[A](rs: ResultSet, name: String, f: (ResultSet, String) => A): Option[A] =
   232     {
   233       val x = f(rs, name)
   234       if (rs.wasNull) None else Some(x)
   235     }
   236 
   237 
   238     /* tables */
   239 
   240     def tables: List[String] =
   241       iterator(connection.getMetaData.getTables(null, null, "%", null))(_.getString(3)).toList
   242 
   243     def create_table(table: Table, strict: Boolean = false, rowid: Boolean = true): Unit =
   244       using(statement(table.sql_create(strict, rowid, sql_type)))(_.execute())
   245 
   246     def drop_table(table: Table, strict: Boolean = false): Unit =
   247       using(statement(table.sql_drop(strict)))(_.execute())
   248 
   249     def create_index(table: Table, name: String, columns: List[Column],
   250         strict: Boolean = false, unique: Boolean = false): Unit =
   251       using(statement(table.sql_create_index(name, columns, strict, unique)))(_.execute())
   252 
   253     def drop_index(table: Table, name: String, strict: Boolean = false): Unit =
   254       using(statement(table.sql_drop_index(name, strict)))(_.execute())
   255   }
   256 }
   257 
   258 
   259 
   260 /** SQLite **/
   261 
   262 object SQLite
   263 {
   264   // see https://www.sqlite.org/lang_datefunc.html
   265   val date_format: Date.Format = Date.Format("uuuu-MM-dd HH:mm:ss.SSS x")
   266 
   267   lazy val init_jdbc: Unit = Class.forName("org.sqlite.JDBC")
   268 
   269   def open_database(path: Path): Database =
   270   {
   271     init_jdbc
   272     val path0 = path.expand
   273     val s0 = File.platform_path(path0)
   274     val s1 = if (Platform.is_windows) s0.replace('\\', '/') else s0
   275     val connection = DriverManager.getConnection("jdbc:sqlite:" + s1)
   276     new Database(path0.toString, connection)
   277   }
   278 
   279   class Database private[SQLite](name: String, val connection: Connection) extends SQL.Database
   280   {
   281     override def toString: String = name
   282 
   283     def sql_type(T: SQL.Type.Value): String = SQL.sql_type_sqlite(T)
   284 
   285     def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
   286       set_string(stmt, i, date_format(date))
   287     def date(rs: ResultSet, name: String): Date =
   288       date_format.parse(string(rs, name))
   289 
   290     def rebuild { using(statement("VACUUM"))(_.execute()) }
   291   }
   292 }
   293 
   294 
   295 
   296 /** PostgreSQL **/
   297 
   298 object PostgreSQL
   299 {
   300   val default_port = 5432
   301 
   302   lazy val init_jdbc: Unit = Class.forName("org.postgresql.Driver")
   303 
   304   def open_database(
   305     user: String,
   306     password: String,
   307     database: String = "",
   308     host: String = "",
   309     port: Int = default_port,
   310     ssh: Option[SSH.Session] = None): Database =
   311   {
   312     init_jdbc
   313 
   314     require(user != "")
   315 
   316     val db_host = if (host != "") host else "localhost"
   317     val db_port = if (port != default_port) ":" + port else ""
   318     val db_name = "/" + (if (database != "") database else user)
   319 
   320     val (url, name, port_forwarding) =
   321       ssh match {
   322         case None =>
   323           val spec = db_host + db_port + db_name
   324           val url = "jdbc:postgresql://" + spec
   325           val name = user + "@" + spec
   326           (url, name, None)
   327         case Some(ssh) =>
   328           val fw = ssh.port_forwarding(remote_host = db_host, remote_port = port)
   329           val url = "jdbc:postgresql://localhost:" + fw.local_port + db_name
   330           val name = user + "@" + fw + db_name + " via ssh " + ssh
   331           (url, name, Some(fw))
   332       }
   333     try {
   334       val connection = DriverManager.getConnection(url, user, password)
   335       new Database(name, connection, port_forwarding)
   336     }
   337     catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn }
   338   }
   339 
   340   class Database private[PostgreSQL](
   341       name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding])
   342     extends SQL.Database
   343   {
   344     override def toString: String = name
   345 
   346     def sql_type(T: SQL.Type.Value): String = SQL.sql_type_postgresql(T)
   347 
   348     // see https://jdbc.postgresql.org/documentation/head/8-date-time.html
   349     def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
   350       stmt.setObject(i, OffsetDateTime.from(date.to_utc.rep))
   351     def date(rs: ResultSet, name: String): Date =
   352       Date.instant(rs.getObject(name, classOf[OffsetDateTime]).toInstant)
   353 
   354     override def close() { super.close; port_forwarding.foreach(_.close) }
   355   }
   356 }