aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-19 14:32:23 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-19 14:32:23 +0100
commit3cf8be5fe60504a02be04009b9348913ae32b564 (patch)
tree2efea86685fb8c94725e71628ddaeaa683a1faed /model-integration/src/main
parent74cbf975a435d54eb892de0142d6cceb2d1ebc93 (diff)
Add a class for assist efficient traversal of dimensions in an IndexedTensor.
Diffstat (limited to 'model-integration/src/main')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java9
1 files changed, 7 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 853009873a1..4b90fa0a9bf 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -10,9 +10,9 @@ import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
+import com.yahoo.tensor.DirectIndexedAddress;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Paths;
import java.util.List;
@@ -152,10 +152,15 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
String dimension = tensorType.dimensions().get(0).name();
//Iterate over the vocab dimension and find the max value for each sequence token
long [] tokens = new long[1];
+ DirectIndexedAddress directAddress = modelOutput.directAddress();
+ directAddress.setIndex(0,0);
for (int v = 0; v < vocabSize; v++) {
double maxValue = 0.0d;
+ directAddress.setIndex(2, v);
+ long increment = directAddress.getStride(1);
+ long directIndex = directAddress.getIndex();
for (int s = 0; s < sequenceLength; s++) {
- double value = modelOutput.get(0, s, v); // batch, sequence, vocab
+ double value = modelOutput.get(directIndex + s * increment);
if (value > maxValue) {
maxValue = value;
}