src/Pure/PIDE/prover.scala
author wenzelm
Sun Feb 14 14:33:32 2016 +0100 (2016-02-14)
changeset 62310 ab836dc7410e
parent 62307 ccbd48444049
child 62556 c115e69f457f
permissions -rw-r--r--
more gentle termination (like Bash.multi_kill without signal) to give prover a chance to conclude;
     1 /*  Title:      Pure/PIDE/prover.scala
     2     Author:     Makarius
     3     Options:    :folding=explicit:
     4 
     5 Prover process wrapping.
     6 */
     7 
     8 package isabelle
     9 
    10 
    11 import java.io.{InputStream, OutputStream, BufferedReader, BufferedOutputStream, IOException}
    12 
    13 
    14 object Prover
    15 {
    16   /* syntax */
    17 
    18   trait Syntax
    19   {
    20     def ++ (other: Syntax): Syntax
    21     def add_keywords(keywords: Thy_Header.Keywords): Syntax
    22     def parse_spans(input: CharSequence): List[Command_Span.Span]
    23     def load_command(name: String): Option[List[String]]
    24     def load_commands_in(text: String): Boolean
    25   }
    26 
    27 
    28   /* underlying system process */
    29 
    30   trait System_Process
    31   {
    32     def stdout: BufferedReader
    33     def stderr: BufferedReader
    34     def terminate: Unit
    35     def join: Int
    36   }
    37 
    38 
    39   /* messages */
    40 
    41   sealed abstract class Message
    42 
    43   class Input(val name: String, val args: List[String]) extends Message
    44   {
    45     override def toString: String =
    46       XML.Elem(Markup(Markup.PROVER_COMMAND, List((Markup.NAME, name))),
    47         args.map(s =>
    48           List(XML.Text("\n"), XML.elem(Markup.PROVER_ARG, YXML.parse_body(s)))).flatten).toString
    49   }
    50 
    51   class Output(val message: XML.Elem) extends Message
    52   {
    53     def kind: String = message.markup.name
    54     def properties: Properties.T = message.markup.properties
    55     def body: XML.Body = message.body
    56 
    57     def is_init = kind == Markup.INIT
    58     def is_exit = kind == Markup.EXIT
    59     def is_stdout = kind == Markup.STDOUT
    60     def is_stderr = kind == Markup.STDERR
    61     def is_system = kind == Markup.SYSTEM
    62     def is_status = kind == Markup.STATUS
    63     def is_report = kind == Markup.REPORT
    64     def is_syslog = is_init || is_exit || is_system || is_stderr
    65 
    66     override def toString: String =
    67     {
    68       val res =
    69         if (is_status || is_report) message.body.map(_.toString).mkString
    70         else Pretty.string_of(message.body)
    71       if (properties.isEmpty)
    72         kind.toString + " [[" + res + "]]"
    73       else
    74         kind.toString + " " +
    75           (for ((x, y) <- properties) yield x + "=" + y).mkString("{", ",", "}") + " [[" + res + "]]"
    76     }
    77   }
    78 
    79   class Protocol_Output(props: Properties.T, val bytes: Bytes)
    80     extends Output(XML.Elem(Markup(Markup.PROTOCOL, props), Nil))
    81   {
    82     lazy val text: String = bytes.toString
    83   }
    84 }
    85 
    86 
    87 abstract class Prover(
    88   receiver: Prover.Message => Unit,
    89   system_channel: System_Channel,
    90   system_process: Prover.System_Process) extends Protocol
    91 {
    92   /** receiver output **/
    93 
    94   val xml_cache: XML.Cache = new XML.Cache()
    95 
    96   private def system_output(text: String)
    97   {
    98     receiver(new Prover.Output(XML.Elem(Markup(Markup.SYSTEM, Nil), List(XML.Text(text)))))
    99   }
   100 
   101   private def protocol_output(props: Properties.T, bytes: Bytes)
   102   {
   103     receiver(new Prover.Protocol_Output(props, bytes))
   104   }
   105 
   106   private def output(kind: String, props: Properties.T, body: XML.Body)
   107   {
   108     if (kind == Markup.INIT) system_channel.accepted()
   109 
   110     val main = XML.Elem(Markup(kind, props), Protocol_Message.clean_reports(body))
   111     val reports = Protocol_Message.reports(props, body)
   112     for (msg <- main :: reports) receiver(new Prover.Output(xml_cache.elem(msg)))
   113   }
   114 
   115   private def exit_message(rc: Int)
   116   {
   117     output(Markup.EXIT, Markup.Return_Code(rc), List(XML.Text("Return code: " + rc.toString)))
   118   }
   119 
   120 
   121 
   122   /** process manager **/
   123 
   124   private val process_result: Future[Int] =
   125     Future.thread("process_result") { system_process.join }
   126 
   127   private def terminate_process()
   128   {
   129     try { system_process.terminate }
   130     catch {
   131       case exn @ ERROR(_) => system_output("Failed to terminate prover process: " + exn.getMessage)
   132     }
   133   }
   134 
   135   private val process_manager = Standard_Thread.fork("process_manager")
   136   {
   137     val (startup_failed, startup_errors) =
   138     {
   139       var finished: Option[Boolean] = None
   140       val result = new StringBuilder(100)
   141       while (finished.isEmpty && (system_process.stderr.ready || !process_result.is_finished)) {
   142         while (finished.isEmpty && system_process.stderr.ready) {
   143           try {
   144             val c = system_process.stderr.read
   145             if (c == 2) finished = Some(true)
   146             else result += c.toChar
   147           }
   148           catch { case _: IOException => finished = Some(false) }
   149         }
   150         Thread.sleep(10)
   151       }
   152       (finished.isEmpty || !finished.get, result.toString.trim)
   153     }
   154     if (startup_errors != "") system_output(startup_errors)
   155 
   156     if (startup_failed) {
   157       terminate_process()
   158       process_result.join
   159       exit_message(127)
   160     }
   161     else {
   162       val (command_stream, message_stream) = system_channel.rendezvous()
   163 
   164       command_input_init(command_stream)
   165       val stdout = physical_output(false)
   166       val stderr = physical_output(true)
   167       val message = message_output(message_stream)
   168 
   169       val rc = process_result.join
   170       system_output("process terminated")
   171       command_input_close()
   172       for (thread <- List(stdout, stderr, message)) thread.join
   173       system_output("process_manager terminated")
   174       exit_message(rc)
   175     }
   176     system_channel.accepted()
   177   }
   178 
   179 
   180   /* management methods */
   181 
   182   def join() { process_manager.join() }
   183 
   184   def terminate()
   185   {
   186     system_output("Terminating prover process")
   187     command_input_close()
   188 
   189     var count = 10
   190     while (!process_result.is_finished && count > 0) {
   191       Thread.sleep(100)
   192       count -= 1
   193     }
   194     if (!process_result.is_finished) terminate_process()
   195   }
   196 
   197 
   198 
   199   /** process streams **/
   200 
   201   /* command input */
   202 
   203   private var command_input: Option[Consumer_Thread[List[Bytes]]] = None
   204 
   205   private def command_input_close(): Unit = command_input.foreach(_.shutdown)
   206 
   207   private def command_input_init(raw_stream: OutputStream)
   208   {
   209     val name = "command_input"
   210     val stream = new BufferedOutputStream(raw_stream)
   211     command_input =
   212       Some(
   213         Consumer_Thread.fork(name)(
   214           consume =
   215             {
   216               case chunks =>
   217                 try {
   218                   Bytes(chunks.map(_.length).mkString("", ",", "\n")).write(stream)
   219                   chunks.foreach(_.write(stream))
   220                   stream.flush
   221                   true
   222                 }
   223                 catch { case e: IOException => system_output(name + ": " + e.getMessage); false }
   224             },
   225           finish = { case () => stream.close; system_output(name + " terminated") }
   226         )
   227       )
   228   }
   229 
   230 
   231   /* physical output */
   232 
   233   private def physical_output(err: Boolean): Thread =
   234   {
   235     val (name, reader, markup) =
   236       if (err) ("standard_error", system_process.stderr, Markup.STDERR)
   237       else ("standard_output", system_process.stdout, Markup.STDOUT)
   238 
   239     Standard_Thread.fork(name) {
   240       try {
   241         var result = new StringBuilder(100)
   242         var finished = false
   243         while (!finished) {
   244           //{{{
   245           var c = -1
   246           var done = false
   247           while (!done && (result.length == 0 || reader.ready)) {
   248             c = reader.read
   249             if (c >= 0) result.append(c.asInstanceOf[Char])
   250             else done = true
   251           }
   252           if (result.length > 0) {
   253             output(markup, Nil, List(XML.Text(decode(result.toString))))
   254             result.length = 0
   255           }
   256           else {
   257             reader.close
   258             finished = true
   259           }
   260           //}}}
   261         }
   262       }
   263       catch { case e: IOException => system_output(name + ": " + e.getMessage) }
   264       system_output(name + " terminated")
   265     }
   266   }
   267 
   268 
   269   /* message output */
   270 
   271   private def message_output(stream: InputStream): Thread =
   272   {
   273     class EOF extends Exception
   274     class Protocol_Error(msg: String) extends Exception(msg)
   275 
   276     val name = "message_output"
   277     Standard_Thread.fork(name) {
   278       val default_buffer = new Array[Byte](65536)
   279       var c = -1
   280 
   281       def read_int(): Int =
   282       //{{{
   283       {
   284         var n = 0
   285         c = stream.read
   286         if (c == -1) throw new EOF
   287         while (48 <= c && c <= 57) {
   288           n = 10 * n + (c - 48)
   289           c = stream.read
   290         }
   291         if (c != 10)
   292           throw new Protocol_Error("malformed header: expected integer followed by newline")
   293         else n
   294       }
   295       //}}}
   296 
   297       def read_chunk_bytes(): (Array[Byte], Int) =
   298       //{{{
   299       {
   300         val n = read_int()
   301         val buf =
   302           if (n <= default_buffer.length) default_buffer
   303           else new Array[Byte](n)
   304 
   305         var i = 0
   306         var m = 0
   307         do {
   308           m = stream.read(buf, i, n - i)
   309           if (m != -1) i += m
   310         }
   311         while (m != -1 && n > i)
   312 
   313         if (i != n)
   314           throw new Protocol_Error("bad chunk (unexpected EOF after " + i + " of " + n + " bytes)")
   315 
   316         (buf, n)
   317       }
   318       //}}}
   319 
   320       def read_chunk(): XML.Body =
   321       {
   322         val (buf, n) = read_chunk_bytes()
   323         YXML.parse_body_failsafe(UTF8.decode_chars(decode, buf, 0, n))
   324       }
   325 
   326       try {
   327         do {
   328           try {
   329             val header = read_chunk()
   330             header match {
   331               case List(XML.Elem(Markup(name, props), Nil)) =>
   332                 val kind = name.intern
   333                 if (kind == Markup.PROTOCOL) {
   334                   val (buf, n) = read_chunk_bytes()
   335                   protocol_output(props, Bytes(buf, 0, n))
   336                 }
   337                 else {
   338                   val body = read_chunk()
   339                   output(kind, props, body)
   340                 }
   341               case _ =>
   342                 read_chunk()
   343                 throw new Protocol_Error("bad header: " + header.toString)
   344             }
   345           }
   346           catch { case _: EOF => }
   347         }
   348         while (c != -1)
   349       }
   350       catch {
   351         case e: IOException => system_output("Cannot read message:\n" + e.getMessage)
   352         case e: Protocol_Error => system_output("Malformed message:\n" + e.getMessage)
   353       }
   354       stream.close
   355 
   356       system_output(name + " terminated")
   357     }
   358   }
   359 
   360 
   361 
   362   /** protocol commands **/
   363 
   364   def protocol_command_bytes(name: String, args: Bytes*): Unit =
   365     command_input match {
   366       case Some(thread) => thread.send(Bytes(name) :: args.toList)
   367       case None => error("Uninitialized command input thread")
   368     }
   369 
   370   def protocol_command(name: String, args: String*)
   371   {
   372     receiver(new Prover.Input(name, args.toList))
   373     protocol_command_bytes(name, args.map(Bytes(_)): _*)
   374   }
   375 }