diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-19 14:32:23 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-19 14:32:23 +0100 |
commit | 3cf8be5fe60504a02be04009b9348913ae32b564 (patch) | |
tree | 2efea86685fb8c94725e71628ddaeaa683a1faed /model-integration/src | |
parent | 74cbf975a435d54eb892de0142d6cceb2d1ebc93 (diff) |
Add a class for assist efficient traversal of dimensions in an IndexedTensor.
Diffstat (limited to 'model-integration/src')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java | 9 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java | 4 |
2 files changed, 9 insertions, 4 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 853009873a1..4b90fa0a9bf 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -10,9 +10,9 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.DirectIndexedAddress; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; @@ -152,10 +152,15 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token long [] tokens = new long[1]; + DirectIndexedAddress directAddress = modelOutput.directAddress(); + directAddress.setIndex(0,0); for (int v = 0; v < vocabSize; v++) { double maxValue = 0.0d; + directAddress.setIndex(2, v); + long increment = directAddress.getStride(1); + long directIndex = directAddress.getIndex(); for (int s = 0; s < sequenceLength; s++) { - double value = modelOutput.get(0, s, v); // batch, sequence, vocab + double value = modelOutput.get(directIndex + s * increment); if (value > maxValue) { maxValue = value; } diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index 9ecb0e3e162..82998b56fb5 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -49,11 +49,11 @@ public class SpladeEmbedderTest { String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" + " ii the project was led by the united states with the support of the united kingdom and canada"; Long now = System.currentTimeMillis(); - int n = 10; + int n = 1000; // Takes around 8s on Intel core i9 2.4Ghz (macbook pro, 2019) for (int i = 0; i < n; i++) { assertEmbed("tensor<float>(t{})", text, indexingContext); } - Long elapsed = (System.currentTimeMillis() - now)/1000; + Long elapsed = System.currentTimeMillis() - now; System.out.println("Elapsed time: " + elapsed + " ms"); } |