summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-01-11 10:49:28 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2024-01-11 10:49:28 +0100
commit52a0c0597b85dc45fc69cd01f2f9d859a11bb348 (patch)
treef5fc486c2e3d8a61c10715760a437b7697e25890 /model-integration
parent04d491286aa2a6f8b3a04048936419c6cde4e3ec (diff)
Avoid generic reduce to reduce gc pressure
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java65
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java15
2 files changed, 61 insertions, 19 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 4af7820274f..78fb3142704 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -13,7 +13,6 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Reduce;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
@@ -22,9 +21,7 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES
/**
* A SPLADE embedder that is embedding text to a 1-d mapped tensor. For interpretability, the tensor labels
- * are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0). This
- * instead of using the token identifier.
- *
+ * are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0).
*/
@Beta
public class SpladeEmbedder extends AbstractComponent implements Embedder {
@@ -119,26 +116,21 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
attentionMaskName, attentionMask.expand("d0"),
tokenTypeIdsName, tokenTypeIds.expand("d0"));
-
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- //Remove batch dim, batch size of 1
- Tensor output = outputs.get(outputName).reduce(Reduce.Aggregator.max, "d0");
- Tensor mappedTensor = sparsify(output, tensorType);
+ Tensor spladeTensor = sparsify((IndexedTensor) evaluator.evaluate(inputs).get(outputName), tensorType);
runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context);
- return mappedTensor;
+ return spladeTensor;
}
+
/**
- * Sparsify the output tensor by applying a threshold on the log of the relu of the output.
- * @param modelOutput the model output tensor of shape d1,dim where d1 is the sequence length and dim is size
- * of the vocabulary
+ * Sparsify the model output tensor.
+ *
+ * @param modelOutput the model output tensor of type tensorType
* @param tensorType the type of the destination tensor
* @return A mapped tensor with the terms from the vocab that has a score above the threshold
*/
- public Tensor sparsify(Tensor modelOutput, TensorType tensorType) {
- Tensor logOfRelu = modelOutput.map((x) -> Math.log(1 + Math.max(0, x)));
- Tensor maxReduced = logOfRelu.reduce(Reduce.Aggregator.max, "d1");
- IndexedTensor vocab = (IndexedTensor) maxReduced;
+ public Tensor sparsify(IndexedTensor modelOutput, TensorType tensorType) {
+ IndexedTensor vocab = customMaxReduceOverLogOfRelu(modelOutput);
var builder = Tensor.Builder.of(tensorType);
for(int i = 0; i < vocab.size(); i++) {
var score = vocab.get(i);
@@ -152,7 +144,6 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
-
private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
int size = input.size();
TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
@@ -167,4 +158,42 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
evaluator.close();
tokenizer.close();
}
+
+ /**
+ * Custom max reduce over the sequence dimension of the output tensor. This
+ * to reduce GC pressure from generic Tensor.reduce operation.
+ *
+ * @param tensor the model output tensor of shape d1,dim where d1 is the sequence length and dim is size
+ * of the vocabulary
+ * @return A tensor of shape d1,1 where each value is the max of the log of the relu of the input tensor
+ */
+ private static IndexedTensor customMaxReduceOverLogOfRelu(IndexedTensor tensor) {
+ long[] shape = tensor.shape();
+ if(shape.length != 3) {
+ throw new IllegalArgumentException("The indexed tensor must be 3-dimensional");
+ }
+ long batch = shape[0];
+ if (batch != 1) {
+ throw new IllegalArgumentException("Batch size must be 1");
+ }
+ long sequenceLength = shape[1];
+ long vocabSize = shape[2];
+
+ TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed("vocab", vocabSize).build();
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
+ //Iterate over the vocab dimension and find the max value for each sequence token
+ for(int v = 0; v < vocabSize; v++) {
+ double maxLogOfRelu = Double.MIN_VALUE;
+ for(int s = 0; s < sequenceLength; s++) {
+ double value = tensor.get(0, s, v); // batch, sequence, vocab
+ double logOfRelu = Math.log(1 + Math.max(0, value));
+ if(logOfRelu > maxLogOfRelu) {
+ maxLogOfRelu = logOfRelu;
+ }
+ }
+ builder.cell(maxLogOfRelu, v);
+ }
+ return builder.build();
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
index e0d940ca5fe..5099a251f00 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
@@ -44,6 +44,19 @@ public class SpladeEmbedderTest {
}
@Test
+ public void testPerformanceNotTerrible() {
+ String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" +
+ " ii the project was led by the united states with the support of the united kingdom and canada";
+ Long now = System.currentTimeMillis();
+ int n = 10;
+ for (int i = 0; i < n; i++) {
+ assertEmbed("tensor<float>(t{})", text, indexingContext);
+ }
+ Long elapsed = (System.currentTimeMillis() - now)/1000;
+ System.out.println("Elapsed time: " + elapsed + " ms");
+ }
+
+ @Test
public void throwsOnInvalidTensorType() {
Throwable exception = assertThrows(RuntimeException.class, () -> {
assertEmbed("tensor<float>(d[128])", "", indexingContext);
@@ -54,7 +67,7 @@ public class SpladeEmbedderTest {
static final Embedder spladeEmbedder;
static final Embedder.Context indexingContext;
- static final Double scoreThreshold = 1.15;
+ static final Double scoreThreshold = 1.15;
static {
indexingContext = new Embedder.Context("schema.indexing");