diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-01-11 12:43:00 +0100 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-01-11 12:43:00 +0100 |
commit | 9535c4f8073f003cfd799c2d3a4bfebe8bb02e55 (patch) | |
tree | 51d27b30db7f36298d9f25327ea350ccb03bc2d7 /model-integration | |
parent | 52a0c0597b85dc45fc69cd01f2f9d859a11bb348 (diff) |
address review
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java | 65 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java | 3 |
2 files changed, 25 insertions, 43 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 78fb3142704..644b1ec538f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -130,45 +130,8 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { * @return A mapped tensor with the terms from the vocab that has a score above the threshold */ public Tensor sparsify(IndexedTensor modelOutput, TensorType tensorType) { - IndexedTensor vocab = customMaxReduceOverLogOfRelu(modelOutput); var builder = Tensor.Builder.of(tensorType); - for(int i = 0; i < vocab.size(); i++) { - var score = vocab.get(i); - if (score > termScoreThreshold) { - String term = tokenizer.decode(List.of((long) i)); - builder.cell(). - label(tensorType.dimensions().get(0).name(), term) - .value(score); - } - } - return builder.build(); - } - - private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) { - int size = input.size(); - TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); - IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); - for (int i = 0; i < size; ++i) { - builder.cell(input.get(i), i); - } - return builder.build(); - } - @Override - public void deconstruct() { - evaluator.close(); - tokenizer.close(); - } - - /** - * Custom max reduce over the sequence dimension of the output tensor. This - * to reduce GC pressure from generic Tensor.reduce operation. - * - * @param tensor the model output tensor of shape d1,dim where d1 is the sequence length and dim is size - * of the vocabulary - * @return A tensor of shape d1,1 where each value is the max of the log of the relu of the input tensor - */ - private static IndexedTensor customMaxReduceOverLogOfRelu(IndexedTensor tensor) { - long[] shape = tensor.shape(); + long[] shape = modelOutput.shape(); if(shape.length != 3) { throw new IllegalArgumentException("The indexed tensor must be 3-dimensional"); } @@ -179,21 +142,39 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { long sequenceLength = shape[1]; long vocabSize = shape[2]; - TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed("vocab", vocabSize).build(); - IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); //Iterate over the vocab dimension and find the max value for each sequence token for(int v = 0; v < vocabSize; v++) { double maxLogOfRelu = Double.MIN_VALUE; for(int s = 0; s < sequenceLength; s++) { - double value = tensor.get(0, s, v); // batch, sequence, vocab + double value = modelOutput.get(0, s, v); // batch, sequence, vocab double logOfRelu = Math.log(1 + Math.max(0, value)); if(logOfRelu > maxLogOfRelu) { maxLogOfRelu = logOfRelu; } } - builder.cell(maxLogOfRelu, v); + if (maxLogOfRelu > termScoreThreshold) { + String term = tokenizer.decode(List.of((long) v)); + builder.cell(). + label(tensorType.dimensions().get(0).name(), term) + .value(maxLogOfRelu); + } + } + return builder.build(); + } + + private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) { + int size = input.size(); + TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); + for (int i = 0; i < size; ++i) { + builder.cell(input.get(i), i); } return builder.build(); } + @Override + public void deconstruct() { + evaluator.close(); + tokenizer.close(); + } } diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index 5099a251f00..b28748ddd82 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -11,6 +11,7 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.util.List; +import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; @@ -43,7 +44,7 @@ public class SpladeEmbedderTest { assertEquals(0, result.size()); } - @Test + @Ignore public void testPerformanceNotTerrible() { String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" + " ii the project was led by the united states with the support of the united kingdom and canada"; |