support for asynchronous tasks, with "cancel" command;
authorwenzelm
Wed, 14 Mar 2018 20:20:10 +0100
changeset 67860 5a6c483269f3
parent 67859 612846bff1ea
child 67861 cd1cac824ef8
support for asynchronous tasks, with "cancel" command;
src/Pure/Tools/server.scala
--- a/src/Pure/Tools/server.scala	Wed Mar 14 19:58:27 2018 +0100
+++ b/src/Pure/Tools/server.scala	Wed Mar 14 20:20:10 2018 +0100
@@ -67,6 +67,7 @@
         "help" -> { case (_, ()) => table.keySet.toList.sorted },
         "echo" -> { case (_, t) => t },
         "shutdown" -> { case (context, ()) => context.shutdown(); () },
+        "cancel" -> { case (context, JSON.Value.String(id)) => context.cancel_task(id) },
         "session_build" ->
           { case (context, Server_Commands.Session_Build(args)) =>
              Server_Commands.Session_Build.command(context.progress(), args)._1
@@ -83,7 +84,7 @@
 
   object Reply extends Enumeration
   {
-    val OK, ERROR, NOTE = Value
+    val OK, ERROR, FINISHED, FAILED, NOTE = Value
 
     def message(msg: String): JSON.Object.Entry = ("message" -> msg)
 
@@ -184,26 +185,56 @@
     def error_message(msg: String, more: JSON.Object.Entry*): Unit =
       message(Markup.ERROR_MESSAGE, msg, more:_*)
 
-    val logger: Connection_Logger = new Connection_Logger(context)
-    def progress(): Connection_Progress = new Connection_Progress(context)
+    def logger(more: JSON.Object.Entry*): Connection_Logger =
+      new Connection_Logger(context, more:_*)
+
+    def progress(more: JSON.Object.Entry*): Connection_Progress =
+      new Connection_Progress(context, more:_*)
 
     override def toString: String = connection.toString
+
+
+    /* asynchronous tasks */
+
+    private val _tasks = Synchronized(Set.empty[Task])
+
+    def make_task(body: Task => JSON.Object.T): Task =
+    {
+      val task = new Task(context, body)
+      _tasks.change(_ + task)
+      task
+    }
+
+    def remove_task(task: Task): Unit =
+      _tasks.change(_ - task)
+
+    def cancel_task(id: String): Unit =
+      _tasks.change(tasks => { tasks.find(task => task.id == id).foreach(_.cancel); tasks })
+
+    def close()
+    {
+      while(_tasks.change_result(tasks => { tasks.foreach(_.cancel); (tasks.nonEmpty, tasks) }))
+      { _tasks.value.foreach(_.join) }
+    }
   }
 
-  class Connection_Logger private[Server](context: Context) extends Logger
+  class Connection_Logger private[Server](context: Context, more: JSON.Object.Entry*)
+    extends Logger
   {
-    def apply(msg: => String): Unit = context.message(Markup.LOGGER, msg)
+    def apply(msg: => String): Unit = context.message(Markup.LOGGER, msg, more:_*)
 
     override def toString: String = context.toString
   }
 
-  class Connection_Progress private[Server](context: Context) extends Progress
+  class Connection_Progress private[Server](context: Context, more: JSON.Object.Entry*)
+    extends Progress
   {
-    override def echo(msg: String): Unit = context.writeln(msg)
-    override def echo_warning(msg: String): Unit = context.warning(msg)
-    override def echo_error_message(msg: String): Unit = context.error_message(msg)
+    override def echo(msg: String): Unit = context.writeln(msg, more:_*)
+    override def echo_warning(msg: String): Unit = context.warning(msg, more:_*)
+    override def echo_error_message(msg: String): Unit = context.error_message(msg, more:_*)
     override def theory(session: String, theory: String): Unit =
-      context.writeln(session + ": theory " + theory, "session" -> session, "theory" -> theory)
+      context.writeln(session + ": theory " + theory,
+        (List("session" -> session, "theory" -> theory) ::: more.toList):_*)
 
     @volatile private var is_stopped = false
     override def stopped: Boolean = is_stopped
@@ -212,6 +243,36 @@
     override def toString: String = context.toString
   }
 
+  class Task private[Server](val context: Context, body: Task => JSON.Object.T)
+  {
+    task =>
+
+    val id: String = Library.UUID()
+    val ident: JSON.Object.Entry = ("task" -> id)
+
+    val logger: Logger = context.logger(ident)
+
+    val progress: Connection_Progress = context.progress(ident)
+    def cancel { progress.stop }
+
+    private lazy val thread = Standard_Thread.fork("server_task")
+    {
+      Exn.capture { body(task) } match {
+        case Exn.Res(res) =>
+          context.reply(Reply.FINISHED, res + ident)
+        case Exn.Exn(exn: Server.Error) =>
+          context.reply(Reply.FAILED, JSON.Object(Reply.message(exn.message)) ++ exn.json + ident)
+        case Exn.Exn(ERROR(msg)) =>
+          context.reply(Reply.FAILED, JSON.Object(Reply.message(msg)) + ident)
+        case Exn.Exn(exn) => throw exn
+      }
+      progress.stop
+      context.remove_task(task)
+    }
+    def start { thread }
+    def join { thread.join }
+  }
+
 
   /* server info */
 
@@ -398,48 +459,52 @@
 
   private def handle(connection: Server.Connection)
   {
-    val context = new Server.Context(server, connection)
-
-    connection.read_message() match {
-      case Some(msg) if msg == password =>
-        connection.reply_ok(())
-        var finished = false
-        while (!finished) {
-          connection.read_message() match {
-            case None => finished = true
-            case Some("") => context.notify("Command 'help' provides list of commands")
-            case Some(msg) =>
-              val (name, argument) = Server.Argument.split(msg)
-              name match {
-                case Server.Command(cmd) =>
-                  argument match {
-                    case Server.Argument(arg) =>
-                      if (cmd.isDefinedAt((context, arg))) {
-                        Exn.capture { cmd((context, arg)) } match {
-                          case Exn.Res(res) => connection.reply_ok(res)
-                          case Exn.Exn(exn: Server.Error) =>
-                            connection.reply_error_message(exn.message, exn.json.toList:_*)
-                          case Exn.Exn(ERROR(msg)) =>
-                            connection.reply_error_message(msg)
-                          case Exn.Exn(exn) => throw exn
+    using(new Server.Context(server, connection))(context =>
+    {
+      connection.read_message() match {
+        case Some(msg) if msg == password =>
+          connection.reply_ok(())
+          var finished = false
+          while (!finished) {
+            connection.read_message() match {
+              case None => finished = true
+              case Some("") => context.notify("Command 'help' provides list of commands")
+              case Some(msg) =>
+                val (name, argument) = Server.Argument.split(msg)
+                name match {
+                  case Server.Command(cmd) =>
+                    argument match {
+                      case Server.Argument(arg) =>
+                        if (cmd.isDefinedAt((context, arg))) {
+                          Exn.capture { cmd((context, arg)) } match {
+                            case Exn.Res(task: Server.Task) =>
+                              connection.reply_ok(JSON.Object(task.ident))
+                              task.start
+                            case Exn.Res(res) => connection.reply_ok(res)
+                            case Exn.Exn(exn: Server.Error) =>
+                              connection.reply_error_message(exn.message, exn.json.toList:_*)
+                            case Exn.Exn(ERROR(msg)) =>
+                              connection.reply_error_message(msg)
+                            case Exn.Exn(exn) => throw exn
+                          }
                         }
-                      }
-                      else {
+                        else {
+                          connection.reply_error_message(
+                            "Bad argument for command " + Library.single_quote(name),
+                            "argument" -> argument)
+                        }
+                      case _ =>
                         connection.reply_error_message(
-                          "Bad argument for command " + Library.single_quote(name),
+                          "Malformed argument for command " + Library.single_quote(name),
                           "argument" -> argument)
-                      }
-                    case _ =>
-                      connection.reply_error_message(
-                        "Malformed argument for command " + Library.single_quote(name),
-                        "argument" -> argument)
-                  }
-                case _ => connection.reply_error("Bad command " + Library.single_quote(name))
-              }
+                    }
+                  case _ => connection.reply_error("Bad command " + Library.single_quote(name))
+                }
+            }
           }
-        }
-      case _ =>
-    }
+        case _ =>
+      }
+    })
   }
 
   private lazy val server_thread: Thread =