server commands may access Server;
authorwenzelm
Fri, 09 Mar 2018 12:07:47 +0100
changeset 67785 ad96390ceb5d
parent 67784 543e36ae489c
child 67786 be6d69595ca7
server commands may access Server; Server.stop: proper shutdown; clarified signature;
src/Pure/Tools/server.scala
--- a/src/Pure/Tools/server.scala	Thu Mar 08 21:09:22 2018 +0100
+++ b/src/Pure/Tools/server.scala	Fri Mar 09 12:07:47 2018 +0100
@@ -16,10 +16,11 @@
 {
   /* protocol */
 
-  val commands: Map[String, PartialFunction[JSON.T, JSON.T]] =
+  val commands: Map[String, PartialFunction[(Server, JSON.T), JSON.T]] =
     Map(
-      "help" -> { case JSON.empty => commands.keySet.toList.sorted },
-      "echo" -> { case t => t })
+      "echo" -> { case (_, t) => t },
+      "help" -> { case (_, JSON.empty) => commands.keySet.toList.sorted },
+      "shutdown" -> { case (server, JSON.empty) => server.shutdown(); JSON.empty })
 
   object Reply extends Enumeration
   {
@@ -43,8 +44,11 @@
       def print: String =
         "server " + quote(name) + " = 127.0.0.1:" + port + " (password " + quote(password) + ")"
 
-      def active: Boolean =
-        try { (new Socket(InetAddress.getByName("127.0.0.1"), port)).close; true }
+      def connect(): Socket =
+        new Socket(InetAddress.getByName("127.0.0.1"), port)
+
+      def active(): Boolean =
+        try { connect().close; true }
         catch { case _: IOException => false }
     }
   }
@@ -63,7 +67,7 @@
   def find(db: SQLite.Database, name: String): Option[Data.Entry] =
     list(db).find(entry => entry.name == name && entry.active)
 
-  def start(name: String = "", port: Int = 0): (Data.Entry, Option[Thread]) =
+  def start(name: String = "", port: Int = 0): (Data.Entry, Option[Server]) =
   {
     using(SQLite.open_database(Data.database))(db =>
       db.transaction {
@@ -84,7 +88,7 @@
               stmt.execute()
             })
 
-            (entry, Some(server.thread))
+            (entry, Some(server))
         }
       })
   }
@@ -95,11 +99,13 @@
       db.transaction {
         find(db, name) match {
           case Some(entry) =>
-            // FIXME shutdown server
-            db.using_statement(Data.table.delete(Data.name.where_equal(name)))(_.execute)
+            using(entry.connect())(socket =>
+              {
+                using(socket.getOutputStream)(_.write(UTF8.bytes(entry.password + "\nshutdown")))
+              })
+            while(entry.active) { Thread.sleep(100) }
             true
-          case None =>
-            false
+          case None => false
         }
       })
   }
@@ -137,18 +143,20 @@
           Output.writeln(entry.print, stdout = true)
       }
       else {
-        val (entry, thread) = start(name, port)
+        val (entry, server) = start(name, port)
         Output.writeln(entry.print, stdout = true)
-        thread.foreach(_.join)
+        server.foreach(_.join)
       }
     })
 }
 
 class Server private(_port: Int)
 {
+  server =>
+
   private val server_socket = new ServerSocket(_port, 50, InetAddress.getByName("127.0.0.1"))
   def port: Int = server_socket.getLocalPort
-  def close { server_socket.close }
+  def shutdown() { server_socket.close }
 
   val password: String = Library.UUID()
 
@@ -191,8 +199,8 @@
                 case Some(body) =>
                   proper_string(input) getOrElse "{}" match {
                     case JSON.Format(arg) =>
-                      if (body.isDefinedAt(arg)) {
-                        try { reply_ok(body(arg)) }
+                      if (body.isDefinedAt((server, arg))) {
+                        try { reply_ok(body(server, arg)) }
                         catch { case ERROR(msg) => reply_error(msg) }
                       }
                       else {
@@ -210,7 +218,7 @@
     }
   }
 
-  lazy val thread: Thread =
+  private lazy val thread: Thread =
     Standard_Thread.fork("server") {
       var finished = false
       while (!finished) {
@@ -222,4 +230,6 @@
         }
       }
     }
+
+  def join { thread.join; shutdown() }
 }