--- a/src/Pure/General/sql.scala Thu Feb 09 11:44:55 2017 +0100
+++ b/src/Pure/General/sql.scala Thu Feb 09 15:40:34 2017 +0100
@@ -295,23 +295,43 @@
password: String,
database: String = "",
host: String = "",
- port: Int = default_port): Database =
+ port: Int = default_port,
+ ssh: Option[SSH.Session] = None): Database =
{
require(user != "")
- val spec =
- (if (host != "") host else "localhost") +
- (if (port != default_port) ":" + port else "") + "/" +
- (if (database != "") database else user)
- val connection = DriverManager.getConnection("jdbc:postgresql://" + spec, user, password)
- new Database(user + "@" + spec, connection)
+
+ val db_host = if (host != "") host else "localhost"
+ val db_port = if (port != default_port) ":" + port else ""
+ val db_name = "/" + (if (database != "") database else user)
+
+ val (spec, port_forwarding) =
+ ssh match {
+ case None => (db_host + db_port + db_name, None)
+ case Some(ssh) =>
+ val fw = ssh.port_forwarding(remote_host = db_host, remote_port = port)
+ ("localhost:" + fw.port + db_name, Some(fw))
+ }
+ try {
+ val connection = DriverManager.getConnection("jdbc:postgresql://" + spec, user, password)
+ new Database(user + "@" + spec, connection, port_forwarding)
+ }
+ catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn }
}
- class Database private[PostgreSQL](name: String, val connection: Connection) extends SQL.Database
+ class Database private[PostgreSQL](
+ name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding])
+ extends SQL.Database
{
- override def toString: String = name
+ override def toString: String =
+ port_forwarding match {
+ case None => name
+ case Some(fw) => name + " via ssh " + fw.ssh
+ }
override def type_name(t: SQL.Type.Value): String =
if (t == SQL.Type.Bytes) "BYTEA"
else SQL.default_type_name(t)
+
+ override def close() { super.close; port_forwarding.foreach(_.close) }
}
}
--- a/src/Pure/General/ssh.scala Thu Feb 09 11:44:55 2017 +0100
+++ b/src/Pure/General/ssh.scala Thu Feb 09 15:40:34 2017 +0100
@@ -129,6 +129,14 @@
}
+ /* port forwarding */
+
+ sealed case class Port_Forwarding(ssh: SSH.Session, host: String, port: Int)
+ {
+ def close() { ssh.session.delPortForwardingL(host, port) }
+ }
+
+
/* Sftp channel */
type Attrs = SftpATTRS
@@ -235,6 +243,15 @@
user_prefix + host + port_suffix + (if (session.isConnected) "" else " (disconnected)")
+ /* port forwarding */
+
+ def port_forwarding(
+ remote_port: Int, remote_host: String = "localhost",
+ local_port: Int = 0, local_host: String = "localhost"): Port_Forwarding =
+ Port_Forwarding(this, local_host,
+ session.setPortForwardingL(local_host, local_port, remote_host, remote_port))
+
+
/* sftp channel */
val sftp: ChannelSftp = session.openChannel("sftp").asInstanceOf[ChannelSftp]