clarified socket connection;
authorwenzelm
Fri, 09 Mar 2018 12:29:56 +0100
changeset 67786 be6d69595ca7
parent 67785 ad96390ceb5d
child 67787 8335d88195c4
clarified socket connection;
src/Pure/Tools/server.scala
--- a/src/Pure/Tools/server.scala	Fri Mar 09 12:07:47 2018 +0100
+++ b/src/Pure/Tools/server.scala	Fri Mar 09 12:29:56 2018 +0100
@@ -28,6 +28,49 @@
   }
 
 
+  /* socket connection */
+
+  object Connection
+  {
+    def apply(socket: Socket): Connection =
+      new Connection(socket)
+  }
+
+  class Connection private(socket: Socket)
+  {
+    override def toString: String = socket.toString
+
+    def close() { socket.close }
+
+    val reader = new BufferedReader(new InputStreamReader(socket.getInputStream, UTF8.charset))
+    val writer = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream, UTF8.charset))
+
+    def read_line(): Option[String] =
+      reader.readLine() match {
+        case null => None
+        case line => Some(line)
+      }
+
+    def write_line(msg: String)
+    {
+      require(split_lines(msg).length <= 1)
+      writer.write(msg)
+      writer.newLine()
+      writer.flush()
+    }
+
+    def reply(r: Server.Reply.Value, t: JSON.T)
+    {
+      write_line(if (t == JSON.empty) r.toString else r.toString + " " + JSON.Format(t))
+    }
+
+    def reply_ok(t: JSON.T) { reply(Server.Reply.OK, t) }
+    def reply_error(t: JSON.T) { reply(Server.Reply.ERROR, t) }
+    def reply_error_message(message: String, more: (String, JSON.T)*): Unit =
+      reply_error(Map("message" -> message) ++ more)
+  }
+
+
   /* per-user servers */
 
   object Data
@@ -44,11 +87,11 @@
       def print: String =
         "server " + quote(name) + " = 127.0.0.1:" + port + " (password " + quote(password) + ")"
 
-      def connect(): Socket =
-        new Socket(InetAddress.getByName("127.0.0.1"), port)
+      def connection(): Connection =
+        Connection(new Socket(InetAddress.getByName("127.0.0.1"), port))
 
       def active(): Boolean =
-        try { connect().close; true }
+        try { connection().close; true }
         catch { case _: IOException => false }
     }
   }
@@ -99,9 +142,10 @@
       db.transaction {
         find(db, name) match {
           case Some(entry) =>
-            using(entry.connect())(socket =>
+            using(entry.connection())(connection =>
               {
-                using(socket.getOutputStream)(_.write(UTF8.bytes(entry.password + "\nshutdown")))
+                connection.write_line(entry.password)
+                connection.write_line("shutdown")
               })
             while(entry.active) { Thread.sleep(100) }
             true
@@ -156,59 +200,38 @@
 
   private val server_socket = new ServerSocket(_port, 50, InetAddress.getByName("127.0.0.1"))
   def port: Int = server_socket.getLocalPort
-  def shutdown() { server_socket.close }
 
   val password: String = Library.UUID()
 
-  private def handle_connection(socket: Socket)
+  private def handle(connection: Server.Connection)
   {
-    val reader = new BufferedReader(new InputStreamReader(socket.getInputStream, UTF8.charset))
-    val writer = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream, UTF8.charset))
-
-    def reply_line(msg: String)
-    {
-      require(split_lines(msg).length <= 1)
-      writer.write(msg)
-      writer.newLine()
-      writer.flush()
-    }
-
-    def reply(r: Server.Reply.Value, t: JSON.T)
-    {
-      reply_line(if (t == JSON.empty) r.toString else r.toString + " " + JSON.Format(t))
-    }
-
-    def reply_ok(t: JSON.T) { reply(Server.Reply.OK, t) }
-    def reply_error(t: JSON.T) { reply(Server.Reply.ERROR, t) }
-    def reply_error_message(message: String, more: (String, JSON.T)*): Unit =
-      reply_error(Map("message" -> message) ++ more)
-
-    reader.readLine() match {
-      case null =>
-      case bad if bad != password => reply_error("Bad password -- connection closed")
+    connection.read_line() match {
+      case None =>
+      case Some(line) if line != password =>
+        connection.reply_error("Bad password -- connection closed")
       case _ =>
         var finished = false
         while (!finished) {
-          reader.readLine() match {
-            case null => finished = true
-            case line =>
+          connection.read_line() match {
+            case None => finished = true
+            case Some(line) =>
               val cmd = line.takeWhile(c => Symbol.is_ascii_letter(c) || Symbol.is_ascii_letdig(c))
               val input = line.substring(cmd.length).dropWhile(Symbol.is_ascii_blank(_))
               Server.commands.get(cmd) match {
-                case None => reply_error("Bad command " + quote(cmd))
+                case None => connection.reply_error("Bad command " + quote(cmd))
                 case Some(body) =>
                   proper_string(input) getOrElse "{}" match {
                     case JSON.Format(arg) =>
                       if (body.isDefinedAt((server, arg))) {
-                        try { reply_ok(body(server, arg)) }
-                        catch { case ERROR(msg) => reply_error(msg) }
+                        try { connection.reply_ok(body(server, arg)) }
+                        catch { case ERROR(msg) => connection.reply_error(msg) }
                       }
                       else {
-                        reply_error_message(
+                        connection.reply_error_message(
                           "Bad argument for command", "command" -> cmd, "argument" -> arg)
                       }
                     case _ =>
-                      reply_error_message(
+                      connection.reply_error_message(
                         "Malformed command-line", "command" -> cmd, "input" -> input)
                   }
               }
@@ -218,18 +241,20 @@
     }
   }
 
-  private lazy val thread: Thread =
+  private lazy val server_thread: Thread =
     Standard_Thread.fork("server") {
       var finished = false
       while (!finished) {
         Exn.capture(server_socket.accept) match {
           case Exn.Res(socket) =>
             Standard_Thread.fork("server_connection")
-              { try { handle_connection(socket) } finally { socket.close } }
+              { using(Server.Connection(socket))(handle(_)) }
           case Exn.Exn(_) => finished = true
         }
       }
     }
 
-  def join { thread.join; shutdown() }
+  def join { server_thread.join; shutdown() }
+
+  def shutdown() { server_socket.close }
 }