support for proxy connection, similar to ProxyCommand in ssh config;
authorwenzelm
Fri, 02 Mar 2018 18:45:11 +0100
changeset 67745 d83efbe52438
parent 67744 5c781dcd5864
child 67746 cb0f0f5f8876
support for proxy connection, similar to ProxyCommand in ssh config;
src/Pure/General/ssh.scala
--- a/src/Pure/General/ssh.scala	Fri Mar 02 15:16:10 2018 +0100
+++ b/src/Pure/General/ssh.scala	Fri Mar 02 18:45:11 2018 +0100
@@ -37,6 +37,7 @@
   }
 
   val default_port = 22
+  def make_port(port: Int): Int = if (port > 0) port else default_port
 
   def connect_timeout(options: Options): Int =
     options.seconds("ssh_connect_timeout").ms.toInt
@@ -73,32 +74,52 @@
     new Context(options, jsch)
   }
 
-  def open_session(options: Options, host: String, user: String = "", port: Int = 0): Session =
-    init_context(options).open_session(host = host, user = user, port = port)
+  def open_session(options: Options, host: String, user: String = "", port: Int = 0,
+      proxy_host: String = "", proxy_user: String = "", proxy_port: Int = 0): Session =
+    init_context(options).open_session(host = host, user = user, port = port,
+      proxy_host = proxy_host, proxy_user = proxy_user, proxy_port = proxy_port)
 
   class Context private[SSH](val options: Options, val jsch: JSch)
   {
     def update_options(new_options: Options): Context = new Context(new_options, jsch)
 
-    def open_session(host: String, user: String = "", port: Int = 0): Session =
+    def connect_session(host: String, user: String = "", port: Int = 0,
+      host_key_alias: String = "", on_close: () => Unit = () => ()): Session =
     {
-      val session =
-        jsch.getSession(proper_string(user) getOrElse null, host,
-          if (port > 0) port else default_port)
+      val session = jsch.getSession(proper_string(user) getOrElse null, host, make_port(port))
 
       session.setUserInfo(No_User_Info)
       session.setServerAliveInterval(alive_interval(options))
       session.setServerAliveCountMax(alive_count_max(options))
       session.setConfig("MaxAuthTries", "3")
+      if (host_key_alias != "") session.setHostKeyAlias(host_key_alias)
 
       if (options.bool("ssh_compression")) {
         session.setConfig("compression.s2c", "zlib@openssh.com,zlib,none")
         session.setConfig("compression.c2s", "zlib@openssh.com,zlib,none")
         session.setConfig("compression_level", "9")
       }
+      session.connect(connect_timeout(options))
+      new Session(options, session, on_close)
+    }
 
-      session.connect(connect_timeout(options))
-      new Session(options, session)
+    def open_session(host: String, user: String = "", port: Int = 0,
+      proxy_host: String = "", proxy_user: String = "", proxy_port: Int = 0): Session =
+    {
+      if (proxy_host == "") connect_session(host = host, user = user, port = port)
+      else {
+        val proxy = connect_session(host = proxy_host, port = proxy_port, user = proxy_user)
+
+        val fw =
+          try { proxy.port_forwarding(remote_host = host, remote_port = make_port(port)) }
+          catch { case exn: Throwable => proxy.close; throw exn }
+
+        try {
+          connect_session(host = fw.local_host, port = fw.local_port, host_key_alias = host,
+            user = user, on_close = () => { fw.close; proxy.close })
+        }
+        catch { case exn: Throwable => fw.close; proxy.close; throw exn }
+      }
     }
   }
 
@@ -262,9 +283,13 @@
 
   /* session */
 
-  class Session private[SSH](val options: Options, val session: JSch_Session) extends System
+  class Session private[SSH](
+    val options: Options,
+    val session: JSch_Session,
+    on_close: () => Unit) extends System
   {
-    def update_options(new_options: Options): Session = new Session(new_options, session)
+    def update_options(new_options: Options): Session =
+      new Session(new_options, session, on_close)
 
     def user_prefix: String = if (session.getUserName == null) "" else session.getUserName + "@"
     def host: String = if (session.getHost == null) "" else session.getHost
@@ -292,7 +317,7 @@
     val sftp: ChannelSftp = session.openChannel("sftp").asInstanceOf[ChannelSftp]
     sftp.connect(connect_timeout(options))
 
-    def close() { sftp.disconnect; session.disconnect }
+    def close() { sftp.disconnect; session.disconnect; on_close() }
 
     val settings: Map[String, String] =
     {