summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2024-01-19 07:52:35 +0100
committerGitHub <noreply@github.com>2024-01-19 07:52:35 +0100
commit58d5bb8337bcd9f1b7698eb2945a764824e3942e (patch)
tree789a087a0cdd3da355b2b4970861442891a70f63
parent7d8d6cc6a568ab695522cf0de50ff4e0b12b52ce (diff)
parentbceb0e5d4dd71c12a87cd15e18d31ec7ca4957e7 (diff)
Merge pull request #29974 from vespa-engine/balder/optimize-splade-embedder
Balder/optimize splade embedder
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java36
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java39
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java76
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java3
7 files changed, 94 insertions, 86 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index a805fc79a64..da3068c3744 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -105,8 +105,9 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
public Encoding encode(String text) { return encode(text, Language.UNKNOWN); }
public Encoding encode(String text, Language language) { return Encoding.from(resolve(language).encode(text)); }
- public String decode(List<Long> tokens) { return decode(tokens, Language.UNKNOWN); }
- public String decode(List<Long> tokens, Language language) { return resolve(language).decode(toArray(tokens)); }
+
+ public String decode(long [] tokens) { return decode(tokens, Language.UNKNOWN); }
+ public String decode(long [] tokens, Language language) { return resolve(language).decode(tokens); }
@Override
public void close() {
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 644b1ec538f..853009873a1 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -12,6 +12,7 @@ import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Paths;
import java.util.List;
@@ -139,24 +140,33 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
if (batch != 1) {
throw new IllegalArgumentException("Batch size must be 1");
}
- long sequenceLength = shape[1];
- long vocabSize = shape[2];
+ if (shape[1] > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("sequenceLength=" + shape[1] + " larger than an int");
+ }
+ if (shape[2] > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("vocabSize=" + shape[2] + " larger than an int");
+ }
+ int sequenceLength = (int) shape[1];
+ int vocabSize = (int) shape[2];
+ String dimension = tensorType.dimensions().get(0).name();
//Iterate over the vocab dimension and find the max value for each sequence token
- for(int v = 0; v < vocabSize; v++) {
- double maxLogOfRelu = Double.MIN_VALUE;
- for(int s = 0; s < sequenceLength; s++) {
+ long [] tokens = new long[1];
+ for (int v = 0; v < vocabSize; v++) {
+ double maxValue = 0.0d;
+ for (int s = 0; s < sequenceLength; s++) {
double value = modelOutput.get(0, s, v); // batch, sequence, vocab
- double logOfRelu = Math.log(1 + Math.max(0, value));
- if(logOfRelu > maxLogOfRelu) {
- maxLogOfRelu = logOfRelu;
+ if (value > maxValue) {
+ maxValue = value;
}
}
- if (maxLogOfRelu > termScoreThreshold) {
- String term = tokenizer.decode(List.of((long) v));
- builder.cell().
- label(tensorType.dimensions().get(0).name(), term)
- .value(maxLogOfRelu);
+ double logOfRelu = Math.log(1 + maxValue);
+ if (logOfRelu > termScoreThreshold) {
+ tokens[0] = v;
+ String term = tokenizer.decode(tokens);
+ builder.cell()
+ .label(dimension, term)
+ .value(logOfRelu);
}
}
return builder.build();
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
index 2612702e99b..07f2aea4ab6 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
@@ -53,10 +53,9 @@ class TensorConverter {
static OnnxTensor toOnnxTensor(Tensor vespaTensor, TensorInfo onnxTensorInfo, OrtEnvironment environment)
throws OrtException
{
- if ( ! (vespaTensor instanceof IndexedTensor)) {
+ if ( ! (vespaTensor instanceof IndexedTensor tensor)) {
throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions");
}
- IndexedTensor tensor = (IndexedTensor) vespaTensor;
ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder());
if (onnxTensorInfo.type == OnnxJavaType.FLOAT) {
for (int i = 0; i < tensor.size(); i++)
@@ -103,54 +102,54 @@ class TensorConverter {
}
static Tensor toVespaTensor(OnnxValue onnxValue) {
- if ( ! (onnxValue instanceof OnnxTensor)) {
+ if ( ! (onnxValue instanceof OnnxTensor onnxTensor)) {
throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported");
}
- OnnxTensor onnxTensor = (OnnxTensor) onnxValue;
TensorInfo tensorInfo = onnxTensor.getInfo();
TensorType type = toVespaType(onnxTensor.getInfo());
DimensionSizes sizes = sizesFromType(type);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes);
+ long totalSize = sizes.totalSize();
if (tensorInfo.type == OnnxJavaType.FLOAT) {
FloatBuffer buffer = onnxTensor.getFloatBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.DOUBLE) {
DoubleBuffer buffer = onnxTensor.getDoubleBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.INT8) {
ByteBuffer buffer = onnxTensor.getByteBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.INT16) {
ShortBuffer buffer = onnxTensor.getShortBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.INT32) {
IntBuffer buffer = onnxTensor.getIntBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.INT64) {
LongBuffer buffer = onnxTensor.getLongBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, buffer.get());
}
else if (tensorInfo.type == OnnxJavaType.FLOAT16) {
ShortBuffer buffer = onnxTensor.getShortBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get()));
}
else if (tensorInfo.type == OnnxJavaType.BFLOAT16) {
ShortBuffer buffer = onnxTensor.getShortBuffer();
- for (long i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < totalSize; i++)
builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get())));
}
else {
@@ -201,14 +200,14 @@ class TensorConverter {
}
static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) {
- switch (onnxType) {
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return TensorType.Value.FLOAT;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE;
- }
- return TensorType.Value.DOUBLE;
+ return switch (onnxType) {
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 -> TensorType.Value.INT8;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 -> TensorType.Value.BFLOAT16;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 -> TensorType.Value.FLOAT;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT -> TensorType.Value.FLOAT;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE -> TensorType.Value.DOUBLE;
+ default -> TensorType.Value.DOUBLE;
+ };
}
static private TensorInfo toTensorInfo(ValueInfo valueInfo) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index 83a625f72ac..640fa609432 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -11,10 +11,19 @@ import java.util.Arrays;
public final class DimensionSizes {
private final long[] sizes;
+ private final long[] productOfSizesFromHereOn;
+ private final long totalSize;
private DimensionSizes(Builder builder) {
this.sizes = builder.sizes;
builder.sizes = null; // invalidate builder to avoid copying the array
+ this.productOfSizesFromHereOn = new long[sizes.length];
+ long product = 1;
+ for (int i = sizes.length; i-- > 0; ) {
+ productOfSizesFromHereOn[i] = product;
+ product *= sizes[i];
+ }
+ this.totalSize = product;
}
/**
@@ -49,10 +58,11 @@ public final class DimensionSizes {
/** Returns the product of the sizes of this */
public long totalSize() {
- long productSize = 1;
- for (long dimensionSize : sizes )
- productSize *= dimensionSize;
- return productSize;
+ return totalSize;
+ }
+
+ long productOfDimensionsAfter(int afterIndex) {
+ return productOfSizesFromHereOn[afterIndex];
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 6a879fa533b..1319675f5d4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -90,7 +90,7 @@ public abstract class IndexedTensor implements Tensor {
* @throws IllegalArgumentException if any of the indexes are out of bound or a wrong number of indexes are given
*/
public double get(long ... indexes) {
- return get((int)toValueIndex(indexes, dimensionSizes));
+ return get(toValueIndex(indexes, dimensionSizes));
}
/**
@@ -108,7 +108,7 @@ public abstract class IndexedTensor implements Tensor {
public double get(TensorAddress address) {
// optimize for fast lookup within bounds:
try {
- return get((int)toValueIndex(address, dimensionSizes, type));
+ return get(toValueIndex(address, dimensionSizes, type));
}
catch (IllegalArgumentException e) {
return 0.0;
@@ -150,7 +150,7 @@ public abstract class IndexedTensor implements Tensor {
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] >= sizes.size(i))
throw new IllegalArgumentException(Arrays.toString(indexes) + " are not within bounds");
- valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i];
+ valueIndex += sizes.productOfDimensionsAfter(i) * indexes[i];
}
return valueIndex;
}
@@ -162,18 +162,11 @@ public abstract class IndexedTensor implements Tensor {
for (int i = 0; i < address.size(); i++) {
if (address.numericLabel(i) >= sizes.size(i))
throw new IllegalArgumentException(address + " is not within the bounds of " + type);
- valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i);
+ valueIndex += sizes.productOfDimensionsAfter(i) * address.numericLabel(i);
}
return valueIndex;
}
- private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) {
- long product = 1;
- for (int i = afterIndex + 1; i < sizes.dimensions(); i++)
- product *= sizes.size(i);
- return product;
- }
-
void throwOnIncompatibleType(TensorType type) {
if ( ! this.type().isRenamableTo(type))
throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type +
@@ -227,7 +220,7 @@ public abstract class IndexedTensor implements Tensor {
@Override
public String toAbbreviatedString(boolean withType, boolean shortForms) {
- return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1)));
+ return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1)));
}
private String toString(boolean withType, boolean shortForms, long maxCells) {
@@ -250,8 +243,7 @@ public abstract class IndexedTensor implements Tensor {
b.append(", ");
// start brackets
- for (int i = 0; i < indexes.nextDimensionsAtStart(); i++)
- b.append("[");
+ b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart())));
// value
switch (tensor.type().valueType()) {
@@ -264,8 +256,7 @@ public abstract class IndexedTensor implements Tensor {
}
// end bracket and comma
- for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
- b.append("]");
+ b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd())));
}
if (index == maxCells && index < tensor.size())
b.append(", ...]");
@@ -327,14 +318,13 @@ public abstract class IndexedTensor implements Tensor {
*/
public static Builder of(TensorType type, DimensionSizes sizes) {
validate(type, sizes);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
/**
@@ -348,14 +338,13 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, float[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
/**
@@ -369,14 +358,13 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, double[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
private static void validateSizes(DimensionSizes sizes, int length) {
@@ -518,7 +506,7 @@ public abstract class IndexedTensor implements Tensor {
if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension
for (long i = 0; i < currentDimension.size(); i++)
fillValues(currentDimensionIndex + 1,
- offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i,
+ offset + sizes.productOfDimensionsAfter(currentDimensionIndex) * i,
(List<Object>) currentDimension.get((int)i), sizes, values);
} else { // last dimension - fill values
for (long i = 0; i < currentDimension.size(); i++) {
@@ -1091,8 +1079,8 @@ public abstract class IndexedTensor implements Tensor {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
- this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceSizes);
- this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateSizes);
+ this.sourceStep = sourceSizes.productOfDimensionsAfter(iterateDimension);
+ this.iterationStep = iterateSizes.productOfDimensionsAfter(iterateDimension);
// Initialize to the (virtual) position before the first cell
indexes[iterateDimension]--;
@@ -1156,7 +1144,7 @@ public abstract class IndexedTensor implements Tensor {
super(sizes, sizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
- this.step = productOfDimensionsAfter(iterateDimension, sizes);
+ this.step = sizes.productOfDimensionsAfter(iterateDimension);
// Initialize to the (virtual) position before the first cell
indexes[iterateDimension]--;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index e529c7f71d2..5471ea65b97 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
-import java.util.function.DoubleBinaryOperator;
/**
* A sparse implementation of a tensor backed by a Map of cells to values.
@@ -83,7 +82,7 @@ public class MappedTensor implements Tensor {
@Override
public String toAbbreviatedString(boolean withType, boolean shortForms) {
- return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1)));
+ return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1)));
}
private String toString(boolean withType, boolean shortForms, long maxCells) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index e44df06ed20..cc8e1602adb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -20,6 +20,7 @@ import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.XwPlusB;
import com.yahoo.tensor.functions.Expand;
+import com.yahoo.tensor.impl.NumericTensorAddress;
import java.util.ArrayList;
import java.util.Arrays;
@@ -623,7 +624,7 @@ public interface Tensor {
public TensorType type() { return tensorBuilder.type(); }
public CellBuilder label(String dimension, long label) {
- return label(dimension, String.valueOf(label));
+ return label(dimension, NumericTensorAddress.asString(label));
}
public Builder value(double cellValue) {