diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-19 14:32:23 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-19 14:32:23 +0100 |
commit | 3cf8be5fe60504a02be04009b9348913ae32b564 (patch) | |
tree | 2efea86685fb8c94725e71628ddaeaa683a1faed /vespajlib | |
parent | 74cbf975a435d54eb892de0142d6cceb2d1ebc93 (diff) |
Add a class for assist efficient traversal of dimensions in an IndexedTensor.
Diffstat (limited to 'vespajlib')
4 files changed, 91 insertions, 0 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 5d88b2d2829..5b03c8b5661 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -720,6 +720,20 @@ ], "fields" : [ ] }, + "com.yahoo.tensor.DirectIndexedAddress" : { + "superClass" : "java.lang.Object", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void setIndex(int, int)", + "public long getIndex()", + "public long getStride(int)" + ], + "fields" : [ ] + }, "com.yahoo.tensor.IndexedDoubleTensor$BoundDoubleBuilder" : { "superClass" : "com.yahoo.tensor.IndexedTensor$BoundBuilder", "interfaces" : [ ], @@ -894,6 +908,8 @@ "public java.util.Iterator subspaceIterator(java.util.Set, com.yahoo.tensor.DimensionSizes)", "public java.util.Iterator subspaceIterator(java.util.Set)", "public varargs double get(long[])", + "public double get(com.yahoo.tensor.DirectIndexedAddress)", + "public com.yahoo.tensor.DirectIndexedAddress directAddress()", "public varargs float getFloat(long[])", "public double get(com.yahoo.tensor.TensorAddress)", "public boolean has(com.yahoo.tensor.TensorAddress)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java new file mode 100644 index 00000000000..80cb545238c --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java @@ -0,0 +1,38 @@ +package com.yahoo.tensor; + +/** + * Utility class for efficient access and iteration along dimensions in Indexed tensors. + * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration. + * long base = addr.getIndex(); + * long stride = addr.getStride(dimension) + * i = 0...size_of_dimension + * double value = tensor.get(base + i * stride); + */ +public final class DirectIndexedAddress { + private final DimensionSizes sizes; + private final int [] indexes; + private long directIndex; + private DirectIndexedAddress(DimensionSizes sizes) { + this.sizes = sizes; + indexes = new int[sizes.dimensions()]; + directIndex = 0; + } + static DirectIndexedAddress of(DimensionSizes sizes) { + return new DirectIndexedAddress(sizes); + } + /** Sets the current index of a dimension */ + public void setIndex(int dimension, int index) { + if (index < 0 || index >= sizes.size(dimension)) { + throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">"); + } + int diff = index - indexes[dimension]; + directIndex += getStride(dimension) * diff; + indexes[dimension] = index; + } + /** Retrieve the index that can be used for direct lookup in an indexed tensor. */ + public long getIndex() { return directIndex; } + /** returns the stride to be used for the given dimension */ + public long getStride(int dimension) { + return sizes.productOfDimensionsAfter(dimension); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 1319675f5d4..d43b4a03f72 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -93,6 +93,10 @@ public abstract class IndexedTensor implements Tensor { return get(toValueIndex(indexes, dimensionSizes)); } + public double get(DirectIndexedAddress address) { + return get(address.getIndex()); + } + public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); } /** * Returns the value at the given indexes as a float * diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index 0a6c821e64e..d145ad7a316 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -9,6 +9,7 @@ import java.util.Iterator; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -96,6 +97,38 @@ public class IndexedTensorTestCase { } @Test + public void testDirectIndexedAddress() { + TensorType type = new TensorType.Builder().indexed("v", 3) + .indexed("w", wSize) + .indexed("x", xSize) + .indexed("y", ySize) + .indexed("z", zSize) + .build(); + var directAddress = DirectIndexedAddress.of(DimensionSizes.of(type)); + assertThrows(ArrayIndexOutOfBoundsException.class, () -> directAddress.getStride(5)); + assertThrows(IndexOutOfBoundsException.class, () -> directAddress.setIndex(4, 7)); + assertEquals(wSize*xSize*ySize*zSize, directAddress.getStride(0)); + assertEquals(xSize*ySize*zSize, directAddress.getStride(1)); + assertEquals(ySize*zSize, directAddress.getStride(2)); + assertEquals(zSize, directAddress.getStride(3)); + assertEquals(1, directAddress.getStride(4)); + assertEquals(0, directAddress.getIndex()); + directAddress.setIndex(0,1); + assertEquals(1 * directAddress.getStride(0), directAddress.getIndex()); + directAddress.setIndex(1,1); + assertEquals(1 * directAddress.getStride(0) + 1 * directAddress.getStride(1), directAddress.getIndex()); + directAddress.setIndex(2,2); + directAddress.setIndex(3,2); + directAddress.setIndex(4,2); + long expected = 1 * directAddress.getStride(0) + + 1 * directAddress.getStride(1) + + 2 * directAddress.getStride(2) + + 2 * directAddress.getStride(3) + + 2 * directAddress.getStride(4); + assertEquals(expected, directAddress.getIndex()); + } + + @Test public void testUnboundBuilding() { TensorType type = new TensorType.Builder().indexed("w") .indexed("v") |