aboutsummaryrefslogtreecommitdiffstats
path: root/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Huffman.java
diff options
context:
space:
mode:
Diffstat (limited to 'airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Huffman.java')
-rw-r--r--airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Huffman.java323
1 files changed, 323 insertions, 0 deletions
diff --git a/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Huffman.java b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Huffman.java
new file mode 100644
index 00000000000..c8ed6a1f5f0
--- /dev/null
+++ b/airlift-zstd/src/main/java/ai/vespa/airlift/zstd/Huffman.java
@@ -0,0 +1,323 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package ai.vespa.airlift.zstd;
+
+import java.util.Arrays;
+
+import static ai.vespa.airlift.zstd.BitInputStream.isEndOfStream;
+import static ai.vespa.airlift.zstd.BitInputStream.peekBitsFast;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_INT;
+import static ai.vespa.airlift.zstd.Constants.SIZE_OF_SHORT;
+import static ai.vespa.airlift.zstd.UnsafeUtil.UNSAFE;
+import static ai.vespa.airlift.zstd.Util.isPowerOf2;
+import static ai.vespa.airlift.zstd.Util.verify;
+
+class Huffman
+{
+ public static final int MAX_SYMBOL = 255;
+ public static final int MAX_SYMBOL_COUNT = MAX_SYMBOL + 1;
+
+ public static final int MAX_TABLE_LOG = 12;
+ public static final int MIN_TABLE_LOG = 5;
+ public static final int MAX_FSE_TABLE_LOG = 6;
+
+ // stats
+ private final byte[] weights = new byte[MAX_SYMBOL + 1];
+ private final int[] ranks = new int[MAX_TABLE_LOG + 1];
+
+ // table
+ private int tableLog = -1;
+ private final byte[] symbols = new byte[1 << MAX_TABLE_LOG];
+ private final byte[] numbersOfBits = new byte[1 << MAX_TABLE_LOG];
+
+ private final FseTableReader reader = new FseTableReader();
+ private final FiniteStateEntropy.Table fseTable = new FiniteStateEntropy.Table(MAX_FSE_TABLE_LOG);
+
+ public boolean isLoaded()
+ {
+ return tableLog != -1;
+ }
+
+ public int readTable(final Object inputBase, final long inputAddress, final int size)
+ {
+ Arrays.fill(ranks, 0);
+ long input = inputAddress;
+
+ // read table header
+ verify(size > 0, input, "Not enough input bytes");
+ int inputSize = UNSAFE.getByte(inputBase, input++) & 0xFF;
+
+ int outputSize;
+ if (inputSize >= 128) {
+ outputSize = inputSize - 127;
+ inputSize = ((outputSize + 1) / 2);
+
+ verify(inputSize + 1 <= size, input, "Not enough input bytes");
+ verify(outputSize <= MAX_SYMBOL + 1, input, "Input is corrupted");
+
+ for (int i = 0; i < outputSize; i += 2) {
+ int value = UNSAFE.getByte(inputBase, input + i / 2) & 0xFF;
+ weights[i] = (byte) (value >>> 4);
+ weights[i + 1] = (byte) (value & 0b1111);
+ }
+ }
+ else {
+ verify(inputSize + 1 <= size, input, "Not enough input bytes");
+
+ long inputLimit = input + inputSize;
+ input += reader.readFseTable(fseTable, inputBase, input, inputLimit, FiniteStateEntropy.MAX_SYMBOL, MAX_FSE_TABLE_LOG);
+ outputSize = FiniteStateEntropy.decompress(fseTable, inputBase, input, inputLimit, weights);
+ }
+
+ int totalWeight = 0;
+ for (int i = 0; i < outputSize; i++) {
+ ranks[weights[i]]++;
+ totalWeight += (1 << weights[i]) >> 1; // TODO same as 1 << (weights[n] - 1)?
+ }
+ verify(totalWeight != 0, input, "Input is corrupted");
+
+ tableLog = Util.highestBit(totalWeight) + 1;
+ verify(tableLog <= MAX_TABLE_LOG, input, "Input is corrupted");
+
+ int total = 1 << tableLog;
+ int rest = total - totalWeight;
+ verify(isPowerOf2(rest), input, "Input is corrupted");
+
+ int lastWeight = Util.highestBit(rest) + 1;
+
+ weights[outputSize] = (byte) lastWeight;
+ ranks[lastWeight]++;
+
+ int numberOfSymbols = outputSize + 1;
+
+ // populate table
+ int nextRankStart = 0;
+ for (int i = 1; i < tableLog + 1; ++i) {
+ int current = nextRankStart;
+ nextRankStart += ranks[i] << (i - 1);
+ ranks[i] = current;
+ }
+
+ for (int n = 0; n < numberOfSymbols; n++) {
+ int weight = weights[n];
+ int length = (1 << weight) >> 1; // TODO: 1 << (weight - 1) ??
+
+ byte symbol = (byte) n;
+ byte numberOfBits = (byte) (tableLog + 1 - weight);
+ for (int i = ranks[weight]; i < ranks[weight] + length; i++) {
+ symbols[i] = symbol;
+ numbersOfBits[i] = numberOfBits;
+ }
+ ranks[weight] += length;
+ }
+
+ verify(ranks[1] >= 2 && (ranks[1] & 1) == 0, input, "Input is corrupted");
+
+ return inputSize + 1;
+ }
+
+ public void decodeSingleStream(final Object inputBase, final long inputAddress, final long inputLimit, final Object outputBase, final long outputAddress, final long outputLimit)
+ {
+ BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, inputAddress, inputLimit);
+ initializer.initialize();
+
+ long bits = initializer.getBits();
+ int bitsConsumed = initializer.getBitsConsumed();
+ long currentAddress = initializer.getCurrentAddress();
+
+ int tableLog = this.tableLog;
+ byte[] numbersOfBits = this.numbersOfBits;
+ byte[] symbols = this.symbols;
+
+ // 4 symbols at a time
+ long output = outputAddress;
+ long fastOutputLimit = outputLimit - 4;
+ while (output < fastOutputLimit) {
+ BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, inputAddress, currentAddress, bits, bitsConsumed);
+ boolean done = loader.load();
+ bits = loader.getBits();
+ bitsConsumed = loader.getBitsConsumed();
+ currentAddress = loader.getCurrentAddress();
+ if (done) {
+ break;
+ }
+
+ bitsConsumed = decodeSymbol(outputBase, output, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ bitsConsumed = decodeSymbol(outputBase, output + 1, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ bitsConsumed = decodeSymbol(outputBase, output + 2, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ bitsConsumed = decodeSymbol(outputBase, output + 3, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ output += SIZE_OF_INT;
+ }
+
+ decodeTail(inputBase, inputAddress, currentAddress, bitsConsumed, bits, outputBase, output, outputLimit);
+ }
+
+ public void decode4Streams(final Object inputBase, final long inputAddress, final long inputLimit, final Object outputBase, final long outputAddress, final long outputLimit)
+ {
+ verify(inputLimit - inputAddress >= 10, inputAddress, "Input is corrupted"); // jump table + 1 byte per stream
+
+ long start1 = inputAddress + 3 * SIZE_OF_SHORT; // for the shorts we read below
+ long start2 = start1 + (UNSAFE.getShort(inputBase, inputAddress) & 0xFFFF);
+ long start3 = start2 + (UNSAFE.getShort(inputBase, inputAddress + 2) & 0xFFFF);
+ long start4 = start3 + (UNSAFE.getShort(inputBase, inputAddress + 4) & 0xFFFF);
+
+ BitInputStream.Initializer initializer = new BitInputStream.Initializer(inputBase, start1, start2);
+ initializer.initialize();
+ int stream1bitsConsumed = initializer.getBitsConsumed();
+ long stream1currentAddress = initializer.getCurrentAddress();
+ long stream1bits = initializer.getBits();
+
+ initializer = new BitInputStream.Initializer(inputBase, start2, start3);
+ initializer.initialize();
+ int stream2bitsConsumed = initializer.getBitsConsumed();
+ long stream2currentAddress = initializer.getCurrentAddress();
+ long stream2bits = initializer.getBits();
+
+ initializer = new BitInputStream.Initializer(inputBase, start3, start4);
+ initializer.initialize();
+ int stream3bitsConsumed = initializer.getBitsConsumed();
+ long stream3currentAddress = initializer.getCurrentAddress();
+ long stream3bits = initializer.getBits();
+
+ initializer = new BitInputStream.Initializer(inputBase, start4, inputLimit);
+ initializer.initialize();
+ int stream4bitsConsumed = initializer.getBitsConsumed();
+ long stream4currentAddress = initializer.getCurrentAddress();
+ long stream4bits = initializer.getBits();
+
+ int segmentSize = (int) ((outputLimit - outputAddress + 3) / 4);
+
+ long outputStart2 = outputAddress + segmentSize;
+ long outputStart3 = outputStart2 + segmentSize;
+ long outputStart4 = outputStart3 + segmentSize;
+
+ long output1 = outputAddress;
+ long output2 = outputStart2;
+ long output3 = outputStart3;
+ long output4 = outputStart4;
+
+ long fastOutputLimit = outputLimit - 7;
+ int tableLog = this.tableLog;
+ byte[] numbersOfBits = this.numbersOfBits;
+ byte[] symbols = this.symbols;
+
+ while (output4 < fastOutputLimit) {
+ stream1bitsConsumed = decodeSymbol(outputBase, output1, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream2bitsConsumed = decodeSymbol(outputBase, output2, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream3bitsConsumed = decodeSymbol(outputBase, output3, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream4bitsConsumed = decodeSymbol(outputBase, output4, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
+
+ stream1bitsConsumed = decodeSymbol(outputBase, output1 + 1, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream2bitsConsumed = decodeSymbol(outputBase, output2 + 1, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream3bitsConsumed = decodeSymbol(outputBase, output3 + 1, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream4bitsConsumed = decodeSymbol(outputBase, output4 + 1, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
+
+ stream1bitsConsumed = decodeSymbol(outputBase, output1 + 2, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream2bitsConsumed = decodeSymbol(outputBase, output2 + 2, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream3bitsConsumed = decodeSymbol(outputBase, output3 + 2, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream4bitsConsumed = decodeSymbol(outputBase, output4 + 2, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
+
+ stream1bitsConsumed = decodeSymbol(outputBase, output1 + 3, stream1bits, stream1bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream2bitsConsumed = decodeSymbol(outputBase, output2 + 3, stream2bits, stream2bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream3bitsConsumed = decodeSymbol(outputBase, output3 + 3, stream3bits, stream3bitsConsumed, tableLog, numbersOfBits, symbols);
+ stream4bitsConsumed = decodeSymbol(outputBase, output4 + 3, stream4bits, stream4bitsConsumed, tableLog, numbersOfBits, symbols);
+
+ output1 += SIZE_OF_INT;
+ output2 += SIZE_OF_INT;
+ output3 += SIZE_OF_INT;
+ output4 += SIZE_OF_INT;
+
+ BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, start1, stream1currentAddress, stream1bits, stream1bitsConsumed);
+ boolean done = loader.load();
+ stream1bitsConsumed = loader.getBitsConsumed();
+ stream1bits = loader.getBits();
+ stream1currentAddress = loader.getCurrentAddress();
+
+ if (done) {
+ break;
+ }
+
+ loader = new BitInputStream.Loader(inputBase, start2, stream2currentAddress, stream2bits, stream2bitsConsumed);
+ done = loader.load();
+ stream2bitsConsumed = loader.getBitsConsumed();
+ stream2bits = loader.getBits();
+ stream2currentAddress = loader.getCurrentAddress();
+
+ if (done) {
+ break;
+ }
+
+ loader = new BitInputStream.Loader(inputBase, start3, stream3currentAddress, stream3bits, stream3bitsConsumed);
+ done = loader.load();
+ stream3bitsConsumed = loader.getBitsConsumed();
+ stream3bits = loader.getBits();
+ stream3currentAddress = loader.getCurrentAddress();
+ if (done) {
+ break;
+ }
+
+ loader = new BitInputStream.Loader(inputBase, start4, stream4currentAddress, stream4bits, stream4bitsConsumed);
+ done = loader.load();
+ stream4bitsConsumed = loader.getBitsConsumed();
+ stream4bits = loader.getBits();
+ stream4currentAddress = loader.getCurrentAddress();
+ if (done) {
+ break;
+ }
+ }
+
+ verify(output1 <= outputStart2 && output2 <= outputStart3 && output3 <= outputStart4, inputAddress, "Input is corrupted");
+
+ /// finish streams one by one
+ decodeTail(inputBase, start1, stream1currentAddress, stream1bitsConsumed, stream1bits, outputBase, output1, outputStart2);
+ decodeTail(inputBase, start2, stream2currentAddress, stream2bitsConsumed, stream2bits, outputBase, output2, outputStart3);
+ decodeTail(inputBase, start3, stream3currentAddress, stream3bitsConsumed, stream3bits, outputBase, output3, outputStart4);
+ decodeTail(inputBase, start4, stream4currentAddress, stream4bitsConsumed, stream4bits, outputBase, output4, outputLimit);
+ }
+
+ private void decodeTail(final Object inputBase, final long startAddress, long currentAddress, int bitsConsumed, long bits, final Object outputBase, long outputAddress, final long outputLimit)
+ {
+ int tableLog = this.tableLog;
+ byte[] numbersOfBits = this.numbersOfBits;
+ byte[] symbols = this.symbols;
+
+ // closer to the end
+ while (outputAddress < outputLimit) {
+ BitInputStream.Loader loader = new BitInputStream.Loader(inputBase, startAddress, currentAddress, bits, bitsConsumed);
+ boolean done = loader.load();
+ bitsConsumed = loader.getBitsConsumed();
+ bits = loader.getBits();
+ currentAddress = loader.getCurrentAddress();
+ if (done) {
+ break;
+ }
+
+ bitsConsumed = decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ }
+
+ // not more data in bit stream, so no need to reload
+ while (outputAddress < outputLimit) {
+ bitsConsumed = decodeSymbol(outputBase, outputAddress++, bits, bitsConsumed, tableLog, numbersOfBits, symbols);
+ }
+
+ verify(isEndOfStream(startAddress, currentAddress, bitsConsumed), startAddress, "Bit stream is not fully consumed");
+ }
+
+ private static int decodeSymbol(Object outputBase, long outputAddress, long bitContainer, int bitsConsumed, int tableLog, byte[] numbersOfBits, byte[] symbols)
+ {
+ int value = (int) peekBitsFast(bitsConsumed, bitContainer, tableLog);
+ UNSAFE.putByte(outputBase, outputAddress, symbols[value]);
+ return bitsConsumed + numbersOfBits[value];
+ }
+}