src/Pure/General/sql.scala
changeset 65009 eda9366bbfac
parent 65008 ed2eedf786f3
child 65010 a27e9908dcf7
--- a/src/Pure/General/sql.scala	Thu Feb 09 11:44:55 2017 +0100
+++ b/src/Pure/General/sql.scala	Thu Feb 09 15:40:34 2017 +0100
@@ -295,23 +295,43 @@
     password: String,
     database: String = "",
     host: String = "",
-    port: Int = default_port): Database =
+    port: Int = default_port,
+    ssh: Option[SSH.Session] = None): Database =
   {
     require(user != "")
-    val spec =
-      (if (host != "") host else "localhost") +
-      (if (port != default_port) ":" + port else "") + "/" +
-      (if (database != "") database else user)
-    val connection = DriverManager.getConnection("jdbc:postgresql://" + spec, user, password)
-    new Database(user + "@" + spec, connection)
+
+    val db_host = if (host != "") host else "localhost"
+    val db_port = if (port != default_port) ":" + port else ""
+    val db_name = "/" + (if (database != "") database else user)
+
+    val (spec, port_forwarding) =
+      ssh match {
+        case None => (db_host + db_port + db_name, None)
+        case Some(ssh) =>
+          val fw = ssh.port_forwarding(remote_host = db_host, remote_port = port)
+          ("localhost:" + fw.port + db_name, Some(fw))
+      }
+    try {
+      val connection = DriverManager.getConnection("jdbc:postgresql://" + spec, user, password)
+      new Database(user + "@" + spec, connection, port_forwarding)
+    }
+    catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn }
   }
 
-  class Database private[PostgreSQL](name: String, val connection: Connection) extends SQL.Database
+  class Database private[PostgreSQL](
+      name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding])
+    extends SQL.Database
   {
-    override def toString: String = name
+    override def toString: String =
+      port_forwarding match {
+        case None => name
+        case Some(fw) => name + " via ssh " + fw.ssh
+      }
 
     override def type_name(t: SQL.Type.Value): String =
       if (t == SQL.Type.Bytes) "BYTEA"
       else SQL.default_type_name(t)
+
+    override def close() { super.close; port_forwarding.foreach(_.close) }
   }
 }