diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-14 08:34:09 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-14 08:34:09 +0100 |
commit | f5ccf036b4f7368f217a6bcbffc1699aac5eac2d (patch) | |
tree | 749afd3b29f52b918c67099c1742cb9db50211cf /vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | |
parent | 3954dbe2403bdbb21e9a558fbc55fd137afa40f8 (diff) |
Interpret dimensions in written order
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 95 |
1 files changed, 69 insertions, 26 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 30923976fa5..ba3a35e8eda 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -218,7 +218,7 @@ public abstract class IndexedTensor implements Tensor { indexes.next(); // start brackets - for (int i = 0; i < indexes.rightDimensionsAtStart(); i++) + for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) b.append("["); // value @@ -230,7 +230,7 @@ public abstract class IndexedTensor implements Tensor { throw new IllegalStateException("Unexpected value type " + type.valueType()); // end bracket and comma - for (int i = 0; i < indexes.rightDimensionsAtEnd(); i++) + for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) b.append("]"); if (index < size() - 1) b.append(", "); @@ -777,6 +777,10 @@ public abstract class IndexedTensor implements Tensor { return of(DimensionSizes.of(type)); } + public static Indexes of(TensorType type, List<String> iterateDimensionOrder) { + return of(DimensionSizes.of(type), toIterationOrder(iterateDimensionOrder, type)); + } + public static Indexes of(DimensionSizes sizes) { return of(sizes, sizes); } @@ -789,6 +793,10 @@ public abstract class IndexedTensor implements Tensor { return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size); } + private static Indexes of(DimensionSizes sizes, List<Integer> iterateDimensions) { + return of(sizes, sizes, iterateDimensions); + } + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions) { return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions)); } @@ -822,6 +830,16 @@ public abstract class IndexedTensor implements Tensor { } } + private static List<Integer> toIterationOrder(List<String> dimensionNames, TensorType type) { + if (dimensionNames == null) return completeIterationOrder(type.rank()); + + List<Integer> iterationDimensions = new ArrayList<>(type.rank()); + for (int i = 0; i < type.rank(); i++) + iterationDimensions.add(type.rank() - 1 - type.indexOfDimension(dimensionNames.get(i)).get()); + return iterationDimensions; + } + + /** Since the right dimensions binds closest, iteration order is the opposite of the tensor order */ private static List<Integer> completeIterationOrder(int length) { List<Integer> iterationDimensions = new ArrayList<>(length); for (int i = 0; i < length; i++) @@ -854,7 +872,7 @@ public abstract class IndexedTensor implements Tensor { /** Returns a copy of the indexes of this which must not be modified */ public long[] indexesForReading() { return indexes; } - long toSourceValueIndex() { + public long toSourceValueIndex() { return IndexedTensor.toValueIndex(indexes, sourceSizes); } @@ -882,27 +900,12 @@ public abstract class IndexedTensor implements Tensor { /** Returns whether further values are available by calling next() */ public abstract boolean hasNext(); - /** Returns the number of dimensions from the right which are currently at the start position (0) */ - int rightDimensionsAtStart() { - int dimension = indexes.length - 1; - int atStartCount = 0; - while (dimension >= 0 && indexes[dimension] == 0) { - atStartCount++; - dimension--; - } - return atStartCount; - } + /** Returns the number of dimensions in iteration order which are currently at the start position (0) */ + abstract int nextDimensionsAtStart(); + + /** Returns the number of dimensions in iteration order which are currently at their end position */ + abstract int nextDimensionsAtEnd(); - /** Returns the number of dimensions from the right which are currently at the end position */ - int rightDimensionsAtEnd() { - int dimension = indexes.length - 1; - int atEndCount = 0; - while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) { - atEndCount++; - dimension--; - } - return atEndCount; - } } private final static class EmptyIndexes extends Indexes { @@ -920,6 +923,12 @@ public abstract class IndexedTensor implements Tensor { @Override public boolean hasNext() { return false; } + @Override + int nextDimensionsAtStart() { return 0; } + + @Override + int nextDimensionsAtEnd() { return 0; } + } private final static class SingleValueIndexes extends Indexes { @@ -939,6 +948,12 @@ public abstract class IndexedTensor implements Tensor { @Override public boolean hasNext() { return ! exhausted; } + @Override + int nextDimensionsAtStart() { return 1; } + + @Override + int nextDimensionsAtEnd() { return 1; } + } private static class MultiDimensionIndexes extends Indexes { @@ -987,6 +1002,22 @@ public abstract class IndexedTensor implements Tensor { return false; } + @Override + int nextDimensionsAtStart() { + int dimension = 0; + while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == 0) + dimension++; + return dimension; + } + + @Override + int nextDimensionsAtEnd() { + int dimension = 0; + while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == dimensionSizes().size(iterateDimensions.get(dimension)) - 1) + dimension++; + return dimension; + } + } /** In this case we can reuse the source index computation for the iteration index */ @@ -999,7 +1030,7 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { + public long toSourceValueIndex() { return lastComputedSourceValueIndex = super.toSourceValueIndex(); } @@ -1056,7 +1087,7 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { return currentSourceValueIndex; } + public long toSourceValueIndex() { return currentSourceValueIndex; } @Override long toIterationValueIndex() { return currentIterationValueIndex; } @@ -1066,6 +1097,12 @@ public abstract class IndexedTensor implements Tensor { return indexes[iterateDimension] + 1 < size; } + @Override + int nextDimensionsAtStart() { return currentSourceValueIndex == 0 ? 1 : 0; } + + @Override + int nextDimensionsAtEnd() { return currentSourceValueIndex == size - 1 ? 1 : 0; } + } /** In this case we only need to keep track of one index */ @@ -1117,11 +1154,17 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { return currentValueIndex; } + public long toSourceValueIndex() { return currentValueIndex; } @Override long toIterationValueIndex() { return currentValueIndex; } + @Override + int nextDimensionsAtStart() { return currentValueIndex == 0 ? 1 : 0; } + + @Override + int nextDimensionsAtEnd() { return currentValueIndex == size - 1 ? 1 : 0; } + } } |