# HG changeset patch # User wenzelm # Date 1486651234 -3600 # Node ID eda9366bbfac769e583f1717fe3e5c776f3e6bb4 # Parent ed2eedf786f3914a9b0bd73f154f1a4a5d8798cf remote database access via ssh port forwarding; diff -r ed2eedf786f3 -r eda9366bbfac src/Pure/General/sql.scala --- a/src/Pure/General/sql.scala Thu Feb 09 11:44:55 2017 +0100 +++ b/src/Pure/General/sql.scala Thu Feb 09 15:40:34 2017 +0100 @@ -295,23 +295,43 @@ password: String, database: String = "", host: String = "", - port: Int = default_port): Database = + port: Int = default_port, + ssh: Option[SSH.Session] = None): Database = { require(user != "") - val spec = - (if (host != "") host else "localhost") + - (if (port != default_port) ":" + port else "") + "/" + - (if (database != "") database else user) - val connection = DriverManager.getConnection("jdbc:postgresql://" + spec, user, password) - new Database(user + "@" + spec, connection) + + val db_host = if (host != "") host else "localhost" + val db_port = if (port != default_port) ":" + port else "" + val db_name = "/" + (if (database != "") database else user) + + val (spec, port_forwarding) = + ssh match { + case None => (db_host + db_port + db_name, None) + case Some(ssh) => + val fw = ssh.port_forwarding(remote_host = db_host, remote_port = port) + ("localhost:" + fw.port + db_name, Some(fw)) + } + try { + val connection = DriverManager.getConnection("jdbc:postgresql://" + spec, user, password) + new Database(user + "@" + spec, connection, port_forwarding) + } + catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn } } - class Database private[PostgreSQL](name: String, val connection: Connection) extends SQL.Database + class Database private[PostgreSQL]( + name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding]) + extends SQL.Database { - override def toString: String = name + override def toString: String = + port_forwarding match { + case None => name + case Some(fw) => name + " via ssh " + fw.ssh + } override def type_name(t: SQL.Type.Value): String = if (t == SQL.Type.Bytes) "BYTEA" else SQL.default_type_name(t) + + override def close() { super.close; port_forwarding.foreach(_.close) } } } diff -r ed2eedf786f3 -r eda9366bbfac src/Pure/General/ssh.scala --- a/src/Pure/General/ssh.scala Thu Feb 09 11:44:55 2017 +0100 +++ b/src/Pure/General/ssh.scala Thu Feb 09 15:40:34 2017 +0100 @@ -129,6 +129,14 @@ } + /* port forwarding */ + + sealed case class Port_Forwarding(ssh: SSH.Session, host: String, port: Int) + { + def close() { ssh.session.delPortForwardingL(host, port) } + } + + /* Sftp channel */ type Attrs = SftpATTRS @@ -235,6 +243,15 @@ user_prefix + host + port_suffix + (if (session.isConnected) "" else " (disconnected)") + /* port forwarding */ + + def port_forwarding( + remote_port: Int, remote_host: String = "localhost", + local_port: Int = 0, local_host: String = "localhost"): Port_Forwarding = + Port_Forwarding(this, local_host, + session.setPortForwardingL(local_host, local_port, remote_host, remote_port)) + + /* sftp channel */ val sftp: ChannelSftp = session.openChannel("sftp").asInstanceOf[ChannelSftp]