summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahoo-inc.com>2024-01-11 13:09:13 +0100
committerGitHub <noreply@github.com>2024-01-11 13:09:13 +0100
commitb4b5bd584110601471abf51bc59f29752e295fca (patch)
tree070978da8a83a0039281e75b4fa4dfa0b26bd055
parentc35da2bfe42797997cff3c6d42c491c5566698e7 (diff)
parenta7b7875202cf6648c7797803128dafd06382ef46 (diff)
Merge pull request #29851 from vespa-engine/jobergum/custom-reducer
Avoid generic reduce to reduce gc pressure
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java56
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java14
2 files changed, 47 insertions, 23 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 4af7820274f..644b1ec538f 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -13,7 +13,6 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Reduce;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
@@ -22,9 +21,7 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES
/**
* A SPLADE embedder that is embedding text to a 1-d mapped tensor. For interpretability, the tensor labels
- * are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0). This
- * instead of using the token identifier.
- *
+ * are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0).
*/
@Beta
public class SpladeEmbedder extends AbstractComponent implements Embedder {
@@ -119,40 +116,52 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
attentionMaskName, attentionMask.expand("d0"),
tokenTypeIdsName, tokenTypeIds.expand("d0"));
-
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- //Remove batch dim, batch size of 1
- Tensor output = outputs.get(outputName).reduce(Reduce.Aggregator.max, "d0");
- Tensor mappedTensor = sparsify(output, tensorType);
+ Tensor spladeTensor = sparsify((IndexedTensor) evaluator.evaluate(inputs).get(outputName), tensorType);
runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context);
- return mappedTensor;
+ return spladeTensor;
}
+
/**
- * Sparsify the output tensor by applying a threshold on the log of the relu of the output.
- * @param modelOutput the model output tensor of shape d1,dim where d1 is the sequence length and dim is size
- * of the vocabulary
+ * Sparsify the model output tensor.
+ *
+ * @param modelOutput the model output tensor of type tensorType
* @param tensorType the type of the destination tensor
* @return A mapped tensor with the terms from the vocab that has a score above the threshold
*/
- public Tensor sparsify(Tensor modelOutput, TensorType tensorType) {
- Tensor logOfRelu = modelOutput.map((x) -> Math.log(1 + Math.max(0, x)));
- Tensor maxReduced = logOfRelu.reduce(Reduce.Aggregator.max, "d1");
- IndexedTensor vocab = (IndexedTensor) maxReduced;
+ public Tensor sparsify(IndexedTensor modelOutput, TensorType tensorType) {
var builder = Tensor.Builder.of(tensorType);
- for(int i = 0; i < vocab.size(); i++) {
- var score = vocab.get(i);
- if (score > termScoreThreshold) {
- String term = tokenizer.decode(List.of((long) i));
+ long[] shape = modelOutput.shape();
+ if(shape.length != 3) {
+ throw new IllegalArgumentException("The indexed tensor must be 3-dimensional");
+ }
+ long batch = shape[0];
+ if (batch != 1) {
+ throw new IllegalArgumentException("Batch size must be 1");
+ }
+ long sequenceLength = shape[1];
+ long vocabSize = shape[2];
+
+ //Iterate over the vocab dimension and find the max value for each sequence token
+ for(int v = 0; v < vocabSize; v++) {
+ double maxLogOfRelu = Double.MIN_VALUE;
+ for(int s = 0; s < sequenceLength; s++) {
+ double value = modelOutput.get(0, s, v); // batch, sequence, vocab
+ double logOfRelu = Math.log(1 + Math.max(0, value));
+ if(logOfRelu > maxLogOfRelu) {
+ maxLogOfRelu = logOfRelu;
+ }
+ }
+ if (maxLogOfRelu > termScoreThreshold) {
+ String term = tokenizer.decode(List.of((long) v));
builder.cell().
label(tensorType.dimensions().get(0).name(), term)
- .value(score);
+ .value(maxLogOfRelu);
}
}
return builder.build();
}
-
private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
int size = input.size();
TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
@@ -167,4 +176,5 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
evaluator.close();
tokenizer.close();
}
+
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
index e0d940ca5fe..9ecb0e3e162 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
@@ -11,6 +11,7 @@ import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.util.List;
+import org.junit.Ignore;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
@@ -43,6 +44,19 @@ public class SpladeEmbedderTest {
assertEquals(0, result.size());
}
+ @Ignore
+ public void testPerformanceNotTerrible() {
+ String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" +
+ " ii the project was led by the united states with the support of the united kingdom and canada";
+ Long now = System.currentTimeMillis();
+ int n = 10;
+ for (int i = 0; i < n; i++) {
+ assertEmbed("tensor<float>(t{})", text, indexingContext);
+ }
+ Long elapsed = (System.currentTimeMillis() - now)/1000;
+ System.out.println("Elapsed time: " + elapsed + " ms");
+ }
+
@Test
public void throwsOnInvalidTensorType() {
Throwable exception = assertThrows(RuntimeException.class, () -> {