diff -r 3a6ab890832f -r 0c2ed45ece20 src/Pure/Tools/server.scala --- a/src/Pure/Tools/server.scala Mon Mar 12 11:37:30 2018 +0100 +++ b/src/Pure/Tools/server.scala Mon Mar 12 16:32:33 2018 +0100 @@ -60,13 +60,13 @@ object Command { - type T = PartialFunction[(Server, Any), Any] + type T = PartialFunction[(Context, Any), Any] private val table: Map[String, T] = Map( "echo" -> { case (_, t) => t }, "help" -> { case (_, ()) => table.keySet.toList.sorted }, - "shutdown" -> { case (server, ()) => server.close(); () }) + "shutdown" -> { case (context, ()) => context.shutdown(); () }) def unapply(name: String): Option[T] = table.get(name) } @@ -112,9 +112,14 @@ private val in = new BufferedInputStream(socket.getInputStream) private val out = new BufferedOutputStream(socket.getOutputStream) + private val out_lock: AnyRef = new Object def tty_loop(interrupt: Option[() => Unit] = None): TTY_Loop = - new TTY_Loop(new OutputStreamWriter(out), new InputStreamReader(in), interrupt = interrupt) + new TTY_Loop( + new OutputStreamWriter(out), + new InputStreamReader(in), + writer_lock = out_lock, + interrupt = interrupt) def read_message(): Option[String] = try { @@ -126,7 +131,7 @@ } catch { case _: SocketException => None } - def write_message(msg: String) + def write_message(msg: String): Unit = out_lock.synchronized { val b = UTF8.bytes(msg) if (b.length > 100 || b.contains(10)) { @@ -150,8 +155,52 @@ reply_error(Map("message" -> message) ++ more) def notify(arg: Any) { reply(Server.Reply.NOTE, arg) } - def notify_message(message: String, more: (String, JSON.T)*): Unit = - notify(Map("message" -> message) ++ more) + def notify_message(kind: String, msg: String, more: (String, JSON.T)*): Unit = + notify(Map(Markup.KIND -> kind, "message" -> msg) ++ more) + } + + + /* context with output channels */ + + class Context private[Server](server: Server, connection: Connection) + { + context => + + def shutdown() { server.close() } + + def message(kind: String, msg: String, more: (String, JSON.T)*): Unit = + connection.notify_message(kind, msg, more:_*) + def writeln(msg: String, more: (String, JSON.T)*): Unit = message(Markup.WRITELN, msg, more:_*) + def warning(msg: String, more: (String, JSON.T)*): Unit = message(Markup.WARNING, msg, more:_*) + def error_message(msg: String, more: (String, JSON.T)*): Unit = + message(Markup.ERROR_MESSAGE, msg, more:_*) + + val logger: Connection_Logger = new Connection_Logger(context) + def progress(): Connection_Progress = new Connection_Progress(context) + + override def toString: String = connection.toString + } + + class Connection_Logger private[Server](context: Context) extends Logger + { + def apply(msg: => String): Unit = context.message(Markup.LOGGER, msg) + + override def toString: String = context.toString + } + + class Connection_Progress private[Server](context: Context) extends Progress + { + override def echo(msg: String): Unit = context.writeln(msg) + override def echo_warning(msg: String): Unit = context.warning(msg) + override def echo_error_message(msg: String): Unit = context.error_message(msg) + override def theory(session: String, theory: String): Unit = + context.writeln(session + ": theory " + theory, "session" -> session, "theory" -> theory) + + @volatile private var is_stopped = false + override def stopped: Boolean = is_stopped + def stop { is_stopped = true } + + override def toString: String = context.toString } @@ -340,6 +389,8 @@ private def handle(connection: Server.Connection) { + val context = new Server.Context(server, connection) + connection.read_message() match { case Some(msg) if msg == password => connection.reply_ok(()) @@ -347,16 +398,15 @@ while (!finished) { connection.read_message() match { case None => finished = true - case Some("") => - connection.notify_message("Command 'help' provides list of commands") + case Some("") => context.writeln("Command 'help' provides list of commands") case Some(msg) => val (name, argument) = Server.Argument.split(msg) name match { case Server.Command(cmd) => argument match { case Server.Argument(arg) => - if (cmd.isDefinedAt((server, arg))) { - Exn.capture { cmd((server, arg)) } match { + if (cmd.isDefinedAt((context, arg))) { + Exn.capture { cmd((context, arg)) } match { case Exn.Res(res) => connection.reply_ok(res) case Exn.Exn(ERROR(msg)) => connection.reply_error(msg) case Exn.Exn(exn) => throw exn