diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-20 11:39:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-20 11:39:02 +0100 |
commit | c18b5805006b83efbeb9fc881e1658a57be28e56 (patch) | |
tree | be930b815a0e0335db622d81134550344193aae2 | |
parent | 3a9c2446b1fdab6443365a034dd72e8939a59943 (diff) | |
parent | 77df31b8e9af00e02003f04285f24e50bea4e59a (diff) |
Merge pull request #29983 from vespa-engine/balder/add-class-to-assist-fast-iteration-of-of-indexed-tensors
Add a class for assist efficient traversal of dimensions in an Indexe…
6 files changed, 100 insertions, 4 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 853009873a1..28f8c4e252f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -10,9 +10,9 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.DirectIndexedAddress; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; @@ -152,10 +152,15 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token long [] tokens = new long[1]; + DirectIndexedAddress directAddress = modelOutput.directAddress(); + directAddress.setIndex(0,0); for (int v = 0; v < vocabSize; v++) { double maxValue = 0.0d; + directAddress.setIndex(2, v); + long increment = directAddress.getStride(1); + long directIndex = directAddress.getDirectIndex(); for (int s = 0; s < sequenceLength; s++) { - double value = modelOutput.get(0, s, v); // batch, sequence, vocab + double value = modelOutput.get(directIndex + s * increment); if (value > maxValue) { maxValue = value; } diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index 9ecb0e3e162..82998b56fb5 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -49,11 +49,11 @@ public class SpladeEmbedderTest { String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" + " ii the project was led by the united states with the support of the united kingdom and canada"; Long now = System.currentTimeMillis(); - int n = 10; + int n = 1000; // Takes around 8s on Intel core i9 2.4Ghz (macbook pro, 2019) for (int i = 0; i < n; i++) { assertEmbed("tensor<float>(t{})", text, indexingContext); } - Long elapsed = (System.currentTimeMillis() - now)/1000; + Long elapsed = System.currentTimeMillis() - now; System.out.println("Elapsed time: " + elapsed + " ms"); } diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 5d88b2d2829..174ce6332db 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -720,6 +720,20 @@ ], "fields" : [ ] }, + "com.yahoo.tensor.DirectIndexedAddress" : { + "superClass" : "java.lang.Object", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void setIndex(int, int)", + "public long getDirectIndex()", + "public long getStride(int)" + ], + "fields" : [ ] + }, "com.yahoo.tensor.IndexedDoubleTensor$BoundDoubleBuilder" : { "superClass" : "com.yahoo.tensor.IndexedTensor$BoundBuilder", "interfaces" : [ ], @@ -894,6 +908,8 @@ "public java.util.Iterator subspaceIterator(java.util.Set, com.yahoo.tensor.DimensionSizes)", "public java.util.Iterator subspaceIterator(java.util.Set)", "public varargs double get(long[])", + "public double get(com.yahoo.tensor.DirectIndexedAddress)", + "public com.yahoo.tensor.DirectIndexedAddress directAddress()", "public varargs float getFloat(long[])", "public double get(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java new file mode 100644 index 00000000000..37752361876 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java @@ -0,0 +1,38 @@ +package com.yahoo.tensor; + +/** + * Utility class for efficient access and iteration along dimensions in Indexed tensors. + * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration. + * long base = addr.getDirectIndex(); + * long stride = addr.getStride(dimension) + * i = 0...size_of_dimension + * double value = tensor.get(base + i * stride); + */ +public final class DirectIndexedAddress { + private final DimensionSizes sizes; + private final int [] indexes; + private long directIndex; + private DirectIndexedAddress(DimensionSizes sizes) { + this.sizes = sizes; + indexes = new int[sizes.dimensions()]; + directIndex = 0; + } + static DirectIndexedAddress of(DimensionSizes sizes) { + return new DirectIndexedAddress(sizes); + } + /** Sets the current index of a dimension */ + public void setIndex(int dimension, int index) { + if (index < 0 || index >= sizes.size(dimension)) { + throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">"); + } + int diff = index - indexes[dimension]; + directIndex += getStride(dimension) * diff; + indexes[dimension] = index; + } + /** Retrieve the index that can be used for direct lookup in an indexed tensor. */ + public long getDirectIndex() { return directIndex; } + /** returns the stride to be used for the given dimension */ + public long getStride(int dimension) { + return sizes.productOfDimensionsAfter(dimension); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 1319675f5d4..93cdc3f630f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -93,6 +93,10 @@ public abstract class IndexedTensor implements Tensor { return get(toValueIndex(indexes, dimensionSizes)); } + public double get(DirectIndexedAddress address) { + return get(address.getDirectIndex()); + } + public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); } /** * Returns the value at the given indexes as a float * diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index 0a6c821e64e..afc95d295f0 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -9,6 +9,7 @@ import java.util.Iterator; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -96,6 +97,38 @@ public class IndexedTensorTestCase { } @Test + public void testDirectIndexedAddress() { + TensorType type = new TensorType.Builder().indexed("v", 3) + .indexed("w", wSize) + .indexed("x", xSize) + .indexed("y", ySize) + .indexed("z", zSize) + .build(); + var directAddress = DirectIndexedAddress.of(DimensionSizes.of(type)); + assertThrows(ArrayIndexOutOfBoundsException.class, () -> directAddress.getStride(5)); + assertThrows(IndexOutOfBoundsException.class, () -> directAddress.setIndex(4, 7)); + assertEquals(wSize*xSize*ySize*zSize, directAddress.getStride(0)); + assertEquals(xSize*ySize*zSize, directAddress.getStride(1)); + assertEquals(ySize*zSize, directAddress.getStride(2)); + assertEquals(zSize, directAddress.getStride(3)); + assertEquals(1, directAddress.getStride(4)); + assertEquals(0, directAddress.getDirectIndex()); + directAddress.setIndex(0,1); + assertEquals(1 * directAddress.getStride(0), directAddress.getDirectIndex()); + directAddress.setIndex(1,1); + assertEquals(1 * directAddress.getStride(0) + 1 * directAddress.getStride(1), directAddress.getDirectIndex()); + directAddress.setIndex(2,2); + directAddress.setIndex(3,2); + directAddress.setIndex(4,2); + long expected = 1 * directAddress.getStride(0) + + 1 * directAddress.getStride(1) + + 2 * directAddress.getStride(2) + + 2 * directAddress.getStride(3) + + 2 * directAddress.getStride(4); + assertEquals(expected, directAddress.getDirectIndex()); + } + + @Test public void testUnboundBuilding() { TensorType type = new TensorType.Builder().indexed("w") .indexed("v") |