aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java155
1 files changed, 106 insertions, 49 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 6e03c27af75..b89185b5131 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -81,6 +81,12 @@ public class IndexedTensor implements Tensor {
return subspaceIterator(dimensions, dimensionSizes);
}
+ /** Returns whether the dimensions sizes of this are equal to the given sizes */
+ // TODO: Replace by returning immutable sizes when DimensionSizes are a class
+ public boolean dimensionSizesAre(int[] dimensionSizes) {
+ return Arrays.equals(dimensionSizes, this.dimensionSizes);
+ }
+
/**
* Returns the value at the given indexes
*
@@ -95,7 +101,7 @@ public class IndexedTensor implements Tensor {
/** Returns the value at this address, or NaN if there is no value at this address */
@Override
public double get(TensorAddress address) {
- // optimize for fast lookup within bounds
+ // optimize for fast lookup within bounds:
try {
return values[toValueIndex(address, dimensionSizes)];
}
@@ -104,6 +110,8 @@ public class IndexedTensor implements Tensor {
}
}
+ double get(int valueIndex) { return values[valueIndex]; }
+
/** Returns the value at these indexes */
private double get(Indexes indexes) {
return values[toValueIndex(indexes.indexesForReading(), dimensionSizes)];
@@ -153,10 +161,10 @@ public class IndexedTensor implements Tensor {
return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]);
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
- Indexes indexes = Indexes.of(dimensionSizes, values.length);
+ Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
for (int i = 0; i < values.length; i++) {
indexes.next();
- builder.put(indexes.toAddress(), values[i]);
+ builder.put(indexes.toAddress(i), values[i]);
}
return builder.build();
}
@@ -209,7 +217,10 @@ public class IndexedTensor implements Tensor {
}
public abstract Builder cell(double value, int ... indexes);
-
+
+ /** Add a cell by internal index */
+ public abstract Builder cellWithInternalIndex(int internalIndex, double value);
+
protected double[] arrayFor(int[] dimensionSizes) {
int productSize = 1;
for (int dimensionSize : dimensionSizes)
@@ -281,6 +292,12 @@ public class IndexedTensor implements Tensor {
return tensor;
}
+ @Override
+ public Builder cellWithInternalIndex(int internalIndex, double value) {
+ values[internalIndex] = value;
+ return this;
+ }
+
}
/**
@@ -400,13 +417,17 @@ public class IndexedTensor implements Tensor {
list.add(list.size(), null);
}
+ @Override
+ public Builder cellWithInternalIndex(int internalIndex, double value) {
+ throw new UnsupportedOperationException("Not supoprted for unbound builders");
+ }
+
}
- // TODO: Generalize to vector cell iterator?
private final class CellIterator implements Iterator<Map.Entry<TensorAddress, Double>> {
private int count = 0;
- private final Indexes indexes = Indexes.of(dimensionSizes, values.length);
+ private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
@Override
public boolean hasNext() {
@@ -418,7 +439,9 @@ public class IndexedTensor implements Tensor {
if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes);
count++;
indexes.next();
- return new Cell(indexes.toAddress(), get(indexes));
+ int valueIndex = toValueIndex(indexes.indexesForReading(), IndexedTensor.this.dimensionSizes);
+ TensorAddress address = indexes.toAddress(valueIndex);
+ return new Cell(address, get(valueIndex));
}
}
@@ -493,12 +516,12 @@ public class IndexedTensor implements Tensor {
* The sizes of the space we'll return values of, one value for each dimension of this tensor,
* which may be equal to or smaller than the sizes of this tensor
*/
- private final int[] dimensionSizes;
+ private final int[] iterateDimensionSizes;
private int count = 0;
- private SuperspaceIterator(Set<String> superdimensionNames, int[] dimensionSizes) {
- this.dimensionSizes = dimensionSizes;
+ private SuperspaceIterator(Set<String> superdimensionNames, int[] iterateDimensionSizes) {
+ this.iterateDimensionSizes = iterateDimensionSizes;
List<Integer> superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator
subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length)
@@ -509,7 +532,7 @@ public class IndexedTensor implements Tensor {
subdimensionIndexes.add(i);
}
- superindexes = Indexes.of(dimensionSizes, superdimensionIndexes);
+ superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, superdimensionIndexes);
}
@Override
@@ -522,7 +545,7 @@ public class IndexedTensor implements Tensor {
if ( ! hasNext()) throw new NoSuchElementException("No cell at " + superindexes);
count++;
superindexes.next();
- return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), dimensionSizes);
+ return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), iterateDimensionSizes);
}
}
@@ -539,7 +562,7 @@ public class IndexedTensor implements Tensor {
*/
private final List<Integer> iterateDimensions;
private final int[] address;
- private final int[] dimensionSizes;
+ private final int[] iterateDimensionSizes;
private Indexes indexes;
private int count = 0;
@@ -556,11 +579,11 @@ public class IndexedTensor implements Tensor {
* This is treated as immutable.
* @param address the address of the first cell of this subspace.
*/
- private SubspaceIterator(List<Integer> iterateDimensions, int[] address, int[] dimensionSizes) {
+ private SubspaceIterator(List<Integer> iterateDimensions, int[] address, int[] iterateDimensionSizes) {
this.iterateDimensions = iterateDimensions;
this.address = address;
- this.dimensionSizes = dimensionSizes;
- this.indexes = Indexes.of(dimensionSizes, iterateDimensions, address);
+ this.iterateDimensionSizes = iterateDimensionSizes;
+ this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, iterateDimensions, address);
}
/** Returns the total number of cells in this subspace */
@@ -569,12 +592,12 @@ public class IndexedTensor implements Tensor {
}
/** Returns the address of the cell this currently points to (which may be an invalid position) */
- public TensorAddress address() { return indexes.toAddress(); }
+ public TensorAddress address() { return indexes.toAddress(-1); }
/** Rewind this iterator to the first element */
public void reset() {
this.count = 0;
- this.indexes = Indexes.of(dimensionSizes, iterateDimensions, address);
+ this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, iterateDimensions, address);
}
@Override
@@ -587,10 +610,14 @@ public class IndexedTensor implements Tensor {
if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes);
count++;
indexes.next();
- return new Cell(indexes.toAddress(), get(indexes));
+ int valueIndex = indexes.toValueIndex();
+ TensorAddress address = indexes.toAddress(valueIndex);
+ return new Cell(address, get(valueIndex)); // TODO: Change type to Cell, then change Cell to work with indexes + valueIndex instead of creating an address?
}
}
+
+ // TODO: Make dimensionSizes a class
/**
* An array of indexes into this tensor which are able to find the next index in the value order.
@@ -599,37 +626,45 @@ public class IndexedTensor implements Tensor {
*/
public abstract static class Indexes {
+ private final int[] sourceDimensionSizes;
+
+ private final int[] iterateDimensionSizes;
+
protected final int[] indexes;
public static Indexes of(int[] dimensionSizes) {
- return of(dimensionSizes, completeIterationOrder(dimensionSizes.length));
+ return of(dimensionSizes, dimensionSizes);
+ }
+
+ private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes) {
+ return of(sourceDimensionSizes, iterateDimensionSizes, completeIterationOrder(iterateDimensionSizes.length));
}
- private static Indexes of(int[] dimensionSizes, int size) {
- return of(dimensionSizes, completeIterationOrder(dimensionSizes.length), size);
+ private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int size) {
+ return of(sourceDimensionSizes, iterateDimensionSizes, completeIterationOrder(iterateDimensionSizes.length), size);
}
- private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions) {
- return of(dimensionSizes, iterateDimensions, computeSize(dimensionSizes, iterateDimensions));
+ private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions) {
+ return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, computeSize(iterateDimensionSizes, iterateDimensions));
}
- private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int size) {
- return of(dimensionSizes, iterateDimensions, new int[dimensionSizes.length], size);
+ private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int size) {
+ return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, new int[iterateDimensionSizes.length], size);
}
- private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes) {
- return of(dimensionSizes, iterateDimensions, initialIndexes, computeSize(dimensionSizes, iterateDimensions));
+ private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes) {
+ return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, initialIndexes, computeSize(iterateDimensionSizes, iterateDimensions));
}
- private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
if (size == 0)
- return new EmptyIndexes(initialIndexes); // we're told explicitly there are truly no values available
+ return new EmptyIndexes(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); // we're told explicitly there are truly no values available
else if (size == 1)
- return new SingleValueIndexes(initialIndexes); // with no (iterating) dimensions, we still return one value, not zero
+ return new SingleValueIndexes(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); // with no (iterating) dimensions, we still return one value, not zero
else if (iterateDimensions.size() == 1)
- return new SingleDimensionIndexes(iterateDimensions.get(0), initialIndexes, size); // optimization
+ return new SingleDimensionIndexes(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions.get(0), initialIndexes, size); // optimization
else
- return new MultivalueIndexes(dimensionSizes, iterateDimensions, initialIndexes, size);
+ return new MultivalueIndexes(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, initialIndexes, size);
}
private static List<Integer> completeIterationOrder(int length) {
@@ -639,7 +674,9 @@ public class IndexedTensor implements Tensor {
return iterationDimensions;
}
- private Indexes(int[] indexes) {
+ private Indexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) {
+ this.sourceDimensionSizes = sourceDimensionSizes;
+ this.iterateDimensionSizes = iterateDimensionSizes;
this.indexes = indexes;
}
@@ -651,8 +688,8 @@ public class IndexedTensor implements Tensor {
}
/** Returns the address of the current position of these indexes */
- private TensorAddress toAddress() {
- return TensorAddress.of(indexes);
+ private TensorAddress toAddress(int valueIndex) {
+ return TensorAddress.withValueIndex(valueIndex, indexes);
}
public int[] indexesCopy() {
@@ -661,6 +698,14 @@ public class IndexedTensor implements Tensor {
/** Returns a copy of the indexes of this which must not be modified */
public int[] indexesForReading() { return indexes; }
+
+ /** Returns the value index for this in the tensor we are iterating over */
+ int toValueIndex() {
+ return IndexedTensor.toValueIndex(indexes, sourceDimensionSizes);
+ }
+
+ /** Returns the dimension sizes of this. Do not modify the return value */
+ int[] dimensionSizes() { return iterateDimensionSizes; }
/** Returns an immutable list containing a copy of the indexes in this */
public List<Integer> toList() {
@@ -683,8 +728,8 @@ public class IndexedTensor implements Tensor {
private final static class EmptyIndexes extends Indexes {
- private EmptyIndexes(int[] indexes) {
- super(indexes);
+ private EmptyIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) {
+ super(sourceDimensionSizes, iterateDimensionSizes, indexes);
}
@Override
@@ -697,8 +742,8 @@ public class IndexedTensor implements Tensor {
private final static class SingleValueIndexes extends Indexes {
- private SingleValueIndexes(int[] indexes) {
- super(indexes);
+ private SingleValueIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) {
+ super(sourceDimensionSizes, iterateDimensionSizes, indexes);
}
@Override
@@ -713,13 +758,10 @@ public class IndexedTensor implements Tensor {
private final int size;
- private final int[] dimensionSizes;
-
private final List<Integer> iterateDimensions;
- private MultivalueIndexes(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
- super(initialIndexes);
- this.dimensionSizes = dimensionSizes;
+ private MultivalueIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ super(sourceDimensionSizes, iterateDimensionSizes, initialIndexes);
this.iterateDimensions = iterateDimensions;
this.size = size;
@@ -742,7 +784,7 @@ public class IndexedTensor implements Tensor {
@Override
public void next() {
int iterateDimensionsIndex = 0;
- while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes[iterateDimensions.get(iterateDimensionsIndex)]) {
+ while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes()[iterateDimensions.get(iterateDimensionsIndex)]) {
indexes[iterateDimensions.get(iterateDimensionsIndex)] = 0; // carry over
iterateDimensionsIndex++;
}
@@ -756,16 +798,25 @@ public class IndexedTensor implements Tensor {
private final int size;
private final int iterateDimension;
+
+ /** Maintain this directly as an optimization for 1-d iteration */
+ private int currentValueIndex;
- private SingleDimensionIndexes(int iterateDimension, int[] initialIndexes, int size) {
- super(initialIndexes);
+ /** The iteration step in the value index space */
+ private final int step;
+
+ private SingleDimensionIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes,
+ int iterateDimension, int[] initialIndexes, int size) {
+ super(sourceDimensionSizes, iterateDimensionSizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
+ this.step = productOfDimensionsAfter(iterateDimension, sourceDimensionSizes);
// Initialize to the (virtual) position before the first cell
indexes[iterateDimension]--;
+ currentValueIndex = IndexedTensor.toValueIndex(indexes, sourceDimensionSizes);
}
-
+
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
public int size() {
@@ -781,6 +832,12 @@ public class IndexedTensor implements Tensor {
@Override
public void next() {
indexes[iterateDimension]++;
+ currentValueIndex += step;
+ }
+
+ @Override
+ int toValueIndex() {
+ return currentValueIndex;
}
}