summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java38
1 files changed, 38 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
new file mode 100644
index 00000000000..37752361876
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
@@ -0,0 +1,38 @@
+package com.yahoo.tensor;
+
+/**
+ * Utility class for efficient access and iteration along dimensions in Indexed tensors.
+ * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration.
+ * long base = addr.getDirectIndex();
+ * long stride = addr.getStride(dimension)
+ * i = 0...size_of_dimension
+ * double value = tensor.get(base + i * stride);
+ */
+public final class DirectIndexedAddress {
+ private final DimensionSizes sizes;
+ private final int [] indexes;
+ private long directIndex;
+ private DirectIndexedAddress(DimensionSizes sizes) {
+ this.sizes = sizes;
+ indexes = new int[sizes.dimensions()];
+ directIndex = 0;
+ }
+ static DirectIndexedAddress of(DimensionSizes sizes) {
+ return new DirectIndexedAddress(sizes);
+ }
+ /** Sets the current index of a dimension */
+ public void setIndex(int dimension, int index) {
+ if (index < 0 || index >= sizes.size(dimension)) {
+ throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">");
+ }
+ int diff = index - indexes[dimension];
+ directIndex += getStride(dimension) * diff;
+ indexes[dimension] = index;
+ }
+ /** Retrieve the index that can be used for direct lookup in an indexed tensor. */
+ public long getDirectIndex() { return directIndex; }
+ /** returns the stride to be used for the given dimension */
+ public long getStride(int dimension) {
+ return sizes.productOfDimensionsAfter(dimension);
+ }
+}