diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-19 14:32:23 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-19 14:32:23 +0100 |
commit | 3cf8be5fe60504a02be04009b9348913ae32b564 (patch) | |
tree | 2efea86685fb8c94725e71628ddaeaa683a1faed /model-integration/src/main | |
parent | 74cbf975a435d54eb892de0142d6cceb2d1ebc93 (diff) |
Add a class for assist efficient traversal of dimensions in an IndexedTensor.
Diffstat (limited to 'model-integration/src/main')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 853009873a1..4b90fa0a9bf 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -10,9 +10,9 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.DirectIndexedAddress; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; @@ -152,10 +152,15 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token long [] tokens = new long[1]; + DirectIndexedAddress directAddress = modelOutput.directAddress(); + directAddress.setIndex(0,0); for (int v = 0; v < vocabSize; v++) { double maxValue = 0.0d; + directAddress.setIndex(2, v); + long increment = directAddress.getStride(1); + long directIndex = directAddress.getIndex(); for (int s = 0; s < sequenceLength; s++) { - double value = modelOutput.get(0, s, v); // batch, sequence, vocab + double value = modelOutput.get(directIndex + s * increment); if (value > maxValue) { maxValue = value; } |