diff options
author | Jo Kristian Bergum <bergum@yahoo-inc.com> | 2024-01-11 13:09:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-11 13:09:13 +0100 |
commit | b4b5bd584110601471abf51bc59f29752e295fca (patch) | |
tree | 070978da8a83a0039281e75b4fa4dfa0b26bd055 | |
parent | c35da2bfe42797997cff3c6d42c491c5566698e7 (diff) | |
parent | a7b7875202cf6648c7797803128dafd06382ef46 (diff) |
Merge pull request #29851 from vespa-engine/jobergum/custom-reducer
Avoid generic reduce to reduce gc pressure
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java | 56 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java | 14 |
2 files changed, 47 insertions, 23 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 4af7820274f..644b1ec538f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -13,7 +13,6 @@ import com.yahoo.language.process.Embedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Reduce; import java.nio.file.Paths; import java.util.List; import java.util.Map; @@ -22,9 +21,7 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES /** * A SPLADE embedder that is embedding text to a 1-d mapped tensor. For interpretability, the tensor labels - * are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0). This - * instead of using the token identifier. - * + * are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0). */ @Beta public class SpladeEmbedder extends AbstractComponent implements Embedder { @@ -119,40 +116,52 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"), attentionMaskName, attentionMask.expand("d0"), tokenTypeIdsName, tokenTypeIds.expand("d0")); - - Map<String, Tensor> outputs = evaluator.evaluate(inputs); - //Remove batch dim, batch size of 1 - Tensor output = outputs.get(outputName).reduce(Reduce.Aggregator.max, "d0"); - Tensor mappedTensor = sparsify(output, tensorType); + Tensor spladeTensor = sparsify((IndexedTensor) evaluator.evaluate(inputs).get(outputName), tensorType); runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); - return mappedTensor; + return spladeTensor; } + /** - * Sparsify the output tensor by applying a threshold on the log of the relu of the output. - * @param modelOutput the model output tensor of shape d1,dim where d1 is the sequence length and dim is size - * of the vocabulary + * Sparsify the model output tensor. + * + * @param modelOutput the model output tensor of type tensorType * @param tensorType the type of the destination tensor * @return A mapped tensor with the terms from the vocab that has a score above the threshold */ - public Tensor sparsify(Tensor modelOutput, TensorType tensorType) { - Tensor logOfRelu = modelOutput.map((x) -> Math.log(1 + Math.max(0, x))); - Tensor maxReduced = logOfRelu.reduce(Reduce.Aggregator.max, "d1"); - IndexedTensor vocab = (IndexedTensor) maxReduced; + public Tensor sparsify(IndexedTensor modelOutput, TensorType tensorType) { var builder = Tensor.Builder.of(tensorType); - for(int i = 0; i < vocab.size(); i++) { - var score = vocab.get(i); - if (score > termScoreThreshold) { - String term = tokenizer.decode(List.of((long) i)); + long[] shape = modelOutput.shape(); + if(shape.length != 3) { + throw new IllegalArgumentException("The indexed tensor must be 3-dimensional"); + } + long batch = shape[0]; + if (batch != 1) { + throw new IllegalArgumentException("Batch size must be 1"); + } + long sequenceLength = shape[1]; + long vocabSize = shape[2]; + + //Iterate over the vocab dimension and find the max value for each sequence token + for(int v = 0; v < vocabSize; v++) { + double maxLogOfRelu = Double.MIN_VALUE; + for(int s = 0; s < sequenceLength; s++) { + double value = modelOutput.get(0, s, v); // batch, sequence, vocab + double logOfRelu = Math.log(1 + Math.max(0, value)); + if(logOfRelu > maxLogOfRelu) { + maxLogOfRelu = logOfRelu; + } + } + if (maxLogOfRelu > termScoreThreshold) { + String term = tokenizer.decode(List.of((long) v)); builder.cell(). label(tensorType.dimensions().get(0).name(), term) - .value(score); + .value(maxLogOfRelu); } } return builder.build(); } - private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) { int size = input.size(); TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); @@ -167,4 +176,5 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { evaluator.close(); tokenizer.close(); } + } diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index e0d940ca5fe..9ecb0e3e162 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -11,6 +11,7 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.util.List; +import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; @@ -43,6 +44,19 @@ public class SpladeEmbedderTest { assertEquals(0, result.size()); } + @Ignore + public void testPerformanceNotTerrible() { + String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" + + " ii the project was led by the united states with the support of the united kingdom and canada"; + Long now = System.currentTimeMillis(); + int n = 10; + for (int i = 0; i < n; i++) { + assertEmbed("tensor<float>(t{})", text, indexingContext); + } + Long elapsed = (System.currentTimeMillis() - now)/1000; + System.out.println("Elapsed time: " + elapsed + " ms"); + } + @Test public void throwsOnInvalidTensorType() { Throwable exception = assertThrows(RuntimeException.class, () -> { |