summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-14 08:34:09 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-14 08:34:09 +0100
commitf5ccf036b4f7368f217a6bcbffc1699aac5eac2d (patch)
tree749afd3b29f52b918c67099c1742cb9db50211cf /vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
parent3954dbe2403bdbb21e9a558fbc55fd137afa40f8 (diff)
Interpret dimensions in written order
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java95
1 files changed, 69 insertions, 26 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 30923976fa5..ba3a35e8eda 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -218,7 +218,7 @@ public abstract class IndexedTensor implements Tensor {
indexes.next();
// start brackets
- for (int i = 0; i < indexes.rightDimensionsAtStart(); i++)
+ for (int i = 0; i < indexes.nextDimensionsAtStart(); i++)
b.append("[");
// value
@@ -230,7 +230,7 @@ public abstract class IndexedTensor implements Tensor {
throw new IllegalStateException("Unexpected value type " + type.valueType());
// end bracket and comma
- for (int i = 0; i < indexes.rightDimensionsAtEnd(); i++)
+ for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
b.append("]");
if (index < size() - 1)
b.append(", ");
@@ -777,6 +777,10 @@ public abstract class IndexedTensor implements Tensor {
return of(DimensionSizes.of(type));
}
+ public static Indexes of(TensorType type, List<String> iterateDimensionOrder) {
+ return of(DimensionSizes.of(type), toIterationOrder(iterateDimensionOrder, type));
+ }
+
public static Indexes of(DimensionSizes sizes) {
return of(sizes, sizes);
}
@@ -789,6 +793,10 @@ public abstract class IndexedTensor implements Tensor {
return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size);
}
+ private static Indexes of(DimensionSizes sizes, List<Integer> iterateDimensions) {
+ return of(sizes, sizes, iterateDimensions);
+ }
+
private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions) {
return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions));
}
@@ -822,6 +830,16 @@ public abstract class IndexedTensor implements Tensor {
}
}
+ private static List<Integer> toIterationOrder(List<String> dimensionNames, TensorType type) {
+ if (dimensionNames == null) return completeIterationOrder(type.rank());
+
+ List<Integer> iterationDimensions = new ArrayList<>(type.rank());
+ for (int i = 0; i < type.rank(); i++)
+ iterationDimensions.add(type.rank() - 1 - type.indexOfDimension(dimensionNames.get(i)).get());
+ return iterationDimensions;
+ }
+
+ /** Since the right dimensions binds closest, iteration order is the opposite of the tensor order */
private static List<Integer> completeIterationOrder(int length) {
List<Integer> iterationDimensions = new ArrayList<>(length);
for (int i = 0; i < length; i++)
@@ -854,7 +872,7 @@ public abstract class IndexedTensor implements Tensor {
/** Returns a copy of the indexes of this which must not be modified */
public long[] indexesForReading() { return indexes; }
- long toSourceValueIndex() {
+ public long toSourceValueIndex() {
return IndexedTensor.toValueIndex(indexes, sourceSizes);
}
@@ -882,27 +900,12 @@ public abstract class IndexedTensor implements Tensor {
/** Returns whether further values are available by calling next() */
public abstract boolean hasNext();
- /** Returns the number of dimensions from the right which are currently at the start position (0) */
- int rightDimensionsAtStart() {
- int dimension = indexes.length - 1;
- int atStartCount = 0;
- while (dimension >= 0 && indexes[dimension] == 0) {
- atStartCount++;
- dimension--;
- }
- return atStartCount;
- }
+ /** Returns the number of dimensions in iteration order which are currently at the start position (0) */
+ abstract int nextDimensionsAtStart();
+
+ /** Returns the number of dimensions in iteration order which are currently at their end position */
+ abstract int nextDimensionsAtEnd();
- /** Returns the number of dimensions from the right which are currently at the end position */
- int rightDimensionsAtEnd() {
- int dimension = indexes.length - 1;
- int atEndCount = 0;
- while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) {
- atEndCount++;
- dimension--;
- }
- return atEndCount;
- }
}
private final static class EmptyIndexes extends Indexes {
@@ -920,6 +923,12 @@ public abstract class IndexedTensor implements Tensor {
@Override
public boolean hasNext() { return false; }
+ @Override
+ int nextDimensionsAtStart() { return 0; }
+
+ @Override
+ int nextDimensionsAtEnd() { return 0; }
+
}
private final static class SingleValueIndexes extends Indexes {
@@ -939,6 +948,12 @@ public abstract class IndexedTensor implements Tensor {
@Override
public boolean hasNext() { return ! exhausted; }
+ @Override
+ int nextDimensionsAtStart() { return 1; }
+
+ @Override
+ int nextDimensionsAtEnd() { return 1; }
+
}
private static class MultiDimensionIndexes extends Indexes {
@@ -987,6 +1002,22 @@ public abstract class IndexedTensor implements Tensor {
return false;
}
+ @Override
+ int nextDimensionsAtStart() {
+ int dimension = 0;
+ while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == 0)
+ dimension++;
+ return dimension;
+ }
+
+ @Override
+ int nextDimensionsAtEnd() {
+ int dimension = 0;
+ while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == dimensionSizes().size(iterateDimensions.get(dimension)) - 1)
+ dimension++;
+ return dimension;
+ }
+
}
/** In this case we can reuse the source index computation for the iteration index */
@@ -999,7 +1030,7 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
- long toSourceValueIndex() {
+ public long toSourceValueIndex() {
return lastComputedSourceValueIndex = super.toSourceValueIndex();
}
@@ -1056,7 +1087,7 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
- long toSourceValueIndex() { return currentSourceValueIndex; }
+ public long toSourceValueIndex() { return currentSourceValueIndex; }
@Override
long toIterationValueIndex() { return currentIterationValueIndex; }
@@ -1066,6 +1097,12 @@ public abstract class IndexedTensor implements Tensor {
return indexes[iterateDimension] + 1 < size;
}
+ @Override
+ int nextDimensionsAtStart() { return currentSourceValueIndex == 0 ? 1 : 0; }
+
+ @Override
+ int nextDimensionsAtEnd() { return currentSourceValueIndex == size - 1 ? 1 : 0; }
+
}
/** In this case we only need to keep track of one index */
@@ -1117,11 +1154,17 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
- long toSourceValueIndex() { return currentValueIndex; }
+ public long toSourceValueIndex() { return currentValueIndex; }
@Override
long toIterationValueIndex() { return currentValueIndex; }
+ @Override
+ int nextDimensionsAtStart() { return currentValueIndex == 0 ? 1 : 0; }
+
+ @Override
+ int nextDimensionsAtEnd() { return currentValueIndex == size - 1 ? 1 : 0; }
+
}
}