aboutsummaryrefslogtreecommitdiffstats
path: root/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdBlockDecompressor.java
diff options
context:
space:
mode:
Diffstat (limited to 'airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdBlockDecompressor.java')
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdBlockDecompressor.java810
1 files changed, 810 insertions, 0 deletions
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdBlockDecompressor.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdBlockDecompressor.java
new file mode 100644
index 00000000000..08adf88cfa6
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/ZstdBlockDecompressor.java
@@ -0,0 +1,810 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import java.util.Arrays;
+
+import static ai.vespa.airlift.zstd.BitInputStream.peekBits;
+import static ai.vespa.airlift.zstd.Constants.COMPRESSED_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.COMPRESSED_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.DEFAULT_MAX_OFFSET_CODE_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.LITERALS_LENGTH_BITS;
+import static ai.vespa.airlift.zstd.Constants.LITERAL_LENGTH_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Constants.LONG_NUMBER_OF_SEQUENCES;
+import static ai.vespa.airlift.zstd.Constants.MATCH_LENGTH_BITS;
+import static ai.vespa.airlift.zstd.Constants.MATCH_LENGTH_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Constants.MAX_BLOCK_SIZE;
+import static ai.vespa.airlift.zstd.Constants.MAX_LITERALS_LENGTH_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.MAX_MATCH_LENGTH_SYMBOL;
+import static ai.vespa.airlift.zstd.Constants.MIN_BLOCK_SIZE;
+import static ai.vespa.airlift.zstd.Constants.MIN_SEQUENCES_SIZE;
+import static ai.vespa.airlift.zstd.Constants.OFFSET_TABLE_LOG;
+import static ai.vespa.airlift.zstd.Constants.RAW_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.RAW_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.RLE_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.RLE_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.Constants.SEQUENCE_ENCODING_BASIC;
+import static ai.vespa.airlift.zstd.Constants.SEQUENCE_ENCODING_COMPRESSED;
+import static ai.vespa.airlift.zstd.Constants.SEQUENCE_ENCODING_REPEAT;
+import static ai.vespa.airlift.zstd.Constants.SEQUENCE_ENCODING_RLE;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_INT;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_LONG;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_SHORT;
+import static ai.vespa.airlift.zstd.Constants.TREELESS_LITERALS_BLOCK;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.fail;
+import static ai.vespa.airlift.zstd.Util.mask;
+import static ai.vespa.airlift.zstd.Util.verify;
+import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
+
+/**
+ * Handles decompression of all blocks in a single frame.
+ **/
+class ZstdBlockDecompressor
+{
+ private static final int[] DEC_32_TABLE = {4, 1, 2, 1, 4, 4, 4, 4};
+ private static final int[] DEC_64_TABLE = {0, 0, 0, -1, 0, 1, 2, 3};
+
+ private static final int MAX_WINDOW_SIZE = 1 << 23;
+
+ private static final int[] LITERALS_LENGTH_BASE = {
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 16, 18, 20, 22, 24, 28, 32, 40, 48, 64, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000,
+ 0x2000, 0x4000, 0x8000, 0x10000};
+
+ private static final int[] MATCH_LENGTH_BASE = {
+ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
+ 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
+ 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, 99, 0x83, 0x103, 0x203, 0x403, 0x803,
+ 0x1003, 0x2003, 0x4003, 0x8003, 0x10003};
+
+ private static final int[] OFFSET_CODES_BASE = {
+ 0, 1, 1, 5, 0xD, 0x1D, 0x3D, 0x7D,
+ 0xFD, 0x1FD, 0x3FD, 0x7FD, 0xFFD, 0x1FFD, 0x3FFD, 0x7FFD,
+ 0xFFFD, 0x1FFFD, 0x3FFFD, 0x7FFFD, 0xFFFFD, 0x1FFFFD, 0x3FFFFD, 0x7FFFFD,
+ 0xFFFFFD, 0x1FFFFFD, 0x3FFFFFD, 0x7FFFFFD, 0xFFFFFFD};
+
+ private static final FiniteStateEntropy.Table DEFAULT_LITERALS_LENGTH_TABLE = new FiniteStateEntropy.Table(
+ 6,
+ new int[] {
+ 0, 16, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 32, 0, 32, 0, 32, 0, 0, 16, 32, 0, 0, 48, 16, 32, 32, 32,
+ 32, 32, 32, 32, 32, 0, 32, 32, 32, 32, 32, 32, 0, 0, 0, 0},
+ new byte[] {
+ 0, 0, 1, 3, 4, 6, 7, 9, 10, 12, 14, 16, 18, 19, 21, 22, 24, 25, 26, 27, 29, 31, 0, 1, 2, 4, 5, 7, 8, 10, 11, 13, 16, 17, 19, 20, 22, 23, 25, 25, 26, 28, 30, 0,
+ 1, 2, 3, 5, 6, 8, 9, 11, 12, 15, 17, 18, 20, 21, 23, 24, 35, 34, 33, 32},
+ new byte[] {
+ 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 4, 4, 5, 6, 6, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
+ 6, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6});
+
+ private static final FiniteStateEntropy.Table DEFAULT_OFFSET_CODES_TABLE = new FiniteStateEntropy.Table(
+ 5,
+ new int[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 16, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0},
+ new byte[] {0, 6, 9, 15, 21, 3, 7, 12, 18, 23, 5, 8, 14, 20, 2, 7, 11, 17, 22, 4, 8, 13, 19, 1, 6, 10, 16, 28, 27, 26, 25, 24},
+ new byte[] {5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 4, 5, 5, 5, 5, 4, 5, 5, 5, 4, 5, 5, 5, 5, 5, 5, 5});
+
+ private static final FiniteStateEntropy.Table DEFAULT_MATCH_LENGTH_TABLE = new FiniteStateEntropy.Table(
+ 6,
+ new int[] {
+ 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 32, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 48, 16, 32, 32, 32, 32,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ new byte[] {
+ 0, 1, 2, 3, 5, 6, 8, 10, 13, 16, 19, 22, 25, 28, 31, 33, 35, 37, 39, 41, 43, 45, 1, 2, 3, 4, 6, 7, 9, 12, 15, 18, 21, 24, 27, 30, 32, 34, 36, 38, 40, 42, 44, 1,
+ 1, 2, 4, 5, 7, 8, 11, 14, 17, 20, 23, 26, 29, 52, 51, 50, 49, 48, 47, 46},
+ new byte[] {
+ 6, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6,
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6});
+
+ private final byte[] literals = new byte[MAX_BLOCK_SIZE + SIZE_OF_LONG]; // extra space to allow for long-at-a-time copy
+
+ // current buffer containing literals
+ private Object literalsBase;
+ private long literalsAddress;
+ private long literalsLimit;
+
+ private final int[] previousOffsets = new int[3];
+
+ private final FiniteStateEntropy.Table literalsLengthTable = new FiniteStateEntropy.Table(LITERAL_LENGTH_TABLE_LOG);
+ private final FiniteStateEntropy.Table offsetCodesTable = new FiniteStateEntropy.Table(OFFSET_TABLE_LOG);
+ private final FiniteStateEntropy.Table matchLengthTable = new FiniteStateEntropy.Table(MATCH_LENGTH_TABLE_LOG);
+
+ private FiniteStateEntropy.Table currentLiteralsLengthTable;
+ private FiniteStateEntropy.Table currentOffsetCodesTable;
+ private FiniteStateEntropy.Table currentMatchLengthTable;
+
+ private final Huffman huffman = new Huffman();
+ private final FseTableReader fse = new FseTableReader();
+
+ private final FrameHeader frameHeader;
+
+ public ZstdBlockDecompressor(FrameHeader frameHeader)
+ {
+ this.frameHeader = frameHeader;
+
+ previousOffsets[0] = 1;
+ previousOffsets[1] = 4;
+ previousOffsets[2] = 8;
+ }
+
+ int decompressBlock(
+ int blockType,
+ int blockSize,
+ final Object inputBase,
+ final long inputAddress,
+ final long inputLimit,
+ final Object outputBase,
+ final long outputAddress,
+ final long outputLimit)
+ {
+ int decodedSize;
+ switch (blockType) {
+ case RAW_BLOCK:
+ verify(inputAddress + blockSize <= inputLimit, inputAddress, "Not enough input bytes");
+ decodedSize = decodeRawBlock(inputBase, inputAddress, blockSize, outputBase, outputAddress, outputLimit);
+ break;
+ case RLE_BLOCK:
+ verify(inputAddress + 1 <= inputLimit, inputAddress, "Not enough input bytes");
+ decodedSize = decodeRleBlock(blockSize, inputBase, inputAddress, outputBase, outputAddress, outputLimit);
+ break;
+ case COMPRESSED_BLOCK:
+ verify(inputAddress + blockSize <= inputLimit, inputAddress, "Not enough input bytes");
+ decodedSize = decodeCompressedBlock(inputBase, inputAddress, blockSize, outputBase, outputAddress, outputLimit, frameHeader.windowSize, outputAddress);
+ break;
+ default:
+ throw fail(inputAddress, "Invalid block type");
+ }
+ return decodedSize;
+ }
+
+ static int decodeRawBlock(Object inputBase, long inputAddress, int blockSize, Object outputBase, long outputAddress, long outputLimit)
+ {
+ verify(outputAddress + blockSize <= outputLimit, inputAddress, "Output buffer too small");
+
+ UNSAFE.copyMemory(inputBase, inputAddress, outputBase, outputAddress, blockSize);
+ return blockSize;
+ }
+
+ static int decodeRleBlock(int size, Object inputBase, long inputAddress, Object outputBase, long outputAddress, long outputLimit)
+ {
+ verify(outputAddress + size <= outputLimit, inputAddress, "Output buffer too small");
+
+ long output = outputAddress;
+ long value = UNSAFE.getByte(inputBase, inputAddress) & 0xFFL;
+
+ int remaining = size;
+ if (remaining >= SIZE_OF_LONG) {
+ long packed = value
+ | (value << 8)
+ | (value << 16)
+ | (value << 24)
+ | (value << 32)
+ | (value << 40)
+ | (value << 48)
+ | (value << 56);
+
+ do {
+ UNSAFE.putLong(outputBase, output, packed);
+ output += SIZE_OF_LONG;
+ remaining -= SIZE_OF_LONG;
+ }
+ while (remaining >= SIZE_OF_LONG);
+ }
+
+ for (int i = 0; i < remaining; i++) {
+ UNSAFE.putByte(outputBase, output, (byte) value);
+ output++;
+ }
+
+ return size;
+ }
+
+ @SuppressWarnings("fallthrough")
+ int decodeCompressedBlock(Object inputBase, final long inputAddress, int blockSize, Object outputBase, long outputAddress, long outputLimit, int windowSize, long outputAbsoluteBaseAddress)
+ {
+ long inputLimit = inputAddress + blockSize;
+ long input = inputAddress;
+
+ verify(blockSize <= MAX_BLOCK_SIZE, input, "Expected match length table to be present");
+ verify(blockSize >= MIN_BLOCK_SIZE, input, "Compressed block size too small");
+
+ // decode literals
+ int literalsBlockType = UNSAFE.getByte(inputBase, input) & 0b11;
+
+ switch (literalsBlockType) {
+ case RAW_LITERALS_BLOCK: {
+ input += decodeRawLiterals(inputBase, input, inputLimit);
+ break;
+ }
+ case RLE_LITERALS_BLOCK: {
+ input += decodeRleLiterals(inputBase, input, blockSize);
+ break;
+ }
+ case TREELESS_LITERALS_BLOCK:
+ verify(huffman.isLoaded(), input, "Dictionary is corrupted");
+ case COMPRESSED_LITERALS_BLOCK: {
+ input += decodeCompressedLiterals(inputBase, input, blockSize, literalsBlockType);
+ break;
+ }
+ default:
+ throw fail(input, "Invalid literals block encoding type");
+ }
+
+ verify(windowSize <= MAX_WINDOW_SIZE, input, "Window size too large (not yet supported)");
+
+ return decompressSequences(
+ inputBase, input, inputAddress + blockSize,
+ outputBase, outputAddress, outputLimit,
+ literalsBase, literalsAddress, literalsLimit,
+ outputAbsoluteBaseAddress);
+ }
+
+ private int decompressSequences(
+ final Object inputBase, final long inputAddress, final long inputLimit,
+ final Object outputBase, final long outputAddress, final long outputLimit,
+ final Object literalsBase, final long literalsAddress, final long literalsLimit,
+ long outputAbsoluteBaseAddress)
+ {
+ final long fastOutputLimit = outputLimit - SIZE_OF_LONG;
+ final long fastMatchOutputLimit = fastOutputLimit - SIZE_OF_LONG;
+
+ long input = inputAddress;
+ long output = outputAddress;
+
+ long literalsInput = literalsAddress;
+
+ int size = (int) (inputLimit - inputAddress);
+ verify(size >= MIN_SEQUENCES_SIZE, input, "Not enough input bytes");
+
+ // decode header
+ int sequenceCount = UNSAFE.getByte(inputBase, input++) & 0xFF;
+ if (sequenceCount != 0) {
+ if (sequenceCount == 255) {
+ verify(input + SIZE_OF_SHORT <= inputLimit, input, "Not enough input bytes");
+ sequenceCount = (UNSAFE.getShort(inputBase, input) & 0xFFFF) + LONG_NUMBER_OF_SEQUENCES;
+ input += SIZE_OF_SHORT;
+ }
+ else if (sequenceCount > 127) {
+ verify(input < inputLimit, input, "Not enough input bytes");
+ sequenceCount = ((sequenceCount - 128) << 8) + (UNSAFE.getByte(inputBase, input++) & 0xFF);
+ }
+
+ verify(input + SIZE_OF_INT <= inputLimit, input, "Not enough input bytes");
+
+ byte type = UNSAFE.getByte(inputBase, input++);
+
+ int literalsLengthType = (type & 0xFF) >>> 6;
+ int offsetCodesType = (type >>> 4) & 0b11;
+ int matchLengthType = (type >>> 2) & 0b11;
+
+ input = computeLiteralsTable(literalsLengthType, inputBase, input, inputLimit);
+ input = computeOffsetsTable(offsetCodesType, inputBase, input, inputLimit);
+ input = computeMatchLengthTable(matchLengthType, inputBase, input, inputLimit);
+
+ // decompress sequences
+ BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, input, inputLimit);
+ initializer.initialize();
+ int bitsConsumed = initializer.getBitsConsumed();
+ long bits = initializer.getBits();
+ long currentAddress = initializer.getCurrentAddress();
+
+ FiniteStateEntropy.Table currentLiteralsLengthTable = this.currentLiteralsLengthTable;
+ FiniteStateEntropy.Table currentOffsetCodesTable = this.currentOffsetCodesTable;
+ FiniteStateEntropy.Table currentMatchLengthTable = this.currentMatchLengthTable;
+
+ int literalsLengthState = (int) peekBits(bitsConsumed, bits, currentLiteralsLengthTable.log2Size);
+ bitsConsumed += currentLiteralsLengthTable.log2Size;
+
+ int offsetCodesState = (int) peekBits(bitsConsumed, bits, currentOffsetCodesTable.log2Size);
+ bitsConsumed += currentOffsetCodesTable.log2Size;
+
+ int matchLengthState = (int) peekBits(bitsConsumed, bits, currentMatchLengthTable.log2Size);
+ bitsConsumed += currentMatchLengthTable.log2Size;
+
+ int[] previousOffsets = this.previousOffsets;
+
+ byte[] literalsLengthNumbersOfBits = currentLiteralsLengthTable.numberOfBits;
+ int[] literalsLengthNewStates = currentLiteralsLengthTable.newState;
+ byte[] literalsLengthSymbols = currentLiteralsLengthTable.symbol;
+
+ byte[] matchLengthNumbersOfBits = currentMatchLengthTable.numberOfBits;
+ int[] matchLengthNewStates = currentMatchLengthTable.newState;
+ byte[] matchLengthSymbols = currentMatchLengthTable.symbol;
+
+ byte[] offsetCodesNumbersOfBits = currentOffsetCodesTable.numberOfBits;
+ int[] offsetCodesNewStates = currentOffsetCodesTable.newState;
+ byte[] offsetCodesSymbols = currentOffsetCodesTable.symbol;
+
+ while (sequenceCount > 0) {
+ sequenceCount--;
+
+ BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
+ loader.load();
+ bitsConsumed = loader.getBitsConsumed();
+ bits = loader.getBits();
+ currentAddress = loader.getCurrentAddress();
+ if (loader.isOverflow()) {
+ verify(sequenceCount == 0, input, "Not all sequences were consumed");
+ break;
+ }
+
+ // decode sequence
+ int literalsLengthCode = literalsLengthSymbols[literalsLengthState];
+ int matchLengthCode = matchLengthSymbols[matchLengthState];
+ int offsetCode = offsetCodesSymbols[offsetCodesState];
+
+ int literalsLengthBits = LITERALS_LENGTH_BITS[literalsLengthCode];
+ int matchLengthBits = MATCH_LENGTH_BITS[matchLengthCode];
+ int offsetBits = offsetCode;
+
+ int offset = OFFSET_CODES_BASE[offsetCode];
+ if (offsetCode > 0) {
+ offset += peekBits(bitsConsumed, bits, offsetBits);
+ bitsConsumed += offsetBits;
+ }
+
+ if (offsetCode <= 1) {
+ if (literalsLengthCode == 0) {
+ offset++;
+ }
+
+ if (offset != 0) {
+ int temp;
+ if (offset == 3) {
+ temp = previousOffsets[0] - 1;
+ }
+ else {
+ temp = previousOffsets[offset];
+ }
+
+ if (temp == 0) {
+ temp = 1;
+ }
+
+ if (offset != 1) {
+ previousOffsets[2] = previousOffsets[1];
+ }
+ previousOffsets[1] = previousOffsets[0];
+ previousOffsets[0] = temp;
+
+ offset = temp;
+ }
+ else {
+ offset = previousOffsets[0];
+ }
+ }
+ else {
+ previousOffsets[2] = previousOffsets[1];
+ previousOffsets[1] = previousOffsets[0];
+ previousOffsets[0] = offset;
+ }
+
+ int matchLength = MATCH_LENGTH_BASE[matchLengthCode];
+ if (matchLengthCode > 31) {
+ matchLength += peekBits(bitsConsumed, bits, matchLengthBits);
+ bitsConsumed += matchLengthBits;
+ }
+
+ int literalsLength = LITERALS_LENGTH_BASE[literalsLengthCode];
+ if (literalsLengthCode > 15) {
+ literalsLength += peekBits(bitsConsumed, bits, literalsLengthBits);
+ bitsConsumed += literalsLengthBits;
+ }
+
+ int totalBits = literalsLengthBits + matchLengthBits + offsetBits;
+ if (totalBits > 64 - 7 - (LITERAL_LENGTH_TABLE_LOG + MATCH_LENGTH_TABLE_LOG + OFFSET_TABLE_LOG)) {
+ BitInputStream.Loader loader1 = new BitInputStream.Loader(inputBase, input, currentAddress, bits, bitsConsumed);
+ loader1.load();
+
+ bitsConsumed = loader1.getBitsConsumed();
+ bits = loader1.getBits();
+ currentAddress = loader1.getCurrentAddress();
+ }
+
+ int numberOfBits;
+
+ numberOfBits = literalsLengthNumbersOfBits[literalsLengthState];
+ literalsLengthState = (int) (literalsLengthNewStates[literalsLengthState] + peekBits(bitsConsumed, bits, numberOfBits)); // <= 9 bits
+ bitsConsumed += numberOfBits;
+
+ numberOfBits = matchLengthNumbersOfBits[matchLengthState];
+ matchLengthState = (int) (matchLengthNewStates[matchLengthState] + peekBits(bitsConsumed, bits, numberOfBits)); // <= 9 bits
+ bitsConsumed += numberOfBits;
+
+ numberOfBits = offsetCodesNumbersOfBits[offsetCodesState];
+ offsetCodesState = (int) (offsetCodesNewStates[offsetCodesState] + peekBits(bitsConsumed, bits, numberOfBits)); // <= 8 bits
+ bitsConsumed += numberOfBits;
+
+ final long literalOutputLimit = output + literalsLength;
+ final long matchOutputLimit = literalOutputLimit + matchLength;
+
+ verify(matchOutputLimit <= outputLimit, input, "Output buffer too small");
+ long literalEnd = literalsInput + literalsLength;
+ verify(literalEnd <= literalsLimit, input, "Input is corrupted");
+
+ long matchAddress = literalOutputLimit - offset;
+ verify(matchAddress >= outputAbsoluteBaseAddress, input, "Input is corrupted");
+
+ if (literalOutputLimit > fastOutputLimit) {
+ executeLastSequence(outputBase, output, literalOutputLimit, matchOutputLimit, fastOutputLimit, literalsInput, matchAddress);
+ }
+ else {
+ // copy literals. literalOutputLimit <= fastOutputLimit, so we can copy
+ // long at a time with over-copy
+ output = copyLiterals(outputBase, literalsBase, output, literalsInput, literalOutputLimit);
+ copyMatch(outputBase, fastOutputLimit, output, offset, matchOutputLimit, matchAddress, matchLength, fastMatchOutputLimit);
+ }
+ output = matchOutputLimit;
+ literalsInput = literalEnd;
+ }
+ }
+
+ // last literal segment
+ output = copyLastLiteral(outputBase, literalsBase, literalsLimit, output, literalsInput);
+
+ return (int) (output - outputAddress);
+ }
+
+ private long copyLastLiteral(Object outputBase, Object literalsBase, long literalsLimit, long output, long literalsInput)
+ {
+ long lastLiteralsSize = literalsLimit - literalsInput;
+ UNSAFE.copyMemory(literalsBase, literalsInput, outputBase, output, lastLiteralsSize);
+ output += lastLiteralsSize;
+ return output;
+ }
+
+ private void copyMatch(Object outputBase, long fastOutputLimit, long output, int offset, long matchOutputLimit, long matchAddress, int matchLength, long fastMatchOutputLimit)
+ {
+ matchAddress = copyMatchHead(outputBase, output, offset, matchAddress);
+ output += SIZE_OF_LONG;
+ matchLength -= SIZE_OF_LONG; // first 8 bytes copied above
+
+ copyMatchTail(outputBase, fastOutputLimit, output, matchOutputLimit, matchAddress, matchLength, fastMatchOutputLimit);
+ }
+
+ private void copyMatchTail(Object outputBase, long fastOutputLimit, long output, long matchOutputLimit, long matchAddress, int matchLength, long fastMatchOutputLimit)
+ {
+ // fastMatchOutputLimit is just fastOutputLimit - SIZE_OF_LONG. It needs to be passed in so that it can be computed once for the
+ // whole invocation to decompressSequences. Otherwise, we'd just compute it here.
+ // If matchOutputLimit is < fastMatchOutputLimit, we know that even after the head (8 bytes) has been copied, the output pointer
+ // will be within fastOutputLimit, so it's safe to copy blindly before checking the limit condition
+ if (matchOutputLimit < fastMatchOutputLimit) {
+ int copied = 0;
+ do {
+ UNSAFE.putLong(outputBase, output, UNSAFE.getLong(outputBase, matchAddress));
+ output += SIZE_OF_LONG;
+ matchAddress += SIZE_OF_LONG;
+ copied += SIZE_OF_LONG;
+ }
+ while (copied < matchLength);
+ }
+ else {
+ while (output < fastOutputLimit) {
+ UNSAFE.putLong(outputBase, output, UNSAFE.getLong(outputBase, matchAddress));
+ matchAddress += SIZE_OF_LONG;
+ output += SIZE_OF_LONG;
+ }
+
+ while (output < matchOutputLimit) {
+ UNSAFE.putByte(outputBase, output++, UNSAFE.getByte(outputBase, matchAddress++));
+ }
+ }
+ }
+
+ private long copyMatchHead(Object outputBase, long output, int offset, long matchAddress)
+ {
+ // copy match
+ if (offset < 8) {
+ // 8 bytes apart so that we can copy long-at-a-time below
+ int increment32 = DEC_32_TABLE[offset];
+ int decrement64 = DEC_64_TABLE[offset];
+
+ UNSAFE.putByte(outputBase, output, UNSAFE.getByte(outputBase, matchAddress));
+ UNSAFE.putByte(outputBase, output + 1, UNSAFE.getByte(outputBase, matchAddress + 1));
+ UNSAFE.putByte(outputBase, output + 2, UNSAFE.getByte(outputBase, matchAddress + 2));
+ UNSAFE.putByte(outputBase, output + 3, UNSAFE.getByte(outputBase, matchAddress + 3));
+ matchAddress += increment32;
+
+ UNSAFE.putInt(outputBase, output + 4, UNSAFE.getInt(outputBase, matchAddress));
+ matchAddress -= decrement64;
+ }
+ else {
+ UNSAFE.putLong(outputBase, output, UNSAFE.getLong(outputBase, matchAddress));
+ matchAddress += SIZE_OF_LONG;
+ }
+ return matchAddress;
+ }
+
+ private long copyLiterals(Object outputBase, Object literalsBase, long output, long literalsInput, long literalOutputLimit)
+ {
+ long literalInput = literalsInput;
+ do {
+ UNSAFE.putLong(outputBase, output, UNSAFE.getLong(literalsBase, literalInput));
+ output += SIZE_OF_LONG;
+ literalInput += SIZE_OF_LONG;
+ }
+ while (output < literalOutputLimit);
+ output = literalOutputLimit; // correction in case we over-copied
+ return output;
+ }
+
+ private long computeMatchLengthTable(int matchLengthType, Object inputBase, long input, long inputLimit)
+ {
+ switch (matchLengthType) {
+ case SEQUENCE_ENCODING_RLE:
+ verify(input < inputLimit, input, "Not enough input bytes");
+
+ byte value = UNSAFE.getByte(inputBase, input++);
+ verify(value <= MAX_MATCH_LENGTH_SYMBOL, input, "Value exceeds expected maximum value");
+
+ FseTableReader.initializeRleTable(matchLengthTable, value);
+ currentMatchLengthTable = matchLengthTable;
+ break;
+ case SEQUENCE_ENCODING_BASIC:
+ currentMatchLengthTable = DEFAULT_MATCH_LENGTH_TABLE;
+ break;
+ case SEQUENCE_ENCODING_REPEAT:
+ verify(currentMatchLengthTable != null, input, "Expected match length table to be present");
+ break;
+ case SEQUENCE_ENCODING_COMPRESSED:
+ input += fse.readFseTable(matchLengthTable, inputBase, input, inputLimit, MAX_MATCH_LENGTH_SYMBOL, MATCH_LENGTH_TABLE_LOG);
+ currentMatchLengthTable = matchLengthTable;
+ break;
+ default:
+ throw fail(input, "Invalid match length encoding type");
+ }
+ return input;
+ }
+
+ private long computeOffsetsTable(int offsetCodesType, Object inputBase, long input, long inputLimit)
+ {
+ switch (offsetCodesType) {
+ case SEQUENCE_ENCODING_RLE:
+ verify(input < inputLimit, input, "Not enough input bytes");
+
+ byte value = UNSAFE.getByte(inputBase, input++);
+ verify(value <= DEFAULT_MAX_OFFSET_CODE_SYMBOL, input, "Value exceeds expected maximum value");
+
+ FseTableReader.initializeRleTable(offsetCodesTable, value);
+ currentOffsetCodesTable = offsetCodesTable;
+ break;
+ case SEQUENCE_ENCODING_BASIC:
+ currentOffsetCodesTable = DEFAULT_OFFSET_CODES_TABLE;
+ break;
+ case SEQUENCE_ENCODING_REPEAT:
+ verify(currentOffsetCodesTable != null, input, "Expected match length table to be present");
+ break;
+ case SEQUENCE_ENCODING_COMPRESSED:
+ input += fse.readFseTable(offsetCodesTable, inputBase, input, inputLimit, DEFAULT_MAX_OFFSET_CODE_SYMBOL, OFFSET_TABLE_LOG);
+ currentOffsetCodesTable = offsetCodesTable;
+ break;
+ default:
+ throw fail(input, "Invalid offset code encoding type");
+ }
+ return input;
+ }
+
+ private long computeLiteralsTable(int literalsLengthType, Object inputBase, long input, long inputLimit)
+ {
+ switch (literalsLengthType) {
+ case SEQUENCE_ENCODING_RLE:
+ verify(input < inputLimit, input, "Not enough input bytes");
+
+ byte value = UNSAFE.getByte(inputBase, input++);
+ verify(value <= MAX_LITERALS_LENGTH_SYMBOL, input, "Value exceeds expected maximum value");
+
+ FseTableReader.initializeRleTable(literalsLengthTable, value);
+ currentLiteralsLengthTable = literalsLengthTable;
+ break;
+ case SEQUENCE_ENCODING_BASIC:
+ currentLiteralsLengthTable = DEFAULT_LITERALS_LENGTH_TABLE;
+ break;
+ case SEQUENCE_ENCODING_REPEAT:
+ verify(currentLiteralsLengthTable != null, input, "Expected match length table to be present");
+ break;
+ case SEQUENCE_ENCODING_COMPRESSED:
+ input += fse.readFseTable(literalsLengthTable, inputBase, input, inputLimit, MAX_LITERALS_LENGTH_SYMBOL, LITERAL_LENGTH_TABLE_LOG);
+ currentLiteralsLengthTable = literalsLengthTable;
+ break;
+ default:
+ throw fail(input, "Invalid literals length encoding type");
+ }
+ return input;
+ }
+
+ private void executeLastSequence(Object outputBase, long output, long literalOutputLimit, long matchOutputLimit, long fastOutputLimit, long literalInput, long matchAddress)
+ {
+ // copy literals
+ if (output < fastOutputLimit) {
+ // wild copy
+ do {
+ UNSAFE.putLong(outputBase, output, UNSAFE.getLong(literalsBase, literalInput));
+ output += SIZE_OF_LONG;
+ literalInput += SIZE_OF_LONG;
+ }
+ while (output < fastOutputLimit);
+
+ literalInput -= output - fastOutputLimit;
+ output = fastOutputLimit;
+ }
+
+ while (output < literalOutputLimit) {
+ UNSAFE.putByte(outputBase, output, UNSAFE.getByte(literalsBase, literalInput));
+ output++;
+ literalInput++;
+ }
+
+ // copy match
+ while (output < matchOutputLimit) {
+ UNSAFE.putByte(outputBase, output, UNSAFE.getByte(outputBase, matchAddress));
+ output++;
+ matchAddress++;
+ }
+ }
+
+ @SuppressWarnings("fallthrough")
+ private int decodeCompressedLiterals(Object inputBase, final long inputAddress, int blockSize, int literalsBlockType)
+ {
+ long input = inputAddress;
+ verify(blockSize >= 5, input, "Not enough input bytes");
+
+ // compressed
+ int compressedSize;
+ int uncompressedSize;
+ boolean singleStream = false;
+ int headerSize;
+ int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0b11;
+ switch (type) {
+ case 0:
+ singleStream = true;
+ case 1: {
+ int header = UNSAFE.getInt(inputBase, input);
+
+ headerSize = 3;
+ uncompressedSize = (header >>> 4) & mask(10);
+ compressedSize = (header >>> 14) & mask(10);
+ break;
+ }
+ case 2: {
+ int header = UNSAFE.getInt(inputBase, input);
+
+ headerSize = 4;
+ uncompressedSize = (header >>> 4) & mask(14);
+ compressedSize = (header >>> 18) & mask(14);
+ break;
+ }
+ case 3: {
+ // read 5 little-endian bytes
+ long header = UNSAFE.getByte(inputBase, input) & 0xFF |
+ (UNSAFE.getInt(inputBase, input + 1) & 0xFFFF_FFFFL) << 8;
+
+ headerSize = 5;
+ uncompressedSize = (int) ((header >>> 4) & mask(18));
+ compressedSize = (int) ((header >>> 22) & mask(18));
+ break;
+ }
+ default:
+ throw fail(input, "Invalid literals header size type");
+ }
+
+ verify(uncompressedSize <= MAX_BLOCK_SIZE, input, "Block exceeds maximum size");
+ verify(headerSize + compressedSize <= blockSize, input, "Input is corrupted");
+
+ input += headerSize;
+
+ long inputLimit = input + compressedSize;
+ if (literalsBlockType != TREELESS_LITERALS_BLOCK) {
+ input += huffman.readTable(inputBase, input, compressedSize);
+ }
+
+ literalsBase = literals;
+ literalsAddress = ARRAY_BYTE_BASE_OFFSET;
+ literalsLimit = ARRAY_BYTE_BASE_OFFSET + uncompressedSize;
+
+ if (singleStream) {
+ huffman.decodeSingleStream(inputBase, input, inputLimit, literals, literalsAddress, literalsLimit);
+ }
+ else {
+ huffman.decode4Streams(inputBase, input, inputLimit, literals, literalsAddress, literalsLimit);
+ }
+
+ return headerSize + compressedSize;
+ }
+
+ private int decodeRleLiterals(Object inputBase, final long inputAddress, int blockSize)
+ {
+ long input = inputAddress;
+ int outputSize;
+
+ int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0b11;
+ switch (type) {
+ case 0:
+ case 2:
+ outputSize = (UNSAFE.getByte(inputBase, input) & 0xFF) >>> 3;
+ input++;
+ break;
+ case 1:
+ outputSize = (UNSAFE.getShort(inputBase, input) & 0xFFFF) >>> 4;
+ input += 2;
+ break;
+ case 3:
+ // we need at least 4 bytes (3 for the header, 1 for the payload)
+ verify(blockSize >= SIZE_OF_INT, input, "Not enough input bytes");
+ outputSize = (UNSAFE.getInt(inputBase, input) & 0xFF_FFFF) >>> 4;
+ input += 3;
+ break;
+ default:
+ throw fail(input, "Invalid RLE literals header encoding type");
+ }
+
+ verify(outputSize <= MAX_BLOCK_SIZE, input, "Output exceeds maximum block size");
+
+ byte value = UNSAFE.getByte(inputBase, input++);
+ Arrays.fill(literals, 0, outputSize + SIZE_OF_LONG, value);
+
+ literalsBase = literals;
+ literalsAddress = ARRAY_BYTE_BASE_OFFSET;
+ literalsLimit = ARRAY_BYTE_BASE_OFFSET + outputSize;
+
+ return (int) (input - inputAddress);
+ }
+
+ private int decodeRawLiterals(Object inputBase, final long inputAddress, long inputLimit)
+ {
+ long input = inputAddress;
+ int type = (UNSAFE.getByte(inputBase, input) >> 2) & 0b11;
+
+ int literalSize;
+ switch (type) {
+ case 0:
+ case 2:
+ literalSize = (UNSAFE.getByte(inputBase, input) & 0xFF) >>> 3;
+ input++;
+ break;
+ case 1:
+ literalSize = (UNSAFE.getShort(inputBase, input) & 0xFFFF) >>> 4;
+ input += 2;
+ break;
+ case 3:
+ // read 3 little-endian bytes
+ int header = ((UNSAFE.getByte(inputBase, input) & 0xFF) |
+ ((UNSAFE.getShort(inputBase, input + 1) & 0xFFFF) << 8));
+
+ literalSize = header >>> 4;
+ input += 3;
+ break;
+ default:
+ throw fail(input, "Invalid raw literals header encoding type");
+ }
+
+ verify(input + literalSize <= inputLimit, input, "Not enough input bytes");
+
+ // Set literals pointer to [input, literalSize], but only if we can copy 8 bytes at a time during sequence decoding
+ // Otherwise, copy literals into buffer that's big enough to guarantee that
+ if (literalSize > (inputLimit - input) - SIZE_OF_LONG) {
+ literalsBase = literals;
+ literalsAddress = ARRAY_BYTE_BASE_OFFSET;
+ literalsLimit = ARRAY_BYTE_BASE_OFFSET + literalSize;
+
+ UNSAFE.copyMemory(inputBase, input, literals, literalsAddress, literalSize);
+ Arrays.fill(literals, literalSize, literalSize + SIZE_OF_LONG, (byte) 0);
+ }
+ else {
+ literalsBase = inputBase;
+ literalsAddress = input;
+ literalsLimit = literalsAddress + literalSize;
+ }
+ input += literalSize;
+
+ return (int) (input - inputAddress);
+ }
+}