--- a/src/Pure/Tools/server.scala Thu Mar 08 21:09:22 2018 +0100
+++ b/src/Pure/Tools/server.scala Fri Mar 09 12:07:47 2018 +0100
@@ -16,10 +16,11 @@
{
/* protocol */
- val commands: Map[String, PartialFunction[JSON.T, JSON.T]] =
+ val commands: Map[String, PartialFunction[(Server, JSON.T), JSON.T]] =
Map(
- "help" -> { case JSON.empty => commands.keySet.toList.sorted },
- "echo" -> { case t => t })
+ "echo" -> { case (_, t) => t },
+ "help" -> { case (_, JSON.empty) => commands.keySet.toList.sorted },
+ "shutdown" -> { case (server, JSON.empty) => server.shutdown(); JSON.empty })
object Reply extends Enumeration
{
@@ -43,8 +44,11 @@
def print: String =
"server " + quote(name) + " = 127.0.0.1:" + port + " (password " + quote(password) + ")"
- def active: Boolean =
- try { (new Socket(InetAddress.getByName("127.0.0.1"), port)).close; true }
+ def connect(): Socket =
+ new Socket(InetAddress.getByName("127.0.0.1"), port)
+
+ def active(): Boolean =
+ try { connect().close; true }
catch { case _: IOException => false }
}
}
@@ -63,7 +67,7 @@
def find(db: SQLite.Database, name: String): Option[Data.Entry] =
list(db).find(entry => entry.name == name && entry.active)
- def start(name: String = "", port: Int = 0): (Data.Entry, Option[Thread]) =
+ def start(name: String = "", port: Int = 0): (Data.Entry, Option[Server]) =
{
using(SQLite.open_database(Data.database))(db =>
db.transaction {
@@ -84,7 +88,7 @@
stmt.execute()
})
- (entry, Some(server.thread))
+ (entry, Some(server))
}
})
}
@@ -95,11 +99,13 @@
db.transaction {
find(db, name) match {
case Some(entry) =>
- // FIXME shutdown server
- db.using_statement(Data.table.delete(Data.name.where_equal(name)))(_.execute)
+ using(entry.connect())(socket =>
+ {
+ using(socket.getOutputStream)(_.write(UTF8.bytes(entry.password + "\nshutdown")))
+ })
+ while(entry.active) { Thread.sleep(100) }
true
- case None =>
- false
+ case None => false
}
})
}
@@ -137,18 +143,20 @@
Output.writeln(entry.print, stdout = true)
}
else {
- val (entry, thread) = start(name, port)
+ val (entry, server) = start(name, port)
Output.writeln(entry.print, stdout = true)
- thread.foreach(_.join)
+ server.foreach(_.join)
}
})
}
class Server private(_port: Int)
{
+ server =>
+
private val server_socket = new ServerSocket(_port, 50, InetAddress.getByName("127.0.0.1"))
def port: Int = server_socket.getLocalPort
- def close { server_socket.close }
+ def shutdown() { server_socket.close }
val password: String = Library.UUID()
@@ -191,8 +199,8 @@
case Some(body) =>
proper_string(input) getOrElse "{}" match {
case JSON.Format(arg) =>
- if (body.isDefinedAt(arg)) {
- try { reply_ok(body(arg)) }
+ if (body.isDefinedAt((server, arg))) {
+ try { reply_ok(body(server, arg)) }
catch { case ERROR(msg) => reply_error(msg) }
}
else {
@@ -210,7 +218,7 @@
}
}
- lazy val thread: Thread =
+ private lazy val thread: Thread =
Standard_Thread.fork("server") {
var finished = false
while (!finished) {
@@ -222,4 +230,6 @@
}
}
}
+
+ def join { thread.join; shutdown() }
}