summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-01-03 16:23:16 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-01-03 16:23:16 +0100
commita837e4c028cdda5d7a84ac7fffb65f740a001416 (patch)
treed5e44c8f9212c089a6aab3e47ca9cecea8410bd4 /vespajlib
parentf2dfa7062d3dd6f718bd2c6ff4bac7ab6232e92a (diff)
Support iterating over dimensions in any order
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java115
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java2
2 files changed, 58 insertions, 59 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index f19097da6bd..e9bb121ef3e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -486,8 +486,8 @@ public class IndexedTensor implements Tensor {
private final Indexes superindexes;
- /** true at indexes whose dimension subspaces iterate over */
- private final boolean[] subdimensionIndexes;
+ /** Those indexes this should iterate over */
+ private final List<Integer> subdimensionIndexes;
/**
* The sizes of the space we'll return values of, one value for each dimension of this tensor,
@@ -500,12 +500,13 @@ public class IndexedTensor implements Tensor {
private SuperspaceIterator(Set<String> superdimensionNames, int[] dimensionSizes) {
this.dimensionSizes = dimensionSizes;
- boolean[] superdimensionIndexes = new boolean[dimensionSizes.length]; // for outer iterator
- subdimensionIndexes = new boolean [dimensionSizes.length]; // for inner iterator
- for (int i = 0; i < type.dimensions().size(); i++ ) {
- boolean superDimension = superdimensionNames.contains(type.dimensions().get(i).name());
- superdimensionIndexes[i] = superDimension;
- subdimensionIndexes[i] = ! superDimension;
+ List<Integer> superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator
+ subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length)
+ for (int i = type.dimensions().size() - 1; i >= 0; i-- ) { // iterate inner dimensions first
+ if (superdimensionNames.contains(type.dimensions().get(i).name()))
+ superdimensionIndexes.add(i);
+ else
+ subdimensionIndexes.add(i);
}
superindexes = Indexes.of(dimensionSizes, superdimensionIndexes);
@@ -531,7 +532,12 @@ public class IndexedTensor implements Tensor {
*/
public final class SubspaceIterator implements Iterator<Map.Entry<TensorAddress, Double>> {
- private final boolean[] dimensionIndexes;
+ /**
+ * This iterator will iterate over the given dimensions, in the order given
+ * (the first dimension index given is incremented to exhaustion first (i.e is etc.).
+ * This may be any subset of the dimensions given by address and dimensionSizes.
+ */
+ private final List<Integer> iterateDimensions;
private final int[] address;
private final int[] dimensionSizes;
@@ -541,15 +547,20 @@ public class IndexedTensor implements Tensor {
/**
* Creates a new subspace iterator
*
- * @param dimensionIndexes a boolean array with a true entry for dimensions we should iterate over and false
- * entries for all other dimensions
+ * @param iterateDimensions the dimensions to iterate over, given as indexes in the dimension order of the
+ * type of the tensor this iterates over. This iterator will iterate over these
+ * dimensions to exhaustion in the order given (the first dimension index given is
+ * incremented to exhaustion first (i.e is etc.), while other dimensions will be held
+ * at a constant position.
+ * This may be any subset of the dimensions given by address and dimensionSizes.
+ * This is treated as immutable.
* @param address the address of the first cell of this subspace.
*/
- private SubspaceIterator(boolean[] dimensionIndexes, int[] address, int[] dimensionSizes) {
- this.dimensionIndexes = dimensionIndexes;
+ private SubspaceIterator(List<Integer> iterateDimensions, int[] address, int[] dimensionSizes) {
+ this.iterateDimensions = iterateDimensions;
this.address = address;
this.dimensionSizes = dimensionSizes;
- this.indexes = Indexes.of(dimensionSizes, dimensionIndexes, address);
+ this.indexes = Indexes.of(dimensionSizes, iterateDimensions, address);
}
/** Returns the total number of cells in this subspace */
@@ -563,7 +574,7 @@ public class IndexedTensor implements Tensor {
/** Rewind this iterator to the first element */
public void reset() {
this.count = 0;
- this.indexes = Indexes.of(dimensionSizes, dimensionIndexes, address);
+ this.indexes = Indexes.of(dimensionSizes, iterateDimensions, address);
}
@Override
@@ -591,49 +602,49 @@ public class IndexedTensor implements Tensor {
protected final int[] indexes;
public static Indexes of(int[] dimensionSizes) {
- return of(dimensionSizes, trueArray(dimensionSizes.length));
+ return of(dimensionSizes, completeIterationOrder(dimensionSizes.length));
}
private static Indexes of(int[] dimensionSizes, int size) {
- return of(dimensionSizes, trueArray(dimensionSizes.length), size);
+ return of(dimensionSizes, completeIterationOrder(dimensionSizes.length), size);
}
- private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions) {
- return of(dimensionSizes, iteratingDimensions, computeSize(dimensionSizes, iteratingDimensions));
+ private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions) {
+ return of(dimensionSizes, iterateDimensions, computeSize(dimensionSizes, iterateDimensions));
}
- private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensionIndexes, int size) {
- return of(dimensionSizes, iteratingDimensionIndexes, new int[dimensionSizes.length], size);
+ private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int size) {
+ return of(dimensionSizes, iterateDimensions, new int[dimensionSizes.length], size);
}
- private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes) {
- return of(dimensionSizes, iteratingDimensions, initialIndexes, computeSize(dimensionSizes, iteratingDimensions));
+ private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes) {
+ return of(dimensionSizes, iterateDimensions, initialIndexes, computeSize(dimensionSizes, iterateDimensions));
}
- private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) {
+ private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
if (size == 0)
return new EmptyIndexes(initialIndexes); // we're told explicitly there are truly no values available
else if (size == 1)
return new SingleValueIndexes(initialIndexes); // with no (iterating) dimensions, we still return one value, not zero
else
- return new MultivalueIndexes(dimensionSizes, iteratingDimensions, initialIndexes, size);
+ return new MultivalueIndexes(dimensionSizes, iterateDimensions, initialIndexes, size);
+ }
+
+ private static List<Integer> completeIterationOrder(int length) {
+ List<Integer> iterationDimensions = new ArrayList<>(length);
+ for (int i = 0; i < length; i++)
+ iterationDimensions.add(length - 1 - i);
+ return iterationDimensions;
}
private Indexes(int[] indexes) {
this.indexes = indexes;
}
- private static boolean[] trueArray(int size) {
- boolean[] array = new boolean[size];
- Arrays.fill(array, true);
- return array;
- }
-
- private static int computeSize(int[] dimensionSizes, boolean[] iteratingDimensions) {
+ private static int computeSize(int[] dimensionSizes, List<Integer> iterateDimensions) {
int size = 1;
- for (int dimensionIndex = 0; dimensionIndex < dimensionSizes.length; dimensionIndex++)
- if (iteratingDimensions[dimensionIndex])
- size *= dimensionSizes[dimensionIndex];
+ for (int iterateDimension : iterateDimensions)
+ size *= dimensionSizes[iterateDimension];
return size;
}
@@ -707,27 +718,16 @@ public class IndexedTensor implements Tensor {
private final int[] dimensionSizes;
- /** Only mutate (take next in) the dimension indexes which are true */
- private final boolean[] iteratingDimensions;
-
- private static boolean haveIteratingDimensions(boolean[] iteratingDimensions) {
- for (boolean iterating : iteratingDimensions)
- if (iterating)
- return true;
- return false;
- }
+ private final List<Integer> iterateDimensions;
- private MultivalueIndexes(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) {
+ private MultivalueIndexes(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
super(initialIndexes);
this.dimensionSizes = dimensionSizes;
- this.iteratingDimensions = iteratingDimensions;
+ this.iterateDimensions = iterateDimensions;
this.size = size;
// Initialize to the (virtual) position before the first cell
- int currentDimension = indexes.length - 1;
- while (! iteratingDimensions[currentDimension])
- currentDimension--;
- indexes[currentDimension]--;
+ indexes[iterateDimensions.get(0)]--;
}
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@@ -739,18 +739,17 @@ public class IndexedTensor implements Tensor {
/**
* Advances this to the next cell in the standard indexed tensor cell order.
* The first call to this will put it at the first position.
+ *
+ * @throws RuntimeException if this is called more times than its size
*/
@Override
public void next() {
- int currentDimension = indexes.length - 1;
- while ( ! iteratingDimensions[currentDimension] ||
- indexes[currentDimension] + 1 == dimensionSizes[currentDimension]) {
- if ( iteratingDimensions[currentDimension])
- indexes[currentDimension--] = 0; // carry over
- else // leave this dimension as-is
- currentDimension--;
+ int iterateDimensionsIndex = 0;
+ while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes[iterateDimensions.get(iterateDimensionsIndex)]) {
+ indexes[iterateDimensions.get(iterateDimensionsIndex)] = 0; // carry over
+ iterateDimensionsIndex++;
}
- indexes[currentDimension]++;
+ indexes[iterateDimensions.get(iterateDimensionsIndex)]++;
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
index 45ce5c92a1e..ace911409bc 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
@@ -38,7 +38,7 @@ public class ConcatTestCase {
}
@Test
- public void testUnequalEqualSizesSameDimension() {
+ public void testUnequalSizesSameDimension() {
Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }");
Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x"));