src/Pure/General/sql.scala
author wenzelm
Mon Mar 25 16:45:08 2019 +0100 (2 months ago)
changeset 69980 f2e3adfd916f
parent 69393 ed0824ef337e
permissions -rw-r--r--
tuned signature;
     1 /*  Title:      Pure/General/sql.scala
     2     Author:     Makarius
     3 
     4 Support for SQL databases: SQLite and PostgreSQL.
     5 */
     6 
     7 package isabelle
     8 
     9 
    10 import java.time.OffsetDateTime
    11 import java.sql.{DriverManager, Connection, PreparedStatement, ResultSet}
    12 
    13 import scala.collection.mutable
    14 
    15 
    16 object SQL
    17 {
    18   /** SQL language **/
    19 
    20   type Source = String
    21 
    22 
    23   /* concrete syntax */
    24 
    25   def escape_char(c: Char): String =
    26     c match {
    27       case '\u0000' => "\\0"
    28       case '\'' => "\\'"
    29       case '\"' => "\\\""
    30       case '\b' => "\\b"
    31       case '\n' => "\\n"
    32       case '\r' => "\\r"
    33       case '\t' => "\\t"
    34       case '\u001a' => "\\Z"
    35       case '\\' => "\\\\"
    36       case _ => c.toString
    37     }
    38 
    39   def string(s: String): Source =
    40     s.iterator.map(escape_char(_)).mkString("'", "", "'")
    41 
    42   def ident(s: String): Source =
    43     Long_Name.implode(Long_Name.explode(s).map(a => quote(a.replace("\"", "\"\""))))
    44 
    45   def enclose(s: Source): Source = "(" + s + ")"
    46   def enclosure(ss: Iterable[Source]): Source = ss.mkString("(", ", ", ")")
    47 
    48   def select(columns: List[Column] = Nil, distinct: Boolean = false): Source =
    49     "SELECT " + (if (distinct) "DISTINCT " else "") +
    50     (if (columns.isEmpty) "*" else commas(columns.map(_.ident))) + " FROM "
    51 
    52   val join_outer: Source = " LEFT OUTER JOIN "
    53   val join_inner: Source = " INNER JOIN "
    54   def join(outer: Boolean): Source = if (outer) join_outer else join_inner
    55 
    56   def member(x: Source, set: Iterable[String]): Source =
    57     set.iterator.map(a => x + " = " + SQL.string(a)).mkString("(", " OR ", ")")
    58 
    59 
    60   /* types */
    61 
    62   object Type extends Enumeration
    63   {
    64     val Boolean = Value("BOOLEAN")
    65     val Int = Value("INTEGER")
    66     val Long = Value("BIGINT")
    67     val Double = Value("DOUBLE PRECISION")
    68     val String = Value("TEXT")
    69     val Bytes = Value("BLOB")
    70     val Date = Value("TIMESTAMP WITH TIME ZONE")
    71   }
    72 
    73   def sql_type_default(T: Type.Value): Source = T.toString
    74 
    75   def sql_type_sqlite(T: Type.Value): Source =
    76     if (T == Type.Boolean) "INTEGER"
    77     else if (T == Type.Date) "TEXT"
    78     else sql_type_default(T)
    79 
    80   def sql_type_postgresql(T: Type.Value): Source =
    81     if (T == Type.Bytes) "BYTEA"
    82     else sql_type_default(T)
    83 
    84 
    85   /* columns */
    86 
    87   object Column
    88   {
    89     def bool(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    90       Column(name, Type.Boolean, strict, primary_key)
    91     def int(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    92       Column(name, Type.Int, strict, primary_key)
    93     def long(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    94       Column(name, Type.Long, strict, primary_key)
    95     def double(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    96       Column(name, Type.Double, strict, primary_key)
    97     def string(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
    98       Column(name, Type.String, strict, primary_key)
    99     def bytes(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
   100       Column(name, Type.Bytes, strict, primary_key)
   101     def date(name: String, strict: Boolean = false, primary_key: Boolean = false): Column =
   102       Column(name, Type.Date, strict, primary_key)
   103   }
   104 
   105   sealed case class Column(
   106     name: String, T: Type.Value, strict: Boolean = false, primary_key: Boolean = false,
   107     expr: SQL.Source = "")
   108   {
   109     def make_primary_key: Column = copy(primary_key = true)
   110 
   111     def apply(table: Table): Column =
   112       Column(Long_Name.qualify(table.name, name), T, strict = strict, primary_key = primary_key)
   113 
   114     def ident: Source =
   115       if (expr == "") SQL.ident(name)
   116       else enclose(expr) + " AS " + SQL.ident(name)
   117 
   118     def decl(sql_type: Type.Value => Source): Source =
   119       ident + " " + sql_type(T) + (if (strict || primary_key) " NOT NULL" else "")
   120 
   121     def defined: String = ident + " IS NOT NULL"
   122     def undefined: String = ident + " IS NULL"
   123 
   124     def equal(s: String): Source = ident + " = " + string(s)
   125     def where_equal(s: String): Source = "WHERE " + equal(s)
   126 
   127     override def toString: Source = ident
   128   }
   129 
   130 
   131   /* tables */
   132 
   133   sealed case class Table(name: String, columns: List[Column], body: Source = "")
   134   {
   135     private val columns_index: Map[String, Int] =
   136       columns.iterator.map(_.name).zipWithIndex.toMap
   137 
   138     Library.duplicates(columns.map(_.name)) match {
   139       case Nil =>
   140       case bad => error("Duplicate column names " + commas_quote(bad) + " for table " + quote(name))
   141     }
   142 
   143     def ident: Source = SQL.ident(name)
   144 
   145     def query: Source =
   146       if (body == "") error("Missing SQL body for table " + quote(name))
   147       else SQL.enclose(body)
   148 
   149     def query_named: Source = query + " AS " + SQL.ident(name)
   150 
   151     def create(strict: Boolean = false, sql_type: Type.Value => Source): Source =
   152     {
   153       val primary_key =
   154         columns.filter(_.primary_key).map(_.name) match {
   155           case Nil => Nil
   156           case keys => List("PRIMARY KEY " + enclosure(keys))
   157         }
   158       "CREATE TABLE " + (if (strict) "" else "IF NOT EXISTS ") +
   159         ident + " " + enclosure(columns.map(_.decl(sql_type)) ::: primary_key)
   160     }
   161 
   162     def create_index(index_name: String, index_columns: List[Column],
   163         strict: Boolean = false, unique: Boolean = false): Source =
   164       "CREATE " + (if (unique) "UNIQUE " else "") + "INDEX " +
   165         (if (strict) "" else "IF NOT EXISTS ") + SQL.ident(index_name) + " ON " +
   166         ident + " " + enclosure(index_columns.map(_.name))
   167 
   168     def insert_cmd(cmd: Source, sql: Source = ""): Source =
   169       cmd + " INTO " + ident + " VALUES " + enclosure(columns.map(_ => "?")) +
   170         (if (sql == "") "" else " " + sql)
   171 
   172     def insert(sql: Source = ""): Source = insert_cmd("INSERT", sql)
   173 
   174     def delete(sql: Source = ""): Source =
   175       "DELETE FROM " + ident +
   176         (if (sql == "") "" else " " + sql)
   177 
   178     def select(
   179         select_columns: List[Column] = Nil, sql: Source = "", distinct: Boolean = false): Source =
   180       SQL.select(select_columns, distinct = distinct) + ident +
   181         (if (sql == "") "" else " " + sql)
   182 
   183     override def toString: Source = ident
   184   }
   185 
   186 
   187 
   188   /** SQL database operations **/
   189 
   190   /* statements */
   191 
   192   class Statement private[SQL](val db: Database, val rep: PreparedStatement)
   193     extends AutoCloseable
   194   {
   195     stmt =>
   196 
   197     object bool
   198     {
   199       def update(i: Int, x: Boolean) { rep.setBoolean(i, x) }
   200       def update(i: Int, x: Option[Boolean])
   201       {
   202         if (x.isDefined) update(i, x.get)
   203         else rep.setNull(i, java.sql.Types.BOOLEAN)
   204       }
   205     }
   206     object int
   207     {
   208       def update(i: Int, x: Int) { rep.setInt(i, x) }
   209       def update(i: Int, x: Option[Int])
   210       {
   211         if (x.isDefined) update(i, x.get)
   212         else rep.setNull(i, java.sql.Types.INTEGER)
   213       }
   214     }
   215     object long
   216     {
   217       def update(i: Int, x: Long) { rep.setLong(i, x) }
   218       def update(i: Int, x: Option[Long])
   219       {
   220         if (x.isDefined) update(i, x.get)
   221         else rep.setNull(i, java.sql.Types.BIGINT)
   222       }
   223     }
   224     object double
   225     {
   226       def update(i: Int, x: Double) { rep.setDouble(i, x) }
   227       def update(i: Int, x: Option[Double])
   228       {
   229         if (x.isDefined) update(i, x.get)
   230         else rep.setNull(i, java.sql.Types.DOUBLE)
   231       }
   232     }
   233     object string
   234     {
   235       def update(i: Int, x: String) { rep.setString(i, x) }
   236       def update(i: Int, x: Option[String]): Unit = update(i, x.orNull)
   237     }
   238     object bytes
   239     {
   240       def update(i: Int, bytes: Bytes)
   241       {
   242         if (bytes == null) rep.setBytes(i, null)
   243         else rep.setBinaryStream(i, bytes.stream(), bytes.length)
   244       }
   245       def update(i: Int, bytes: Option[Bytes]): Unit = update(i, bytes.orNull)
   246     }
   247     object date
   248     {
   249       def update(i: Int, date: Date): Unit = db.update_date(stmt, i, date)
   250       def update(i: Int, date: Option[Date]): Unit = update(i, date.orNull)
   251     }
   252 
   253     def execute(): Boolean = rep.execute()
   254     def execute_query(): Result = new Result(this, rep.executeQuery())
   255 
   256     def close(): Unit = rep.close
   257   }
   258 
   259 
   260   /* results */
   261 
   262   class Result private[SQL](val stmt: Statement, val rep: ResultSet)
   263   {
   264     res =>
   265 
   266     def next(): Boolean = rep.next()
   267 
   268     def iterator[A](get: Result => A): Iterator[A] = new Iterator[A]
   269     {
   270       private var _next: Boolean = res.next()
   271       def hasNext: Boolean = _next
   272       def next: A = { val x = get(res); _next = res.next(); x }
   273     }
   274 
   275     def bool(column: Column): Boolean = rep.getBoolean(column.name)
   276     def int(column: Column): Int = rep.getInt(column.name)
   277     def long(column: Column): Long = rep.getLong(column.name)
   278     def double(column: Column): Double = rep.getDouble(column.name)
   279     def string(column: Column): String =
   280     {
   281       val s = rep.getString(column.name)
   282       if (s == null) "" else s
   283     }
   284     def bytes(column: Column): Bytes =
   285     {
   286       val bs = rep.getBytes(column.name)
   287       if (bs == null) Bytes.empty else Bytes(bs)
   288     }
   289     def date(column: Column): Date = stmt.db.date(res, column)
   290 
   291     def timing(c1: Column, c2: Column, c3: Column) =
   292       Timing(Time.ms(long(c1)), Time.ms(long(c2)), Time.ms(long(c3)))
   293 
   294     def get[A](column: Column, f: Column => A): Option[A] =
   295     {
   296       val x = f(column)
   297       if (rep.wasNull) None else Some(x)
   298     }
   299     def get_bool(column: Column): Option[Boolean] = get(column, bool _)
   300     def get_int(column: Column): Option[Int] = get(column, int _)
   301     def get_long(column: Column): Option[Long] = get(column, long _)
   302     def get_double(column: Column): Option[Double] = get(column, double _)
   303     def get_string(column: Column): Option[String] = get(column, string _)
   304     def get_bytes(column: Column): Option[Bytes] = get(column, bytes _)
   305     def get_date(column: Column): Option[Date] = get(column, date _)
   306   }
   307 
   308 
   309   /* database */
   310 
   311   trait Database extends AutoCloseable
   312   {
   313     db =>
   314 
   315 
   316     /* types */
   317 
   318     def sql_type(T: Type.Value): Source
   319 
   320 
   321     /* connection */
   322 
   323     def connection: Connection
   324 
   325     def close() { connection.close }
   326 
   327     def transaction[A](body: => A): A =
   328     {
   329       val auto_commit = connection.getAutoCommit
   330       try {
   331         connection.setAutoCommit(false)
   332         val savepoint = connection.setSavepoint
   333         try {
   334           val result = body
   335           connection.commit
   336           result
   337         }
   338         catch { case exn: Throwable => connection.rollback(savepoint); throw exn }
   339       }
   340       finally { connection.setAutoCommit(auto_commit) }
   341     }
   342 
   343 
   344     /* statements and results */
   345 
   346     def statement(sql: Source): Statement =
   347       new Statement(db, connection.prepareStatement(sql))
   348 
   349     def using_statement[A](sql: Source)(f: Statement => A): A =
   350       using(statement(sql))(f)
   351 
   352     def update_date(stmt: Statement, i: Int, date: Date): Unit
   353     def date(res: Result, column: Column): Date
   354 
   355     def insert_permissive(table: Table, sql: Source = ""): Source
   356 
   357 
   358     /* tables and views */
   359 
   360     def tables: List[String] =
   361     {
   362       val result = new mutable.ListBuffer[String]
   363       val rs = connection.getMetaData.getTables(null, null, "%", null)
   364       while (rs.next) { result += rs.getString(3) }
   365       result.toList
   366     }
   367 
   368     def create_table(table: Table, strict: Boolean = false, sql: Source = ""): Unit =
   369       using_statement(
   370         table.create(strict, sql_type) + (if (sql == "") "" else " " + sql))(_.execute())
   371 
   372     def create_index(table: Table, name: String, columns: List[Column],
   373         strict: Boolean = false, unique: Boolean = false): Unit =
   374       using_statement(table.create_index(name, columns, strict, unique))(_.execute())
   375 
   376     def create_view(table: Table, strict: Boolean = false): Unit =
   377     {
   378       if (strict || !tables.contains(table.name)) {
   379         val sql = "CREATE VIEW " + table + " AS " + { table.query; table.body }
   380         using_statement(sql)(_.execute())
   381       }
   382     }
   383   }
   384 }
   385 
   386 
   387 
   388 /** SQLite **/
   389 
   390 object SQLite
   391 {
   392   // see https://www.sqlite.org/lang_datefunc.html
   393   val date_format: Date.Format = Date.Format("uuuu-MM-dd HH:mm:ss.SSS x")
   394 
   395   lazy val init_jdbc: Unit =
   396   {
   397     val lib_path = Path.explode("$ISABELLE_SQLITE_HOME/" + Platform.jvm_platform)
   398     val lib_name =
   399       File.find_files(lib_path.file) match {
   400         case List(file) => file.getName
   401         case _ => error("Exactly one file expected in directory " + lib_path.expand)
   402       }
   403     System.setProperty("org.sqlite.lib.path", File.platform_path(lib_path))
   404     System.setProperty("org.sqlite.lib.name", lib_name)
   405 
   406     Class.forName("org.sqlite.JDBC")
   407   }
   408 
   409   def open_database(path: Path): Database =
   410   {
   411     init_jdbc
   412     val path0 = path.expand
   413     val s0 = File.platform_path(path0)
   414     val s1 = if (Platform.is_windows) s0.replace('\\', '/') else s0
   415     val connection = DriverManager.getConnection("jdbc:sqlite:" + s1)
   416     new Database(path0.toString, connection)
   417   }
   418 
   419   class Database private[SQLite](name: String, val connection: Connection) extends SQL.Database
   420   {
   421     override def toString: String = name
   422 
   423     def sql_type(T: SQL.Type.Value): SQL.Source = SQL.sql_type_sqlite(T)
   424 
   425     def update_date(stmt: SQL.Statement, i: Int, date: Date): Unit =
   426       if (date == null) stmt.string(i) = (null: String)
   427       else stmt.string(i) = date_format(date)
   428 
   429     def date(res: SQL.Result, column: SQL.Column): Date =
   430       date_format.parse(res.string(column))
   431 
   432     def insert_permissive(table: SQL.Table, sql: SQL.Source = ""): SQL.Source =
   433       table.insert_cmd("INSERT OR IGNORE", sql = sql)
   434 
   435     def rebuild { using_statement("VACUUM")(_.execute()) }
   436   }
   437 }
   438 
   439 
   440 
   441 /** PostgreSQL **/
   442 
   443 object PostgreSQL
   444 {
   445   val default_port = 5432
   446 
   447   lazy val init_jdbc: Unit = Class.forName("org.postgresql.Driver")
   448 
   449   def open_database(
   450     user: String,
   451     password: String,
   452     database: String = "",
   453     host: String = "",
   454     port: Int = 0,
   455     ssh: Option[SSH.Session] = None,
   456     ssh_close: Boolean = false): Database =
   457   {
   458     init_jdbc
   459 
   460     if (user == "") error("Undefined database user")
   461 
   462     val db_host = proper_string(host) getOrElse "localhost"
   463     val db_port = if (port > 0 && port != default_port) ":" + port else ""
   464     val db_name = "/" + (proper_string(database) getOrElse user)
   465 
   466     val (url, name, port_forwarding) =
   467       ssh match {
   468         case None =>
   469           val spec = db_host + db_port + db_name
   470           val url = "jdbc:postgresql://" + spec
   471           val name = user + "@" + spec
   472           (url, name, None)
   473         case Some(ssh) =>
   474           val fw =
   475             ssh.port_forwarding(remote_host = db_host,
   476               remote_port = if (port > 0) port else default_port,
   477               ssh_close = ssh_close)
   478           val url = "jdbc:postgresql://localhost:" + fw.local_port + db_name
   479           val name = user + "@" + fw + db_name + " via ssh " + ssh
   480           (url, name, Some(fw))
   481       }
   482     try {
   483       val connection = DriverManager.getConnection(url, user, password)
   484       new Database(name, connection, port_forwarding)
   485     }
   486     catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn }
   487   }
   488 
   489   class Database private[PostgreSQL](
   490       name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding])
   491     extends SQL.Database
   492   {
   493     override def toString: String = name
   494 
   495     def sql_type(T: SQL.Type.Value): SQL.Source = SQL.sql_type_postgresql(T)
   496 
   497     // see https://jdbc.postgresql.org/documentation/head/8-date-time.html
   498     def update_date(stmt: SQL.Statement, i: Int, date: Date): Unit =
   499       if (date == null) stmt.rep.setObject(i, null)
   500       else stmt.rep.setObject(i, OffsetDateTime.from(date.to(Date.timezone_utc).rep))
   501 
   502     def date(res: SQL.Result, column: SQL.Column): Date =
   503     {
   504       val obj = res.rep.getObject(column.name, classOf[OffsetDateTime])
   505       if (obj == null) null else Date.instant(obj.toInstant)
   506     }
   507 
   508     def insert_permissive(table: SQL.Table, sql: SQL.Source = ""): SQL.Source =
   509       table.insert_cmd("INSERT",
   510         sql = sql + (if (sql == "") "" else " ") + "ON CONFLICT DO NOTHING")
   511 
   512     override def close() { super.close; port_forwarding.foreach(_.close) }
   513   }
   514 }