ssh_close for proper termination after use of database;
authorwenzelm
Sat, 29 Apr 2017 20:15:26 +0200
changeset 65636 df804cdba5f9
parent 65635 0a025b8496a2
child 65637 e9b87bf6578b
ssh_close for proper termination after use of database;
src/Pure/Admin/build_log.scala
src/Pure/General/sql.scala
src/Pure/General/ssh.scala
--- a/src/Pure/Admin/build_log.scala	Sat Apr 29 19:43:04 2017 +0200
+++ b/src/Pure/Admin/build_log.scala	Sat Apr 29 20:15:26 2017 +0200
@@ -645,7 +645,8 @@
         user = user, password = password, database = database, host = host, port = port,
         ssh =
           if (ssh_host == "") None
-          else Some(SSH.init_context(options).open_session(ssh_host, ssh_user, port)))
+          else Some(SSH.init_context(options).open_session(ssh_host, ssh_user, port)),
+        ssh_close = true)
     }
 
     def write_info(db: SQL.Database, files: List[JFile])
--- a/src/Pure/General/sql.scala	Sat Apr 29 19:43:04 2017 +0200
+++ b/src/Pure/General/sql.scala	Sat Apr 29 20:15:26 2017 +0200
@@ -355,7 +355,8 @@
     database: String = "",
     host: String = "",
     port: Int = 0,
-    ssh: Option[SSH.Session] = None): Database =
+    ssh: Option[SSH.Session] = None,
+    ssh_close: Boolean = false): Database =
   {
     init_jdbc
 
@@ -375,7 +376,8 @@
         case Some(ssh) =>
           val fw =
             ssh.port_forwarding(remote_host = db_host,
-              remote_port = if (port > 0) port else default_port)
+              remote_port = if (port > 0) port else default_port,
+              ssh_close = ssh_close)
           val url = "jdbc:postgresql://localhost:" + fw.local_port + db_name
           val name = user + "@" + fw + db_name + " via ssh " + ssh
           (url, name, Some(fw))
--- a/src/Pure/General/ssh.scala	Sat Apr 29 19:43:04 2017 +0200
+++ b/src/Pure/General/ssh.scala	Sat Apr 29 20:15:26 2017 +0200
@@ -134,16 +134,17 @@
 
   object Port_Forwarding
   {
-    def open(ssh: Session,
+    def open(ssh: Session, ssh_close: Boolean,
       local_host: String, local_port: Int, remote_host: String, remote_port: Int): Port_Forwarding =
     {
       val port = ssh.session.setPortForwardingL(local_host, local_port, remote_host, remote_port)
-      new Port_Forwarding(ssh, local_host, port, remote_host, remote_port)
+      new Port_Forwarding(ssh, ssh_close, local_host, port, remote_host, remote_port)
     }
   }
 
   class Port_Forwarding private[SSH](
     ssh: SSH.Session,
+    ssh_close: Boolean,
     val local_host: String,
     val local_port: Int,
     val remote_host: String,
@@ -152,7 +153,11 @@
     override def toString: String =
       local_host + ":" + local_port + ":" + remote_host + ":" + remote_port
 
-    def close() { ssh.session.delPortForwardingL(local_host, local_port) }
+    def close()
+    {
+      ssh.session.delPortForwardingL(local_host, local_port)
+      if (ssh_close) ssh.close()
+    }
   }
 
 
@@ -264,9 +269,11 @@
 
     /* port forwarding */
 
-    def port_forwarding(remote_port: Int, remote_host: String = "localhost",
-        local_port: Int = 0, local_host: String = "localhost"): Port_Forwarding =
-      Port_Forwarding.open(this, local_host, local_port, remote_host, remote_port)
+    def port_forwarding(
+        remote_port: Int, remote_host: String = "localhost",
+        local_port: Int = 0, local_host: String = "localhost",
+        ssh_close: Boolean = false): Port_Forwarding =
+      Port_Forwarding.open(this, ssh_close, local_host, local_port, remote_host, remote_port)
 
 
     /* sftp channel */