diff options
author | Jon Bratseth <bratseth@gmail.com> | 2024-01-19 07:52:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-19 07:52:35 +0100 |
commit | 58d5bb8337bcd9f1b7698eb2945a764824e3942e (patch) | |
tree | 789a087a0cdd3da355b2b4970861442891a70f63 | |
parent | 7d8d6cc6a568ab695522cf0de50ff4e0b12b52ce (diff) | |
parent | bceb0e5d4dd71c12a87cd15e18d31ec7ca4957e7 (diff) |
Merge pull request #29974 from vespa-engine/balder/optimize-splade-embedder
Balder/optimize splade embedder
7 files changed, 94 insertions, 86 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index a805fc79a64..da3068c3744 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -105,8 +105,9 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public Encoding encode(String text) { return encode(text, Language.UNKNOWN); } public Encoding encode(String text, Language language) { return Encoding.from(resolve(language).encode(text)); } - public String decode(List<Long> tokens) { return decode(tokens, Language.UNKNOWN); } - public String decode(List<Long> tokens, Language language) { return resolve(language).decode(toArray(tokens)); } + + public String decode(long [] tokens) { return decode(tokens, Language.UNKNOWN); } + public String decode(long [] tokens, Language language) { return resolve(language).decode(tokens); } @Override public void close() { diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 644b1ec538f..853009873a1 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -12,6 +12,7 @@ import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; @@ -139,24 +140,33 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { if (batch != 1) { throw new IllegalArgumentException("Batch size must be 1"); } - long sequenceLength = shape[1]; - long vocabSize = shape[2]; + if (shape[1] > Integer.MAX_VALUE) { + throw new IllegalArgumentException("sequenceLength=" + shape[1] + " larger than an int"); + } + if (shape[2] > Integer.MAX_VALUE) { + throw new IllegalArgumentException("vocabSize=" + shape[2] + " larger than an int"); + } + int sequenceLength = (int) shape[1]; + int vocabSize = (int) shape[2]; + String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token - for(int v = 0; v < vocabSize; v++) { - double maxLogOfRelu = Double.MIN_VALUE; - for(int s = 0; s < sequenceLength; s++) { + long [] tokens = new long[1]; + for (int v = 0; v < vocabSize; v++) { + double maxValue = 0.0d; + for (int s = 0; s < sequenceLength; s++) { double value = modelOutput.get(0, s, v); // batch, sequence, vocab - double logOfRelu = Math.log(1 + Math.max(0, value)); - if(logOfRelu > maxLogOfRelu) { - maxLogOfRelu = logOfRelu; + if (value > maxValue) { + maxValue = value; } } - if (maxLogOfRelu > termScoreThreshold) { - String term = tokenizer.decode(List.of((long) v)); - builder.cell(). - label(tensorType.dimensions().get(0).name(), term) - .value(maxLogOfRelu); + double logOfRelu = Math.log(1 + maxValue); + if (logOfRelu > termScoreThreshold) { + tokens[0] = v; + String term = tokenizer.decode(tokens); + builder.cell() + .label(dimension, term) + .value(logOfRelu); } } return builder.build(); diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index 2612702e99b..07f2aea4ab6 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -53,10 +53,9 @@ class TensorConverter { static OnnxTensor toOnnxTensor(Tensor vespaTensor, TensorInfo onnxTensorInfo, OrtEnvironment environment) throws OrtException { - if ( ! (vespaTensor instanceof IndexedTensor)) { + if ( ! (vespaTensor instanceof IndexedTensor tensor)) { throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions"); } - IndexedTensor tensor = (IndexedTensor) vespaTensor; ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder()); if (onnxTensorInfo.type == OnnxJavaType.FLOAT) { for (int i = 0; i < tensor.size(); i++) @@ -103,54 +102,54 @@ class TensorConverter { } static Tensor toVespaTensor(OnnxValue onnxValue) { - if ( ! (onnxValue instanceof OnnxTensor)) { + if ( ! (onnxValue instanceof OnnxTensor onnxTensor)) { throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported"); } - OnnxTensor onnxTensor = (OnnxTensor) onnxValue; TensorInfo tensorInfo = onnxTensor.getInfo(); TensorType type = toVespaType(onnxTensor.getInfo()); DimensionSizes sizes = sizesFromType(type); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes); + long totalSize = sizes.totalSize(); if (tensorInfo.type == OnnxJavaType.FLOAT) { FloatBuffer buffer = onnxTensor.getFloatBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.DOUBLE) { DoubleBuffer buffer = onnxTensor.getDoubleBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.INT8) { ByteBuffer buffer = onnxTensor.getByteBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.INT16) { ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.INT32) { IntBuffer buffer = onnxTensor.getIntBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.INT64) { LongBuffer buffer = onnxTensor.getLongBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, buffer.get()); } else if (tensorInfo.type == OnnxJavaType.FLOAT16) { ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get())); } else if (tensorInfo.type == OnnxJavaType.BFLOAT16) { ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < totalSize; i++) builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get()))); } else { @@ -201,14 +200,14 @@ class TensorConverter { } static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) { - switch (onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return TensorType.Value.FLOAT; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE; - } - return TensorType.Value.DOUBLE; + return switch (onnxType) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 -> TensorType.Value.INT8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 -> TensorType.Value.BFLOAT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 -> TensorType.Value.FLOAT; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT -> TensorType.Value.FLOAT; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE -> TensorType.Value.DOUBLE; + default -> TensorType.Value.DOUBLE; + }; } static private TensorInfo toTensorInfo(ValueInfo valueInfo) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index 83a625f72ac..640fa609432 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -11,10 +11,19 @@ import java.util.Arrays; public final class DimensionSizes { private final long[] sizes; + private final long[] productOfSizesFromHereOn; + private final long totalSize; private DimensionSizes(Builder builder) { this.sizes = builder.sizes; builder.sizes = null; // invalidate builder to avoid copying the array + this.productOfSizesFromHereOn = new long[sizes.length]; + long product = 1; + for (int i = sizes.length; i-- > 0; ) { + productOfSizesFromHereOn[i] = product; + product *= sizes[i]; + } + this.totalSize = product; } /** @@ -49,10 +58,11 @@ public final class DimensionSizes { /** Returns the product of the sizes of this */ public long totalSize() { - long productSize = 1; - for (long dimensionSize : sizes ) - productSize *= dimensionSize; - return productSize; + return totalSize; + } + + long productOfDimensionsAfter(int afterIndex) { + return productOfSizesFromHereOn[afterIndex]; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 6a879fa533b..1319675f5d4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -90,7 +90,7 @@ public abstract class IndexedTensor implements Tensor { * @throws IllegalArgumentException if any of the indexes are out of bound or a wrong number of indexes are given */ public double get(long ... indexes) { - return get((int)toValueIndex(indexes, dimensionSizes)); + return get(toValueIndex(indexes, dimensionSizes)); } /** @@ -108,7 +108,7 @@ public abstract class IndexedTensor implements Tensor { public double get(TensorAddress address) { // optimize for fast lookup within bounds: try { - return get((int)toValueIndex(address, dimensionSizes, type)); + return get(toValueIndex(address, dimensionSizes, type)); } catch (IllegalArgumentException e) { return 0.0; @@ -150,7 +150,7 @@ public abstract class IndexedTensor implements Tensor { for (int i = 0; i < indexes.length; i++) { if (indexes[i] >= sizes.size(i)) throw new IllegalArgumentException(Arrays.toString(indexes) + " are not within bounds"); - valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i]; + valueIndex += sizes.productOfDimensionsAfter(i) * indexes[i]; } return valueIndex; } @@ -162,18 +162,11 @@ public abstract class IndexedTensor implements Tensor { for (int i = 0; i < address.size(); i++) { if (address.numericLabel(i) >= sizes.size(i)) throw new IllegalArgumentException(address + " is not within the bounds of " + type); - valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i); + valueIndex += sizes.productOfDimensionsAfter(i) * address.numericLabel(i); } return valueIndex; } - private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) { - long product = 1; - for (int i = afterIndex + 1; i < sizes.dimensions(); i++) - product *= sizes.size(i); - return product; - } - void throwOnIncompatibleType(TensorType type) { if ( ! this.type().isRenamableTo(type)) throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type + @@ -227,7 +220,7 @@ public abstract class IndexedTensor implements Tensor { @Override public String toAbbreviatedString(boolean withType, boolean shortForms) { - return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1))); } private String toString(boolean withType, boolean shortForms, long maxCells) { @@ -250,8 +243,7 @@ public abstract class IndexedTensor implements Tensor { b.append(", "); // start brackets - for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) - b.append("["); + b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart()))); // value switch (tensor.type().valueType()) { @@ -264,8 +256,7 @@ public abstract class IndexedTensor implements Tensor { } // end bracket and comma - for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) - b.append("]"); + b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd()))); } if (index == maxCells && index < tensor.size()) b.append(", ...]"); @@ -327,14 +318,13 @@ public abstract class IndexedTensor implements Tensor { */ public static Builder of(TensorType type, DimensionSizes sizes) { validate(type, sizes); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } /** @@ -348,14 +338,13 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, float[] values) { validate(type, sizes); validateSizes(sizes, values.length); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } /** @@ -369,14 +358,13 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, double[] values) { validate(type, sizes); validateSizes(sizes, values.length); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } private static void validateSizes(DimensionSizes sizes, int length) { @@ -518,7 +506,7 @@ public abstract class IndexedTensor implements Tensor { if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension for (long i = 0; i < currentDimension.size(); i++) fillValues(currentDimensionIndex + 1, - offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i, + offset + sizes.productOfDimensionsAfter(currentDimensionIndex) * i, (List<Object>) currentDimension.get((int)i), sizes, values); } else { // last dimension - fill values for (long i = 0; i < currentDimension.size(); i++) { @@ -1091,8 +1079,8 @@ public abstract class IndexedTensor implements Tensor { super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; - this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceSizes); - this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateSizes); + this.sourceStep = sourceSizes.productOfDimensionsAfter(iterateDimension); + this.iterationStep = iterateSizes.productOfDimensionsAfter(iterateDimension); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; @@ -1156,7 +1144,7 @@ public abstract class IndexedTensor implements Tensor { super(sizes, sizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; - this.step = productOfDimensionsAfter(iterateDimension, sizes); + this.step = sizes.productOfDimensionsAfter(iterateDimension); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index e529c7f71d2..5471ea65b97 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableMap; import java.util.Iterator; import java.util.Map; import java.util.Set; -import java.util.function.DoubleBinaryOperator; /** * A sparse implementation of a tensor backed by a Map of cells to values. @@ -83,7 +82,7 @@ public class MappedTensor implements Tensor { @Override public String toAbbreviatedString(boolean withType, boolean shortForms) { - return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1))); } private String toString(boolean withType, boolean shortForms, long maxCells) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index e44df06ed20..cc8e1602adb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -20,6 +20,7 @@ import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.XwPlusB; import com.yahoo.tensor.functions.Expand; +import com.yahoo.tensor.impl.NumericTensorAddress; import java.util.ArrayList; import java.util.Arrays; @@ -623,7 +624,7 @@ public interface Tensor { public TensorType type() { return tensorBuilder.type(); } public CellBuilder label(String dimension, long label) { - return label(dimension, String.valueOf(label)); + return label(dimension, NumericTensorAddress.asString(label)); } public Builder value(double cellValue) { |