wenzelm@63788: /* Title: Pure/General/sql.scala wenzelm@63778: Author: Makarius wenzelm@63778: wenzelm@65006: Support for SQL databases: SQLite and PostgreSQL. wenzelm@63778: */ wenzelm@63778: wenzelm@63778: package isabelle wenzelm@63778: wenzelm@65021: import java.time.OffsetDateTime wenzelm@65021: import java.sql.{DriverManager, Connection, PreparedStatement, ResultSet} wenzelm@63779: wenzelm@63779: wenzelm@63778: object SQL wenzelm@63778: { wenzelm@65006: /** SQL language **/ wenzelm@65006: wenzelm@63778: /* concrete syntax */ wenzelm@63778: wenzelm@65321: def escape_char(c: Char): String = wenzelm@63778: c match { wenzelm@63778: case '\u0000' => "\\0" wenzelm@63778: case '\'' => "\\'" wenzelm@63778: case '\"' => "\\\"" wenzelm@63778: case '\b' => "\\b" wenzelm@63778: case '\n' => "\\n" wenzelm@63778: case '\r' => "\\r" wenzelm@63778: case '\t' => "\\t" wenzelm@63778: case '\u001a' => "\\Z" wenzelm@63778: case '\\' => "\\\\" wenzelm@63778: case _ => c.toString wenzelm@63778: } wenzelm@63778: wenzelm@63778: def quote_string(s: String): String = wenzelm@65321: "'" + s.map(escape_char(_)).mkString + "'" wenzelm@63778: wenzelm@63779: def quote_ident(s: String): String = wenzelm@65003: quote(s.replace("\"", "\"\"")) wenzelm@63779: wenzelm@63791: def enclosure(ss: Iterable[String]): String = ss.mkString("(", ", ", ")") wenzelm@63791: wenzelm@63779: wenzelm@65008: /* types */ wenzelm@65008: wenzelm@65008: object Type extends Enumeration wenzelm@65008: { wenzelm@65011: val Boolean = Value("BOOLEAN") wenzelm@65008: val Int = Value("INTEGER") wenzelm@65008: val Long = Value("BIGINT") wenzelm@65008: val Double = Value("DOUBLE PRECISION") wenzelm@65008: val String = Value("TEXT") wenzelm@65008: val Bytes = Value("BLOB") wenzelm@65014: val Date = Value("TIMESTAMP WITH TIME ZONE") wenzelm@65008: } wenzelm@65008: wenzelm@65019: def sql_type_default(T: Type.Value): String = T.toString wenzelm@65013: wenzelm@65019: def sql_type_sqlite(T: Type.Value): String = wenzelm@65019: if (T == Type.Boolean) "INTEGER" wenzelm@65019: else if (T == Type.Date) "TEXT" wenzelm@65019: else sql_type_default(T) wenzelm@65013: wenzelm@65019: def sql_type_postgresql(T: Type.Value): String = wenzelm@65019: if (T == Type.Bytes) "BYTEA" wenzelm@65019: else sql_type_default(T) wenzelm@65008: wenzelm@65008: wenzelm@63779: /* columns */ wenzelm@63779: wenzelm@63779: object Column wenzelm@63779: { wenzelm@65280: def bool(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Boolean, strict, primary_key) wenzelm@65280: def int(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Int, strict, primary_key) wenzelm@65280: def long(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Long, strict, primary_key) wenzelm@65280: def double(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Double, strict, primary_key) wenzelm@65280: def string(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.String, strict, primary_key) wenzelm@65280: def bytes(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Bytes, strict, primary_key) wenzelm@65280: def date(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Date, strict, primary_key) wenzelm@63779: } wenzelm@63779: wenzelm@65018: sealed case class Column( wenzelm@65280: name: String, T: Type.Value, strict: Boolean = false, primary_key: Boolean = false) wenzelm@63779: { wenzelm@63779: def sql_name: String = quote_ident(name) wenzelm@65019: def sql_decl(sql_type: Type.Value => String): String = wenzelm@65325: sql_name + " " + sql_type(T) + (if (strict || primary_key) " NOT NULL" else "") wenzelm@63781: wenzelm@65593: def sql_where_eq: String = "WHERE " + sql_name + " = " wenzelm@65602: def sql_where_equal(s: String): String = sql_where_eq + quote_string(s) wenzelm@65593: wenzelm@65019: override def toString: String = sql_decl(sql_type_default) wenzelm@63779: } wenzelm@63779: wenzelm@63780: wenzelm@63780: /* tables */ wenzelm@63780: wenzelm@65018: sealed case class Table(name: String, columns: List[Column]) wenzelm@63780: { wenzelm@63790: private val columns_index: Map[String, Int] = wenzelm@63790: columns.iterator.map(_.name).zipWithIndex.toMap wenzelm@63790: wenzelm@63781: Library.duplicates(columns.map(_.name)) match { wenzelm@63781: case Nil => wenzelm@63781: case bad => error("Duplicate column names " + commas_quote(bad) + " for table " + quote(name)) wenzelm@63781: } wenzelm@63781: wenzelm@65325: def sql_columns(sql_type: Type.Value => String): String = wenzelm@65325: { wenzelm@65325: val primary_key = wenzelm@65325: columns.filter(_.primary_key).map(_.name) match { wenzelm@65325: case Nil => Nil wenzelm@65325: case keys => List("PRIMARY KEY " + enclosure(keys)) wenzelm@65325: } wenzelm@65325: enclosure(columns.map(_.sql_decl(sql_type)) ::: primary_key) wenzelm@63781: } wenzelm@63781: wenzelm@65327: def sql_create(strict: Boolean, sql_type: Type.Value => String): String = wenzelm@63784: "CREATE TABLE " + (if (strict) "" else "IF NOT EXISTS ") + wenzelm@65327: quote_ident(name) + " " + sql_columns(sql_type) wenzelm@63780: wenzelm@63780: def sql_drop(strict: Boolean): String = wenzelm@63784: "DROP TABLE " + (if (strict) "" else "IF EXISTS ") + quote_ident(name) wenzelm@63783: wenzelm@63791: def sql_create_index( wenzelm@65018: index_name: String, index_columns: List[Column], wenzelm@63791: strict: Boolean, unique: Boolean): String = wenzelm@63791: "CREATE " + (if (unique) "UNIQUE " else "") + "INDEX " + wenzelm@63791: (if (strict) "" else "IF NOT EXISTS ") + quote_ident(index_name) + " ON " + wenzelm@63791: quote_ident(name) + " " + enclosure(index_columns.map(_.name)) wenzelm@63791: wenzelm@63791: def sql_drop_index(index_name: String, strict: Boolean): String = wenzelm@63791: "DROP INDEX " + (if (strict) "" else "IF EXISTS ") + quote_ident(index_name) wenzelm@63791: wenzelm@63790: def sql_insert: String = wenzelm@63791: "INSERT INTO " + quote_ident(name) + " VALUES " + enclosure(columns.map(_ => "?")) wenzelm@63791: wenzelm@65319: def sql_delete: String = wenzelm@65319: "DELETE FROM " + quote_ident(name) wenzelm@65319: wenzelm@65018: def sql_select(select_columns: List[Column], distinct: Boolean): String = wenzelm@63791: "SELECT " + (if (distinct) "DISTINCT " else "") + wenzelm@63791: commas(select_columns.map(_.sql_name)) + " FROM " + quote_ident(name) wenzelm@63790: wenzelm@63783: override def toString: String = wenzelm@65325: "TABLE " + quote_ident(name) + " " + sql_columns(sql_type_default) wenzelm@63780: } wenzelm@63790: wenzelm@63790: wenzelm@65012: wenzelm@65012: /** SQL database operations **/ wenzelm@65012: wenzelm@63790: /* results */ wenzelm@63790: wenzelm@63790: def iterator[A](rs: ResultSet)(get: ResultSet => A): Iterator[A] = new Iterator[A] wenzelm@63790: { wenzelm@63790: private var _next: Boolean = rs.next() wenzelm@63790: def hasNext: Boolean = _next wenzelm@63790: def next: A = { val x = get(rs); _next = rs.next(); x } wenzelm@63790: } wenzelm@65006: wenzelm@65006: trait Database wenzelm@65006: { wenzelm@65008: /* types */ wenzelm@65008: wenzelm@65019: def sql_type(T: Type.Value): String wenzelm@65008: wenzelm@65008: wenzelm@65006: /* connection */ wenzelm@65006: wenzelm@65006: def connection: Connection wenzelm@65006: wenzelm@65006: def close() { connection.close } wenzelm@65006: wenzelm@65006: def transaction[A](body: => A): A = wenzelm@65006: { wenzelm@65006: val auto_commit = connection.getAutoCommit wenzelm@65006: try { wenzelm@65006: connection.setAutoCommit(false) wenzelm@65022: val savepoint = connection.setSavepoint wenzelm@65022: try { wenzelm@65022: val result = body wenzelm@65022: connection.commit wenzelm@65022: result wenzelm@65022: } wenzelm@65022: catch { case exn: Throwable => connection.rollback(savepoint); throw exn } wenzelm@65006: } wenzelm@65006: finally { connection.setAutoCommit(auto_commit) } wenzelm@65006: } wenzelm@65006: wenzelm@65006: wenzelm@65006: /* statements */ wenzelm@65006: wenzelm@65006: def statement(sql: String): PreparedStatement = connection.prepareStatement(sql) wenzelm@65006: wenzelm@65619: def insert(table: Table): PreparedStatement = statement(table.sql_insert) wenzelm@65006: wenzelm@65619: def delete(table: Table, sql: String = ""): PreparedStatement = wenzelm@65319: statement(table.sql_delete + (if (sql == "") "" else " " + sql)) wenzelm@65319: wenzelm@65619: def select(table: Table, columns: List[Column], sql: String = "", distinct: Boolean = false) wenzelm@65619: : PreparedStatement = wenzelm@65006: statement(table.sql_select(columns, distinct) + (if (sql == "") "" else " " + sql)) wenzelm@65006: wenzelm@65006: wenzelm@65020: /* input */ wenzelm@65020: wenzelm@65020: def set_bool(stmt: PreparedStatement, i: Int, x: Boolean) { stmt.setBoolean(i, x) } wenzelm@65612: def set_bool(stmt: PreparedStatement, i: Int, x: Option[Boolean]) wenzelm@65612: { wenzelm@65612: if (x.isDefined) set_bool(stmt, i, x.get) wenzelm@65612: else stmt.setNull(i, java.sql.Types.BOOLEAN) wenzelm@65612: } wenzelm@65615: wenzelm@65020: def set_int(stmt: PreparedStatement, i: Int, x: Int) { stmt.setInt(i, x) } wenzelm@65612: def set_int(stmt: PreparedStatement, i: Int, x: Option[Int]) wenzelm@65612: { wenzelm@65612: if (x.isDefined) set_int(stmt, i, x.get) wenzelm@65612: else stmt.setNull(i, java.sql.Types.INTEGER) wenzelm@65612: } wenzelm@65615: wenzelm@65020: def set_long(stmt: PreparedStatement, i: Int, x: Long) { stmt.setLong(i, x) } wenzelm@65612: def set_long(stmt: PreparedStatement, i: Int, x: Option[Long]) wenzelm@65612: { wenzelm@65612: if (x.isDefined) set_long(stmt, i, x.get) wenzelm@65612: else stmt.setNull(i, java.sql.Types.BIGINT) wenzelm@65612: } wenzelm@65615: wenzelm@65020: def set_double(stmt: PreparedStatement, i: Int, x: Double) { stmt.setDouble(i, x) } wenzelm@65612: def set_double(stmt: PreparedStatement, i: Int, x: Option[Double]) wenzelm@65612: { wenzelm@65612: if (x.isDefined) set_double(stmt, i, x.get) wenzelm@65612: else stmt.setNull(i, java.sql.Types.DOUBLE) wenzelm@65612: } wenzelm@65615: wenzelm@65020: def set_string(stmt: PreparedStatement, i: Int, x: String) { stmt.setString(i, x) } wenzelm@65615: def set_string(stmt: PreparedStatement, i: Int, x: Option[String]): Unit = wenzelm@65615: set_string(stmt, i, x.orNull) wenzelm@65615: wenzelm@65020: def set_bytes(stmt: PreparedStatement, i: Int, bytes: Bytes) wenzelm@65610: { wenzelm@65610: if (bytes == null) stmt.setBytes(i, null) wenzelm@65610: else stmt.setBinaryStream(i, bytes.stream(), bytes.length) wenzelm@65610: } wenzelm@65615: def set_bytes(stmt: PreparedStatement, i: Int, bytes: Option[Bytes]): Unit = wenzelm@65615: set_bytes(stmt, i, bytes.orNull) wenzelm@65615: wenzelm@65615: def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit wenzelm@65615: def set_date(stmt: PreparedStatement, i: Int, date: Option[Date]): Unit = wenzelm@65615: set_date(stmt, i, date.orNull) wenzelm@65020: wenzelm@65022: wenzelm@65020: /* output */ wenzelm@65018: wenzelm@65620: def bool(rs: ResultSet, column: Column): Boolean = rs.getBoolean(column.name) wenzelm@65620: def int(rs: ResultSet, column: Column): Int = rs.getInt(column.name) wenzelm@65620: def long(rs: ResultSet, column: Column): Long = rs.getLong(column.name) wenzelm@65620: def double(rs: ResultSet, column: Column): Double = rs.getDouble(column.name) wenzelm@65620: def string(rs: ResultSet, column: Column): String = wenzelm@65018: { wenzelm@65620: val s = rs.getString(column.name) wenzelm@65018: if (s == null) "" else s wenzelm@65018: } wenzelm@65620: def bytes(rs: ResultSet, column: Column): Bytes = wenzelm@65018: { wenzelm@65620: val bs = rs.getBytes(column.name) wenzelm@65018: if (bs == null) Bytes.empty else Bytes(bs) wenzelm@65018: } wenzelm@65620: def date(rs: ResultSet, column: Column): Date wenzelm@65018: wenzelm@65620: def get[A](rs: ResultSet, column: Column, f: (ResultSet, Column) => A): Option[A] = wenzelm@65018: { wenzelm@65620: val x = f(rs, column) wenzelm@65018: if (rs.wasNull) None else Some(x) wenzelm@65018: } wenzelm@65018: wenzelm@65018: wenzelm@65006: /* tables */ wenzelm@65006: wenzelm@65006: def tables: List[String] = wenzelm@65006: iterator(connection.getMetaData.getTables(null, null, "%", null))(_.getString(3)).toList wenzelm@65006: wenzelm@65327: def create_table(table: Table, strict: Boolean = false, sql: String = ""): Unit = wenzelm@65327: using(statement(table.sql_create(strict, sql_type) + (if (sql == "") "" else " " + sql)))( wenzelm@65327: _.execute()) wenzelm@65006: wenzelm@65280: def drop_table(table: Table, strict: Boolean = false): Unit = wenzelm@65006: using(statement(table.sql_drop(strict)))(_.execute()) wenzelm@65006: wenzelm@65018: def create_index(table: Table, name: String, columns: List[Column], wenzelm@65280: strict: Boolean = false, unique: Boolean = false): Unit = wenzelm@65006: using(statement(table.sql_create_index(name, columns, strict, unique)))(_.execute()) wenzelm@65006: wenzelm@65280: def drop_index(table: Table, name: String, strict: Boolean = false): Unit = wenzelm@65006: using(statement(table.sql_drop_index(name, strict)))(_.execute()) wenzelm@65006: } wenzelm@63778: } wenzelm@65006: wenzelm@65006: wenzelm@65006: wenzelm@65006: /** SQLite **/ wenzelm@65006: wenzelm@65006: object SQLite wenzelm@65006: { wenzelm@65021: // see https://www.sqlite.org/lang_datefunc.html wenzelm@65021: val date_format: Date.Format = Date.Format("uuuu-MM-dd HH:mm:ss.SSS x") wenzelm@65021: wenzelm@65292: lazy val init_jdbc: Unit = Class.forName("org.sqlite.JDBC") wenzelm@65292: wenzelm@65006: def open_database(path: Path): Database = wenzelm@65006: { wenzelm@65292: init_jdbc wenzelm@65006: val path0 = path.expand wenzelm@65006: val s0 = File.platform_path(path0) wenzelm@65006: val s1 = if (Platform.is_windows) s0.replace('\\', '/') else s0 wenzelm@65006: val connection = DriverManager.getConnection("jdbc:sqlite:" + s1) wenzelm@65007: new Database(path0.toString, connection) wenzelm@65006: } wenzelm@65006: wenzelm@65007: class Database private[SQLite](name: String, val connection: Connection) extends SQL.Database wenzelm@65006: { wenzelm@65007: override def toString: String = name wenzelm@65006: wenzelm@65019: def sql_type(T: SQL.Type.Value): String = SQL.sql_type_sqlite(T) wenzelm@65011: wenzelm@65021: def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit = wenzelm@65615: if (date == null) set_string(stmt, i, null: String) wenzelm@65598: else set_string(stmt, i, date_format(date)) wenzelm@65598: wenzelm@65620: def date(rs: ResultSet, column: SQL.Column): Date = wenzelm@65620: date_format.parse(string(rs, column)) wenzelm@65021: wenzelm@65006: def rebuild { using(statement("VACUUM"))(_.execute()) } wenzelm@65006: } wenzelm@65006: } wenzelm@65006: wenzelm@65006: wenzelm@65006: wenzelm@65006: /** PostgreSQL **/ wenzelm@65006: wenzelm@65006: object PostgreSQL wenzelm@65006: { wenzelm@65006: val default_port = 5432 wenzelm@65006: wenzelm@65292: lazy val init_jdbc: Unit = Class.forName("org.postgresql.Driver") wenzelm@65292: wenzelm@65006: def open_database( wenzelm@65006: user: String, wenzelm@65006: password: String, wenzelm@65006: database: String = "", wenzelm@65006: host: String = "", wenzelm@65594: port: Int = 0, wenzelm@65636: ssh: Option[SSH.Session] = None, wenzelm@65636: ssh_close: Boolean = false): Database = wenzelm@65006: { wenzelm@65292: init_jdbc wenzelm@65292: wenzelm@65641: if (user == "") error("Undefined database user") wenzelm@65009: wenzelm@65009: val db_host = if (host != "") host else "localhost" wenzelm@65594: val db_port = if (port > 0 && port != default_port) ":" + port else "" wenzelm@65009: val db_name = "/" + (if (database != "") database else user) wenzelm@65009: wenzelm@65010: val (url, name, port_forwarding) = wenzelm@65009: ssh match { wenzelm@65010: case None => wenzelm@65010: val spec = db_host + db_port + db_name wenzelm@65010: val url = "jdbc:postgresql://" + spec wenzelm@65010: val name = user + "@" + spec wenzelm@65010: (url, name, None) wenzelm@65009: case Some(ssh) => wenzelm@65594: val fw = wenzelm@65594: ssh.port_forwarding(remote_host = db_host, wenzelm@65636: remote_port = if (port > 0) port else default_port, wenzelm@65636: ssh_close = ssh_close) wenzelm@65010: val url = "jdbc:postgresql://localhost:" + fw.local_port + db_name wenzelm@65010: val name = user + "@" + fw + db_name + " via ssh " + ssh wenzelm@65010: (url, name, Some(fw)) wenzelm@65009: } wenzelm@65009: try { wenzelm@65010: val connection = DriverManager.getConnection(url, user, password) wenzelm@65010: new Database(name, connection, port_forwarding) wenzelm@65009: } wenzelm@65009: catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn } wenzelm@65006: } wenzelm@65006: wenzelm@65009: class Database private[PostgreSQL]( wenzelm@65009: name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding]) wenzelm@65009: extends SQL.Database wenzelm@65006: { wenzelm@65010: override def toString: String = name wenzelm@65008: wenzelm@65019: def sql_type(T: SQL.Type.Value): String = SQL.sql_type_postgresql(T) wenzelm@65009: wenzelm@65021: // see https://jdbc.postgresql.org/documentation/head/8-date-time.html wenzelm@65021: def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit = wenzelm@65598: if (date == null) stmt.setObject(i, null) wenzelm@65598: else stmt.setObject(i, OffsetDateTime.from(date.to_utc.rep)) wenzelm@65598: wenzelm@65620: def date(rs: ResultSet, column: SQL.Column): Date = wenzelm@65620: Date.instant(rs.getObject(column.name, classOf[OffsetDateTime]).toInstant) wenzelm@65021: wenzelm@65009: override def close() { super.close; port_forwarding.foreach(_.close) } wenzelm@65006: } wenzelm@65006: }