aboutsummaryrefslogtreecommitdiffstats
path: root/airlift-zstd
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2022-12-28 11:32:05 +0000
committergjoranv <gv@verizonmedia.com>2023-01-03 13:05:55 +0100
commit6ce10cc8150f5686d267e5e49e069ea2f37e4e8a (patch)
treecd8cd3916125b891c7d00513247e7e4e27c55fd7 /airlift-zstd
parentd5ca7762b704300d708e46007e781add20097fc4 (diff)
add fork of airlift zstd code
Diffstat (limited to 'airlift-zstd')
-rw-r--r--airlift-zstd/CMakeLists.txt2
-rw-r--r--airlift-zstd/pom.xml47
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/compress/Compressor.java28
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/compress/Decompressor.java28
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/compress/IncompatibleJvmException.java23
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/compress/MalformedInputException.java36
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BitInputStream.java207
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BitOutputStream.java90
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BlockCompressionState.java60
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BlockCompressor.java21
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/CompressionContext.java46
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/CompressionParameters.java306
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Constants.java85
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/DoubleFastBlockCompressor.java261
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FiniteStateEntropy.java551
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FrameHeader.java70
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FseCompressionTable.java158
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FseTableReader.java169
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Histogram.java65
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Huffman.java323
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionContext.java61
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionTable.java437
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionTableWorkspace.java33
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressor.java137
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanTableWriterWorkspace.java29
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/NodeTable.java48
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/RepeatedOffsets.java49
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceEncoder.java351
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceEncodingContext.java30
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceStore.java160
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/UnsafeUtil.java64
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Util.java94
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/XxHash64.java286
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdBlockDecompressor.java810
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdCompressor.java126
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdDecompressor.java119
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdFrameCompressor.java438
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdFrameDecompressor.java212
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdInputStream.java471
39 files changed, 6531 insertions, 0 deletions
diff --git a/airlift-zstd/CMakeLists.txt b/airlift-zstd/CMakeLists.txt
new file mode 100644
index 00000000000..c9be5ff262a
--- /dev/null
+++ b/airlift-zstd/CMakeLists.txt
@@ -0,0 +1,2 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+install_jar(airlift-zstd.jar)
diff --git a/airlift-zstd/pom.xml b/airlift-zstd/pom.xml
new file mode 100644
index 00000000000..2d2f83daed9
--- /dev/null
+++ b/airlift-zstd/pom.xml
@@ -0,0 +1,47 @@
+<?xml version="1.0"?>
+<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>parent</artifactId>
+ <version>8-SNAPSHOT</version>
+ <relativePath>../parent/pom.xml</relativePath>
+ </parent>
+ <artifactId>airlift-zstd</artifactId>
+ <packaging>jar</packaging>
+ <version>8-SNAPSHOT</version>
+ <description>
+ Fork of https://github.com/airlift/aircompressor (zstd only).
+ This module is temporary until we get an official release that includes the
+ ZstdInputStream API (which is already implemented by two different people
+ but neither PR shows any progress).
+ </description>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-deploy-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <configuration>
+ <compilerArgs>
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-serial</arg>
+ <arg>-Xlint:-try</arg>
+ <arg>-Xlint:-processing</arg>
+ </compilerArgs>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-install-plugin</artifactId>
+ <configuration>
+ <updateReleaseInfo>true</updateReleaseInfo>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/compress/Compressor.java b/airlift-zstd/src/main/java/ai/vespa/airlift/compress/Compressor.java
new file mode 100644
index 00000000000..ba0530c985f
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/compress/Compressor.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.compress;
+
+import java.nio.ByteBuffer;
+
+public interface Compressor
+{
+ int maxCompressedLength(int uncompressedSize);
+
+ /**
+ * @return number of bytes written to the output
+ */
+ int compress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength);
+
+ void compress(ByteBuffer input, ByteBuffer output);
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/compress/Decompressor.java b/airlift-zstd/src/main/java/ai/vespa/airlift/compress/Decompressor.java
new file mode 100644
index 00000000000..256df93e7c7
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/compress/Decompressor.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.compress;
+
+import java.nio.ByteBuffer;
+
+public interface Decompressor
+{
+ /**
+ * @return number of bytes written to the output
+ */
+ int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength)
+ throws MalformedInputException;
+
+ void decompress(ByteBuffer input, ByteBuffer output)
+ throws MalformedInputException;
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/compress/IncompatibleJvmException.java b/airlift-zstd/src/main/java/ai/vespa/airlift/compress/IncompatibleJvmException.java
new file mode 100644
index 00000000000..3c65f2c9cda
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/compress/IncompatibleJvmException.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.compress;
+
+public class IncompatibleJvmException
+ extends RuntimeException
+{
+ public IncompatibleJvmException(String message)
+ {
+ super(message);
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/compress/MalformedInputException.java b/airlift-zstd/src/main/java/ai/vespa/airlift/compress/MalformedInputException.java
new file mode 100644
index 00000000000..82e14e8ab19
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/compress/MalformedInputException.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.compress;
+
+public class MalformedInputException
+ extends RuntimeException
+{
+ private final long offset;
+
+ public MalformedInputException(long offset)
+ {
+ this(offset, "Malformed input");
+ }
+
+ public MalformedInputException(long offset, String reason)
+ {
+ super(reason + ": offset=" + offset);
+ this.offset = offset;
+ }
+
+ public long getOffset()
+ {
+ return offset;
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BitInputStream.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BitInputStream.java
new file mode 100644
index 00000000000..5b7234594f9
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BitInputStream.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_LONG;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.highestBit;
+import static ai.vespa.airlift.zstd.Util.verify;
+
+/**
+ * Bit streams are encoded as a byte-aligned little-endian stream. Thus, bits are laid out
+ * in the following manner, and the stream is read from right to left.
+ * <p>
+ * <p>
+ * ... [16 17 18 19 20 21 22 23] [8 9 10 11 12 13 14 15] [0 1 2 3 4 5 6 7]
+ */
+class BitInputStream
+{
+ private BitInputStream()
+ {
+ }
+
+ public static boolean isEndOfStream(long startAddress, long currentAddress, int bitsConsumed)
+ {
+ return startAddress == currentAddress && bitsConsumed == Long.SIZE;
+ }
+
+ @SuppressWarnings("fallthrough")
+ static long readTail(Object inputBase, long inputAddress, int inputSize)
+ {
+ long bits = UNSAFE.getByte(inputBase, inputAddress) & 0xFF;
+
+ switch (inputSize) {
+ case 7:
+ bits |= (UNSAFE.getByte(inputBase, inputAddress + 6) & 0xFFL) << 48;
+ case 6:
+ bits |= (UNSAFE.getByte(inputBase, inputAddress + 5) & 0xFFL) << 40;
+ case 5:
+ bits |= (UNSAFE.getByte(inputBase, inputAddress + 4) & 0xFFL) << 32;
+ case 4:
+ bits |= (UNSAFE.getByte(inputBase, inputAddress + 3) & 0xFFL) << 24;
+ case 3:
+ bits |= (UNSAFE.getByte(inputBase, inputAddress + 2) & 0xFFL) << 16;
+ case 2:
+ bits |= (UNSAFE.getByte(inputBase, inputAddress + 1) & 0xFFL) << 8;
+ }
+
+ return bits;
+ }
+
+ /**
+ * @return numberOfBits in the low order bits of a long
+ */
+ public static long peekBits(int bitsConsumed, long bitContainer, int numberOfBits)
+ {
+ return (((bitContainer << bitsConsumed) >>> 1) >>> (63 - numberOfBits));
+ }
+
+ /**
+ * numberOfBits must be > 0
+ *
+ * @return numberOfBits in the low order bits of a long
+ */
+ public static long peekBitsFast(int bitsConsumed, long bitContainer, int numberOfBits)
+ {
+ return ((bitContainer << bitsConsumed) >>> (64 - numberOfBits));
+ }
+
+ static class Initializer
+ {
+ private final Object inputBase;
+ private final long startAddress;
+ private final long endAddress;
+ private long bits;
+ private long currentAddress;
+ private int bitsConsumed;
+
+ public Initializer(Object inputBase, long startAddress, long endAddress)
+ {
+ this.inputBase = inputBase;
+ this.startAddress = startAddress;
+ this.endAddress = endAddress;
+ }
+
+ public long getBits()
+ {
+ return bits;
+ }
+
+ public long getCurrentAddress()
+ {
+ return currentAddress;
+ }
+
+ public int getBitsConsumed()
+ {
+ return bitsConsumed;
+ }
+
+ public void initialize()
+ {
+ verify(endAddress - startAddress >= 1, startAddress, "Bitstream is empty");
+
+ int lastByte = UNSAFE.getByte(inputBase, endAddress - 1) & 0xFF;
+ verify(lastByte != 0, endAddress, "Bitstream end mark not present");
+
+ bitsConsumed = SIZE_OF_LONG - highestBit(lastByte);
+
+ int inputSize = (int) (endAddress - startAddress);
+ if (inputSize >= SIZE_OF_LONG) { /* normal case */
+ currentAddress = endAddress - SIZE_OF_LONG;
+ bits = UNSAFE.getLong(inputBase, currentAddress);
+ }
+ else {
+ currentAddress = startAddress;
+ bits = readTail(inputBase, startAddress, inputSize);
+
+ bitsConsumed += (SIZE_OF_LONG - inputSize) * 8;
+ }
+ }
+ }
+
+ static final class Loader
+ {
+ private final Object inputBase;
+ private final long startAddress;
+ private long bits;
+ private long currentAddress;
+ private int bitsConsumed;
+ private boolean overflow;
+
+ public Loader(Object inputBase, long startAddress, long currentAddress, long bits, int bitsConsumed)
+ {
+ this.inputBase = inputBase;
+ this.startAddress = startAddress;
+ this.bits = bits;
+ this.currentAddress = currentAddress;
+ this.bitsConsumed = bitsConsumed;
+ }
+
+ public long getBits()
+ {
+ return bits;
+ }
+
+ public long getCurrentAddress()
+ {
+ return currentAddress;
+ }
+
+ public int getBitsConsumed()
+ {
+ return bitsConsumed;
+ }
+
+ public boolean isOverflow()
+ {
+ return overflow;
+ }
+
+ public boolean load()
+ {
+ if (bitsConsumed > 64) {
+ overflow = true;
+ return true;
+ }
+
+ else if (currentAddress == startAddress) {
+ return true;
+ }
+
+ int bytes = bitsConsumed >>> 3; // divide by 8
+ if (currentAddress >= startAddress + SIZE_OF_LONG) {
+ if (bytes > 0) {
+ currentAddress -= bytes;
+ bits = UNSAFE.getLong(inputBase, currentAddress);
+ }
+ bitsConsumed &= 0b111;
+ }
+ else if (currentAddress - bytes < startAddress) {
+ bytes = (int) (currentAddress - startAddress);
+ currentAddress = startAddress;
+ bitsConsumed -= bytes * SIZE_OF_LONG;
+ bits = UNSAFE.getLong(inputBase, startAddress);
+ return true;
+ }
+ else {
+ currentAddress -= bytes;
+ bitsConsumed -= bytes * SIZE_OF_LONG;
+ bits = UNSAFE.getLong(inputBase, currentAddress);
+ }
+
+ return false;
+ }
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BitOutputStream.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BitOutputStream.java
new file mode 100644
index 00000000000..29dd168fca2
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BitOutputStream.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_LONG;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.checkArgument;
+
+class BitOutputStream
+{
+ private static final long[] BIT_MASK = {
+ 0x0, 0x1, 0x3, 0x7, 0xF, 0x1F,
+ 0x3F, 0x7F, 0xFF, 0x1FF, 0x3FF, 0x7FF,
+ 0xFFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF, 0x1FFFF,
+ 0x3FFFF, 0x7FFFF, 0xFFFFF, 0x1FFFFF, 0x3FFFFF, 0x7FFFFF,
+ 0xFFFFFF, 0x1FFFFFF, 0x3FFFFFF, 0x7FFFFFF, 0xFFFFFFF, 0x1FFFFFFF,
+ 0x3FFFFFFF, 0x7FFFFFFF}; // up to 31 bits
+
+ private final Object outputBase;
+ private final long outputAddress;
+ private final long outputLimit;
+
+ private long container;
+ private int bitCount;
+ private long currentAddress;
+
+ public BitOutputStream(Object outputBase, long outputAddress, int outputSize)
+ {
+ checkArgument(outputSize >= SIZE_OF_LONG, "Output buffer too small");
+
+ this.outputBase = outputBase;
+ this.outputAddress = outputAddress;
+ outputLimit = this.outputAddress + outputSize - SIZE_OF_LONG;
+
+ currentAddress = this.outputAddress;
+ }
+
+ public void addBits(int value, int bits)
+ {
+ container |= (value & BIT_MASK[bits]) << bitCount;
+ bitCount += bits;
+ }
+
+ /**
+ * Note: leading bits of value must be 0
+ */
+ public void addBitsFast(int value, int bits)
+ {
+ container |= ((long) value) << bitCount;
+ bitCount += bits;
+ }
+
+ public void flush()
+ {
+ int bytes = bitCount >>> 3;
+
+ UNSAFE.putLong(outputBase, currentAddress, container);
+ currentAddress += bytes;
+
+ if (currentAddress > outputLimit) {
+ currentAddress = outputLimit;
+ }
+
+ bitCount &= 7;
+ container >>>= bytes * 8;
+ }
+
+ public int close()
+ {
+ addBitsFast(1, 1); // end mark
+ flush();
+
+ if (currentAddress >= outputLimit) {
+ return 0;
+ }
+
+ return (int) ((currentAddress - outputAddress) + (bitCount > 0 ? 1 : 0));
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BlockCompressionState.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BlockCompressionState.java
new file mode 100644
index 00000000000..e5d15cc6a58
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BlockCompressionState.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import java.util.Arrays;
+
+class BlockCompressionState
+{
+ public final int[] hashTable;
+ public final int[] chainTable;
+
+ private final long baseAddress;
+
+ // starting point of the window with respect to baseAddress
+ private int windowBaseOffset;
+
+ public BlockCompressionState(CompressionParameters parameters, long baseAddress)
+ {
+ this.baseAddress = baseAddress;
+ hashTable = new int[1 << parameters.getHashLog()];
+ chainTable = new int[1 << parameters.getChainLog()]; // TODO: chain table not used by Strategy.FAST
+ }
+
+ public void reset()
+ {
+ Arrays.fill(hashTable, 0);
+ Arrays.fill(chainTable, 0);
+ }
+
+ public void enforceMaxDistance(long inputLimit, int maxDistance)
+ {
+ int distance = (int) (inputLimit - baseAddress);
+
+ int newOffset = distance - maxDistance;
+ if (windowBaseOffset < newOffset) {
+ windowBaseOffset = newOffset;
+ }
+ }
+
+ public long getBaseAddress()
+ {
+ return baseAddress;
+ }
+
+ public int getWindowBaseOffset()
+ {
+ return windowBaseOffset;
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BlockCompressor.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BlockCompressor.java
new file mode 100644
index 00000000000..a23fd0ae9a9
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/BlockCompressor.java
@@ -0,0 +1,21 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+interface BlockCompressor
+{
+ BlockCompressor UNSUPPORTED = (inputBase, inputAddress, inputSize, sequenceStore, blockCompressionState, offsets, parameters) -> { throw new UnsupportedOperationException(); };
+
+ int compressBlock(Object inputBase, long inputAddress, int inputSize, SequenceStore output, BlockCompressionState state, RepeatedOffsets offsets, CompressionParameters parameters);
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/CompressionContext.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/CompressionContext.java
new file mode 100644
index 00000000000..fd4b393c758
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/CompressionContext.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Constants.MAX_BLOCK_SIZE;
+
+class CompressionContext
+{
+ public final RepeatedOffsets offsets = new RepeatedOffsets();
+ public final BlockCompressionState blockCompressionState;
+ public final SequenceStore sequenceStore;
+
+ public final SequenceEncodingContext sequenceEncodingContext = new SequenceEncodingContext();
+
+ public final HuffmanCompressionContext huffmanContext = new HuffmanCompressionContext();
+
+ public CompressionContext(CompressionParameters parameters, long baseAddress, int inputSize)
+ {
+ int windowSize = Math.max(1, Math.min(1 << parameters.getWindowLog(), inputSize));
+ int blockSize = Math.min(MAX_BLOCK_SIZE, windowSize);
+ int divider = (parameters.getSearchLength() == 3) ? 3 : 4;
+
+ int maxSequences = blockSize / divider;
+
+ sequenceStore = new SequenceStore(blockSize, maxSequences);
+
+ blockCompressionState = new BlockCompressionState(parameters, baseAddress);
+ }
+
+ public void commit()
+ {
+ offsets.commit();
+ huffmanContext.saveChanges();
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/CompressionParameters.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/CompressionParameters.java
new file mode 100644
index 00000000000..586a07a8cb2
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/CompressionParameters.java
@@ -0,0 +1,306 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Constants.MAX_WINDOW_LOG;
+import static ai.vespa.airlift.zstd.Constants.MIN_WINDOW_LOG;
+import static ai.vespa.airlift.zstd.Util.cycleLog;
+import static ai.vespa.airlift.zstd.Util.highestBit;
+
+class CompressionParameters
+{
+ private static final int MIN_HASH_LOG = 6;
+
+ public static final int DEFAULT_COMPRESSION_LEVEL = 3;
+ private static final int MAX_COMPRESSION_LEVEL = 22;
+
+ private final int windowLog; // largest match distance : larger == more compression, more memory needed during decompression
+ private final int chainLog; // fully searched segment : larger == more compression, slower, more memory (useless for fast)
+ private final int hashLog; // dispatch table : larger == faster, more memory
+ private final int searchLog; // nb of searches : larger == more compression, slower
+ private final int searchLength; // match length searched : larger == faster decompression, sometimes less compression
+ private final int targetLength; // acceptable match size for optimal parser (only) : larger == more compression, slower
+ private final Strategy strategy;
+
+ private static final CompressionParameters[][] DEFAULT_COMPRESSION_PARAMETERS = new CompressionParameters[][] {
+ {
+ // default
+ new CompressionParameters(19, 12, 13, 1, 6, 1, Strategy.FAST), /* base for negative levels */
+ new CompressionParameters(19, 13, 14, 1, 7, 0, Strategy.FAST), /* level 1 */
+ new CompressionParameters(19, 15, 16, 1, 6, 0, Strategy.FAST), /* level 2 */
+ new CompressionParameters(20, 16, 17, 1, 5, 1, Strategy.DFAST), /* level 3 */
+ new CompressionParameters(20, 18, 18, 1, 5, 1, Strategy.DFAST), /* level 4 */
+ new CompressionParameters(20, 18, 18, 2, 5, 2, Strategy.GREEDY), /* level 5 */
+ new CompressionParameters(21, 18, 19, 2, 5, 4, Strategy.LAZY), /* level 6 */
+ new CompressionParameters(21, 18, 19, 3, 5, 8, Strategy.LAZY2), /* level 7 */
+ new CompressionParameters(21, 19, 19, 3, 5, 16, Strategy.LAZY2), /* level 8 */
+ new CompressionParameters(21, 19, 20, 4, 5, 16, Strategy.LAZY2), /* level 9 */
+ new CompressionParameters(21, 20, 21, 4, 5, 16, Strategy.LAZY2), /* level 10 */
+ new CompressionParameters(21, 21, 22, 4, 5, 16, Strategy.LAZY2), /* level 11 */
+ new CompressionParameters(22, 20, 22, 5, 5, 16, Strategy.LAZY2), /* level 12 */
+ new CompressionParameters(22, 21, 22, 4, 5, 32, Strategy.BTLAZY2), /* level 13 */
+ new CompressionParameters(22, 21, 22, 5, 5, 32, Strategy.BTLAZY2), /* level 14 */
+ new CompressionParameters(22, 22, 22, 6, 5, 32, Strategy.BTLAZY2), /* level 15 */
+ new CompressionParameters(22, 21, 22, 4, 5, 48, Strategy.BTOPT), /* level 16 */
+ new CompressionParameters(23, 22, 22, 4, 4, 64, Strategy.BTOPT), /* level 17 */
+ new CompressionParameters(23, 23, 22, 6, 3, 256, Strategy.BTOPT), /* level 18 */
+ new CompressionParameters(23, 24, 22, 7, 3, 256, Strategy.BTULTRA), /* level 19 */
+ new CompressionParameters(25, 25, 23, 7, 3, 256, Strategy.BTULTRA), /* level 20 */
+ new CompressionParameters(26, 26, 24, 7, 3, 512, Strategy.BTULTRA), /* level 21 */
+ new CompressionParameters(27, 27, 25, 9, 3, 999, Strategy.BTULTRA) /* level 22 */
+ },
+ {
+ // for size <= 256 KB
+ new CompressionParameters(18, 12, 13, 1, 5, 1, Strategy.FAST), /* base for negative levels */
+ new CompressionParameters(18, 13, 14, 1, 6, 0, Strategy.FAST), /* level 1 */
+ new CompressionParameters(18, 14, 14, 1, 5, 1, Strategy.DFAST), /* level 2 */
+ new CompressionParameters(18, 16, 16, 1, 4, 1, Strategy.DFAST), /* level 3 */
+ new CompressionParameters(18, 16, 17, 2, 5, 2, Strategy.GREEDY), /* level 4.*/
+ new CompressionParameters(18, 18, 18, 3, 5, 2, Strategy.GREEDY), /* level 5.*/
+ new CompressionParameters(18, 18, 19, 3, 5, 4, Strategy.LAZY), /* level 6.*/
+ new CompressionParameters(18, 18, 19, 4, 4, 4, Strategy.LAZY), /* level 7 */
+ new CompressionParameters(18, 18, 19, 4, 4, 8, Strategy.LAZY2), /* level 8 */
+ new CompressionParameters(18, 18, 19, 5, 4, 8, Strategy.LAZY2), /* level 9 */
+ new CompressionParameters(18, 18, 19, 6, 4, 8, Strategy.LAZY2), /* level 10 */
+ new CompressionParameters(18, 18, 19, 5, 4, 16, Strategy.BTLAZY2), /* level 11.*/
+ new CompressionParameters(18, 19, 19, 6, 4, 16, Strategy.BTLAZY2), /* level 12.*/
+ new CompressionParameters(18, 19, 19, 8, 4, 16, Strategy.BTLAZY2), /* level 13 */
+ new CompressionParameters(18, 18, 19, 4, 4, 24, Strategy.BTOPT), /* level 14.*/
+ new CompressionParameters(18, 18, 19, 4, 3, 24, Strategy.BTOPT), /* level 15.*/
+ new CompressionParameters(18, 19, 19, 6, 3, 64, Strategy.BTOPT), /* level 16.*/
+ new CompressionParameters(18, 19, 19, 8, 3, 128, Strategy.BTOPT), /* level 17.*/
+ new CompressionParameters(18, 19, 19, 10, 3, 256, Strategy.BTOPT), /* level 18.*/
+ new CompressionParameters(18, 19, 19, 10, 3, 256, Strategy.BTULTRA), /* level 19.*/
+ new CompressionParameters(18, 19, 19, 11, 3, 512, Strategy.BTULTRA), /* level 20.*/
+ new CompressionParameters(18, 19, 19, 12, 3, 512, Strategy.BTULTRA), /* level 21.*/
+ new CompressionParameters(18, 19, 19, 13, 3, 999, Strategy.BTULTRA) /* level 22.*/
+ },
+ {
+ // for size <= 128 KB
+ new CompressionParameters(17, 12, 12, 1, 5, 1, Strategy.FAST), /* base for negative levels */
+ new CompressionParameters(17, 12, 13, 1, 6, 0, Strategy.FAST), /* level 1 */
+ new CompressionParameters(17, 13, 15, 1, 5, 0, Strategy.FAST), /* level 2 */
+ new CompressionParameters(17, 15, 16, 2, 5, 1, Strategy.DFAST), /* level 3 */
+ new CompressionParameters(17, 17, 17, 2, 4, 1, Strategy.DFAST), /* level 4 */
+ new CompressionParameters(17, 16, 17, 3, 4, 2, Strategy.GREEDY), /* level 5 */
+ new CompressionParameters(17, 17, 17, 3, 4, 4, Strategy.LAZY), /* level 6 */
+ new CompressionParameters(17, 17, 17, 3, 4, 8, Strategy.LAZY2), /* level 7 */
+ new CompressionParameters(17, 17, 17, 4, 4, 8, Strategy.LAZY2), /* level 8 */
+ new CompressionParameters(17, 17, 17, 5, 4, 8, Strategy.LAZY2), /* level 9 */
+ new CompressionParameters(17, 17, 17, 6, 4, 8, Strategy.LAZY2), /* level 10 */
+ new CompressionParameters(17, 17, 17, 7, 4, 8, Strategy.LAZY2), /* level 11 */
+ new CompressionParameters(17, 18, 17, 6, 4, 16, Strategy.BTLAZY2), /* level 12 */
+ new CompressionParameters(17, 18, 17, 8, 4, 16, Strategy.BTLAZY2), /* level 13.*/
+ new CompressionParameters(17, 18, 17, 4, 4, 32, Strategy.BTOPT), /* level 14.*/
+ new CompressionParameters(17, 18, 17, 6, 3, 64, Strategy.BTOPT), /* level 15.*/
+ new CompressionParameters(17, 18, 17, 7, 3, 128, Strategy.BTOPT), /* level 16.*/
+ new CompressionParameters(17, 18, 17, 7, 3, 256, Strategy.BTOPT), /* level 17.*/
+ new CompressionParameters(17, 18, 17, 8, 3, 256, Strategy.BTOPT), /* level 18.*/
+ new CompressionParameters(17, 18, 17, 8, 3, 256, Strategy.BTULTRA), /* level 19.*/
+ new CompressionParameters(17, 18, 17, 9, 3, 256, Strategy.BTULTRA), /* level 20.*/
+ new CompressionParameters(17, 18, 17, 10, 3, 256, Strategy.BTULTRA), /* level 21.*/
+ new CompressionParameters(17, 18, 17, 11, 3, 512, Strategy.BTULTRA) /* level 22.*/
+ },
+ {
+ // for size <= 16 KB
+ new CompressionParameters(14, 12, 13, 1, 5, 1, Strategy.FAST), /* base for negative levels */
+ new CompressionParameters(14, 14, 15, 1, 5, 0, Strategy.FAST), /* level 1 */
+ new CompressionParameters(14, 14, 15, 1, 4, 0, Strategy.FAST), /* level 2 */
+ new CompressionParameters(14, 14, 14, 2, 4, 1, Strategy.DFAST), /* level 3.*/
+ new CompressionParameters(14, 14, 14, 4, 4, 2, Strategy.GREEDY), /* level 4.*/
+ new CompressionParameters(14, 14, 14, 3, 4, 4, Strategy.LAZY), /* level 5.*/
+ new CompressionParameters(14, 14, 14, 4, 4, 8, Strategy.LAZY2), /* level 6 */
+ new CompressionParameters(14, 14, 14, 6, 4, 8, Strategy.LAZY2), /* level 7 */
+ new CompressionParameters(14, 14, 14, 8, 4, 8, Strategy.LAZY2), /* level 8.*/
+ new CompressionParameters(14, 15, 14, 5, 4, 8, Strategy.BTLAZY2), /* level 9.*/
+ new CompressionParameters(14, 15, 14, 9, 4, 8, Strategy.BTLAZY2), /* level 10.*/
+ new CompressionParameters(14, 15, 14, 3, 4, 12, Strategy.BTOPT), /* level 11.*/
+ new CompressionParameters(14, 15, 14, 6, 3, 16, Strategy.BTOPT), /* level 12.*/
+ new CompressionParameters(14, 15, 14, 6, 3, 24, Strategy.BTOPT), /* level 13.*/
+ new CompressionParameters(14, 15, 15, 6, 3, 48, Strategy.BTOPT), /* level 14.*/
+ new CompressionParameters(14, 15, 15, 6, 3, 64, Strategy.BTOPT), /* level 15.*/
+ new CompressionParameters(14, 15, 15, 6, 3, 96, Strategy.BTOPT), /* level 16.*/
+ new CompressionParameters(14, 15, 15, 6, 3, 128, Strategy.BTOPT), /* level 17.*/
+ new CompressionParameters(14, 15, 15, 8, 3, 256, Strategy.BTOPT), /* level 18.*/
+ new CompressionParameters(14, 15, 15, 6, 3, 256, Strategy.BTULTRA), /* level 19.*/
+ new CompressionParameters(14, 15, 15, 8, 3, 256, Strategy.BTULTRA), /* level 20.*/
+ new CompressionParameters(14, 15, 15, 9, 3, 256, Strategy.BTULTRA), /* level 21.*/
+ new CompressionParameters(14, 15, 15, 10, 3, 512, Strategy.BTULTRA) /* level 22.*/
+ }
+ };
+
+ public enum Strategy
+ {
+ // from faster to stronger
+
+ // YC: fast is a "single probe" strategy : at every position, we attempt to find a match, and give up if we don't find any. similar to lz4.
+ FAST(BlockCompressor.UNSUPPORTED),
+
+ // YC: double_fast is a 2 attempts strategies. They are not symmetrical by the way. One attempt is "normal" while the second one looks for "long matches". It was
+ // empirically found that this was the best trade off. As can be guessed, it's slower than single-attempt, but find more and better matches, so compresses better.
+ DFAST(new DoubleFastBlockCompressor()),
+
+ // YC: greedy uses a hash chain strategy. Every position is hashed, and all positions with same hash are chained. The algorithm goes through all candidates. There are
+ // diminishing returns in going deeper and deeper, so after a nb of attempts (which can be selected), it abandons the search. The best (longest) match wins. If there is
+ // one winner, it's immediately encoded.
+ GREEDY(BlockCompressor.UNSUPPORTED),
+
+ // YC: lazy will do something similar to greedy, but will not encode immediately. It will search again at next position, in case it would find something better.
+ // It's actually fairly common to have a small match at position p hiding a more worthy one at position p+1. This obviously increases the search workload. But the
+ // resulting compressed stream generally contains larger matches, hence compresses better.
+ LAZY(BlockCompressor.UNSUPPORTED),
+
+ // YC: lazy2 is same as lazy, but deeper. It will search at P, P+1 and then P+2 in case it would find something even better. More workload. Better matches.
+ LAZY2(BlockCompressor.UNSUPPORTED),
+
+ // YC: btlazy2 is like lazy2, but trades the hash chain for a binary tree. This becomes necessary, as the nb of attempts becomes prohibitively expensive. The binary tree
+ // complexity increases with log of search depth, instead of proportionally with search depth. So searching deeper in history quickly becomes the dominant operation.
+ // btlazy2 cuts into that. But it costs 2x more memory. It's also relatively "slow", even when trying to cut its parameters to make it perform faster. So it's really
+ // a high compression strategy.
+ BTLAZY2(BlockCompressor.UNSUPPORTED),
+
+ // YC: btopt is, well, a hell of lot more complex.
+ // It will compute and find multiple matches per position, will dynamically compare every path from point P to P+N, reverse the graph to find cheapest path, iterate on
+ // batches of overlapping matches, etc. It's much more expensive. But the compression ratio is also much better.
+ BTOPT(BlockCompressor.UNSUPPORTED),
+
+ // YC: btultra is about the same, but doesn't cut as many corners (btopt "abandons" more quickly unpromising little gains). Slower, stronger.
+ BTULTRA(BlockCompressor.UNSUPPORTED);
+
+ private final BlockCompressor compressor;
+
+ Strategy(BlockCompressor compressor)
+ {
+ this.compressor = compressor;
+ }
+
+ public BlockCompressor getCompressor()
+ {
+ return compressor;
+ }
+ }
+
+ public CompressionParameters(int windowLog, int chainLog, int hashLog, int searchLog, int searchLength, int targetLength, Strategy strategy)
+ {
+ this.windowLog = windowLog;
+ this.chainLog = chainLog;
+ this.hashLog = hashLog;
+ this.searchLog = searchLog;
+ this.searchLength = searchLength;
+ this.targetLength = targetLength;
+ this.strategy = strategy;
+ }
+
+ public int getWindowLog()
+ {
+ return windowLog;
+ }
+
+ public int getSearchLength()
+ {
+ return searchLength;
+ }
+
+ public int getChainLog()
+ {
+ return chainLog;
+ }
+
+ public int getHashLog()
+ {
+ return hashLog;
+ }
+
+ public int getSearchLog()
+ {
+ return searchLog;
+ }
+
+ public int getTargetLength()
+ {
+ return targetLength;
+ }
+
+ public Strategy getStrategy()
+ {
+ return strategy;
+ }
+
+ public static CompressionParameters compute(int compressionLevel, int inputSize)
+ {
+ CompressionParameters defaultParameters = getDefaultParameters(compressionLevel, inputSize);
+
+ int targetLength = defaultParameters.targetLength;
+ int windowLog = defaultParameters.windowLog;
+ int chainLog = defaultParameters.chainLog;
+ int hashLog = defaultParameters.hashLog;
+ int searchLog = defaultParameters.searchLog;
+ int searchLength = defaultParameters.searchLength;
+ Strategy strategy = defaultParameters.strategy;
+
+ if (compressionLevel < 0) {
+ targetLength = -compressionLevel; // acceleration factor
+ }
+
+ // resize windowLog if input is small enough, to use less memory
+ long maxWindowResize = 1L << (MAX_WINDOW_LOG - 1);
+ if (inputSize < maxWindowResize) {
+ int hashSizeMin = 1 << MIN_HASH_LOG;
+ int inputSizeLog = (inputSize < hashSizeMin) ? MIN_HASH_LOG : highestBit(inputSize - 1) + 1;
+ if (windowLog > inputSizeLog) {
+ windowLog = inputSizeLog;
+ }
+ }
+
+ if (hashLog > windowLog + 1) {
+ hashLog = windowLog + 1;
+ }
+
+ int cycleLog = cycleLog(chainLog, strategy);
+ if (cycleLog > windowLog) {
+ chainLog -= (cycleLog - windowLog);
+ }
+
+ if (windowLog < MIN_WINDOW_LOG) {
+ windowLog = MIN_WINDOW_LOG;
+ }
+
+ return new CompressionParameters(windowLog, chainLog, hashLog, searchLog, searchLength, targetLength, strategy);
+ }
+
+ private static CompressionParameters getDefaultParameters(int compressionLevel, long estimatedInputSize)
+ {
+ int table = 0;
+
+ if (estimatedInputSize != 0) {
+ if (estimatedInputSize <= 16 * 1024) {
+ table = 3;
+ }
+ else if (estimatedInputSize <= 128 * 1024) {
+ table = 2;
+ }
+ else if (estimatedInputSize <= 256 * 1024) {
+ table = 1;
+ }
+ }
+
+ int row = DEFAULT_COMPRESSION_LEVEL;
+
+ if (compressionLevel != 0) { // TODO: figure out better way to indicate default compression level
+ row = Math.min(Math.max(0, compressionLevel), MAX_COMPRESSION_LEVEL);
+ }
+
+ return DEFAULT_COMPRESSION_PARAMETERS[table][row];
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Constants.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Constants.java
new file mode 100644
index 00000000000..8777487b8c2
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Constants.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+class Constants
+{
+ public static final int SIZE_OF_BYTE = 1;
+ public static final int SIZE_OF_SHORT = 2;
+ public static final int SIZE_OF_INT = 4;
+ public static final int SIZE_OF_LONG = 8;
+
+ public static final int MAGIC_NUMBER = 0xFD2FB528;
+ public static final int MAGIC_SKIPFRAME_MIN = 0x184D2A50;
+ public static final int MAGIC_SKIPFRAME_MAX = 0x184D2A5F;
+
+ public static final int MIN_WINDOW_LOG = 10;
+ public static final int MAX_WINDOW_LOG = 31;
+
+ public static final int SIZE_OF_BLOCK_HEADER = 3;
+
+ public static final int MIN_SEQUENCES_SIZE = 1;
+ public static final int MIN_BLOCK_SIZE = 1 // block type tag
+ + 1 // min size of raw or rle length header
+ + MIN_SEQUENCES_SIZE;
+ public static final int MAX_BLOCK_SIZE = 128 * 1024;
+
+ public static final int REPEATED_OFFSET_COUNT = 3;
+
+ // block types
+ public static final int RAW_BLOCK = 0;
+ public static final int RLE_BLOCK = 1;
+ public static final int COMPRESSED_BLOCK = 2;
+
+ // sequence encoding types
+ public static final int SEQUENCE_ENCODING_BASIC = 0;
+ public static final int SEQUENCE_ENCODING_RLE = 1;
+ public static final int SEQUENCE_ENCODING_COMPRESSED = 2;
+ public static final int SEQUENCE_ENCODING_REPEAT = 3;
+
+ public static final int MAX_LITERALS_LENGTH_SYMBOL = 35;
+ public static final int MAX_MATCH_LENGTH_SYMBOL = 52;
+ public static final int MAX_OFFSET_CODE_SYMBOL = 31;
+ public static final int DEFAULT_MAX_OFFSET_CODE_SYMBOL = 28;
+
+ public static final int LITERAL_LENGTH_TABLE_LOG = 9;
+ public static final int MATCH_LENGTH_TABLE_LOG = 9;
+ public static final int OFFSET_TABLE_LOG = 8;
+
+ // literal block types
+ public static final int RAW_LITERALS_BLOCK = 0;
+ public static final int RLE_LITERALS_BLOCK = 1;
+ public static final int COMPRESSED_LITERALS_BLOCK = 2;
+ public static final int TREELESS_LITERALS_BLOCK = 3;
+
+ public static final int LONG_NUMBER_OF_SEQUENCES = 0x7F00;
+
+ public static final int[] LITERALS_LENGTH_BITS = {0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 1, 1, 1, 1, 2, 2, 3, 3,
+ 4, 6, 7, 8, 9, 10, 11, 12,
+ 13, 14, 15, 16};
+
+ public static final int[] MATCH_LENGTH_BITS = {0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 1, 1, 1, 1, 2, 2, 3, 3,
+ 4, 4, 5, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16};
+
+ private Constants()
+ {
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/DoubleFastBlockCompressor.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/DoubleFastBlockCompressor.java
new file mode 100644
index 00000000000..c2c6b4a936d
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/DoubleFastBlockCompressor.java
@@ -0,0 +1,261 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_INT;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_LONG;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+
+class DoubleFastBlockCompressor
+ implements BlockCompressor
+{
+ private static final int MIN_MATCH = 3;
+ private static final int SEARCH_STRENGTH = 8;
+ private static final int REP_MOVE = Constants.REPEATED_OFFSET_COUNT - 1;
+
+ public int compressBlock(Object inputBase, final long inputAddress, int inputSize, SequenceStore output, BlockCompressionState state, RepeatedOffsets offsets, CompressionParameters parameters)
+ {
+ int matchSearchLength = Math.max(parameters.getSearchLength(), 4);
+
+ // Offsets in hash tables are relative to baseAddress. Hash tables can be reused across calls to compressBlock as long as
+ // baseAddress is kept constant.
+ // We don't want to generate sequences that point before the current window limit, so we "filter" out all results from looking up in the hash tables
+ // beyond that point.
+ final long baseAddress = state.getBaseAddress();
+ final long windowBaseAddress = baseAddress + state.getWindowBaseOffset();
+
+ int[] longHashTable = state.hashTable;
+ int longHashBits = parameters.getHashLog();
+
+ int[] shortHashTable = state.chainTable;
+ int shortHashBits = parameters.getChainLog();
+
+ final long inputEnd = inputAddress + inputSize;
+ final long inputLimit = inputEnd - SIZE_OF_LONG; // We read a long at a time for computing the hashes
+
+ long input = inputAddress;
+ long anchor = inputAddress;
+
+ int offset1 = offsets.getOffset0();
+ int offset2 = offsets.getOffset1();
+
+ int savedOffset = 0;
+
+ if (input - windowBaseAddress == 0) {
+ input++;
+ }
+ int maxRep = (int) (input - windowBaseAddress);
+
+ if (offset2 > maxRep) {
+ savedOffset = offset2;
+ offset2 = 0;
+ }
+
+ if (offset1 > maxRep) {
+ savedOffset = offset1;
+ offset1 = 0;
+ }
+
+ while (input < inputLimit) { // < instead of <=, because repcode check at (input+1)
+ int shortHash = hash(inputBase, input, shortHashBits, matchSearchLength);
+ long shortMatchAddress = baseAddress + shortHashTable[shortHash];
+
+ int longHash = hash8(UNSAFE.getLong(inputBase, input), longHashBits);
+ long longMatchAddress = baseAddress + longHashTable[longHash];
+
+ // update hash tables
+ int current = (int) (input - baseAddress);
+ longHashTable[longHash] = current;
+ shortHashTable[shortHash] = current;
+
+ int matchLength;
+ int offset;
+
+ if (offset1 > 0 && UNSAFE.getInt(inputBase, input + 1 - offset1) == UNSAFE.getInt(inputBase, input + 1)) {
+ // found a repeated sequence of at least 4 bytes, separated by offset1
+ matchLength = count(inputBase, input + 1 + SIZE_OF_INT, inputEnd, input + 1 + SIZE_OF_INT - offset1) + SIZE_OF_INT;
+ input++;
+ output.storeSequence(inputBase, anchor, (int) (input - anchor), 0, matchLength - MIN_MATCH);
+ }
+ else {
+ // check prefix long match
+ if (longMatchAddress > windowBaseAddress && UNSAFE.getLong(inputBase, longMatchAddress) == UNSAFE.getLong(inputBase, input)) {
+ matchLength = count(inputBase, input + SIZE_OF_LONG, inputEnd, longMatchAddress + SIZE_OF_LONG) + SIZE_OF_LONG;
+ offset = (int) (input - longMatchAddress);
+ while (input > anchor && longMatchAddress > windowBaseAddress && UNSAFE.getByte(inputBase, input - 1) == UNSAFE.getByte(inputBase, longMatchAddress - 1)) {
+ input--;
+ longMatchAddress--;
+ matchLength++;
+ }
+ }
+ else {
+ // check prefix short match
+ if (shortMatchAddress > windowBaseAddress && UNSAFE.getInt(inputBase, shortMatchAddress) == UNSAFE.getInt(inputBase, input)) {
+ int nextOffsetHash = hash8(UNSAFE.getLong(inputBase, input + 1), longHashBits);
+ long nextOffsetMatchAddress = baseAddress + longHashTable[nextOffsetHash];
+ longHashTable[nextOffsetHash] = current + 1;
+
+ // check prefix long +1 match
+ if (nextOffsetMatchAddress > windowBaseAddress && UNSAFE.getLong(inputBase, nextOffsetMatchAddress) == UNSAFE.getLong(inputBase, input + 1)) {
+ matchLength = count(inputBase, input + 1 + SIZE_OF_LONG, inputEnd, nextOffsetMatchAddress + SIZE_OF_LONG) + SIZE_OF_LONG;
+ input++;
+ offset = (int) (input - nextOffsetMatchAddress);
+ while (input > anchor && nextOffsetMatchAddress > windowBaseAddress && UNSAFE.getByte(inputBase, input - 1) == UNSAFE.getByte(inputBase, nextOffsetMatchAddress - 1)) {
+ input--;
+ nextOffsetMatchAddress--;
+ matchLength++;
+ }
+ }
+ else {
+ // if no long +1 match, explore the short match we found
+ matchLength = count(inputBase, input + SIZE_OF_INT, inputEnd, shortMatchAddress + SIZE_OF_INT) + SIZE_OF_INT;
+ offset = (int) (input - shortMatchAddress);
+ while (input > anchor && shortMatchAddress > windowBaseAddress && UNSAFE.getByte(inputBase, input - 1) == UNSAFE.getByte(inputBase, shortMatchAddress - 1)) {
+ input--;
+ shortMatchAddress--;
+ matchLength++;
+ }
+ }
+ }
+ else {
+ input += ((input - anchor) >> SEARCH_STRENGTH) + 1;
+ continue;
+ }
+ }
+
+ offset2 = offset1;
+ offset1 = offset;
+
+ output.storeSequence(inputBase, anchor, (int) (input - anchor), offset + REP_MOVE, matchLength - MIN_MATCH);
+ }
+
+ input += matchLength;
+ anchor = input;
+
+ if (input <= inputLimit) {
+ // Fill Table
+ longHashTable[hash8(UNSAFE.getLong(inputBase, baseAddress + current + 2), longHashBits)] = current + 2;
+ shortHashTable[hash(inputBase, baseAddress + current + 2, shortHashBits, matchSearchLength)] = current + 2;
+
+ longHashTable[hash8(UNSAFE.getLong(inputBase, input - 2), longHashBits)] = (int) (input - 2 - baseAddress);
+ shortHashTable[hash(inputBase, input - 2, shortHashBits, matchSearchLength)] = (int) (input - 2 - baseAddress);
+
+ while (input <= inputLimit && offset2 > 0 && UNSAFE.getInt(inputBase, input) == UNSAFE.getInt(inputBase, input - offset2)) {
+ int repetitionLength = count(inputBase, input + SIZE_OF_INT, inputEnd, input + SIZE_OF_INT - offset2) + SIZE_OF_INT;
+
+ // swap offset2 <=> offset1
+ int temp = offset2;
+ offset2 = offset1;
+ offset1 = temp;
+
+ shortHashTable[hash(inputBase, input, shortHashBits, matchSearchLength)] = (int) (input - baseAddress);
+ longHashTable[hash8(UNSAFE.getLong(inputBase, input), longHashBits)] = (int) (input - baseAddress);
+
+ output.storeSequence(inputBase, anchor, 0, 0, repetitionLength - MIN_MATCH);
+
+ input += repetitionLength;
+ anchor = input;
+ }
+ }
+ }
+
+ // save reps for next block
+ offsets.saveOffset0(offset1 != 0 ? offset1 : savedOffset);
+ offsets.saveOffset1(offset2 != 0 ? offset2 : savedOffset);
+
+ // return the last literals size
+ return (int) (inputEnd - anchor);
+ }
+
+ // TODO: same as LZ4RawCompressor.count
+
+ /**
+ * matchAddress must be < inputAddress
+ */
+ public static int count(Object inputBase, final long inputAddress, final long inputLimit, final long matchAddress)
+ {
+ long input = inputAddress;
+ long match = matchAddress;
+
+ int remaining = (int) (inputLimit - inputAddress);
+
+ // first, compare long at a time
+ int count = 0;
+ while (count < remaining - (SIZE_OF_LONG - 1)) {
+ long diff = UNSAFE.getLong(inputBase, match) ^ UNSAFE.getLong(inputBase, input);
+ if (diff != 0) {
+ return count + (Long.numberOfTrailingZeros(diff) >> 3);
+ }
+
+ count += SIZE_OF_LONG;
+ input += SIZE_OF_LONG;
+ match += SIZE_OF_LONG;
+ }
+
+ while (count < remaining && UNSAFE.getByte(inputBase, match) == UNSAFE.getByte(inputBase, input)) {
+ count++;
+ input++;
+ match++;
+ }
+
+ return count;
+ }
+
+ private static int hash(Object inputBase, long inputAddress, int bits, int matchSearchLength)
+ {
+ switch (matchSearchLength) {
+ case 8:
+ return hash8(UNSAFE.getLong(inputBase, inputAddress), bits);
+ case 7:
+ return hash7(UNSAFE.getLong(inputBase, inputAddress), bits);
+ case 6:
+ return hash6(UNSAFE.getLong(inputBase, inputAddress), bits);
+ case 5:
+ return hash5(UNSAFE.getLong(inputBase, inputAddress), bits);
+ default:
+ return hash4(UNSAFE.getInt(inputBase, inputAddress), bits);
+ }
+ }
+
+ private static final int PRIME_4_BYTES = 0x9E3779B1;
+ private static final long PRIME_5_BYTES = 0xCF1BBCDCBBL;
+ private static final long PRIME_6_BYTES = 0xCF1BBCDCBF9BL;
+ private static final long PRIME_7_BYTES = 0xCF1BBCDCBFA563L;
+ private static final long PRIME_8_BYTES = 0xCF1BBCDCB7A56463L;
+
+ private static int hash4(int value, int bits)
+ {
+ return (value * PRIME_4_BYTES) >>> (Integer.SIZE - bits);
+ }
+
+ private static int hash5(long value, int bits)
+ {
+ return (int) (((value << (Long.SIZE - 40)) * PRIME_5_BYTES) >>> (Long.SIZE - bits));
+ }
+
+ private static int hash6(long value, int bits)
+ {
+ return (int) (((value << (Long.SIZE - 48)) * PRIME_6_BYTES) >>> (Long.SIZE - bits));
+ }
+
+ private static int hash7(long value, int bits)
+ {
+ return (int) (((value << (Long.SIZE - 56)) * PRIME_7_BYTES) >>> (Long.SIZE - bits));
+ }
+
+ private static int hash8(long value, int bits)
+ {
+ return (int) ((value * PRIME_8_BYTES) >>> (Long.SIZE - bits));
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FiniteStateEntropy.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FiniteStateEntropy.java
new file mode 100644
index 00000000000..5703f0200a3
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FiniteStateEntropy.java
@@ -0,0 +1,551 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.BitInputStream.peekBits;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_INT;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_LONG;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_SHORT;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.checkArgument;
+import static ai.vespa.airlift.zstd.Util.verify;
+import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
+
+class FiniteStateEntropy
+{
+ public static final int MAX_SYMBOL = 255;
+ public static final int MAX_TABLE_LOG = 12;
+ public static final int MIN_TABLE_LOG = 5;
+
+ private static final int[] REST_TO_BEAT = new int[] {0, 473195, 504333, 520860, 550000, 700000, 750000, 830000};
+ private static final short UNASSIGNED = -2;
+
+ private FiniteStateEntropy()
+ {
+ }
+
+ public static int decompress(FiniteStateEntropy.Table table, final Object inputBase, final long inputAddress, final long inputLimit, byte[] outputBuffer)
+ {
+ final Object outputBase = outputBuffer;
+ final long outputAddress = ARRAY_BYTE_BASE_OFFSET;
+ final long outputLimit = outputAddress + outputBuffer.length;
+
+ long input = inputAddress;
+ long output = outputAddress;
+
+ // initialize bit stream
+ BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, input, inputLimit);
+ initializer.initialize();
+ int bitsConsumed = initializer.getBitsConsumed();
+ long currentAddress = initializer.getCurrentAddress();
+ long bits = initializer.getBits();
+
+ // initialize first FSE stream
+ int state1 = (int) peekBits(bitsConsumed, bits, table.log2Size);
+ bitsConsumed += table.log2Size;
+
+ BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
+ loader.load();
+ bits = loader.getBits();
+ bitsConsumed = loader.getBitsConsumed();
+ currentAddress = loader.getCurrentAddress();
+
+ // initialize second FSE stream
+ int state2 = (int) peekBits(bitsConsumed, bits, table.log2Size);
+ bitsConsumed += table.log2Size;
+
+ loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
+ loader.load();
+ bits = loader.getBits();
+ bitsConsumed = loader.getBitsConsumed();
+ currentAddress = loader.getCurrentAddress();
+
+ byte[] symbols = table.symbol;
+ byte[] numbersOfBits = table.numberOfBits;
+ int[] newStates = table.newState;
+
+ // decode 4 symbols per loop
+ while (output <= outputLimit - 4) {
+ int numberOfBits;
+
+ UNSAFE.putByte(outputBase, output, symbols[state1]);
+ numberOfBits = numbersOfBits[state1];
+ state1 = (int) (newStates[state1] + peekBits(bitsConsumed, bits, numberOfBits));
+ bitsConsumed += numberOfBits;
+
+ UNSAFE.putByte(outputBase, output + 1, symbols[state2]);
+ numberOfBits = numbersOfBits[state2];
+ state2 = (int) (newStates[state2] + peekBits(bitsConsumed, bits, numberOfBits));
+ bitsConsumed += numberOfBits;
+
+ UNSAFE.putByte(outputBase, output + 2, symbols[state1]);
+ numberOfBits = numbersOfBits[state1];
+ state1 = (int) (newStates[state1] + peekBits(bitsConsumed, bits, numberOfBits));
+ bitsConsumed += numberOfBits;
+
+ UNSAFE.putByte(outputBase, output + 3, symbols[state2]);
+ numberOfBits = numbersOfBits[state2];
+ state2 = (int) (newStates[state2] + peekBits(bitsConsumed, bits, numberOfBits));
+ bitsConsumed += numberOfBits;
+
+ output += SIZE_OF_INT;
+
+ loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
+ boolean done = loader.load();
+ bitsConsumed = loader.getBitsConsumed();
+ bits = loader.getBits();
+ currentAddress = loader.getCurrentAddress();
+ if (done) {
+ break;
+ }
+ }
+
+ while (true) {
+ verify(output <= outputLimit - 2, input, "Output buffer is too small");
+ UNSAFE.putByte(outputBase, output++, symbols[state1]);
+ int numberOfBits = numbersOfBits[state1];
+ state1 = (int) (newStates[state1] + peekBits(bitsConsumed, bits, numberOfBits));
+ bitsConsumed += numberOfBits;
+
+ loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
+ loader.load();
+ bitsConsumed = loader.getBitsConsumed();
+ bits = loader.getBits();
+ currentAddress = loader.getCurrentAddress();
+
+ if (loader.isOverflow()) {
+ UNSAFE.putByte(outputBase, output++, symbols[state2]);
+ break;
+ }
+
+ verify(output <= outputLimit - 2, input, "Output buffer is too small");
+ UNSAFE.putByte(outputBase, output++, symbols[state2]);
+ int numberOfBits1 = numbersOfBits[state2];
+ state2 = (int) (newStates[state2] + peekBits(bitsConsumed, bits, numberOfBits1));
+ bitsConsumed += numberOfBits1;
+
+ loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
+ loader.load();
+ bitsConsumed = loader.getBitsConsumed();
+ bits = loader.getBits();
+ currentAddress = loader.getCurrentAddress();
+
+ if (loader.isOverflow()) {
+ UNSAFE.putByte(outputBase, output++, symbols[state1]);
+ break;
+ }
+ }
+
+ return (int) (output - outputAddress);
+ }
+
+ public static int compress(Object outputBase, long outputAddress, int outputSize, byte[] input, int inputSize, FseCompressionTable table)
+ {
+ return compress(outputBase, outputAddress, outputSize, input, ARRAY_BYTE_BASE_OFFSET, inputSize, table);
+ }
+
+ public static int compress(Object outputBase, long outputAddress, int outputSize, Object inputBase, long inputAddress, int inputSize, FseCompressionTable table)
+ {
+ checkArgument(outputSize >= SIZE_OF_LONG, "Output buffer too small");
+
+ final long start = inputAddress;
+ final long inputLimit = start + inputSize;
+
+ long input = inputLimit;
+
+ if (inputSize <= 2) {
+ return 0;
+ }
+
+ BitOutputStream stream = new BitOutputStream(outputBase, outputAddress, outputSize);
+
+ int state1;
+ int state2;
+
+ if ((inputSize & 1) != 0) {
+ input--;
+ state1 = table.begin(UNSAFE.getByte(inputBase, input));
+
+ input--;
+ state2 = table.begin(UNSAFE.getByte(inputBase, input));
+
+ input--;
+ state1 = table.encode(stream, state1, UNSAFE.getByte(inputBase, input));
+
+ stream.flush();
+ }
+ else {
+ input--;
+ state2 = table.begin(UNSAFE.getByte(inputBase, input));
+
+ input--;
+ state1 = table.begin(UNSAFE.getByte(inputBase, input));
+ }
+
+ // join to mod 4
+ inputSize -= 2;
+
+ if ((SIZE_OF_LONG * 8 > MAX_TABLE_LOG * 4 + 7) && (inputSize & 2) != 0) { /* test bit 2 */
+ input--;
+ state2 = table.encode(stream, state2, UNSAFE.getByte(inputBase, input));
+
+ input--;
+ state1 = table.encode(stream, state1, UNSAFE.getByte(inputBase, input));
+
+ stream.flush();
+ }
+
+ // 2 or 4 encoding per loop
+ while (input > start) {
+ input--;
+ state2 = table.encode(stream, state2, UNSAFE.getByte(inputBase, input));
+
+ if (SIZE_OF_LONG * 8 < MAX_TABLE_LOG * 2 + 7) {
+ stream.flush();
+ }
+
+ input--;
+ state1 = table.encode(stream, state1, UNSAFE.getByte(inputBase, input));
+
+ if (SIZE_OF_LONG * 8 > MAX_TABLE_LOG * 4 + 7) {
+ input--;
+ state2 = table.encode(stream, state2, UNSAFE.getByte(inputBase, input));
+
+ input--;
+ state1 = table.encode(stream, state1, UNSAFE.getByte(inputBase, input));
+ }
+
+ stream.flush();
+ }
+
+ table.finish(stream, state2);
+ table.finish(stream, state1);
+
+ return stream.close();
+ }
+
+ public static int optimalTableLog(int maxTableLog, int inputSize, int maxSymbol)
+ {
+ if (inputSize <= 1) {
+ throw new IllegalArgumentException(); // not supported. Use RLE instead
+ }
+
+ int result = maxTableLog;
+
+ result = Math.min(result, Util.highestBit((inputSize - 1)) - 2); // we may be able to reduce accuracy if input is small
+
+ // Need a minimum to safely represent all symbol values
+ result = Math.max(result, Util.minTableLog(inputSize, maxSymbol));
+
+ result = Math.max(result, MIN_TABLE_LOG);
+ result = Math.min(result, MAX_TABLE_LOG);
+
+ return result;
+ }
+
+ public static int normalizeCounts(short[] normalizedCounts, int tableLog, int[] counts, int total, int maxSymbol)
+ {
+ checkArgument(tableLog >= MIN_TABLE_LOG, "Unsupported FSE table size");
+ checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table size too large");
+ checkArgument(tableLog >= Util.minTableLog(total, maxSymbol), "FSE table size too small");
+
+ long scale = 62 - tableLog;
+ long step = (1L << 62) / total;
+ long vstep = 1L << (scale - 20);
+
+ int stillToDistribute = 1 << tableLog;
+
+ int largest = 0;
+ short largestProbability = 0;
+ int lowThreshold = total >>> tableLog;
+
+ for (int symbol = 0; symbol <= maxSymbol; symbol++) {
+ if (counts[symbol] == total) {
+ throw new IllegalArgumentException(); // TODO: should have been RLE-compressed by upper layers
+ }
+ if (counts[symbol] == 0) {
+ normalizedCounts[symbol] = 0;
+ continue;
+ }
+ if (counts[symbol] <= lowThreshold) {
+ normalizedCounts[symbol] = -1;
+ stillToDistribute--;
+ }
+ else {
+ short probability = (short) ((counts[symbol] * step) >>> scale);
+ if (probability < 8) {
+ long restToBeat = vstep * REST_TO_BEAT[probability];
+ long delta = counts[symbol] * step - (((long) probability) << scale);
+ if (delta > restToBeat) {
+ probability++;
+ }
+ }
+ if (probability > largestProbability) {
+ largestProbability = probability;
+ largest = symbol;
+ }
+ normalizedCounts[symbol] = probability;
+ stillToDistribute -= probability;
+ }
+ }
+
+ if (-stillToDistribute >= (normalizedCounts[largest] >>> 1)) {
+ // corner case. Need another normalization method
+ // TODO size_t const errorCode = FSE_normalizeM2(normalizedCounter, tableLog, count, total, maxSymbolValue);
+ normalizeCounts2(normalizedCounts, tableLog, counts, total, maxSymbol);
+ }
+ else {
+ normalizedCounts[largest] += (short) stillToDistribute;
+ }
+
+ return tableLog;
+ }
+
+ private static int normalizeCounts2(short[] normalizedCounts, int tableLog, int[] counts, int total, int maxSymbol)
+ {
+ int distributed = 0;
+
+ int lowThreshold = total >>> tableLog; // minimum count below which frequency in the normalized table is "too small" (~ < 1)
+ int lowOne = (total * 3) >>> (tableLog + 1); // 1.5 * lowThreshold. If count in (lowThreshold, lowOne] => assign frequency 1
+
+ for (int i = 0; i <= maxSymbol; i++) {
+ if (counts[i] == 0) {
+ normalizedCounts[i] = 0;
+ }
+ else if (counts[i] <= lowThreshold) {
+ normalizedCounts[i] = -1;
+ distributed++;
+ total -= counts[i];
+ }
+ else if (counts[i] <= lowOne) {
+ normalizedCounts[i] = 1;
+ distributed++;
+ total -= counts[i];
+ }
+ else {
+ normalizedCounts[i] = UNASSIGNED;
+ }
+ }
+
+ int normalizationFactor = 1 << tableLog;
+ int toDistribute = normalizationFactor - distributed;
+
+ if ((total / toDistribute) > lowOne) {
+ /* risk of rounding to zero */
+ lowOne = ((total * 3) / (toDistribute * 2));
+ for (int i = 0; i <= maxSymbol; i++) {
+ if ((normalizedCounts[i] == UNASSIGNED) && (counts[i] <= lowOne)) {
+ normalizedCounts[i] = 1;
+ distributed++;
+ total -= counts[i];
+ }
+ }
+ toDistribute = normalizationFactor - distributed;
+ }
+
+ if (distributed == maxSymbol + 1) {
+ // all values are pretty poor;
+ // probably incompressible data (should have already been detected);
+ // find max, then give all remaining points to max
+ int maxValue = 0;
+ int maxCount = 0;
+ for (int i = 0; i <= maxSymbol; i++) {
+ if (counts[i] > maxCount) {
+ maxValue = i;
+ maxCount = counts[i];
+ }
+ }
+ normalizedCounts[maxValue] += (short) toDistribute;
+ return 0;
+ }
+
+ if (total == 0) {
+ // all of the symbols were low enough for the lowOne or lowThreshold
+ for (int i = 0; toDistribute > 0; i = (i + 1) % (maxSymbol + 1)) {
+ if (normalizedCounts[i] > 0) {
+ toDistribute--;
+ normalizedCounts[i]++;
+ }
+ }
+ return 0;
+ }
+
+ // TODO: simplify/document this code
+ long vStepLog = 62 - tableLog;
+ long mid = (1L << (vStepLog - 1)) - 1;
+ long rStep = (((1L << vStepLog) * toDistribute) + mid) / total; /* scale on remaining */
+ long tmpTotal = mid;
+ for (int i = 0; i <= maxSymbol; i++) {
+ if (normalizedCounts[i] == UNASSIGNED) {
+ long end = tmpTotal + (counts[i] * rStep);
+ int sStart = (int) (tmpTotal >>> vStepLog);
+ int sEnd = (int) (end >>> vStepLog);
+ int weight = sEnd - sStart;
+
+ if (weight < 1) {
+ throw new AssertionError();
+ }
+ normalizedCounts[i] = (short) weight;
+ tmpTotal = end;
+ }
+ }
+
+ return 0;
+ }
+
+ public static int writeNormalizedCounts(Object outputBase, long outputAddress, int outputSize, short[] normalizedCounts, int maxSymbol, int tableLog)
+ {
+ checkArgument(tableLog <= MAX_TABLE_LOG, "FSE table too large");
+ checkArgument(tableLog >= MIN_TABLE_LOG, "FSE table too small");
+
+ long output = outputAddress;
+ long outputLimit = outputAddress + outputSize;
+
+ int tableSize = 1 << tableLog;
+
+ int bitCount = 0;
+
+ // encode table size
+ int bitStream = (tableLog - MIN_TABLE_LOG);
+ bitCount += 4;
+
+ int remaining = tableSize + 1; // +1 for extra accuracy
+ int threshold = tableSize;
+ int tableBitCount = tableLog + 1;
+
+ int symbol = 0;
+
+ boolean previousIs0 = false;
+ while (remaining > 1) {
+ if (previousIs0) {
+ // From RFC 8478, section 4.1.1:
+ // When a symbol has a probability of zero, it is followed by a 2-bit
+ // repeat flag. This repeat flag tells how many probabilities of zeroes
+ // follow the current one. It provides a number ranging from 0 to 3.
+ // If it is a 3, another 2-bit repeat flag follows, and so on.
+ int start = symbol;
+
+ // find run of symbols with count 0
+ while (normalizedCounts[symbol] == 0) {
+ symbol++;
+ }
+
+ // encode in batches if 8 repeat sequences in one shot (representing 24 symbols total)
+ while (symbol >= start + 24) {
+ start += 24;
+ bitStream |= (0b11_11_11_11_11_11_11_11 << bitCount);
+ checkArgument(output + SIZE_OF_SHORT <= outputLimit, "Output buffer too small");
+
+ UNSAFE.putShort(outputBase, output, (short) bitStream);
+ output += SIZE_OF_SHORT;
+
+ // flush now, so no need to increase bitCount by 16
+ bitStream >>>= Short.SIZE;
+ }
+
+ // encode remaining in batches of 3 symbols
+ while (symbol >= start + 3) {
+ start += 3;
+ bitStream |= 0b11 << bitCount;
+ bitCount += 2;
+ }
+
+ // encode tail
+ bitStream |= (symbol - start) << bitCount;
+ bitCount += 2;
+
+ // flush bitstream if necessary
+ if (bitCount > 16) {
+ checkArgument(output + SIZE_OF_SHORT <= outputLimit, "Output buffer too small");
+
+ UNSAFE.putShort(outputBase, output, (short) bitStream);
+ output += SIZE_OF_SHORT;
+
+ bitStream >>>= Short.SIZE;
+ bitCount -= Short.SIZE;
+ }
+ }
+
+ int count = normalizedCounts[symbol++];
+ int max = (2 * threshold - 1) - remaining;
+ remaining -= count < 0 ? -count : count;
+ count++; /* +1 for extra accuracy */
+ if (count >= threshold) {
+ count += max;
+ }
+ bitStream |= count << bitCount;
+ bitCount += tableBitCount;
+ bitCount -= (count < max ? 1 : 0);
+ previousIs0 = (count == 1);
+
+ if (remaining < 1) {
+ throw new AssertionError();
+ }
+
+ while (remaining < threshold) {
+ tableBitCount--;
+ threshold >>= 1;
+ }
+
+ // flush bitstream if necessary
+ if (bitCount > 16) {
+ checkArgument(output + SIZE_OF_SHORT <= outputLimit, "Output buffer too small");
+
+ UNSAFE.putShort(outputBase, output, (short) bitStream);
+ output += SIZE_OF_SHORT;
+
+ bitStream >>>= Short.SIZE;
+ bitCount -= Short.SIZE;
+ }
+ }
+
+ // flush remaining bitstream
+ checkArgument(output + SIZE_OF_SHORT <= outputLimit, "Output buffer too small");
+ UNSAFE.putShort(outputBase, output, (short) bitStream);
+ output += (bitCount + 7) / 8;
+
+ checkArgument(symbol <= maxSymbol + 1, "Error"); // TODO
+
+ return (int) (output - outputAddress);
+ }
+
+ public static final class Table
+ {
+ int log2Size;
+ final int[] newState;
+ final byte[] symbol;
+ final byte[] numberOfBits;
+
+ public Table(int log2Capacity)
+ {
+ int capacity = 1 << log2Capacity;
+ newState = new int[capacity];
+ symbol = new byte[capacity];
+ numberOfBits = new byte[capacity];
+ }
+
+ public Table(int log2Size, int[] newState, byte[] symbol, byte[] numberOfBits)
+ {
+ int size = 1 << log2Size;
+ if (newState.length != size || symbol.length != size || numberOfBits.length != size) {
+ throw new IllegalArgumentException("Expected arrays to match provided size");
+ }
+
+ this.log2Size = log2Size;
+ this.newState = newState;
+ this.symbol = symbol;
+ this.numberOfBits = numberOfBits;
+ }
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FrameHeader.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FrameHeader.java
new file mode 100644
index 00000000000..6495939cb38
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FrameHeader.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import java.util.Objects;
+import java.util.StringJoiner;
+
+class FrameHeader
+{
+ final long headerSize;
+ final int windowSize;
+ final long contentSize;
+ final long dictionaryId;
+ final boolean hasChecksum;
+
+ public FrameHeader(long headerSize, int windowSize, long contentSize, long dictionaryId, boolean hasChecksum)
+ {
+ this.headerSize = headerSize;
+ this.windowSize = windowSize;
+ this.contentSize = contentSize;
+ this.dictionaryId = dictionaryId;
+ this.hasChecksum = hasChecksum;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ FrameHeader that = (FrameHeader) o;
+ return headerSize == that.headerSize &&
+ windowSize == that.windowSize &&
+ contentSize == that.contentSize &&
+ dictionaryId == that.dictionaryId &&
+ hasChecksum == that.hasChecksum;
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(headerSize, windowSize, contentSize, dictionaryId, hasChecksum);
+ }
+
+ @Override
+ public String toString()
+ {
+ return new StringJoiner(", ", FrameHeader.class.getSimpleName() + "[", "]")
+ .add("headerSize=" + headerSize)
+ .add("windowSize=" + windowSize)
+ .add("contentSize=" + contentSize)
+ .add("dictionaryId=" + dictionaryId)
+ .add("hasChecksum=" + hasChecksum)
+ .toString();
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FseCompressionTable.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FseCompressionTable.java
new file mode 100644
index 00000000000..e360c5ea5a6
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FseCompressionTable.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.FiniteStateEntropy.MAX_SYMBOL;
+
+class FseCompressionTable
+{
+ private final short[] nextState;
+ private final int[] deltaNumberOfBits;
+ private final int[] deltaFindState;
+
+ private int log2Size;
+
+ public FseCompressionTable(int maxTableLog, int maxSymbol)
+ {
+ nextState = new short[1 << maxTableLog];
+ deltaNumberOfBits = new int[maxSymbol + 1];
+ deltaFindState = new int[maxSymbol + 1];
+ }
+
+ public static FseCompressionTable newInstance(short[] normalizedCounts, int maxSymbol, int tableLog)
+ {
+ FseCompressionTable result = new FseCompressionTable(tableLog, maxSymbol);
+ result.initialize(normalizedCounts, maxSymbol, tableLog);
+
+ return result;
+ }
+
+ public void initializeRleTable(int symbol)
+ {
+ log2Size = 0;
+
+ nextState[0] = 0;
+ nextState[1] = 0;
+
+ deltaFindState[symbol] = 0;
+ deltaNumberOfBits[symbol] = 0;
+ }
+
+ public void initialize(short[] normalizedCounts, int maxSymbol, int tableLog)
+ {
+ int tableSize = 1 << tableLog;
+
+ byte[] table = new byte[tableSize]; // TODO: allocate in workspace
+ int highThreshold = tableSize - 1;
+
+ // TODO: make sure FseCompressionTable has enough size
+ log2Size = tableLog;
+
+ // For explanations on how to distribute symbol values over the table:
+ // http://fastcompression.blogspot.fr/2014/02/fse-distributing-symbol-values.html
+
+ // symbol start positions
+ int[] cumulative = new int[MAX_SYMBOL + 2]; // TODO: allocate in workspace
+ cumulative[0] = 0;
+ for (int i = 1; i <= maxSymbol + 1; i++) {
+ if (normalizedCounts[i - 1] == -1) { // Low probability symbol
+ cumulative[i] = cumulative[i - 1] + 1;
+ table[highThreshold--] = (byte) (i - 1);
+ }
+ else {
+ cumulative[i] = cumulative[i - 1] + normalizedCounts[i - 1];
+ }
+ }
+ cumulative[maxSymbol + 1] = tableSize + 1;
+
+ // Spread symbols
+ int position = spreadSymbols(normalizedCounts, maxSymbol, tableSize, highThreshold, table);
+
+ if (position != 0) {
+ throw new AssertionError("Spread symbols failed");
+ }
+
+ // Build table
+ for (int i = 0; i < tableSize; i++) {
+ byte symbol = table[i];
+ nextState[cumulative[symbol]++] = (short) (tableSize + i); /* TableU16 : sorted by symbol order; gives next state value */
+ }
+
+ // Build symbol transformation table
+ int total = 0;
+ for (int symbol = 0; symbol <= maxSymbol; symbol++) {
+ switch (normalizedCounts[symbol]) {
+ case 0:
+ deltaNumberOfBits[symbol] = ((tableLog + 1) << 16) - tableSize;
+ break;
+ case -1:
+ case 1:
+ deltaNumberOfBits[symbol] = (tableLog << 16) - tableSize;
+ deltaFindState[symbol] = total - 1;
+ total++;
+ break;
+ default:
+ int maxBitsOut = tableLog - Util.highestBit(normalizedCounts[symbol] - 1);
+ int minStatePlus = normalizedCounts[symbol] << maxBitsOut;
+ deltaNumberOfBits[symbol] = (maxBitsOut << 16) - minStatePlus;
+ deltaFindState[symbol] = total - normalizedCounts[symbol];
+ total += normalizedCounts[symbol];
+ break;
+ }
+ }
+ }
+
+ public int begin(byte symbol)
+ {
+ int outputBits = (deltaNumberOfBits[symbol] + (1 << 15)) >>> 16;
+ int base = ((outputBits << 16) - deltaNumberOfBits[symbol]) >>> outputBits;
+ return nextState[base + deltaFindState[symbol]];
+ }
+
+ public int encode(BitOutputStream stream, int state, int symbol)
+ {
+ int outputBits = (state + deltaNumberOfBits[symbol]) >>> 16;
+ stream.addBits(state, outputBits);
+ return nextState[(state >>> outputBits) + deltaFindState[symbol]];
+ }
+
+ public void finish(BitOutputStream stream, int state)
+ {
+ stream.addBits(state, log2Size);
+ stream.flush();
+ }
+
+ private static int calculateStep(int tableSize)
+ {
+ return (tableSize >>> 1) + (tableSize >>> 3) + 3;
+ }
+
+ public static int spreadSymbols(short[] normalizedCounters, int maxSymbolValue, int tableSize, int highThreshold, byte[] symbols)
+ {
+ int mask = tableSize - 1;
+ int step = calculateStep(tableSize);
+
+ int position = 0;
+ for (byte symbol = 0; symbol <= maxSymbolValue; symbol++) {
+ for (int i = 0; i < normalizedCounters[symbol]; i++) {
+ symbols[position] = symbol;
+ do {
+ position = (position + step) & mask;
+ }
+ while (position > highThreshold);
+ }
+ }
+ return position;
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FseTableReader.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FseTableReader.java
new file mode 100644
index 00000000000..0b8182dbc42
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/FseTableReader.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.FiniteStateEntropy.MAX_SYMBOL;
+import static ai.vespa.airlift.zstd.FiniteStateEntropy.MIN_TABLE_LOG;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.highestBit;
+import static ai.vespa.airlift.zstd.Util.verify;
+
+class FseTableReader
+{
+ private final short[] nextSymbol = new short[MAX_SYMBOL + 1];
+ private final short[] normalizedCounters = new short[MAX_SYMBOL + 1];
+
+ public int readFseTable(FiniteStateEntropy.Table table, Object inputBase, long inputAddress, long inputLimit, int maxSymbol, int maxTableLog)
+ {
+ // read table headers
+ long input = inputAddress;
+ verify(inputLimit - inputAddress >= 4, input, "Not enough input bytes");
+
+ int threshold;
+ int symbolNumber = 0;
+ boolean previousIsZero = false;
+
+ int bitStream = UNSAFE.getInt(inputBase, input);
+
+ int tableLog = (bitStream & 0xF) + MIN_TABLE_LOG;
+
+ int numberOfBits = tableLog + 1;
+ bitStream >>>= 4;
+ int bitCount = 4;
+
+ verify(tableLog <= maxTableLog, input, "FSE table size exceeds maximum allowed size");
+
+ int remaining = (1 << tableLog) + 1;
+ threshold = 1 << tableLog;
+
+ while (remaining > 1 && symbolNumber <= maxSymbol) {
+ if (previousIsZero) {
+ int n0 = symbolNumber;
+ while ((bitStream & 0xFFFF) == 0xFFFF) {
+ n0 += 24;
+ if (input < inputLimit - 5) {
+ input += 2;
+ bitStream = (UNSAFE.getInt(inputBase, input) >>> bitCount);
+ }
+ else {
+ // end of bit stream
+ bitStream >>>= 16;
+ bitCount += 16;
+ }
+ }
+ while ((bitStream & 3) == 3) {
+ n0 += 3;
+ bitStream >>>= 2;
+ bitCount += 2;
+ }
+ n0 += bitStream & 3;
+ bitCount += 2;
+
+ verify(n0 <= maxSymbol, input, "Symbol larger than max value");
+
+ while (symbolNumber < n0) {
+ normalizedCounters[symbolNumber++] = 0;
+ }
+ if ((input <= inputLimit - 7) || (input + (bitCount >>> 3) <= inputLimit - 4)) {
+ input += bitCount >>> 3;
+ bitCount &= 7;
+ bitStream = UNSAFE.getInt(inputBase, input) >>> bitCount;
+ }
+ else {
+ bitStream >>>= 2;
+ }
+ }
+
+ short max = (short) ((2 * threshold - 1) - remaining);
+ short count;
+
+ if ((bitStream & (threshold - 1)) < max) {
+ count = (short) (bitStream & (threshold - 1));
+ bitCount += numberOfBits - 1;
+ }
+ else {
+ count = (short) (bitStream & (2 * threshold - 1));
+ if (count >= threshold) {
+ count -= max;
+ }
+ bitCount += numberOfBits;
+ }
+ count--; // extra accuracy
+
+ remaining -= Math.abs(count);
+ normalizedCounters[symbolNumber++] = count;
+ previousIsZero = count == 0;
+ while (remaining < threshold) {
+ numberOfBits--;
+ threshold >>>= 1;
+ }
+
+ if ((input <= inputLimit - 7) || (input + (bitCount >> 3) <= inputLimit - 4)) {
+ input += bitCount >>> 3;
+ bitCount &= 7;
+ }
+ else {
+ bitCount -= (int) (8 * (inputLimit - 4 - input));
+ input = inputLimit - 4;
+ }
+ bitStream = UNSAFE.getInt(inputBase, input) >>> (bitCount & 31);
+ }
+
+ verify(remaining == 1 && bitCount <= 32, input, "Input is corrupted");
+
+ maxSymbol = symbolNumber - 1;
+ verify(maxSymbol <= MAX_SYMBOL, input, "Max symbol value too large (too many symbols for FSE)");
+
+ input += (bitCount + 7) >> 3;
+
+ // populate decoding table
+ int symbolCount = maxSymbol + 1;
+ int tableSize = 1 << tableLog;
+ int highThreshold = tableSize - 1;
+
+ table.log2Size = tableLog;
+
+ for (byte symbol = 0; symbol < symbolCount; symbol++) {
+ if (normalizedCounters[symbol] == -1) {
+ table.symbol[highThreshold--] = symbol;
+ nextSymbol[symbol] = 1;
+ }
+ else {
+ nextSymbol[symbol] = normalizedCounters[symbol];
+ }
+ }
+
+ int position = FseCompressionTable.spreadSymbols(normalizedCounters, maxSymbol, tableSize, highThreshold, table.symbol);
+
+ // position must reach all cells once, otherwise normalizedCounter is incorrect
+ verify(position == 0, input, "Input is corrupted");
+
+ for (int i = 0; i < tableSize; i++) {
+ byte symbol = table.symbol[i];
+ short nextState = nextSymbol[symbol]++;
+ table.numberOfBits[i] = (byte) (tableLog - highestBit(nextState));
+ table.newState[i] = (short) ((nextState << table.numberOfBits[i]) - tableSize);
+ }
+
+ return (int) (input - inputAddress);
+ }
+
+ public static void initializeRleTable(FiniteStateEntropy.Table table, byte value)
+ {
+ table.log2Size = 0;
+ table.symbol[0] = value;
+ table.newState[0] = 0;
+ table.numberOfBits[0] = 0;
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Histogram.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Histogram.java
new file mode 100644
index 00000000000..169de8b2cfa
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Histogram.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import java.util.Arrays;
+
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
+
+class Histogram
+{
+ private Histogram()
+ {
+ }
+
+ // TODO: count parallel heuristic for large inputs
+ private static void count(Object inputBase, long inputAddress, int inputSize, int[] counts)
+ {
+ long input = inputAddress;
+
+ Arrays.fill(counts, 0);
+
+ for (int i = 0; i < inputSize; i++) {
+ int symbol = UNSAFE.getByte(inputBase, input) & 0xFF;
+ input++;
+ counts[symbol]++;
+ }
+ }
+
+ public static int findLargestCount(int[] counts, int maxSymbol)
+ {
+ int max = 0;
+ for (int i = 0; i <= maxSymbol; i++) {
+ if (counts[i] > max) {
+ max = counts[i];
+ }
+ }
+
+ return max;
+ }
+
+ public static int findMaxSymbol(int[] counts, int maxSymbol)
+ {
+ while (counts[maxSymbol] == 0) {
+ maxSymbol--;
+ }
+ return maxSymbol;
+ }
+
+ public static void count(byte[] input, int length, int[] counts)
+ {
+ count(input, ARRAY_BYTE_BASE_OFFSET, length, counts);
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Huffman.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Huffman.java
new file mode 100644
index 00000000000..c8ed6a1f5f0
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Huffman.java
@@ -0,0 +1,323 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import java.util.Arrays;
+
+import static ai.vespa.airlift.zstd.BitInputStream.isEndOfStream;
+import static ai.vespa.airlift.zstd.BitInputStream.peekBitsFast;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_INT;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_SHORT;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.isPowerOf2;
+import static ai.vespa.airlift.zstd.Util.verify;
+
+class Huffman
+{
+ public static final int MAX_SYMBOL = 255;
+ public static final int MAX_SYMBOL_COUNT = MAX_SYMBOL + 1;
+
+ public static final int MAX_TABLE_LOG = 12;
+ public static final int MIN_TABLE_LOG = 5;
+ public static final int MAX_FSE_TABLE_LOG = 6;
+
+ // stats
+ private final byte[] weights = new byte[MAX_SYMBOL + 1];
+ private final int[] ranks = new int[MAX_TABLE_LOG + 1];
+
+ // table
+ private int tableLog = -1;
+ private final byte[] symbols = new byte[1 << MAX_TABLE_LOG];
+ private final byte[] numbersOfBits = new byte[1 << MAX_TABLE_LOG];
+
+ private final FseTableReader reader = new FseTableReader();
+ private final FiniteStateEntropy.Table fseTable = new FiniteStateEntropy.Table(MAX_FSE_TABLE_LOG);
+
+ public boolean isLoaded()
+ {
+ return tableLog != -1;
+ }
+
+ public int readTable(final Object inputBase, final long inputAddress, final int size)
+ {
+ Arrays.fill(ranks, 0);
+ long input = inputAddress;
+
+ // read table header
+ verify(size > 0, input, "Not enough input bytes");
+ int inputSize = UNSAFE.getByte(inputBase, input++) & 0xFF;
+
+ int outputSize;
+ if (inputSize >= 128) {
+ outputSize = inputSize - 127;
+ inputSize = ((outputSize + 1) / 2);
+
+ verify(inputSize + 1 <= size, input, "Not enough input bytes");
+ verify(outputSize <= MAX_SYMBOL + 1, input, "Input is corrupted");
+
+ for (int i = 0; i < outputSize; i += 2) {
+ int value = UNSAFE.getByte(inputBase, input + i / 2) & 0xFF;
+ weights[i] = (byte) (value >>> 4);
+ weights[i + 1] = (byte) (value & 0b1111);
+ }
+ }
+ else {
+ verify(inputSize + 1 <= size, input, "Not enough input bytes");
+
+ long inputLimit = input + inputSize;
+ input += reader.readFseTable(fseTable, inputBase, input, inputLimit, FiniteStateEntropy.MAX_SYMBOL, MAX_FSE_TABLE_LOG);
+ outputSize = FiniteStateEntropy.decompress(fseTable, inputBase, input, inputLimit, weights);
+ }
+
+ int totalWeight = 0;
+ for (int i = 0; i < outputSize; i++) {
+ ranks[weights[i]]++;
+ totalWeight += (1 << weights[i]) >> 1; // TODO same as 1 << (weights[n] - 1)?
+ }
+ verify(totalWeight != 0, input, "Input is corrupted");
+
+ tableLog = Util.highestBit(totalWeight) + 1;
+ verify(tableLog <= MAX_TABLE_LOG, input, "Input is corrupted");
+
+ int total = 1 << tableLog;
+ int rest = total - totalWeight;
+ verify(isPowerOf2(rest), input, "Input is corrupted");
+
+ int lastWeight = Util.highestBit(rest) + 1;
+
+ weights[outputSize] = (byte) lastWeight;
+ ranks[lastWeight]++;
+
+ int numberOfSymbols = outputSize + 1;
+
+ // populate table
+ int nextRankStart = 0;
+ for (int i = 1; i < tableLog + 1; ++i) {
+ int current = nextRankStart;
+ nextRankStart += ranks[i] << (i - 1);
+ ranks[i] = current;
+ }
+
+ for (int n = 0; n < numberOfSymbols; n++) {
+ int weight = weights[n];
+ int length = (1 << weight) >> 1; // TODO: 1 << (weight - 1) ??
+
+ byte symbol = (byte) n;
+ byte numberOfBits = (byte) (tableLog + 1 - weight);
+ for (int i = ranks[weight]; i < ranks[weight] + length; i++) {
+ symbols[i] = symbol;
+ numbersOfBits[i] = numberOfBits;
+ }
+ ranks[weight] += length;
+ }
+
+ verify(ranks[1] >= 2 && (ranks[1] & 1) == 0, input, "Input is corrupted");
+
+ return inputSize + 1;
+ }
+
+ public void decodeSingleStream(final Object inputBase, final long inputAddress, final long inputLimit, final Object outputBase, final long outputAddress, final long outputLimit)
+ {
+ BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
+ initializer.initialize();
+
+ long bits = initializer.getBits();
+ int bitsConsumed = initializer.getBitsConsumed();
+ long currentAddress = initializer.getCurrentAddress();
+
+ int tableLog = this.tableLog;
+ byte[] numbersOfBits = this.numbersOfBits;
+ byte[] symbols = this.symbols;
+
+ // 4 symbols at a time
+ long output = outputAddress;
+ long fastOutputLimit = outputLimit - 4;
+ while (output < fastOutputLimit) {
+ BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits, bitsConsumed);
+ boolean done = loader.load();
+ bits = loader.getBits();
+ bitsConsumed = loader.getBitsConsumed();
+ currentAddress = loader.getCurrentAddress();
+ if (done) {
+ break;
+ }
+
+ bitsConsumed = decodeSymbol(outputBase, output, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ bitsConsumed = decodeSymbol(outputBase, output + 1, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ bitsConsumed = decodeSymbol(outputBase, output + 2, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ bitsConsumed = decodeSymbol(outputBase, output + 3, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ output += SIZE_OF_INT;
+ }
+
+ decodeTail(inputBase, inputAddress, currentAddress, bitsConsumed, bits, outputBase, output, outputLimit);
+ }
+
+ public void decode4Streams(final Object inputBase, final long inputAddress, final long inputLimit, final Object outputBase, final long outputAddress, final long outputLimit)
+ {
+ verify(inputLimit - inputAddress >= 10, inputAddress, "Input is corrupted"); // jump table + 1 byte per stream
+
+ long start1 = inputAddress + 3 * SIZE_OF_SHORT; // for the shorts we read below
+ long start2 = start1 + (UNSAFE.getShort(inputBase, inputAddress) & 0xFFFF);
+ long start3 = start2 + (UNSAFE.getShort(inputBase, inputAddress + 2) & 0xFFFF);
+ long start4 = start3 + (UNSAFE.getShort(inputBase, inputAddress + 4) & 0xFFFF);
+
+ BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, start1, start2);
+ initializer.initialize();
+ int stream1bitsConsumed = initializer.getBitsConsumed();
+ long stream1currentAddress = initializer.getCurrentAddress();
+ long stream1bits = initializer.getBits();
+
+ initializer = new BitInputStream.Initializer(inputBase, start2, start3);
+ initializer.initialize();
+ int stream2bitsConsumed = initializer.getBitsConsumed();
+ long stream2currentAddress = initializer.getCurrentAddress();
+ long stream2bits = initializer.getBits();
+
+ initializer = new BitInputStream.Initializer(inputBase, start3, start4);
+ initializer.initialize();
+ int stream3bitsConsumed = initializer.getBitsConsumed();
+ long stream3currentAddress = initializer.getCurrentAddress();
+ long stream3bits = initializer.getBits();
+
+ initializer = new BitInputStream.Initializer(inputBase, start4, inputLimit);
+ initializer.initialize();
+ int stream4bitsConsumed = initializer.getBitsConsumed();
+ long stream4currentAddress = initializer.getCurrentAddress();
+ long stream4bits = initializer.getBits();
+
+ int segmentSize = (int) ((outputLimit - outputAddress + 3) / 4);
+
+ long outputStart2 = outputAddress + segmentSize;
+ long outputStart3 = outputStart2 + segmentSize;
+ long outputStart4 = outputStart3 + segmentSize;
+
+ long output1 = outputAddress;
+ long output2 = outputStart2;
+ long output3 = outputStart3;
+ long output4 = outputStart4;
+
+ long fastOutputLimit = outputLimit - 7;
+ int tableLog = this.tableLog;
+ byte[] numbersOfBits = this.numbersOfBits;
+ byte[] symbols = this.symbols;
+
+ while (output4 < fastOutputLimit) {
+ stream1bitsConsumed = decodeSymbol(outputBase, output1, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream2bitsConsumed = decodeSymbol(outputBase, output2, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream3bitsConsumed = decodeSymbol(outputBase, output3, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream4bitsConsumed = decodeSymbol(outputBase, output4, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
+
+ stream1bitsConsumed = decodeSymbol(outputBase, output1 + 1, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream2bitsConsumed = decodeSymbol(outputBase, output2 + 1, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream3bitsConsumed = decodeSymbol(outputBase, output3 + 1, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream4bitsConsumed = decodeSymbol(outputBase, output4 + 1, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
+
+ stream1bitsConsumed = decodeSymbol(outputBase, output1 + 2, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream2bitsConsumed = decodeSymbol(outputBase, output2 + 2, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream3bitsConsumed = decodeSymbol(outputBase, output3 + 2, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream4bitsConsumed = decodeSymbol(outputBase, output4 + 2, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
+
+ stream1bitsConsumed = decodeSymbol(outputBase, output1 + 3, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream2bitsConsumed = decodeSymbol(outputBase, output2 + 3, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream3bitsConsumed = decodeSymbol(outputBase, output3 + 3, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream4bitsConsumed = decodeSymbol(outputBase, output4 + 3, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
+
+ output1 += SIZE_OF_INT;
+ output2 += SIZE_OF_INT;
+ output3 += SIZE_OF_INT;
+ output4 += SIZE_OF_INT;
+
+ BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, start1, stream1currentAddress, stream1bits, stream1bitsConsumed);
+ boolean done = loader.load();
+ stream1bitsConsumed = loader.getBitsConsumed();
+ stream1bits = loader.getBits();
+ stream1currentAddress = loader.getCurrentAddress();
+
+ if (done) {
+ break;
+ }
+
+ loader = new BitInputStream.Loader(inputBase, start2, stream2currentAddress, stream2bits, stream2bitsConsumed);
+ done = loader.load();
+ stream2bitsConsumed = loader.getBitsConsumed();
+ stream2bits = loader.getBits();
+ stream2currentAddress = loader.getCurrentAddress();
+
+ if (done) {
+ break;
+ }
+
+ loader = new BitInputStream.Loader(inputBase, start3, stream3currentAddress, stream3bits, stream3bitsConsumed);
+ done = loader.load();
+ stream3bitsConsumed = loader.getBitsConsumed();
+ stream3bits = loader.getBits();
+ stream3currentAddress = loader.getCurrentAddress();
+ if (done) {
+ break;
+ }
+
+ loader = new BitInputStream.Loader(inputBase, start4, stream4currentAddress, stream4bits, stream4bitsConsumed);
+ done = loader.load();
+ stream4bitsConsumed = loader.getBitsConsumed();
+ stream4bits = loader.getBits();
+ stream4currentAddress = loader.getCurrentAddress();
+ if (done) {
+ break;
+ }
+ }
+
+ verify(output1 <= outputStart2 && output2 <= outputStart3 && output3 <= outputStart4, inputAddress, "Input is corrupted");
+
+ /// finish streams one by one
+ decodeTail(inputBase, start1, stream1currentAddress, stream1bitsConsumed, stream1bits, outputBase, output1, outputStart2);
+ decodeTail(inputBase, start2, stream2currentAddress, stream2bitsConsumed, stream2bits, outputBase, output2, outputStart3);
+ decodeTail(inputBase, start3, stream3currentAddress, stream3bitsConsumed, stream3bits, outputBase, output3, outputStart4);
+ decodeTail(inputBase, start4, stream4currentAddress, stream4bitsConsumed, stream4bits, outputBase, output4, outputLimit);
+ }
+
+ private void decodeTail(final Object inputBase, final long startAddress, long currentAddress, int bitsConsumed, long bits, final Object outputBase, long outputAddress, final long outputLimit)
+ {
+ int tableLog = this.tableLog;
+ byte[] numbersOfBits = this.numbersOfBits;
+ byte[] symbols = this.symbols;
+
+ // closer to the end
+ while (outputAddress < outputLimit) {
+ BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, startAddress, currentAddress, bits, bitsConsumed);
+ boolean done = loader.load();
+ bitsConsumed = loader.getBitsConsumed();
+ bits = loader.getBits();
+ currentAddress = loader.getCurrentAddress();
+ if (done) {
+ break;
+ }
+
+ bitsConsumed = decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ }
+
+ // not more data in bit stream, so no need to reload
+ while (outputAddress < outputLimit) {
+ bitsConsumed = decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ }
+
+ verify(isEndOfStream(startAddress, currentAddress, bitsConsumed), startAddress, "Bit stream is not fully consumed");
+ }
+
+ private static int decodeSymbol(Object outputBase, long outputAddress, long bitContainer, int bitsConsumed, int tableLog, byte[] numbersOfBits, byte[] symbols)
+ {
+ int value = (int) peekBitsFast(bitsConsumed, bitContainer, tableLog);
+ UNSAFE.putByte(outputBase, outputAddress, symbols[value]);
+ return bitsConsumed + numbersOfBits[value];
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionContext.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionContext.java
new file mode 100644
index 00000000000..a651ea2a625
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionContext.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+class HuffmanCompressionContext
+{
+ private final HuffmanTableWriterWorkspace tableWriterWorkspace = new HuffmanTableWriterWorkspace();
+ private final HuffmanCompressionTableWorkspace compressionTableWorkspace = new HuffmanCompressionTableWorkspace();
+
+ private HuffmanCompressionTable previousTable = new HuffmanCompressionTable(Huffman.MAX_SYMBOL_COUNT);
+ private HuffmanCompressionTable temporaryTable = new HuffmanCompressionTable(Huffman.MAX_SYMBOL_COUNT);
+
+ private HuffmanCompressionTable previousCandidate = previousTable;
+ private HuffmanCompressionTable temporaryCandidate = temporaryTable;
+
+ public HuffmanCompressionTable getPreviousTable()
+ {
+ return previousTable;
+ }
+
+ public HuffmanCompressionTable borrowTemporaryTable()
+ {
+ previousCandidate = temporaryTable;
+ temporaryCandidate = previousTable;
+
+ return temporaryTable;
+ }
+
+ public void discardTemporaryTable()
+ {
+ previousCandidate = previousTable;
+ temporaryCandidate = temporaryTable;
+ }
+
+ public void saveChanges()
+ {
+ temporaryTable = temporaryCandidate;
+ previousTable = previousCandidate;
+ }
+
+ public HuffmanCompressionTableWorkspace getCompressionTableWorkspace()
+ {
+ return compressionTableWorkspace;
+ }
+
+ public HuffmanTableWriterWorkspace getTableWriterWorkspace()
+ {
+ return tableWriterWorkspace;
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionTable.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionTable.java
new file mode 100644
index 00000000000..a18d7343b52
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionTable.java
@@ -0,0 +1,437 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import java.util.Arrays;
+
+import static ai.vespa.airlift.zstd.Huffman.MAX_FSE_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Huffman.MAX_SYMBOL;
+import static ai.vespa.airlift.zstd.Huffman.MAX_SYMBOL_COUNT;
+import static ai.vespa.airlift.zstd.Huffman.MAX_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Huffman.MIN_TABLE_LOG;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.checkArgument;
+import static ai.vespa.airlift.zstd.Util.minTableLog;
+
+final class HuffmanCompressionTable
+{
+ private final short[] values;
+ private final byte[] numberOfBits;
+
+ private int maxSymbol;
+ private int maxNumberOfBits;
+
+ public HuffmanCompressionTable(int capacity)
+ {
+ this.values = new short[capacity];
+ this.numberOfBits = new byte[capacity];
+ }
+
+ public static int optimalNumberOfBits(int maxNumberOfBits, int inputSize, int maxSymbol)
+ {
+ if (inputSize <= 1) {
+ throw new IllegalArgumentException(); // not supported. Use RLE instead
+ }
+
+ int result = maxNumberOfBits;
+
+ result = Math.min(result, Util.highestBit((inputSize - 1)) - 1); // we may be able to reduce accuracy if input is small
+
+ // Need a minimum to safely represent all symbol values
+ result = Math.max(result, minTableLog(inputSize, maxSymbol));
+
+ result = Math.max(result, MIN_TABLE_LOG); // absolute minimum for Huffman
+ result = Math.min(result, MAX_TABLE_LOG); // absolute maximum for Huffman
+
+ return result;
+ }
+
+ public void initialize(int[] counts, int maxSymbol, int maxNumberOfBits, HuffmanCompressionTableWorkspace workspace)
+ {
+ checkArgument(maxSymbol <= MAX_SYMBOL, "Max symbol value too large");
+
+ workspace.reset();
+
+ NodeTable nodeTable = workspace.nodeTable;
+ nodeTable.reset();
+
+ int lastNonZero = buildTree(counts, maxSymbol, nodeTable);
+
+ // enforce max table log
+ maxNumberOfBits = setMaxHeight(nodeTable, lastNonZero, maxNumberOfBits, workspace);
+ checkArgument(maxNumberOfBits <= MAX_TABLE_LOG, "Max number of bits larger than max table size");
+
+ // populate table
+ int symbolCount = maxSymbol + 1;
+ for (int node = 0; node < symbolCount; node++) {
+ int symbol = nodeTable.symbols[node];
+ numberOfBits[symbol] = nodeTable.numberOfBits[node];
+ }
+
+ short[] entriesPerRank = workspace.entriesPerRank;
+ short[] valuesPerRank = workspace.valuesPerRank;
+
+ for (int n = 0; n <= lastNonZero; n++) {
+ entriesPerRank[nodeTable.numberOfBits[n]]++;
+ }
+
+ // determine starting value per rank
+ short startingValue = 0;
+ for (int rank = maxNumberOfBits; rank > 0; rank--) {
+ valuesPerRank[rank] = startingValue; // get starting value within each rank
+ startingValue += entriesPerRank[rank];
+ startingValue >>>= 1;
+ }
+
+ for (int n = 0; n <= maxSymbol; n++) {
+ values[n] = valuesPerRank[numberOfBits[n]]++; // assign value within rank, symbol order
+ }
+
+ this.maxSymbol = maxSymbol;
+ this.maxNumberOfBits = maxNumberOfBits;
+ }
+
+ private int buildTree(int[] counts, int maxSymbol, NodeTable nodeTable)
+ {
+ // populate the leaves of the node table from the histogram of counts
+ // in descending order by count, ascending by symbol value.
+ short current = 0;
+
+ for (int symbol = 0; symbol <= maxSymbol; symbol++) {
+ int count = counts[symbol];
+
+ // simple insertion sort
+ int position = current;
+ while (position > 1 && count > nodeTable.count[position - 1]) {
+ nodeTable.copyNode(position - 1, position);
+ position--;
+ }
+
+ nodeTable.count[position] = count;
+ nodeTable.symbols[position] = symbol;
+
+ current++;
+ }
+
+ int lastNonZero = maxSymbol;
+ while (nodeTable.count[lastNonZero] == 0) {
+ lastNonZero--;
+ }
+
+ // populate the non-leaf nodes
+ short nonLeafStart = MAX_SYMBOL_COUNT;
+ current = nonLeafStart;
+
+ int currentLeaf = lastNonZero;
+
+ // combine the two smallest leaves to create the first intermediate node
+ int currentNonLeaf = current;
+ nodeTable.count[current] = nodeTable.count[currentLeaf] + nodeTable.count[currentLeaf - 1];
+ nodeTable.parents[currentLeaf] = current;
+ nodeTable.parents[currentLeaf - 1] = current;
+ current++;
+ currentLeaf -= 2;
+
+ int root = MAX_SYMBOL_COUNT + lastNonZero - 1;
+
+ // fill in sentinels
+ for (int n = current; n <= root; n++) {
+ nodeTable.count[n] = 1 << 30;
+ }
+
+ // create parents
+ while (current <= root) {
+ int child1;
+ if (currentLeaf >= 0 && nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
+ child1 = currentLeaf--;
+ }
+ else {
+ child1 = currentNonLeaf++;
+ }
+
+ int child2;
+ if (currentLeaf >= 0 && nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
+ child2 = currentLeaf--;
+ }
+ else {
+ child2 = currentNonLeaf++;
+ }
+
+ nodeTable.count[current] = nodeTable.count[child1] + nodeTable.count[child2];
+ nodeTable.parents[child1] = current;
+ nodeTable.parents[child2] = current;
+ current++;
+ }
+
+ // distribute weights
+ nodeTable.numberOfBits[root] = 0;
+ for (int n = root - 1; n >= nonLeafStart; n--) {
+ short parent = nodeTable.parents[n];
+ nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
+ }
+
+ for (int n = 0; n <= lastNonZero; n++) {
+ short parent = nodeTable.parents[n];
+ nodeTable.numberOfBits[n] = (byte) (nodeTable.numberOfBits[parent] + 1);
+ }
+
+ return lastNonZero;
+ }
+
+ // TODO: consider encoding 2 symbols at a time
+ // - need a table with 256x256 entries with
+ // - the concatenated bits for the corresponding pair of symbols
+ // - the sum of bits for the corresponding pair of symbols
+ // - read 2 symbols at a time from the input
+ public void encodeSymbol(BitOutputStream output, int symbol)
+ {
+ output.addBitsFast(values[symbol], numberOfBits[symbol]);
+ }
+
+ public int write(Object outputBase, long outputAddress, int outputSize, HuffmanTableWriterWorkspace workspace)
+ {
+ byte[] weights = workspace.weights;
+
+ long output = outputAddress;
+
+ int maxNumberOfBits = this.maxNumberOfBits;
+ int maxSymbol = this.maxSymbol;
+
+ // convert to weights per RFC 8478 section 4.2.1
+ for (int symbol = 0; symbol < maxSymbol; symbol++) {
+ int bits = numberOfBits[symbol];
+
+ if (bits == 0) {
+ weights[symbol] = 0;
+ }
+ else {
+ weights[symbol] = (byte) (maxNumberOfBits + 1 - bits);
+ }
+ }
+
+ // attempt weights compression by FSE
+ int size = compressWeights(outputBase, output + 1, outputSize - 1, weights, maxSymbol, workspace);
+
+ if (maxSymbol > 127 && size > 127) {
+ // This should never happen. Since weights are in the range [0, 12], they can be compressed optimally to ~3.7 bits per symbol for a uniform distribution.
+ // Since maxSymbol has to be <= MAX_SYMBOL (255), this is 119 bytes + FSE headers.
+ throw new AssertionError();
+ }
+
+ if (size != 0 && size != 1 && size < maxSymbol / 2) {
+ // Go with FSE only if:
+ // - the weights are compressible
+ // - the compressed size is better than what we'd get with the raw encoding below
+ // - the compressed size is <= 127 bytes, which is the most that the encoding can hold for FSE-compressed weights (see RFC 8478 section 4.2.1.1). This is implied
+ // by the maxSymbol / 2 check, since maxSymbol must be <= 255
+ UNSAFE.putByte(outputBase, output, (byte) size);
+ return size + 1; // header + size
+ }
+ else {
+ // Use raw encoding (4 bits per entry)
+
+ // #entries = #symbols - 1 since last symbol is implicit. Thus, #entries = (maxSymbol + 1) - 1 = maxSymbol
+ int entryCount = maxSymbol;
+
+ size = (entryCount + 1) / 2; // ceil(#entries / 2)
+ checkArgument(size + 1 /* header */ <= outputSize, "Output size too small"); // 2 entries per byte
+
+ // encode number of symbols
+ // header = #entries + 127 per RFC
+ UNSAFE.putByte(outputBase, output, (byte) (127 + entryCount));
+ output++;
+
+ weights[maxSymbol] = 0; // last weight is implicit, so set to 0 so that it doesn't get encoded below
+ for (int i = 0; i < entryCount; i += 2) {
+ UNSAFE.putByte(outputBase, output, (byte) ((weights[i] << 4) + weights[i + 1]));
+ output++;
+ }
+
+ return (int) (output - outputAddress);
+ }
+ }
+
+ /**
+ * Can this table encode all symbols with non-zero count?
+ */
+ public boolean isValid(int[] counts, int maxSymbol)
+ {
+ if (maxSymbol > this.maxSymbol) {
+ // some non-zero count symbols cannot be encoded by the current table
+ return false;
+ }
+
+ for (int symbol = 0; symbol <= maxSymbol; ++symbol) {
+ if (counts[symbol] != 0 && numberOfBits[symbol] == 0) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ public int estimateCompressedSize(int[] counts, int maxSymbol)
+ {
+ int numberOfBits = 0;
+ for (int symbol = 0; symbol <= Math.min(maxSymbol, this.maxSymbol); symbol++) {
+ numberOfBits += this.numberOfBits[symbol] * counts[symbol];
+ }
+
+ return numberOfBits >>> 3; // convert to bytes
+ }
+
+ // http://fastcompression.blogspot.com/2015/07/huffman-revisited-part-3-depth-limited.html
+ private static int setMaxHeight(NodeTable nodeTable, int lastNonZero, int maxNumberOfBits, HuffmanCompressionTableWorkspace workspace)
+ {
+ int largestBits = nodeTable.numberOfBits[lastNonZero];
+
+ if (largestBits <= maxNumberOfBits) {
+ return largestBits; // early exit: no elements > maxNumberOfBits
+ }
+
+ // there are several too large elements (at least >= 2)
+ int totalCost = 0;
+ int baseCost = 1 << (largestBits - maxNumberOfBits);
+ int n = lastNonZero;
+
+ while (nodeTable.numberOfBits[n] > maxNumberOfBits) {
+ totalCost += baseCost - (1 << (largestBits - nodeTable.numberOfBits[n]));
+ nodeTable.numberOfBits[n ] = (byte) maxNumberOfBits;
+ n--;
+ } // n stops at nodeTable.numberOfBits[n + offset] <= maxNumberOfBits
+
+ while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
+ n--; // n ends at index of smallest symbol using < maxNumberOfBits
+ }
+
+ // renormalize totalCost
+ totalCost >>>= (largestBits - maxNumberOfBits); // note: totalCost is necessarily a multiple of baseCost
+
+ // repay normalized cost
+ int noSymbol = 0xF0F0F0F0;
+ int[] rankLast = workspace.rankLast;
+ Arrays.fill(rankLast, noSymbol);
+
+ // Get pos of last (smallest) symbol per rank
+ int currentNbBits = maxNumberOfBits;
+ for (int pos = n; pos >= 0; pos--) {
+ if (nodeTable.numberOfBits[pos] >= currentNbBits) {
+ continue;
+ }
+ currentNbBits = nodeTable.numberOfBits[pos]; // < maxNumberOfBits
+ rankLast[maxNumberOfBits - currentNbBits] = pos;
+ }
+
+ while (totalCost > 0) {
+ int numberOfBitsToDecrease = Util.highestBit(totalCost) + 1;
+ for (; numberOfBitsToDecrease > 1; numberOfBitsToDecrease--) {
+ int highPosition = rankLast[numberOfBitsToDecrease];
+ int lowPosition = rankLast[numberOfBitsToDecrease - 1];
+ if (highPosition == noSymbol) {
+ continue;
+ }
+ if (lowPosition == noSymbol) {
+ break;
+ }
+ int highTotal = nodeTable.count[highPosition];
+ int lowTotal = 2 * nodeTable.count[lowPosition];
+ if (highTotal <= lowTotal) {
+ break;
+ }
+ }
+
+ // only triggered when no more rank 1 symbol left => find closest one (note : there is necessarily at least one !)
+ // HUF_MAX_TABLELOG test just to please gcc 5+; but it should not be necessary
+ while ((numberOfBitsToDecrease <= MAX_TABLE_LOG) && (rankLast[numberOfBitsToDecrease] == noSymbol)) {
+ numberOfBitsToDecrease++;
+ }
+ totalCost -= 1 << (numberOfBitsToDecrease - 1);
+ if (rankLast[numberOfBitsToDecrease - 1] == noSymbol) {
+ rankLast[numberOfBitsToDecrease - 1] = rankLast[numberOfBitsToDecrease]; // this rank is no longer empty
+ }
+ nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]]++;
+ if (rankLast[numberOfBitsToDecrease] == 0) { /* special case, reached largest symbol */
+ rankLast[numberOfBitsToDecrease] = noSymbol;
+ }
+ else {
+ rankLast[numberOfBitsToDecrease]--;
+ if (nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]] != maxNumberOfBits - numberOfBitsToDecrease) {
+ rankLast[numberOfBitsToDecrease] = noSymbol; // this rank is now empty
+ }
+ }
+ }
+
+ while (totalCost < 0) { // Sometimes, cost correction overshoot
+ if (rankLast[1] == noSymbol) { /* special case : no rank 1 symbol (using maxNumberOfBits-1); let's create one from largest rank 0 (using maxNumberOfBits) */
+ while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
+ n--;
+ }
+ nodeTable.numberOfBits[n + 1]--;
+ rankLast[1] = n + 1;
+ totalCost++;
+ continue;
+ }
+ nodeTable.numberOfBits[rankLast[1] + 1]--;
+ rankLast[1]++;
+ totalCost++;
+ }
+
+ return maxNumberOfBits;
+ }
+
+ /**
+ * All elements within weightTable must be <= Huffman.MAX_TABLE_LOG
+ */
+ private static int compressWeights(Object outputBase, long outputAddress, int outputSize, byte[] weights, int weightsLength, HuffmanTableWriterWorkspace workspace)
+ {
+ if (weightsLength <= 1) {
+ return 0; // Not compressible
+ }
+
+ // Scan input and build symbol stats
+ int[] counts = workspace.counts;
+ Histogram.count(weights, weightsLength, counts);
+ int maxSymbol = Histogram.findMaxSymbol(counts, MAX_TABLE_LOG);
+ int maxCount = Histogram.findLargestCount(counts, maxSymbol);
+
+ if (maxCount == weightsLength) {
+ return 1; // only a single symbol in source
+ }
+ if (maxCount == 1) {
+ return 0; // each symbol present maximum once => not compressible
+ }
+
+ short[] normalizedCounts = workspace.normalizedCounts;
+
+ int tableLog = FiniteStateEntropy.optimalTableLog(MAX_FSE_TABLE_LOG, weightsLength, maxSymbol);
+ FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts, weightsLength, maxSymbol);
+
+ long output = outputAddress;
+ long outputLimit = outputAddress + outputSize;
+
+ // Write table description header
+ int headerSize = FiniteStateEntropy.writeNormalizedCounts(outputBase, output, outputSize, normalizedCounts, maxSymbol, tableLog);
+ output += headerSize;
+
+ // Compress
+ FseCompressionTable compressionTable = workspace.fseTable;
+ compressionTable.initialize(normalizedCounts, maxSymbol, tableLog);
+ int compressedSize = FiniteStateEntropy.compress(outputBase, output, (int) (outputLimit - output), weights, weightsLength, compressionTable);
+ if (compressedSize == 0) {
+ return 0;
+ }
+ output += compressedSize;
+
+ return (int) (output - outputAddress);
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionTableWorkspace.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionTableWorkspace.java
new file mode 100644
index 00000000000..b6ad2adaec7
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressionTableWorkspace.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import java.util.Arrays;
+
+class HuffmanCompressionTableWorkspace
+{
+ public final NodeTable nodeTable = new NodeTable((2 * Huffman.MAX_SYMBOL_COUNT - 1)); // number of nodes in binary tree with MAX_SYMBOL_COUNT leaves
+
+ public final short[] entriesPerRank = new short[Huffman.MAX_TABLE_LOG + 1];
+ public final short[] valuesPerRank = new short[Huffman.MAX_TABLE_LOG + 1];
+
+ // for setMaxHeight
+ public final int[] rankLast = new int[Huffman.MAX_TABLE_LOG + 2];
+
+ public void reset()
+ {
+ Arrays.fill(entriesPerRank, (short) 0);
+ Arrays.fill(valuesPerRank, (short) 0);
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressor.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressor.java
new file mode 100644
index 00000000000..6c94181a88f
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanCompressor.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_LONG;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_SHORT;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+
+class HuffmanCompressor
+{
+ private HuffmanCompressor()
+ {
+ }
+
+ public static int compress4streams(Object outputBase, long outputAddress, int outputSize, Object inputBase, long inputAddress, int inputSize, HuffmanCompressionTable table)
+ {
+ long input = inputAddress;
+ long inputLimit = inputAddress + inputSize;
+ long output = outputAddress;
+ long outputLimit = outputAddress + outputSize;
+
+ int segmentSize = (inputSize + 3) / 4;
+
+ if (outputSize < 6 /* jump table */ + 1 /* first stream */ + 1 /* second stream */ + 1 /* third stream */ + 8 /* 8 bytes minimum needed by the bitstream encoder */) {
+ return 0; // minimum space to compress successfully
+ }
+
+ if (inputSize <= 6 + 1 + 1 + 1) { // jump table + one byte per stream
+ return 0; // no saving possible: input too small
+ }
+
+ output += SIZE_OF_SHORT + SIZE_OF_SHORT + SIZE_OF_SHORT; // jump table
+
+ int compressedSize;
+
+ // first segment
+ compressedSize = compressSingleStream(outputBase, output, (int) (outputLimit - output), inputBase, input, segmentSize, table);
+ if (compressedSize == 0) {
+ return 0;
+ }
+ UNSAFE.putShort(outputBase, outputAddress, (short) compressedSize);
+ output += compressedSize;
+ input += segmentSize;
+
+ // second segment
+ compressedSize = compressSingleStream(outputBase, output, (int) (outputLimit - output), inputBase, input, segmentSize, table);
+ if (compressedSize == 0) {
+ return 0;
+ }
+ UNSAFE.putShort(outputBase, outputAddress + SIZE_OF_SHORT, (short) compressedSize);
+ output += compressedSize;
+ input += segmentSize;
+
+ // third segment
+ compressedSize = compressSingleStream(outputBase, output, (int) (outputLimit - output), inputBase, input, segmentSize, table);
+ if (compressedSize == 0) {
+ return 0;
+ }
+ UNSAFE.putShort(outputBase, outputAddress + SIZE_OF_SHORT + SIZE_OF_SHORT, (short) compressedSize);
+ output += compressedSize;
+ input += segmentSize;
+
+ // fourth segment
+ compressedSize = compressSingleStream(outputBase, output, (int) (outputLimit - output), inputBase, input, (int) (inputLimit - input), table);
+ if (compressedSize == 0) {
+ return 0;
+ }
+ output += compressedSize;
+
+ return (int) (output - outputAddress);
+ }
+
+ @SuppressWarnings("fallthrough")
+ public static int compressSingleStream(Object outputBase, long outputAddress, int outputSize, Object inputBase, long inputAddress, int inputSize, HuffmanCompressionTable table)
+ {
+ if (outputSize < SIZE_OF_LONG) {
+ return 0;
+ }
+
+ BitOutputStream bitstream = new BitOutputStream(outputBase, outputAddress, outputSize);
+ long input = inputAddress;
+
+ int n = inputSize & ~3; // join to mod 4
+
+ switch (inputSize & 3) {
+ case 3:
+ table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n + 2) & 0xFF);
+ if (SIZE_OF_LONG * 8 < Huffman.MAX_TABLE_LOG * 4 + 7) {
+ bitstream.flush();
+ }
+ // fall-through
+ case 2:
+ table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n + 1) & 0xFF);
+ if (SIZE_OF_LONG * 8 < Huffman.MAX_TABLE_LOG * 2 + 7) {
+ bitstream.flush();
+ }
+ // fall-through
+ case 1:
+ table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n + 0) & 0xFF);
+ bitstream.flush();
+ // fall-through
+ case 0: /* fall-through */
+ default:
+ break;
+ }
+
+ for (; n > 0; n -= 4) { // note: n & 3 == 0 at this stage
+ table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n - 1) & 0xFF);
+ if (SIZE_OF_LONG * 8 < Huffman.MAX_TABLE_LOG * 2 + 7) {
+ bitstream.flush();
+ }
+ table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n - 2) & 0xFF);
+ if (SIZE_OF_LONG * 8 < Huffman.MAX_TABLE_LOG * 4 + 7) {
+ bitstream.flush();
+ }
+ table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n - 3) & 0xFF);
+ if (SIZE_OF_LONG * 8 < Huffman.MAX_TABLE_LOG * 2 + 7) {
+ bitstream.flush();
+ }
+ table.encodeSymbol(bitstream, UNSAFE.getByte(inputBase, input + n - 4) & 0xFF);
+ bitstream.flush();
+ }
+
+ return bitstream.close();
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanTableWriterWorkspace.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanTableWriterWorkspace.java
new file mode 100644
index 00000000000..80f39506f07
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/HuffmanTableWriterWorkspace.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Huffman.MAX_FSE_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Huffman.MAX_SYMBOL;
+import static ai.vespa.airlift.zstd.Huffman.MAX_TABLE_LOG;
+
+class HuffmanTableWriterWorkspace
+{
+ // for encoding weights
+ public final byte[] weights = new byte[MAX_SYMBOL]; // the weight for the last symbol is implicit
+
+ // for compressing weights
+ public final int[] counts = new int[MAX_TABLE_LOG + 1];
+ public final short[] normalizedCounts = new short[MAX_TABLE_LOG + 1];
+ public final FseCompressionTable fseTable = new FseCompressionTable(MAX_FSE_TABLE_LOG, MAX_TABLE_LOG);
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/NodeTable.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/NodeTable.java
new file mode 100644
index 00000000000..4466071025d
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/NodeTable.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import java.util.Arrays;
+
+class NodeTable
+{
+ int[] count;
+ short[] parents;
+ int[] symbols;
+ byte[] numberOfBits;
+
+ public NodeTable(int size)
+ {
+ count = new int[size];
+ parents = new short[size];
+ symbols = new int[size];
+ numberOfBits = new byte[size];
+ }
+
+ public void reset()
+ {
+ Arrays.fill(count, 0);
+ Arrays.fill(parents, (short) 0);
+ Arrays.fill(symbols, 0);
+ Arrays.fill(numberOfBits, (byte) 0);
+ }
+
+ public void copyNode(int from, int to)
+ {
+ count[to] = count[from];
+ parents[to] = parents[from];
+ symbols[to] = symbols[from];
+ numberOfBits[to] = numberOfBits[from];
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/RepeatedOffsets.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/RepeatedOffsets.java
new file mode 100644
index 00000000000..9b6eab05611
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/RepeatedOffsets.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+class RepeatedOffsets
+{
+ private int offset0 = 1;
+ private int offset1 = 4;
+
+ private int tempOffset0;
+ private int tempOffset1;
+
+ public int getOffset0()
+ {
+ return offset0;
+ }
+
+ public int getOffset1()
+ {
+ return offset1;
+ }
+
+ public void saveOffset0(int offset)
+ {
+ tempOffset0 = offset;
+ }
+
+ public void saveOffset1(int offset)
+ {
+ tempOffset1 = offset;
+ }
+
+ public void commit()
+ {
+ offset0 = tempOffset0;
+ offset1 = tempOffset1;
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceEncoder.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceEncoder.java
new file mode 100644
index 00000000000..df80b08dd35
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceEncoder.java
@@ -0,0 +1,351 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Constants.DEFAULT_MAX_OFFSET_CODE_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.LITERALS_LENGTH_BITS;
+import static ai.vespa.airlift.zstd.Constants.LITERAL_LENGTH_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Constants.LONG_NUMBER_OF_SEQUENCES;
+import static ai.vespa.airlift.zstd.Constants.MATCH_LENGTH_BITS;
+import static ai.vespa.airlift.zstd.Constants.MATCH_LENGTH_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Constants.MAX_LITERALS_LENGTH_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.MAX_MATCH_LENGTH_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.MAX_OFFSET_CODE_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.OFFSET_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Constants.SEQUENCE_ENCODING_BASIC;
+import static ai.vespa.airlift.zstd.Constants.SEQUENCE_ENCODING_COMPRESSED;
+import static ai.vespa.airlift.zstd.Constants.SEQUENCE_ENCODING_RLE;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_SHORT;
+import static ai.vespa.airlift.zstd.FiniteStateEntropy.optimalTableLog;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.checkArgument;
+
+class SequenceEncoder
+{
+ private static final int DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG = 6;
+ private static final short[] DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS = {4, 3, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 1, 1, 1,
+ 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 3, 2, 1, 1, 1, 1, 1,
+ -1, -1, -1, -1};
+
+ private static final int DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS_LOG = 6;
+ private static final short[] DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS = {1, 4, 3, 2, 2, 2, 2, 2,
+ 2, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, -1, -1,
+ -1, -1, -1, -1, -1};
+
+ private static final int DEFAULT_OFFSET_NORMALIZED_COUNTS_LOG = 5;
+ private static final short[] DEFAULT_OFFSET_NORMALIZED_COUNTS = {1, 1, 1, 1, 1, 1, 2, 2,
+ 2, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1,
+ -1, -1, -1, -1, -1};
+
+ private static final FseCompressionTable DEFAULT_LITERAL_LENGTHS_TABLE = FseCompressionTable.newInstance(DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS, MAX_LITERALS_LENGTH_SYMBOL, DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG);
+ private static final FseCompressionTable DEFAULT_MATCH_LENGTHS_TABLE = FseCompressionTable.newInstance(DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS, MAX_MATCH_LENGTH_SYMBOL, DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG);
+ private static final FseCompressionTable DEFAULT_OFFSETS_TABLE = FseCompressionTable.newInstance(DEFAULT_OFFSET_NORMALIZED_COUNTS, DEFAULT_MAX_OFFSET_CODE_SYMBOL, DEFAULT_OFFSET_NORMALIZED_COUNTS_LOG);
+
+ private SequenceEncoder()
+ {
+ }
+
+ public static int compressSequences(Object outputBase, final long outputAddress, int outputSize, SequenceStore sequences, CompressionParameters.Strategy strategy, SequenceEncodingContext workspace)
+ {
+ long output = outputAddress;
+ long outputLimit = outputAddress + outputSize;
+
+ checkArgument(outputLimit - output > 3 /* max sequence count Size */ + 1 /* encoding type flags */, "Output buffer too small");
+
+ int sequenceCount = sequences.sequenceCount;
+ if (sequenceCount < 0x7F) {
+ UNSAFE.putByte(outputBase, output, (byte) sequenceCount);
+ output++;
+ }
+ else if (sequenceCount < LONG_NUMBER_OF_SEQUENCES) {
+ UNSAFE.putByte(outputBase, output, (byte) (sequenceCount >>> 8 | 0x80));
+ UNSAFE.putByte(outputBase, output + 1, (byte) sequenceCount);
+ output += SIZE_OF_SHORT;
+ }
+ else {
+ UNSAFE.putByte(outputBase, output, (byte) 0xFF);
+ output++;
+ UNSAFE.putShort(outputBase, output, (short) (sequenceCount - LONG_NUMBER_OF_SEQUENCES));
+ output += SIZE_OF_SHORT;
+ }
+
+ if (sequenceCount == 0) {
+ return (int) (output - outputAddress);
+ }
+
+ // flags for FSE encoding type
+ long headerAddress = output++;
+
+ int maxSymbol;
+ int largestCount;
+
+ // literal lengths
+ int[] counts = workspace.counts;
+ Histogram.count(sequences.literalLengthCodes, sequenceCount, workspace.counts);
+ maxSymbol = Histogram.findMaxSymbol(counts, MAX_LITERALS_LENGTH_SYMBOL);
+ largestCount = Histogram.findLargestCount(counts, maxSymbol);
+
+ int literalsLengthEncodingType = selectEncodingType(largestCount, sequenceCount, DEFAULT_LITERAL_LENGTH_NORMALIZED_COUNTS_LOG, true, strategy);
+
+ FseCompressionTable literalLengthTable;
+ switch (literalsLengthEncodingType) {
+ case SEQUENCE_ENCODING_RLE:
+ UNSAFE.putByte(outputBase, output, sequences.literalLengthCodes[0]);
+ output++;
+ workspace.literalLengthTable.initializeRleTable(maxSymbol);
+ literalLengthTable = workspace.literalLengthTable;
+ break;
+ case SEQUENCE_ENCODING_BASIC:
+ literalLengthTable = DEFAULT_LITERAL_LENGTHS_TABLE;
+ break;
+ case SEQUENCE_ENCODING_COMPRESSED:
+ output += buildCompressionTable(
+ workspace.literalLengthTable,
+ outputBase,
+ output,
+ outputLimit,
+ sequenceCount,
+ LITERAL_LENGTH_TABLE_LOG,
+ sequences.literalLengthCodes,
+ workspace.counts,
+ maxSymbol,
+ workspace.normalizedCounts);
+ literalLengthTable = workspace.literalLengthTable;
+ break;
+ default:
+ throw new UnsupportedOperationException("not yet implemented");
+ }
+
+ // offsets
+ Histogram.count(sequences.offsetCodes, sequenceCount, workspace.counts);
+ maxSymbol = Histogram.findMaxSymbol(counts, MAX_OFFSET_CODE_SYMBOL);
+ largestCount = Histogram.findLargestCount(counts, maxSymbol);
+
+ // We can only use the basic table if max <= DEFAULT_MAX_OFFSET_CODE_SYMBOL, otherwise the offsets are too large .
+ boolean defaultAllowed = maxSymbol < DEFAULT_MAX_OFFSET_CODE_SYMBOL;
+
+ int offsetEncodingType = selectEncodingType(largestCount, sequenceCount, DEFAULT_OFFSET_NORMALIZED_COUNTS_LOG, defaultAllowed, strategy);
+
+ FseCompressionTable offsetCodeTable;
+ switch (offsetEncodingType) {
+ case SEQUENCE_ENCODING_RLE:
+ UNSAFE.putByte(outputBase, output, sequences.offsetCodes[0]);
+ output++;
+ workspace.offsetCodeTable.initializeRleTable(maxSymbol);
+ offsetCodeTable = workspace.offsetCodeTable;
+ break;
+ case SEQUENCE_ENCODING_BASIC:
+ offsetCodeTable = DEFAULT_OFFSETS_TABLE;
+ break;
+ case SEQUENCE_ENCODING_COMPRESSED:
+ output += buildCompressionTable(
+ workspace.offsetCodeTable,
+ outputBase,
+ output,
+ output + outputSize,
+ sequenceCount,
+ OFFSET_TABLE_LOG,
+ sequences.offsetCodes,
+ workspace.counts,
+ maxSymbol,
+ workspace.normalizedCounts);
+ offsetCodeTable = workspace.offsetCodeTable;
+ break;
+ default:
+ throw new UnsupportedOperationException("not yet implemented");
+ }
+
+ // match lengths
+ Histogram.count(sequences.matchLengthCodes, sequenceCount, workspace.counts);
+ maxSymbol = Histogram.findMaxSymbol(counts, MAX_MATCH_LENGTH_SYMBOL);
+ largestCount = Histogram.findLargestCount(counts, maxSymbol);
+
+ int matchLengthEncodingType = selectEncodingType(largestCount, sequenceCount, DEFAULT_MATCH_LENGTH_NORMALIZED_COUNTS_LOG, true, strategy);
+
+ FseCompressionTable matchLengthTable;
+ switch (matchLengthEncodingType) {
+ case SEQUENCE_ENCODING_RLE:
+ UNSAFE.putByte(outputBase, output, sequences.matchLengthCodes[0]);
+ output++;
+ workspace.matchLengthTable.initializeRleTable(maxSymbol);
+ matchLengthTable = workspace.matchLengthTable;
+ break;
+ case SEQUENCE_ENCODING_BASIC:
+ matchLengthTable = DEFAULT_MATCH_LENGTHS_TABLE;
+ break;
+ case SEQUENCE_ENCODING_COMPRESSED:
+ output += buildCompressionTable(
+ workspace.matchLengthTable,
+ outputBase,
+ output,
+ outputLimit,
+ sequenceCount,
+ MATCH_LENGTH_TABLE_LOG,
+ sequences.matchLengthCodes,
+ workspace.counts,
+ maxSymbol,
+ workspace.normalizedCounts);
+ matchLengthTable = workspace.matchLengthTable;
+ break;
+ default:
+ throw new UnsupportedOperationException("not yet implemented");
+ }
+
+ // flags
+ UNSAFE.putByte(outputBase, headerAddress, (byte) ((literalsLengthEncodingType << 6) | (offsetEncodingType << 4) | (matchLengthEncodingType << 2)));
+
+ output += encodeSequences(outputBase, output, outputLimit, matchLengthTable, offsetCodeTable, literalLengthTable, sequences);
+
+ return (int) (output - outputAddress);
+ }
+
+ private static int buildCompressionTable(FseCompressionTable table, Object outputBase, long output, long outputLimit, int sequenceCount, int maxTableLog, byte[] codes, int[] counts, int maxSymbol, short[] normalizedCounts)
+ {
+ int tableLog = optimalTableLog(maxTableLog, sequenceCount, maxSymbol);
+
+ // this is a minor optimization. The last symbol is embedded in the initial FSE state, so it's not part of the bitstream. We can omit it from the
+ // statistics (but only if its count is > 1). This makes the statistics a tiny bit more accurate.
+ if (counts[codes[sequenceCount - 1]] > 1) {
+ counts[codes[sequenceCount - 1]]--;
+ sequenceCount--;
+ }
+
+ FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts, sequenceCount, maxSymbol);
+ table.initialize(normalizedCounts, maxSymbol, tableLog);
+
+ return FiniteStateEntropy.writeNormalizedCounts(outputBase, output, (int) (outputLimit - output), normalizedCounts, maxSymbol, tableLog); // TODO: pass outputLimit directly
+ }
+
+ private static int encodeSequences(
+ Object outputBase,
+ long output,
+ long outputLimit,
+ FseCompressionTable matchLengthTable,
+ FseCompressionTable offsetsTable,
+ FseCompressionTable literalLengthTable,
+ SequenceStore sequences)
+ {
+ byte[] matchLengthCodes = sequences.matchLengthCodes;
+ byte[] offsetCodes = sequences.offsetCodes;
+ byte[] literalLengthCodes = sequences.literalLengthCodes;
+
+ BitOutputStream blockStream = new BitOutputStream(outputBase, output, (int) (outputLimit - output));
+
+ int sequenceCount = sequences.sequenceCount;
+
+ // first symbols
+ int matchLengthState = matchLengthTable.begin(matchLengthCodes[sequenceCount - 1]);
+ int offsetState = offsetsTable.begin(offsetCodes[sequenceCount - 1]);
+ int literalLengthState = literalLengthTable.begin(literalLengthCodes[sequenceCount - 1]);
+
+ blockStream.addBits(sequences.literalLengths[sequenceCount - 1], LITERALS_LENGTH_BITS[literalLengthCodes[sequenceCount - 1]]);
+ blockStream.addBits(sequences.matchLengths[sequenceCount - 1], MATCH_LENGTH_BITS[matchLengthCodes[sequenceCount - 1]]);
+ blockStream.addBits(sequences.offsets[sequenceCount - 1], offsetCodes[sequenceCount - 1]);
+ blockStream.flush();
+
+ if (sequenceCount >= 2) {
+ for (int n = sequenceCount - 2; n >= 0; n--) {
+ byte literalLengthCode = literalLengthCodes[n];
+ byte offsetCode = offsetCodes[n];
+ byte matchLengthCode = matchLengthCodes[n];
+
+ int literalLengthBits = LITERALS_LENGTH_BITS[literalLengthCode];
+ int offsetBits = offsetCode;
+ int matchLengthBits = MATCH_LENGTH_BITS[matchLengthCode];
+
+ // (7)
+ offsetState = offsetsTable.encode(blockStream, offsetState, offsetCode); // 15
+ matchLengthState = matchLengthTable.encode(blockStream, matchLengthState, matchLengthCode); // 24
+ literalLengthState = literalLengthTable.encode(blockStream, literalLengthState, literalLengthCode); // 33
+
+ if ((offsetBits + matchLengthBits + literalLengthBits >= 64 - 7 - (LITERAL_LENGTH_TABLE_LOG + MATCH_LENGTH_TABLE_LOG + OFFSET_TABLE_LOG))) {
+ blockStream.flush(); /* (7)*/
+ }
+
+ blockStream.addBits(sequences.literalLengths[n], literalLengthBits);
+ if (((literalLengthBits + matchLengthBits) > 24)) {
+ blockStream.flush();
+ }
+
+ blockStream.addBits(sequences.matchLengths[n], matchLengthBits);
+ if ((offsetBits + matchLengthBits + literalLengthBits > 56)) {
+ blockStream.flush();
+ }
+
+ blockStream.addBits(sequences.offsets[n], offsetBits); // 31
+ blockStream.flush(); // (7)
+ }
+ }
+
+ matchLengthTable.finish(blockStream, matchLengthState);
+ offsetsTable.finish(blockStream, offsetState);
+ literalLengthTable.finish(blockStream, literalLengthState);
+
+ int streamSize = blockStream.close();
+ checkArgument(streamSize > 0, "Output buffer too small");
+
+ return streamSize;
+ }
+
+ private static int selectEncodingType(
+ int largestCount,
+ int sequenceCount,
+ int defaultNormalizedCountsLog,
+ boolean isDefaultTableAllowed,
+ CompressionParameters.Strategy strategy)
+ {
+ if (largestCount == sequenceCount) { // => all entries are equal
+ if (isDefaultTableAllowed && sequenceCount <= 2) {
+ /* Prefer set_basic over set_rle when there are 2 or fewer symbols,
+ * since RLE uses 1 byte, but set_basic uses 5-6 bits per symbol.
+ * If basic encoding isn't possible, always choose RLE.
+ */
+ return SEQUENCE_ENCODING_BASIC;
+ }
+
+ return SEQUENCE_ENCODING_RLE;
+ }
+
+ if (strategy.ordinal() < CompressionParameters.Strategy.LAZY.ordinal()) { // TODO: more robust check. Maybe encapsulate in strategy objects
+ if (isDefaultTableAllowed) {
+ int factor = 10 - strategy.ordinal(); // TODO more robust. Move it to strategy
+ int baseLog = 3;
+ long minNumberOfSequences = ((1L << defaultNormalizedCountsLog) * factor) >> baseLog; /* 28-36 for offset, 56-72 for lengths */
+
+ if ((sequenceCount < minNumberOfSequences) || (largestCount < (sequenceCount >> (defaultNormalizedCountsLog - 1)))) {
+ /* The format allows default tables to be repeated, but it isn't useful.
+ * When using simple heuristics to select encoding type, we don't want
+ * to confuse these tables with dictionaries. When running more careful
+ * analysis, we don't need to waste time checking both repeating tables
+ * and default tables.
+ */
+ return SEQUENCE_ENCODING_BASIC;
+ }
+ }
+ }
+ else {
+ // TODO implement when other strategies are supported
+ throw new UnsupportedOperationException("not yet implemented");
+ }
+
+ return SEQUENCE_ENCODING_COMPRESSED;
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceEncodingContext.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceEncodingContext.java
new file mode 100644
index 00000000000..da5978336e8
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceEncodingContext.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Constants.MAX_LITERALS_LENGTH_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.MAX_MATCH_LENGTH_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.MAX_OFFSET_CODE_SYMBOL;
+
+class SequenceEncodingContext
+{
+ private static final int MAX_SEQUENCES = Math.max(MAX_LITERALS_LENGTH_SYMBOL, MAX_MATCH_LENGTH_SYMBOL);
+
+ public final FseCompressionTable literalLengthTable = new FseCompressionTable(Constants.LITERAL_LENGTH_TABLE_LOG, MAX_LITERALS_LENGTH_SYMBOL);
+ public final FseCompressionTable offsetCodeTable = new FseCompressionTable(Constants.OFFSET_TABLE_LOG, MAX_OFFSET_CODE_SYMBOL);
+ public final FseCompressionTable matchLengthTable = new FseCompressionTable(Constants.MATCH_LENGTH_TABLE_LOG, MAX_MATCH_LENGTH_SYMBOL);
+
+ public final int[] counts = new int[MAX_SEQUENCES + 1];
+ public final short[] normalizedCounts = new short[MAX_SEQUENCES + 1];
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceStore.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceStore.java
new file mode 100644
index 00000000000..f01d54f0527
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/SequenceStore.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_LONG;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
+
+class SequenceStore
+{
+ public final byte[] literalsBuffer;
+ public int literalsLength;
+
+ public final int[] offsets;
+ public final int[] literalLengths;
+ public final int[] matchLengths;
+ public int sequenceCount;
+
+ public final byte[] literalLengthCodes;
+ public final byte[] matchLengthCodes;
+ public final byte[] offsetCodes;
+
+ public LongField longLengthField;
+ public int longLengthPosition;
+
+ public enum LongField
+ {
+ LITERAL, MATCH
+ }
+
+ private static final byte[] LITERAL_LENGTH_CODE = {0, 1, 2, 3, 4, 5, 6, 7,
+ 8, 9, 10, 11, 12, 13, 14, 15,
+ 16, 16, 17, 17, 18, 18, 19, 19,
+ 20, 20, 20, 20, 21, 21, 21, 21,
+ 22, 22, 22, 22, 22, 22, 22, 22,
+ 23, 23, 23, 23, 23, 23, 23, 23,
+ 24, 24, 24, 24, 24, 24, 24, 24,
+ 24, 24, 24, 24, 24, 24, 24, 24};
+
+ private static final byte[] MATCH_LENGTH_CODE = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
+ 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 36, 36, 37, 37, 37, 37,
+ 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39,
+ 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40,
+ 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41,
+ 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42,
+ 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42};
+
+ public SequenceStore(int blockSize, int maxSequences)
+ {
+ offsets = new int[maxSequences];
+ literalLengths = new int[maxSequences];
+ matchLengths = new int[maxSequences];
+
+ literalLengthCodes = new byte[maxSequences];
+ matchLengthCodes = new byte[maxSequences];
+ offsetCodes = new byte[maxSequences];
+
+ literalsBuffer = new byte[blockSize];
+
+ reset();
+ }
+
+ public void appendLiterals(Object inputBase, long inputAddress, int inputSize)
+ {
+ UNSAFE.copyMemory(inputBase, inputAddress, literalsBuffer, ARRAY_BYTE_BASE_OFFSET + literalsLength, inputSize);
+ literalsLength += inputSize;
+ }
+
+ public void storeSequence(Object literalBase, long literalAddress, int literalLength, int offsetCode, int matchLengthBase)
+ {
+ long input = literalAddress;
+ long output = ARRAY_BYTE_BASE_OFFSET + literalsLength;
+ int copied = 0;
+ do {
+ UNSAFE.putLong(literalsBuffer, output, UNSAFE.getLong(literalBase, input));
+ input += SIZE_OF_LONG;
+ output += SIZE_OF_LONG;
+ copied += SIZE_OF_LONG;
+ }
+ while (copied < literalLength);
+
+ literalsLength += literalLength;
+
+ if (literalLength > 65535) {
+ longLengthField = LongField.LITERAL;
+ longLengthPosition = sequenceCount;
+ }
+ literalLengths[sequenceCount] = literalLength;
+
+ offsets[sequenceCount] = offsetCode + 1;
+
+ if (matchLengthBase > 65535) {
+ longLengthField = LongField.MATCH;
+ longLengthPosition = sequenceCount;
+ }
+
+ matchLengths[sequenceCount] = matchLengthBase;
+
+ sequenceCount++;
+ }
+
+ public void reset()
+ {
+ literalsLength = 0;
+ sequenceCount = 0;
+ longLengthField = null;
+ }
+
+ public void generateCodes()
+ {
+ for (int i = 0; i < sequenceCount; ++i) {
+ literalLengthCodes[i] = (byte) literalLengthToCode(literalLengths[i]);
+ offsetCodes[i] = (byte) Util.highestBit(offsets[i]);
+ matchLengthCodes[i] = (byte) matchLengthToCode(matchLengths[i]);
+ }
+
+ if (longLengthField == LongField.LITERAL) {
+ literalLengthCodes[longLengthPosition] = Constants.MAX_LITERALS_LENGTH_SYMBOL;
+ }
+ if (longLengthField == LongField.MATCH) {
+ matchLengthCodes[longLengthPosition] = Constants.MAX_MATCH_LENGTH_SYMBOL;
+ }
+ }
+
+ private static int literalLengthToCode(int literalLength)
+ {
+ if (literalLength >= 64) {
+ return Util.highestBit(literalLength) + 19;
+ }
+ else {
+ return LITERAL_LENGTH_CODE[literalLength];
+ }
+ }
+
+ /*
+ * matchLengthBase = matchLength - MINMATCH
+ * (that's how it's stored in SequenceStore)
+ */
+ private static int matchLengthToCode(int matchLengthBase)
+ {
+ if (matchLengthBase >= 128) {
+ return Util.highestBit(matchLengthBase) + 36;
+ }
+ else {
+ return MATCH_LENGTH_CODE[matchLengthBase];
+ }
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/UnsafeUtil.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/UnsafeUtil.java
new file mode 100644
index 00000000000..decde678321
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/UnsafeUtil.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import ai.vespa.airlift.compress.IncompatibleJvmException;
+import sun.misc.Unsafe;
+
+import java.lang.reflect.Field;
+import java.nio.Buffer;
+import java.nio.ByteOrder;
+
+import static java.lang.String.format;
+
+final class UnsafeUtil
+{
+ public static final Unsafe UNSAFE;
+ private static final long ADDRESS_OFFSET;
+
+ private UnsafeUtil() {}
+
+ static {
+ ByteOrder order = ByteOrder.nativeOrder();
+ if (!order.equals(ByteOrder.LITTLE_ENDIAN)) {
+ throw new IncompatibleJvmException(format("Zstandard requires a little endian platform (found %s)", order));
+ }
+
+ try {
+ Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
+ theUnsafe.setAccessible(true);
+ UNSAFE = (Unsafe) theUnsafe.get(null);
+ }
+ catch (Exception e) {
+ throw new IncompatibleJvmException("Zstandard requires access to sun.misc.Unsafe");
+ }
+
+ try {
+ // fetch the address field for direct buffers
+ ADDRESS_OFFSET = UNSAFE.objectFieldOffset(Buffer.class.getDeclaredField("address"));
+ }
+ catch (NoSuchFieldException e) {
+ throw new IncompatibleJvmException("Zstandard requires access to java.nio.Buffer raw address field");
+ }
+ }
+
+ public static long getAddress(Buffer buffer)
+ {
+ if (!buffer.isDirect()) {
+ throw new IllegalArgumentException("buffer is not direct");
+ }
+
+ return UNSAFE.getLong(buffer, ADDRESS_OFFSET);
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Util.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Util.java
new file mode 100644
index 00000000000..d0e622f02c9
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Util.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import ai.vespa.airlift.compress.MalformedInputException;
+
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_SHORT;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+
+final class Util
+{
+ private Util()
+ {
+ }
+
+ public static int highestBit(int value)
+ {
+ return 31 - Integer.numberOfLeadingZeros(value);
+ }
+
+ public static boolean isPowerOf2(int value)
+ {
+ return (value & (value - 1)) == 0;
+ }
+
+ public static int mask(int bits)
+ {
+ return (1 << bits) - 1;
+ }
+
+ public static void verify(boolean condition, long offset, String reason)
+ {
+ if (!condition) {
+ throw new MalformedInputException(offset, reason);
+ }
+ }
+
+ public static void checkArgument(boolean condition, String reason)
+ {
+ if (!condition) {
+ throw new IllegalArgumentException(reason);
+ }
+ }
+
+ public static void checkState(boolean condition, String reason)
+ {
+ if (!condition) {
+ throw new IllegalStateException(reason);
+ }
+ }
+
+ public static MalformedInputException fail(long offset, String reason)
+ {
+ throw new MalformedInputException(offset, reason);
+ }
+
+ public static int cycleLog(int hashLog, CompressionParameters.Strategy strategy)
+ {
+ int cycleLog = hashLog;
+ if (strategy == CompressionParameters.Strategy.BTLAZY2 || strategy == CompressionParameters.Strategy.BTOPT || strategy == CompressionParameters.Strategy.BTULTRA) {
+ cycleLog = hashLog - 1;
+ }
+ return cycleLog;
+ }
+
+ public static void put24BitLittleEndian(Object outputBase, long outputAddress, int value)
+ {
+ UNSAFE.putShort(outputBase, outputAddress, (short) value);
+ UNSAFE.putByte(outputBase, outputAddress + SIZE_OF_SHORT, (byte) (value >>> Short.SIZE));
+ }
+
+ // provides the minimum logSize to safely represent a distribution
+ public static int minTableLog(int inputSize, int maxSymbolValue)
+ {
+ if (inputSize <= 1) {
+ throw new IllegalArgumentException("Not supported. RLE should be used instead"); // TODO
+ }
+
+ int minBitsSrc = highestBit((inputSize - 1)) + 1;
+ int minBitsSymbols = highestBit(maxSymbolValue) + 2;
+ return Math.min(minBitsSrc, minBitsSymbols);
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/XxHash64.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/XxHash64.java
new file mode 100644
index 00000000000..df2c869d11b
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/XxHash64.java
@@ -0,0 +1,286 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_LONG;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static java.lang.Long.rotateLeft;
+import static java.lang.Math.min;
+import static java.lang.String.format;
+import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
+
+// forked from https://github.com/airlift/slice
+final class XxHash64
+{
+ private static final long PRIME64_1 = 0x9E3779B185EBCA87L;
+ private static final long PRIME64_2 = 0xC2B2AE3D27D4EB4FL;
+ private static final long PRIME64_3 = 0x165667B19E3779F9L;
+ private static final long PRIME64_4 = 0x85EBCA77C2b2AE63L;
+ private static final long PRIME64_5 = 0x27D4EB2F165667C5L;
+
+ private static final long DEFAULT_SEED = 0;
+
+ private final long seed;
+
+ private static final long BUFFER_ADDRESS = ARRAY_BYTE_BASE_OFFSET;
+ private final byte[] buffer = new byte[32];
+ private int bufferSize;
+
+ private long bodyLength;
+
+ private long v1;
+ private long v2;
+ private long v3;
+ private long v4;
+
+ public static long hash(long seed, Object base, long address, int length)
+ {
+ XxHash64 hasher = new XxHash64(seed);
+ hasher.updateHash(base, address, length);
+ return hasher.hash();
+ }
+
+ public XxHash64()
+ {
+ this(DEFAULT_SEED);
+ }
+
+ public XxHash64(long seed)
+ {
+ this.seed = seed;
+ this.v1 = seed + PRIME64_1 + PRIME64_2;
+ this.v2 = seed + PRIME64_2;
+ this.v3 = seed;
+ this.v4 = seed - PRIME64_1;
+ }
+
+ public XxHash64 update(byte[] data)
+ {
+ return update(data, 0, data.length);
+ }
+
+ public XxHash64 update(byte[] data, int offset, int length)
+ {
+ checkPositionIndexes(offset, offset + length, data.length);
+ updateHash(data, ARRAY_BYTE_BASE_OFFSET + offset, length);
+ return this;
+ }
+
+ public long hash()
+ {
+ long hash;
+ if (bodyLength > 0) {
+ hash = computeBody();
+ }
+ else {
+ hash = seed + PRIME64_5;
+ }
+
+ hash += bodyLength + bufferSize;
+
+ return updateTail(hash, buffer, BUFFER_ADDRESS, 0, bufferSize);
+ }
+
+ private static String badPositionIndex(long index, long size, String desc)
+ {
+ if (index < 0) {
+ return format("%s (%s) must not be negative", desc, index);
+ }
+ else if (size < 0) {
+ throw new IllegalArgumentException("negative size: " + size);
+ }
+ else { // index > size
+ return format("%s (%s) must not be greater than size (%s)", desc, index, size);
+ }
+ }
+
+ private static String badPositionIndexes(int start, int end, int size)
+ {
+ if (start < 0 || start > size) {
+ return badPositionIndex(start, size, "start index");
+ }
+ if (end < 0 || end > size) {
+ return badPositionIndex(end, size, "end index");
+ }
+ // end < start
+ return format("end index (%s) must not be less than start index (%s)", end, start);
+ }
+
+ private static void checkPositionIndexes(int start, int end, int size)
+ {
+ // Carefully optimized for execution by hotspot
+ if (start < 0 || end < start || end > size) {
+ throw new IndexOutOfBoundsException(badPositionIndexes(start, end, size));
+ }
+ }
+
+ private long computeBody()
+ {
+ long hash = rotateLeft(v1, 1) + rotateLeft(v2, 7) + rotateLeft(v3, 12) + rotateLeft(v4, 18);
+
+ hash = update(hash, v1);
+ hash = update(hash, v2);
+ hash = update(hash, v3);
+ hash = update(hash, v4);
+
+ return hash;
+ }
+
+ private void updateHash(Object base, long address, int length)
+ {
+ if (bufferSize > 0) {
+ int available = min(32 - bufferSize, length);
+
+ UNSAFE.copyMemory(base, address, buffer, BUFFER_ADDRESS + bufferSize, available);
+
+ bufferSize += available;
+ address += available;
+ length -= available;
+
+ if (bufferSize == 32) {
+ updateBody(buffer, BUFFER_ADDRESS, bufferSize);
+ bufferSize = 0;
+ }
+ }
+
+ if (length >= 32) {
+ int index = updateBody(base, address, length);
+ address += index;
+ length -= index;
+ }
+
+ if (length > 0) {
+ UNSAFE.copyMemory(base, address, buffer, BUFFER_ADDRESS, length);
+ bufferSize = length;
+ }
+ }
+
+ private int updateBody(Object base, long address, int length)
+ {
+ int remaining = length;
+ while (remaining >= 32) {
+ v1 = mix(v1, UNSAFE.getLong(base, address));
+ v2 = mix(v2, UNSAFE.getLong(base, address + 8));
+ v3 = mix(v3, UNSAFE.getLong(base, address + 16));
+ v4 = mix(v4, UNSAFE.getLong(base, address + 24));
+
+ address += 32;
+ remaining -= 32;
+ }
+
+ int index = length - remaining;
+ bodyLength += index;
+ return index;
+ }
+
+ public static long hash(long value)
+ {
+ long hash = DEFAULT_SEED + PRIME64_5 + SIZE_OF_LONG;
+ hash = updateTail(hash, value);
+ hash = finalShuffle(hash);
+
+ return hash;
+ }
+
+ private static long updateTail(long hash, Object base, long address, int index, int length)
+ {
+ while (index <= length - 8) {
+ hash = updateTail(hash, UNSAFE.getLong(base, address + index));
+ index += 8;
+ }
+
+ if (index <= length - 4) {
+ hash = updateTail(hash, UNSAFE.getInt(base, address + index));
+ index += 4;
+ }
+
+ while (index < length) {
+ hash = updateTail(hash, UNSAFE.getByte(base, address + index));
+ index++;
+ }
+
+ hash = finalShuffle(hash);
+
+ return hash;
+ }
+
+ private static long updateBody(long seed, Object base, long address, int length)
+ {
+ long v1 = seed + PRIME64_1 + PRIME64_2;
+ long v2 = seed + PRIME64_2;
+ long v3 = seed;
+ long v4 = seed - PRIME64_1;
+
+ int remaining = length;
+ while (remaining >= 32) {
+ v1 = mix(v1, UNSAFE.getLong(base, address));
+ v2 = mix(v2, UNSAFE.getLong(base, address + 8));
+ v3 = mix(v3, UNSAFE.getLong(base, address + 16));
+ v4 = mix(v4, UNSAFE.getLong(base, address + 24));
+
+ address += 32;
+ remaining -= 32;
+ }
+
+ long hash = rotateLeft(v1, 1) + rotateLeft(v2, 7) + rotateLeft(v3, 12) + rotateLeft(v4, 18);
+
+ hash = update(hash, v1);
+ hash = update(hash, v2);
+ hash = update(hash, v3);
+ hash = update(hash, v4);
+
+ return hash;
+ }
+
+ private static long mix(long current, long value)
+ {
+ return rotateLeft(current + value * PRIME64_2, 31) * PRIME64_1;
+ }
+
+ private static long update(long hash, long value)
+ {
+ long temp = hash ^ mix(0, value);
+ return temp * PRIME64_1 + PRIME64_4;
+ }
+
+ private static long updateTail(long hash, long value)
+ {
+ long temp = hash ^ mix(0, value);
+ return rotateLeft(temp, 27) * PRIME64_1 + PRIME64_4;
+ }
+
+ private static long updateTail(long hash, int value)
+ {
+ long unsigned = value & 0xFFFF_FFFFL;
+ long temp = hash ^ (unsigned * PRIME64_1);
+ return rotateLeft(temp, 23) * PRIME64_2 + PRIME64_3;
+ }
+
+ private static long updateTail(long hash, byte value)
+ {
+ int unsigned = value & 0xFF;
+ long temp = hash ^ (unsigned * PRIME64_5);
+ return rotateLeft(temp, 11) * PRIME64_1;
+ }
+
+ private static long finalShuffle(long hash)
+ {
+ hash ^= hash >>> 33;
+ hash *= PRIME64_2;
+ hash ^= hash >>> 29;
+ hash *= PRIME64_3;
+ hash ^= hash >>> 32;
+ return hash;
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdBlockDecompressor.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdBlockDecompressor.java
new file mode 100644
index 00000000000..08adf88cfa6
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdBlockDecompressor.java
@@ -0,0 +1,810 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import java.util.Arrays;
+
+import static ai.vespa.airlift.zstd.BitInputStream.peekBits;
+import static ai.vespa.airlift.zstd.Constants.COMPRESSED_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.COMPRESSED_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.DEFAULT_MAX_OFFSET_CODE_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.LITERALS_LENGTH_BITS;
+import static ai.vespa.airlift.zstd.Constants.LITERAL_LENGTH_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Constants.LONG_NUMBER_OF_SEQUENCES;
+import static ai.vespa.airlift.zstd.Constants.MATCH_LENGTH_BITS;
+import static ai.vespa.airlift.zstd.Constants.MATCH_LENGTH_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Constants.MAX_BLOCK_SIZE;
+import static ai.vespa.airlift.zstd.Constants.MAX_LITERALS_LENGTH_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.MAX_MATCH_LENGTH_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.MIN_BLOCK_SIZE;
+import static ai.vespa.airlift.zstd.Constants.MIN_SEQUENCES_SIZE;
+import static ai.vespa.airlift.zstd.Constants.OFFSET_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Constants.RAW_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.RAW_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.RLE_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.RLE_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.SEQUENCE_ENCODING_BASIC;
+import static ai.vespa.airlift.zstd.Constants.SEQUENCE_ENCODING_COMPRESSED;
+import static ai.vespa.airlift.zstd.Constants.SEQUENCE_ENCODING_REPEAT;
+import static ai.vespa.airlift.zstd.Constants.SEQUENCE_ENCODING_RLE;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_INT;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_LONG;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_SHORT;
+import static ai.vespa.airlift.zstd.Constants.TREELESS_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.fail;
+import static ai.vespa.airlift.zstd.Util.mask;
+import static ai.vespa.airlift.zstd.Util.verify;
+import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
+
+/**
+ * Handles decompression of all blocks in a single frame.
+ **/
+class ZstdBlockDecompressor
+{
+ private static final int[] DEC_32_TABLE = {4, 1, 2, 1, 4, 4, 4, 4};
+ private static final int[] DEC_64_TABLE = {0, 0, 0, -1, 0, 1, 2, 3};
+
+ private static final int MAX_WINDOW_SIZE = 1 << 23;
+
+ private static final int[] LITERALS_LENGTH_BASE = {
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 16, 18, 20, 22, 24, 28, 32, 40, 48, 64, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000,
+ 0x2000, 0x4000, 0x8000, 0x10000};
+
+ private static final int[] MATCH_LENGTH_BASE = {
+ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
+ 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
+ 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, 99, 0x83, 0x103, 0x203, 0x403, 0x803,
+ 0x1003, 0x2003, 0x4003, 0x8003, 0x10003};
+
+ private static final int[] OFFSET_CODES_BASE = {
+ 0, 1, 1, 5, 0xD, 0x1D, 0x3D, 0x7D,
+ 0xFD, 0x1FD, 0x3FD, 0x7FD, 0xFFD, 0x1FFD, 0x3FFD, 0x7FFD,
+ 0xFFFD, 0x1FFFD, 0x3FFFD, 0x7FFFD, 0xFFFFD, 0x1FFFFD, 0x3FFFFD, 0x7FFFFD,
+ 0xFFFFFD, 0x1FFFFFD, 0x3FFFFFD, 0x7FFFFFD, 0xFFFFFFD};
+
+ private static final FiniteStateEntropy.Table DEFAULT_LITERALS_LENGTH_TABLE = new FiniteStateEntropy.Table(
+ 6,
+ new int[] {
+ 0, 16, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 16, 32, 0, 0, 48, 16, 32, 32, 32,
+ 32, 32, 32, 32, 32, 0, 32, 32, 32, 32, 32, 32, 0, 0, 0, 0},
+ new byte[] {
+ 0, 0, 1, 3, 4, 6, 7, 9, 10, 12, 14, 16, 18, 19, 21, 22, 24, 25, 26, 27, 29, 31, 0, 1, 2, 4, 5, 7, 8, 10, 11, 13, 16, 17, 19, 20, 22, 23, 25, 25, 26, 28, 30, 0,
+ 1, 2, 3, 5, 6, 8, 9, 11, 12, 15, 17, 18, 20, 21, 23, 24, 35, 34, 33, 32},
+ new byte[] {
+ 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 4, 4, 5, 6, 6, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
+ 6, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6});
+
+ private static final FiniteStateEntropy.Table DEFAULT_OFFSET_CODES_TABLE = new FiniteStateEntropy.Table(
+ 5,
+ new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 16, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0},
+ new byte[] {0, 6, 9, 15, 21, 3, 7, 12, 18, 23, 5, 8, 14, 20, 2, 7, 11, 17, 22, 4, 8, 13, 19, 1, 6, 10, 16, 28, 27, 26, 25, 24},
+ new byte[] {5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 4, 5, 5, 5, 5, 5, 5, 5});
+
+ private static final FiniteStateEntropy.Table DEFAULT_MATCH_LENGTH_TABLE = new FiniteStateEntropy.Table(
+ 6,
+ new int[] {
+ 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 32, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 48, 16, 32, 32, 32, 32,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ new byte[] {
+ 0, 1, 2, 3, 5, 6, 8, 10, 13, 16, 19, 22, 25, 28, 31, 33, 35, 37, 39, 41, 43, 45, 1, 2, 3, 4, 6, 7, 9, 12, 15, 18, 21, 24, 27, 30, 32, 34, 36, 38, 40, 42, 44, 1,
+ 1, 2, 4, 5, 7, 8, 11, 14, 17, 20, 23, 26, 29, 52, 51, 50, 49, 48, 47, 46},
+ new byte[] {
+ 6, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6,
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6});
+
+ private final byte[] literals = new byte[MAX_BLOCK_SIZE + SIZE_OF_LONG]; // extra space to allow for long-at-a-time copy
+
+ // current buffer containing literals
+ private Object literalsBase;
+ private long literalsAddress;
+ private long literalsLimit;
+
+ private final int[] previousOffsets = new int[3];
+
+ private final FiniteStateEntropy.Table literalsLengthTable = new FiniteStateEntropy.Table(LITERAL_LENGTH_TABLE_LOG);
+ private final FiniteStateEntropy.Table offsetCodesTable = new FiniteStateEntropy.Table(OFFSET_TABLE_LOG);
+ private final FiniteStateEntropy.Table matchLengthTable = new FiniteStateEntropy.Table(MATCH_LENGTH_TABLE_LOG);
+
+ private FiniteStateEntropy.Table currentLiteralsLengthTable;
+ private FiniteStateEntropy.Table currentOffsetCodesTable;
+ private FiniteStateEntropy.Table currentMatchLengthTable;
+
+ private final Huffman huffman = new Huffman();
+ private final FseTableReader fse = new FseTableReader();
+
+ private final FrameHeader frameHeader;
+
+ public ZstdBlockDecompressor(FrameHeader frameHeader)
+ {
+ this.frameHeader = frameHeader;
+
+ previousOffsets[0] = 1;
+ previousOffsets[1] = 4;
+ previousOffsets[2] = 8;
+ }
+
+ int decompressBlock(
+ int blockType,
+ int blockSize,
+ final Object inputBase,
+ final long inputAddress,
+ final long inputLimit,
+ final Object outputBase,
+ final long outputAddress,
+ final long outputLimit)
+ {
+ int decodedSize;
+ switch (blockType) {
+ case RAW_BLOCK:
+ verify(inputAddress + blockSize <= inputLimit, inputAddress, "Not enough input bytes");
+ decodedSize = decodeRawBlock(inputBase, inputAddress, blockSize, outputBase, outputAddress, outputLimit);
+ break;
+ case RLE_BLOCK:
+ verify(inputAddress + 1 <= inputLimit, inputAddress, "Not enough input bytes");
+ decodedSize = decodeRleBlock(blockSize, inputBase, inputAddress, outputBase, outputAddress, outputLimit);
+ break;
+ case COMPRESSED_BLOCK:
+ verify(inputAddress + blockSize <= inputLimit, inputAddress, "Not enough input bytes");
+ decodedSize = decodeCompressedBlock(inputBase, inputAddress, blockSize, outputBase, outputAddress, outputLimit, frameHeader.windowSize, outputAddress);
+ break;
+ default:
+ throw fail(inputAddress, "Invalid block type");
+ }
+ return decodedSize;
+ }
+
+ static int decodeRawBlock(Object inputBase, long inputAddress, int blockSize, Object outputBase, long outputAddress, long outputLimit)
+ {
+ verify(outputAddress + blockSize <= outputLimit, inputAddress, "Output buffer too small");
+
+ UNSAFE.copyMemory(inputBase, inputAddress, outputBase, outputAddress, blockSize);
+ return blockSize;
+ }
+
+ static int decodeRleBlock(int size, Object inputBase, long inputAddress, Object outputBase, long outputAddress, long outputLimit)
+ {
+ verify(outputAddress + size <= outputLimit, inputAddress, "Output buffer too small");
+
+ long output = outputAddress;
+ long value = UNSAFE.getByte(inputBase, inputAddress) & 0xFFL;
+
+ int remaining = size;
+ if (remaining >= SIZE_OF_LONG) {
+ long packed = value
+ | (value << 8)
+ | (value << 16)
+ | (value << 24)
+ | (value << 32)
+ | (value << 40)
+ | (value << 48)
+ | (value << 56);
+
+ do {
+ UNSAFE.putLong(outputBase, output, packed);
+ output += SIZE_OF_LONG;
+ remaining -= SIZE_OF_LONG;
+ }
+ while (remaining >= SIZE_OF_LONG);
+ }
+
+ for (int i = 0; i < remaining; i++) {
+ UNSAFE.putByte(outputBase, output, (byte) value);
+ output++;
+ }
+
+ return size;
+ }
+
+ @SuppressWarnings("fallthrough")
+ int decodeCompressedBlock(Object inputBase, final long inputAddress, int blockSize, Object outputBase, long outputAddress, long outputLimit, int windowSize, long outputAbsoluteBaseAddress)
+ {
+ long inputLimit = inputAddress + blockSize;
+ long input = inputAddress;
+
+ verify(blockSize <= MAX_BLOCK_SIZE, input, "Expected match length table to be present");
+ verify(blockSize >= MIN_BLOCK_SIZE, input, "Compressed block size too small");
+
+ // decode literals
+ int literalsBlockType = UNSAFE.getByte(inputBase, input) & 0b11;
+
+ switch (literalsBlockType) {
+ case RAW_LITERALS_BLOCK: {
+ input += decodeRawLiterals(inputBase, input, inputLimit);
+ break;
+ }
+ case RLE_LITERALS_BLOCK: {
+ input += decodeRleLiterals(inputBase, input, blockSize);
+ break;
+ }
+ case TREELESS_LITERALS_BLOCK:
+ verify(huffman.isLoaded(), input, "Dictionary is corrupted");
+ case COMPRESSED_LITERALS_BLOCK: {
+ input += decodeCompressedLiterals(inputBase, input, blockSize, literalsBlockType);
+ break;
+ }
+ default:
+ throw fail(input, "Invalid literals block encoding type");
+ }
+
+ verify(windowSize <= MAX_WINDOW_SIZE, input, "Window size too large (not yet supported)");
+
+ return decompressSequences(
+ inputBase, input, inputAddress + blockSize,
+ outputBase, outputAddress, outputLimit,
+ literalsBase, literalsAddress, literalsLimit,
+ outputAbsoluteBaseAddress);
+ }
+
+ private int decompressSequences(
+ final Object inputBase, final long inputAddress, final long inputLimit,
+ final Object outputBase, final long outputAddress, final long outputLimit,
+ final Object literalsBase, final long literalsAddress, final long literalsLimit,
+ long outputAbsoluteBaseAddress)
+ {
+ final long fastOutputLimit = outputLimit - SIZE_OF_LONG;
+ final long fastMatchOutputLimit = fastOutputLimit - SIZE_OF_LONG;
+
+ long input = inputAddress;
+ long output = outputAddress;
+
+ long literalsInput = literalsAddress;
+
+ int size = (int) (inputLimit - inputAddress);
+ verify(size >= MIN_SEQUENCES_SIZE, input, "Not enough input bytes");
+
+ // decode header
+ int sequenceCount = UNSAFE.getByte(inputBase, input++) & 0xFF;
+ if (sequenceCount != 0) {
+ if (sequenceCount == 255) {
+ verify(input + SIZE_OF_SHORT <= inputLimit, input, "Not enough input bytes");
+ sequenceCount = (UNSAFE.getShort(inputBase, input) & 0xFFFF) + LONG_NUMBER_OF_SEQUENCES;
+ input += SIZE_OF_SHORT;
+ }
+ else if (sequenceCount > 127) {
+ verify(input < inputLimit, input, "Not enough input bytes");
+ sequenceCount = ((sequenceCount - 128) << 8) + (UNSAFE.getByte(inputBase, input++) & 0xFF);
+ }
+
+ verify(input + SIZE_OF_INT <= inputLimit, input, "Not enough input bytes");
+
+ byte type = UNSAFE.getByte(inputBase, input++);
+
+ int literalsLengthType = (type & 0xFF) >>> 6;
+ int offsetCodesType = (type >>> 4) & 0b11;
+ int matchLengthType = (type >>> 2) & 0b11;
+
+ input = computeLiteralsTable(literalsLengthType, inputBase, input, inputLimit);
+ input = computeOffsetsTable(offsetCodesType, inputBase, input, inputLimit);
+ input = computeMatchLengthTable(matchLengthType, inputBase, input, inputLimit);
+
+ // decompress sequences
+ BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, input, inputLimit);
+ initializer.initialize();
+ int bitsConsumed = initializer.getBitsConsumed();
+ long bits = initializer.getBits();
+ long currentAddress = initializer.getCurrentAddress();
+
+ FiniteStateEntropy.Table currentLiteralsLengthTable = this.currentLiteralsLengthTable;
+ FiniteStateEntropy.Table currentOffsetCodesTable = this.currentOffsetCodesTable;
+ FiniteStateEntropy.Table currentMatchLengthTable = this.currentMatchLengthTable;
+
+ int literalsLengthState = (int) peekBits(bitsConsumed, bits, currentLiteralsLengthTable.log2Size);
+ bitsConsumed += currentLiteralsLengthTable.log2Size;
+
+ int offsetCodesState = (int) peekBits(bitsConsumed, bits, currentOffsetCodesTable.log2Size);
+ bitsConsumed += currentOffsetCodesTable.log2Size;
+
+ int matchLengthState = (int) peekBits(bitsConsumed, bits, currentMatchLengthTable.log2Size);
+ bitsConsumed += currentMatchLengthTable.log2Size;
+
+ int[] previousOffsets = this.previousOffsets;
+
+ byte[] literalsLengthNumbersOfBits = currentLiteralsLengthTable.numberOfBits;
+ int[] literalsLengthNewStates = currentLiteralsLengthTable.newState;
+ byte[] literalsLengthSymbols = currentLiteralsLengthTable.symbol;
+
+ byte[] matchLengthNumbersOfBits = currentMatchLengthTable.numberOfBits;
+ int[] matchLengthNewStates = currentMatchLengthTable.newState;
+ byte[] matchLengthSymbols = currentMatchLengthTable.symbol;
+
+ byte[] offsetCodesNumbersOfBits = currentOffsetCodesTable.numberOfBits;
+ int[] offsetCodesNewStates = currentOffsetCodesTable.newState;
+ byte[] offsetCodesSymbols = currentOffsetCodesTable.symbol;
+
+ while (sequenceCount > 0) {
+ sequenceCount--;
+
+ BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
+ loader.load();
+ bitsConsumed = loader.getBitsConsumed();
+ bits = loader.getBits();
+ currentAddress = loader.getCurrentAddress();
+ if (loader.isOverflow()) {
+ verify(sequenceCount == 0, input, "Not all sequences were consumed");
+ break;
+ }
+
+ // decode sequence
+ int literalsLengthCode = literalsLengthSymbols[literalsLengthState];
+ int matchLengthCode = matchLengthSymbols[matchLengthState];
+ int offsetCode = offsetCodesSymbols[offsetCodesState];
+
+ int literalsLengthBits = LITERALS_LENGTH_BITS[literalsLengthCode];
+ int matchLengthBits = MATCH_LENGTH_BITS[matchLengthCode];
+ int offsetBits = offsetCode;
+
+ int offset = OFFSET_CODES_BASE[offsetCode];
+ if (offsetCode > 0) {
+ offset += peekBits(bitsConsumed, bits, offsetBits);
+ bitsConsumed += offsetBits;
+ }
+
+ if (offsetCode <= 1) {
+ if (literalsLengthCode == 0) {
+ offset++;
+ }
+
+ if (offset != 0) {
+ int temp;
+ if (offset == 3) {
+ temp = previousOffsets[0] - 1;
+ }
+ else {
+ temp = previousOffsets[offset];
+ }
+
+ if (temp == 0) {
+ temp = 1;
+ }
+
+ if (offset != 1) {
+ previousOffsets[2] = previousOffsets[1];
+ }
+ previousOffsets[1] = previousOffsets[0];
+ previousOffsets[0] = temp;
+
+ offset = temp;
+ }
+ else {
+ offset = previousOffsets[0];
+ }
+ }
+ else {
+ previousOffsets[2] = previousOffsets[1];
+ previousOffsets[1] = previousOffsets[0];
+ previousOffsets[0] = offset;
+ }
+
+ int matchLength = MATCH_LENGTH_BASE[matchLengthCode];
+ if (matchLengthCode > 31) {
+ matchLength += peekBits(bitsConsumed, bits, matchLengthBits);
+ bitsConsumed += matchLengthBits;
+ }
+
+ int literalsLength = LITERALS_LENGTH_BASE[literalsLengthCode];
+ if (literalsLengthCode > 15) {
+ literalsLength += peekBits(bitsConsumed, bits, literalsLengthBits);
+ bitsConsumed += literalsLengthBits;
+ }
+
+ int totalBits = literalsLengthBits + matchLengthBits + offsetBits;
+ if (totalBits > 64 - 7 - (LITERAL_LENGTH_TABLE_LOG + MATCH_LENGTH_TABLE_LOG + OFFSET_TABLE_LOG)) {
+ BitInputStream.Loader loader1 = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
+ loader1.load();
+
+ bitsConsumed = loader1.getBitsConsumed();
+ bits = loader1.getBits();
+ currentAddress = loader1.getCurrentAddress();
+ }
+
+ int numberOfBits;
+
+ numberOfBits = literalsLengthNumbersOfBits[literalsLengthState];
+ literalsLengthState = (int) (literalsLengthNewStates[literalsLengthState] + peekBits(bitsConsumed, bits, numberOfBits)); // <= 9 bits
+ bitsConsumed += numberOfBits;
+
+ numberOfBits = matchLengthNumbersOfBits[matchLengthState];
+ matchLengthState = (int) (matchLengthNewStates[matchLengthState] + peekBits(bitsConsumed, bits, numberOfBits)); // <= 9 bits
+ bitsConsumed += numberOfBits;
+
+ numberOfBits = offsetCodesNumbersOfBits[offsetCodesState];
+ offsetCodesState = (int) (offsetCodesNewStates[offsetCodesState] + peekBits(bitsConsumed, bits, numberOfBits)); // <= 8 bits
+ bitsConsumed += numberOfBits;
+
+ final long literalOutputLimit = output + literalsLength;
+ final long matchOutputLimit = literalOutputLimit + matchLength;
+
+ verify(matchOutputLimit <= outputLimit, input, "Output buffer too small");
+ long literalEnd = literalsInput + literalsLength;
+ verify(literalEnd <= literalsLimit, input, "Input is corrupted");
+
+ long matchAddress = literalOutputLimit - offset;
+ verify(matchAddress >= outputAbsoluteBaseAddress, input, "Input is corrupted");
+
+ if (literalOutputLimit > fastOutputLimit) {
+ executeLastSequence(outputBase, output, literalOutputLimit, matchOutputLimit, fastOutputLimit, literalsInput, matchAddress);
+ }
+ else {
+ // copy literals. literalOutputLimit <= fastOutputLimit, so we can copy
+ // long at a time with over-copy
+ output = copyLiterals(outputBase, literalsBase, output, literalsInput, literalOutputLimit);
+ copyMatch(outputBase, fastOutputLimit, output, offset, matchOutputLimit, matchAddress, matchLength, fastMatchOutputLimit);
+ }
+ output = matchOutputLimit;
+ literalsInput = literalEnd;
+ }
+ }
+
+ // last literal segment
+ output = copyLastLiteral(outputBase, literalsBase, literalsLimit, output, literalsInput);
+
+ return (int) (output - outputAddress);
+ }
+
+ private long copyLastLiteral(Object outputBase, Object literalsBase, long literalsLimit, long output, long literalsInput)
+ {
+ long lastLiteralsSize = literalsLimit - literalsInput;
+ UNSAFE.copyMemory(literalsBase, literalsInput, outputBase, output, lastLiteralsSize);
+ output += lastLiteralsSize;
+ return output;
+ }
+
+ private void copyMatch(Object outputBase, long fastOutputLimit, long output, int offset, long matchOutputLimit, long matchAddress, int matchLength, long fastMatchOutputLimit)
+ {
+ matchAddress = copyMatchHead(outputBase, output, offset, matchAddress);
+ output += SIZE_OF_LONG;
+ matchLength -= SIZE_OF_LONG; // first 8 bytes copied above
+
+ copyMatchTail(outputBase, fastOutputLimit, output, matchOutputLimit, matchAddress, matchLength, fastMatchOutputLimit);
+ }
+
+ private void copyMatchTail(Object outputBase, long fastOutputLimit, long output, long matchOutputLimit, long matchAddress, int matchLength, long fastMatchOutputLimit)
+ {
+ // fastMatchOutputLimit is just fastOutputLimit - SIZE_OF_LONG. It needs to be passed in so that it can be computed once for the
+ // whole invocation to decompressSequences. Otherwise, we'd just compute it here.
+ // If matchOutputLimit is < fastMatchOutputLimit, we know that even after the head (8 bytes) has been copied, the output pointer
+ // will be within fastOutputLimit, so it's safe to copy blindly before checking the limit condition
+ if (matchOutputLimit < fastMatchOutputLimit) {
+ int copied = 0;
+ do {
+ UNSAFE.putLong(outputBase, output, UNSAFE.getLong(outputBase, matchAddress));
+ output += SIZE_OF_LONG;
+ matchAddress += SIZE_OF_LONG;
+ copied += SIZE_OF_LONG;
+ }
+ while (copied < matchLength);
+ }
+ else {
+ while (output < fastOutputLimit) {
+ UNSAFE.putLong(outputBase, output, UNSAFE.getLong(outputBase, matchAddress));
+ matchAddress += SIZE_OF_LONG;
+ output += SIZE_OF_LONG;
+ }
+
+ while (output < matchOutputLimit) {
+ UNSAFE.putByte(outputBase, output++, UNSAFE.getByte(outputBase, matchAddress++));
+ }
+ }
+ }
+
+ private long copyMatchHead(Object outputBase, long output, int offset, long matchAddress)
+ {
+ // copy match
+ if (offset < 8) {
+ // 8 bytes apart so that we can copy long-at-a-time below
+ int increment32 = DEC_32_TABLE[offset];
+ int decrement64 = DEC_64_TABLE[offset];
+
+ UNSAFE.putByte(outputBase, output, UNSAFE.getByte(outputBase, matchAddress));
+ UNSAFE.putByte(outputBase, output + 1, UNSAFE.getByte(outputBase, matchAddress + 1));
+ UNSAFE.putByte(outputBase, output + 2, UNSAFE.getByte(outputBase, matchAddress + 2));
+ UNSAFE.putByte(outputBase, output + 3, UNSAFE.getByte(outputBase, matchAddress + 3));
+ matchAddress += increment32;
+
+ UNSAFE.putInt(outputBase, output + 4, UNSAFE.getInt(outputBase, matchAddress));
+ matchAddress -= decrement64;
+ }
+ else {
+ UNSAFE.putLong(outputBase, output, UNSAFE.getLong(outputBase, matchAddress));
+ matchAddress += SIZE_OF_LONG;
+ }
+ return matchAddress;
+ }
+
+ private long copyLiterals(Object outputBase, Object literalsBase, long output, long literalsInput, long literalOutputLimit)
+ {
+ long literalInput = literalsInput;
+ do {
+ UNSAFE.putLong(outputBase, output, UNSAFE.getLong(literalsBase, literalInput));
+ output += SIZE_OF_LONG;
+ literalInput += SIZE_OF_LONG;
+ }
+ while (output < literalOutputLimit);
+ output = literalOutputLimit; // correction in case we over-copied
+ return output;
+ }
+
+ private long computeMatchLengthTable(int matchLengthType, Object inputBase, long input, long inputLimit)
+ {
+ switch (matchLengthType) {
+ case SEQUENCE_ENCODING_RLE:
+ verify(input < inputLimit, input, "Not enough input bytes");
+
+ byte value = UNSAFE.getByte(inputBase, input++);
+ verify(value <= MAX_MATCH_LENGTH_SYMBOL, input, "Value exceeds expected maximum value");
+
+ FseTableReader.initializeRleTable(matchLengthTable, value);
+ currentMatchLengthTable = matchLengthTable;
+ break;
+ case SEQUENCE_ENCODING_BASIC:
+ currentMatchLengthTable = DEFAULT_MATCH_LENGTH_TABLE;
+ break;
+ case SEQUENCE_ENCODING_REPEAT:
+ verify(currentMatchLengthTable != null, input, "Expected match length table to be present");
+ break;
+ case SEQUENCE_ENCODING_COMPRESSED:
+ input += fse.readFseTable(matchLengthTable, inputBase, input, inputLimit, MAX_MATCH_LENGTH_SYMBOL, MATCH_LENGTH_TABLE_LOG);
+ currentMatchLengthTable = matchLengthTable;
+ break;
+ default:
+ throw fail(input, "Invalid match length encoding type");
+ }
+ return input;
+ }
+
+ private long computeOffsetsTable(int offsetCodesType, Object inputBase, long input, long inputLimit)
+ {
+ switch (offsetCodesType) {
+ case SEQUENCE_ENCODING_RLE:
+ verify(input < inputLimit, input, "Not enough input bytes");
+
+ byte value = UNSAFE.getByte(inputBase, input++);
+ verify(value <= DEFAULT_MAX_OFFSET_CODE_SYMBOL, input, "Value exceeds expected maximum value");
+
+ FseTableReader.initializeRleTable(offsetCodesTable, value);
+ currentOffsetCodesTable = offsetCodesTable;
+ break;
+ case SEQUENCE_ENCODING_BASIC:
+ currentOffsetCodesTable = DEFAULT_OFFSET_CODES_TABLE;
+ break;
+ case SEQUENCE_ENCODING_REPEAT:
+ verify(currentOffsetCodesTable != null, input, "Expected match length table to be present");
+ break;
+ case SEQUENCE_ENCODING_COMPRESSED:
+ input += fse.readFseTable(offsetCodesTable, inputBase, input, inputLimit, DEFAULT_MAX_OFFSET_CODE_SYMBOL, OFFSET_TABLE_LOG);
+ currentOffsetCodesTable = offsetCodesTable;
+ break;
+ default:
+ throw fail(input, "Invalid offset code encoding type");
+ }
+ return input;
+ }
+
+ private long computeLiteralsTable(int literalsLengthType, Object inputBase, long input, long inputLimit)
+ {
+ switch (literalsLengthType) {
+ case SEQUENCE_ENCODING_RLE:
+ verify(input < inputLimit, input, "Not enough input bytes");
+
+ byte value = UNSAFE.getByte(inputBase, input++);
+ verify(value <= MAX_LITERALS_LENGTH_SYMBOL, input, "Value exceeds expected maximum value");
+
+ FseTableReader.initializeRleTable(literalsLengthTable, value);
+ currentLiteralsLengthTable = literalsLengthTable;
+ break;
+ case SEQUENCE_ENCODING_BASIC:
+ currentLiteralsLengthTable = DEFAULT_LITERALS_LENGTH_TABLE;
+ break;
+ case SEQUENCE_ENCODING_REPEAT:
+ verify(currentLiteralsLengthTable != null, input, "Expected match length table to be present");
+ break;
+ case SEQUENCE_ENCODING_COMPRESSED:
+ input += fse.readFseTable(literalsLengthTable, inputBase, input, inputLimit, MAX_LITERALS_LENGTH_SYMBOL, LITERAL_LENGTH_TABLE_LOG);
+ currentLiteralsLengthTable = literalsLengthTable;
+ break;
+ default:
+ throw fail(input, "Invalid literals length encoding type");
+ }
+ return input;
+ }
+
+ private void executeLastSequence(Object outputBase, long output, long literalOutputLimit, long matchOutputLimit, long fastOutputLimit, long literalInput, long matchAddress)
+ {
+ // copy literals
+ if (output < fastOutputLimit) {
+ // wild copy
+ do {
+ UNSAFE.putLong(outputBase, output, UNSAFE.getLong(literalsBase, literalInput));
+ output += SIZE_OF_LONG;
+ literalInput += SIZE_OF_LONG;
+ }
+ while (output < fastOutputLimit);
+
+ literalInput -= output - fastOutputLimit;
+ output = fastOutputLimit;
+ }
+
+ while (output < literalOutputLimit) {
+ UNSAFE.putByte(outputBase, output, UNSAFE.getByte(literalsBase, literalInput));
+ output++;
+ literalInput++;
+ }
+
+ // copy match
+ while (output < matchOutputLimit) {
+ UNSAFE.putByte(outputBase, output, UNSAFE.getByte(outputBase, matchAddress));
+ output++;
+ matchAddress++;
+ }
+ }
+
+ @SuppressWarnings("fallthrough")
+ private int decodeCompressedLiterals(Object inputBase, final long inputAddress, int blockSize, int literalsBlockType)
+ {
+ long input = inputAddress;
+ verify(blockSize >= 5, input, "Not enough input bytes");
+
+ // compressed
+ int compressedSize;
+ int uncompressedSize;
+ boolean singleStream = false;
+ int headerSize;
+ int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0b11;
+ switch (type) {
+ case 0:
+ singleStream = true;
+ case 1: {
+ int header = UNSAFE.getInt(inputBase, input);
+
+ headerSize = 3;
+ uncompressedSize = (header >>> 4) & mask(10);
+ compressedSize = (header >>> 14) & mask(10);
+ break;
+ }
+ case 2: {
+ int header = UNSAFE.getInt(inputBase, input);
+
+ headerSize = 4;
+ uncompressedSize = (header >>> 4) & mask(14);
+ compressedSize = (header >>> 18) & mask(14);
+ break;
+ }
+ case 3: {
+ // read 5 little-endian bytes
+ long header = UNSAFE.getByte(inputBase, input) & 0xFF |
+ (UNSAFE.getInt(inputBase, input + 1) & 0xFFFF_FFFFL) << 8;
+
+ headerSize = 5;
+ uncompressedSize = (int) ((header >>> 4) & mask(18));
+ compressedSize = (int) ((header >>> 22) & mask(18));
+ break;
+ }
+ default:
+ throw fail(input, "Invalid literals header size type");
+ }
+
+ verify(uncompressedSize <= MAX_BLOCK_SIZE, input, "Block exceeds maximum size");
+ verify(headerSize + compressedSize <= blockSize, input, "Input is corrupted");
+
+ input += headerSize;
+
+ long inputLimit = input + compressedSize;
+ if (literalsBlockType != TREELESS_LITERALS_BLOCK) {
+ input += huffman.readTable(inputBase, input, compressedSize);
+ }
+
+ literalsBase = literals;
+ literalsAddress = ARRAY_BYTE_BASE_OFFSET;
+ literalsLimit = ARRAY_BYTE_BASE_OFFSET + uncompressedSize;
+
+ if (singleStream) {
+ huffman.decodeSingleStream(inputBase, input, inputLimit, literals, literalsAddress, literalsLimit);
+ }
+ else {
+ huffman.decode4Streams(inputBase, input, inputLimit, literals, literalsAddress, literalsLimit);
+ }
+
+ return headerSize + compressedSize;
+ }
+
+ private int decodeRleLiterals(Object inputBase, final long inputAddress, int blockSize)
+ {
+ long input = inputAddress;
+ int outputSize;
+
+ int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0b11;
+ switch (type) {
+ case 0:
+ case 2:
+ outputSize = (UNSAFE.getByte(inputBase, input) & 0xFF) >>> 3;
+ input++;
+ break;
+ case 1:
+ outputSize = (UNSAFE.getShort(inputBase, input) & 0xFFFF) >>> 4;
+ input += 2;
+ break;
+ case 3:
+ // we need at least 4 bytes (3 for the header, 1 for the payload)
+ verify(blockSize >= SIZE_OF_INT, input, "Not enough input bytes");
+ outputSize = (UNSAFE.getInt(inputBase, input) & 0xFF_FFFF) >>> 4;
+ input += 3;
+ break;
+ default:
+ throw fail(input, "Invalid RLE literals header encoding type");
+ }
+
+ verify(outputSize <= MAX_BLOCK_SIZE, input, "Output exceeds maximum block size");
+
+ byte value = UNSAFE.getByte(inputBase, input++);
+ Arrays.fill(literals, 0, outputSize + SIZE_OF_LONG, value);
+
+ literalsBase = literals;
+ literalsAddress = ARRAY_BYTE_BASE_OFFSET;
+ literalsLimit = ARRAY_BYTE_BASE_OFFSET + outputSize;
+
+ return (int) (input - inputAddress);
+ }
+
+ private int decodeRawLiterals(Object inputBase, final long inputAddress, long inputLimit)
+ {
+ long input = inputAddress;
+ int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0b11;
+
+ int literalSize;
+ switch (type) {
+ case 0:
+ case 2:
+ literalSize = (UNSAFE.getByte(inputBase, input) & 0xFF) >>> 3;
+ input++;
+ break;
+ case 1:
+ literalSize = (UNSAFE.getShort(inputBase, input) & 0xFFFF) >>> 4;
+ input += 2;
+ break;
+ case 3:
+ // read 3 little-endian bytes
+ int header = ((UNSAFE.getByte(inputBase, input) & 0xFF) |
+ ((UNSAFE.getShort(inputBase, input + 1) & 0xFFFF) << 8));
+
+ literalSize = header >>> 4;
+ input += 3;
+ break;
+ default:
+ throw fail(input, "Invalid raw literals header encoding type");
+ }
+
+ verify(input + literalSize <= inputLimit, input, "Not enough input bytes");
+
+ // Set literals pointer to [input, literalSize], but only if we can copy 8 bytes at a time during sequence decoding
+ // Otherwise, copy literals into buffer that's big enough to guarantee that
+ if (literalSize > (inputLimit - input) - SIZE_OF_LONG) {
+ literalsBase = literals;
+ literalsAddress = ARRAY_BYTE_BASE_OFFSET;
+ literalsLimit = ARRAY_BYTE_BASE_OFFSET + literalSize;
+
+ UNSAFE.copyMemory(inputBase, input, literals, literalsAddress, literalSize);
+ Arrays.fill(literals, literalSize, literalSize + SIZE_OF_LONG, (byte) 0);
+ }
+ else {
+ literalsBase = inputBase;
+ literalsAddress = input;
+ literalsLimit = literalsAddress + literalSize;
+ }
+ input += literalSize;
+
+ return (int) (input - inputAddress);
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdCompressor.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdCompressor.java
new file mode 100644
index 00000000000..1624067f769
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdCompressor.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import ai.vespa.airlift.compress.Compressor;
+
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+
+import static ai.vespa.airlift.zstd.Constants.MAX_BLOCK_SIZE;
+import static ai.vespa.airlift.zstd.UnsafeUtil.getAddress;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
+
+public class ZstdCompressor
+ implements Compressor
+{
+ @Override
+ public int maxCompressedLength(int uncompressedSize)
+ {
+ int result = uncompressedSize + (uncompressedSize >>> 8);
+
+ if (uncompressedSize < MAX_BLOCK_SIZE) {
+ result += (MAX_BLOCK_SIZE - uncompressedSize) >>> 11;
+ }
+
+ return result;
+ }
+
+ @Override
+ public int compress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength)
+ {
+ verifyRange(input, inputOffset, inputLength);
+ verifyRange(output, outputOffset, maxOutputLength);
+
+ long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset;
+ long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset;
+
+ return ZstdFrameCompressor.compress(input, inputAddress, inputAddress + inputLength, output, outputAddress, outputAddress + maxOutputLength, CompressionParameters.DEFAULT_COMPRESSION_LEVEL);
+ }
+
+ @Override
+ public void compress(ByteBuffer inputBuffer, ByteBuffer outputBuffer)
+ {
+ // Java 9+ added an overload of various methods in ByteBuffer. When compiling with Java 11+ and targeting Java 8 bytecode
+ // the resulting signatures are invalid for JDK 8, so accesses below result in NoSuchMethodError. Accessing the
+ // methods through the interface class works around the problem
+ // Sidenote: we can't target "javac --release 8" because Unsafe is not available in the signature data for that profile
+ Buffer input = inputBuffer;
+ Buffer output = outputBuffer;
+
+ Object inputBase;
+ long inputAddress;
+ long inputLimit;
+ if (input.isDirect()) {
+ inputBase = null;
+ long address = getAddress(input);
+ inputAddress = address + input.position();
+ inputLimit = address + input.limit();
+ }
+ else if (input.hasArray()) {
+ inputBase = input.array();
+ inputAddress = ARRAY_BYTE_BASE_OFFSET + input.arrayOffset() + input.position();
+ inputLimit = ARRAY_BYTE_BASE_OFFSET + input.arrayOffset() + input.limit();
+ }
+ else {
+ throw new IllegalArgumentException("Unsupported input ByteBuffer implementation " + input.getClass().getName());
+ }
+
+ Object outputBase;
+ long outputAddress;
+ long outputLimit;
+ if (output.isDirect()) {
+ outputBase = null;
+ long address = getAddress(output);
+ outputAddress = address + output.position();
+ outputLimit = address + output.limit();
+ }
+ else if (output.hasArray()) {
+ outputBase = output.array();
+ outputAddress = ARRAY_BYTE_BASE_OFFSET + output.arrayOffset() + output.position();
+ outputLimit = ARRAY_BYTE_BASE_OFFSET + output.arrayOffset() + output.limit();
+ }
+ else {
+ throw new IllegalArgumentException("Unsupported output ByteBuffer implementation " + output.getClass().getName());
+ }
+
+ // HACK: Assure JVM does not collect Slice wrappers while compressing, since the
+ // collection may trigger freeing of the underlying memory resulting in a segfault
+ // There is no other known way to signal to the JVM that an object should not be
+ // collected in a block, and technically, the JVM is allowed to eliminate these locks.
+ synchronized (input) {
+ synchronized (output) {
+ int written = ZstdFrameCompressor.compress(
+ inputBase,
+ inputAddress,
+ inputLimit,
+ outputBase,
+ outputAddress,
+ outputLimit,
+ CompressionParameters.DEFAULT_COMPRESSION_LEVEL);
+ output.position(output.position() + written);
+ }
+ }
+ }
+
+ private static void verifyRange(byte[] data, int offset, int length)
+ {
+ requireNonNull(data, "data is null");
+ if (offset < 0 || length < 0 || offset + length > data.length) {
+ throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length));
+ }
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdDecompressor.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdDecompressor.java
new file mode 100644
index 00000000000..a5c755e3685
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdDecompressor.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import ai.vespa.airlift.compress.Decompressor;
+import ai.vespa.airlift.compress.MalformedInputException;
+
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+
+import static ai.vespa.airlift.zstd.UnsafeUtil.getAddress;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
+
+public class ZstdDecompressor
+ implements Decompressor
+{
+ private final ZstdFrameDecompressor decompressor = new ZstdFrameDecompressor();
+
+ @Override
+ public int decompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset, int maxOutputLength)
+ throws MalformedInputException
+ {
+ verifyRange(input, inputOffset, inputLength);
+ verifyRange(output, outputOffset, maxOutputLength);
+
+ long inputAddress = ARRAY_BYTE_BASE_OFFSET + inputOffset;
+ long inputLimit = inputAddress + inputLength;
+ long outputAddress = ARRAY_BYTE_BASE_OFFSET + outputOffset;
+ long outputLimit = outputAddress + maxOutputLength;
+
+ return decompressor.decompress(input, inputAddress, inputLimit, output, outputAddress, outputLimit);
+ }
+
+ @Override
+ public void decompress(ByteBuffer inputBuffer, ByteBuffer outputBuffer)
+ throws MalformedInputException
+ {
+ // Java 9+ added an overload of various methods in ByteBuffer. When compiling with Java 11+ and targeting Java 8 bytecode
+ // the resulting signatures are invalid for JDK 8, so accesses below result in NoSuchMethodError. Accessing the
+ // methods through the interface class works around the problem
+ // Sidenote: we can't target "javac --release 8" because Unsafe is not available in the signature data for that profile
+ Buffer input = inputBuffer;
+ Buffer output = outputBuffer;
+
+ Object inputBase;
+ long inputAddress;
+ long inputLimit;
+ if (input.isDirect()) {
+ inputBase = null;
+ long address = getAddress(input);
+ inputAddress = address + input.position();
+ inputLimit = address + input.limit();
+ }
+ else if (input.hasArray()) {
+ inputBase = input.array();
+ inputAddress = ARRAY_BYTE_BASE_OFFSET + input.arrayOffset() + input.position();
+ inputLimit = ARRAY_BYTE_BASE_OFFSET + input.arrayOffset() + input.limit();
+ }
+ else {
+ throw new IllegalArgumentException("Unsupported input ByteBuffer implementation " + input.getClass().getName());
+ }
+
+ Object outputBase;
+ long outputAddress;
+ long outputLimit;
+ if (output.isDirect()) {
+ outputBase = null;
+ long address = getAddress(output);
+ outputAddress = address + output.position();
+ outputLimit = address + output.limit();
+ }
+ else if (output.hasArray()) {
+ outputBase = output.array();
+ outputAddress = ARRAY_BYTE_BASE_OFFSET + output.arrayOffset() + output.position();
+ outputLimit = ARRAY_BYTE_BASE_OFFSET + output.arrayOffset() + output.limit();
+ }
+ else {
+ throw new IllegalArgumentException("Unsupported output ByteBuffer implementation " + output.getClass().getName());
+ }
+
+ // HACK: Assure JVM does not collect Slice wrappers while decompressing, since the
+ // collection may trigger freeing of the underlying memory resulting in a segfault
+ // There is no other known way to signal to the JVM that an object should not be
+ // collected in a block, and technically, the JVM is allowed to eliminate these locks.
+ synchronized (input) {
+ synchronized (output) {
+ int written = new ZstdFrameDecompressor().decompress(inputBase, inputAddress, inputLimit, outputBase, outputAddress, outputLimit);
+ output.position(output.position() + written);
+ }
+ }
+ }
+
+ public static long getDecompressedSize(byte[] input, int offset, int length)
+ {
+ int baseAddress = ARRAY_BYTE_BASE_OFFSET + offset;
+ return ZstdFrameDecompressor.getDecompressedSize(input, baseAddress, baseAddress + length);
+ }
+
+ private static void verifyRange(byte[] data, int offset, int length)
+ {
+ requireNonNull(data, "data is null");
+ if (offset < 0 || length < 0 || offset + length > data.length) {
+ throw new IllegalArgumentException(format("Invalid offset or length (%s, %s) in array of length %s", offset, length, data.length));
+ }
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdFrameCompressor.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdFrameCompressor.java
new file mode 100644
index 00000000000..44209b1f9e2
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdFrameCompressor.java
@@ -0,0 +1,438 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import static ai.vespa.airlift.zstd.Constants.COMPRESSED_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.COMPRESSED_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.MAGIC_NUMBER;
+import static ai.vespa.airlift.zstd.Constants.MAX_BLOCK_SIZE;
+import static ai.vespa.airlift.zstd.Constants.MIN_BLOCK_SIZE;
+import static ai.vespa.airlift.zstd.Constants.MIN_WINDOW_LOG;
+import static ai.vespa.airlift.zstd.Constants.RAW_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.RAW_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.RLE_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_BLOCK_HEADER;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_INT;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_SHORT;
+import static ai.vespa.airlift.zstd.Constants.TREELESS_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.Huffman.MAX_SYMBOL;
+import static ai.vespa.airlift.zstd.Huffman.MAX_SYMBOL_COUNT;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.checkArgument;
+import static ai.vespa.airlift.zstd.Util.put24BitLittleEndian;
+import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
+
+class ZstdFrameCompressor
+{
+ static final int MAX_FRAME_HEADER_SIZE = 14;
+
+ private static final int CHECKSUM_FLAG = 0b100;
+ private static final int SINGLE_SEGMENT_FLAG = 0b100000;
+
+ private static final int MINIMUM_LITERALS_SIZE = 63;
+
+ // the maximum table log allowed for literal encoding per RFC 8478, section 4.2.1
+ private static final int MAX_HUFFMAN_TABLE_LOG = 11;
+
+ private ZstdFrameCompressor()
+ {
+ }
+
+ // visible for testing
+ static int writeMagic(final Object outputBase, final long outputAddress, final long outputLimit)
+ {
+ checkArgument(outputLimit - outputAddress >= SIZE_OF_INT, "Output buffer too small");
+
+ UNSAFE.putInt(outputBase, outputAddress, MAGIC_NUMBER);
+ return SIZE_OF_INT;
+ }
+
+ // visible for testing
+ static int writeFrameHeader(final Object outputBase, final long outputAddress, final long outputLimit, int inputSize, int windowSize)
+ {
+ checkArgument(outputLimit - outputAddress >= MAX_FRAME_HEADER_SIZE, "Output buffer too small");
+
+ long output = outputAddress;
+
+ int contentSizeDescriptor = (inputSize >= 256 ? 1 : 0) + (inputSize >= 65536 + 256 ? 1 : 0);
+ int frameHeaderDescriptor = (contentSizeDescriptor << 6) | CHECKSUM_FLAG; // dictionary ID missing
+
+ boolean singleSegment = windowSize >= inputSize;
+ if (singleSegment) {
+ frameHeaderDescriptor |= SINGLE_SEGMENT_FLAG;
+ }
+
+ UNSAFE.putByte(outputBase, output, (byte) frameHeaderDescriptor);
+ output++;
+
+ if (!singleSegment) {
+ int base = Integer.highestOneBit(windowSize);
+
+ int exponent = 32 - Integer.numberOfLeadingZeros(base) - 1;
+ if (exponent < MIN_WINDOW_LOG) {
+ throw new IllegalArgumentException("Minimum window size is " + (1 << MIN_WINDOW_LOG));
+ }
+
+ int remainder = windowSize - base;
+ if (remainder % (base / 8) != 0) {
+ throw new IllegalArgumentException("Window size of magnitude 2^" + exponent + " must be multiple of " + (base / 8));
+ }
+
+ // mantissa is guaranteed to be between 0-7
+ int mantissa = remainder / (base / 8);
+ int encoded = ((exponent - MIN_WINDOW_LOG) << 3) | mantissa;
+
+ UNSAFE.putByte(outputBase, output, (byte) encoded);
+ output++;
+ }
+
+ switch (contentSizeDescriptor) {
+ case 0:
+ if (singleSegment) {
+ UNSAFE.putByte(outputBase, output++, (byte) inputSize);
+ }
+ break;
+ case 1:
+ UNSAFE.putShort(outputBase, output, (short) (inputSize - 256));
+ output += SIZE_OF_SHORT;
+ break;
+ case 2:
+ UNSAFE.putInt(outputBase, output, inputSize);
+ output += SIZE_OF_INT;
+ break;
+ default:
+ throw new AssertionError();
+ }
+
+ return (int) (output - outputAddress);
+ }
+
+ // visible for testing
+ static int writeChecksum(Object outputBase, long outputAddress, long outputLimit, Object inputBase, long inputAddress, long inputLimit)
+ {
+ checkArgument(outputLimit - outputAddress >= SIZE_OF_INT, "Output buffer too small");
+
+ int inputSize = (int) (inputLimit - inputAddress);
+
+ long hash = XxHash64.hash(0, inputBase, inputAddress, inputSize);
+
+ UNSAFE.putInt(outputBase, outputAddress, (int) hash);
+
+ return SIZE_OF_INT;
+ }
+
+ public static int compress(Object inputBase, long inputAddress, long inputLimit, Object outputBase, long outputAddress, long outputLimit, int compressionLevel)
+ {
+ int inputSize = (int) (inputLimit - inputAddress);
+
+ CompressionParameters parameters = CompressionParameters.compute(compressionLevel, inputSize);
+
+ long output = outputAddress;
+
+ output += writeMagic(outputBase, output, outputLimit);
+ output += writeFrameHeader(outputBase, output, outputLimit, inputSize, 1 << parameters.getWindowLog());
+ output += compressFrame(inputBase, inputAddress, inputLimit, outputBase, output, outputLimit, parameters);
+ output += writeChecksum(outputBase, output, outputLimit, inputBase, inputAddress, inputLimit);
+
+ return (int) (output - outputAddress);
+ }
+
+ private static int compressFrame(Object inputBase, long inputAddress, long inputLimit, Object outputBase, long outputAddress, long outputLimit, CompressionParameters parameters)
+ {
+ int windowSize = 1 << parameters.getWindowLog(); // TODO: store window size in parameters directly?
+ int blockSize = Math.min(MAX_BLOCK_SIZE, windowSize);
+
+ int outputSize = (int) (outputLimit - outputAddress);
+ int remaining = (int) (inputLimit - inputAddress);
+
+ long output = outputAddress;
+ long input = inputAddress;
+
+ CompressionContext context = new CompressionContext(parameters, inputAddress, remaining);
+
+ do {
+ checkArgument(outputSize >= SIZE_OF_BLOCK_HEADER + MIN_BLOCK_SIZE, "Output buffer too small");
+
+ int lastBlockFlag = blockSize >= remaining ? 1 : 0;
+ blockSize = Math.min(blockSize, remaining);
+
+ int compressedSize = 0;
+ if (remaining > 0) {
+ compressedSize = compressBlock(inputBase, input, blockSize, outputBase, output + SIZE_OF_BLOCK_HEADER, outputSize - SIZE_OF_BLOCK_HEADER, context, parameters);
+ }
+
+ if (compressedSize == 0) { // block is not compressible
+ checkArgument(blockSize + SIZE_OF_BLOCK_HEADER <= outputSize, "Output size too small");
+
+ int blockHeader = lastBlockFlag | (RAW_BLOCK << 1) | (blockSize << 3);
+ put24BitLittleEndian(outputBase, output, blockHeader);
+ UNSAFE.copyMemory(inputBase, input, outputBase, output + SIZE_OF_BLOCK_HEADER, blockSize);
+ compressedSize = SIZE_OF_BLOCK_HEADER + blockSize;
+ }
+ else {
+ int blockHeader = lastBlockFlag | (COMPRESSED_BLOCK << 1) | (compressedSize << 3);
+ put24BitLittleEndian(outputBase, output, blockHeader);
+ compressedSize += SIZE_OF_BLOCK_HEADER;
+ }
+
+ input += blockSize;
+ remaining -= blockSize;
+ output += compressedSize;
+ outputSize -= compressedSize;
+ }
+ while (remaining > 0);
+
+ return (int) (output - outputAddress);
+ }
+
+ private static int compressBlock(Object inputBase, long inputAddress, int inputSize, Object outputBase, long outputAddress, int outputSize, CompressionContext context, CompressionParameters parameters)
+ {
+ if (inputSize < MIN_BLOCK_SIZE + SIZE_OF_BLOCK_HEADER + 1) {
+ // don't even attempt compression below a certain input size
+ return 0;
+ }
+
+ context.blockCompressionState.enforceMaxDistance(inputAddress + inputSize, 1 << parameters.getWindowLog());
+ context.sequenceStore.reset();
+
+ int lastLiteralsSize = parameters.getStrategy()
+ .getCompressor()
+ .compressBlock(inputBase, inputAddress, inputSize, context.sequenceStore, context.blockCompressionState, context.offsets, parameters);
+
+ long lastLiteralsAddress = inputAddress + inputSize - lastLiteralsSize;
+
+ // append [lastLiteralsAddress .. lastLiteralsSize] to sequenceStore literals buffer
+ context.sequenceStore.appendLiterals(inputBase, lastLiteralsAddress, lastLiteralsSize);
+
+ // convert length/offsets into codes
+ context.sequenceStore.generateCodes();
+
+ long outputLimit = outputAddress + outputSize;
+ long output = outputAddress;
+
+ int compressedLiteralsSize = encodeLiterals(
+ context.huffmanContext,
+ parameters,
+ outputBase,
+ output,
+ (int) (outputLimit - output),
+ context.sequenceStore.literalsBuffer,
+ context.sequenceStore.literalsLength);
+ output += compressedLiteralsSize;
+
+ int compressedSequencesSize = SequenceEncoder.compressSequences(outputBase, output, (int) (outputLimit - output), context.sequenceStore, parameters.getStrategy(), context.sequenceEncodingContext);
+
+ int compressedSize = compressedLiteralsSize + compressedSequencesSize;
+ if (compressedSize == 0) {
+ // not compressible
+ return compressedSize;
+ }
+
+ // Check compressibility
+ int maxCompressedSize = inputSize - calculateMinimumGain(inputSize, parameters.getStrategy());
+ if (compressedSize > maxCompressedSize) {
+ return 0; // not compressed
+ }
+
+ // confirm repeated offsets and entropy tables
+ context.commit();
+
+ return compressedSize;
+ }
+
+ private static int encodeLiterals(
+ HuffmanCompressionContext context,
+ CompressionParameters parameters,
+ Object outputBase,
+ long outputAddress,
+ int outputSize,
+ byte[] literals,
+ int literalsSize)
+ {
+ // TODO: move this to Strategy
+ boolean bypassCompression = (parameters.getStrategy() == CompressionParameters.Strategy.FAST) && (parameters.getTargetLength() > 0);
+ if (bypassCompression || literalsSize <= MINIMUM_LITERALS_SIZE) {
+ return rawLiterals(outputBase, outputAddress, outputSize, literals, ARRAY_BYTE_BASE_OFFSET, literalsSize);
+ }
+
+ int headerSize = 3 + (literalsSize >= 1024 ? 1 : 0) + (literalsSize >= 16384 ? 1 : 0);
+
+ checkArgument(headerSize + 1 <= outputSize, "Output buffer too small");
+
+ int[] counts = new int[MAX_SYMBOL_COUNT]; // TODO: preallocate
+ Histogram.count(literals, literalsSize, counts);
+ int maxSymbol = Histogram.findMaxSymbol(counts, MAX_SYMBOL);
+ int largestCount = Histogram.findLargestCount(counts, maxSymbol);
+
+ long literalsAddress = ARRAY_BYTE_BASE_OFFSET;
+ if (largestCount == literalsSize) {
+ // all bytes in input are equal
+ return rleLiterals(outputBase, outputAddress, outputSize, literals, ARRAY_BYTE_BASE_OFFSET, literalsSize);
+ }
+ else if (largestCount <= (literalsSize >>> 7) + 4) {
+ // heuristic: probably not compressible enough
+ return rawLiterals(outputBase, outputAddress, outputSize, literals, ARRAY_BYTE_BASE_OFFSET, literalsSize);
+ }
+
+ HuffmanCompressionTable previousTable = context.getPreviousTable();
+ HuffmanCompressionTable table;
+ int serializedTableSize;
+ boolean reuseTable;
+
+ boolean canReuse = previousTable.isValid(counts, maxSymbol);
+
+ // heuristic: use existing table for small inputs if valid
+ // TODO: move to Strategy
+ boolean preferReuse = parameters.getStrategy().ordinal() < CompressionParameters.Strategy.LAZY.ordinal() && literalsSize <= 1024;
+ if (preferReuse && canReuse) {
+ table = previousTable;
+ reuseTable = true;
+ serializedTableSize = 0;
+ }
+ else {
+ HuffmanCompressionTable newTable = context.borrowTemporaryTable();
+
+ newTable.initialize(
+ counts,
+ maxSymbol,
+ HuffmanCompressionTable.optimalNumberOfBits(MAX_HUFFMAN_TABLE_LOG, literalsSize, maxSymbol),
+ context.getCompressionTableWorkspace());
+
+ serializedTableSize = newTable.write(outputBase, outputAddress + headerSize, outputSize - headerSize, context.getTableWriterWorkspace());
+
+ // Check if using previous huffman table is beneficial
+ if (canReuse && previousTable.estimateCompressedSize(counts, maxSymbol) <= serializedTableSize + newTable.estimateCompressedSize(counts, maxSymbol)) {
+ table = previousTable;
+ reuseTable = true;
+ serializedTableSize = 0;
+ context.discardTemporaryTable();
+ }
+ else {
+ table = newTable;
+ reuseTable = false;
+ }
+ }
+
+ int compressedSize;
+ boolean singleStream = literalsSize < 256;
+ if (singleStream) {
+ compressedSize = HuffmanCompressor.compressSingleStream(outputBase, outputAddress + headerSize + serializedTableSize, outputSize - headerSize - serializedTableSize, literals, literalsAddress, literalsSize, table);
+ }
+ else {
+ compressedSize = HuffmanCompressor.compress4streams(outputBase, outputAddress + headerSize + serializedTableSize, outputSize - headerSize - serializedTableSize, literals, literalsAddress, literalsSize, table);
+ }
+
+ int totalSize = serializedTableSize + compressedSize;
+ int minimumGain = calculateMinimumGain(literalsSize, parameters.getStrategy());
+
+ if (compressedSize == 0 || totalSize >= literalsSize - minimumGain) {
+ // incompressible or no savings
+
+ // discard any temporary table we might have borrowed above
+ context.discardTemporaryTable();
+
+ return rawLiterals(outputBase, outputAddress, outputSize, literals, ARRAY_BYTE_BASE_OFFSET, literalsSize);
+ }
+
+ int encodingType = reuseTable ? TREELESS_LITERALS_BLOCK : COMPRESSED_LITERALS_BLOCK;
+
+ // Build header
+ switch (headerSize) {
+ case 3: { // 2 - 2 - 10 - 10
+ int header = encodingType | ((singleStream ? 0 : 1) << 2) | (literalsSize << 4) | (totalSize << 14);
+ put24BitLittleEndian(outputBase, outputAddress, header);
+ break;
+ }
+ case 4: { // 2 - 2 - 14 - 14
+ int header = encodingType | (2 << 2) | (literalsSize << 4) | (totalSize << 18);
+ UNSAFE.putInt(outputBase, outputAddress, header);
+ break;
+ }
+ case 5: { // 2 - 2 - 18 - 18
+ int header = encodingType | (3 << 2) | (literalsSize << 4) | (totalSize << 22);
+ UNSAFE.putInt(outputBase, outputAddress, header);
+ UNSAFE.putByte(outputBase, outputAddress + SIZE_OF_INT, (byte) (totalSize >>> 10));
+ break;
+ }
+ default: // not possible : headerSize is {3,4,5}
+ throw new IllegalStateException();
+ }
+
+ return headerSize + totalSize;
+ }
+
+ private static int rleLiterals(Object outputBase, long outputAddress, int outputSize, Object inputBase, long inputAddress, int inputSize)
+ {
+ int headerSize = 1 + (inputSize > 31 ? 1 : 0) + (inputSize > 4095 ? 1 : 0);
+
+ switch (headerSize) {
+ case 1: // 2 - 1 - 5
+ UNSAFE.putByte(outputBase, outputAddress, (byte) (RLE_LITERALS_BLOCK | (inputSize << 3)));
+ break;
+ case 2: // 2 - 2 - 12
+ UNSAFE.putShort(outputBase, outputAddress, (short) (RLE_LITERALS_BLOCK | (1 << 2) | (inputSize << 4)));
+ break;
+ case 3: // 2 - 2 - 20
+ UNSAFE.putInt(outputBase, outputAddress, RLE_LITERALS_BLOCK | 3 << 2 | inputSize << 4);
+ break;
+ default: // impossible. headerSize is {1,2,3}
+ throw new IllegalStateException();
+ }
+
+ UNSAFE.putByte(outputBase, outputAddress + headerSize, UNSAFE.getByte(inputBase, inputAddress));
+
+ return headerSize + 1;
+ }
+
+ private static int calculateMinimumGain(int inputSize, CompressionParameters.Strategy strategy)
+ {
+ // TODO: move this to Strategy to avoid hardcoding a specific strategy here
+ int minLog = strategy == CompressionParameters.Strategy.BTULTRA ? 7 : 6;
+ return (inputSize >>> minLog) + 2;
+ }
+
+ private static int rawLiterals(Object outputBase, long outputAddress, int outputSize, Object inputBase, long inputAddress, int inputSize)
+ {
+ int headerSize = 1;
+ if (inputSize >= 32) {
+ headerSize++;
+ }
+ if (inputSize >= 4096) {
+ headerSize++;
+ }
+
+ checkArgument(inputSize + headerSize <= outputSize, "Output buffer too small");
+
+ switch (headerSize) {
+ case 1:
+ UNSAFE.putByte(outputBase, outputAddress, (byte) (RAW_LITERALS_BLOCK | (inputSize << 3)));
+ break;
+ case 2:
+ UNSAFE.putShort(outputBase, outputAddress, (short) (RAW_LITERALS_BLOCK | (1 << 2) | (inputSize << 4)));
+ break;
+ case 3:
+ put24BitLittleEndian(outputBase, outputAddress, RAW_LITERALS_BLOCK | (3 << 2) | (inputSize << 4));
+ break;
+ default:
+ throw new AssertionError();
+ }
+
+ // TODO: ensure this test is correct
+ checkArgument(inputSize + 1 <= outputSize, "Output buffer too small");
+
+ UNSAFE.copyMemory(inputBase, inputAddress, outputBase, outputAddress + headerSize, inputSize);
+
+ return headerSize + inputSize;
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdFrameDecompressor.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdFrameDecompressor.java
new file mode 100644
index 00000000000..46b2ea2a894
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdFrameDecompressor.java
@@ -0,0 +1,212 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import ai.vespa.airlift.compress.MalformedInputException;
+
+import static ai.vespa.airlift.zstd.Constants.COMPRESSED_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.MAGIC_NUMBER;
+import static ai.vespa.airlift.zstd.Constants.MIN_WINDOW_LOG;
+import static ai.vespa.airlift.zstd.Constants.RAW_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.RLE_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_BLOCK_HEADER;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_BYTE;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_INT;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_LONG;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_SHORT;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.fail;
+import static ai.vespa.airlift.zstd.Util.verify;
+
+class ZstdFrameDecompressor
+{
+ private static final int V07_MAGIC_NUMBER = 0xFD2FB527;
+
+ public int decompress(
+ final Object inputBase,
+ final long inputAddress,
+ final long inputLimit,
+ final Object outputBase,
+ final long outputAddress,
+ final long outputLimit)
+ {
+ if (outputAddress == outputLimit) {
+ return 0;
+ }
+ long input = inputAddress;
+ long output = outputAddress;
+
+ while (input < inputLimit) {
+ long outputStart = output;
+ input += verifyMagic(inputBase, input, inputLimit);
+
+ FrameHeader frameHeader = readFrameHeader(inputBase, input, inputLimit);
+ input += frameHeader.headerSize;
+
+ ZstdBlockDecompressor blockDecompressor = new ZstdBlockDecompressor(frameHeader);
+ boolean lastBlock;
+ do {
+ verify(input + SIZE_OF_BLOCK_HEADER <= inputLimit, input, "Not enough input bytes");
+
+ // read block header
+ int header = UNSAFE.getInt(inputBase, input) & 0xFF_FFFF;
+ input += SIZE_OF_BLOCK_HEADER;
+
+ lastBlock = (header & 1) != 0;
+ int blockType = (header >>> 1) & 0b11;
+ int blockSize = (header >>> 3) & 0x1F_FFFF; // 21 bits
+
+ int decodedSize;
+ switch (blockType) {
+ case RAW_BLOCK:
+ verify(inputAddress + blockSize <= inputLimit, input, "Not enough input bytes");
+ decodedSize = ZstdBlockDecompressor.decodeRawBlock(inputBase, input, blockSize, outputBase, output, outputLimit);
+ input += blockSize;
+ break;
+ case RLE_BLOCK:
+ verify(inputAddress + 1 <= inputLimit, input, "Not enough input bytes");
+ decodedSize = ZstdBlockDecompressor.decodeRleBlock(blockSize, inputBase, input, outputBase, output, outputLimit);
+ input += 1;
+ break;
+ case COMPRESSED_BLOCK:
+ verify(inputAddress + blockSize <= inputLimit, input, "Not enough input bytes");
+ decodedSize = blockDecompressor.decodeCompressedBlock(inputBase, input, blockSize, outputBase, output, outputLimit, frameHeader.windowSize, outputAddress);
+ input += blockSize;
+ break;
+ default:
+ throw fail(input, "Invalid block type");
+ }
+ output += decodedSize;
+ }
+ while (!lastBlock);
+
+ if (frameHeader.hasChecksum) {
+ int decodedFrameSize = (int) (output - outputStart);
+
+ long hash = XxHash64.hash(0, outputBase, outputStart, decodedFrameSize);
+
+ int checksum = UNSAFE.getInt(inputBase, input);
+ if (checksum != (int) hash) {
+ throw new MalformedInputException(input, String.format("Bad checksum. Expected: %s, actual: %s", Integer.toHexString(checksum), Integer.toHexString((int) hash)));
+ }
+
+ input += SIZE_OF_INT;
+ }
+ }
+
+ return (int) (output - outputAddress);
+ }
+
+ static FrameHeader readFrameHeader(final Object inputBase, final long inputAddress, final long inputLimit)
+ {
+ long input = inputAddress;
+ verify(input < inputLimit, input, "Not enough input bytes");
+
+ int frameHeaderDescriptor = UNSAFE.getByte(inputBase, input++) & 0xFF;
+ boolean singleSegment = (frameHeaderDescriptor & 0b100000) != 0;
+ int dictionaryDescriptor = frameHeaderDescriptor & 0b11;
+ int contentSizeDescriptor = frameHeaderDescriptor >>> 6;
+
+ int headerSize = 1 +
+ (singleSegment ? 0 : 1) +
+ (dictionaryDescriptor == 0 ? 0 : (1 << (dictionaryDescriptor - 1))) +
+ (contentSizeDescriptor == 0 ? (singleSegment ? 1 : 0) : (1 << contentSizeDescriptor));
+
+ verify(headerSize <= inputLimit - inputAddress, input, "Not enough input bytes");
+
+ // decode window size
+ int windowSize = -1;
+ if (!singleSegment) {
+ int windowDescriptor = UNSAFE.getByte(inputBase, input++) & 0xFF;
+ int exponent = windowDescriptor >>> 3;
+ int mantissa = windowDescriptor & 0b111;
+
+ int base = 1 << (MIN_WINDOW_LOG + exponent);
+ windowSize = base + (base / 8) * mantissa;
+ }
+
+ // decode dictionary id
+ long dictionaryId = -1;
+ switch (dictionaryDescriptor) {
+ case 1:
+ dictionaryId = UNSAFE.getByte(inputBase, input) & 0xFF;
+ input += SIZE_OF_BYTE;
+ break;
+ case 2:
+ dictionaryId = UNSAFE.getShort(inputBase, input) & 0xFFFF;
+ input += SIZE_OF_SHORT;
+ break;
+ case 3:
+ dictionaryId = UNSAFE.getInt(inputBase, input) & 0xFFFF_FFFFL;
+ input += SIZE_OF_INT;
+ break;
+ }
+ verify(dictionaryId == -1, input, "Custom dictionaries not supported");
+
+ // decode content size
+ long contentSize = -1;
+ switch (contentSizeDescriptor) {
+ case 0:
+ if (singleSegment) {
+ contentSize = UNSAFE.getByte(inputBase, input) & 0xFF;
+ input += SIZE_OF_BYTE;
+ }
+ break;
+ case 1:
+ contentSize = UNSAFE.getShort(inputBase, input) & 0xFFFF;
+ contentSize += 256;
+ input += SIZE_OF_SHORT;
+ break;
+ case 2:
+ contentSize = UNSAFE.getInt(inputBase, input) & 0xFFFF_FFFFL;
+ input += SIZE_OF_INT;
+ break;
+ case 3:
+ contentSize = UNSAFE.getLong(inputBase, input);
+ input += SIZE_OF_LONG;
+ break;
+ }
+
+ boolean hasChecksum = (frameHeaderDescriptor & 0b100) != 0;
+
+ return new FrameHeader(
+ input - inputAddress,
+ windowSize,
+ contentSize,
+ dictionaryId,
+ hasChecksum);
+ }
+
+ public static long getDecompressedSize(final Object inputBase, final long inputAddress, final long inputLimit)
+ {
+ long input = inputAddress;
+ input += verifyMagic(inputBase, input, inputLimit);
+ return readFrameHeader(inputBase, input, inputLimit).contentSize;
+ }
+
+ static int verifyMagic(Object inputBase, long inputAddress, long inputLimit)
+ {
+ verify(inputLimit - inputAddress >= 4, inputAddress, "Not enough input bytes");
+
+ int magic = UNSAFE.getInt(inputBase, inputAddress);
+ if (magic != MAGIC_NUMBER) {
+ if (magic == V07_MAGIC_NUMBER) {
+ throw new MalformedInputException(inputAddress, "Data encoded in unsupported ZSTD v0.7 format");
+ }
+ throw new MalformedInputException(inputAddress, "Invalid magic prefix: " + Integer.toHexString(magic));
+ }
+
+ return SIZE_OF_INT;
+ }
+}
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdInputStream.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdInputStream.java
new file mode 100644
index 00000000000..ffee9286fdb
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdInputStream.java
@@ -0,0 +1,471 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+import static ai.vespa.airlift.zstd.Constants.COMPRESSED_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.MAGIC_NUMBER;
+import static ai.vespa.airlift.zstd.Constants.MAGIC_SKIPFRAME_MAX;
+import static ai.vespa.airlift.zstd.Constants.MAGIC_SKIPFRAME_MIN;
+import static ai.vespa.airlift.zstd.Constants.MAX_BLOCK_SIZE;
+import static ai.vespa.airlift.zstd.Constants.RAW_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.RLE_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_BLOCK_HEADER;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_BYTE;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_INT;
+import static ai.vespa.airlift.zstd.Util.fail;
+import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
+
+/**
+ * Take a compressed InputStream and decompress it as needed
+ * @author arnej27959
+ */
+public class ZstdInputStream
+ extends InputStream
+{
+ private static final int DEFAULT_BUFFER_SIZE = 8 * 1024;
+ private static final int BUFFER_SIZE_MASK = ~(DEFAULT_BUFFER_SIZE - 1);
+ private static final int MAX_WINDOW_SIZE = 1 << 23;
+
+ private final InputStream inputStream;
+ private byte[] inputBuffer;
+ private int inputPosition;
+ private int inputEnd;
+ private byte[] outputBuffer;
+ private int outputPosition;
+ private int outputEnd;
+ private boolean isClosed;
+ private boolean seenEof;
+ private boolean lastBlock;
+ private boolean singleSegmentFlag;
+ private boolean contentChecksumFlag;
+ private long skipBytes;
+ private int windowSize;
+ private int blockMaximumSize = MAX_BLOCK_SIZE;
+ private int curBlockSize;
+ private int curBlockType = -1;
+ private FrameHeader curHeader;
+ private ZstdBlockDecompressor blockDecompressor;
+ private XxHash64 hasher;
+ private long evictedInput;
+
+ public ZstdInputStream(InputStream inp, int initialBufferSize)
+ {
+ this.inputStream = inp;
+ this.inputBuffer = new byte[initialBufferSize];
+ this.outputBuffer = new byte[initialBufferSize];
+ }
+
+ public ZstdInputStream(InputStream inp)
+ {
+ this(inp, DEFAULT_BUFFER_SIZE);
+ }
+
+ @Override
+ public int available()
+ {
+ return outputAvailable();
+ }
+
+ @Override
+ public int read() throws IOException
+ {
+ throwIfClosed();
+ if (ensureGotOutput()) {
+ int b = outputBuffer[outputPosition++];
+ return (b & 0xFF);
+ }
+ else {
+ return -1;
+ }
+ }
+
+ @Override
+ public int read(byte[] b) throws IOException
+ {
+ return read(b, 0, b.length);
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException
+ {
+ throwIfClosed();
+ if (ensureGotOutput()) {
+ len = Math.min(outputAvailable(), len);
+ System.arraycopy(outputBuffer, outputPosition, b, off, len);
+ outputPosition += len;
+ return len;
+ }
+ else {
+ return -1;
+ }
+ }
+
+ @Override
+ public void close() throws IOException
+ {
+ throwIfClosed();
+ if (!seenEof) {
+ inputStream.close();
+ }
+ isClosed = true;
+ }
+
+ private void check(boolean condition, String reason)
+ {
+ Util.verify(condition, curInputFilePosition(), reason);
+ }
+
+ private boolean ensureGotOutput() throws IOException
+ {
+ while ((outputAvailable() == 0) && !seenEof) {
+ if (ensureGotFrameHeader() && ensureGotBlock()) {
+ decompressBlock();
+ }
+ }
+ if (outputAvailable() > 0) {
+ return true;
+ }
+ else {
+ check(seenEof, "unable to decode to EOF");
+ check(inputAvailable() == 0, "leftover input at end of file");
+ check(curHeader == null, "unfinished frame at end of file");
+ return false;
+ }
+ }
+
+ private void readMoreInput() throws IOException
+ {
+ ensureInputSpace(1024);
+ int got = inputStream.read(inputBuffer, inputEnd, inputSpace());
+ if (got == -1) {
+ seenEof = true;
+ }
+ else {
+ inputEnd += got;
+ }
+ }
+
+ private ByteBuffer inputBB()
+ {
+ ByteBuffer bb = ByteBuffer.wrap(inputBuffer, inputPosition, inputAvailable());
+ bb.order(ByteOrder.LITTLE_ENDIAN);
+ return bb;
+ }
+
+ private boolean ensureGotFrameHeader() throws IOException
+ {
+ if (curHeader != null) {
+ return true;
+ }
+ // a skip frame is minimum 8 bytes
+ // a data frame is minimum 4 + 2 + 3 = 9 bytes, but we only
+ // need 5 bytes to know the size of the frame header
+ if (inputAvailable() < 8) {
+ readMoreInput();
+ // retry from start
+ return false;
+ }
+ ByteBuffer bb = inputBB();
+ int magic = bb.getInt();
+ // skippable frame header magic
+ if ((magic >= MAGIC_SKIPFRAME_MIN) && (magic <= MAGIC_SKIPFRAME_MAX)) {
+ inputPosition += SIZE_OF_INT; // for magic
+ skipBytes = (bb.getInt() & 0xffff_ffffL) + SIZE_OF_INT;
+ inputPosition += SIZE_OF_INT; // for skipsize
+ while (skipBytes > 0) {
+ if (skipBytes <= inputAvailable()) {
+ inputPosition += skipBytes;
+ skipBytes = 0;
+ }
+ else {
+ skipBytes -= inputAvailable();
+ inputPosition = inputEnd;
+ readMoreInput();
+ if (seenEof) {
+ throw fail(curInputFilePosition(), "unfinished skip frame at end of file");
+ }
+ }
+ }
+ // entire frame skipped; retry from start
+ return false;
+ }
+ // zstd frame header magic
+ if (magic == MAGIC_NUMBER) {
+ int fhDesc = 0xFF & bb.get();
+ int frameContentSizeFlag = (fhDesc & 0b11000000) >> 6;
+ singleSegmentFlag = (fhDesc & 0b00100000) != 0;
+ contentChecksumFlag = (fhDesc & 0b00000100) != 0;
+ int dictionaryIdFlag = (fhDesc & 0b00000011);
+ // 4 byte magic + 1 byte fhDesc
+ int fhSize = SIZE_OF_INT + SIZE_OF_BYTE;
+ // add size of frameContentSize
+ if (frameContentSizeFlag == 0) {
+ fhSize += (singleSegmentFlag ? 1 : 0);
+ }
+ else {
+ fhSize += 1 << frameContentSizeFlag;
+ }
+ // add size of window descriptor
+ fhSize += (singleSegmentFlag ? 0 : 1);
+ // add size of dictionary id
+ fhSize += (1 << dictionaryIdFlag) >> 1;
+ if (fhSize > inputAvailable()) {
+ readMoreInput();
+ // retry from start
+ return false;
+ }
+ inputPosition += SIZE_OF_INT;
+ curHeader = readFrameHeader();
+ inputPosition += fhSize - SIZE_OF_INT;
+ startFrame();
+ return true;
+ }
+ else {
+ throw fail(curInputFilePosition(), "Invalid magic prefix: " + magic);
+ }
+ }
+
+ private void startFrame()
+ {
+ blockDecompressor = new ZstdBlockDecompressor(curHeader);
+ check(outputPosition == outputEnd, "orphan output present");
+ outputPosition = 0;
+ outputEnd = 0;
+ if (singleSegmentFlag) {
+ if (curHeader.contentSize > MAX_WINDOW_SIZE) {
+ throw fail(curInputFilePosition(), "Single segment too large: " + curHeader.contentSize);
+ }
+ windowSize = (int) curHeader.contentSize;
+ blockMaximumSize = windowSize;
+ ensureOutputSpace(windowSize);
+ }
+ else {
+ if (curHeader.windowSize > MAX_WINDOW_SIZE) {
+ throw fail(curInputFilePosition(), "Window size too large: " + curHeader.windowSize);
+ }
+ windowSize = curHeader.windowSize;
+ blockMaximumSize = Math.min(windowSize, MAX_BLOCK_SIZE);
+ ensureOutputSpace(blockMaximumSize + windowSize);
+ }
+ if (contentChecksumFlag) {
+ hasher = new XxHash64();
+ }
+ }
+
+ private boolean ensureGotBlock() throws IOException
+ {
+ check(curHeader != null, "no current frame");
+ if (curBlockType == -1) {
+ // must have a block now
+ if (inputAvailable() < SIZE_OF_BLOCK_HEADER) {
+ readMoreInput();
+ // retry from start
+ return false;
+ }
+ int blkHeader = nextByte() | nextByte() << 8 | nextByte() << 16;
+ lastBlock = (blkHeader & 0b001) != 0;
+ curBlockType = (blkHeader & 0b110) >> 1;
+ curBlockSize = blkHeader >> 3;
+ ensureInputSpace(curBlockSize + SIZE_OF_INT);
+ }
+ if (inputAvailable() < curBlockSize + (contentChecksumFlag ? SIZE_OF_INT : 0)) {
+ readMoreInput();
+ // retry from start
+ return false;
+ }
+ return true;
+ }
+
+ int nextByte()
+ {
+ int r = 0xFF & inputBuffer[inputPosition];
+ inputPosition++;
+ return r;
+ }
+
+ long inputAddress()
+ {
+ return ARRAY_BYTE_BASE_OFFSET + inputPosition;
+ }
+
+ long inputLimit()
+ {
+ return ARRAY_BYTE_BASE_OFFSET + inputEnd;
+ }
+
+ long outputAddress()
+ {
+ return ARRAY_BYTE_BASE_OFFSET + outputEnd;
+ }
+
+ long outputLimit()
+ {
+ return ARRAY_BYTE_BASE_OFFSET + outputBuffer.length;
+ }
+
+ int decodeRaw()
+ {
+ check(inputAddress() + curBlockSize <= inputLimit(), "Not enough input bytes");
+ check(outputAddress() + curBlockSize <= outputLimit(), "Not enough output space");
+ return ZstdBlockDecompressor.decodeRawBlock(inputBuffer, inputAddress(), curBlockSize, outputBuffer, outputAddress(), outputLimit());
+ }
+
+ int decodeRle()
+ {
+ check(inputAddress() + 1 <= inputLimit(), "Not enough input bytes");
+ check(outputAddress() + curBlockSize <= outputLimit(), "Not enough output space");
+ return ZstdBlockDecompressor.decodeRleBlock(curBlockSize, inputBuffer, inputAddress(), outputBuffer, outputAddress(), outputLimit());
+ }
+
+ int decodeCompressed()
+ {
+ check(inputAddress() + curBlockSize <= inputLimit(), "Not enough input bytes");
+ check(outputAddress() + blockMaximumSize <= outputLimit(), "Not enough output space");
+ return blockDecompressor.decodeCompressedBlock(
+ inputBuffer, inputAddress(),
+ curBlockSize,
+ outputBuffer, outputAddress(), outputLimit(),
+ windowSize, ARRAY_BYTE_BASE_OFFSET);
+ }
+
+ private void decompressBlock()
+ {
+ check(outputPosition == outputEnd, "orphan output present");
+ switch (curBlockType) {
+ case RAW_BLOCK:
+ ensureOutputSpace(curBlockSize);
+ outputEnd += decodeRaw();
+ inputPosition += curBlockSize;
+ break;
+ case RLE_BLOCK:
+ ensureOutputSpace(curBlockSize);
+ outputEnd += decodeRle();
+ inputPosition += 1;
+ break;
+ case COMPRESSED_BLOCK:
+ check(curBlockSize < blockMaximumSize, "compressed block must be smaller than Block_Maximum_Size");
+ ensureOutputSpace(blockMaximumSize);
+ outputEnd += decodeCompressed();
+ inputPosition += curBlockSize;
+ break;
+ default:
+ throw fail(curInputFilePosition(), "Invalid block type " + curBlockType);
+ }
+ if (contentChecksumFlag) {
+ hasher.update(outputBuffer, outputPosition, outputAvailable());
+ }
+ curBlockType = -1;
+ if (lastBlock) {
+ curHeader = null;
+ blockDecompressor = null;
+ if (contentChecksumFlag) {
+ check(inputAvailable() >= SIZE_OF_INT, "missing checksum data");
+ long hash = hasher.hash();
+ int checksum = inputBB().getInt();
+ if (checksum != (int) hash) {
+ throw fail(curInputFilePosition(), String.format("Bad checksum. Expected: %s, actual: %s", Integer.toHexString(checksum), Integer.toHexString((int) hash)));
+ }
+ inputPosition += SIZE_OF_INT;
+ hasher = null;
+ }
+ }
+ }
+
+ private int inputAvailable()
+ {
+ return inputEnd - inputPosition;
+ }
+
+ private int inputSpace()
+ {
+ return inputBuffer.length - inputEnd;
+ }
+
+ private long curInputFilePosition()
+ {
+ return evictedInput + inputPosition;
+ }
+
+ private void ensureInputSpace(int size)
+ {
+ if (inputSpace() < size) {
+ if (size < inputPosition) {
+ System.arraycopy(inputBuffer, inputPosition, inputBuffer, 0, inputAvailable());
+ }
+ else {
+ int newSize = (inputBuffer.length + size + DEFAULT_BUFFER_SIZE) & BUFFER_SIZE_MASK;
+ byte[] newBuf = new byte[newSize];
+ System.arraycopy(inputBuffer, inputPosition, newBuf, 0, inputAvailable());
+ inputBuffer = newBuf;
+ }
+ evictedInput += inputPosition;
+ inputEnd = inputAvailable();
+ inputPosition = 0;
+ }
+ }
+
+ private int outputAvailable()
+ {
+ return outputEnd - outputPosition;
+ }
+
+ private int outputSpace()
+ {
+ return outputBuffer.length - outputEnd;
+ }
+
+ private void ensureOutputSpace(int size)
+ {
+ if (outputSpace() < size) {
+ check(outputAvailable() == 0, "logic error");
+ byte[] newBuf;
+ if (windowSize * 4 + size < outputPosition) {
+ // plenty space in old buffer
+ newBuf = outputBuffer;
+ }
+ else {
+ int newSize = (outputBuffer.length
+ + windowSize * 4
+ + size
+ + DEFAULT_BUFFER_SIZE) & BUFFER_SIZE_MASK;
+ newBuf = new byte[newSize];
+ }
+ // keep up to one window of old data
+ int sizeToKeep = Math.min(outputPosition, windowSize);
+ System.arraycopy(outputBuffer, outputPosition - sizeToKeep, newBuf, 0, sizeToKeep);
+ outputBuffer = newBuf;
+ outputEnd = sizeToKeep;
+ outputPosition = sizeToKeep;
+ }
+ }
+
+ private void throwIfClosed() throws IOException
+ {
+ if (isClosed) {
+ throw new IOException("Input stream is already closed");
+ }
+ }
+
+ private FrameHeader readFrameHeader()
+ {
+ long base = ARRAY_BYTE_BASE_OFFSET + inputPosition;
+ long limit = ARRAY_BYTE_BASE_OFFSET + inputEnd;
+ return ZstdFrameDecompressor.readFrameHeader(inputBuffer, base, limit);
+ }
+}