summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java194
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java73
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java30
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java40
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java36
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java13
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java7
18 files changed, 256 insertions, 259 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index f6237a1977a..01bf082d32f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -13,7 +13,7 @@ import java.util.Arrays;
@Beta
public final class DimensionSizes {
- private final int[] sizes;
+ private final long[] sizes;
private DimensionSizes(Builder builder) {
this.sizes = builder.sizes;
@@ -25,15 +25,15 @@ public final class DimensionSizes {
*
* @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one
*/
- public int size(int dimensionIndex) { return sizes[dimensionIndex]; }
+ public long size(int dimensionIndex) { return sizes[dimensionIndex]; }
/** Returns the number of dimensions this provides the size of */
public int dimensions() { return sizes.length; }
/** Returns the product of the sizes of this */
- public int totalSize() {
- int productSize = 1;
- for (int dimensionSize : sizes )
+ public long totalSize() {
+ long productSize = 1;
+ for (long dimensionSize : sizes )
productSize *= dimensionSize;
return productSize;
}
@@ -54,13 +54,13 @@ public final class DimensionSizes {
*/
public final static class Builder {
- private int[] sizes;
+ private long[] sizes;
public Builder(int dimensions) {
- this.sizes = new int[dimensions];
+ this.sizes = new long[dimensions];
}
- public Builder set(int dimensionIndex, int size) {
+ public Builder set(int dimensionIndex, long size) {
sizes[dimensionIndex] = size;
return this;
}
@@ -70,7 +70,7 @@ public final class DimensionSizes {
*
* @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one
*/
- public int size(int dimensionIndex) { return sizes[dimensionIndex]; }
+ public long size(int dimensionIndex) { return sizes[dimensionIndex]; }
/** Returns the number of dimensions this provides the size of */
public int dimensions() { return sizes.length; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 6b0d769de9f..7130c053e9f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -38,7 +38,7 @@ public class IndexedTensor implements Tensor {
}
@Override
- public int size() {
+ public long size() {
return values.length;
}
@@ -55,10 +55,10 @@ public class IndexedTensor implements Tensor {
/** Returns an iterator over all the cells in this tensor which matches the given partial address */
// TODO: Move up to Tensor and create a mixed tensor which can implement it (and subspace iterators) efficiently
public SubspaceIterator cellIterator(PartialAddress partialAddress, DimensionSizes iterationSizes) {
- int[] startAddress = new int[type().dimensions().size()];
+ long[] startAddress = new long[type().dimensions().size()];
List<Integer> iterateDimensions = new ArrayList<>();
for (int i = 0; i < type().dimensions().size(); i++) {
- int partialAddressLabel = partialAddress.intLabel(type.dimensions().get(i).name());
+ long partialAddressLabel = partialAddress.numericLabel(type.dimensions().get(i).name());
if (partialAddressLabel >= 0) // iterate at this label
startAddress[i] = partialAddressLabel;
else // iterate over this dimension
@@ -102,8 +102,8 @@ public class IndexedTensor implements Tensor {
* @param indexes the indexes into the dimensions of this. Must be one number per dimension of this
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
- public double get(int ... indexes) {
- return values[toValueIndex(indexes, dimensionSizes)];
+ public double get(long ... indexes) {
+ return values[(int)toValueIndex(indexes, dimensionSizes)];
}
/** Returns the value at this address, or NaN if there is no value at this address */
@@ -111,20 +111,20 @@ public class IndexedTensor implements Tensor {
public double get(TensorAddress address) {
// optimize for fast lookup within bounds:
try {
- return values[toValueIndex(address, dimensionSizes)];
+ return values[(int)toValueIndex(address, dimensionSizes)];
}
catch (IndexOutOfBoundsException e) {
return Double.NaN;
}
}
- private double get(int valueIndex) { return values[valueIndex]; }
+ private double get(long valueIndex) { return values[(int)valueIndex]; }
- private static int toValueIndex(int[] indexes, DimensionSizes sizes) {
+ private static long toValueIndex(long[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
if (indexes.length == 0) return 0; // for speed
- int valueIndex = 0;
+ long valueIndex = 0;
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] >= sizes.size(i)) {
throw new IndexOutOfBoundsException();
@@ -134,21 +134,21 @@ public class IndexedTensor implements Tensor {
return valueIndex;
}
- private static int toValueIndex(TensorAddress address, DimensionSizes sizes) {
+ private static long toValueIndex(TensorAddress address, DimensionSizes sizes) {
if (address.isEmpty()) return 0;
- int valueIndex = 0;
+ long valueIndex = 0;
for (int i = 0; i < address.size(); i++) {
- if (address.intLabel(i) >= sizes.size(i)) {
+ if (address.numericLabel(i) >= sizes.size(i)) {
throw new IndexOutOfBoundsException();
}
- valueIndex += productOfDimensionsAfter(i, sizes) * address.intLabel(i);
+ valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i);
}
return valueIndex;
}
- private static int productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) {
- int product = 1;
+ private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) {
+ long product = 1;
for (int i = afterIndex + 1; i < sizes.dimensions(); i++)
product *= sizes.size(i);
return product;
@@ -168,9 +168,9 @@ public class IndexedTensor implements Tensor {
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
- for (int i = 0; i < values.length; i++) {
+ for (long i = 0; i < values.length; i++) {
indexes.next();
- builder.put(indexes.toAddress(), values[i]);
+ builder.put(indexes.toAddress(), values[(int)i]);
}
return builder.build();
}
@@ -213,7 +213,7 @@ public class IndexedTensor implements Tensor {
throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
"for " + type);
for (int i = 0; i < sizes.dimensions(); i++ ) {
- Optional<Integer> size = type.dimensions().get(i).size();
+ Optional<Long> size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
sizes.size(i) +
@@ -223,7 +223,7 @@ public class IndexedTensor implements Tensor {
return new BoundBuilder(type, sizes);
}
- public abstract Builder cell(double value, int ... indexes);
+ public abstract Builder cell(double value, long ... indexes);
@Override
public TensorType type() { return type; }
@@ -255,12 +255,12 @@ public class IndexedTensor implements Tensor {
if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
- values = new double[sizes.totalSize()];
+ values = new double[(int)sizes.totalSize()];
}
@Override
- public BoundBuilder cell(double value, int ... indexes) {
- values[toValueIndex(indexes, sizes)] = value;
+ public BoundBuilder cell(double value, long ... indexes) {
+ values[(int)toValueIndex(indexes, sizes)] = value;
return this;
}
@@ -271,7 +271,7 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(TensorAddress address, double value) {
- values[toValueIndex(address, sizes)] = value;
+ values[(int)toValueIndex(address, sizes)] = value;
return this;
}
@@ -286,9 +286,9 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(Cell cell, double value) {
- int directIndex = cell.getDirectIndex();
+ long directIndex = cell.getDirectIndex();
if (directIndex >= 0) // optimization
- values[directIndex] = value;
+ values[(int)directIndex] = value;
else
super.cell(cell, value);
return this;
@@ -299,8 +299,8 @@ public class IndexedTensor implements Tensor {
* This requires knowledge of the internal layout of cells in this implementation, and should therefore
* probably not be used (but when it can be used it is fast).
*/
- public void cellByDirectIndex(int index, double value) {
- values[index] = value;
+ public void cellByDirectIndex(long index, double value) {
+ values[(int)index] = value;
}
}
@@ -326,13 +326,13 @@ public class IndexedTensor implements Tensor {
return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
DimensionSizes dimensionSizes = findDimensionSizes(firstDimension);
- double[] values = new double[dimensionSizes.totalSize()];
+ double[] values = new double[(int)dimensionSizes.totalSize()];
fillValues(0, 0, firstDimension, dimensionSizes, values);
return new IndexedTensor(type, dimensionSizes, values);
}
private DimensionSizes findDimensionSizes(List<Object> firstDimension) {
- List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size());
+ List<Long> dimensionSizeList = new ArrayList<>(type.dimensions().size());
findDimensionSizes(0, dimensionSizeList, firstDimension);
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct
for (int i = 0; i < b.dimensions(); i++) {
@@ -343,9 +343,9 @@ public class IndexedTensor implements Tensor {
}
@SuppressWarnings("unchecked")
- private void findDimensionSizes(int currentDimensionIndex, List<Integer> dimensionSizes, List<Object> currentDimension) {
+ private void findDimensionSizes(int currentDimensionIndex, List<Long> dimensionSizes, List<Object> currentDimension) {
if (currentDimensionIndex == dimensionSizes.size())
- dimensionSizes.add(currentDimension.size());
+ dimensionSizes.add((long)currentDimension.size());
else if (dimensionSizes.get(currentDimensionIndex) != currentDimension.size())
throw new IllegalArgumentException("Missing values in dimension " +
type.dimensions().get(currentDimensionIndex) + " in " + type);
@@ -356,16 +356,16 @@ public class IndexedTensor implements Tensor {
}
@SuppressWarnings("unchecked")
- private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension,
+ private void fillValues(int currentDimensionIndex, long offset, List<Object> currentDimension,
DimensionSizes sizes, double[] values) {
if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension
- for (int i = 0; i < currentDimension.size(); i++)
+ for (long i = 0; i < currentDimension.size(); i++)
fillValues(currentDimensionIndex + 1,
offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i,
- (List<Object>) currentDimension.get(i), sizes, values);
+ (List<Object>) currentDimension.get((int)i), sizes, values);
} else { // last dimension - fill values
- for (int i = 0; i < currentDimension.size(); i++) {
- values[offset + i] = nullAsZero((Double)currentDimension.get(i)); // fill missing values as zero
+ for (long i = 0; i < currentDimension.size(); i++) {
+ values[(int)(offset + i)] = nullAsZero((Double)currentDimension.get((int)i)); // fill missing values as zero
}
}
}
@@ -382,9 +382,9 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(TensorAddress address, double value) {
- int[] indexes = new int[address.size()];
+ long[] indexes = new long[address.size()];
for (int i = 0; i < address.size(); i++) {
- indexes[i] = address.intLabel(i);
+ indexes[i] = address.numericLabel(i);
}
cell(value, indexes);
return this;
@@ -399,7 +399,7 @@ public class IndexedTensor implements Tensor {
*/
@SuppressWarnings("unchecked")
@Override
- public Builder cell(double value, int... indexes) {
+ public Builder cell(double value, long... indexes) {
if (indexes.length != type.dimensions().size())
throw new IllegalArgumentException("Wrong number of indexes (" + indexes.length + ") for " + type);
@@ -414,18 +414,18 @@ public class IndexedTensor implements Tensor {
for (int dimensionIndex = 0; dimensionIndex < indexes.length; dimensionIndex++) {
ensureCapacity(indexes[dimensionIndex], currentValues);
if (dimensionIndex == indexes.length - 1) { // last dimension
- currentValues.set(indexes[dimensionIndex], value);
+ currentValues.set((int)indexes[dimensionIndex], value);
} else {
- if (currentValues.get(indexes[dimensionIndex]) == null)
- currentValues.set(indexes[dimensionIndex], new ArrayList<>());
- currentValues = (List<Object>) currentValues.get(indexes[dimensionIndex]);
+ if (currentValues.get((int)indexes[dimensionIndex]) == null)
+ currentValues.set((int)indexes[dimensionIndex], new ArrayList<>());
+ currentValues = (List<Object>) currentValues.get((int)indexes[dimensionIndex]);
}
}
return this;
}
/** Fill the given list with nulls if necessary to make sure it has a (possibly null) value at the given index */
- private void ensureCapacity(int index, List<Object> list) {
+ private void ensureCapacity(long index, List<Object> list) {
while (list.size() <= index)
list.add(list.size(), null);
}
@@ -434,7 +434,7 @@ public class IndexedTensor implements Tensor {
private final class CellIterator implements Iterator<Cell> {
- private int count = 0;
+ private long count = 0;
private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
private final LazyCell reusedCell = new LazyCell(indexes, Double.NaN);
@@ -456,7 +456,7 @@ public class IndexedTensor implements Tensor {
private final class ValueIterator implements Iterator<Double> {
- private int count = 0;
+ private long count = 0;
@Override
public boolean hasNext() {
@@ -466,7 +466,7 @@ public class IndexedTensor implements Tensor {
@Override
public Double next() {
try {
- return values[count++];
+ return values[(int)count++];
}
catch (IndexOutOfBoundsException e) {
throw new NoSuchElementException("No element at position " + count);
@@ -479,7 +479,7 @@ public class IndexedTensor implements Tensor {
private final Indexes superindexes;
- /** Those indexes this should iterate over */
+ /** The indexes this should iterate over */
private final List<Integer> subdimensionIndexes;
/**
@@ -488,7 +488,7 @@ public class IndexedTensor implements Tensor {
*/
private final DimensionSizes iterateSizes;
- private int count = 0;
+ private long count = 0;
private SuperspaceIterator(Set<String> superdimensionNames, DimensionSizes iterateSizes) {
this.iterateSizes = iterateSizes;
@@ -533,11 +533,11 @@ public class IndexedTensor implements Tensor {
* This may be any subset of the dimensions given by address and dimensionSizes.
*/
private final List<Integer> iterateDimensions;
- private final int[] address;
+ private final long[] address;
private final DimensionSizes iterateSizes;
private Indexes indexes;
- private int count = 0;
+ private long count = 0;
/** A lazy cell for reuse */
private final LazyCell reusedCell;
@@ -554,7 +554,7 @@ public class IndexedTensor implements Tensor {
* This is treated as immutable.
* @param address the address of the first cell of this subspace.
*/
- private SubspaceIterator(List<Integer> iterateDimensions, int[] address, DimensionSizes iterateSizes) {
+ private SubspaceIterator(List<Integer> iterateDimensions, long[] address, DimensionSizes iterateSizes) {
this.iterateDimensions = iterateDimensions;
this.address = address;
this.iterateSizes = iterateSizes;
@@ -563,7 +563,7 @@ public class IndexedTensor implements Tensor {
}
/** Returns the total number of cells in this subspace */
- public int size() {
+ public long size() {
return indexes.size();
}
@@ -605,7 +605,7 @@ public class IndexedTensor implements Tensor {
}
@Override
- int getDirectIndex() { return indexes.toIterationValueIndex(); }
+ long getDirectIndex() { return indexes.toIterationValueIndex(); }
@Override
public TensorAddress getKey() {
@@ -630,7 +630,7 @@ public class IndexedTensor implements Tensor {
private final DimensionSizes iterationSizes;
- protected final int[] indexes;
+ protected final long[] indexes;
public static Indexes of(DimensionSizes sizes) {
return of(sizes, sizes);
@@ -640,7 +640,7 @@ public class IndexedTensor implements Tensor {
return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()));
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int size) {
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long size) {
return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size);
}
@@ -648,15 +648,15 @@ public class IndexedTensor implements Tensor {
return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions));
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int size) {
- return of(sourceSizes, iterateSizes, iterateDimensions, new int[iterateSizes.dimensions()], size);
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long size) {
+ return of(sourceSizes, iterateSizes, iterateDimensions, new long[iterateSizes.dimensions()], size);
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes) {
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes) {
return of(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, computeSize(iterateSizes, iterateDimensions));
}
- private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
if (size == 0) {
return new EmptyIndexes(sourceSizes, iterateSizes, initialIndexes); // we're told explicitly there are truly no values available
}
@@ -684,14 +684,14 @@ public class IndexedTensor implements Tensor {
return iterationDimensions;
}
- private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) {
+ private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, long[] indexes) {
this.sourceSizes = sourceSizes;
this.iterationSizes = iterationSizes;
this.indexes = indexes;
}
- private static int computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) {
- int size = 1;
+ private static long computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) {
+ long size = 1;
for (int iterateDimension : iterateDimensions)
size *= sizes.size(iterateDimension);
return size;
@@ -702,25 +702,25 @@ public class IndexedTensor implements Tensor {
return TensorAddress.of(indexes);
}
- public int[] indexesCopy() {
+ public long[] indexesCopy() {
return Arrays.copyOf(indexes, indexes.length);
}
/** Returns a copy of the indexes of this which must not be modified */
- public int[] indexesForReading() { return indexes; }
+ public long[] indexesForReading() { return indexes; }
- int toSourceValueIndex() {
+ long toSourceValueIndex() {
return IndexedTensor.toValueIndex(indexes, sourceSizes);
}
- int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); }
+ long toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); }
DimensionSizes dimensionSizes() { return iterationSizes; }
/** Returns an immutable list containing a copy of the indexes in this */
- public List<Integer> toList() {
- ImmutableList.Builder<Integer> builder = new ImmutableList.Builder<>();
- for (int index : indexes)
+ public List<Long> toList() {
+ ImmutableList.Builder<Long> builder = new ImmutableList.Builder<>();
+ for (long index : indexes)
builder.add(index);
return builder.build();
}
@@ -730,7 +730,7 @@ public class IndexedTensor implements Tensor {
return "indexes " + Arrays.toString(indexes);
}
- public abstract int size();
+ public abstract long size();
public abstract void next();
@@ -738,12 +738,12 @@ public class IndexedTensor implements Tensor {
private final static class EmptyIndexes extends Indexes {
- private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) {
+ private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) {
super(sourceSizes, iterateSizes, indexes);
}
@Override
- public int size() { return 0; }
+ public long size() { return 0; }
@Override
public void next() {}
@@ -752,12 +752,12 @@ public class IndexedTensor implements Tensor {
private final static class SingleValueIndexes extends Indexes {
- private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) {
+ private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) {
super(sourceSizes, iterateSizes, indexes);
}
@Override
- public int size() { return 1; }
+ public long size() { return 1; }
@Override
public void next() {}
@@ -766,11 +766,11 @@ public class IndexedTensor implements Tensor {
private static class MultiDimensionIndexes extends Indexes {
- private final int size;
+ private final long size;
private final List<Integer> iterateDimensions;
- private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimensions = iterateDimensions;
this.size = size;
@@ -781,7 +781,7 @@ public class IndexedTensor implements Tensor {
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
- public int size() {
+ public long size() {
return size;
}
@@ -806,36 +806,38 @@ public class IndexedTensor implements Tensor {
/** In this case we can reuse the source index computation for the iteration index */
private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes {
- private int lastComputedSourceValueIndex = -1;
+ private long lastComputedSourceValueIndex = -1;
- private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
super(sizes, sizes, iterateDimensions, initialIndexes, size);
}
- int toSourceValueIndex() {
+ @Override
+ long toSourceValueIndex() {
return lastComputedSourceValueIndex = super.toSourceValueIndex();
}
// NOTE: We assume the source index always gets computed first. Otherwise using this will produce a runtime exception
- int toIterationValueIndex() { return lastComputedSourceValueIndex; }
+ @Override
+ long toIterationValueIndex() { return lastComputedSourceValueIndex; }
}
/** In this case we can keep track of indexes using a step instead of using the more elaborate computation */
private final static class SingleDimensionIndexes extends Indexes {
- private final int size;
+ private final long size;
private final int iterateDimension;
/** Maintain this directly as an optimization for 1-d iteration */
- private int currentSourceValueIndex, currentIterationValueIndex;
+ private long currentSourceValueIndex, currentIterationValueIndex;
/** The iteration step in the value index space */
- private final int sourceStep, iterationStep;
+ private final long sourceStep, iterationStep;
private SingleDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes,
- int iterateDimension, int[] initialIndexes, int size) {
+ int iterateDimension, long[] initialIndexes, long size) {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
@@ -850,7 +852,7 @@ public class IndexedTensor implements Tensor {
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
- public int size() {
+ public long size() {
return size;
}
@@ -868,28 +870,28 @@ public class IndexedTensor implements Tensor {
}
@Override
- int toSourceValueIndex() { return currentSourceValueIndex; }
+ long toSourceValueIndex() { return currentSourceValueIndex; }
@Override
- int toIterationValueIndex() { return currentIterationValueIndex; }
+ long toIterationValueIndex() { return currentIterationValueIndex; }
}
/** In this case we only need to keep track of one index */
private final static class EqualSizeSingleDimensionIndexes extends Indexes {
- private final int size;
+ private final long size;
private final int iterateDimension;
/** Maintain this directly as an optimization for 1-d iteration */
- private int currentValueIndex;
+ private long currentValueIndex;
/** The iteration step in the value index space */
- private final int step;
+ private final long step;
private EqualSizeSingleDimensionIndexes(DimensionSizes sizes,
- int iterateDimension, int[] initialIndexes, int size) {
+ int iterateDimension, long[] initialIndexes, long size) {
super(sizes, sizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
@@ -902,7 +904,7 @@ public class IndexedTensor implements Tensor {
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
- public int size() {
+ public long size() {
return size;
}
@@ -919,10 +921,10 @@ public class IndexedTensor implements Tensor {
}
@Override
- int toSourceValueIndex() { return currentValueIndex; }
+ long toSourceValueIndex() { return currentValueIndex; }
@Override
- int toIterationValueIndex() { return currentValueIndex; }
+ long toIterationValueIndex() { return currentValueIndex; }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index aba61478e69..15993072c37 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -29,7 +29,7 @@ public class MappedTensor implements Tensor {
public TensorType type() { return type; }
@Override
- public int size() { return cells.size(); }
+ public long size() { return cells.size(); }
@Override
public double get(TensorAddress address) { return cells.getOrDefault(address, Double.NaN); }
@@ -80,7 +80,7 @@ public class MappedTensor implements Tensor {
}
@Override
- public Builder cell(double value, int... labels) {
+ public Builder cell(double value, long... labels) {
cells.put(TensorAddress.of(labels), value);
return this;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 9a751e078e0..0c9ed769c0d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -47,13 +47,13 @@ public class MixedTensor implements Tensor {
/** Returns the size of the tensor measured in number of cells */
@Override
- public int size() { return cells.size(); }
+ public long size() { return cells.size(); }
/** Returns the value at the given address */
@Override
public double get(TensorAddress address) {
- int cellIndex = index.indexOf(address);
- Cell cell = cells.get(cellIndex);
+ long cellIndex = index.indexOf(address);
+ Cell cell = cells.get((int)cellIndex);
if (!address.equals(cell.getKey())) {
throw new IllegalStateException("Unable to find correct cell by direct index.");
}
@@ -113,7 +113,7 @@ public class MixedTensor implements Tensor {
}
/** Returns the size of dense subspaces */
- public int denseSubspaceSize() {
+ public long denseSubspaceSize() {
return index.denseSubspaceSize();
}
@@ -148,7 +148,7 @@ public class MixedTensor implements Tensor {
}
@Override
- public Tensor.Builder cell(double value, int... labels) {
+ public Tensor.Builder cell(double value, long... labels) {
throw new UnsupportedOperationException("Not implemented.");
}
@@ -179,13 +179,13 @@ public class MixedTensor implements Tensor {
index = indexBuilder.index();
}
- public int denseSubspaceSize() {
+ public long denseSubspaceSize() {
return index.denseSubspaceSize();
}
private double[] denseSubspace(TensorAddress sparsePartial) {
if (!denseSubspaceMap.containsKey(sparsePartial)) {
- denseSubspaceMap.put(sparsePartial, new double[denseSubspaceSize()]);
+ denseSubspaceMap.put(sparsePartial, new double[(int)denseSubspaceSize()]);
}
return denseSubspaceMap.get(sparsePartial);
}
@@ -193,21 +193,21 @@ public class MixedTensor implements Tensor {
@Override
public Tensor.Builder cell(TensorAddress address, double value) {
TensorAddress sparsePart = index.sparsePartialAddress(address);
- int denseOffset = index.denseOffset(address);
+ long denseOffset = index.denseOffset(address);
double[] denseSubspace = denseSubspace(sparsePart);
- denseSubspace[denseOffset] = value;
+ denseSubspace[(int)denseOffset] = value;
return this;
}
public Tensor.Builder block(TensorAddress sparsePart, double[] values) {
double[] denseSubspace = denseSubspace(sparsePart);
- System.arraycopy(values, 0, denseSubspace, 0, denseSubspaceSize());
+ System.arraycopy(values, 0, denseSubspace, 0, (int)denseSubspaceSize());
return this;
}
@Override
public MixedTensor build() {
- int count = 0;
+ long count = 0;
ImmutableList.Builder<Cell> builder = new ImmutableList.Builder<>();
for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) {
@@ -215,9 +215,9 @@ public class MixedTensor implements Tensor {
indexBuilder.put(sparsePart, count);
double[] denseSubspace = entry.getValue();
- for (int offset = 0; offset < denseSubspace.length; ++offset) {
+ for (long offset = 0; offset < denseSubspace.length; ++offset) {
TensorAddress cellAddress = index.addressOf(sparsePart, offset);
- double value = denseSubspace[offset];
+ double value = denseSubspace[(int)offset];
builder.add(new Cell(cellAddress, value));
count++;
}
@@ -239,12 +239,12 @@ public class MixedTensor implements Tensor {
public static class UnboundBuilder extends Builder {
private Map<TensorAddress, Double> cells;
- private final int[] dimensionBounds;
+ private final long[] dimensionBounds;
private UnboundBuilder(TensorType type) {
super(type);
cells = new HashMap<>();
- dimensionBounds = new int[type.dimensions().size()];
+ dimensionBounds = new long[type.dimensions().size()];
}
@Override
@@ -268,7 +268,7 @@ public class MixedTensor implements Tensor {
for (int i = 0; i < type.dimensions().size(); ++i) {
TensorType.Dimension dimension = type.dimensions().get(i);
if (dimension.isIndexed()) {
- dimensionBounds[i] = Math.max(address.intLabel(i), dimensionBounds[i]);
+ dimensionBounds[i] = Math.max(address.numericLabel(i), dimensionBounds[i]);
}
}
}
@@ -280,7 +280,7 @@ public class MixedTensor implements Tensor {
if (!dimension.isIndexed()) {
typeBuilder.mapped(dimension.name());
} else {
- int size = dimension.size().orElse(dimensionBounds[i] + 1);
+ long size = dimension.size().orElse(dimensionBounds[i] + 1);
typeBuilder.indexed(dimension.name(), size);
}
}
@@ -303,8 +303,8 @@ public class MixedTensor implements Tensor {
private final List<TensorType.Dimension> mappedDimensions;
private final List<TensorType.Dimension> indexedDimensions;
- private ImmutableMap<TensorAddress, Integer> sparseMap;
- private int denseSubspaceSize = -1;
+ private ImmutableMap<TensorAddress, Long> sparseMap;
+ private long denseSubspaceSize = -1;
private Index(TensorType type) {
this.type = type;
@@ -314,26 +314,27 @@ public class MixedTensor implements Tensor {
this.denseType = createPartialType(indexedDimensions);
}
- public int indexOf(TensorAddress address) {
+ public long indexOf(TensorAddress address) {
TensorAddress sparsePart = sparsePartialAddress(address);
- if (!sparseMap.containsKey(sparsePart)) {
+ if ( ! sparseMap.containsKey(sparsePart)) {
throw new IllegalArgumentException("Address not found");
}
- int base = sparseMap.get(sparsePart);
- int offset = denseOffset(address);
+ long base = sparseMap.get(sparsePart);
+ long offset = denseOffset(address);
return base + offset;
}
public static class Builder {
+
private final Index index;
- private final ImmutableMap.Builder<TensorAddress, Integer> builder;
+ private final ImmutableMap.Builder<TensorAddress, Long> builder;
public Builder(TensorType type) {
index = new Index(type);
builder = new ImmutableMap.Builder<>();
}
- public void put(TensorAddress address, int index) {
+ public void put(TensorAddress address, long index) {
builder.put(address, index);
}
@@ -347,7 +348,7 @@ public class MixedTensor implements Tensor {
}
}
- public int denseSubspaceSize() {
+ public long denseSubspaceSize() {
if (denseSubspaceSize == -1) {
denseSubspaceSize = 1;
for (int i = 0; i < type.dimensions().size(); ++i) {
@@ -375,13 +376,13 @@ public class MixedTensor implements Tensor {
return builder.build();
}
- private int denseOffset(TensorAddress address) {
- int innerSize = 1;
- int offset = 0;
+ private long denseOffset(TensorAddress address) {
+ long innerSize = 1;
+ long offset = 0;
for (int i = type.dimensions().size(); --i >= 0; ) {
TensorType.Dimension dimension = type.dimensions().get(i);
if (dimension.isIndexed()) {
- int label = address.intLabel(i);
+ long label = address.numericLabel(i);
offset += label * innerSize;
innerSize *= dimension.size().orElseThrow(() ->
new IllegalArgumentException("Unknown size of indexed dimension."));
@@ -390,18 +391,18 @@ public class MixedTensor implements Tensor {
return offset;
}
- private TensorAddress denseOffsetToAddress(int denseOffset) {
+ private TensorAddress denseOffsetToAddress(long denseOffset) {
if (denseOffset < 0 || denseOffset > denseSubspaceSize) {
throw new IllegalArgumentException("Offset out of bounds");
}
- int restSize = denseOffset;
- int innerSize = denseSubspaceSize;
- int[] labels = new int[indexedDimensions.size()];
+ long restSize = denseOffset;
+ long innerSize = denseSubspaceSize;
+ long[] labels = new long[indexedDimensions.size()];
for (int i = 0; i < labels.length; ++i) {
TensorType.Dimension dimension = indexedDimensions.get(i);
- int dimensionSize = dimension.size().orElseThrow(() ->
+ long dimensionSize = dimension.size().orElseThrow(() ->
new IllegalArgumentException("Unknown size of indexed dimension."));
innerSize /= dimensionSize;
@@ -411,7 +412,7 @@ public class MixedTensor implements Tensor {
return TensorAddress.of(labels);
}
- private TensorAddress addressOf(TensorAddress sparsePart, int denseOffset) {
+ private TensorAddress addressOf(TensorAddress sparsePart, long denseOffset) {
TensorAddress densePart = denseOffsetToAddress(denseOffset);
String[] labels = new String[type.dimensions().size()];
int mappedIndex = 0;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
index e3398850373..23ef0772aea 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
@@ -6,11 +6,11 @@ import com.google.common.annotations.Beta;
/**
* An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors
* dimensions.
- *
+ *
* @author bratseth
*/
-// Implementation notes:
-// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation.
+// Implementation notes:
+// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation.
// We also avoid non-essential error checking.
// - We can add support for string labels later without breaking the API
@Beta
@@ -19,7 +19,7 @@ public class PartialAddress {
// Two arrays which contains corresponding dimension=label pairs.
// The sizes of these are always equal.
private final String[] dimensionNames;
- private final int[] labels;
+ private final long[] labels;
private PartialAddress(Builder builder) {
this.dimensionNames = builder.dimensionNames;
@@ -27,36 +27,36 @@ public class PartialAddress {
builder.dimensionNames = null; // invalidate builder to safely take over array ownership
builder.labels = null;
}
-
+
/** Returns the int label of this dimension, or -1 if no label is specified for it */
- int intLabel(String dimensionName) {
+ long numericLabel(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
return labels[i];
return -1;
}
-
+
public static class Builder {
private String[] dimensionNames;
- private int[] labels;
+ private long[] labels;
private int index = 0;
-
+
public Builder(int size) {
dimensionNames = new String[size];
- labels = new int[size];
+ labels = new long[size];
}
-
- public void add(String dimensionName, int label) {
+
+ public void add(String dimensionName, long label) {
dimensionNames[index] = dimensionName;
labels[index] = label;
index++;
}
-
+
public PartialAddress build() {
return new PartialAddress(this);
}
-
+
}
-
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 1b60e01cf7e..0c948f1fbee 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -59,7 +59,7 @@ public interface Tensor {
default boolean isEmpty() { return size() == 0; }
/** Returns the number of cells in this */
- int size();
+ long size();
/** Returns the value of a cell, or NaN if this cell does not exist/have no value */
double get(TensorAddress address);
@@ -124,7 +124,7 @@ public interface Tensor {
return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate();
}
- static Tensor generate(TensorType type, Function<List<Integer>, Double> valueSupplier) {
+ static Tensor generate(TensorType type, Function<List<Long>, Double> valueSupplier) {
return new Generate(type, valueSupplier).evaluate();
}
@@ -333,7 +333,7 @@ public interface Tensor {
* This is for optimizations mapping between tensors where this is possible without creating a
* TensorAddress.
*/
- int getDirectIndex() { return -1; }
+ long getDirectIndex() { return -1; }
@Override
public Double getValue() { return value; }
@@ -396,7 +396,7 @@ public interface Tensor {
Builder cell(TensorAddress address, double value);
/** Add a cell */
- Builder cell(double value, int ... labels);
+ Builder cell(double value, long ... labels);
/**
* Add a cell
@@ -425,7 +425,7 @@ public interface Tensor {
return this;
}
- public CellBuilder label(String dimension, int label) {
+ public CellBuilder label(String dimension, long label) {
return label(dimension, String.valueOf(label));
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index ff1202463f2..38553497478 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -2,16 +2,10 @@
package com.yahoo.tensor;
import com.google.common.annotations.Beta;
-import com.google.common.collect.ImmutableList;
-import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
import java.util.Objects;
import java.util.Optional;
-import java.util.Set;
/**
* An immutable address to a tensor cell. This simply supplies a value to each dimension
@@ -26,8 +20,8 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return new StringTensorAddress(labels);
}
- public static TensorAddress of(int ... labels) {
- return new IntTensorAddress(labels);
+ public static TensorAddress of(long ... labels) {
+ return new NumericTensorAddress(labels);
}
/** Returns the number of labels in this */
@@ -41,14 +35,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
public abstract String label(int i);
/**
- * Returns the i'th label in this as an int.
- * Prefer this if you know that this is an integer address, but not otherwise.
+ * Returns the i'th label in this as a long.
+ * Prefer this if you know that this is a numeric address, but not otherwise.
*
* @throws IllegalArgumentException if there is no label at this index
*/
- public abstract int intLabel(int i);
+ public abstract long numericLabel(int i);
- public abstract TensorAddress withLabel(int labelIndex, int label);
+ public abstract TensorAddress withLabel(int labelIndex, long label);
public final boolean isEmpty() { return size() == 0; }
@@ -110,17 +104,17 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
public String label(int i) { return labels[i]; }
@Override
- public int intLabel(int i) {
+ public long numericLabel(int i) {
try {
- return Integer.parseInt(labels[i]);
+ return Long.parseLong(labels[i]);
}
catch (NumberFormatException e) {
- throw new IllegalArgumentException("Expected an int label in " + this + " at position " + i);
+ throw new IllegalArgumentException("Expected a long label in " + this + " at position " + i);
}
}
@Override
- public TensorAddress withLabel(int index, int label) {
+ public TensorAddress withLabel(int index, long label) {
String[] labels = Arrays.copyOf(this.labels, this.labels.length);
labels[index] = String.valueOf(label);
return new StringTensorAddress(labels);
@@ -133,11 +127,11 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
}
- private static final class IntTensorAddress extends TensorAddress {
+ private static final class NumericTensorAddress extends TensorAddress {
- private final int[] labels;
+ private final long[] labels;
- private IntTensorAddress(int[] labels) {
+ private NumericTensorAddress(long[] labels) {
this.labels = Arrays.copyOf(labels, labels.length);
}
@@ -148,13 +142,13 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
public String label(int i) { return String.valueOf(labels[i]); }
@Override
- public int intLabel(int i) { return labels[i]; }
+ public long numericLabel(int i) { return labels[i]; }
@Override
- public TensorAddress withLabel(int index, int label) {
- int[] labels = Arrays.copyOf(this.labels, this.labels.length);
+ public TensorAddress withLabel(int index, long label) {
+ long[] labels = Arrays.copyOf(this.labels, this.labels.length);
labels[index] = label;
- return new IntTensorAddress(labels);
+ return new NumericTensorAddress(labels);
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 914d853aeca..b396f831de0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -139,7 +139,7 @@ public class TensorType {
public final String name() { return name; }
/** Returns the size of this dimension if it is bound, empty otherwise */
- public abstract Optional<Integer> size();
+ public abstract Optional<Long> size();
public abstract Type type();
@@ -189,7 +189,7 @@ public class TensorType {
return this.name.compareTo(other.name);
}
- public static Dimension indexed(String name, int size) {
+ public static Dimension indexed(String name, long size) {
return new IndexedBoundDimension(name, size);
}
@@ -197,17 +197,19 @@ public class TensorType {
public static class IndexedBoundDimension extends TensorType.Dimension {
- private final Integer size;
+ private final Long size;
- private IndexedBoundDimension(String name, int size) {
+ private IndexedBoundDimension(String name, long size) {
super(name);
if (size < 1)
throw new IllegalArgumentException("Size of bound dimension '" + name + "' must be at least 1");
+ if (size > Integer.MAX_VALUE)
+ throw new IllegalArgumentException("Size of bound dimension '" + name + "' cannot be larger than " + Integer.MAX_VALUE);
this.size = size;
}
@Override
- public Optional<Integer> size() { return Optional.of(size); }
+ public Optional<Long> size() { return Optional.of(size); }
@Override
public Type type() { return Type.indexedBound; }
@@ -248,7 +250,7 @@ public class TensorType {
}
@Override
- public Optional<Integer> size() { return Optional.empty(); }
+ public Optional<Long> size() { return Optional.empty(); }
@Override
public Type type() { return Type.indexedUnbound; }
@@ -269,7 +271,7 @@ public class TensorType {
}
@Override
- public Optional<Integer> size() { return Optional.empty(); }
+ public Optional<Long> size() { return Optional.empty(); }
@Override
public Type type() { return Type.mapped; }
@@ -357,7 +359,7 @@ public class TensorType {
*
* @throws IllegalArgumentException if the dimension is already present
*/
- public Builder indexed(String name, int size) { return add(new IndexedBoundDimension(name, size)); }
+ public Builder indexed(String name, long size) { return add(new IndexedBoundDimension(name, size)); }
/**
* Adds an unbound indexed dimension to this
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index faa0ca36cb6..d4affe0ef9b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -67,7 +67,7 @@ public class Concat extends PrimitiveTensorFunction {
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
- int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
+ long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
int[] aToIndexes = mapIndexes(a.type(), concatType);
int[] bToIndexes = mapIndexes(b.type(), concatType);
concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
@@ -75,7 +75,7 @@ public class Concat extends PrimitiveTensorFunction {
return builder.build();
}
- private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType,
+ private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
@@ -129,8 +129,8 @@ public class Concat extends PrimitiveTensorFunction {
DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
for (int i = 0; i < concatSizes.dimensions(); i++) {
String currentDimension = concatType.dimensions().get(i).name();
- int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0);
- int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0);
+ long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L);
+ long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L);
if (currentDimension.equals(concatDimension))
concatSizes.set(i, aSize + bSize);
else if (aSize != 0 && bSize != 0 && aSize!=bSize )
@@ -148,8 +148,8 @@ public class Concat extends PrimitiveTensorFunction {
* (in some other dimension than the concat dimension)
*/
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
- TensorType concatType, int concatOffset, String concatDimension) {
- int[] combinedLabels = new int[concatType.dimensions().size()];
+ TensorType concatType, long concatOffset, String concatDimension) {
+ long[] combinedLabels = new long[concatType.dimensions().size()];
Arrays.fill(combinedLabels, -1);
int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
@@ -179,15 +179,15 @@ public class Concat extends PrimitiveTensorFunction {
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private boolean mapContent(TensorAddress from, int[] to, int[] indexMap, int concatDimension, int concatOffset) {
+ private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
for (int i = 0; i < from.size(); i++) {
int toIndex = indexMap[i];
if (concatDimension == toIndex) {
- to[toIndex] = from.intLabel(i) + concatOffset;
+ to[toIndex] = from.numericLabel(i) + concatOffset;
}
else {
- if (to[toIndex] != -1 && to[toIndex] != from.intLabel(i)) return false;
- to[toIndex] = from.intLabel(i);
+ if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
+ to[toIndex] = from.numericLabel(i);
}
}
return true;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index c75d8ee4753..653be8dacf0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -17,7 +17,7 @@ import java.util.stream.Stream;
public class Diag extends CompositeTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> diagFunction;
+ private final Function<List<Long>, Double> diagFunction;
public Diag(TensorType type) {
this.type = type;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index e42d25197e2..ef2770c04f5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -22,17 +22,17 @@ import java.util.function.Function;
public class Generate extends PrimitiveTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> generator;
+ private final Function<List<Long>, Double> generator;
/**
* Creates a generated tensor
*
* @param type the type of the tensor
- * @param generator the function generating values from a list of ints specifying the indexes of the
+ * @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public Generate(TensorType type, Function<List<Integer>, Double> generator) {
+ public Generate(TensorType type, Function<List<Long>, Double> generator) {
Objects.requireNonNull(type, "The argument tensor type cannot be null");
Objects.requireNonNull(generator, "The argument function cannot be null");
validateType(type);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index ff887e3e9a6..174a8e4c435 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -56,8 +56,8 @@ public class Join extends PrimitiveTensorFunction {
if (aDim.name().equals(bDim.name())) { // include
if (aDim.isIndexed() && bDim.isIndexed()) {
if (aDim.size().isPresent() || bDim.size().isPresent())
- typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
- bDim.size().orElse(Integer.MAX_VALUE)));
+ typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Long.MAX_VALUE),
+ bDim.size().orElse(Long.MAX_VALUE)));
else
typeBuilder.indexed(aDim.name());
}
@@ -118,11 +118,11 @@ public class Join extends PrimitiveTensorFunction {
}
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
- int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
+ long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
- IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedLength).build());
- for (int i = 0; i < joinedLength; i++)
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build());
+ for (int i = 0; i < joinedRank; i++)
builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
return builder.build();
}
@@ -169,10 +169,10 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
- private void joinSubspaces(Iterator<Double> subspace, int subspaceSize,
- Iterator<Tensor.Cell> superspace, int superspaceSize,
+ private void joinSubspaces(Iterator<Double> subspace, long subspaceSize,
+ Iterator<Tensor.Cell> superspace, long superspaceSize,
boolean reversedArgumentOrder, IndexedTensor.Builder builder) {
- int joinedLength = Math.min(subspaceSize, superspaceSize);
+ long joinedLength = Math.min(subspaceSize, superspaceSize);
if (reversedArgumentOrder) {
for (int i = 0; i < joinedLength; i++) {
Tensor.Cell supercell = superspace.next();
@@ -281,7 +281,7 @@ public class Join extends PrimitiveTensorFunction {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
if (retainDimensions.contains(addressType.dimensions().get(i).name()))
- builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
+ builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i));
return builder.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index a56f82b026a..8e7f4e4c773 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -18,7 +18,7 @@ import java.util.stream.Stream;
public class Range extends CompositeTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> rangeFunction;
+ private final Function<List<Long>, Double> rangeFunction;
public Range(TensorType type) {
this.type = type;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index fb5029fbfd6..f1dadba2a29 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -14,8 +14,8 @@ import java.util.stream.Collectors;
/**
* Factory of scalar Java functions.
* The purpose of this is to embellish anonymous functions with a runtime type
- * such that they can be inspected and will return a parseable toString.
- *
+ * such that they can be inspected and will return a parsable toString.
+ *
* @author bratseth
*/
@Beta
@@ -31,9 +31,9 @@ public class ScalarFunctions {
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
- public static Function<List<Integer>, Double> random() { return new Random(); }
- public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
- public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
+ public static Function<List<Long>, Double> random() { return new Random(); }
+ public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
+ public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
// Binary operators -----------------------------------------------------------------------------
@@ -60,7 +60,7 @@ public class ScalarFunctions {
public static class Multiply implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left * right; }
+ public double applyAsDouble(double left, double right) { return left * right; }
@Override
public String toString() { return "f(a,b)(a * b)"; }
}
@@ -100,26 +100,26 @@ public class ScalarFunctions {
// Variable-length operators -----------------------------------------------------------------------------
- public static class EqualElements implements Function<List<Integer>, Double> {
- private final ImmutableList<String> argumentNames;
+ public static class EqualElements implements Function<List<Long>, Double> {
+ private final ImmutableList<String> argumentNames;
private EqualElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@Override
- public Double apply(List<Integer> values) {
+ public Double apply(List<Long> values) {
if (values.isEmpty()) return 1.0;
- for (Integer value : values)
+ for (Long value : values)
if ( ! value.equals(values.get(0)))
return 0.0;
return 1.0;
}
@Override
- public String toString() {
+ public String toString() {
if (argumentNames.size() == 0) return "1";
if (argumentNames.size() == 1) return "1";
if (argumentNames.size() == 2) return argumentNames.get(0) + "==" + argumentNames.get(1);
-
+
StringBuilder b = new StringBuilder();
for (int i = 0; i < argumentNames.size() -1; i++) {
b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")");
@@ -130,25 +130,25 @@ public class ScalarFunctions {
}
}
- public static class Random implements Function<List<Integer>, Double> {
+ public static class Random implements Function<List<Long>, Double> {
@Override
- public Double apply(List<Integer> values) {
+ public Double apply(List<Long> values) {
return ThreadLocalRandom.current().nextDouble();
}
@Override
public String toString() { return "random"; }
}
- public static class SumElements implements Function<List<Integer>, Double> {
+ public static class SumElements implements Function<List<Long>, Double> {
private final ImmutableList<String> argumentNames;
private SumElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@Override
- public Double apply(List<Integer> values) {
- int sum = 0;
- for (Integer value : values)
+ public Double apply(List<Long> values) {
+ long sum = 0;
+ for (Long value : values)
sum += value;
return (double)sum;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index aabb53d1c67..1e830bac461 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -36,7 +36,7 @@ public class DenseBinaryFormat implements BinaryFormat {
buffer.putInt1_4Bytes(tensor.type().dimensions().size());
for (int i = 0; i < tensor.type().dimensions().size(); i++) {
buffer.putUtf8String(tensor.type().dimensions().get(i).name());
- buffer.putInt1_4Bytes(tensor.dimensionSizes().size(i));
+ buffer.putInt1_4Bytes((int)tensor.dimensionSizes().size(i)); // XXX: Size truncation
}
}
@@ -71,7 +71,7 @@ public class DenseBinaryFormat implements BinaryFormat {
int dimensionCount = buffer.getInt1_4Bytes();
TensorType.Builder builder = new TensorType.Builder();
for (int i = 0; i < dimensionCount; i++)
- builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes());
+ builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation
return builder.build();
}
@@ -84,7 +84,7 @@ public class DenseBinaryFormat implements BinaryFormat {
}
private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
- for (int i = 0; i < sizes.totalSize(); i++)
+ for (long i = 0; i < sizes.totalSize(); i++)
builder.cellByDirectIndex(i, buffer.getDouble());
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
index 61dfa888567..34e6cccf0f0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -46,16 +46,16 @@ class MixedBinaryFormat implements BinaryFormat {
buffer.putInt1_4Bytes(denseDimensions.size());
for (TensorType.Dimension dimension : denseDimensions) {
buffer.putUtf8String(dimension.name());
- buffer.putInt1_4Bytes(dimension.size().orElseThrow(() ->
- new IllegalArgumentException("Unknown size of indexed dimension.")));
+ buffer.putInt1_4Bytes((int)dimension.size().orElseThrow(() ->
+ new IllegalArgumentException("Unknown size of indexed dimension.")).longValue()); // XXX: Size truncation
}
}
private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor) {
List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
- int denseSubspaceSize = tensor.denseSubspaceSize();
+ long denseSubspaceSize = tensor.denseSubspaceSize();
if (sparseDimensions.size() > 0) {
- buffer.putInt1_4Bytes(tensor.size() / denseSubspaceSize);
+ buffer.putInt1_4Bytes((int)(tensor.size() / denseSubspaceSize)); // XXX: Size truncation
}
Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
while (cellIterator.hasNext()) {
@@ -98,7 +98,7 @@ class MixedBinaryFormat implements BinaryFormat {
}
int numIndexedDimensions = buffer.getInt1_4Bytes();
for (int i = 0; i < numIndexedDimensions; ++i) {
- builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes());
+ builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation
}
return builder.build();
}
@@ -106,21 +106,21 @@ class MixedBinaryFormat implements BinaryFormat {
private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) {
List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
TensorType sparseType = MixedTensor.createPartialType(sparseDimensions);
- int denseSubspaceSize = builder.denseSubspaceSize();
+ long denseSubspaceSize = builder.denseSubspaceSize();
int numBlocks = 1;
if (sparseDimensions.size() > 0) {
numBlocks = buffer.getInt1_4Bytes();
}
- double[] denseSubspace = new double[denseSubspaceSize];
+ double[] denseSubspace = new double[(int)denseSubspaceSize];
for (int i = 0; i < numBlocks; ++i) {
TensorAddress.Builder sparseAddress = new TensorAddress.Builder(sparseType);
for (TensorType.Dimension sparseDimension : sparseDimensions) {
sparseAddress.add(sparseDimension.name(), buffer.getUtf8String());
}
- for (int denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) {
- denseSubspace[denseOffset] = buffer.getDouble();
+ for (long denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) {
+ denseSubspace[(int)denseOffset] = buffer.getDouble();
}
builder.block(sparseAddress.build(), denseSubspace);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index 19969506eca..0cd3ff77aca 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -3,13 +3,14 @@ package com.yahoo.tensor.serialization;
import com.google.common.annotations.Beta;
import com.yahoo.io.GrowableByteBuffer;
-import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.text.Utf8;
-import java.util.*;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
/**
* Implementation of a sparse binary format for a tensor on the form:
@@ -39,7 +40,7 @@ class SparseBinaryFormat implements BinaryFormat {
}
private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
- buffer.putInt1_4Bytes(tensor.size());
+ buffer.putInt1_4Bytes((int)tensor.size()); // XXX: Size truncation
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
encodeAddress(buffer, cell.getKey());
@@ -79,8 +80,8 @@ class SparseBinaryFormat implements BinaryFormat {
}
private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
- int numCells = buffer.getInt1_4Bytes();
- for (int i = 0; i < numCells; ++i) {
+ long numCells = buffer.getInt1_4Bytes(); // XXX: Size truncation
+ for (long i = 0; i < numCells; ++i) {
Tensor.Builder.CellBuilder cellBuilder = builder.cell();
decodeAddress(buffer, cellBuilder, type);
cellBuilder.value(buffer.getDouble());
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 693b0f09351..38a8329bff1 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -4,7 +4,6 @@ package com.yahoo.tensor;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.evaluation.MapEvaluationContext;
import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.Argmax;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Reduce;
@@ -12,14 +11,12 @@ import com.yahoo.tensor.functions.TensorFunction;
import org.junit.Test;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
-import java.util.stream.Collectors;
-import static org.junit.Assert.assertEquals;
import static com.yahoo.tensor.TensorType.Dimension.Type;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -99,7 +96,7 @@ public class TensorTestCase {
ImmutableList.of("y", "x")));
assertEquals(Tensor.from("{ {x:0,y:0}:0, {x:0,y:1}:0, {x:1,y:0}:0, {x:1,y:1}:1, {x:2,y:0}:0, {x:2,y:1}:2, }"),
Tensor.generate(new TensorType.Builder().indexed("x", 3).indexed("y", 2).build(),
- (List<Integer> indexes) -> (double)indexes.get(0)*indexes.get(1)));
+ (List<Long> indexes) -> (double)indexes.get(0)*indexes.get(1)));
assertEquals(Tensor.from("{ {x:0,y:0,z:0}:0, {x:0,y:1,z:0}:1, {x:1,y:0,z:0}:1, {x:1,y:1,z:0}:2, {x:2,y:0,z:0}:2, {x:2,y:1,z:0}:3, "+
" {x:0,y:0,z:1}:1, {x:0,y:1,z:1}:2, {x:1,y:0,z:1}:2, {x:1,y:1,z:1}:3, {x:2,y:0,z:1}:3, {x:2,y:1,z:1}:4 }"),
Tensor.range(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build()));