src/Pure/PIDE/prover.scala
author wenzelm
Tue Aug 12 18:36:43 2014 +0200 (2014-08-12)
changeset 57916 2c2c24dbf0a4
parent 57915 448325de6e4f
child 57917 8ce97e5d545f
permissions -rw-r--r--
generic process wrapping in Prover;
clarified module arrangement;
     1 /*  Title:      Pure/PIDE/prover.scala
     2     Author:     Makarius
     3 
     4 General prover operations and process wrapping.
     5 */
     6 
     7 package isabelle
     8 
     9 
    10 import java.io.{InputStream, OutputStream, BufferedReader, BufferedOutputStream, IOException}
    11 
    12 
    13 object Prover
    14 {
    15   /* syntax */
    16 
    17   trait Syntax
    18   {
    19     def add_keywords(keywords: Thy_Header.Keywords): Syntax
    20     def parse_spans(input: CharSequence): List[Command_Span.Span]
    21     def load_command(name: String): Option[List[String]]
    22     def load_commands_in(text: String): Boolean
    23   }
    24 
    25 
    26   /* underlying system process */
    27 
    28   trait System_Process
    29   {
    30     def channel: System_Channel
    31     def stdout: BufferedReader
    32     def stderr: BufferedReader
    33     def terminate: Unit
    34     def join: Int
    35   }
    36 
    37 
    38   /* messages */
    39 
    40   sealed abstract class Message
    41 
    42   class Input(val name: String, val args: List[String]) extends Message
    43   {
    44     override def toString: String =
    45       XML.Elem(Markup(Markup.PROVER_COMMAND, List((Markup.NAME, name))),
    46         args.map(s =>
    47           List(XML.Text("\n"), XML.elem(Markup.PROVER_ARG, YXML.parse_body(s)))).flatten).toString
    48   }
    49 
    50   class Output(val message: XML.Elem) extends Message
    51   {
    52     def kind: String = message.markup.name
    53     def properties: Properties.T = message.markup.properties
    54     def body: XML.Body = message.body
    55 
    56     def is_init = kind == Markup.INIT
    57     def is_exit = kind == Markup.EXIT
    58     def is_stdout = kind == Markup.STDOUT
    59     def is_stderr = kind == Markup.STDERR
    60     def is_system = kind == Markup.SYSTEM
    61     def is_status = kind == Markup.STATUS
    62     def is_report = kind == Markup.REPORT
    63     def is_syslog = is_init || is_exit || is_system || is_stderr
    64 
    65     override def toString: String =
    66     {
    67       val res =
    68         if (is_status || is_report) message.body.map(_.toString).mkString
    69         else Pretty.string_of(message.body)
    70       if (properties.isEmpty)
    71         kind.toString + " [[" + res + "]]"
    72       else
    73         kind.toString + " " +
    74           (for ((x, y) <- properties) yield x + "=" + y).mkString("{", ",", "}") + " [[" + res + "]]"
    75     }
    76   }
    77 
    78   class Protocol_Output(props: Properties.T, val bytes: Bytes)
    79     extends Output(XML.Elem(Markup(Markup.PROTOCOL, props), Nil))
    80   {
    81     lazy val text: String = bytes.toString
    82   }
    83 }
    84 
    85 
    86 abstract class Prover(
    87   receiver: Prover.Message => Unit,
    88   system_process: Prover.System_Process) extends Protocol
    89 {
    90   /* output */
    91 
    92   val xml_cache: XML.Cache = new XML.Cache()
    93 
    94   private def system_output(text: String)
    95   {
    96     receiver(new Prover.Output(XML.Elem(Markup(Markup.SYSTEM, Nil), List(XML.Text(text)))))
    97   }
    98 
    99   private def protocol_output(props: Properties.T, bytes: Bytes)
   100   {
   101     receiver(new Prover.Protocol_Output(props, bytes))
   102   }
   103 
   104   private def output(kind: String, props: Properties.T, body: XML.Body)
   105   {
   106     if (kind == Markup.INIT) system_process.channel.accepted()
   107 
   108     val main = XML.Elem(Markup(kind, props), Protocol.clean_message(body))
   109     val reports = Protocol.message_reports(props, body)
   110     for (msg <- main :: reports) receiver(new Prover.Output(xml_cache.elem(msg)))
   111   }
   112 
   113   private def exit_message(rc: Int)
   114   {
   115     output(Markup.EXIT, Markup.Return_Code(rc), List(XML.Text("Return code: " + rc.toString)))
   116   }
   117 
   118 
   119 
   120   /** process manager **/
   121 
   122   private val (_, process_result) =
   123     Simple_Thread.future("process_result") { system_process.join }
   124 
   125   private def terminate_process()
   126   {
   127     try { system_process.terminate }
   128     catch {
   129       case exn @ ERROR(_) => system_output("Failed to terminate prover process: " + exn.getMessage)
   130     }
   131   }
   132 
   133   private val process_manager = Simple_Thread.fork("process_manager")
   134   {
   135     val (startup_failed, startup_errors) =
   136     {
   137       var finished: Option[Boolean] = None
   138       val result = new StringBuilder(100)
   139       while (finished.isEmpty && (system_process.stderr.ready || !process_result.is_finished)) {
   140         while (finished.isEmpty && system_process.stderr.ready) {
   141           try {
   142             val c = system_process.stderr.read
   143             if (c == 2) finished = Some(true)
   144             else result += c.toChar
   145           }
   146           catch { case _: IOException => finished = Some(false) }
   147         }
   148         Thread.sleep(10)
   149       }
   150       (finished.isEmpty || !finished.get, result.toString.trim)
   151     }
   152     if (startup_errors != "") system_output(startup_errors)
   153 
   154     if (startup_failed) {
   155       terminate_process()
   156       process_result.join
   157       exit_message(127)
   158     }
   159     else {
   160       val (command_stream, message_stream) = system_process.channel.rendezvous()
   161 
   162       command_input_init(command_stream)
   163       val stdout = physical_output(false)
   164       val stderr = physical_output(true)
   165       val message = message_output(message_stream)
   166 
   167       val rc = process_result.join
   168       system_output("process terminated")
   169       command_input_close()
   170       for (thread <- List(stdout, stderr, message)) thread.join
   171       system_output("process_manager terminated")
   172       exit_message(rc)
   173     }
   174     system_process.channel.accepted()
   175   }
   176 
   177 
   178   /* management methods */
   179 
   180   def join() { process_manager.join() }
   181 
   182   def terminate()
   183   {
   184     command_input_close()
   185     system_output("Terminating prover process")
   186     terminate_process()
   187   }
   188 
   189 
   190 
   191   /** process streams **/
   192 
   193   /* command input */
   194 
   195   private var command_input: Option[Consumer_Thread[List[Bytes]]] = None
   196 
   197   private def command_input_close(): Unit = command_input.foreach(_.shutdown)
   198 
   199   private def command_input_init(raw_stream: OutputStream)
   200   {
   201     val name = "command_input"
   202     val stream = new BufferedOutputStream(raw_stream)
   203     command_input =
   204       Some(
   205         Consumer_Thread.fork(name)(
   206           consume =
   207             {
   208               case chunks =>
   209                 try {
   210                   Bytes(chunks.map(_.length).mkString("", ",", "\n")).write(stream)
   211                   chunks.foreach(_.write(stream))
   212                   stream.flush
   213                   true
   214                 }
   215                 catch { case e: IOException => system_output(name + ": " + e.getMessage); false }
   216             },
   217           finish = { case () => stream.close; system_output(name + " terminated") }
   218         )
   219       )
   220   }
   221 
   222 
   223   /* physical output */
   224 
   225   private def physical_output(err: Boolean): Thread =
   226   {
   227     val (name, reader, markup) =
   228       if (err) ("standard_error", system_process.stderr, Markup.STDERR)
   229       else ("standard_output", system_process.stdout, Markup.STDOUT)
   230 
   231     Simple_Thread.fork(name) {
   232       try {
   233         var result = new StringBuilder(100)
   234         var finished = false
   235         while (!finished) {
   236           //{{{
   237           var c = -1
   238           var done = false
   239           while (!done && (result.length == 0 || reader.ready)) {
   240             c = reader.read
   241             if (c >= 0) result.append(c.asInstanceOf[Char])
   242             else done = true
   243           }
   244           if (result.length > 0) {
   245             output(markup, Nil, List(XML.Text(decode(result.toString))))
   246             result.length = 0
   247           }
   248           else {
   249             reader.close
   250             finished = true
   251           }
   252           //}}}
   253         }
   254       }
   255       catch { case e: IOException => system_output(name + ": " + e.getMessage) }
   256       system_output(name + " terminated")
   257     }
   258   }
   259 
   260 
   261   /* message output */
   262 
   263   private def message_output(stream: InputStream): Thread =
   264   {
   265     class EOF extends Exception
   266     class Protocol_Error(msg: String) extends Exception(msg)
   267 
   268     val name = "message_output"
   269     Simple_Thread.fork(name) {
   270       val default_buffer = new Array[Byte](65536)
   271       var c = -1
   272 
   273       def read_int(): Int =
   274       //{{{
   275       {
   276         var n = 0
   277         c = stream.read
   278         if (c == -1) throw new EOF
   279         while (48 <= c && c <= 57) {
   280           n = 10 * n + (c - 48)
   281           c = stream.read
   282         }
   283         if (c != 10)
   284           throw new Protocol_Error("malformed header: expected integer followed by newline")
   285         else n
   286       }
   287       //}}}
   288 
   289       def read_chunk_bytes(): (Array[Byte], Int) =
   290       //{{{
   291       {
   292         val n = read_int()
   293         val buf =
   294           if (n <= default_buffer.size) default_buffer
   295           else new Array[Byte](n)
   296 
   297         var i = 0
   298         var m = 0
   299         do {
   300           m = stream.read(buf, i, n - i)
   301           if (m != -1) i += m
   302         }
   303         while (m != -1 && n > i)
   304 
   305         if (i != n)
   306           throw new Protocol_Error("bad chunk (unexpected EOF after " + i + " of " + n + " bytes)")
   307 
   308         (buf, n)
   309       }
   310       //}}}
   311 
   312       def read_chunk(): XML.Body =
   313       {
   314         val (buf, n) = read_chunk_bytes()
   315         YXML.parse_body_failsafe(UTF8.decode_chars(decode, buf, 0, n))
   316       }
   317 
   318       try {
   319         do {
   320           try {
   321             val header = read_chunk()
   322             header match {
   323               case List(XML.Elem(Markup(name, props), Nil)) =>
   324                 val kind = name.intern
   325                 if (kind == Markup.PROTOCOL) {
   326                   val (buf, n) = read_chunk_bytes()
   327                   protocol_output(props, Bytes(buf, 0, n))
   328                 }
   329                 else {
   330                   val body = read_chunk()
   331                   output(kind, props, body)
   332                 }
   333               case _ =>
   334                 read_chunk()
   335                 throw new Protocol_Error("bad header: " + header.toString)
   336             }
   337           }
   338           catch { case _: EOF => }
   339         }
   340         while (c != -1)
   341       }
   342       catch {
   343         case e: IOException => system_output("Cannot read message:\n" + e.getMessage)
   344         case e: Protocol_Error => system_output("Malformed message:\n" + e.getMessage)
   345       }
   346       stream.close
   347 
   348       system_output(name + " terminated")
   349     }
   350   }
   351 
   352 
   353 
   354   /** protocol commands **/
   355 
   356   def protocol_command_bytes(name: String, args: Bytes*): Unit =
   357     command_input match {
   358       case Some(thread) => thread.send(Bytes(name) :: args.toList)
   359       case None => error("Uninitialized command input thread")
   360     }
   361 
   362   def protocol_command(name: String, args: String*)
   363   {
   364     receiver(new Prover.Input(name, args.toList))
   365     protocol_command_bytes(name, args.map(Bytes(_)): _*)
   366   }
   367 }
   368