src/Pure/General/sql.scala
author wenzelm
Mon May 01 15:42:26 2017 +0200 (2017-05-01)
changeset 65668 366bc4e6a238
parent 65649 0818da4f67bb
child 65669 d2f19b4a16ae
permissions -rw-r--r--
more operations;
tuned;
     1 /*  Title:      Pure/General/sql.scala
     2     Author:     Makarius
     3 
     4 Support for SQL databases: SQLite and PostgreSQL.
     5 */
     6 
     7 package isabelle
     8 
     9 import java.time.OffsetDateTime
    10 import java.sql.{DriverManager, Connection, PreparedStatement, ResultSet}
    11 
    12 
    13 object SQL
    14 {
    15   /** SQL language **/
    16 
    17   /* concrete syntax */
    18 
    19   def escape_char(c: Char): String =
    20     c match {
    21       case '\u0000' => "\\0"
    22       case '\'' => "\\'"
    23       case '\"' => "\\\""
    24       case '\b' => "\\b"
    25       case '\n' => "\\n"
    26       case '\r' => "\\r"
    27       case '\t' => "\\t"
    28       case '\u001a' => "\\Z"
    29       case '\\' => "\\\\"
    30       case _ => c.toString
    31     }
    32 
    33   def string(s: String): String =
    34     "'" + s.map(escape_char(_)).mkString + "'"
    35 
    36   def ident(s: String): String =
    37     if (Long_Name.is_qualified(s)) s
    38     else quote(s.replace("\"", "\"\""))
    39 
    40   def enclose(s: String): String = "(" + s + ")"
    41   def enclosure(ss: Iterable[String]): String = ss.mkString("(", ", ", ")")
    42 
    43   def select(columns: List[Column], distinct: Boolean = false): String =
    44     "SELECT " + (if (distinct) "DISTINCT " else "") + commas(columns.map(_.sql)) + " FROM "
    45 
    46   def join(table1: Table, table2: Table, sql: String = "", outer: Boolean = false): String =
    47     table1.sql + (if (outer) " LEFT OUTER JOIN " else " INNER JOIN ") + table2.sql +
    48       (if (sql == "") "" else " ON " + sql)
    49 
    50   def join_outer(table1: Table, table2: Table, sql: String = ""): String =
    51     join(table1, table2, sql, outer = true)
    52 
    53 
    54   /* types */
    55 
    56   object Type extends Enumeration
    57   {
    58     val Boolean = Value("BOOLEAN")
    59     val Int = Value("INTEGER")
    60     val Long = Value("BIGINT")
    61     val Double = Value("DOUBLE PRECISION")
    62     val String = Value("TEXT")
    63     val Bytes = Value("BLOB")
    64     val Date = Value("TIMESTAMP WITH TIME ZONE")
    65   }
    66 
    67   def sql_type_default(T: Type.Value): String = T.toString
    68 
    69   def sql_type_sqlite(T: Type.Value): String =
    70     if (T == Type.Boolean) "INTEGER"
    71     else if (T == Type.Date) "TEXT"
    72     else sql_type_default(T)
    73 
    74   def sql_type_postgresql(T: Type.Value): String =
    75     if (T == Type.Bytes) "BYTEA"
    76     else sql_type_default(T)
    77 
    78 
    79   /* columns */
    80 
    81   object Column
    82   {
    83     def bool(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    84       Column(name, Type.Boolean, strict, primary_key)
    85     def int(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    86       Column(name, Type.Int, strict, primary_key)
    87     def long(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    88       Column(name, Type.Long, strict, primary_key)
    89     def double(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    90       Column(name, Type.Double, strict, primary_key)
    91     def string(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    92       Column(name, Type.String, strict, primary_key)
    93     def bytes(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    94       Column(name, Type.Bytes, strict, primary_key)
    95     def date(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    96       Column(name, Type.Date, strict, primary_key)
    97   }
    98 
    99   sealed case class Column(
   100     name: String, T: Type.Value, strict: Boolean = false, primary_key: Boolean = false)
   101   {
   102     def apply(table: Table): Column =
   103       Column(Long_Name.qualify(table.name, name), T, strict = strict, primary_key = primary_key)
   104 
   105     def sql: String = ident(name)
   106     def sql_decl(sql_type: Type.Value => String): String =
   107       sql + " " + sql_type(T) + (if (strict || primary_key) " NOT NULL" else "")
   108 
   109     def sql_where_eq: String = "WHERE " + sql + " = "
   110     def sql_where_equal(s: String): String = sql_where_eq + string(s)
   111 
   112     override def toString: String = sql_decl(sql_type_default)
   113   }
   114 
   115 
   116   /* tables */
   117 
   118   sealed case class Table(name: String, columns: List[Column])
   119   {
   120     private val columns_index: Map[String, Int] =
   121       columns.iterator.map(_.name).zipWithIndex.toMap
   122 
   123     Library.duplicates(columns.map(_.name)) match {
   124       case Nil =>
   125       case bad => error("Duplicate column names " + commas_quote(bad) + " for table " + quote(name))
   126     }
   127 
   128     def sql: String = ident(name)
   129 
   130     def sql_columns(sql_type: Type.Value => String): String =
   131     {
   132       val primary_key =
   133         columns.filter(_.primary_key).map(_.name) match {
   134           case Nil => Nil
   135           case keys => List("PRIMARY KEY " + enclosure(keys))
   136         }
   137       enclosure(columns.map(_.sql_decl(sql_type)) ::: primary_key)
   138     }
   139 
   140     def sql_create(strict: Boolean, sql_type: Type.Value => String): String =
   141       "CREATE TABLE " + (if (strict) "" else "IF NOT EXISTS ") +
   142         ident(name) + " " + sql_columns(sql_type)
   143 
   144     def sql_drop(strict: Boolean): String =
   145       "DROP TABLE " + (if (strict) "" else "IF EXISTS ") + ident(name)
   146 
   147     def sql_create_index(
   148         index_name: String, index_columns: List[Column],
   149         strict: Boolean, unique: Boolean): String =
   150       "CREATE " + (if (unique) "UNIQUE " else "") + "INDEX " +
   151         (if (strict) "" else "IF NOT EXISTS ") + ident(index_name) + " ON " +
   152         ident(name) + " " + enclosure(index_columns.map(_.name))
   153 
   154     def sql_drop_index(index_name: String, strict: Boolean): String =
   155       "DROP INDEX " + (if (strict) "" else "IF EXISTS ") + ident(index_name)
   156 
   157     def sql_insert: String =
   158       "INSERT INTO " + ident(name) + " VALUES " + enclosure(columns.map(_ => "?"))
   159 
   160     def sql_delete: String =
   161       "DELETE FROM " + ident(name)
   162 
   163     def sql_select(select_columns: List[Column], distinct: Boolean = false): String =
   164       select(select_columns, distinct = distinct) + ident(name)
   165 
   166     override def toString: String =
   167       "TABLE " + ident(name) + " " + sql_columns(sql_type_default)
   168   }
   169 
   170 
   171 
   172   /** SQL database operations **/
   173 
   174   /* results */
   175 
   176   def iterator[A](rs: ResultSet)(get: ResultSet => A): Iterator[A] = new Iterator[A]
   177   {
   178     private var _next: Boolean = rs.next()
   179     def hasNext: Boolean = _next
   180     def next: A = { val x = get(rs); _next = rs.next(); x }
   181   }
   182 
   183   trait Database
   184   {
   185     /* types */
   186 
   187     def sql_type(T: Type.Value): String
   188 
   189 
   190     /* connection */
   191 
   192     def connection: Connection
   193 
   194     def close() { connection.close }
   195 
   196     def transaction[A](body: => A): A =
   197     {
   198       val auto_commit = connection.getAutoCommit
   199       try {
   200         connection.setAutoCommit(false)
   201         val savepoint = connection.setSavepoint
   202         try {
   203           val result = body
   204           connection.commit
   205           result
   206         }
   207         catch { case exn: Throwable => connection.rollback(savepoint); throw exn }
   208       }
   209       finally { connection.setAutoCommit(auto_commit) }
   210     }
   211 
   212 
   213     /* statements */
   214 
   215     def statement(sql: String): PreparedStatement = connection.prepareStatement(sql)
   216 
   217     def insert(table: Table): PreparedStatement = statement(table.sql_insert)
   218 
   219     def delete(table: Table, sql: String = ""): PreparedStatement =
   220       statement(table.sql_delete + (if (sql == "") "" else " " + sql))
   221 
   222     def select(table: Table, columns: List[Column], sql: String = "", distinct: Boolean = false)
   223         : PreparedStatement =
   224       statement(table.sql_select(columns, distinct = distinct) + (if (sql == "") "" else " " + sql))
   225 
   226 
   227     /* input */
   228 
   229     def set_bool(stmt: PreparedStatement, i: Int, x: Boolean) { stmt.setBoolean(i, x) }
   230     def set_bool(stmt: PreparedStatement, i: Int, x: Option[Boolean])
   231     {
   232       if (x.isDefined) set_bool(stmt, i, x.get)
   233       else stmt.setNull(i, java.sql.Types.BOOLEAN)
   234     }
   235 
   236     def set_int(stmt: PreparedStatement, i: Int, x: Int) { stmt.setInt(i, x) }
   237     def set_int(stmt: PreparedStatement, i: Int, x: Option[Int])
   238     {
   239       if (x.isDefined) set_int(stmt, i, x.get)
   240       else stmt.setNull(i, java.sql.Types.INTEGER)
   241     }
   242 
   243     def set_long(stmt: PreparedStatement, i: Int, x: Long) { stmt.setLong(i, x) }
   244     def set_long(stmt: PreparedStatement, i: Int, x: Option[Long])
   245     {
   246       if (x.isDefined) set_long(stmt, i, x.get)
   247       else stmt.setNull(i, java.sql.Types.BIGINT)
   248     }
   249 
   250     def set_double(stmt: PreparedStatement, i: Int, x: Double) { stmt.setDouble(i, x) }
   251     def set_double(stmt: PreparedStatement, i: Int, x: Option[Double])
   252     {
   253       if (x.isDefined) set_double(stmt, i, x.get)
   254       else stmt.setNull(i, java.sql.Types.DOUBLE)
   255     }
   256 
   257     def set_string(stmt: PreparedStatement, i: Int, x: String) { stmt.setString(i, x) }
   258     def set_string(stmt: PreparedStatement, i: Int, x: Option[String]): Unit =
   259       set_string(stmt, i, x.orNull)
   260 
   261     def set_bytes(stmt: PreparedStatement, i: Int, bytes: Bytes)
   262     {
   263       if (bytes == null) stmt.setBytes(i, null)
   264       else stmt.setBinaryStream(i, bytes.stream(), bytes.length)
   265     }
   266     def set_bytes(stmt: PreparedStatement, i: Int, bytes: Option[Bytes]): Unit =
   267       set_bytes(stmt, i, bytes.orNull)
   268 
   269     def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit
   270     def set_date(stmt: PreparedStatement, i: Int, date: Option[Date]): Unit =
   271       set_date(stmt, i, date.orNull)
   272 
   273 
   274     /* output */
   275 
   276     def bool(rs: ResultSet, column: Column): Boolean = rs.getBoolean(column.name)
   277     def int(rs: ResultSet, column: Column): Int = rs.getInt(column.name)
   278     def long(rs: ResultSet, column: Column): Long = rs.getLong(column.name)
   279     def double(rs: ResultSet, column: Column): Double = rs.getDouble(column.name)
   280     def string(rs: ResultSet, column: Column): String =
   281     {
   282       val s = rs.getString(column.name)
   283       if (s == null) "" else s
   284     }
   285     def bytes(rs: ResultSet, column: Column): Bytes =
   286     {
   287       val bs = rs.getBytes(column.name)
   288       if (bs == null) Bytes.empty else Bytes(bs)
   289     }
   290     def date(rs: ResultSet, column: Column): Date
   291 
   292     def get[A](rs: ResultSet, column: Column, f: (ResultSet, Column) => A): Option[A] =
   293     {
   294       val x = f(rs, column)
   295       if (rs.wasNull) None else Some(x)
   296     }
   297 
   298 
   299     /* tables */
   300 
   301     def tables: List[String] =
   302       iterator(connection.getMetaData.getTables(null, null, "%", null))(_.getString(3)).toList
   303 
   304     def create_table(table: Table, strict: Boolean = false, sql: String = ""): Unit =
   305       using(statement(table.sql_create(strict, sql_type) + (if (sql == "") "" else " " + sql)))(
   306         _.execute())
   307 
   308     def drop_table(table: Table, strict: Boolean = false): Unit =
   309       using(statement(table.sql_drop(strict)))(_.execute())
   310 
   311     def create_index(table: Table, name: String, columns: List[Column],
   312         strict: Boolean = false, unique: Boolean = false): Unit =
   313       using(statement(table.sql_create_index(name, columns, strict, unique)))(_.execute())
   314 
   315     def drop_index(table: Table, name: String, strict: Boolean = false): Unit =
   316       using(statement(table.sql_drop_index(name, strict)))(_.execute())
   317   }
   318 }
   319 
   320 
   321 
   322 /** SQLite **/
   323 
   324 object SQLite
   325 {
   326   // see https://www.sqlite.org/lang_datefunc.html
   327   val date_format: Date.Format = Date.Format("uuuu-MM-dd HH:mm:ss.SSS x")
   328 
   329   lazy val init_jdbc: Unit = Class.forName("org.sqlite.JDBC")
   330 
   331   def open_database(path: Path): Database =
   332   {
   333     init_jdbc
   334     val path0 = path.expand
   335     val s0 = File.platform_path(path0)
   336     val s1 = if (Platform.is_windows) s0.replace('\\', '/') else s0
   337     val connection = DriverManager.getConnection("jdbc:sqlite:" + s1)
   338     new Database(path0.toString, connection)
   339   }
   340 
   341   class Database private[SQLite](name: String, val connection: Connection) extends SQL.Database
   342   {
   343     override def toString: String = name
   344 
   345     def sql_type(T: SQL.Type.Value): String = SQL.sql_type_sqlite(T)
   346 
   347     def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
   348       if (date == null) set_string(stmt, i, null: String)
   349       else set_string(stmt, i, date_format(date))
   350 
   351     def date(rs: ResultSet, column: SQL.Column): Date =
   352       date_format.parse(string(rs, column))
   353 
   354     def rebuild { using(statement("VACUUM"))(_.execute()) }
   355   }
   356 }
   357 
   358 
   359 
   360 /** PostgreSQL **/
   361 
   362 object PostgreSQL
   363 {
   364   val default_port = 5432
   365 
   366   lazy val init_jdbc: Unit = Class.forName("org.postgresql.Driver")
   367 
   368   def open_database(
   369     user: String,
   370     password: String,
   371     database: String = "",
   372     host: String = "",
   373     port: Int = 0,
   374     ssh: Option[SSH.Session] = None,
   375     ssh_close: Boolean = false): Database =
   376   {
   377     init_jdbc
   378 
   379     if (user == "") error("Undefined database user")
   380 
   381     val db_host = if (host != "") host else "localhost"
   382     val db_port = if (port > 0 && port != default_port) ":" + port else ""
   383     val db_name = "/" + (if (database != "") database else user)
   384 
   385     val (url, name, port_forwarding) =
   386       ssh match {
   387         case None =>
   388           val spec = db_host + db_port + db_name
   389           val url = "jdbc:postgresql://" + spec
   390           val name = user + "@" + spec
   391           (url, name, None)
   392         case Some(ssh) =>
   393           val fw =
   394             ssh.port_forwarding(remote_host = db_host,
   395               remote_port = if (port > 0) port else default_port,
   396               ssh_close = ssh_close)
   397           val url = "jdbc:postgresql://localhost:" + fw.local_port + db_name
   398           val name = user + "@" + fw + db_name + " via ssh " + ssh
   399           (url, name, Some(fw))
   400       }
   401     try {
   402       val connection = DriverManager.getConnection(url, user, password)
   403       new Database(name, connection, port_forwarding)
   404     }
   405     catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn }
   406   }
   407 
   408   class Database private[PostgreSQL](
   409       name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding])
   410     extends SQL.Database
   411   {
   412     override def toString: String = name
   413 
   414     def sql_type(T: SQL.Type.Value): String = SQL.sql_type_postgresql(T)
   415 
   416     // see https://jdbc.postgresql.org/documentation/head/8-date-time.html
   417     def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit =
   418       if (date == null) stmt.setObject(i, null)
   419       else stmt.setObject(i, OffsetDateTime.from(date.to_utc.rep))
   420 
   421     def date(rs: ResultSet, column: SQL.Column): Date =
   422       Date.instant(rs.getObject(column.name, classOf[OffsetDateTime]).toInstant)
   423 
   424     override def close() { super.close; port_forwarding.foreach(_.close) }
   425   }
   426 }