aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-19 14:32:23 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-19 14:32:23 +0100
commit3cf8be5fe60504a02be04009b9348913ae32b564 (patch)
tree2efea86685fb8c94725e71628ddaeaa683a1faed /model-integration
parent74cbf975a435d54eb892de0142d6cceb2d1ebc93 (diff)
Add a class for assist efficient traversal of dimensions in an IndexedTensor.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java9
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java4
2 files changed, 9 insertions, 4 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 853009873a1..4b90fa0a9bf 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -10,9 +10,9 @@ import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
+import com.yahoo.tensor.DirectIndexedAddress;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Paths;
import java.util.List;
@@ -152,10 +152,15 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
String dimension = tensorType.dimensions().get(0).name();
//Iterate over the vocab dimension and find the max value for each sequence token
long [] tokens = new long[1];
+ DirectIndexedAddress directAddress = modelOutput.directAddress();
+ directAddress.setIndex(0,0);
for (int v = 0; v < vocabSize; v++) {
double maxValue = 0.0d;
+ directAddress.setIndex(2, v);
+ long increment = directAddress.getStride(1);
+ long directIndex = directAddress.getIndex();
for (int s = 0; s < sequenceLength; s++) {
- double value = modelOutput.get(0, s, v); // batch, sequence, vocab
+ double value = modelOutput.get(directIndex + s * increment);
if (value > maxValue) {
maxValue = value;
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
index 9ecb0e3e162..82998b56fb5 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
@@ -49,11 +49,11 @@ public class SpladeEmbedderTest {
String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" +
" ii the project was led by the united states with the support of the united kingdom and canada";
Long now = System.currentTimeMillis();
- int n = 10;
+ int n = 1000; // Takes around 8s on Intel core i9 2.4Ghz (macbook pro, 2019)
for (int i = 0; i < n; i++) {
assertEmbed("tensor<float>(t{})", text, indexingContext);
}
- Long elapsed = (System.currentTimeMillis() - now)/1000;
+ Long elapsed = System.currentTimeMillis() - now;
System.out.println("Elapsed time: " + elapsed + " ms");
}