diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 155 |
1 files changed, 106 insertions, 49 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 6e03c27af75..b89185b5131 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -81,6 +81,12 @@ public class IndexedTensor implements Tensor { return subspaceIterator(dimensions, dimensionSizes); } + /** Returns whether the dimensions sizes of this are equal to the given sizes */ + // TODO: Replace by returning immutable sizes when DimensionSizes are a class + public boolean dimensionSizesAre(int[] dimensionSizes) { + return Arrays.equals(dimensionSizes, this.dimensionSizes); + } + /** * Returns the value at the given indexes * @@ -95,7 +101,7 @@ public class IndexedTensor implements Tensor { /** Returns the value at this address, or NaN if there is no value at this address */ @Override public double get(TensorAddress address) { - // optimize for fast lookup within bounds + // optimize for fast lookup within bounds: try { return values[toValueIndex(address, dimensionSizes)]; } @@ -104,6 +110,8 @@ public class IndexedTensor implements Tensor { } } + double get(int valueIndex) { return values[valueIndex]; } + /** Returns the value at these indexes */ private double get(Indexes indexes) { return values[toValueIndex(indexes.indexesForReading(), dimensionSizes)]; @@ -153,10 +161,10 @@ public class IndexedTensor implements Tensor { return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]); ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); - Indexes indexes = Indexes.of(dimensionSizes, values.length); + Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); for (int i = 0; i < values.length; i++) { indexes.next(); - builder.put(indexes.toAddress(), values[i]); + builder.put(indexes.toAddress(i), values[i]); } return builder.build(); } @@ -209,7 +217,10 @@ public class IndexedTensor implements Tensor { } public abstract Builder cell(double value, int ... indexes); - + + /** Add a cell by internal index */ + public abstract Builder cellWithInternalIndex(int internalIndex, double value); + protected double[] arrayFor(int[] dimensionSizes) { int productSize = 1; for (int dimensionSize : dimensionSizes) @@ -281,6 +292,12 @@ public class IndexedTensor implements Tensor { return tensor; } + @Override + public Builder cellWithInternalIndex(int internalIndex, double value) { + values[internalIndex] = value; + return this; + } + } /** @@ -400,13 +417,17 @@ public class IndexedTensor implements Tensor { list.add(list.size(), null); } + @Override + public Builder cellWithInternalIndex(int internalIndex, double value) { + throw new UnsupportedOperationException("Not supoprted for unbound builders"); + } + } - // TODO: Generalize to vector cell iterator? private final class CellIterator implements Iterator<Map.Entry<TensorAddress, Double>> { private int count = 0; - private final Indexes indexes = Indexes.of(dimensionSizes, values.length); + private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); @Override public boolean hasNext() { @@ -418,7 +439,9 @@ public class IndexedTensor implements Tensor { if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes); count++; indexes.next(); - return new Cell(indexes.toAddress(), get(indexes)); + int valueIndex = toValueIndex(indexes.indexesForReading(), IndexedTensor.this.dimensionSizes); + TensorAddress address = indexes.toAddress(valueIndex); + return new Cell(address, get(valueIndex)); } } @@ -493,12 +516,12 @@ public class IndexedTensor implements Tensor { * The sizes of the space we'll return values of, one value for each dimension of this tensor, * which may be equal to or smaller than the sizes of this tensor */ - private final int[] dimensionSizes; + private final int[] iterateDimensionSizes; private int count = 0; - private SuperspaceIterator(Set<String> superdimensionNames, int[] dimensionSizes) { - this.dimensionSizes = dimensionSizes; + private SuperspaceIterator(Set<String> superdimensionNames, int[] iterateDimensionSizes) { + this.iterateDimensionSizes = iterateDimensionSizes; List<Integer> superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length) @@ -509,7 +532,7 @@ public class IndexedTensor implements Tensor { subdimensionIndexes.add(i); } - superindexes = Indexes.of(dimensionSizes, superdimensionIndexes); + superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, superdimensionIndexes); } @Override @@ -522,7 +545,7 @@ public class IndexedTensor implements Tensor { if ( ! hasNext()) throw new NoSuchElementException("No cell at " + superindexes); count++; superindexes.next(); - return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), dimensionSizes); + return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), iterateDimensionSizes); } } @@ -539,7 +562,7 @@ public class IndexedTensor implements Tensor { */ private final List<Integer> iterateDimensions; private final int[] address; - private final int[] dimensionSizes; + private final int[] iterateDimensionSizes; private Indexes indexes; private int count = 0; @@ -556,11 +579,11 @@ public class IndexedTensor implements Tensor { * This is treated as immutable. * @param address the address of the first cell of this subspace. */ - private SubspaceIterator(List<Integer> iterateDimensions, int[] address, int[] dimensionSizes) { + private SubspaceIterator(List<Integer> iterateDimensions, int[] address, int[] iterateDimensionSizes) { this.iterateDimensions = iterateDimensions; this.address = address; - this.dimensionSizes = dimensionSizes; - this.indexes = Indexes.of(dimensionSizes, iterateDimensions, address); + this.iterateDimensionSizes = iterateDimensionSizes; + this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, iterateDimensions, address); } /** Returns the total number of cells in this subspace */ @@ -569,12 +592,12 @@ public class IndexedTensor implements Tensor { } /** Returns the address of the cell this currently points to (which may be an invalid position) */ - public TensorAddress address() { return indexes.toAddress(); } + public TensorAddress address() { return indexes.toAddress(-1); } /** Rewind this iterator to the first element */ public void reset() { this.count = 0; - this.indexes = Indexes.of(dimensionSizes, iterateDimensions, address); + this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, iterateDimensions, address); } @Override @@ -587,10 +610,14 @@ public class IndexedTensor implements Tensor { if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes); count++; indexes.next(); - return new Cell(indexes.toAddress(), get(indexes)); + int valueIndex = indexes.toValueIndex(); + TensorAddress address = indexes.toAddress(valueIndex); + return new Cell(address, get(valueIndex)); // TODO: Change type to Cell, then change Cell to work with indexes + valueIndex instead of creating an address? } } + + // TODO: Make dimensionSizes a class /** * An array of indexes into this tensor which are able to find the next index in the value order. @@ -599,37 +626,45 @@ public class IndexedTensor implements Tensor { */ public abstract static class Indexes { + private final int[] sourceDimensionSizes; + + private final int[] iterateDimensionSizes; + protected final int[] indexes; public static Indexes of(int[] dimensionSizes) { - return of(dimensionSizes, completeIterationOrder(dimensionSizes.length)); + return of(dimensionSizes, dimensionSizes); + } + + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes) { + return of(sourceDimensionSizes, iterateDimensionSizes, completeIterationOrder(iterateDimensionSizes.length)); } - private static Indexes of(int[] dimensionSizes, int size) { - return of(dimensionSizes, completeIterationOrder(dimensionSizes.length), size); + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int size) { + return of(sourceDimensionSizes, iterateDimensionSizes, completeIterationOrder(iterateDimensionSizes.length), size); } - private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions) { - return of(dimensionSizes, iterateDimensions, computeSize(dimensionSizes, iterateDimensions)); + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions) { + return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, computeSize(iterateDimensionSizes, iterateDimensions)); } - private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int size) { - return of(dimensionSizes, iterateDimensions, new int[dimensionSizes.length], size); + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int size) { + return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, new int[iterateDimensionSizes.length], size); } - private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes) { - return of(dimensionSizes, iterateDimensions, initialIndexes, computeSize(dimensionSizes, iterateDimensions)); + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes) { + return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, initialIndexes, computeSize(iterateDimensionSizes, iterateDimensions)); } - private static Indexes of(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { if (size == 0) - return new EmptyIndexes(initialIndexes); // we're told explicitly there are truly no values available + return new EmptyIndexes(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); // we're told explicitly there are truly no values available else if (size == 1) - return new SingleValueIndexes(initialIndexes); // with no (iterating) dimensions, we still return one value, not zero + return new SingleValueIndexes(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); // with no (iterating) dimensions, we still return one value, not zero else if (iterateDimensions.size() == 1) - return new SingleDimensionIndexes(iterateDimensions.get(0), initialIndexes, size); // optimization + return new SingleDimensionIndexes(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions.get(0), initialIndexes, size); // optimization else - return new MultivalueIndexes(dimensionSizes, iterateDimensions, initialIndexes, size); + return new MultivalueIndexes(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, initialIndexes, size); } private static List<Integer> completeIterationOrder(int length) { @@ -639,7 +674,9 @@ public class IndexedTensor implements Tensor { return iterationDimensions; } - private Indexes(int[] indexes) { + private Indexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) { + this.sourceDimensionSizes = sourceDimensionSizes; + this.iterateDimensionSizes = iterateDimensionSizes; this.indexes = indexes; } @@ -651,8 +688,8 @@ public class IndexedTensor implements Tensor { } /** Returns the address of the current position of these indexes */ - private TensorAddress toAddress() { - return TensorAddress.of(indexes); + private TensorAddress toAddress(int valueIndex) { + return TensorAddress.withValueIndex(valueIndex, indexes); } public int[] indexesCopy() { @@ -661,6 +698,14 @@ public class IndexedTensor implements Tensor { /** Returns a copy of the indexes of this which must not be modified */ public int[] indexesForReading() { return indexes; } + + /** Returns the value index for this in the tensor we are iterating over */ + int toValueIndex() { + return IndexedTensor.toValueIndex(indexes, sourceDimensionSizes); + } + + /** Returns the dimension sizes of this. Do not modify the return value */ + int[] dimensionSizes() { return iterateDimensionSizes; } /** Returns an immutable list containing a copy of the indexes in this */ public List<Integer> toList() { @@ -683,8 +728,8 @@ public class IndexedTensor implements Tensor { private final static class EmptyIndexes extends Indexes { - private EmptyIndexes(int[] indexes) { - super(indexes); + private EmptyIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) { + super(sourceDimensionSizes, iterateDimensionSizes, indexes); } @Override @@ -697,8 +742,8 @@ public class IndexedTensor implements Tensor { private final static class SingleValueIndexes extends Indexes { - private SingleValueIndexes(int[] indexes) { - super(indexes); + private SingleValueIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) { + super(sourceDimensionSizes, iterateDimensionSizes, indexes); } @Override @@ -713,13 +758,10 @@ public class IndexedTensor implements Tensor { private final int size; - private final int[] dimensionSizes; - private final List<Integer> iterateDimensions; - private MultivalueIndexes(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { - super(initialIndexes); - this.dimensionSizes = dimensionSizes; + private MultivalueIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { + super(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); this.iterateDimensions = iterateDimensions; this.size = size; @@ -742,7 +784,7 @@ public class IndexedTensor implements Tensor { @Override public void next() { int iterateDimensionsIndex = 0; - while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes[iterateDimensions.get(iterateDimensionsIndex)]) { + while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes()[iterateDimensions.get(iterateDimensionsIndex)]) { indexes[iterateDimensions.get(iterateDimensionsIndex)] = 0; // carry over iterateDimensionsIndex++; } @@ -756,16 +798,25 @@ public class IndexedTensor implements Tensor { private final int size; private final int iterateDimension; + + /** Maintain this directly as an optimization for 1-d iteration */ + private int currentValueIndex; - private SingleDimensionIndexes(int iterateDimension, int[] initialIndexes, int size) { - super(initialIndexes); + /** The iteration step in the value index space */ + private final int step; + + private SingleDimensionIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, + int iterateDimension, int[] initialIndexes, int size) { + super(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; + this.step = productOfDimensionsAfter(iterateDimension, sourceDimensionSizes); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; + currentValueIndex = IndexedTensor.toValueIndex(indexes, sourceDimensionSizes); } - + /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @Override public int size() { @@ -781,6 +832,12 @@ public class IndexedTensor implements Tensor { @Override public void next() { indexes[iterateDimension]++; + currentValueIndex += step; + } + + @Override + int toValueIndex() { + return currentValueIndex; } } |