--- a/src/Pure/Tools/server.scala Fri Mar 09 12:07:47 2018 +0100
+++ b/src/Pure/Tools/server.scala Fri Mar 09 12:29:56 2018 +0100
@@ -28,6 +28,49 @@
}
+ /* socket connection */
+
+ object Connection
+ {
+ def apply(socket: Socket): Connection =
+ new Connection(socket)
+ }
+
+ class Connection private(socket: Socket)
+ {
+ override def toString: String = socket.toString
+
+ def close() { socket.close }
+
+ val reader = new BufferedReader(new InputStreamReader(socket.getInputStream, UTF8.charset))
+ val writer = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream, UTF8.charset))
+
+ def read_line(): Option[String] =
+ reader.readLine() match {
+ case null => None
+ case line => Some(line)
+ }
+
+ def write_line(msg: String)
+ {
+ require(split_lines(msg).length <= 1)
+ writer.write(msg)
+ writer.newLine()
+ writer.flush()
+ }
+
+ def reply(r: Server.Reply.Value, t: JSON.T)
+ {
+ write_line(if (t == JSON.empty) r.toString else r.toString + " " + JSON.Format(t))
+ }
+
+ def reply_ok(t: JSON.T) { reply(Server.Reply.OK, t) }
+ def reply_error(t: JSON.T) { reply(Server.Reply.ERROR, t) }
+ def reply_error_message(message: String, more: (String, JSON.T)*): Unit =
+ reply_error(Map("message" -> message) ++ more)
+ }
+
+
/* per-user servers */
object Data
@@ -44,11 +87,11 @@
def print: String =
"server " + quote(name) + " = 127.0.0.1:" + port + " (password " + quote(password) + ")"
- def connect(): Socket =
- new Socket(InetAddress.getByName("127.0.0.1"), port)
+ def connection(): Connection =
+ Connection(new Socket(InetAddress.getByName("127.0.0.1"), port))
def active(): Boolean =
- try { connect().close; true }
+ try { connection().close; true }
catch { case _: IOException => false }
}
}
@@ -99,9 +142,10 @@
db.transaction {
find(db, name) match {
case Some(entry) =>
- using(entry.connect())(socket =>
+ using(entry.connection())(connection =>
{
- using(socket.getOutputStream)(_.write(UTF8.bytes(entry.password + "\nshutdown")))
+ connection.write_line(entry.password)
+ connection.write_line("shutdown")
})
while(entry.active) { Thread.sleep(100) }
true
@@ -156,59 +200,38 @@
private val server_socket = new ServerSocket(_port, 50, InetAddress.getByName("127.0.0.1"))
def port: Int = server_socket.getLocalPort
- def shutdown() { server_socket.close }
val password: String = Library.UUID()
- private def handle_connection(socket: Socket)
+ private def handle(connection: Server.Connection)
{
- val reader = new BufferedReader(new InputStreamReader(socket.getInputStream, UTF8.charset))
- val writer = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream, UTF8.charset))
-
- def reply_line(msg: String)
- {
- require(split_lines(msg).length <= 1)
- writer.write(msg)
- writer.newLine()
- writer.flush()
- }
-
- def reply(r: Server.Reply.Value, t: JSON.T)
- {
- reply_line(if (t == JSON.empty) r.toString else r.toString + " " + JSON.Format(t))
- }
-
- def reply_ok(t: JSON.T) { reply(Server.Reply.OK, t) }
- def reply_error(t: JSON.T) { reply(Server.Reply.ERROR, t) }
- def reply_error_message(message: String, more: (String, JSON.T)*): Unit =
- reply_error(Map("message" -> message) ++ more)
-
- reader.readLine() match {
- case null =>
- case bad if bad != password => reply_error("Bad password -- connection closed")
+ connection.read_line() match {
+ case None =>
+ case Some(line) if line != password =>
+ connection.reply_error("Bad password -- connection closed")
case _ =>
var finished = false
while (!finished) {
- reader.readLine() match {
- case null => finished = true
- case line =>
+ connection.read_line() match {
+ case None => finished = true
+ case Some(line) =>
val cmd = line.takeWhile(c => Symbol.is_ascii_letter(c) || Symbol.is_ascii_letdig(c))
val input = line.substring(cmd.length).dropWhile(Symbol.is_ascii_blank(_))
Server.commands.get(cmd) match {
- case None => reply_error("Bad command " + quote(cmd))
+ case None => connection.reply_error("Bad command " + quote(cmd))
case Some(body) =>
proper_string(input) getOrElse "{}" match {
case JSON.Format(arg) =>
if (body.isDefinedAt((server, arg))) {
- try { reply_ok(body(server, arg)) }
- catch { case ERROR(msg) => reply_error(msg) }
+ try { connection.reply_ok(body(server, arg)) }
+ catch { case ERROR(msg) => connection.reply_error(msg) }
}
else {
- reply_error_message(
+ connection.reply_error_message(
"Bad argument for command", "command" -> cmd, "argument" -> arg)
}
case _ =>
- reply_error_message(
+ connection.reply_error_message(
"Malformed command-line", "command" -> cmd, "input" -> input)
}
}
@@ -218,18 +241,20 @@
}
}
- private lazy val thread: Thread =
+ private lazy val server_thread: Thread =
Standard_Thread.fork("server") {
var finished = false
while (!finished) {
Exn.capture(server_socket.accept) match {
case Exn.Res(socket) =>
Standard_Thread.fork("server_connection")
- { try { handle_connection(socket) } finally { socket.close } }
+ { using(Server.Connection(socket))(handle(_)) }
case Exn.Exn(_) => finished = true
}
}
}
- def join { thread.join; shutdown() }
+ def join { server_thread.join; shutdown() }
+
+ def shutdown() { server_socket.close }
}