From 52a0c0597b85dc45fc69cd01f2f9d859a11bb348 Mon Sep 17 00:00:00 2001 From: Jo Kristian Bergum Date: Thu, 11 Jan 2024 10:49:28 +0100 Subject: Avoid generic reduce to reduce gc pressure --- .../java/ai/vespa/embedding/SpladeEmbedder.java | 65 ++++++++++++++++------ .../ai/vespa/embedding/SpladeEmbedderTest.java | 15 ++++- 2 files changed, 61 insertions(+), 19 deletions(-) diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 4af7820274f..78fb3142704 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -13,7 +13,6 @@ import com.yahoo.language.process.Embedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Reduce; import java.nio.file.Paths; import java.util.List; import java.util.Map; @@ -22,9 +21,7 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES /** * A SPLADE embedder that is embedding text to a 1-d mapped tensor. For interpretability, the tensor labels - * are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0). This - * instead of using the token identifier. - * + * are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0). */ @Beta public class SpladeEmbedder extends AbstractComponent implements Embedder { @@ -119,26 +116,21 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { Map inputs = Map.of(inputIdsName, inputSequence.expand("d0"), attentionMaskName, attentionMask.expand("d0"), tokenTypeIdsName, tokenTypeIds.expand("d0")); - - Map outputs = evaluator.evaluate(inputs); - //Remove batch dim, batch size of 1 - Tensor output = outputs.get(outputName).reduce(Reduce.Aggregator.max, "d0"); - Tensor mappedTensor = sparsify(output, tensorType); + Tensor spladeTensor = sparsify((IndexedTensor) evaluator.evaluate(inputs).get(outputName), tensorType); runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); - return mappedTensor; + return spladeTensor; } + /** - * Sparsify the output tensor by applying a threshold on the log of the relu of the output. - * @param modelOutput the model output tensor of shape d1,dim where d1 is the sequence length and dim is size - * of the vocabulary + * Sparsify the model output tensor. + * + * @param modelOutput the model output tensor of type tensorType * @param tensorType the type of the destination tensor * @return A mapped tensor with the terms from the vocab that has a score above the threshold */ - public Tensor sparsify(Tensor modelOutput, TensorType tensorType) { - Tensor logOfRelu = modelOutput.map((x) -> Math.log(1 + Math.max(0, x))); - Tensor maxReduced = logOfRelu.reduce(Reduce.Aggregator.max, "d1"); - IndexedTensor vocab = (IndexedTensor) maxReduced; + public Tensor sparsify(IndexedTensor modelOutput, TensorType tensorType) { + IndexedTensor vocab = customMaxReduceOverLogOfRelu(modelOutput); var builder = Tensor.Builder.of(tensorType); for(int i = 0; i < vocab.size(); i++) { var score = vocab.get(i); @@ -152,7 +144,6 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - private IndexedTensor createTensorRepresentation(List input, String dimension) { int size = input.size(); TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); @@ -167,4 +158,42 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { evaluator.close(); tokenizer.close(); } + + /** + * Custom max reduce over the sequence dimension of the output tensor. This + * to reduce GC pressure from generic Tensor.reduce operation. + * + * @param tensor the model output tensor of shape d1,dim where d1 is the sequence length and dim is size + * of the vocabulary + * @return A tensor of shape d1,1 where each value is the max of the log of the relu of the input tensor + */ + private static IndexedTensor customMaxReduceOverLogOfRelu(IndexedTensor tensor) { + long[] shape = tensor.shape(); + if(shape.length != 3) { + throw new IllegalArgumentException("The indexed tensor must be 3-dimensional"); + } + long batch = shape[0]; + if (batch != 1) { + throw new IllegalArgumentException("Batch size must be 1"); + } + long sequenceLength = shape[1]; + long vocabSize = shape[2]; + + TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed("vocab", vocabSize).build(); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); + //Iterate over the vocab dimension and find the max value for each sequence token + for(int v = 0; v < vocabSize; v++) { + double maxLogOfRelu = Double.MIN_VALUE; + for(int s = 0; s < sequenceLength; s++) { + double value = tensor.get(0, s, v); // batch, sequence, vocab + double logOfRelu = Math.log(1 + Math.max(0, value)); + if(logOfRelu > maxLogOfRelu) { + maxLogOfRelu = logOfRelu; + } + } + builder.cell(maxLogOfRelu, v); + } + return builder.build(); + } + } diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index e0d940ca5fe..5099a251f00 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -43,6 +43,19 @@ public class SpladeEmbedderTest { assertEquals(0, result.size()); } + @Test + public void testPerformanceNotTerrible() { + String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" + + " ii the project was led by the united states with the support of the united kingdom and canada"; + Long now = System.currentTimeMillis(); + int n = 10; + for (int i = 0; i < n; i++) { + assertEmbed("tensor(t{})", text, indexingContext); + } + Long elapsed = (System.currentTimeMillis() - now)/1000; + System.out.println("Elapsed time: " + elapsed + " ms"); + } + @Test public void throwsOnInvalidTensorType() { Throwable exception = assertThrows(RuntimeException.class, () -> { @@ -54,7 +67,7 @@ public class SpladeEmbedderTest { static final Embedder spladeEmbedder; static final Embedder.Context indexingContext; - static final Double scoreThreshold = 1.15; + static final Double scoreThreshold = 1.15; static { indexingContext = new Embedder.Context("schema.indexing"); -- cgit v1.2.3