support large byte arrays, using multiple "chunks";
authorwenzelm
Sat, 15 Jun 2024 17:12:49 +0200
changeset 80366 ac4d53bc8f6b
parent 80365 29b761e290c5
child 80367 a6c1526600b3
support large byte arrays, using multiple "chunks"; support incremental builder; clarified "limit" (valid >= 0) vs. "hint" (valid > 0); clarified byte access: prefer unchecked acces and iterators internally;
src/Pure/General/bytes.scala
--- a/src/Pure/General/bytes.scala	Sat Jun 15 12:27:57 2024 +0200
+++ b/src/Pure/General/bytes.scala	Sat Jun 15 17:12:49 2024 +0200
@@ -16,78 +16,93 @@
 import org.tukaani.xz
 import com.github.luben.zstd
 
+import scala.collection.mutable.ArrayBuffer
+
 
 object Bytes {
-  val empty: Bytes = new Bytes(Array[Byte](), 0, 0)
+  /* internal sizes */
+
+  val array_size: Long = Int.MaxValue - 8  // see java.io.InputStream.MAX_BUFFER_SIZE
+  val chunk_size: Long = Space.MiB(100).bytes
+  val block_size: Int = 8192
+
+  class Too_Large(size: Long) extends IndexOutOfBoundsException {
+    override def getMessage: String =
+      "Bytes too large for particular operation: " +
+        Space.bytes(size).print + " > " + Space.bytes(array_size).print
+  }
+
 
-  def apply(s: CharSequence): Bytes =
-    if (s.isEmpty) empty
-    else {
-      val b = UTF8.bytes(s.toString)
-      new Bytes(b, 0, b.length)
-    }
+  /* main constructors */
+
+  private def reuse_array(bytes: Array[Byte]): Bytes =
+    if (bytes.length <= chunk_size) new Bytes(None, bytes, 0L, bytes.length.toLong)
+    else apply(bytes)
+
+  val empty: Bytes = reuse_array(new Array(0))
+
+  def apply(s: CharSequence): Bytes = {
+    val str = s.toString
+    if (str.isEmpty) empty
+    else Builder.use(hint = str.length) { builder => builder += str }
+  }
 
   def apply(a: Array[Byte]): Bytes = apply(a, 0, a.length)
 
   def apply(a: Array[Byte], offset: Int, length: Int): Bytes =
-    if (length == 0) empty
-    else {
-      val b = new Array[Byte](length)
-      System.arraycopy(a, offset, b, 0, length)
-      new Bytes(b, 0, b.length)
-    }
+    Builder.use(hint = length) { builder => builder += (a, offset, length) }
 
   val newline: Bytes = apply("\n")
 
-
-  /* base64 */
-
-  def decode_base64(s: String): Bytes = {
-    val a = Base64.decode(s)
-    new Bytes(a, 0, a.length)
-  }
+  def decode_base64(s: String): Bytes = Bytes.reuse_array(Base64.decode(s))
 
 
   /* read */
 
-  def read_stream(stream: InputStream, limit: Int = Int.MaxValue, hint: Int = 1024): Bytes =
+  def read_stream(stream: InputStream, limit: Long = -1L, hint: Long = 0L): Bytes = {
     if (limit == 0) empty
     else {
-      val out_size = (if (limit == Int.MaxValue) hint else limit) max 1024
-      val out = new ByteArrayOutputStream(out_size)
-      val buf = new Array[Byte](8192)
-      var m = 0
-
-      while ({
-        m = stream.read(buf, 0, buf.length min (limit - out.size))
-        if (m != -1) out.write(buf, 0, m)
-        m != -1 && limit > out.size
-      }) ()
-
-      new Bytes(out.toByteArray, 0, out.size)
+      Builder.use(hint = if (limit > 0) limit else hint) { builder =>
+        val buf = new Array[Byte](Bytes.block_size)
+        var m = 0
+        var n = 0L
+        while ({
+          val l = if (limit > 0) ((limit - n) min buf.length).toInt else buf.length
+          m = stream.read(buf, 0, l)
+          if (m != -1) {
+            builder += (buf, 0, m)
+            n += m
+          }
+          m != -1 && (limit < 0 || limit > n)
+        }) ()
+      }
     }
+  }
 
   def read_url(name: String): Bytes = using(Url(name).open_stream())(read_stream(_))
 
-  def read_file(path: Path, offset: Long = 0L, limit: Long = Long.MaxValue): Bytes = {
+  def read_file(path: Path, offset: Long = 0L, limit: Long = -1L): Bytes = {
     val length = File.size(path)
     val start = offset.max(0L)
-    val len = (length - start).max(0L).min(limit)
-    if (len > Int.MaxValue) error("Cannot read large file slice: " + Space.bytes(len).print)
-    else if (len == 0L) empty
+    val len = (length - start).max(0L).min(if (limit < 0) Long.MaxValue else limit)
+    if (len == 0L) empty
     else {
-      using(FileChannel.open(path.java_path, StandardOpenOption.READ)) { channel =>
-        channel.position(start)
-        val n = len.toInt
-        val buf = ByteBuffer.allocate(n)
-        var i = 0
-        var m = 0
-        while ({
-          m = channel.read(buf)
-          if (m != -1) i += m
-          m != -1 && n > i
-        }) ()
-        new Bytes(buf.array, 0, i)
+      Builder.use(hint = len) { builder =>
+        using(FileChannel.open(path.java_path, StandardOpenOption.READ)) { channel =>
+          channel.position(start)
+          val buf = ByteBuffer.allocate(Bytes.block_size)
+          var m = 0
+          var n = 0L
+          while ({
+            m = channel.read(buf)
+            if (m != -1) {
+              builder += (buf.array(), 0, m)
+              buf.clear()
+              n += m
+            }
+            m != -1 && len > n
+          }) ()
+        }
       }
     }
   }
@@ -125,37 +140,150 @@
       if (0 <= i && i < size) string(i.toInt)
       else throw new IndexOutOfBoundsException
   }
+
+
+  /* incremental builder: synchronized */
+
+  private def make_size(chunks: Array[Array[Byte]], buffer: Array[Byte]): Long =
+    chunks.foldLeft(buffer.length.toLong)((n, chunk) => n + chunk.length)
+
+  object Builder {
+    def use(hint: Long = 0L)(body: Builder => Unit): Bytes = {
+      val chunks_size = if (hint <= 0) 16 else (hint / chunk_size).toInt
+      val buffer_size = if (hint <= 0) 1024 else (hint min chunk_size min array_size).toInt
+      val builder = new Builder(chunks_size, buffer_size)
+      body(builder)
+      builder.done()
+    }
+  }
+
+  final class Builder private[Bytes](chunks_size: Int, buffer_size: Int) {
+    var chunks = new ArrayBuffer[Array[Byte]](chunks_size)
+    var buffer = new ByteArrayOutputStream(buffer_size)
+    def buffer_free(): Int = chunk_size.toInt - buffer.size()
+
+    def += (array: Array[Byte], offset: Int, length: Int): Unit = {
+      if (offset < 0 || length < 0 || offset.toLong + length.toLong > array.length) {
+        throw new IndexOutOfBoundsException
+      }
+      else if (length > 0) {
+        synchronized {
+          var i = offset
+          var n = length
+          while (n > 0) {
+            val m = buffer_free()
+            if (m > 0) {
+              val l = m min n
+              buffer.write(array, i, l)
+              i += l
+              n -= l
+            }
+            if (buffer_free() == 0) {
+              chunks += buffer.toByteArray
+              buffer = new ByteArrayOutputStream
+            }
+          }
+        }
+      }
+    }
+
+    def += (array: Array[Byte]): Unit = { this += (array, 0, array.length) }
+
+    def += (a: Subarray): Unit = { this += (a.array, a.offset, a.length) }
+
+    def += (string: String): Unit = if (string.nonEmpty) { this += UTF8.bytes(string) }
+
+    private def done(): Bytes = synchronized {
+      val cs = chunks.toArray
+      val b = buffer.toByteArray
+      chunks = null
+      buffer = null
+      new Bytes(if (cs.isEmpty) None else Some(cs), b, 0L, make_size(cs, b))
+    }
+  }
+
+
+  /* subarray */
+
+  object Subarray {
+    val empty: Subarray = new Subarray(new Array[Byte](0), 0, 0)
+
+    def apply(array: Array[Byte], offset: Int, length: Int): Subarray = {
+      val n = array.length
+      if (0 <= offset && offset < n && 0 <= length && offset + length <= n) {
+        if (length == 0) empty
+        else new Subarray(array, offset, length)
+      }
+      else throw new IndexOutOfBoundsException
+    }
+  }
+
+  final class Subarray private(
+    val array: Array[Byte],
+    val offset: Int,
+    val length: Int
+  ) {
+    override def toString: String = "Bytes.Subarray(" + Space.bytes(length).print + ")"
+
+    def byte_iterator: Iterator[Byte] =
+      if (length == 0) Iterator.empty
+      else { for (i <- (offset until (offset + length)).iterator) yield array(i) }
+  }
 }
 
 final class Bytes private(
-  protected val bytes: Array[Byte],
-  protected val offset: Int,
-  protected val length: Int
+  protected val chunks: Option[Array[Array[Byte]]],
+  protected val chunk0: Array[Byte],
+  protected val offset: Long,
+  val size: Long
 ) extends Bytes.Vec {
-
-  def size: Long = length.toLong
+  assert(
+    (chunks.isEmpty ||
+      chunks.get.nonEmpty &&
+      chunks.get.forall(chunk => chunk.length == Bytes.chunk_size)) &&
+    chunk0.length < Bytes.chunk_size)
 
   def is_empty: Boolean = size == 0
 
   def is_sliced: Boolean =
-    offset != 0L || length != bytes.length
+    offset != 0L || {
+      chunks match {
+        case None => size != chunk0.length
+        case Some(cs) => size != Bytes.make_size(cs, chunk0)
+      }
+    }
 
   override def toString: String =
     if (is_empty) "Bytes.empty"
     else "Bytes(" + Space.bytes(size).print + if_proper(is_sliced, ", sliced") + ")"
 
+  def small_size: Int =
+    if (size > Bytes.array_size) throw new Bytes.Too_Large(size)
+    else size.toInt
+
 
   /* slice */
 
   def slice(i: Long, j: Long): Bytes =
-    if (0 <= i && i <= j && j <= size) new Bytes(bytes, (offset + i).toInt, (j - i).toInt)
+    if (0 <= i && i <= j && j <= size) {
+      if (i == j) Bytes.empty
+      else new Bytes(chunks, chunk0, offset + i, j - i)
+    }
     else throw new IndexOutOfBoundsException
 
+  def unslice: Bytes =
+    if (is_sliced) {
+      Bytes.Builder.use(hint = size) { builder =>
+        for (a <- subarray_iterator) { builder += a }
+      }
+    }
+    else this
+
   def trim_line: Bytes =
-    if (size >= 2 && apply(size - 2) == 13 && apply(size - 1) == 10) {
+    if (size >= 2 && byte_unchecked(size - 2) == 13 && byte_unchecked(size - 1) == 10) {
       slice(0, size - 2)
     }
-    else if (size >= 1 && (apply(size - 1) == 13 || apply(size - 1) == 10)) {
+    else if (size >= 1 && (byte_unchecked(size - 1) == 13 || byte_unchecked(size - 1) == 10)) {
       slice(0, size - 1)
     }
     else this
@@ -163,19 +291,58 @@
 
   /* elements: signed Byte or unsigned Char */
 
-  def byte_iterator: Iterator[Byte] =
-    for (i <- (offset until (offset + length)).iterator)
-      yield bytes(i)
+  protected def byte_unchecked(i: Long): Byte = {
+    val a = offset + i
+    chunks match {
+      case None => chunk0(a.toInt)
+      case Some(cs) =>
+        val b = a % Bytes.chunk_size
+        val c = a / Bytes.chunk_size
+        if (c < cs.length) cs(c.toInt)(b.toInt) else chunk0(b.toInt)
+    }
+  }
+
+  def byte(i: Long): Byte =
+    if (0 <= i && i < size) byte_unchecked(i)
+    else throw new IndexOutOfBoundsException
+
+  def apply(i: Long): Char = (byte(i).toInt & 0xff).toChar
 
-  def apply(i: Long): Char =
-    if (0 <= i && i < size) (bytes((offset + i).toInt).asInstanceOf[Int] & 0xFF).asInstanceOf[Char]
-    else throw new IndexOutOfBoundsException
+  protected def subarray_iterator: Iterator[Bytes.Subarray] =
+    if (is_empty) Iterator.empty
+    else if (chunks.isEmpty) Iterator(Bytes.Subarray(chunk0, offset.toInt, size.toInt))
+    else {
+      val end_offset = offset + size
+      for ((array, index) <- (chunks.get.iterator ++ Iterator(chunk0)).zipWithIndex) yield {
+        val array_start = Bytes.chunk_size * index
+        val array_stop = array_start + array.length
+        if (offset < array_stop && array_start < end_offset) {
+          val i = (array_start max offset) - array_start
+          val j = (array_stop min end_offset) - array_start
+          Bytes.Subarray(array, i.toInt, (j - i).toInt)
+        }
+        else Bytes.Subarray.empty
+      }
+    }
+
+  def byte_iterator: Iterator[Byte] =
+    for {
+      a <- subarray_iterator
+      b <- a.byte_iterator
+    } yield b
 
 
   /* hash and equality */
 
   lazy val sha1_digest: SHA1.Digest =
-    if (is_empty) SHA1.digest_empty else SHA1.digest(bytes, offset, length)
+    if (is_empty) SHA1.digest_empty
+    else {
+      SHA1.make_digest { sha =>
+        for (a <- subarray_iterator if a.length > 0) {
+          sha.update(a.array, a.offset, a.length)
+        }
+      }
+    }
 
   override def hashCode(): Int = sha1_digest.hashCode()
 
@@ -184,9 +351,9 @@
       case other: Bytes =>
         if (this.eq(other)) true
         else if (size != other.size) false
-        else if (size <= 10 * SHA1.digest_length) {
-          Arrays.equals(bytes, offset, offset + length,
-            other.bytes, other.offset, other.offset + other.length)
+        else if (chunks.isEmpty && size <= 10 * SHA1.digest_length) {
+          Arrays.equals(chunk0, offset.toInt, (offset + size).toInt,
+            other.chunk0, other.offset.toInt, (other.offset + other.size).toInt)
         }
         else sha1_digest == other.sha1_digest
       case _ => false
@@ -197,29 +364,26 @@
   /* content */
 
   def array: Array[Byte] = {
-    val a = new Array[Byte](length)
-    System.arraycopy(bytes, offset, a, 0, length)
-    a
+    val buf = new ByteArrayOutputStream(small_size)
+    for (a <- subarray_iterator) { buf.write(a.array, a.offset, a.length) }
+    buf.toByteArray
   }
 
   def text: String =
     if (is_empty) ""
     else if (byte_iterator.forall(_ >= 0)) {
-      new String(bytes, offset, length, UTF8.charset)
+      new String(array, UTF8.charset)
     }
     else UTF8.decode_permissive_bytes(this)
 
-  def wellformed_text: Option[String] = {
-    val s = text
-    if (this == Bytes(s)) Some(s) else None
-  }
+  def wellformed_text: Option[String] =
+    try {
+      val s = text
+      if (this == Bytes(s)) Some(s) else None
+    }
+    catch { case ERROR(_) => None }
 
-  def encode_base64: String = {
-    val b =
-      if (offset == 0 && length == bytes.length) bytes
-      else Bytes(bytes, offset, length).bytes
-    Base64.encode(b)
-  }
+  def encode_base64: String = Base64.encode(array)
 
   def maybe_encode_base64: (Boolean, String) =
     wellformed_text match {
@@ -234,50 +398,70 @@
     if (other.is_empty) this
     else if (is_empty) other
     else {
-      val new_bytes = new Array[Byte](length + other.length)
-      System.arraycopy(bytes, offset, new_bytes, 0, length)
-      System.arraycopy(other.bytes, other.offset, new_bytes, length, other.length)
-      new Bytes(new_bytes, 0, new_bytes.length)
+      Bytes.Builder.use(hint = size + other.size) { builder =>
+        for (a <- subarray_iterator ++ other.subarray_iterator) {
+          builder += a
+        }
+      }
     }
 
 
   /* streams */
 
-  def stream(): ByteArrayInputStream = new ByteArrayInputStream(bytes, offset, length)
+  def stream(): InputStream =
+    if (chunks.isEmpty) new ByteArrayInputStream(chunk0, offset.toInt, size.toInt)
+    else {
+      new InputStream {
+        private val it = byte_iterator
+        def read(): Int = if (it.hasNext) it.next().toInt & 0xff else -1
+        override def readAllBytes(): Array[Byte] = array
+      }
+    }
 
-  def write_stream(stream: OutputStream): Unit = stream.write(bytes, offset, length)
+  def write_stream(stream: OutputStream): Unit =
+    for (a <- subarray_iterator if a.length > 0) {
+      stream.write(a.array, a.offset, a.length)
+    }
 
 
   /* XZ / Zstd data compression */
 
   def detect_xz: Boolean =
     size >= 6 &&
-      bytes(offset)     == 0xFD.toByte &&
-      bytes(offset + 1) == 0x37.toByte &&
-      bytes(offset + 2) == 0x7A.toByte &&
-      bytes(offset + 3) == 0x58.toByte &&
-      bytes(offset + 4) == 0x5A.toByte &&
-      bytes(offset + 5) == 0x00.toByte
+      byte_unchecked(0) == 0xFD.toByte &&
+      byte_unchecked(1) == 0x37.toByte &&
+      byte_unchecked(2) == 0x7A.toByte &&
+      byte_unchecked(3) == 0x58.toByte &&
+      byte_unchecked(4) == 0x5A.toByte &&
+      byte_unchecked(5) == 0x00.toByte
 
   def detect_zstd: Boolean =
     size >= 4 &&
-      bytes(offset)     == 0x28.toByte &&
-      bytes(offset + 1) == 0xB5.toByte &&
-      bytes(offset + 2) == 0x2F.toByte &&
-      bytes(offset + 3) == 0xFD.toByte
+      byte_unchecked(0) == 0x28.toByte &&
+      byte_unchecked(1) == 0xB5.toByte &&
+      byte_unchecked(2) == 0x2F.toByte &&
+      byte_unchecked(3) == 0xFD.toByte
 
   def uncompress_xz(cache: Compress.Cache = Compress.Cache.none): Bytes =
-    using(new xz.XZInputStream(stream(), cache.for_xz))(Bytes.read_stream(_, hint = length))
+    using(new xz.XZInputStream(stream(), cache.for_xz))(Bytes.read_stream(_, hint = size))
 
   def uncompress_zstd(cache: Compress.Cache = Compress.Cache.none): Bytes = {
     Zstd.init()
-    val n = zstd.Zstd.decompressedSize(bytes, offset, length)
-    if (n > 0 && n < Int.MaxValue) {
-      Bytes(zstd.Zstd.decompress(array, n.toInt))
+
+    def uncompress_stream(hint: Long): Bytes =
+      using(new zstd.ZstdInputStream(stream(), cache.for_zstd)) { inp =>
+        Bytes.read_stream(inp, hint = hint)
+      }
+
+    if (chunks.isEmpty) {
+      zstd.Zstd.decompressedSize(chunk0, offset.toInt, size.toInt) match {
+        case 0 => Bytes.empty
+        case n if n <= Bytes.array_size && !is_sliced =>
+          Bytes.reuse_array(zstd.Zstd.decompress(chunk0, n.toInt))
+        case n => uncompress_stream(n)
+      }
     }
-    else {
-      using(new zstd.ZstdInputStream(stream(), cache.for_zstd))(Bytes.read_stream(_, hint = length))
-    }
+    else uncompress_stream(size / 2)
   }
 
   def uncompress(cache: Compress.Cache = Compress.Cache.none): Bytes =
@@ -291,12 +475,13 @@
   ): Bytes = {
     options match {
       case options_xz: Compress.Options_XZ =>
-        val result = new ByteArrayOutputStream(length)
-        using(new xz.XZOutputStream(result, options_xz.make, cache.for_xz))(write_stream)
-        new Bytes(result.toByteArray, 0, result.size)
+        val out = new ByteArrayOutputStream((size min Bytes.array_size).toInt)
+        using(new xz.XZOutputStream(out, options_xz.make, cache.for_xz))(write_stream)
+        Bytes(out.toByteArray)
       case options_zstd: Compress.Options_Zstd =>
         Zstd.init()
-        Bytes(zstd.Zstd.compress(if (offset == 0) bytes else array, options_zstd.level))
+        val inp = if (chunks.isEmpty && !is_sliced) chunk0 else array
+        Bytes(zstd.Zstd.compress(inp, options_zstd.level))
     }
   }