aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-20 11:39:02 +0100
committerGitHub <noreply@github.com>2024-01-20 11:39:02 +0100
commitc18b5805006b83efbeb9fc881e1658a57be28e56 (patch)
treebe930b815a0e0335db622d81134550344193aae2
parent3a9c2446b1fdab6443365a034dd72e8939a59943 (diff)
parent77df31b8e9af00e02003f04285f24e50bea4e59a (diff)
Merge pull request #29983 from vespa-engine/balder/add-class-to-assist-fast-iteration-of-of-indexed-tensors
Add a class for assist efficient traversal of dimensions in an Indexe…
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java9
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java4
-rw-r--r--vespajlib/abi-spec.json16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java33
6 files changed, 100 insertions, 4 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 853009873a1..28f8c4e252f 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -10,9 +10,9 @@ import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
+import com.yahoo.tensor.DirectIndexedAddress;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Paths;
import java.util.List;
@@ -152,10 +152,15 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
String dimension = tensorType.dimensions().get(0).name();
//Iterate over the vocab dimension and find the max value for each sequence token
long [] tokens = new long[1];
+ DirectIndexedAddress directAddress = modelOutput.directAddress();
+ directAddress.setIndex(0,0);
for (int v = 0; v < vocabSize; v++) {
double maxValue = 0.0d;
+ directAddress.setIndex(2, v);
+ long increment = directAddress.getStride(1);
+ long directIndex = directAddress.getDirectIndex();
for (int s = 0; s < sequenceLength; s++) {
- double value = modelOutput.get(0, s, v); // batch, sequence, vocab
+ double value = modelOutput.get(directIndex + s * increment);
if (value > maxValue) {
maxValue = value;
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
index 9ecb0e3e162..82998b56fb5 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
@@ -49,11 +49,11 @@ public class SpladeEmbedderTest {
String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" +
" ii the project was led by the united states with the support of the united kingdom and canada";
Long now = System.currentTimeMillis();
- int n = 10;
+ int n = 1000; // Takes around 8s on Intel core i9 2.4Ghz (macbook pro, 2019)
for (int i = 0; i < n; i++) {
assertEmbed("tensor<float>(t{})", text, indexingContext);
}
- Long elapsed = (System.currentTimeMillis() - now)/1000;
+ Long elapsed = System.currentTimeMillis() - now;
System.out.println("Elapsed time: " + elapsed + " ms");
}
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 5d88b2d2829..174ce6332db 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -720,6 +720,20 @@
],
"fields" : [ ]
},
+ "com.yahoo.tensor.DirectIndexedAddress" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public void setIndex(int, int)",
+ "public long getDirectIndex()",
+ "public long getStride(int)"
+ ],
+ "fields" : [ ]
+ },
"com.yahoo.tensor.IndexedDoubleTensor$BoundDoubleBuilder" : {
"superClass" : "com.yahoo.tensor.IndexedTensor$BoundBuilder",
"interfaces" : [ ],
@@ -894,6 +908,8 @@
"public java.util.Iterator subspaceIterator(java.util.Set, com.yahoo.tensor.DimensionSizes)",
"public java.util.Iterator subspaceIterator(java.util.Set)",
"public varargs double get(long[])",
+ "public double get(com.yahoo.tensor.DirectIndexedAddress)",
+ "public com.yahoo.tensor.DirectIndexedAddress directAddress()",
"public varargs float getFloat(long[])",
"public double get(com.yahoo.tensor.TensorAddress)",
"public boolean has(com.yahoo.tensor.TensorAddress)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
new file mode 100644
index 00000000000..37752361876
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
@@ -0,0 +1,38 @@
+package com.yahoo.tensor;
+
+/**
+ * Utility class for efficient access and iteration along dimensions in Indexed tensors.
+ * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration.
+ * long base = addr.getDirectIndex();
+ * long stride = addr.getStride(dimension)
+ * i = 0...size_of_dimension
+ * double value = tensor.get(base + i * stride);
+ */
+public final class DirectIndexedAddress {
+ private final DimensionSizes sizes;
+ private final int [] indexes;
+ private long directIndex;
+ private DirectIndexedAddress(DimensionSizes sizes) {
+ this.sizes = sizes;
+ indexes = new int[sizes.dimensions()];
+ directIndex = 0;
+ }
+ static DirectIndexedAddress of(DimensionSizes sizes) {
+ return new DirectIndexedAddress(sizes);
+ }
+ /** Sets the current index of a dimension */
+ public void setIndex(int dimension, int index) {
+ if (index < 0 || index >= sizes.size(dimension)) {
+ throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">");
+ }
+ int diff = index - indexes[dimension];
+ directIndex += getStride(dimension) * diff;
+ indexes[dimension] = index;
+ }
+ /** Retrieve the index that can be used for direct lookup in an indexed tensor. */
+ public long getDirectIndex() { return directIndex; }
+ /** returns the stride to be used for the given dimension */
+ public long getStride(int dimension) {
+ return sizes.productOfDimensionsAfter(dimension);
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 1319675f5d4..93cdc3f630f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -93,6 +93,10 @@ public abstract class IndexedTensor implements Tensor {
return get(toValueIndex(indexes, dimensionSizes));
}
+ public double get(DirectIndexedAddress address) {
+ return get(address.getDirectIndex());
+ }
+ public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); }
/**
* Returns the value at the given indexes as a float
*
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
index 0a6c821e64e..afc95d295f0 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
@@ -9,6 +9,7 @@ import java.util.Iterator;
import java.util.Map;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -96,6 +97,38 @@ public class IndexedTensorTestCase {
}
@Test
+ public void testDirectIndexedAddress() {
+ TensorType type = new TensorType.Builder().indexed("v", 3)
+ .indexed("w", wSize)
+ .indexed("x", xSize)
+ .indexed("y", ySize)
+ .indexed("z", zSize)
+ .build();
+ var directAddress = DirectIndexedAddress.of(DimensionSizes.of(type));
+ assertThrows(ArrayIndexOutOfBoundsException.class, () -> directAddress.getStride(5));
+ assertThrows(IndexOutOfBoundsException.class, () -> directAddress.setIndex(4, 7));
+ assertEquals(wSize*xSize*ySize*zSize, directAddress.getStride(0));
+ assertEquals(xSize*ySize*zSize, directAddress.getStride(1));
+ assertEquals(ySize*zSize, directAddress.getStride(2));
+ assertEquals(zSize, directAddress.getStride(3));
+ assertEquals(1, directAddress.getStride(4));
+ assertEquals(0, directAddress.getDirectIndex());
+ directAddress.setIndex(0,1);
+ assertEquals(1 * directAddress.getStride(0), directAddress.getDirectIndex());
+ directAddress.setIndex(1,1);
+ assertEquals(1 * directAddress.getStride(0) + 1 * directAddress.getStride(1), directAddress.getDirectIndex());
+ directAddress.setIndex(2,2);
+ directAddress.setIndex(3,2);
+ directAddress.setIndex(4,2);
+ long expected = 1 * directAddress.getStride(0) +
+ 1 * directAddress.getStride(1) +
+ 2 * directAddress.getStride(2) +
+ 2 * directAddress.getStride(3) +
+ 2 * directAddress.getStride(4);
+ assertEquals(expected, directAddress.getDirectIndex());
+ }
+
+ @Test
public void testUnboundBuilding() {
TensorType type = new TensorType.Builder().indexed("w")
.indexed("v")