summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-19 14:32:23 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-19 14:32:23 +0100
commit3cf8be5fe60504a02be04009b9348913ae32b564 (patch)
tree2efea86685fb8c94725e71628ddaeaa683a1faed /vespajlib
parent74cbf975a435d54eb892de0142d6cceb2d1ebc93 (diff)
Add a class for assist efficient traversal of dimensions in an IndexedTensor.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java33
4 files changed, 91 insertions, 0 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 5d88b2d2829..5b03c8b5661 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -720,6 +720,20 @@
],
"fields" : [ ]
},
+ "com.yahoo.tensor.DirectIndexedAddress" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public void setIndex(int, int)",
+ "public long getIndex()",
+ "public long getStride(int)"
+ ],
+ "fields" : [ ]
+ },
"com.yahoo.tensor.IndexedDoubleTensor$BoundDoubleBuilder" : {
"superClass" : "com.yahoo.tensor.IndexedTensor$BoundBuilder",
"interfaces" : [ ],
@@ -894,6 +908,8 @@
"public java.util.Iterator subspaceIterator(java.util.Set, com.yahoo.tensor.DimensionSizes)",
"public java.util.Iterator subspaceIterator(java.util.Set)",
"public varargs double get(long[])",
+ "public double get(com.yahoo.tensor.DirectIndexedAddress)",
+ "public com.yahoo.tensor.DirectIndexedAddress directAddress()",
"public varargs float getFloat(long[])",
"public double get(com.yahoo.tensor.TensorAddress)",
"public boolean has(com.yahoo.tensor.TensorAddress)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
new file mode 100644
index 00000000000..80cb545238c
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
@@ -0,0 +1,38 @@
+package com.yahoo.tensor;
+
+/**
+ * Utility class for efficient access and iteration along dimensions in Indexed tensors.
+ * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration.
+ * long base = addr.getIndex();
+ * long stride = addr.getStride(dimension)
+ * i = 0...size_of_dimension
+ * double value = tensor.get(base + i * stride);
+ */
+public final class DirectIndexedAddress {
+ private final DimensionSizes sizes;
+ private final int [] indexes;
+ private long directIndex;
+ private DirectIndexedAddress(DimensionSizes sizes) {
+ this.sizes = sizes;
+ indexes = new int[sizes.dimensions()];
+ directIndex = 0;
+ }
+ static DirectIndexedAddress of(DimensionSizes sizes) {
+ return new DirectIndexedAddress(sizes);
+ }
+ /** Sets the current index of a dimension */
+ public void setIndex(int dimension, int index) {
+ if (index < 0 || index >= sizes.size(dimension)) {
+ throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">");
+ }
+ int diff = index - indexes[dimension];
+ directIndex += getStride(dimension) * diff;
+ indexes[dimension] = index;
+ }
+ /** Retrieve the index that can be used for direct lookup in an indexed tensor. */
+ public long getIndex() { return directIndex; }
+ /** returns the stride to be used for the given dimension */
+ public long getStride(int dimension) {
+ return sizes.productOfDimensionsAfter(dimension);
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 1319675f5d4..d43b4a03f72 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -93,6 +93,10 @@ public abstract class IndexedTensor implements Tensor {
return get(toValueIndex(indexes, dimensionSizes));
}
+ public double get(DirectIndexedAddress address) {
+ return get(address.getIndex());
+ }
+ public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); }
/**
* Returns the value at the given indexes as a float
*
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
index 0a6c821e64e..d145ad7a316 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
@@ -9,6 +9,7 @@ import java.util.Iterator;
import java.util.Map;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -96,6 +97,38 @@ public class IndexedTensorTestCase {
}
@Test
+ public void testDirectIndexedAddress() {
+ TensorType type = new TensorType.Builder().indexed("v", 3)
+ .indexed("w", wSize)
+ .indexed("x", xSize)
+ .indexed("y", ySize)
+ .indexed("z", zSize)
+ .build();
+ var directAddress = DirectIndexedAddress.of(DimensionSizes.of(type));
+ assertThrows(ArrayIndexOutOfBoundsException.class, () -> directAddress.getStride(5));
+ assertThrows(IndexOutOfBoundsException.class, () -> directAddress.setIndex(4, 7));
+ assertEquals(wSize*xSize*ySize*zSize, directAddress.getStride(0));
+ assertEquals(xSize*ySize*zSize, directAddress.getStride(1));
+ assertEquals(ySize*zSize, directAddress.getStride(2));
+ assertEquals(zSize, directAddress.getStride(3));
+ assertEquals(1, directAddress.getStride(4));
+ assertEquals(0, directAddress.getIndex());
+ directAddress.setIndex(0,1);
+ assertEquals(1 * directAddress.getStride(0), directAddress.getIndex());
+ directAddress.setIndex(1,1);
+ assertEquals(1 * directAddress.getStride(0) + 1 * directAddress.getStride(1), directAddress.getIndex());
+ directAddress.setIndex(2,2);
+ directAddress.setIndex(3,2);
+ directAddress.setIndex(4,2);
+ long expected = 1 * directAddress.getStride(0) +
+ 1 * directAddress.getStride(1) +
+ 2 * directAddress.getStride(2) +
+ 2 * directAddress.getStride(3) +
+ 2 * directAddress.getStride(4);
+ assertEquals(expected, directAddress.getIndex());
+ }
+
+ @Test
public void testUnboundBuilding() {
TensorType type = new TensorType.Builder().indexed("w")
.indexed("v")