summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-01-11 12:43:00 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2024-01-11 12:43:00 +0100
commit9535c4f8073f003cfd799c2d3a4bfebe8bb02e55 (patch)
tree51d27b30db7f36298d9f25327ea350ccb03bc2d7 /model-integration
parent52a0c0597b85dc45fc69cd01f2f9d859a11bb348 (diff)
address review
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java65
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java3
2 files changed, 25 insertions, 43 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 78fb3142704..644b1ec538f 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -130,45 +130,8 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
* @return A mapped tensor with the terms from the vocab that has a score above the threshold
*/
public Tensor sparsify(IndexedTensor modelOutput, TensorType tensorType) {
- IndexedTensor vocab = customMaxReduceOverLogOfRelu(modelOutput);
var builder = Tensor.Builder.of(tensorType);
- for(int i = 0; i < vocab.size(); i++) {
- var score = vocab.get(i);
- if (score > termScoreThreshold) {
- String term = tokenizer.decode(List.of((long) i));
- builder.cell().
- label(tensorType.dimensions().get(0).name(), term)
- .value(score);
- }
- }
- return builder.build();
- }
-
- private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
- int size = input.size();
- TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
- IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
- for (int i = 0; i < size; ++i) {
- builder.cell(input.get(i), i);
- }
- return builder.build();
- }
- @Override
- public void deconstruct() {
- evaluator.close();
- tokenizer.close();
- }
-
- /**
- * Custom max reduce over the sequence dimension of the output tensor. This
- * to reduce GC pressure from generic Tensor.reduce operation.
- *
- * @param tensor the model output tensor of shape d1,dim where d1 is the sequence length and dim is size
- * of the vocabulary
- * @return A tensor of shape d1,1 where each value is the max of the log of the relu of the input tensor
- */
- private static IndexedTensor customMaxReduceOverLogOfRelu(IndexedTensor tensor) {
- long[] shape = tensor.shape();
+ long[] shape = modelOutput.shape();
if(shape.length != 3) {
throw new IllegalArgumentException("The indexed tensor must be 3-dimensional");
}
@@ -179,21 +142,39 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
long sequenceLength = shape[1];
long vocabSize = shape[2];
- TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed("vocab", vocabSize).build();
- IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
//Iterate over the vocab dimension and find the max value for each sequence token
for(int v = 0; v < vocabSize; v++) {
double maxLogOfRelu = Double.MIN_VALUE;
for(int s = 0; s < sequenceLength; s++) {
- double value = tensor.get(0, s, v); // batch, sequence, vocab
+ double value = modelOutput.get(0, s, v); // batch, sequence, vocab
double logOfRelu = Math.log(1 + Math.max(0, value));
if(logOfRelu > maxLogOfRelu) {
maxLogOfRelu = logOfRelu;
}
}
- builder.cell(maxLogOfRelu, v);
+ if (maxLogOfRelu > termScoreThreshold) {
+ String term = tokenizer.decode(List.of((long) v));
+ builder.cell().
+ label(tensorType.dimensions().get(0).name(), term)
+ .value(maxLogOfRelu);
+ }
+ }
+ return builder.build();
+ }
+
+ private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
+ int size = input.size();
+ TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
+ for (int i = 0; i < size; ++i) {
+ builder.cell(input.get(i), i);
}
return builder.build();
}
+ @Override
+ public void deconstruct() {
+ evaluator.close();
+ tokenizer.close();
+ }
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
index 5099a251f00..b28748ddd82 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
@@ -11,6 +11,7 @@ import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.util.List;
+import org.junit.Ignore;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
@@ -43,7 +44,7 @@ public class SpladeEmbedderTest {
assertEquals(0, result.size());
}
- @Test
+ @Ignore
public void testPerformanceNotTerrible() {
String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" +
" ii the project was led by the united states with the support of the united kingdom and canada";