aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-19 23:02:04 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-19 23:02:04 +0100
commit35d59981840614bf4b877714ee88e273816c46d2 (patch)
treefba37b2e8bc9fcee46821821ab2886d371fcd696 /vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
parent067eb48b7d2fc062a74392b1c16f5538b5031d5b (diff)
Use longs for dimensions lengths in all API's
This is to be able to support tensor dimensions with more than 2B elements in the future without API change.
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java194
1 files changed, 98 insertions, 96 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 6b0d769de9f..7130c053e9f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -38,7 +38,7 @@ public class IndexedTensor implements Tensor {
}
@Override
- public int size() {
+ public long size() {
return values.length;
}
@@ -55,10 +55,10 @@ public class IndexedTensor implements Tensor {
/** Returns an iterator over all the cells in this tensor which matches the given partial address */
// TODO: Move up to Tensor and create a mixed tensor which can implement it (and subspace iterators) efficiently
public SubspaceIterator cellIterator(PartialAddress partialAddress, DimensionSizes iterationSizes) {
- int[] startAddress = new int[type().dimensions().size()];
+ long[] startAddress = new long[type().dimensions().size()];
List<Integer> iterateDimensions = new ArrayList<>();
for (int i = 0; i < type().dimensions().size(); i++) {
- int partialAddressLabel = partialAddress.intLabel(type.dimensions().get(i).name());
+ long partialAddressLabel = partialAddress.numericLabel(type.dimensions().get(i).name());
if (partialAddressLabel >= 0) // iterate at this label
startAddress[i] = partialAddressLabel;
else // iterate over this dimension
@@ -102,8 +102,8 @@ public class IndexedTensor implements Tensor {
* @param indexes the indexes into the dimensions of this. Must be one number per dimension of this
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
- public double get(int ... indexes) {
- return values[toValueIndex(indexes, dimensionSizes)];
+ public double get(long ... indexes) {
+ return values[(int)toValueIndex(indexes, dimensionSizes)];
}
/** Returns the value at this address, or NaN if there is no value at this address */
@@ -111,20 +111,20 @@ public class IndexedTensor implements Tensor {
public double get(TensorAddress address) {
// optimize for fast lookup within bounds:
try {
- return values[toValueIndex(address, dimensionSizes)];
+ return values[(int)toValueIndex(address, dimensionSizes)];
}
catch (IndexOutOfBoundsException e) {
return Double.NaN;
}
}
- private double get(int valueIndex) { return values[valueIndex]; }
+ private double get(long valueIndex) { return values[(int)valueIndex]; }
- private static int toValueIndex(int[] indexes, DimensionSizes sizes) {
+ private static long toValueIndex(long[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
if (indexes.length == 0) return 0; // for speed
- int valueIndex = 0;
+ long valueIndex = 0;
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] >= sizes.size(i)) {
throw new IndexOutOfBoundsException();
@@ -134,21 +134,21 @@ public class IndexedTensor implements Tensor {
return valueIndex;
}
- private static int toValueIndex(TensorAddress address, DimensionSizes sizes) {
+ private static long toValueIndex(TensorAddress address, DimensionSizes sizes) {
if (address.isEmpty()) return 0;
- int valueIndex = 0;
+ long valueIndex = 0;
for (int i = 0; i < address.size(); i++) {
- if (address.intLabel(i) >= sizes.size(i)) {
+ if (address.numericLabel(i) >= sizes.size(i)) {
throw new IndexOutOfBoundsException();
}
- valueIndex += productOfDimensionsAfter(i, sizes) * address.intLabel(i);
+ valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i);
}
return valueIndex;
}
- private static int productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) {
- int product = 1;
+ private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) {
+ long product = 1;
for (int i = afterIndex + 1; i < sizes.dimensions(); i++)
product *= sizes.size(i);
return product;
@@ -168,9 +168,9 @@ public class IndexedTensor implements Tensor {
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
- for (int i = 0; i < values.length; i++) {
+ for (long i = 0; i < values.length; i++) {
indexes.next();
- builder.put(indexes.toAddress(), values[i]);
+ builder.put(indexes.toAddress(), values[(int)i]);
}
return builder.build();
}
@@ -213,7 +213,7 @@ public class IndexedTensor implements Tensor {
throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
"for " + type);
for (int i = 0; i < sizes.dimensions(); i++ ) {
- Optional<Integer> size = type.dimensions().get(i).size();
+ Optional<Long> size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
sizes.size(i) +
@@ -223,7 +223,7 @@ public class IndexedTensor implements Tensor {
return new BoundBuilder(type, sizes);
}
- public abstract Builder cell(double value, int ... indexes);
+ public abstract Builder cell(double value, long ... indexes);
@Override
public TensorType type() { return type; }
@@ -255,12 +255,12 @@ public class IndexedTensor implements Tensor {
if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
- values = new double[sizes.totalSize()];
+ values = new double[(int)sizes.totalSize()];
}
@Override
- public BoundBuilder cell(double value, int ... indexes) {
- values[toValueIndex(indexes, sizes)] = value;
+ public BoundBuilder cell(double value, long ... indexes) {
+ values[(int)toValueIndex(indexes, sizes)] = value;
return this;
}
@@ -271,7 +271,7 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(TensorAddress address, double value) {
- values[toValueIndex(address, sizes)] = value;
+ values[(int)toValueIndex(address, sizes)] = value;
return this;
}
@@ -286,9 +286,9 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(Cell cell, double value) {
- int directIndex = cell.getDirectIndex();
+ long directIndex = cell.getDirectIndex();
if (directIndex >= 0) // optimization
- values[directIndex] = value;
+ values[(int)directIndex] = value;
else
super.cell(cell, value);
return this;
@@ -299,8 +299,8 @@ public class IndexedTensor implements Tensor {
* This requires knowledge of the internal layout of cells in this implementation, and should therefore
* probably not be used (but when it can be used it is fast).
*/
- public void cellByDirectIndex(int index, double value) {
- values[index] = value;
+ public void cellByDirectIndex(long index, double value) {
+ values[(int)index] = value;
}
}
@@ -326,13 +326,13 @@ public class IndexedTensor implements Tensor {
return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
DimensionSizes dimensionSizes = findDimensionSizes(firstDimension);
- double[] values = new double[dimensionSizes.totalSize()];
+ double[] values = new double[(int)dimensionSizes.totalSize()];
fillValues(0, 0, firstDimension, dimensionSizes, values);
return new IndexedTensor(type, dimensionSizes, values);
}
private DimensionSizes findDimensionSizes(List<Object> firstDimension) {
- List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size());
+ List<Long> dimensionSizeList = new ArrayList<>(type.dimensions().size());
findDimensionSizes(0, dimensionSizeList, firstDimension);
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct
for (int i = 0; i < b.dimensions(); i++) {
@@ -343,9 +343,9 @@ public class IndexedTensor implements Tensor {
}
@SuppressWarnings("unchecked")
- private void findDimensionSizes(int currentDimensionIndex, List<Integer> dimensionSizes, List<Object> currentDimension) {
+ private void findDimensionSizes(int currentDimensionIndex, List<Long> dimensionSizes, List<Object> currentDimension) {
if (currentDimensionIndex == dimensionSizes.size())
- dimensionSizes.add(currentDimension.size());
+ dimensionSizes.add((long)currentDimension.size());
else if (dimensionSizes.get(currentDimensionIndex) != currentDimension.size())
throw new IllegalArgumentException("Missing values in dimension " +
type.dimensions().get(currentDimensionIndex) + " in " + type);
@@ -356,16 +356,16 @@ public class IndexedTensor implements Tensor {
}
@SuppressWarnings("unchecked")
- private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension,
+ private void fillValues(int currentDimensionIndex, long offset, List<Object> currentDimension,
DimensionSizes sizes, double[] values) {
if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension
- for (int i = 0; i < currentDimension.size(); i++)
+ for (long i = 0; i < currentDimension.size(); i++)
fillValues(currentDimensionIndex + 1,
offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i,
- (List<Object>) currentDimension.get(i), sizes, values);
+ (List<Object>) currentDimension.get((int)i), sizes, values);
} else { // last dimension - fill values
- for (int i = 0; i < currentDimension.size(); i++) {
- values[offset + i] = nullAsZero((Double)currentDimension.get(i)); // fill missing values as zero
+ for (long i = 0; i < currentDimension.size(); i++) {
+ values[(int)(offset + i)] = nullAsZero((Double)currentDimension.get((int)i)); // fill missing values as zero
}
}
}
@@ -382,9 +382,9 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(TensorAddress address, double value) {
- int[] indexes = new int[address.size()];
+ long[] indexes = new long[address.size()];
for (int i = 0; i < address.size(); i++) {
- indexes[i] = address.intLabel(i);
+ indexes[i] = address.numericLabel(i);
}
cell(value, indexes);
return this;
@@ -399,7 +399,7 @@ public class IndexedTensor implements Tensor {
*/
@SuppressWarnings("unchecked")
@Override
- public Builder cell(double value, int... indexes) {
+ public Builder cell(double value, long... indexes) {
if (indexes.length != type.dimensions().size())
throw new IllegalArgumentException("Wrong number of indexes (" + indexes.length + ") for " + type);
@@ -414,18 +414,18 @@ public class IndexedTensor implements Tensor {
for (int dimensionIndex = 0; dimensionIndex < indexes.length; dimensionIndex++) {
ensureCapacity(indexes[dimensionIndex], currentValues);
if (dimensionIndex == indexes.length - 1) { // last dimension
- currentValues.set(indexes[dimensionIndex], value);
+ currentValues.set((int)indexes[dimensionIndex], value);
} else {
- if (currentValues.get(indexes[dimensionIndex]) == null)
- currentValues.set(indexes[dimensionIndex], new ArrayList<>());
- currentValues = (List<Object>) currentValues.get(indexes[dimensionIndex]);
+ if (currentValues.get((int)indexes[dimensionIndex]) == null)
+ currentValues.set((int)indexes[dimensionIndex], new ArrayList<>());
+ currentValues = (List<Object>) currentValues.get((int)indexes[dimensionIndex]);
}
}
return this;
}
/** Fill the given list with nulls if necessary to make sure it has a (possibly null) value at the given index */
- private void ensureCapacity(int index, List<Object> list) {
+ private void ensureCapacity(long index, List<Object> list) {
while (list.size() <= index)
list.add(list.size(), null);
}
@@ -434,7 +434,7 @@ public class IndexedTensor implements Tensor {
private final class CellIterator implements Iterator<Cell> {
- private int count = 0;
+ private long count = 0;
private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
private final LazyCell reusedCell = new LazyCell(indexes, Double.NaN);
@@ -456,7 +456,7 @@ public class IndexedTensor implements Tensor {
private final class ValueIterator implements Iterator<Double> {
- private int count = 0;
+ private long count = 0;
@Override
public boolean hasNext() {
@@ -466,7 +466,7 @@ public class IndexedTensor implements Tensor {
@Override
public Double next() {
try {
- return values[count++];
+ return values[(int)count++];
}
catch (IndexOutOfBoundsException e) {
throw new NoSuchElementException("No element at position " + count);
@@ -479,7 +479,7 @@ public class IndexedTensor implements Tensor {
private final Indexes superindexes;
- /** Those indexes this should iterate over */
+ /** The indexes this should iterate over */
private final List<Integer> subdimensionIndexes;
/**
@@ -488,7 +488,7 @@ public class IndexedTensor implements Tensor {
*/
private final DimensionSizes iterateSizes;
- private int count = 0;
+ private long count = 0;
private SuperspaceIterator(Set<String> superdimensionNames, DimensionSizes iterateSizes) {
this.iterateSizes = iterateSizes;
@@ -533,11 +533,11 @@ public class IndexedTensor implements Tensor {
* This may be any subset of the dimensions given by address and dimensionSizes.
*/
private final List<Integer> iterateDimensions;
- private final int[] address;
+ private final long[] address;
private final DimensionSizes iterateSizes;
private Indexes indexes;
- private int count = 0;
+ private long count = 0;
/** A lazy cell for reuse */
private final LazyCell reusedCell;
@@ -554,7 +554,7 @@ public class IndexedTensor implements Tensor {
* This is treated as immutable.
* @param address the address of the first cell of this subspace.
*/
- private SubspaceIterator(List<Integer> iterateDimensions, int[] address, DimensionSizes iterateSizes) {
+ private SubspaceIterator(List<Integer> iterateDimensions, long[] address, DimensionSizes iterateSizes) {
this.iterateDimensions = iterateDimensions;
this.address = address;
this.iterateSizes = iterateSizes;
@@ -563,7 +563,7 @@ public class IndexedTensor implements Tensor {
}
/** Returns the total number of cells in this subspace */
- public int size() {
+ public long size() {
return indexes.size();
}
@@ -605,7 +605,7 @@ public class IndexedTensor implements Tensor {
}
@Override
- int getDirectIndex() { return indexes.toIterationValueIndex(); }
+ long getDirectIndex() { return indexes.toIterationValueIndex(); }
@Override
public TensorAddress getKey() {
@@ -630,7 +630,7 @@ public class IndexedTensor implements Tensor {
private final DimensionSizes iterationSizes;
- protected final int[] indexes;
+ protected final long[] indexes;
public static Indexes of(DimensionSizes sizes) {
return of(sizes, sizes);
@@ -640,7 +640,7 @@ public class IndexedTensor implements Tensor {
return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()));
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int size) {
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long size) {
return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size);
}
@@ -648,15 +648,15 @@ public class IndexedTensor implements Tensor {
return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions));
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int size) {
- return of(sourceSizes, iterateSizes, iterateDimensions, new int[iterateSizes.dimensions()], size);
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long size) {
+ return of(sourceSizes, iterateSizes, iterateDimensions, new long[iterateSizes.dimensions()], size);
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes) {
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes) {
return of(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, computeSize(iterateSizes, iterateDimensions));
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
if (size == 0) {
return new EmptyIndexes(sourceSizes, iterateSizes, initialIndexes); // we're told explicitly there are truly no values available
}
@@ -684,14 +684,14 @@ public class IndexedTensor implements Tensor {
return iterationDimensions;
}
- private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) {
+ private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, long[] indexes) {
this.sourceSizes = sourceSizes;
this.iterationSizes = iterationSizes;
this.indexes = indexes;
}
- private static int computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) {
- int size = 1;
+ private static long computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) {
+ long size = 1;
for (int iterateDimension : iterateDimensions)
size *= sizes.size(iterateDimension);
return size;
@@ -702,25 +702,25 @@ public class IndexedTensor implements Tensor {
return TensorAddress.of(indexes);
}
- public int[] indexesCopy() {
+ public long[] indexesCopy() {
return Arrays.copyOf(indexes, indexes.length);
}
/** Returns a copy of the indexes of this which must not be modified */
- public int[] indexesForReading() { return indexes; }
+ public long[] indexesForReading() { return indexes; }
- int toSourceValueIndex() {
+ long toSourceValueIndex() {
return IndexedTensor.toValueIndex(indexes, sourceSizes);
}
- int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); }
+ long toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); }
DimensionSizes dimensionSizes() { return iterationSizes; }
/** Returns an immutable list containing a copy of the indexes in this */
- public List<Integer> toList() {
- ImmutableList.Builder<Integer> builder = new ImmutableList.Builder<>();
- for (int index : indexes)
+ public List<Long> toList() {
+ ImmutableList.Builder<Long> builder = new ImmutableList.Builder<>();
+ for (long index : indexes)
builder.add(index);
return builder.build();
}
@@ -730,7 +730,7 @@ public class IndexedTensor implements Tensor {
return "indexes " + Arrays.toString(indexes);
}
- public abstract int size();
+ public abstract long size();
public abstract void next();
@@ -738,12 +738,12 @@ public class IndexedTensor implements Tensor {
private final static class EmptyIndexes extends Indexes {
- private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) {
+ private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) {
super(sourceSizes, iterateSizes, indexes);
}
@Override
- public int size() { return 0; }
+ public long size() { return 0; }
@Override
public void next() {}
@@ -752,12 +752,12 @@ public class IndexedTensor implements Tensor {
private final static class SingleValueIndexes extends Indexes {
- private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) {
+ private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) {
super(sourceSizes, iterateSizes, indexes);
}
@Override
- public int size() { return 1; }
+ public long size() { return 1; }
@Override
public void next() {}
@@ -766,11 +766,11 @@ public class IndexedTensor implements Tensor {
private static class MultiDimensionIndexes extends Indexes {
- private final int size;
+ private final long size;
private final List<Integer> iterateDimensions;
- private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimensions = iterateDimensions;
this.size = size;
@@ -781,7 +781,7 @@ public class IndexedTensor implements Tensor {
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
- public int size() {
+ public long size() {
return size;
}
@@ -806,36 +806,38 @@ public class IndexedTensor implements Tensor {
/** In this case we can reuse the source index computation for the iteration index */
private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes {
- private int lastComputedSourceValueIndex = -1;
+ private long lastComputedSourceValueIndex = -1;
- private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
super(sizes, sizes, iterateDimensions, initialIndexes, size);
}
- int toSourceValueIndex() {
+ @Override
+ long toSourceValueIndex() {
return lastComputedSourceValueIndex = super.toSourceValueIndex();
}
// NOTE: We assume the source index always gets computed first. Otherwise using this will produce a runtime exception
- int toIterationValueIndex() { return lastComputedSourceValueIndex; }
+ @Override
+ long toIterationValueIndex() { return lastComputedSourceValueIndex; }
}
/** In this case we can keep track of indexes using a step instead of using the more elaborate computation */
private final static class SingleDimensionIndexes extends Indexes {
- private final int size;
+ private final long size;
private final int iterateDimension;
/** Maintain this directly as an optimization for 1-d iteration */
- private int currentSourceValueIndex, currentIterationValueIndex;
+ private long currentSourceValueIndex, currentIterationValueIndex;
/** The iteration step in the value index space */
- private final int sourceStep, iterationStep;
+ private final long sourceStep, iterationStep;
private SingleDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes,
- int iterateDimension, int[] initialIndexes, int size) {
+ int iterateDimension, long[] initialIndexes, long size) {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
@@ -850,7 +852,7 @@ public class IndexedTensor implements Tensor {
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
- public int size() {
+ public long size() {
return size;
}
@@ -868,28 +870,28 @@ public class IndexedTensor implements Tensor {
}
@Override
- int toSourceValueIndex() { return currentSourceValueIndex; }
+ long toSourceValueIndex() { return currentSourceValueIndex; }
@Override
- int toIterationValueIndex() { return currentIterationValueIndex; }
+ long toIterationValueIndex() { return currentIterationValueIndex; }
}
/** In this case we only need to keep track of one index */
private final static class EqualSizeSingleDimensionIndexes extends Indexes {
- private final int size;
+ private final long size;
private final int iterateDimension;
/** Maintain this directly as an optimization for 1-d iteration */
- private int currentValueIndex;
+ private long currentValueIndex;
/** The iteration step in the value index space */
- private final int step;
+ private final long step;
private EqualSizeSingleDimensionIndexes(DimensionSizes sizes,
- int iterateDimension, int[] initialIndexes, int size) {
+ int iterateDimension, long[] initialIndexes, long size) {
super(sizes, sizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
@@ -902,7 +904,7 @@ public class IndexedTensor implements Tensor {
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
- public int size() {
+ public long size() {
return size;
}
@@ -919,10 +921,10 @@ public class IndexedTensor implements Tensor {
}
@Override
- int toSourceValueIndex() { return currentValueIndex; }
+ long toSourceValueIndex() { return currentValueIndex; }
@Override
- int toIterationValueIndex() { return currentValueIndex; }
+ long toIterationValueIndex() { return currentValueIndex; }
}