remote database access via ssh port forwarding;
authorwenzelm
Thu Feb 09 15:40:34 2017 +0100 (2017-02-09)
changeset 65009eda9366bbfac
parent 65008 ed2eedf786f3
child 65010 a27e9908dcf7
remote database access via ssh port forwarding;
src/Pure/General/sql.scala
src/Pure/General/ssh.scala
     1.1 --- a/src/Pure/General/sql.scala	Thu Feb 09 11:44:55 2017 +0100
     1.2 +++ b/src/Pure/General/sql.scala	Thu Feb 09 15:40:34 2017 +0100
     1.3 @@ -295,23 +295,43 @@
     1.4      password: String,
     1.5      database: String = "",
     1.6      host: String = "",
     1.7 -    port: Int = default_port): Database =
     1.8 +    port: Int = default_port,
     1.9 +    ssh: Option[SSH.Session] = None): Database =
    1.10    {
    1.11      require(user != "")
    1.12 -    val spec =
    1.13 -      (if (host != "") host else "localhost") +
    1.14 -      (if (port != default_port) ":" + port else "") + "/" +
    1.15 -      (if (database != "") database else user)
    1.16 -    val connection = DriverManager.getConnection("jdbc:postgresql://" + spec, user, password)
    1.17 -    new Database(user + "@" + spec, connection)
    1.18 +
    1.19 +    val db_host = if (host != "") host else "localhost"
    1.20 +    val db_port = if (port != default_port) ":" + port else ""
    1.21 +    val db_name = "/" + (if (database != "") database else user)
    1.22 +
    1.23 +    val (spec, port_forwarding) =
    1.24 +      ssh match {
    1.25 +        case None => (db_host + db_port + db_name, None)
    1.26 +        case Some(ssh) =>
    1.27 +          val fw = ssh.port_forwarding(remote_host = db_host, remote_port = port)
    1.28 +          ("localhost:" + fw.port + db_name, Some(fw))
    1.29 +      }
    1.30 +    try {
    1.31 +      val connection = DriverManager.getConnection("jdbc:postgresql://" + spec, user, password)
    1.32 +      new Database(user + "@" + spec, connection, port_forwarding)
    1.33 +    }
    1.34 +    catch { case exn: Throwable => port_forwarding.foreach(_.close); throw exn }
    1.35    }
    1.36  
    1.37 -  class Database private[PostgreSQL](name: String, val connection: Connection) extends SQL.Database
    1.38 +  class Database private[PostgreSQL](
    1.39 +      name: String, val connection: Connection, port_forwarding: Option[SSH.Port_Forwarding])
    1.40 +    extends SQL.Database
    1.41    {
    1.42 -    override def toString: String = name
    1.43 +    override def toString: String =
    1.44 +      port_forwarding match {
    1.45 +        case None => name
    1.46 +        case Some(fw) => name + " via ssh " + fw.ssh
    1.47 +      }
    1.48  
    1.49      override def type_name(t: SQL.Type.Value): String =
    1.50        if (t == SQL.Type.Bytes) "BYTEA"
    1.51        else SQL.default_type_name(t)
    1.52 +
    1.53 +    override def close() { super.close; port_forwarding.foreach(_.close) }
    1.54    }
    1.55  }
     2.1 --- a/src/Pure/General/ssh.scala	Thu Feb 09 11:44:55 2017 +0100
     2.2 +++ b/src/Pure/General/ssh.scala	Thu Feb 09 15:40:34 2017 +0100
     2.3 @@ -129,6 +129,14 @@
     2.4    }
     2.5  
     2.6  
     2.7 +  /* port forwarding */
     2.8 +
     2.9 +  sealed case class Port_Forwarding(ssh: SSH.Session, host: String, port: Int)
    2.10 +  {
    2.11 +    def close() { ssh.session.delPortForwardingL(host, port) }
    2.12 +  }
    2.13 +
    2.14 +
    2.15    /* Sftp channel */
    2.16  
    2.17    type Attrs = SftpATTRS
    2.18 @@ -235,6 +243,15 @@
    2.19        user_prefix + host + port_suffix + (if (session.isConnected) "" else " (disconnected)")
    2.20  
    2.21  
    2.22 +    /* port forwarding */
    2.23 +
    2.24 +    def port_forwarding(
    2.25 +        remote_port: Int, remote_host: String = "localhost",
    2.26 +        local_port: Int = 0, local_host: String = "localhost"): Port_Forwarding =
    2.27 +      Port_Forwarding(this, local_host,
    2.28 +        session.setPortForwardingL(local_host, local_port, remote_host, remote_port))
    2.29 +
    2.30 +
    2.31      /* sftp channel */
    2.32  
    2.33      val sftp: ChannelSftp = session.openChannel("sftp").asInstanceOf[ChannelSftp]