summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-19 14:32:23 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-19 14:32:23 +0100
commit3cf8be5fe60504a02be04009b9348913ae32b564 (patch)
tree2efea86685fb8c94725e71628ddaeaa683a1faed /vespajlib/src/main/java/com/yahoo/tensor
parent74cbf975a435d54eb892de0142d6cceb2d1ebc93 (diff)
Add a class for assist efficient traversal of dimensions in an IndexedTensor.
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java4
2 files changed, 42 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
new file mode 100644
index 00000000000..80cb545238c
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
@@ -0,0 +1,38 @@
+package com.yahoo.tensor;
+
+/**
+ * Utility class for efficient access and iteration along dimensions in Indexed tensors.
+ * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration.
+ * long base = addr.getIndex();
+ * long stride = addr.getStride(dimension)
+ * i = 0...size_of_dimension
+ * double value = tensor.get(base + i * stride);
+ */
+public final class DirectIndexedAddress {
+ private final DimensionSizes sizes;
+ private final int [] indexes;
+ private long directIndex;
+ private DirectIndexedAddress(DimensionSizes sizes) {
+ this.sizes = sizes;
+ indexes = new int[sizes.dimensions()];
+ directIndex = 0;
+ }
+ static DirectIndexedAddress of(DimensionSizes sizes) {
+ return new DirectIndexedAddress(sizes);
+ }
+ /** Sets the current index of a dimension */
+ public void setIndex(int dimension, int index) {
+ if (index < 0 || index >= sizes.size(dimension)) {
+ throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">");
+ }
+ int diff = index - indexes[dimension];
+ directIndex += getStride(dimension) * diff;
+ indexes[dimension] = index;
+ }
+ /** Retrieve the index that can be used for direct lookup in an indexed tensor. */
+ public long getIndex() { return directIndex; }
+ /** returns the stride to be used for the given dimension */
+ public long getStride(int dimension) {
+ return sizes.productOfDimensionsAfter(dimension);
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 1319675f5d4..d43b4a03f72 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -93,6 +93,10 @@ public abstract class IndexedTensor implements Tensor {
return get(toValueIndex(indexes, dimensionSizes));
}
+ public double get(DirectIndexedAddress address) {
+ return get(address.getIndex());
+ }
+ public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); }
/**
* Returns the value at the given indexes as a float
*