remote database access via ssh port forwarding;
authorwenzelm
Thu, 09 Feb 2017 15:40:34 +0100
changeset 65009 eda9366bbfac
parent 65008 ed2eedf786f3
child 65010 a27e9908dcf7
remote database access via ssh port forwarding;
src/Pure/General/sql.scala
src/Pure/General/ssh.scala
--- a/src/Pure/General/sql.scala	Thu Feb 09 11:44:55 2017 +0100
+++ b/src/Pure/General/sql.scala	Thu Feb 09 15:40:34 2017 +0100
@@ -295,23 +295,43 @@
     password: String,
     database: String = "",
     host: String = "",
-    port: Int = default_port): Database =
+    port: Int = default_port,
+    ssh: Option[SSH.Session] = None): Database =
   {
     require(user != "")
-    val spec =
-      (if (host != "") host else "localhost") +
-      (if (port != default_port) ":" + port else "") + "/" +
-      (if (database != "") database else user)
-    val connection = DriverManager.getConnection("jdbc:postgresql://" + spec, user, password)
-    new Database(user + "@" + spec, connection)
+
+    val db_host = if (host != "") host else "localhost"
+    val db_port = if (port != default_port) ":" + port else ""
+    val db_name = "/" + (if (database != "") database else user)
+
+    val (spec, port_forwarding) =
+      ssh match {
+        case None => (db_host + db_port + db_name, None)
+        case Some(ssh) =>
+          val fw = ssh.port_forwarding(remote_host = db_host, remote_port = port)
+          ("localhost:" + fw.port + db_name, Some(fw))
+      }
+    try {
+      val connection = DriverManager.getConnection("jdbc:postgresql://" + spec, user, password)
+      new Database(user + "@" + spec, connection, port_forwarding)
+    }
+    catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn }
   }
 
-  class Database private[PostgreSQL](name: String, val connection: Connection) extends SQL.Database
+  class Database private[PostgreSQL](
+      name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding])
+    extends SQL.Database
   {
-    override def toString: String = name
+    override def toString: String =
+      port_forwarding match {
+        case None => name
+        case Some(fw) => name + " via ssh " + fw.ssh
+      }
 
     override def type_name(t: SQL.Type.Value): String =
       if (t == SQL.Type.Bytes) "BYTEA"
       else SQL.default_type_name(t)
+
+    override def close() { super.close; port_forwarding.foreach(_.close) }
   }
 }
--- a/src/Pure/General/ssh.scala	Thu Feb 09 11:44:55 2017 +0100
+++ b/src/Pure/General/ssh.scala	Thu Feb 09 15:40:34 2017 +0100
@@ -129,6 +129,14 @@
   }
 
 
+  /* port forwarding */
+
+  sealed case class Port_Forwarding(ssh: SSH.Session, host: String, port: Int)
+  {
+    def close() { ssh.session.delPortForwardingL(host, port) }
+  }
+
+
   /* Sftp channel */
 
   type Attrs = SftpATTRS
@@ -235,6 +243,15 @@
       user_prefix + host + port_suffix + (if (session.isConnected) "" else " (disconnected)")
 
 
+    /* port forwarding */
+
+    def port_forwarding(
+        remote_port: Int, remote_host: String = "localhost",
+        local_port: Int = 0, local_host: String = "localhost"): Port_Forwarding =
+      Port_Forwarding(this, local_host,
+        session.setPortForwardingL(local_host, local_port, remote_host, remote_port))
+
+
     /* sftp channel */
 
     val sftp: ChannelSftp = session.openChannel("sftp").asInstanceOf[ChannelSftp]