--- a/src/Pure/General/sql.scala Thu Feb 09 15:40:34 2017 +0100
+++ b/src/Pure/General/sql.scala Thu Feb 09 16:06:45 2017 +0100
@@ -304,16 +304,22 @@
val db_port = if (port != default_port) ":" + port else ""
val db_name = "/" + (if (database != "") database else user)
- val (spec, port_forwarding) =
+ val (url, name, port_forwarding) =
ssh match {
- case None => (db_host + db_port + db_name, None)
+ case None =>
+ val spec = db_host + db_port + db_name
+ val url = "jdbc:postgresql://" + spec
+ val name = user + "@" + spec
+ (url, name, None)
case Some(ssh) =>
val fw = ssh.port_forwarding(remote_host = db_host, remote_port = port)
- ("localhost:" + fw.port + db_name, Some(fw))
+ val url = "jdbc:postgresql://localhost:" + fw.local_port + db_name
+ val name = user + "@" + fw + db_name + " via ssh " + ssh
+ (url, name, Some(fw))
}
try {
- val connection = DriverManager.getConnection("jdbc:postgresql://" + spec, user, password)
- new Database(user + "@" + spec, connection, port_forwarding)
+ val connection = DriverManager.getConnection(url, user, password)
+ new Database(name, connection, port_forwarding)
}
catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn }
}
@@ -322,11 +328,7 @@
name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding])
extends SQL.Database
{
- override def toString: String =
- port_forwarding match {
- case None => name
- case Some(fw) => name + " via ssh " + fw.ssh
- }
+ override def toString: String = name
override def type_name(t: SQL.Type.Value): String =
if (t == SQL.Type.Bytes) "BYTEA"
--- a/src/Pure/General/ssh.scala Thu Feb 09 15:40:34 2017 +0100
+++ b/src/Pure/General/ssh.scala Thu Feb 09 16:06:45 2017 +0100
@@ -131,9 +131,27 @@
/* port forwarding */
- sealed case class Port_Forwarding(ssh: SSH.Session, host: String, port: Int)
+ object Port_Forwarding
{
- def close() { ssh.session.delPortForwardingL(host, port) }
+ def open(ssh: Session,
+ local_host: String, local_port: Int, remote_host: String, remote_port: Int): Port_Forwarding =
+ {
+ val port = ssh.session.setPortForwardingL(local_host, local_port, remote_host, remote_port)
+ new Port_Forwarding(ssh, local_host, port, remote_host, remote_port)
+ }
+ }
+
+ class Port_Forwarding private[SSH](
+ ssh: SSH.Session,
+ val local_host: String,
+ val local_port: Int,
+ val remote_host: String,
+ val remote_port: Int)
+ {
+ override def toString: String =
+ local_host + ":" + local_port + ":" + remote_host + ":" + remote_port
+
+ def close() { ssh.session.delPortForwardingL(local_host, local_port) }
}
@@ -245,11 +263,9 @@
/* port forwarding */
- def port_forwarding(
- remote_port: Int, remote_host: String = "localhost",
+ def port_forwarding(remote_port: Int, remote_host: String = "localhost",
local_port: Int = 0, local_host: String = "localhost"): Port_Forwarding =
- Port_Forwarding(this, local_host,
- session.setPortForwardingL(local_host, local_port, remote_host, remote_port))
+ Port_Forwarding.open(this, local_host, local_port, remote_host, remote_port)
/* sftp channel */