wenzelm@63788: /* Title: Pure/General/sql.scala wenzelm@63778: Author: Makarius wenzelm@63778: wenzelm@65006: Support for SQL databases: SQLite and PostgreSQL. wenzelm@63778: */ wenzelm@63778: wenzelm@63778: package isabelle wenzelm@63778: wenzelm@65021: import java.time.OffsetDateTime wenzelm@65021: import java.sql.{DriverManager, Connection, PreparedStatement, ResultSet} wenzelm@63779: wenzelm@63779: wenzelm@63778: object SQL wenzelm@63778: { wenzelm@65006: /** SQL language **/ wenzelm@65006: wenzelm@63778: /* concrete syntax */ wenzelm@63778: wenzelm@63778: def quote_char(c: Char): String = wenzelm@63778: c match { wenzelm@63778: case '\u0000' => "\\0" wenzelm@63778: case '\'' => "\\'" wenzelm@63778: case '\"' => "\\\"" wenzelm@63778: case '\b' => "\\b" wenzelm@63778: case '\n' => "\\n" wenzelm@63778: case '\r' => "\\r" wenzelm@63778: case '\t' => "\\t" wenzelm@63778: case '\u001a' => "\\Z" wenzelm@63778: case '\\' => "\\\\" wenzelm@63778: case _ => c.toString wenzelm@63778: } wenzelm@63778: wenzelm@63778: def quote_string(s: String): String = wenzelm@63778: quote(s.map(quote_char(_)).mkString) wenzelm@63778: wenzelm@63779: def quote_ident(s: String): String = wenzelm@65003: quote(s.replace("\"", "\"\"")) wenzelm@63779: wenzelm@63791: def enclosure(ss: Iterable[String]): String = ss.mkString("(", ", ", ")") wenzelm@63791: wenzelm@63779: wenzelm@65008: /* types */ wenzelm@65008: wenzelm@65008: object Type extends Enumeration wenzelm@65008: { wenzelm@65011: val Boolean = Value("BOOLEAN") wenzelm@65008: val Int = Value("INTEGER") wenzelm@65008: val Long = Value("BIGINT") wenzelm@65008: val Double = Value("DOUBLE PRECISION") wenzelm@65008: val String = Value("TEXT") wenzelm@65008: val Bytes = Value("BLOB") wenzelm@65014: val Date = Value("TIMESTAMP WITH TIME ZONE") wenzelm@65008: } wenzelm@65008: wenzelm@65019: def sql_type_default(T: Type.Value): String = T.toString wenzelm@65013: wenzelm@65019: def sql_type_sqlite(T: Type.Value): String = wenzelm@65019: if (T == Type.Boolean) "INTEGER" wenzelm@65019: else if (T == Type.Date) "TEXT" wenzelm@65019: else sql_type_default(T) wenzelm@65013: wenzelm@65019: def sql_type_postgresql(T: Type.Value): String = wenzelm@65019: if (T == Type.Bytes) "BYTEA" wenzelm@65019: else sql_type_default(T) wenzelm@65008: wenzelm@65008: wenzelm@63779: /* columns */ wenzelm@63779: wenzelm@63779: object Column wenzelm@63779: { wenzelm@65280: def bool(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Boolean, strict, primary_key) wenzelm@65280: def int(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Int, strict, primary_key) wenzelm@65280: def long(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Long, strict, primary_key) wenzelm@65280: def double(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Double, strict, primary_key) wenzelm@65280: def string(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.String, strict, primary_key) wenzelm@65280: def bytes(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Bytes, strict, primary_key) wenzelm@65280: def date(name: String, strict: Boolean = false, primary_key: Boolean = false): Column = wenzelm@65018: Column(name, Type.Date, strict, primary_key) wenzelm@63779: } wenzelm@63779: wenzelm@65018: sealed case class Column( wenzelm@65280: name: String, T: Type.Value, strict: Boolean = false, primary_key: Boolean = false) wenzelm@63779: { wenzelm@63779: def sql_name: String = quote_ident(name) wenzelm@65019: def sql_decl(sql_type: Type.Value => String): String = wenzelm@65019: sql_name + " " + sql_type(T) + wenzelm@63781: (if (strict) " NOT NULL" else "") + wenzelm@63781: (if (primary_key) " PRIMARY KEY" else "") wenzelm@63781: wenzelm@65019: override def toString: String = sql_decl(sql_type_default) wenzelm@63779: } wenzelm@63779: wenzelm@63780: wenzelm@63780: /* tables */ wenzelm@63780: wenzelm@65018: sealed case class Table(name: String, columns: List[Column]) wenzelm@63780: { wenzelm@63790: private val columns_index: Map[String, Int] = wenzelm@63790: columns.iterator.map(_.name).zipWithIndex.toMap wenzelm@63790: wenzelm@63781: Library.duplicates(columns.map(_.name)) match { wenzelm@63781: case Nil => wenzelm@63781: case bad => error("Duplicate column names " + commas_quote(bad) + " for table " + quote(name)) wenzelm@63781: } wenzelm@63781: wenzelm@63781: columns.filter(_.primary_key) match { wenzelm@63781: case bad if bad.length > 1 => wenzelm@63781: error("Multiple primary keys " + commas_quote(bad.map(_.name)) + " for table " + quote(name)) wenzelm@63781: case _ => wenzelm@63781: } wenzelm@63781: wenzelm@65019: def sql_create(strict: Boolean, rowid: Boolean, sql_type: Type.Value => String): String = wenzelm@63784: "CREATE TABLE " + (if (strict) "" else "IF NOT EXISTS ") + wenzelm@65019: quote_ident(name) + " " + enclosure(columns.map(_.sql_decl(sql_type))) + wenzelm@63780: (if (rowid) "" else " WITHOUT ROWID") wenzelm@63780: wenzelm@63780: def sql_drop(strict: Boolean): String = wenzelm@63784: "DROP TABLE " + (if (strict) "" else "IF EXISTS ") + quote_ident(name) wenzelm@63783: wenzelm@63791: def sql_create_index( wenzelm@65018: index_name: String, index_columns: List[Column], wenzelm@63791: strict: Boolean, unique: Boolean): String = wenzelm@63791: "CREATE " + (if (unique) "UNIQUE " else "") + "INDEX " + wenzelm@63791: (if (strict) "" else "IF NOT EXISTS ") + quote_ident(index_name) + " ON " + wenzelm@63791: quote_ident(name) + " " + enclosure(index_columns.map(_.name)) wenzelm@63791: wenzelm@63791: def sql_drop_index(index_name: String, strict: Boolean): String = wenzelm@63791: "DROP INDEX " + (if (strict) "" else "IF EXISTS ") + quote_ident(index_name) wenzelm@63791: wenzelm@63790: def sql_insert: String = wenzelm@63791: "INSERT INTO " + quote_ident(name) + " VALUES " + enclosure(columns.map(_ => "?")) wenzelm@63791: wenzelm@65018: def sql_select(select_columns: List[Column], distinct: Boolean): String = wenzelm@63791: "SELECT " + (if (distinct) "DISTINCT " else "") + wenzelm@63791: commas(select_columns.map(_.sql_name)) + " FROM " + quote_ident(name) wenzelm@63790: wenzelm@63783: override def toString: String = wenzelm@63791: "TABLE " + quote_ident(name) + " " + enclosure(columns.map(_.toString)) wenzelm@63780: } wenzelm@63790: wenzelm@63790: wenzelm@65012: wenzelm@65012: /** SQL database operations **/ wenzelm@65012: wenzelm@63790: /* results */ wenzelm@63790: wenzelm@63790: def iterator[A](rs: ResultSet)(get: ResultSet => A): Iterator[A] = new Iterator[A] wenzelm@63790: { wenzelm@63790: private var _next: Boolean = rs.next() wenzelm@63790: def hasNext: Boolean = _next wenzelm@63790: def next: A = { val x = get(rs); _next = rs.next(); x } wenzelm@63790: } wenzelm@65006: wenzelm@65006: trait Database wenzelm@65006: { wenzelm@65008: /* types */ wenzelm@65008: wenzelm@65019: def sql_type(T: Type.Value): String wenzelm@65008: wenzelm@65008: wenzelm@65006: /* connection */ wenzelm@65006: wenzelm@65006: def connection: Connection wenzelm@65006: wenzelm@65006: def close() { connection.close } wenzelm@65006: wenzelm@65006: def transaction[A](body: => A): A = wenzelm@65006: { wenzelm@65006: val auto_commit = connection.getAutoCommit wenzelm@65006: try { wenzelm@65006: connection.setAutoCommit(false) wenzelm@65022: val savepoint = connection.setSavepoint wenzelm@65022: try { wenzelm@65022: val result = body wenzelm@65022: connection.commit wenzelm@65022: result wenzelm@65022: } wenzelm@65022: catch { case exn: Throwable => connection.rollback(savepoint); throw exn } wenzelm@65006: } wenzelm@65006: finally { connection.setAutoCommit(auto_commit) } wenzelm@65006: } wenzelm@65006: wenzelm@65006: wenzelm@65006: /* statements */ wenzelm@65006: wenzelm@65006: def statement(sql: String): PreparedStatement = connection.prepareStatement(sql) wenzelm@65006: wenzelm@65006: def insert_statement(table: Table): PreparedStatement = statement(table.sql_insert) wenzelm@65006: wenzelm@65018: def select_statement(table: Table, columns: List[Column], wenzelm@65006: sql: String = "", distinct: Boolean = false): PreparedStatement = wenzelm@65006: statement(table.sql_select(columns, distinct) + (if (sql == "") "" else " " + sql)) wenzelm@65006: wenzelm@65006: wenzelm@65020: /* input */ wenzelm@65020: wenzelm@65020: def set_bool(stmt: PreparedStatement, i: Int, x: Boolean) { stmt.setBoolean(i, x) } wenzelm@65020: def set_int(stmt: PreparedStatement, i: Int, x: Int) { stmt.setInt(i, x) } wenzelm@65020: def set_long(stmt: PreparedStatement, i: Int, x: Long) { stmt.setLong(i, x) } wenzelm@65020: def set_double(stmt: PreparedStatement, i: Int, x: Double) { stmt.setDouble(i, x) } wenzelm@65020: def set_string(stmt: PreparedStatement, i: Int, x: String) { stmt.setString(i, x) } wenzelm@65020: def set_bytes(stmt: PreparedStatement, i: Int, bytes: Bytes) wenzelm@65020: { stmt.setBinaryStream(i, bytes.stream(), bytes.length) } wenzelm@65020: def set_date(stmt: PreparedStatement, i: Int, date: Date) wenzelm@65020: wenzelm@65022: wenzelm@65020: /* output */ wenzelm@65018: wenzelm@65018: def bool(rs: ResultSet, name: String): Boolean = rs.getBoolean(name) wenzelm@65018: def int(rs: ResultSet, name: String): Int = rs.getInt(name) wenzelm@65018: def long(rs: ResultSet, name: String): Long = rs.getLong(name) wenzelm@65018: def double(rs: ResultSet, name: String): Double = rs.getDouble(name) wenzelm@65018: def string(rs: ResultSet, name: String): String = wenzelm@65018: { wenzelm@65018: val s = rs.getString(name) wenzelm@65018: if (s == null) "" else s wenzelm@65018: } wenzelm@65018: def bytes(rs: ResultSet, name: String): Bytes = wenzelm@65018: { wenzelm@65018: val bs = rs.getBytes(name) wenzelm@65018: if (bs == null) Bytes.empty else Bytes(bs) wenzelm@65018: } wenzelm@65021: def date(rs: ResultSet, name: String): Date wenzelm@65018: wenzelm@65018: def get[A](rs: ResultSet, name: String, f: (ResultSet, String) => A): Option[A] = wenzelm@65018: { wenzelm@65018: val x = f(rs, name) wenzelm@65018: if (rs.wasNull) None else Some(x) wenzelm@65018: } wenzelm@65018: wenzelm@65018: wenzelm@65006: /* tables */ wenzelm@65006: wenzelm@65006: def tables: List[String] = wenzelm@65006: iterator(connection.getMetaData.getTables(null, null, "%", null))(_.getString(3)).toList wenzelm@65006: wenzelm@65280: def create_table(table: Table, strict: Boolean = false, rowid: Boolean = true): Unit = wenzelm@65019: using(statement(table.sql_create(strict, rowid, sql_type)))(_.execute()) wenzelm@65006: wenzelm@65280: def drop_table(table: Table, strict: Boolean = false): Unit = wenzelm@65006: using(statement(table.sql_drop(strict)))(_.execute()) wenzelm@65006: wenzelm@65018: def create_index(table: Table, name: String, columns: List[Column], wenzelm@65280: strict: Boolean = false, unique: Boolean = false): Unit = wenzelm@65006: using(statement(table.sql_create_index(name, columns, strict, unique)))(_.execute()) wenzelm@65006: wenzelm@65280: def drop_index(table: Table, name: String, strict: Boolean = false): Unit = wenzelm@65006: using(statement(table.sql_drop_index(name, strict)))(_.execute()) wenzelm@65006: } wenzelm@63778: } wenzelm@65006: wenzelm@65006: wenzelm@65006: wenzelm@65006: /** SQLite **/ wenzelm@65006: wenzelm@65006: object SQLite wenzelm@65006: { wenzelm@65021: // see https://www.sqlite.org/lang_datefunc.html wenzelm@65021: val date_format: Date.Format = Date.Format("uuuu-MM-dd HH:mm:ss.SSS x") wenzelm@65021: wenzelm@65006: def open_database(path: Path): Database = wenzelm@65006: { wenzelm@65006: val path0 = path.expand wenzelm@65006: val s0 = File.platform_path(path0) wenzelm@65006: val s1 = if (Platform.is_windows) s0.replace('\\', '/') else s0 wenzelm@65006: val connection = DriverManager.getConnection("jdbc:sqlite:" + s1) wenzelm@65007: new Database(path0.toString, connection) wenzelm@65006: } wenzelm@65006: wenzelm@65007: class Database private[SQLite](name: String, val connection: Connection) extends SQL.Database wenzelm@65006: { wenzelm@65007: override def toString: String = name wenzelm@65006: wenzelm@65019: def sql_type(T: SQL.Type.Value): String = SQL.sql_type_sqlite(T) wenzelm@65011: wenzelm@65021: def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit = wenzelm@65021: set_string(stmt, i, date_format(date)) wenzelm@65021: def date(rs: ResultSet, name: String): Date = wenzelm@65021: date_format.parse(string(rs, name)) wenzelm@65021: wenzelm@65006: def rebuild { using(statement("VACUUM"))(_.execute()) } wenzelm@65006: } wenzelm@65006: } wenzelm@65006: wenzelm@65006: wenzelm@65006: wenzelm@65006: /** PostgreSQL **/ wenzelm@65006: wenzelm@65006: object PostgreSQL wenzelm@65006: { wenzelm@65006: val default_port = 5432 wenzelm@65006: wenzelm@65006: def open_database( wenzelm@65006: user: String, wenzelm@65006: password: String, wenzelm@65006: database: String = "", wenzelm@65006: host: String = "", wenzelm@65009: port: Int = default_port, wenzelm@65009: ssh: Option[SSH.Session] = None): Database = wenzelm@65006: { wenzelm@65006: require(user != "") wenzelm@65009: wenzelm@65009: val db_host = if (host != "") host else "localhost" wenzelm@65009: val db_port = if (port != default_port) ":" + port else "" wenzelm@65009: val db_name = "/" + (if (database != "") database else user) wenzelm@65009: wenzelm@65010: val (url, name, port_forwarding) = wenzelm@65009: ssh match { wenzelm@65010: case None => wenzelm@65010: val spec = db_host + db_port + db_name wenzelm@65010: val url = "jdbc:postgresql://" + spec wenzelm@65010: val name = user + "@" + spec wenzelm@65010: (url, name, None) wenzelm@65009: case Some(ssh) => wenzelm@65009: val fw = ssh.port_forwarding(remote_host = db_host, remote_port = port) wenzelm@65010: val url = "jdbc:postgresql://localhost:" + fw.local_port + db_name wenzelm@65010: val name = user + "@" + fw + db_name + " via ssh " + ssh wenzelm@65010: (url, name, Some(fw)) wenzelm@65009: } wenzelm@65009: try { wenzelm@65010: val connection = DriverManager.getConnection(url, user, password) wenzelm@65010: new Database(name, connection, port_forwarding) wenzelm@65009: } wenzelm@65009: catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn } wenzelm@65006: } wenzelm@65006: wenzelm@65009: class Database private[PostgreSQL]( wenzelm@65009: name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding]) wenzelm@65009: extends SQL.Database wenzelm@65006: { wenzelm@65010: override def toString: String = name wenzelm@65008: wenzelm@65019: def sql_type(T: SQL.Type.Value): String = SQL.sql_type_postgresql(T) wenzelm@65009: wenzelm@65021: // see https://jdbc.postgresql.org/documentation/head/8-date-time.html wenzelm@65021: def set_date(stmt: PreparedStatement, i: Int, date: Date): Unit = wenzelm@65021: stmt.setObject(i, OffsetDateTime.from(date.to_utc.rep)) wenzelm@65021: def date(rs: ResultSet, name: String): Date = wenzelm@65021: Date.instant(rs.getObject(name, classOf[OffsetDateTime]).toInstant) wenzelm@65021: wenzelm@65009: override def close() { super.close; port_forwarding.foreach(_.close) } wenzelm@65006: } wenzelm@65006: }