src/Pure/General/bytes.scala
changeset 76351 2cee31cd92f0
parent 76350 978f7ca3329f
child 76353 3698d0f3da18
--- a/src/Pure/General/bytes.scala	Fri Oct 21 14:45:13 2022 +0200
+++ b/src/Pure/General/bytes.scala	Fri Oct 21 16:39:31 2022 +0200
@@ -7,10 +7,10 @@
 package isabelle
 
 
-import java.io.{File => JFile, ByteArrayOutputStream, ByteArrayInputStream,
-  OutputStream, InputStream, FileInputStream, FileOutputStream}
+import com.github.luben.zstd.{ZstdInputStream, ZstdOutputStream}
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream, InputStream, OutputStream, File as JFile}
 import java.net.URL
-
 import org.tukaani.xz.{XZInputStream, XZOutputStream}
 
 
@@ -191,20 +191,59 @@
   def write_stream(stream: OutputStream): Unit = stream.write(bytes, offset, length)
 
 
-  /* XZ data compression */
+  /* XZ / Zstd data compression */
+
+  private def detect_xz: Boolean =
+    length >= 6 &&
+      bytes(offset)     == 0xFD.toByte &&
+      bytes(offset + 1) == 0x37.toByte &&
+      bytes(offset + 2) == 0x7A.toByte &&
+      bytes(offset + 3) == 0x58.toByte &&
+      bytes(offset + 4) == 0x5A.toByte &&
+      bytes(offset + 5) == 0x00.toByte
+
+  private def detect_zstd: Boolean =
+    length >= 4 &&
+      bytes(offset)     == 0x28.toByte &&
+      bytes(offset + 1) == 0xB5.toByte &&
+      bytes(offset + 2) == 0x2F.toByte &&
+      bytes(offset + 3) == 0xFD.toByte
+
+  private def detect_error(name: String = ""): Nothing =
+    error("Cannot detect compression scheme" + (if (name.isEmpty) "" else " " + name))
 
-  def uncompress(cache: XZ.Cache = XZ.Cache.none): Bytes =
-    using(new XZInputStream(stream(), cache))(Bytes.read_stream(_, hint = length))
+  def uncompress(cache: Compress.Cache = Compress.Cache.none): Bytes =
+    using(
+      if (detect_xz) new XZInputStream(stream(), cache.xz)
+      else if (detect_zstd) { Zstd.init(); new ZstdInputStream(stream(), cache.zstd) }
+      else detect_error()
+    )(Bytes.read_stream(_, hint = length))
+
+  def uncompress_xz(cache: Compress.Cache = Compress.Cache.none): Bytes =
+    if (detect_xz) uncompress(cache = cache) else detect_error("XZ")
+
+  def uncompress_zstd(cache: Compress.Cache = Compress.Cache.none): Bytes =
+    if (detect_zstd) uncompress(cache = cache) else detect_error("Zstd")
 
-  def compress(options: XZ.Options = XZ.options(), cache: XZ.Cache = XZ.Cache.none): Bytes = {
+  def compress(
+    options: Compress.Options = Compress.Options(),
+    cache: Compress.Cache = Compress.Cache.none
+  ): Bytes = {
     val result = new ByteArrayOutputStream(length)
-    using(new XZOutputStream(result, options, cache))(write_stream(_))
+    using(
+      options match {
+        case options_xz: Compress.Options_XZ =>
+          new XZOutputStream(result, options_xz.make, cache.xz)
+        case options_zstd: Compress.Options_Zstd =>
+          Zstd.init()
+          new ZstdOutputStream(result, cache.zstd, options_zstd.level)
+      })(write_stream)
     new Bytes(result.toByteArray, 0, result.size)
   }
 
   def maybe_compress(
-    options: XZ.Options = XZ.options(),
-    cache: XZ.Cache = XZ.Cache.none
+    options: Compress.Options = Compress.Options(),
+    cache: Compress.Cache = Compress.Cache.none
   ) : (Boolean, Bytes) = {
     val compressed = compress(options = options, cache = cache)
     if (compressed.length < length) (true, compressed) else (false, this)