merged
authorwenzelm
Wed, 14 Sep 2022 17:32:56 +0200
changeset 76152 a95196ef33f0
parent 76143 e278bf6430cf (current diff)
parent 76151 21492610ae5b (diff)
child 76153 bf9f2f4069b9
merged
--- a/etc/options	Wed Sep 14 09:15:00 2022 +0000
+++ b/etc/options	Wed Sep 14 17:32:56 2022 +0200
@@ -289,6 +289,9 @@
 
 section "Secure Shell"
 
+option ssh_multiplexing : bool = true
+  -- "enable multiplexing of SSH sessions (ignored on Windows)"
+
 option ssh_compression : bool = true
   -- "enable SSH compression"
 
--- a/src/Pure/General/ssh.scala	Wed Sep 14 09:15:00 2022 +0000
+++ b/src/Pure/General/ssh.scala	Wed Sep 14 17:32:56 2022 +0200
@@ -54,16 +54,22 @@
       (if (control_path.nonEmpty) List("ControlPath=" + control_path) else Nil)
     }
 
-    def make_command(command: String, config: List[String]): String =
-      Bash.string(command) + " " + config.map(entry => "-o " + Bash.string(entry)).mkString(" ")
+    def option(entry: String): String = "-o " + Bash.string(entry)
+    def option(x: String, y: String): String = option(entry(x, y))
+    def option(x: String, y: Int): String = option(entry(x, y))
+    def option(x: String, y: Boolean): String = option(entry(x, y))
+
+    def command(command: String, config: List[String]): String =
+      Bash.string(command) + config.map(entry => " " + option(entry)).mkString
   }
 
   def sftp_string(str: String): String = {
-    val special = Set(' ', '*', '?', '{', '}')
-    if (str.exists(special)) {
+    val special = "[]?*\\{} \"'"
+    if (str.isEmpty) "\"\""
+    else if (str.exists(special.contains)) {
       val res = new StringBuilder
       for (c <- str) {
-        if (special(c)) res += '\\'
+        if (special.contains(c)) res += '\\'
         res += c
       }
       res.toString()
@@ -78,17 +84,12 @@
     options: Options,
     host: String,
     port: Int = 0,
-    user: String = "",
-    multiplex: Boolean = !Platform.is_windows
+    user: String = ""
   ): Session = {
+    val multiplex = options.bool("ssh_multiplexing") && !Platform.is_windows
     val (control_master, control_path) =
-      if (multiplex) {
-        val file = Isabelle_System.tmp_file("ssh_socket")
-        file.delete()
-        (true, file.getPath)
-      }
+      if (multiplex) (true, Isabelle_System.tmp_file("ssh_socket", initialized = false).getPath)
       else (false, "")
-
     new Session(options, host, port, user, control_master, control_path)
   }
 
@@ -110,7 +111,7 @@
     override def rsync_prefix: String = user_prefix + host + ":"
 
 
-    /* ssh commands */
+    /* local ssh commands */
 
     def run_command(command: String,
       master: Boolean = false,
@@ -124,7 +125,7 @@
         Config.make(options, port = port, user = user,
           control_master = master, control_path = control_path)
       val cmd =
-        Config.make_command(command, config) +
+        Config.command(command, config) +
         (if (opts.nonEmpty) " " + opts else "") +
         (if (args.nonEmpty) " -- " + args else "")
       Isabelle_System.bash(cmd, progress_stdout = progress_stdout,
@@ -153,7 +154,17 @@
 
     /* init and exit */
 
-    val user_home: String = run_ssh(master = control_master, args = "printenv HOME").check.out
+    val user_home: String = {
+      val args = Bash.string("printenv HOME\nprintenv SHELL")
+      run_ssh(master = control_master, args = args).check.out_lines match {
+        case List(user_home, shell) =>
+          if (shell.endsWith("/bash")) user_home
+          else {
+            error("Bad SHELL for " + quote(toString) + " -- expected GNU bash, but found " + shell)
+          }
+        case _ => error("Malformed remote environment for " + quote(toString))
+      }
+    }
 
     val settings: JMap[String, String] = JMap.of("HOME", user_home, "USER_HOME", user_home)
 
@@ -170,8 +181,7 @@
       settings: Boolean = true,
       strict: Boolean = true
     ): Process_Result = {
-      val args1 =
-        Bash.string(host) + " export " + Bash.string("USER_HOME=\"$HOME\"") + "\n" + cmd_line
+      val args1 = Bash.string(host) + " " + Bash.string("export USER_HOME=\"$HOME\"\n" + cmd_line)
       run_command("ssh", args = args1, progress_stdout = progress_stdout,
         progress_stderr = progress_stderr, strict = strict)
     }
@@ -255,24 +265,45 @@
       local_host: String = "localhost",
       ssh_close: Boolean = false
     ): Port_Forwarding = {
-      if (control_path.isEmpty) error("SSH port forwarding requires multiplexing")
+      val port = if (local_port > 0) local_port else Isabelle_System.local_port()
+
+      val forward = List(local_host, port, remote_host, remote_port).mkString(":")
+      val forward_option = "-L " + Bash.string(forward)
 
-      val port =
-        if (local_port > 0) local_port
+      val cancel: () => Unit =
+        if (control_path.nonEmpty) {
+          run_ssh(opts = forward_option + " -O forward").check
+          () => run_ssh(opts = forward_option + " -O cancel")  // permissive
+        }
         else {
-          // FIXME race condition
-          val dummy = new ServerSocket(0)
-          val port = dummy.getLocalPort
-          dummy.close()
-          port
+          val result = Synchronized[Exn.Result[Boolean]](Exn.Res(false))
+          val thread = Isabelle_Thread.fork("port_forwarding") {
+            val opts =
+              forward_option +
+                " " + Config.option("SessionType", "none") +
+                " " + Config.option("PermitLocalCommand", true) +
+                " " + Config.option("LocalCommand", "pwd")
+            try {
+              run_command("ssh", opts = opts, args = Bash.string(host),
+                progress_stdout = _ => result.change(_ => Exn.Res(true))).check
+            }
+            catch { case exn: Throwable => result.change(_ => Exn.Exn(exn)) }
+          }
+          result.guarded_access {
+            case res@Exn.Res(ok) => if (ok) Some((), res) else None
+            case Exn.Exn(exn) => throw exn
+          }
+          () => thread.interrupt()
         }
-      val string = List(local_host, port, remote_host, remote_port).mkString(":")
-      run_ssh(opts = "-L " + Bash.string(string) + " -O forward").check
+
+      val shutdown_hook =
+        Isabelle_System.create_shutdown_hook { cancel() }
 
       new Port_Forwarding(host, port, remote_host, remote_port) {
-        override def toString: String = string
+        override def toString: String = forward
         override def close(): Unit = {
-          run_ssh(opts = "-L " + Bash.string(string) + " -O cancel").check
+          cancel()
+          Isabelle_System.remove_shutdown_hook(shutdown_hook)
           if (ssh_close) ssh.close()
         }
       }
--- a/src/Pure/System/bash.scala	Wed Sep 14 09:15:00 2022 +0000
+++ b/src/Pure/System/bash.scala	Wed Sep 14 17:32:56 2022 +0200
@@ -146,19 +146,13 @@
     }
 
 
-    // JVM shutdown hook
-
-    private val shutdown_hook = Isabelle_Thread.create(() => terminate())
-
-    try { Runtime.getRuntime.addShutdownHook(shutdown_hook) }
-    catch { case _: IllegalStateException => }
-
-
     // cleanup
 
+    private val shutdown_hook =
+      Isabelle_System.create_shutdown_hook { terminate() }
+
     private def do_cleanup(): Unit = {
-      try { Runtime.getRuntime.removeShutdownHook(shutdown_hook) }
-      catch { case _: IllegalStateException => }
+      Isabelle_System.remove_shutdown_hook(shutdown_hook)
 
       script_file.delete()
       winpid_file.foreach(_.delete())
--- a/src/Pure/System/isabelle_system.scala	Wed Sep 14 09:15:00 2022 +0000
+++ b/src/Pure/System/isabelle_system.scala	Wed Sep 14 17:32:56 2022 +0200
@@ -9,6 +9,7 @@
 
 import java.util.{Map => JMap, HashMap}
 import java.io.{File => JFile, IOException}
+import java.net.ServerSocket
 import java.nio.file.{Path => JPath, Files, SimpleFileVisitor, FileVisitResult,
   StandardCopyOption, FileSystemException}
 import java.nio.file.attribute.BasicFileAttributes
@@ -269,10 +270,15 @@
     File.platform_file(path)
   }
 
-  def tmp_file(name: String, ext: String = "", base_dir: JFile = isabelle_tmp_prefix()): JFile = {
+  def tmp_file(
+    name: String,
+    ext: String = "",
+    base_dir: JFile = isabelle_tmp_prefix(),
+    initialized: Boolean = true
+  ): JFile = {
     val suffix = if (ext == "") "" else "." + ext
     val file = Files.createTempFile(base_dir.toPath, name, suffix).toFile
-    file.deleteOnExit()
+    if (initialized) file.deleteOnExit() else file.delete()
     file
   }
 
@@ -344,6 +350,32 @@
   }
 
 
+  /* TCP/IP ports */
+
+  def local_port(): Int = {
+    val socket = new ServerSocket(0)
+    val port = socket.getLocalPort
+    socket.close()
+    port
+  }
+
+
+  /* JVM shutdown hook */
+
+  def create_shutdown_hook(body: => Unit): Thread = {
+    val shutdown_hook = Isabelle_Thread.create(new Runnable { def run: Unit = body })
+
+    try { Runtime.getRuntime.addShutdownHook(shutdown_hook) }
+    catch { case _: IllegalStateException => }
+
+    shutdown_hook
+  }
+
+  def remove_shutdown_hook(shutdown_hook: Thread): Unit =
+    try { Runtime.getRuntime.removeShutdownHook(shutdown_hook) }
+    catch { case _: IllegalStateException => }
+
+
 
   /** external processes **/