diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-03 16:23:16 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-03 16:23:16 +0100 |
commit | a837e4c028cdda5d7a84ac7fffb65f740a001416 (patch) | |
tree | d5e44c8f9212c089a6aab3e47ca9cecea8410bd4 /vespajlib | |
parent | f2dfa7062d3dd6f718bd2c6ff4bac7ab6232e92a (diff) |
Support iterating over dimensions in any order
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 115 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java | 2 |
2 files changed, 58 insertions, 59 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index f19097da6bd..e9bb121ef3e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -486,8 +486,8 @@ public class IndexedTensor implements Tensor { private final Indexes superindexes; - /** true at indexes whose dimension subspaces iterate over */ - private final boolean[] subdimensionIndexes; + /** Those indexes this should iterate over */ + private final List<Integer> subdimensionIndexes; /** * The sizes of the space we'll return values of, one value for each dimension of this tensor, @@ -500,12 +500,13 @@ public class IndexedTensor implements Tensor { private SuperspaceIterator(Set<String> superdimensionNames, int[] dimensionSizes) { this.dimensionSizes = dimensionSizes; - boolean[] superdimensionIndexes = new boolean[dimensionSizes.length]; // for outer iterator - subdimensionIndexes = new boolean [dimensionSizes.length]; // for inner iterator - for (int i = 0; i < type.dimensions().size(); i++ ) { - boolean superDimension = superdimensionNames.contains(type.dimensions().get(i).name()); - superdimensionIndexes[i] = superDimension; - subdimensionIndexes[i] = ! superDimension; + List<Integer> superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator + subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length) + for (int i = type.dimensions().size() - 1; i >= 0; i-- ) { // iterate inner dimensions first + if (superdimensionNames.contains(type.dimensions().get(i).name())) + superdimensionIndexes.add(i); + else + subdimensionIndexes.add(i); } superindexes = Indexes.of(dimensionSizes, superdimensionIndexes); @@ -531,7 +532,12 @@ public class IndexedTensor implements Tensor { */ public final class SubspaceIterator implements Iterator<Map.Entry<TensorAddress, Double>> { - private final boolean[] dimensionIndexes; + /** + * This iterator will iterate over the given dimensions, in the order given + * (the first dimension index given is incremented to exhaustion first (i.e is etc.). + * This may be any subset of the dimensions given by address and dimensionSizes. + */ + private final List<Integer> iterateDimensions; private final int[] address; private final int[] dimensionSizes; @@ -541,15 +547,20 @@ public class IndexedTensor implements Tensor { /** * Creates a new subspace iterator * - * @param dimensionIndexes a boolean array with a true entry for dimensions we should iterate over and false - * entries for all other dimensions + * @param iterateDimensions the dimensions to iterate over, given as indexes in the dimension order of the + * type of the tensor this iterates over. This iterator will iterate over these + * dimensions to exhaustion in the order given (the first dimension index given is + * incremented to exhaustion first (i.e is etc.), while other dimensions will be held + * at a constant position. + * This may be any subset of the dimensions given by address and dimensionSizes. + * This is treated as immutable. * @param address the address of the first cell of this subspace. */ - private SubspaceIterator(boolean[] dimensionIndexes, int[] address, int[] dimensionSizes) { - this.dimensionIndexes = dimensionIndexes; + private SubspaceIterator(List<Integer> iterateDimensions, int[] address, int[] dimensionSizes) { + this.iterateDimensions = iterateDimensions; this.address = address; this.dimensionSizes = dimensionSizes; - this.indexes = Indexes.of(dimensionSizes, dimensionIndexes, address); + this.indexes = Indexes.of(dimensionSizes, iterateDimensions, address); } /** Returns the total number of cells in this subspace */ @@ -563,7 +574,7 @@ public class IndexedTensor implements Tensor { /** Rewind this iterator to the first element */ public void reset() { this.count = 0; - this.indexes = Indexes.of(dimensionSizes, dimensionIndexes, address); + this.indexes = Indexes.of(dimensionSizes, iterateDimensions, address); } @Override @@ -591,49 +602,49 @@ public class IndexedTensor implements Tensor { protected final int[] indexes; public static Indexes of(int[] dimensionSizes) { - return of(dimensionSizes, trueArray(dimensionSizes.length)); + return of(dimensionSizes, completeIterationOrder(dimensionSizes.length)); } private static Indexes of(int[] dimensionSizes, int size) { - return of(dimensionSizes, trueArray(dimensionSizes.length), size); + return of(dimensionSizes, completeIterationOrder(dimensionSizes.length), size); } - private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions) { - return of(dimensionSizes, iteratingDimensions, computeSize(dimensionSizes, iteratingDimensions)); + private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions) { + return of(dimensionSizes, iterateDimensions, computeSize(dimensionSizes, iterateDimensions)); } - private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensionIndexes, int size) { - return of(dimensionSizes, iteratingDimensionIndexes, new int[dimensionSizes.length], size); + private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int size) { + return of(dimensionSizes, iterateDimensions, new int[dimensionSizes.length], size); } - private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes) { - return of(dimensionSizes, iteratingDimensions, initialIndexes, computeSize(dimensionSizes, iteratingDimensions)); + private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes) { + return of(dimensionSizes, iterateDimensions, initialIndexes, computeSize(dimensionSizes, iterateDimensions)); } - private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) { + private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { if (size == 0) return new EmptyIndexes(initialIndexes); // we're told explicitly there are truly no values available else if (size == 1) return new SingleValueIndexes(initialIndexes); // with no (iterating) dimensions, we still return one value, not zero else - return new MultivalueIndexes(dimensionSizes, iteratingDimensions, initialIndexes, size); + return new MultivalueIndexes(dimensionSizes, iterateDimensions, initialIndexes, size); + } + + private static List<Integer> completeIterationOrder(int length) { + List<Integer> iterationDimensions = new ArrayList<>(length); + for (int i = 0; i < length; i++) + iterationDimensions.add(length - 1 - i); + return iterationDimensions; } private Indexes(int[] indexes) { this.indexes = indexes; } - private static boolean[] trueArray(int size) { - boolean[] array = new boolean[size]; - Arrays.fill(array, true); - return array; - } - - private static int computeSize(int[] dimensionSizes, boolean[] iteratingDimensions) { + private static int computeSize(int[] dimensionSizes, List<Integer> iterateDimensions) { int size = 1; - for (int dimensionIndex = 0; dimensionIndex < dimensionSizes.length; dimensionIndex++) - if (iteratingDimensions[dimensionIndex]) - size *= dimensionSizes[dimensionIndex]; + for (int iterateDimension : iterateDimensions) + size *= dimensionSizes[iterateDimension]; return size; } @@ -707,27 +718,16 @@ public class IndexedTensor implements Tensor { private final int[] dimensionSizes; - /** Only mutate (take next in) the dimension indexes which are true */ - private final boolean[] iteratingDimensions; - - private static boolean haveIteratingDimensions(boolean[] iteratingDimensions) { - for (boolean iterating : iteratingDimensions) - if (iterating) - return true; - return false; - } + private final List<Integer> iterateDimensions; - private MultivalueIndexes(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) { + private MultivalueIndexes(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { super(initialIndexes); this.dimensionSizes = dimensionSizes; - this.iteratingDimensions = iteratingDimensions; + this.iterateDimensions = iterateDimensions; this.size = size; // Initialize to the (virtual) position before the first cell - int currentDimension = indexes.length - 1; - while (! iteratingDimensions[currentDimension]) - currentDimension--; - indexes[currentDimension]--; + indexes[iterateDimensions.get(0)]--; } /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @@ -739,18 +739,17 @@ public class IndexedTensor implements Tensor { /** * Advances this to the next cell in the standard indexed tensor cell order. * The first call to this will put it at the first position. + * + * @throws RuntimeException if this is called more times than its size */ @Override public void next() { - int currentDimension = indexes.length - 1; - while ( ! iteratingDimensions[currentDimension] || - indexes[currentDimension] + 1 == dimensionSizes[currentDimension]) { - if ( iteratingDimensions[currentDimension]) - indexes[currentDimension--] = 0; // carry over - else // leave this dimension as-is - currentDimension--; + int iterateDimensionsIndex = 0; + while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes[iterateDimensions.get(iterateDimensionsIndex)]) { + indexes[iterateDimensions.get(iterateDimensionsIndex)] = 0; // carry over + iterateDimensionsIndex++; } - indexes[currentDimension]++; + indexes[iterateDimensions.get(iterateDimensionsIndex)]++; } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java index 45ce5c92a1e..ace911409bc 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java @@ -38,7 +38,7 @@ public class ConcatTestCase { } @Test - public void testUnequalEqualSizesSameDimension() { + public void testUnequalSizesSameDimension() { Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }"); Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x")); |