more uniform multi-language operations;
authorwenzelm
Tue, 11 Dec 2018 19:25:35 +0100
changeset 69448 51e696887b81
parent 69446 9cf0b79dfb7f
child 69449 b516fdf8005c
more uniform multi-language operations;
src/Pure/General/bytes.scala
src/Pure/General/symbol.scala
src/Pure/PIDE/byte_message.ML
src/Pure/PIDE/byte_message.scala
src/Pure/ROOT.ML
src/Pure/System/system_channel.ML
src/Pure/Tools/server.scala
src/Pure/build-jars
src/Tools/Haskell/Haskell.thy
src/Tools/VSCode/src/channel.scala
--- a/src/Pure/General/bytes.scala	Mon Dec 10 23:36:29 2018 +0100
+++ b/src/Pure/General/bytes.scala	Tue Dec 11 19:25:35 2018 +0100
@@ -64,27 +64,6 @@
       new Bytes(out.toByteArray, 0, out.size)
     }
 
-  def read_block(stream: InputStream, length: Int): Option[Bytes] =
-  {
-    val bytes = read_stream(stream, limit = length)
-    if (bytes.length == length) Some(bytes) else None
-  }
-
-  def read_line(stream: InputStream): Option[Bytes] =
-  {
-    val out = new ByteArrayOutputStream(100)
-    var c = 0
-    while ({ c = stream.read; c != -1 && c != 10 }) out.write(c)
-
-    if (c == -1 && out.size == 0) None
-    else {
-      val a = out.toByteArray
-      val n = a.length
-      val b = if (n > 0 && a(n - 1) == 13) a.take(n - 1) else a
-      Some(new Bytes(b, 0, b.length))
-    }
-  }
-
   def read(file: JFile): Bytes =
     using(new FileInputStream(file))(read_stream(_, file.length.toInt))
 
@@ -136,6 +115,12 @@
 
   lazy val sha1_digest: SHA1.Digest = SHA1.digest(bytes)
 
+  def is_empty: Boolean = length == 0
+
+  def iterator: Iterator[Byte] =
+    for (i <- (offset until (offset + length)).iterator)
+      yield bytes(i)
+
   def array: Array[Byte] =
   {
     val a = new Array[Byte](length)
@@ -190,6 +175,13 @@
     else throw new IndexOutOfBoundsException
   }
 
+  def trim_line: Bytes =
+    if (length >= 2 && charAt(length - 2) == 13 && charAt(length - 1) == 10)
+      subSequence(0, length - 2)
+    else if (length >= 1 && (charAt(length - 1) == 13 || charAt(length - 1) == 10))
+      subSequence(0, length - 1)
+    else this
+
 
   /* streams */
 
--- a/src/Pure/General/symbol.scala	Mon Dec 10 23:36:29 2018 +0100
+++ b/src/Pure/General/symbol.scala	Tue Dec 11 19:25:35 2018 +0100
@@ -48,6 +48,8 @@
 
   def is_ascii_blank(c: Char): Boolean = " \t\n\u000b\f\r".contains(c)
 
+  def is_ascii_line_terminator(c: Char): Boolean = "\r\n".contains(c)
+
   def is_ascii_letdig(c: Char): Boolean =
     is_ascii_letter(c) || is_ascii_digit(c) || is_ascii_quasi(c)
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Pure/PIDE/byte_message.ML	Tue Dec 11 19:25:35 2018 +0100
@@ -0,0 +1,31 @@
+(*  Title:      Pure/General/byte_message.ML
+    Author:     Makarius
+
+Byte-oriented messages.
+*)
+
+signature BYTE_MESSAGE =
+sig
+  val read_line: BinIO.instream -> string option
+  val read_block: BinIO.instream -> int -> string
+end;
+
+structure Byte_Message: BYTE_MESSAGE =
+struct
+
+fun read_line stream =
+  let
+    val result = trim_line o String.implode o rev;
+    fun read cs =
+      (case BinIO.input1 stream of
+        NONE => if null cs then NONE else SOME (result cs)
+      | SOME b =>
+          (case Byte.byteToChar b of
+            #"\n" => SOME (result cs)
+          | c => read (c :: cs)));
+  in read [] end;
+
+fun read_block stream n =
+  Byte.bytesToString (BinIO.inputN (stream, n));
+
+end;
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Pure/PIDE/byte_message.scala	Tue Dec 11 19:25:35 2018 +0100
@@ -0,0 +1,74 @@
+/*  Title:      Pure/General/byte_message.scala
+    Author:     Makarius
+
+Byte-oriented messages.
+*/
+
+package isabelle
+
+import java.io.{ByteArrayOutputStream, OutputStream, InputStream, IOException}
+
+
+object Byte_Message
+{
+  def read_line(stream: InputStream): Option[Bytes] =
+  {
+    val line = new ByteArrayOutputStream(100)
+    var c = 0
+    while ({ c = stream.read; c != -1 && c != 10 }) line.write(c)
+
+    if (c == -1 && line.size == 0) None
+    else {
+      val a = line.toByteArray
+      val n = a.length
+      val len = if (n > 0 && a(n - 1) == 13) n - 1 else n
+      Some(Bytes(a, 0, len))
+    }
+  }
+
+  def read_block(stream: InputStream, length: Int): Option[Bytes] =
+  {
+    val msg = Bytes.read_stream(stream, limit = length)
+    if (msg.length == length) Some(msg) else None
+  }
+
+
+  /* hybrid messages: line or length+block, with content restriction */
+
+  private def is_length(msg: Bytes): Boolean =
+    !msg.is_empty && msg.iterator.forall(b => Symbol.is_ascii_digit(b.toChar))
+
+  private def has_line_terminator(msg: Bytes): Boolean =
+  {
+    val len = msg.length
+    len > 0 && Symbol.is_ascii_line_terminator(msg.charAt(len - 1))
+  }
+
+  def write_line_message(stream: OutputStream, msg: Bytes)
+  {
+    if (is_length(msg) || has_line_terminator(msg))
+      error ("Bad content for line message:\n" ++ msg.text.take(100))
+
+    if (msg.length > 100 || msg.iterator.contains(10)) {
+      stream.write(UTF8.bytes((msg.length + 1).toString))
+      stream.write(10)
+    }
+    msg.write_stream(stream)
+    stream.write(10)
+
+    try { stream.flush() } catch { case _: IOException => }
+  }
+
+  def read_line_message(stream: InputStream): Option[Bytes] =
+  {
+    try {
+      read_line(stream) match {
+        case Some(line) =>
+          if (is_length(line)) read_block(stream, Value.Int.parse(line.text)).map(_.trim_line)
+          else Some(line)
+        case None => None
+      }
+    }
+    catch { case _: IOException => None }
+  }
+}
--- a/src/Pure/ROOT.ML	Mon Dec 10 23:36:29 2018 +0100
+++ b/src/Pure/ROOT.ML	Tue Dec 11 19:25:35 2018 +0100
@@ -83,12 +83,12 @@
 ML_file "General/file.ML";
 ML_file "General/long_name.ML";
 ML_file "General/binding.ML";
-ML_file "General/bytes.ML";
 ML_file "General/socket_io.ML";
 ML_file "General/seq.ML";
 ML_file "General/timing.ML";
 ML_file "General/sha1.ML";
 
+ML_file "PIDE/byte_message.ML";
 ML_file "PIDE/yxml.ML";
 ML_file "PIDE/document_id.ML";
 
--- a/src/Pure/System/system_channel.ML	Mon Dec 10 23:36:29 2018 +0100
+++ b/src/Pure/System/system_channel.ML	Tue Dec 11 19:25:35 2018 +0100
@@ -19,8 +19,8 @@
 
 datatype T = System_Channel of BinIO.instream * BinIO.outstream;
 
-fun input_line (System_Channel (stream, _)) = Bytes.read_line stream;
-fun inputN (System_Channel (stream, _)) n = Bytes.read_block stream n;
+fun input_line (System_Channel (stream, _)) = Byte_Message.read_line stream;
+fun inputN (System_Channel (stream, _)) n = Byte_Message.read_block stream n;
 
 fun output (System_Channel (_, stream)) s = File.output stream s;
 fun flush (System_Channel (_, stream)) = BinIO.flushOut stream;
--- a/src/Pure/Tools/server.scala	Mon Dec 10 23:36:29 2018 +0100
+++ b/src/Pure/Tools/server.scala	Tue Dec 11 19:25:35 2018 +0100
@@ -181,26 +181,10 @@
         interrupt = interrupt)
 
     def read_message(): Option[String] =
-      try {
-        Bytes.read_line(in).map(_.text) match {
-          case Some(Value.Int(n)) =>
-            Bytes.read_block(in, n).map(bytes => Library.trim_line(bytes.text))
-          case res => res
-        }
-      }
-      catch { case _: SocketException => None }
+      Byte_Message.read_line_message(in).map(_.text)
 
-    def write_message(msg: String): Unit = out_lock.synchronized
-    {
-      val b = UTF8.bytes(msg)
-      if (b.length > 100 || b.contains(10)) {
-        out.write(UTF8.bytes((b.length + 1).toString))
-        out.write(10)
-      }
-      out.write(b)
-      out.write(10)
-      try { out.flush() } catch { case _: SocketException => }
-    }
+    def write_message(msg: String): Unit =
+      out_lock.synchronized { Byte_Message.write_line_message(out, Bytes(UTF8.bytes(msg))) }
 
     def reply(r: Reply.Value, arg: Any)
     {
--- a/src/Pure/build-jars	Mon Dec 10 23:36:29 2018 +0100
+++ b/src/Pure/build-jars	Tue Dec 11 19:25:35 2018 +0100
@@ -89,6 +89,7 @@
   ML/ml_process.scala
   ML/ml_statistics.scala
   ML/ml_syntax.scala
+  PIDE/byte_message.scala
   PIDE/command.scala
   PIDE/command_span.scala
   PIDE/document.scala
--- a/src/Tools/Haskell/Haskell.thy	Mon Dec 10 23:36:29 2018 +0100
+++ b/src/Tools/Haskell/Haskell.thy	Tue Dec 11 19:25:35 2018 +0100
@@ -1368,15 +1368,18 @@
     \([], a) -> App (pair term term a)]
 \<close>
 
-generate_file "Isabelle/Bytes.hs" = \<open>
-{-  Title:      Isabelle/Bytes.hs
+generate_file "Isabelle/Byte_Message.hs" = \<open>
+{-  Title:      Isabelle/Byte_Message.hs
     Author:     Makarius
     LICENSE:    BSD 3-clause (Isabelle)
 
-Byte-vector messages.
+Byte-oriented messages.
+
+See \<^file>\<open>$ISABELLE_HOME/src/Pure/PIDE/byte_message.ML\<close>
+and \<^file>\<open>$ISABELLE_HOME/src/Pure/PIDE/byte_message.scala\<close>.
 -}
 
-module Isabelle.Bytes (read_line, read_block, read_message, write_message)
+module Isabelle.Byte_Message (read_line, read_block, trim_line, read_line_message, write_line_message)
 where
 
 import Data.ByteString (ByteString)
@@ -1384,6 +1387,7 @@
 import qualified Data.ByteString.UTF8 as UTF8
 import Data.Word (Word8)
 
+import Control.Monad (when)
 import Network.Socket (Socket)
 import qualified Network.Socket as Socket
 import qualified Network.Socket.ByteString as ByteString
@@ -1391,8 +1395,6 @@
 import qualified Isabelle.Value as Value
 
 
--- see also \<^file>\<open>$ISABELLE_HOME/src/Pure/General/bytes.ML\<close>
-
 read_line :: Socket -> IO (Maybe ByteString)
 read_line socket = read []
   where
@@ -1411,13 +1413,15 @@
             10 -> return (Just (result bs))
             b -> read (b : bs)
 
-read_block :: Socket -> Int -> IO ByteString
+read_block :: Socket -> Int -> IO (Maybe ByteString)
 read_block socket n = read 0 []
   where
-    result :: [ByteString] -> ByteString
-    result = ByteString.concat . reverse
+    result :: [ByteString] -> Maybe ByteString
+    result ss =
+      if ByteString.length s == n then Just s else Nothing
+      where s = ByteString.concat (reverse ss)
 
-    read :: Int -> [ByteString] -> IO ByteString
+    read :: Int -> [ByteString] -> IO (Maybe ByteString)
     read len ss =
       if len >= n then return (result ss)
       else
@@ -1427,27 +1431,48 @@
             0 -> return (result ss)
             m -> read (len + m) (s : ss))
 
+trim_line :: ByteString -> ByteString
+trim_line s =
+    if n >= 2 && at (n - 2) == 13 && at (n - 1) == 10 then ByteString.take (n - 2) s
+    else if n >= 1 && (at (n - 1) == 13 || at (n - 1) == 10) then ByteString.take (n - 1) s
+    else s
+  where
+    n = ByteString.length s
+    at = ByteString.index s
 
--- see also \<^file>\<open>$ISABELLE_HOME/src/Pure/Tools/server.scala\<close>
+
+
+-- hybrid messages: line or length+block (with content restriction)
+
+is_length :: ByteString -> Bool
+is_length s =
+  not (ByteString.null s) && ByteString.all (\b -> 48 <= b && b <= 57) s
 
-read_message :: Socket -> IO (Maybe ByteString)
-read_message socket = do
+has_line_terminator :: ByteString -> Bool
+has_line_terminator s =
+  not (ByteString.null s) && (ByteString.last s == 13 || ByteString.last s == 10)
+
+write_line_message :: Socket -> ByteString -> IO ()
+write_line_message socket msg = do
+  when (is_length msg || has_line_terminator msg) $
+    error ("Bad content for line message:\n" ++ take 100 (UTF8.toString msg))
+
+  let newline = ByteString.singleton 10
+  let n = ByteString.length msg
+  ByteString.sendMany socket
+    (if n > 100 || ByteString.any (== 10) msg then
+      [UTF8.fromString (Value.print_int (n + 1)), newline, msg, newline]
+     else [msg, newline])
+
+read_line_message :: Socket -> IO (Maybe ByteString)
+read_line_message socket = do
   opt_line <- read_line socket
   case opt_line of
     Nothing -> return Nothing
     Just line ->
       case Value.parse_int (UTF8.toString line) of
         Nothing -> return $ Just line
-        Just n -> Just <$> read_block socket n
-
-write_message :: Socket -> ByteString -> IO ()
-write_message socket msg = do
-  let newline = ByteString.singleton 10
-  let n = ByteString.length msg
-  ByteString.sendMany socket
-    (if n > 100 || ByteString.any (== 10) msg then
-      [UTF8.fromString (Value.print_int (n + 1)), newline, msg, newline]
-     else [msg, newline])
+        Just n -> fmap trim_line <$> read_block socket n
 \<close>
 
 end
--- a/src/Tools/VSCode/src/channel.scala	Mon Dec 10 23:36:29 2018 +0100
+++ b/src/Tools/VSCode/src/channel.scala	Tue Dec 11 19:25:35 2018 +0100
@@ -21,7 +21,7 @@
   private val Content_Length = """^\s*Content-Length:\s*(\d+)\s*$""".r
 
   private def read_line(): String =
-    Bytes.read_line(in) match {
+    Byte_Message.read_line(in) match {
       case Some(bytes) => bytes.text
       case None => ""
     }