src/Pure/Tools/server.scala
author wenzelm
Sat Jun 09 21:52:16 2018 +0200 (12 months ago)
changeset 68410 4e27f5c361d2
parent 67941 49a34b2fa788
child 68530 a110dcc9a4c7
permissions -rw-r--r--
clarified signature: more uniform theory_message (see also d7920eb7de54);
     1 /*  Title:      Pure/Tools/server.scala
     2     Author:     Makarius
     3 
     4 Resident Isabelle servers.
     5 
     6 Message formats:
     7   - short message (single line):
     8       NAME ARGUMENT
     9   - long message (multiple lines):
    10       BYTE_LENGTH
    11       NAME ARGUMENT
    12 
    13 Argument formats:
    14   - Unit as empty string
    15   - XML.Elem in YXML notation
    16   - JSON.T in standard notation
    17 */
    18 
    19 package isabelle
    20 
    21 
    22 import java.io.{BufferedInputStream, BufferedOutputStream, InputStreamReader, OutputStreamWriter,
    23   IOException}
    24 import java.net.{Socket, SocketException, SocketTimeoutException, ServerSocket, InetAddress}
    25 
    26 
    27 object Server
    28 {
    29   /* message argument */
    30 
    31   object Argument
    32   {
    33     def is_name_char(c: Char): Boolean =
    34       Symbol.is_ascii_letter(c) || Symbol.is_ascii_digit(c) || c == '_' || c == '.'
    35 
    36     def split(msg: String): (String, String) =
    37     {
    38       val name = msg.takeWhile(is_name_char(_))
    39       val argument = msg.substring(name.length).dropWhile(Symbol.is_ascii_blank(_))
    40       (name, argument)
    41     }
    42 
    43     def print(arg: Any): String =
    44       arg match {
    45         case () => ""
    46         case t: XML.Elem => YXML.string_of_tree(t)
    47         case t: JSON.T => JSON.Format(t)
    48       }
    49 
    50     def parse(argument: String): Any =
    51       if (argument == "") ()
    52       else if (YXML.detect_elem(argument)) YXML.parse_elem(argument)
    53       else JSON.parse(argument, strict = false)
    54 
    55     def unapply(argument: String): Option[Any] =
    56       try { Some(parse(argument)) }
    57       catch { case ERROR(_) => None }
    58   }
    59 
    60 
    61   /* input command */
    62 
    63   object Command
    64   {
    65     type T = PartialFunction[(Context, Any), Any]
    66 
    67     private val table: Map[String, T] =
    68       Map(
    69         "help" -> { case (_, ()) => table.keySet.toList.sorted },
    70         "echo" -> { case (_, t) => t },
    71         "shutdown" -> { case (context, ()) => context.server.shutdown() },
    72         "cancel" ->
    73           { case (context, Server_Commands.Cancel(args)) => context.cancel_task(args.task) },
    74         "session_build" ->
    75           { case (context, Server_Commands.Session_Build(args)) =>
    76               context.make_task(task =>
    77                 Server_Commands.Session_Build.command(args, progress = task.progress)._1)
    78           },
    79         "session_start" ->
    80           { case (context, Server_Commands.Session_Start(args)) =>
    81               context.make_task(task =>
    82                 {
    83                   val (res, entry) =
    84                     Server_Commands.Session_Start.command(
    85                       args, progress = task.progress, log = context.server.log)
    86                   context.server.add_session(entry)
    87                   res
    88                 })
    89           },
    90         "session_stop" ->
    91           { case (context, Server_Commands.Session_Stop(id)) =>
    92               context.make_task(_ =>
    93                 {
    94                   val session = context.server.remove_session(id)
    95                   Server_Commands.Session_Stop.command(session)._1
    96                 })
    97           },
    98         "use_theories" ->
    99           { case (context, Server_Commands.Use_Theories(args)) =>
   100               context.make_task(task =>
   101                 {
   102                   val session = context.server.the_session(args.session_id)
   103                   Server_Commands.Use_Theories.command(
   104                     args, session, id = task.id, progress = task.progress)._1
   105                 }),
   106           },
   107         "purge_theories" ->
   108           { case (context, Server_Commands.Purge_Theories(args)) =>
   109               val session = context.server.the_session(args.session_id)
   110               Server_Commands.Purge_Theories.command(args, session)._1
   111           })
   112 
   113     def unapply(name: String): Option[T] = table.get(name)
   114   }
   115 
   116 
   117   /* output reply */
   118 
   119   class Error(val message: String, val json: JSON.Object.T = JSON.Object.empty)
   120     extends RuntimeException(message)
   121 
   122   def json_error(exn: Throwable): JSON.Object.T =
   123     exn match {
   124       case e: Error => Reply.error_message(e.message) ++ e.json
   125       case ERROR(msg) => Reply.error_message(msg)
   126       case _ if Exn.is_interrupt(exn) => Reply.error_message(Exn.message(exn))
   127       case _ => JSON.Object.empty
   128     }
   129 
   130   object Reply extends Enumeration
   131   {
   132     val OK, ERROR, FINISHED, FAILED, NOTE = Value
   133 
   134     def message(msg: String, kind: String = ""): JSON.Object.T =
   135       JSON.Object(Markup.KIND -> proper_string(kind).getOrElse(Markup.WRITELN), "message" -> msg)
   136 
   137     def error_message(msg: String): JSON.Object.T =
   138       message(msg, kind = Markup.ERROR)
   139 
   140     def unapply(msg: String): Option[(Reply.Value, Any)] =
   141     {
   142       if (msg == "") None
   143       else {
   144         val (name, argument) = Argument.split(msg)
   145         for {
   146           reply <-
   147             try { Some(withName(name)) }
   148             catch { case _: NoSuchElementException => None }
   149           arg <- Argument.unapply(argument)
   150         } yield (reply, arg)
   151       }
   152     }
   153   }
   154 
   155 
   156   /* socket connection */
   157 
   158   object Connection
   159   {
   160     def apply(socket: Socket): Connection =
   161       new Connection(socket)
   162   }
   163 
   164   class Connection private(socket: Socket)
   165   {
   166     override def toString: String = socket.toString
   167 
   168     def close() { socket.close }
   169 
   170     def set_timeout(t: Time) { socket.setSoTimeout(t.ms.toInt) }
   171 
   172     private val in = new BufferedInputStream(socket.getInputStream)
   173     private val out = new BufferedOutputStream(socket.getOutputStream)
   174     private val out_lock: AnyRef = new Object
   175 
   176     def tty_loop(interrupt: Option[() => Unit] = None): TTY_Loop =
   177       new TTY_Loop(
   178         new OutputStreamWriter(out),
   179         new InputStreamReader(in),
   180         writer_lock = out_lock,
   181         interrupt = interrupt)
   182 
   183     def read_message(): Option[String] =
   184       try {
   185         Bytes.read_line(in).map(_.text) match {
   186           case Some(Value.Int(n)) =>
   187             Bytes.read_block(in, n).map(bytes => Library.trim_line(bytes.text))
   188           case res => res
   189         }
   190       }
   191       catch { case _: SocketException => None }
   192 
   193     def write_message(msg: String): Unit = out_lock.synchronized
   194     {
   195       val b = UTF8.bytes(msg)
   196       if (b.length > 100 || b.contains(10)) {
   197         out.write(UTF8.bytes((b.length + 1).toString))
   198         out.write(10)
   199       }
   200       out.write(b)
   201       out.write(10)
   202       try { out.flush() } catch { case _: SocketException => }
   203     }
   204 
   205     def reply(r: Reply.Value, arg: Any)
   206     {
   207       val argument = Argument.print(arg)
   208       write_message(if (argument == "") r.toString else r.toString + " " + argument)
   209     }
   210 
   211     def reply_ok(arg: Any) { reply(Reply.OK, arg) }
   212     def reply_error(arg: Any) { reply(Reply.ERROR, arg) }
   213     def reply_error_message(message: String, more: JSON.Object.Entry*): Unit =
   214       reply_error(Reply.error_message(message) ++ more)
   215 
   216     def notify(arg: Any) { reply(Reply.NOTE, arg) }
   217   }
   218 
   219 
   220   /* context with output channels */
   221 
   222   class Context private[Server](val server: Server, connection: Connection)
   223   {
   224     context =>
   225 
   226     def reply(r: Reply.Value, arg: Any) { connection.reply(r, arg) }
   227     def notify(arg: Any) { connection.notify(arg) }
   228     def message(kind: String, msg: String, more: JSON.Object.Entry*): Unit =
   229       notify(Reply.message(msg, kind = kind) ++ more)
   230     def writeln(msg: String, more: JSON.Object.Entry*): Unit = message(Markup.WRITELN, msg, more:_*)
   231     def warning(msg: String, more: JSON.Object.Entry*): Unit = message(Markup.WARNING, msg, more:_*)
   232     def error_message(msg: String, more: JSON.Object.Entry*): Unit =
   233       message(Markup.ERROR, msg, more:_*)
   234 
   235     def progress(more: JSON.Object.Entry*): Connection_Progress =
   236       new Connection_Progress(context, more:_*)
   237 
   238     override def toString: String = connection.toString
   239 
   240 
   241     /* asynchronous tasks */
   242 
   243     private val _tasks = Synchronized(Set.empty[Task])
   244 
   245     def make_task(body: Task => JSON.Object.T): Task =
   246     {
   247       val task = new Task(context, body)
   248       _tasks.change(_ + task)
   249       task
   250     }
   251 
   252     def remove_task(task: Task): Unit =
   253       _tasks.change(_ - task)
   254 
   255     def cancel_task(id: UUID): Unit =
   256       _tasks.change(tasks => { tasks.find(task => task.id == id).foreach(_.cancel); tasks })
   257 
   258     def close()
   259     {
   260       while(_tasks.change_result(tasks => { tasks.foreach(_.cancel); (tasks.nonEmpty, tasks) }))
   261       { _tasks.value.foreach(_.join) }
   262     }
   263   }
   264 
   265   class Connection_Progress private[Server](context: Context, more: JSON.Object.Entry*)
   266     extends Progress
   267   {
   268     override def echo(msg: String): Unit = context.writeln(msg, more:_*)
   269     override def echo_warning(msg: String): Unit = context.warning(msg, more:_*)
   270     override def echo_error_message(msg: String): Unit = context.error_message(msg, more:_*)
   271     override def theory(session: String, theory: String): Unit =
   272       context.writeln(Progress.theory_message(session, theory),
   273         (List("session" -> session, "theory" -> theory) ::: more.toList):_*)
   274 
   275     @volatile private var is_stopped = false
   276     override def stopped: Boolean = is_stopped
   277     def stop { is_stopped = true }
   278 
   279     override def toString: String = context.toString
   280   }
   281 
   282   class Task private[Server](val context: Context, body: Task => JSON.Object.T)
   283   {
   284     task =>
   285 
   286     val id: UUID = UUID()
   287     val ident: JSON.Object.Entry = ("task" -> id.toString)
   288 
   289     val progress: Connection_Progress = context.progress(ident)
   290     def cancel { progress.stop }
   291 
   292     private lazy val thread = Standard_Thread.fork("server_task")
   293     {
   294       Exn.capture { body(task) } match {
   295         case Exn.Res(res) =>
   296           context.reply(Reply.FINISHED, res + ident)
   297         case Exn.Exn(exn) =>
   298           val err = json_error(exn)
   299           if (err.isEmpty) throw exn else context.reply(Reply.FAILED, err + ident)
   300       }
   301       progress.stop
   302       context.remove_task(task)
   303     }
   304     def start { thread }
   305     def join { thread.join }
   306   }
   307 
   308 
   309   /* server info */
   310 
   311   sealed case class Info(name: String, port: Int, password: String)
   312   {
   313     override def toString: String =
   314       "server " + quote(name) + " = " + print(port, password)
   315 
   316     def connection(): Connection =
   317     {
   318       val connection = Connection(new Socket(InetAddress.getByName("127.0.0.1"), port))
   319       connection.write_message(password)
   320       connection
   321     }
   322 
   323     def active(): Boolean =
   324       try {
   325         using(connection())(connection =>
   326           {
   327             connection.set_timeout(Time.seconds(2.0))
   328             connection.read_message() match {
   329               case Some(Reply(Reply.OK, _)) => true
   330               case _ => false
   331             }
   332           })
   333       }
   334       catch {
   335         case _: IOException => false
   336         case _: SocketException => false
   337         case _: SocketTimeoutException => false
   338       }
   339   }
   340 
   341 
   342   /* per-user servers */
   343 
   344   val default_name = "isabelle"
   345 
   346   def print(port: Int, password: String): String =
   347     "127.0.0.1:" + port + " (password " + quote(password) + ")"
   348 
   349   object Data
   350   {
   351     val database = Path.explode("$ISABELLE_HOME_USER/servers.db")
   352 
   353     val name = SQL.Column.string("name").make_primary_key
   354     val port = SQL.Column.int("port")
   355     val password = SQL.Column.string("password")
   356     val table = SQL.Table("isabelle_servers", List(name, port, password))
   357   }
   358 
   359   def list(db: SQLite.Database): List[Info] =
   360     if (db.tables.contains(Data.table.name)) {
   361       db.using_statement(Data.table.select())(stmt =>
   362         stmt.execute_query().iterator(res =>
   363           Info(
   364             res.string(Data.name),
   365             res.int(Data.port),
   366             res.string(Data.password))).toList.sortBy(_.name))
   367     }
   368     else Nil
   369 
   370   def find(db: SQLite.Database, name: String): Option[Info] =
   371     list(db).find(server_info => server_info.name == name && server_info.active)
   372 
   373   def init(
   374     name: String = default_name,
   375     port: Int = 0,
   376     existing_server: Boolean = false,
   377     log: Logger = No_Logger): (Info, Option[Server]) =
   378   {
   379     using(SQLite.open_database(Data.database))(db =>
   380       {
   381         db.transaction {
   382           Isabelle_System.bash("chmod 600 " + File.bash_path(Data.database)).check
   383           db.create_table(Data.table)
   384           list(db).filterNot(_.active).foreach(server_info =>
   385             db.using_statement(Data.table.delete(Data.name.where_equal(server_info.name)))(
   386               _.execute))
   387         }
   388         db.transaction {
   389           find(db, name) match {
   390             case Some(server_info) => (server_info, None)
   391             case None =>
   392               if (existing_server) error("Isabelle server " + quote(name) + " not running")
   393 
   394               val server = new Server(port, log)
   395               val server_info = Info(name, server.port, server.password)
   396 
   397               db.using_statement(Data.table.delete(Data.name.where_equal(name)))(_.execute)
   398               db.using_statement(Data.table.insert())(stmt =>
   399               {
   400                 stmt.string(1) = server_info.name
   401                 stmt.int(2) = server_info.port
   402                 stmt.string(3) = server_info.password
   403                 stmt.execute()
   404               })
   405 
   406               server.start
   407               (server_info, Some(server))
   408           }
   409         }
   410       })
   411   }
   412 
   413   def exit(name: String = default_name): Boolean =
   414   {
   415     using(SQLite.open_database(Data.database))(db =>
   416       db.transaction {
   417         find(db, name) match {
   418           case Some(server_info) =>
   419             using(server_info.connection())(_.write_message("shutdown"))
   420             while(server_info.active) { Thread.sleep(50) }
   421             true
   422           case None => false
   423         }
   424       })
   425   }
   426 
   427 
   428   /* Isabelle tool wrapper */
   429 
   430   val isabelle_tool =
   431     Isabelle_Tool("server", "manage resident Isabelle servers", args =>
   432     {
   433       var console = false
   434       var log_file: Option[Path] = None
   435       var operation_list = false
   436       var operation_exit = false
   437       var name = default_name
   438       var port = 0
   439       var existing_server = false
   440 
   441       val getopts =
   442         Getopts("""
   443 Usage: isabelle server [OPTIONS]
   444 
   445   Options are:
   446     -L FILE      logging on FILE
   447     -c           console interaction with specified server
   448     -l           list servers (alternative operation)
   449     -n NAME      explicit server name (default: """ + default_name + """)
   450     -p PORT      explicit server port
   451     -s           assume existing server, no implicit startup
   452     -x           exit specified server (alternative operation)
   453 
   454   Manage resident Isabelle servers.
   455 """,
   456           "L:" -> (arg => log_file = Some(Path.explode(File.standard_path(arg)))),
   457           "c" -> (_ => console = true),
   458           "l" -> (_ => operation_list = true),
   459           "n:" -> (arg => name = arg),
   460           "p:" -> (arg => port = Value.Int.parse(arg)),
   461           "s" -> (_ => existing_server = true),
   462           "x" -> (_ => operation_exit = true))
   463 
   464       val more_args = getopts(args)
   465       if (more_args.nonEmpty) getopts.usage()
   466 
   467       if (operation_list) {
   468         for {
   469           server_info <- using(SQLite.open_database(Data.database))(list(_))
   470           if server_info.active
   471         } Output.writeln(server_info.toString, stdout = true)
   472       }
   473       else if (operation_exit) {
   474         val ok = Server.exit(name)
   475         sys.exit(if (ok) 0 else 1)
   476       }
   477       else {
   478         val log = Logger.make(log_file)
   479         val (server_info, server) =
   480           init(name, port = port, existing_server = existing_server, log = log)
   481         Output.writeln(server_info.toString, stdout = true)
   482         if (console) {
   483           using(server_info.connection())(connection => connection.tty_loop().join)
   484         }
   485         server.foreach(_.join)
   486       }
   487     })
   488 }
   489 
   490 class Server private(_port: Int, val log: Logger)
   491 {
   492   server =>
   493 
   494   private val server_socket = new ServerSocket(_port, 50, InetAddress.getByName("127.0.0.1"))
   495 
   496   private val _sessions = Synchronized(Map.empty[UUID, Thy_Resources.Session])
   497   def err_session(id: UUID): Nothing = error("No session " + Library.single_quote(id.toString))
   498   def the_session(id: UUID): Thy_Resources.Session =
   499     _sessions.value.get(id) getOrElse err_session(id)
   500   def add_session(entry: (UUID, Thy_Resources.Session)) { _sessions.change(_ + entry) }
   501   def remove_session(id: UUID): Thy_Resources.Session =
   502   {
   503     _sessions.change_result(sessions =>
   504       sessions.get(id) match {
   505         case Some(session) => (session, sessions - id)
   506         case None => err_session(id)
   507       })
   508   }
   509 
   510   def shutdown()
   511   {
   512     server_socket.close
   513 
   514     val sessions = _sessions.change_result(sessions => (sessions, Map.empty))
   515     for ((_, session) <- sessions) {
   516       try {
   517         val result = session.stop()
   518         if (!result.ok) log("Session shutdown failed: return code " + result.rc)
   519       }
   520       catch { case ERROR(msg) => log("Session shutdown failed: " + msg) }
   521     }
   522   }
   523 
   524   def port: Int = server_socket.getLocalPort
   525   val password: String = UUID().toString
   526 
   527   override def toString: String = Server.print(port, password)
   528 
   529   private def handle(connection: Server.Connection)
   530   {
   531     using(new Server.Context(server, connection))(context =>
   532     {
   533       connection.read_message() match {
   534         case Some(msg) if msg == password =>
   535           connection.reply_ok(
   536             JSON.Object(
   537               "isabelle_id" -> Isabelle_System.isabelle_id(),
   538               "isabelle_version" -> Distribution.version))
   539 
   540           var finished = false
   541           while (!finished) {
   542             connection.read_message() match {
   543               case None => finished = true
   544               case Some("") => context.notify("Command 'help' provides list of commands")
   545               case Some(msg) =>
   546                 val (name, argument) = Server.Argument.split(msg)
   547                 name match {
   548                   case Server.Command(cmd) =>
   549                     argument match {
   550                       case Server.Argument(arg) =>
   551                         if (cmd.isDefinedAt((context, arg))) {
   552                           Exn.capture { cmd((context, arg)) } match {
   553                             case Exn.Res(task: Server.Task) =>
   554                               connection.reply_ok(JSON.Object(task.ident))
   555                               task.start
   556                             case Exn.Res(res) => connection.reply_ok(res)
   557                             case Exn.Exn(exn) =>
   558                               val err = Server.json_error(exn)
   559                               if (err.isEmpty) throw exn else connection.reply_error(err)
   560                           }
   561                         }
   562                         else {
   563                           connection.reply_error_message(
   564                             "Bad argument for command " + Library.single_quote(name),
   565                             "argument" -> argument)
   566                         }
   567                       case _ =>
   568                         connection.reply_error_message(
   569                           "Malformed argument for command " + Library.single_quote(name),
   570                           "argument" -> argument)
   571                     }
   572                   case _ => connection.reply_error("Bad command " + Library.single_quote(name))
   573                 }
   574             }
   575           }
   576         case _ =>
   577       }
   578     })
   579   }
   580 
   581   private lazy val server_thread: Thread =
   582     Standard_Thread.fork("server") {
   583       var finished = false
   584       while (!finished) {
   585         Exn.capture(server_socket.accept) match {
   586           case Exn.Res(socket) =>
   587             Standard_Thread.fork("server_connection")
   588               { using(Server.Connection(socket))(handle(_)) }
   589           case Exn.Exn(_) => finished = true
   590         }
   591       }
   592     }
   593 
   594   def start { server_thread }
   595 
   596   def join { server_thread.join; shutdown() }
   597 }