src/Pure/General/bytes.scala
author wenzelm
Fri, 21 Oct 2022 18:06:32 +0200
changeset 76353 3698d0f3da18
parent 76351 2cee31cd92f0
child 76358 cff0828c374f
permissions -rw-r--r--
clarified signature;

/*  Title:      Pure/General/bytes.scala
    Author:     Makarius

Immutable byte vectors versus UTF8 strings.
*/

package isabelle


import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream, InputStream, OutputStream, File as JFile}
import java.net.URL
import org.tukaani.xz
import com.github.luben.zstd


object Bytes {
  val empty: Bytes = new Bytes(Array[Byte](), 0, 0)

  def apply(s: CharSequence): Bytes = {
    val str = s.toString
    if (str.isEmpty) empty
    else {
      val b = UTF8.bytes(str)
      new Bytes(b, 0, b.length)
    }
  }

  def apply(a: Array[Byte]): Bytes = apply(a, 0, a.length)

  def apply(a: Array[Byte], offset: Int, length: Int): Bytes =
    if (length == 0) empty
    else {
      val b = new Array[Byte](length)
      System.arraycopy(a, offset, b, 0, length)
      new Bytes(b, 0, b.length)
    }

  val newline: Bytes = apply("\n")


  /* base64 */

  def decode_base64(s: String): Bytes = {
    val a = Base64.decode(s)
    new Bytes(a, 0, a.length)
  }


  /* read */

  def read_stream(stream: InputStream, limit: Int = Integer.MAX_VALUE, hint: Int = 1024): Bytes =
    if (limit == 0) empty
    else {
      val out_size = (if (limit == Integer.MAX_VALUE) hint else limit) max 1024
      val out = new ByteArrayOutputStream(out_size)
      val buf = new Array[Byte](8192)
      var m = 0

      while ({
        m = stream.read(buf, 0, buf.length min (limit - out.size))
        if (m != -1) out.write(buf, 0, m)
        m != -1 && limit > out.size
      }) ()

      new Bytes(out.toByteArray, 0, out.size)
    }

  def read(file: JFile): Bytes = {
    val length = file.length
    val limit = if (length < 0 || length > Integer.MAX_VALUE) Integer.MAX_VALUE else length.toInt
    using(new FileInputStream(file))(read_stream(_, limit = limit))
  }

  def read(path: Path): Bytes = read(path.file)

  def read(url: URL): Bytes = using(url.openStream)(read_stream(_))


  /* write */

  def write(file: JFile, bytes: Bytes): Unit =
    using(new FileOutputStream(file))(bytes.write_stream(_))

  def write(path: Path, bytes: Bytes): Unit = write(path.file, bytes)
}

final class Bytes private(
  protected val bytes: Array[Byte],
  protected val offset: Int,
  val length: Int) extends CharSequence {
  /* equality */

  override def equals(that: Any): Boolean = {
    that match {
      case other: Bytes =>
        if (this eq other) true
        else if (length != other.length) false
        else (0 until length).forall(i => bytes(offset + i) == other.bytes(other.offset + i))
      case _ => false
    }
  }

  private lazy val hash: Int = {
    var h = 0
    for (i <- offset until offset + length) {
      val b = bytes(i).asInstanceOf[Int] & 0xFF
      h = 31 * h + b
    }
    h
  }

  override def hashCode(): Int = hash


  /* content */

  lazy val sha1_digest: SHA1.Digest = SHA1.digest(bytes)

  def is_empty: Boolean = length == 0

  def iterator: Iterator[Byte] =
    for (i <- (offset until (offset + length)).iterator)
      yield bytes(i)

  def array: Array[Byte] = {
    val a = new Array[Byte](length)
    System.arraycopy(bytes, offset, a, 0, length)
    a
  }

  def text: String = UTF8.decode_permissive(this)

  def wellformed_text: Option[String] = {
    val s = text
    if (this == Bytes(s)) Some(s) else None
  }

  def encode_base64: String = {
    val b =
      if (offset == 0 && length == bytes.length) bytes
      else Bytes(bytes, offset, length).bytes
    Base64.encode(b)
  }

  def maybe_encode_base64: (Boolean, String) =
    wellformed_text match {
      case Some(s) => (false, s)
      case None => (true, encode_base64)
    }

  override def toString: String = "Bytes(" + length + ")"

  def proper: Option[Bytes] = if (is_empty) None else Some(this)
  def proper_text: Option[String] = if (is_empty) None else Some(text)

  def +(other: Bytes): Bytes =
    if (other.is_empty) this
    else if (is_empty) other
    else {
      val new_bytes = new Array[Byte](length + other.length)
      System.arraycopy(bytes, offset, new_bytes, 0, length)
      System.arraycopy(other.bytes, other.offset, new_bytes, length, other.length)
      new Bytes(new_bytes, 0, new_bytes.length)
    }


  /* CharSequence operations */

  def charAt(i: Int): Char =
    if (0 <= i && i < length) (bytes(offset + i).asInstanceOf[Int] & 0xFF).asInstanceOf[Char]
    else throw new IndexOutOfBoundsException

  def subSequence(i: Int, j: Int): Bytes = {
    if (0 <= i && i <= j && j <= length) new Bytes(bytes, offset + i, j - i)
    else throw new IndexOutOfBoundsException
  }

  def trim_line: Bytes =
    if (length >= 2 && charAt(length - 2) == 13 && charAt(length - 1) == 10)
      subSequence(0, length - 2)
    else if (length >= 1 && (charAt(length - 1) == 13 || charAt(length - 1) == 10))
      subSequence(0, length - 1)
    else this


  /* streams */

  def stream(): ByteArrayInputStream = new ByteArrayInputStream(bytes, offset, length)

  def write_stream(stream: OutputStream): Unit = stream.write(bytes, offset, length)


  /* XZ / Zstd data compression */

  private def detect_xz: Boolean =
    length >= 6 &&
      bytes(offset)     == 0xFD.toByte &&
      bytes(offset + 1) == 0x37.toByte &&
      bytes(offset + 2) == 0x7A.toByte &&
      bytes(offset + 3) == 0x58.toByte &&
      bytes(offset + 4) == 0x5A.toByte &&
      bytes(offset + 5) == 0x00.toByte

  private def detect_zstd: Boolean =
    length >= 4 &&
      bytes(offset)     == 0x28.toByte &&
      bytes(offset + 1) == 0xB5.toByte &&
      bytes(offset + 2) == 0x2F.toByte &&
      bytes(offset + 3) == 0xFD.toByte

  private def detect_error(name: String = ""): Nothing =
    error("Cannot detect compression scheme" + (if (name.isEmpty) "" else " " + name))

  def uncompress(cache: Compress.Cache = Compress.Cache.none): Bytes =
    using(
      if (detect_xz) new xz.XZInputStream(stream(), cache.for_xz)
      else if (detect_zstd) {
        Zstd.init()
        new zstd.ZstdInputStream(stream(), cache.for_zstd)
      }
      else detect_error()
    )(Bytes.read_stream(_, hint = length))

  def uncompress_xz(cache: Compress.Cache = Compress.Cache.none): Bytes =
    if (detect_xz) uncompress(cache = cache) else detect_error("XZ")

  def uncompress_zstd(cache: Compress.Cache = Compress.Cache.none): Bytes =
    if (detect_zstd) uncompress(cache = cache) else detect_error("Zstd")

  def compress(
    options: Compress.Options = Compress.Options(),
    cache: Compress.Cache = Compress.Cache.none
  ): Bytes = {
    val result = new ByteArrayOutputStream(length)
    using(
      options match {
        case options_xz: Compress.Options_XZ =>
          new xz.XZOutputStream(result, options_xz.make, cache.for_xz)
        case options_zstd: Compress.Options_Zstd =>
          Zstd.init()
          new zstd.ZstdOutputStream(result, cache.for_zstd, options_zstd.level)
      })(write_stream)
    new Bytes(result.toByteArray, 0, result.size)
  }

  def maybe_compress(
    options: Compress.Options = Compress.Options(),
    cache: Compress.Cache = Compress.Cache.none
  ) : (Boolean, Bytes) = {
    val compressed = compress(options = options, cache = cache)
    if (compressed.length < length) (true, compressed) else (false, this)
  }
}