summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-01-05 16:44:53 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-01-05 16:44:53 +0100
commite25d723262ed8702be60ade30d87c2da75fbadf2 (patch)
treefbfb8cc3327b9abab638fc513cb6fd93b69d8ab9 /vespajlib
parentfd22e7e254528bea682a2e585f5cbb1fc625c93d (diff)
Type DimensionSizes
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java71
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java251
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java22
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java4
8 files changed, 220 insertions, 164 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
new file mode 100644
index 00000000000..76340bb7d8f
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -0,0 +1,71 @@
+package com.yahoo.tensor;
+
+import java.util.Arrays;
+
+/**
+ * The sizes of a set of dimensions.
+ *
+ * @author bratseth
+ */
+public final class DimensionSizes {
+
+ private final int[] sizes;
+
+ private DimensionSizes(Builder builder) {
+ this.sizes = builder.sizes;
+ builder.sizes = null; // invalidate builder to avoid copying the array
+ }
+
+ /**
+ * Returns the length of this in the nth dimension
+ *
+ * @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one
+ */
+ public int size(int dimensionIndex) { return sizes[dimensionIndex]; }
+
+ /** Returns the number of dimensions this provides the size of */
+ public int dimensions() { return sizes.length; }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == this) return true;
+ if (!(o instanceof DimensionSizes)) return false;
+ return Arrays.equals(((DimensionSizes) o).sizes, this.sizes);
+ }
+
+ @Override
+ public int hashCode() { return Arrays.hashCode(sizes); }
+
+ /**
+ * Builder of a set of dimension sizes.
+ * Dimensions whose size is not set before building will get size 0.
+ */
+ public final static class Builder {
+
+ private int[] sizes;
+
+ public Builder(int dimensions) {
+ this.sizes = new int[dimensions];
+ }
+
+ public Builder set(int dimensionIndex, int size) {
+ sizes[dimensionIndex] = size;
+ return this;
+ }
+
+ /**
+ * Returns the length of this in the nth dimension
+ *
+ * @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one
+ */
+ public int size(int dimensionIndex) { return sizes[dimensionIndex]; }
+
+ /** Returns the number of dimensions this provides the size of */
+ public int dimensions() { return sizes.length; }
+
+ /** Build this. This builder becomes invalid after calling this. */
+ public DimensionSizes build() { return new DimensionSizes(this); }
+
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 2c3cb6ebde2..d69cf65ee8d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -27,11 +27,11 @@ public class IndexedTensor implements Tensor {
private final TensorType type;
/** The sizes of the dimensions of this in the order of the dimensions of the type */
- private final int[] dimensionSizes;
+ private final DimensionSizes dimensionSizes;
private final double[] values;
- private IndexedTensor(TensorType type, int[] dimensionSizes, double[] values) {
+ private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) {
this.type = type;
this.dimensionSizes = dimensionSizes;
this.values = values;
@@ -68,12 +68,12 @@ public class IndexedTensor implements Tensor {
* other iterator.
*
* @param dimensions the names of the dimensions of the superspace
- * @param dimensionSizes the size of each dimension in the space we are returning values for, containing
- * one value per dimension of this tensor (in order). Each size may be the same or smaller
- * than the corresponding size of this tensor
+ * @param sizes the size of each dimension in the space we are returning values for, containing
+ * one value per dimension of this tensor (in order). Each size may be the same or smaller
+ * than the corresponding size of this tensor
*/
- public Iterator<SubspaceIterator> subspaceIterator(Set<String> dimensions, int[] dimensionSizes) {
- return new SuperspaceIterator(dimensions, dimensionSizes);
+ public Iterator<SubspaceIterator> subspaceIterator(Set<String> dimensions, DimensionSizes sizes) {
+ return new SuperspaceIterator(dimensions, sizes);
}
/** Returns a subspace iterator having the sizes of the dimensions of this tensor */
@@ -81,12 +81,6 @@ public class IndexedTensor implements Tensor {
return subspaceIterator(dimensions, dimensionSizes);
}
- /** Returns whether the dimensions sizes of this are equal to the given sizes */
- // TODO: Replace by returning immutable sizes when DimensionSizes are a class
- public boolean dimensionSizesAre(int[] dimensionSizes) {
- return Arrays.equals(dimensionSizes, this.dimensionSizes);
- }
-
/**
* Returns the value at the given indexes
*
@@ -110,54 +104,44 @@ public class IndexedTensor implements Tensor {
}
}
- double get(int valueIndex) { return values[valueIndex]; }
+ private double get(int valueIndex) { return values[valueIndex]; }
- /** Returns the value at these indexes */
- private double get(Indexes indexes) {
- return values[toValueIndex(indexes.indexesForReading(), dimensionSizes)];
- }
-
- private static int toValueIndex(int[] indexes, int[] dimensionSizes) {
+ private static int toValueIndex(int[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
if (indexes.length == 0) return 0; // for speed
int valueIndex = 0;
for (int i = 0; i < indexes.length; i++)
- valueIndex += productOfDimensionsAfter(i, dimensionSizes) * indexes[i];
+ valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i];
return valueIndex;
}
- private static int toValueIndex(TensorAddress address, int[] dimensionSizes) {
+ private static int toValueIndex(TensorAddress address, DimensionSizes sizes) {
if (address.isEmpty()) return 0;
int valueIndex = 0;
for (int i = 0; i < address.size(); i++)
- valueIndex += productOfDimensionsAfter(i, dimensionSizes) * address.intLabel(i);
+ valueIndex += productOfDimensionsAfter(i, sizes) * address.intLabel(i);
return valueIndex;
}
- private static int productOfDimensionsAfter(int afterIndex, int[] dimensionSizes) {
+ private static int productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) {
int product = 1;
- for (int i = afterIndex + 1; i < dimensionSizes.length; i++)
- product *= dimensionSizes[i];
+ for (int i = afterIndex + 1; i < sizes.dimensions(); i++)
+ product *= sizes.size(i);
return product;
}
@Override
public TensorType type() { return type; }
- /**
- * Returns the length of this in the nth dimension
- *
- * @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one
- */
- public int size(int dimension) {
- return dimensionSizes[dimension];
+ public DimensionSizes dimensionSizes() {
+ return dimensionSizes;
}
@Override
public Map<TensorAddress, Double> cells() {
- if (dimensionSizes.length == 0)
+ if (dimensionSizes.dimensions() == 0)
return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]);
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
@@ -201,27 +185,27 @@ public class IndexedTensor implements Tensor {
* and, agree with the type size information when specified in the type.
* If sizes are completely specified in the type this size information is redundant.
*/
- public static Builder of(TensorType type, int[] dimensionSizes) {
+ public static Builder of(TensorType type, DimensionSizes sizes) {
// validate
- if (dimensionSizes.length != type.dimensions().size())
- throw new IllegalArgumentException(dimensionSizes.length + " is the wrong number of dimension sizes " +
- " for " + type);
- for (int i = 0; i < dimensionSizes.length; i++ ) {
+ if (sizes.dimensions() != type.dimensions().size())
+ throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
+ "for " + type);
+ for (int i = 0; i < sizes.dimensions(); i++ ) {
Optional<Integer> size = type.dimensions().get(i).size();
- if (size.isPresent() && size.get() < dimensionSizes[i])
- throw new IllegalArgumentException("Size of " + type.dimensions() + " is " + dimensionSizes[i] +
+ if (size.isPresent() && size.get() < sizes.size(i))
+ throw new IllegalArgumentException("Size of " + type.dimensions() + " is " + sizes.size(i) +
" but cannot be larger than " + size.get());
}
- return new BoundBuilder(type, dimensionSizes);
+ return new BoundBuilder(type, sizes);
}
public abstract Builder cell(double value, int ... indexes);
- protected double[] arrayFor(int[] dimensionSizes) {
+ protected double[] arrayFor(DimensionSizes sizes) {
int productSize = 1;
- for (int dimensionSize : dimensionSizes)
- productSize *= dimensionSize;
+ for (int i = 0; i < sizes.dimensions(); i++ )
+ productSize *= sizes.size(i);
return new double[productSize];
}
@@ -236,32 +220,32 @@ public class IndexedTensor implements Tensor {
/** A bound builder can create the double array directly */
private static class BoundBuilder extends Builder {
- private int[] dimensionSizes;
+ private DimensionSizes sizes;
private double[] values;
private BoundBuilder(TensorType type) {
this(type, dimensionSizesOf(type));
}
- private BoundBuilder(TensorType type, int[] dimensionSizes) {
+ public static DimensionSizes dimensionSizesOf(TensorType type) {
+ DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
+ for (int i = 0; i < type.dimensions().size(); i++)
+ b.set(i, type.dimensions().get(i).size().get());
+ return b.build();
+ }
+
+ private BoundBuilder(TensorType type, DimensionSizes sizes) {
super(type);
- if ( dimensionSizes.length != type.dimensions().size())
+ if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
- this.dimensionSizes = dimensionSizes;
- values = arrayFor(dimensionSizes);
+ this.sizes = sizes;
+ values = arrayFor(sizes);
Arrays.fill(values, Double.NaN);
}
- private static int[] dimensionSizesOf(TensorType type) {
- int[] dimensionSizes = new int[type.dimensions().size()];
- for (int i = 0; i < type.dimensions().size(); i++)
- dimensionSizes[i] = type.dimensions().get(i).size().get();
- return dimensionSizes;
- }
-
@Override
public BoundBuilder cell(double value, int ... indexes) {
- values[toValueIndex(indexes, dimensionSizes)] = value;
+ values[toValueIndex(indexes, sizes)] = value;
return this;
}
@@ -272,7 +256,7 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(TensorAddress address, double value) {
- values[toValueIndex(address, dimensionSizes)] = value;
+ values[toValueIndex(address, sizes)] = value;
return this;
}
@@ -282,9 +266,9 @@ public class IndexedTensor implements Tensor {
// NaN's don't get lost so leaving them in place should be quite benign
if (values.length == 1 && Double.isNaN(values[0]))
values = new double[0];
- IndexedTensor tensor = new IndexedTensor(type, dimensionSizes, values);
+ IndexedTensor tensor = new IndexedTensor(type, sizes, values);
// prevent further modification
- dimensionSizes = null;
+ sizes = null;
values = null;
return tensor;
}
@@ -320,23 +304,23 @@ public class IndexedTensor implements Tensor {
@Override
public IndexedTensor build() {
if (firstDimension == null) // empty
- return new IndexedTensor(type, new int[type.dimensions().size()], new double[] {});
+ return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {});
if (type.dimensions().isEmpty()) // single number
- return new IndexedTensor(type, new int[type.dimensions().size()], new double[] {(Double) firstDimension.get(0) });
+ return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
- int[] dimensionSizes = findDimensionSizes(firstDimension);
+ DimensionSizes dimensionSizes = findDimensionSizes(firstDimension);
double[] values = arrayFor(dimensionSizes);
fillValues(0, 0, firstDimension, dimensionSizes, values);
return new IndexedTensor(type, dimensionSizes, values);
}
- private int[] findDimensionSizes(List<Object> firstDimension) {
+ private DimensionSizes findDimensionSizes(List<Object> firstDimension) {
List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size());
findDimensionSizes(0, dimensionSizeList, firstDimension);
- int[] dimensionSizes = new int[type.dimensions().size()]; // may be longer than the list but that's correct
- for (int i = 0; i < dimensionSizes.length; i++)
- dimensionSizes[i] = dimensionSizeList.get(i);
- return dimensionSizes;
+ DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct
+ for (int i = 0; i < b.dimensions(); i++)
+ b.set(i, dimensionSizeList.get(i));
+ return b.build();
}
@SuppressWarnings("unchecked")
@@ -354,12 +338,12 @@ public class IndexedTensor implements Tensor {
@SuppressWarnings("unchecked")
private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension,
- int[] dimensionSizes, double[] values) {
- if (currentDimensionIndex < dimensionSizes.length - 1) { // recurse to next dimension
+ DimensionSizes sizes, double[] values) {
+ if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension
for (int i = 0; i < currentDimension.size(); i++)
fillValues(currentDimensionIndex + 1,
- offset + productOfDimensionsAfter(currentDimensionIndex, dimensionSizes) * i,
- (List<Object>) currentDimension.get(i), dimensionSizes, values);
+ offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i,
+ (List<Object>) currentDimension.get(i), sizes, values);
} else { // last dimension - fill values
for (int i = 0; i < currentDimension.size(); i++)
values[offset + i] = (double) currentDimension.get(i);
@@ -477,12 +461,12 @@ public class IndexedTensor implements Tensor {
* The sizes of the space we'll return values of, one value for each dimension of this tensor,
* which may be equal to or smaller than the sizes of this tensor
*/
- private final int[] iterateDimensionSizes;
+ private final DimensionSizes iterateSizes;
private int count = 0;
- private SuperspaceIterator(Set<String> superdimensionNames, int[] iterateDimensionSizes) {
- this.iterateDimensionSizes = iterateDimensionSizes;
+ private SuperspaceIterator(Set<String> superdimensionNames, DimensionSizes iterateSizes) {
+ this.iterateSizes = iterateSizes;
List<Integer> superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator
subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length)
@@ -493,7 +477,7 @@ public class IndexedTensor implements Tensor {
subdimensionIndexes.add(i);
}
- superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, superdimensionIndexes);
+ superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, superdimensionIndexes);
}
@Override
@@ -506,7 +490,7 @@ public class IndexedTensor implements Tensor {
if ( ! hasNext()) throw new NoSuchElementException("No cell at " + superindexes);
count++;
superindexes.next();
- return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), iterateDimensionSizes);
+ return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), iterateSizes);
}
}
@@ -525,7 +509,7 @@ public class IndexedTensor implements Tensor {
*/
private final List<Integer> iterateDimensions;
private final int[] address;
- private final int[] iterateDimensionSizes;
+ private final DimensionSizes iterateSizes;
private Indexes indexes;
private int count = 0;
@@ -545,11 +529,11 @@ public class IndexedTensor implements Tensor {
* This is treated as immutable.
* @param address the address of the first cell of this subspace.
*/
- private SubspaceIterator(List<Integer> iterateDimensions, int[] address, int[] iterateDimensionSizes) {
+ private SubspaceIterator(List<Integer> iterateDimensions, int[] address, DimensionSizes iterateSizes) {
this.iterateDimensions = iterateDimensions;
this.address = address;
- this.iterateDimensionSizes = iterateDimensionSizes;
- this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, iterateDimensions, address);
+ this.iterateSizes = iterateSizes;
+ this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address);
reusedCell = new LazyCell(indexes, Double.NaN);
}
@@ -564,7 +548,7 @@ public class IndexedTensor implements Tensor {
/** Rewind this iterator to the first element */
public void reset() {
this.count = 0;
- this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, iterateDimensions, address);
+ this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address);
}
@Override
@@ -617,54 +601,54 @@ public class IndexedTensor implements Tensor {
*/
public abstract static class Indexes {
- private final int[] sourceDimensionSizes;
+ private final DimensionSizes sourceSizes;
- private final int[] iterationDimensionSizes;
+ private final DimensionSizes iterationSizes;
protected final int[] indexes;
- public static Indexes of(int[] dimensionSizes) {
- return of(dimensionSizes, dimensionSizes);
+ public static Indexes of(DimensionSizes sizes) {
+ return of(sizes, sizes);
}
- private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes) {
- return of(sourceDimensionSizes, iterateDimensionSizes, completeIterationOrder(iterateDimensionSizes.length));
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes) {
+ return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()));
}
- private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int size) {
- return of(sourceDimensionSizes, iterateDimensionSizes, completeIterationOrder(iterateDimensionSizes.length), size);
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int size) {
+ return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size);
}
- private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions) {
- return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, computeSize(iterateDimensionSizes, iterateDimensions));
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions) {
+ return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions));
}
- private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int size) {
- return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, new int[iterateDimensionSizes.length], size);
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int size) {
+ return of(sourceSizes, iterateSizes, iterateDimensions, new int[iterateSizes.dimensions()], size);
}
- private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes) {
- return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, initialIndexes, computeSize(iterateDimensionSizes, iterateDimensions));
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes) {
+ return of(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, computeSize(iterateSizes, iterateDimensions));
}
- private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
if (size == 0) {
- return new EmptyIndexes(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); // we're told explicitly there are truly no values available
+ return new EmptyIndexes(sourceSizes, iterateSizes, initialIndexes); // we're told explicitly there are truly no values available
}
else if (size == 1) {
- return new SingleValueIndexes(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); // with no (iterating) dimensions, we still return one value, not zero
+ return new SingleValueIndexes(sourceSizes, iterateSizes, initialIndexes); // with no (iterating) dimensions, we still return one value, not zero
}
else if (iterateDimensions.size() == 1) {
- if (Arrays.equals(sourceDimensionSizes, iterateDimensionSizes))
- return new EqualSizeSingleDimensionIndexes(sourceDimensionSizes, iterateDimensions.get(0), initialIndexes, size);
+ if (sourceSizes.equals(iterateSizes))
+ return new EqualSizeSingleDimensionIndexes(sourceSizes, iterateDimensions.get(0), initialIndexes, size);
else
- return new SingleDimensionIndexes(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions.get(0), initialIndexes, size); // optimization
+ return new SingleDimensionIndexes(sourceSizes, iterateSizes, iterateDimensions.get(0), initialIndexes, size); // optimization
}
else {
- if (Arrays.equals(sourceDimensionSizes, iterateDimensionSizes))
- return new EqualSizeMultiDimensionIndexes(sourceDimensionSizes, iterateDimensions, initialIndexes, size);
+ if (sourceSizes.equals(iterateSizes))
+ return new EqualSizeMultiDimensionIndexes(sourceSizes, iterateDimensions, initialIndexes, size);
else
- return new MultiDimensionIndexes(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, initialIndexes, size);
+ return new MultiDimensionIndexes(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, size);
}
}
@@ -675,16 +659,16 @@ public class IndexedTensor implements Tensor {
return iterationDimensions;
}
- private Indexes(int[] sourceDimensionSizes, int[] iterationDimensionSizes, int[] indexes) {
- this.sourceDimensionSizes = sourceDimensionSizes;
- this.iterationDimensionSizes = iterationDimensionSizes;
+ private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) {
+ this.sourceSizes = sourceSizes;
+ this.iterationSizes = iterationSizes;
this.indexes = indexes;
}
- private static int computeSize(int[] dimensionSizes, List<Integer> iterateDimensions) {
+ private static int computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) {
int size = 1;
for (int iterateDimension : iterateDimensions)
- size *= dimensionSizes[iterateDimension];
+ size *= sizes.size(iterateDimension);
return size;
}
@@ -701,13 +685,12 @@ public class IndexedTensor implements Tensor {
public int[] indexesForReading() { return indexes; }
int toSourceValueIndex() {
- return IndexedTensor.toValueIndex(indexes, sourceDimensionSizes);
+ return IndexedTensor.toValueIndex(indexes, sourceSizes);
}
- int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationDimensionSizes); }
+ int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); }
- /** Returns the dimension sizes of this. Do not modify the return value */
- int[] dimensionSizes() { return iterationDimensionSizes; }
+ DimensionSizes dimensionSizes() { return iterationSizes; }
/** Returns an immutable list containing a copy of the indexes in this */
public List<Integer> toList() {
@@ -730,8 +713,8 @@ public class IndexedTensor implements Tensor {
private final static class EmptyIndexes extends Indexes {
- private EmptyIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) {
- super(sourceDimensionSizes, iterateDimensionSizes, indexes);
+ private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) {
+ super(sourceSizes, iterateSizes, indexes);
}
@Override
@@ -744,8 +727,8 @@ public class IndexedTensor implements Tensor {
private final static class SingleValueIndexes extends Indexes {
- private SingleValueIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) {
- super(sourceDimensionSizes, iterateDimensionSizes, indexes);
+ private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) {
+ super(sourceSizes, iterateSizes, indexes);
}
@Override
@@ -762,8 +745,8 @@ public class IndexedTensor implements Tensor {
private final List<Integer> iterateDimensions;
- private MultiDimensionIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
- super(sourceDimensionSizes, iterateDimensionSizes, initialIndexes);
+ private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimensions = iterateDimensions;
this.size = size;
@@ -786,7 +769,7 @@ public class IndexedTensor implements Tensor {
@Override
public void next() {
int iterateDimensionsIndex = 0;
- while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes()[iterateDimensions.get(iterateDimensionsIndex)]) {
+ while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes().size(iterateDimensions.get(iterateDimensionsIndex))) {
indexes[iterateDimensions.get(iterateDimensionsIndex)] = 0; // carry over
iterateDimensionsIndex++;
}
@@ -800,8 +783,8 @@ public class IndexedTensor implements Tensor {
private int lastComputedSourceValueIndex = -1;
- private EqualSizeMultiDimensionIndexes(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
- super(dimensionSizes, dimensionSizes, iterateDimensions, initialIndexes, size);
+ private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
+ super(sizes, sizes, iterateDimensions, initialIndexes, size);
}
int toSourceValueIndex() {
@@ -826,18 +809,18 @@ public class IndexedTensor implements Tensor {
/** The iteration step in the value index space */
private final int sourceStep, iterationStep;
- private SingleDimensionIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes,
+ private SingleDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes,
int iterateDimension, int[] initialIndexes, int size) {
- super(sourceDimensionSizes, iterateDimensionSizes, initialIndexes);
+ super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
- this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceDimensionSizes);
- this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateDimensionSizes);
+ this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceSizes);
+ this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateSizes);
// Initialize to the (virtual) position before the first cell
indexes[iterateDimension]--;
- currentSourceValueIndex = IndexedTensor.toValueIndex(indexes, sourceDimensionSizes);
- currentIterationValueIndex = IndexedTensor.toValueIndex(indexes, iterateDimensionSizes);
+ currentSourceValueIndex = IndexedTensor.toValueIndex(indexes, sourceSizes);
+ currentIterationValueIndex = IndexedTensor.toValueIndex(indexes, iterateSizes);
}
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@@ -880,16 +863,16 @@ public class IndexedTensor implements Tensor {
/** The iteration step in the value index space */
private final int step;
- private EqualSizeSingleDimensionIndexes(int[] dimensionSizes,
+ private EqualSizeSingleDimensionIndexes(DimensionSizes sizes,
int iterateDimension, int[] initialIndexes, int size) {
- super(dimensionSizes, dimensionSizes, initialIndexes);
+ super(sizes, sizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
- this.step = productOfDimensionsAfter(iterateDimension, dimensionSizes);
+ this.step = productOfDimensionsAfter(iterateDimension, sizes);
// Initialize to the (virtual) position before the first cell
indexes[iterateDimension]--;
- currentValueIndex = IndexedTensor.toValueIndex(indexes, dimensionSizes);
+ currentValueIndex = IndexedTensor.toValueIndex(indexes, sizes);
}
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 800de360369..51d40a89f3b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -346,7 +346,7 @@ public interface Tensor {
}
/** Creates a suitable builder for the given type */
- static Builder of(TensorType type, int[] dimensionSizes) {
+ static Builder of(TensorType type, DimensionSizes dimensionSizes) {
boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed());
boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
if (containsIndexed && containsMapped)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index f212e66fc86..05999ff1240 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -2,6 +2,7 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -62,10 +63,10 @@ public class Concat extends PrimitiveTensorFunction {
IndexedTensor bIndexed = (IndexedTensor) b;
TensorType concatType = concatType(a, b);
- int[] concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
+ DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
- int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(aIndexed::size).orElseThrow(RuntimeException::new);
+ int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
int[] aToIndexes = mapIndexes(a.type(), concatType);
int[] bToIndexes = mapIndexes(b.type(), concatType);
concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
@@ -123,22 +124,22 @@ public class Concat extends PrimitiveTensorFunction {
}
/** Returns the concrete (not type) dimension sizes resulting from combining a and b */
- private int[] concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
- int[] joinedSizes = new int[concatType.dimensions().size()];
- for (int i = 0; i < joinedSizes.length; i++) {
+ private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
+ DimensionSizes.Builder joinedSizes = new DimensionSizes.Builder(concatType.dimensions().size());
+ for (int i = 0; i < joinedSizes.dimensions(); i++) {
String currentDimension = concatType.dimensions().get(i).name();
- int aSize = a.type().indexOfDimension(currentDimension).map(a::size).orElse(0);
- int bSize = b.type().indexOfDimension(currentDimension).map(b::size).orElse(0);
+ int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0);
+ int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0);
if (currentDimension.equals(concatDimension))
- joinedSizes[i] = aSize + bSize;
+ joinedSizes.set(i, aSize + bSize);
else if (aSize != 0 && bSize != 0 && aSize!=bSize )
throw new IllegalArgumentException("Dimension " + currentDimension + " must be of the same size when " +
"concatenating " + a.type() + " and " + b.type() + " along dimension " +
concatDimension + ", but was " + aSize + " and " + bSize);
else
- joinedSizes[i] = Math.max(aSize, bSize);
+ joinedSizes.set(i, Math.max(aSize, bSize));
}
- return joinedSizes;
+ return joinedSizes.build();
}
/**
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 9c92ca00eac..d95feb29af4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -1,6 +1,7 @@
package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
+import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -68,11 +69,11 @@ public class Generate extends PrimitiveTensorFunction {
return builder.build();
}
- private int[] dimensionSizes(TensorType type) {
- int dimensionSizes[] = new int[type.dimensions().size()];
- for (int i = 0; i < dimensionSizes.length; i++)
- dimensionSizes[i] = type.dimensions().get(i).size().get();
- return dimensionSizes;
+ private DimensionSizes dimensionSizes(TensorType type) {
+ DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
+ for (int i = 0; i < b.dimensions(); i++)
+ b.set(i, type.dimensions().get(i).size().get());
+ return b.build();
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 0844877ba29..23865e1cc1c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -2,13 +2,13 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
-import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
@@ -88,10 +88,10 @@ public class Join extends PrimitiveTensorFunction {
}
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
- int joinedLength = Math.min(a.size(0), b.size(0));
+ int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
- IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new int[] { joinedLength});
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedLength).build());
for (int i = 0; i < joinedLength; i++)
builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
return builder.build();
@@ -119,9 +119,9 @@ public class Join extends PrimitiveTensorFunction {
private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
- return Tensor.Builder.of(joinedType, new int[joinedType.dimensions().size()]).build();
+ return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
- int[] joinedSizes = joinedSize(joinedType, subspace, superspace);
+ DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace);
IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);
@@ -156,16 +156,16 @@ public class Join extends PrimitiveTensorFunction {
}
}
- private int[] joinedSize(TensorType joinedType, IndexedTensor subspace, IndexedTensor superspace) {
- int[] joinedSizes = new int[joinedType.dimensions().size()];
- for (int i = 0; i < joinedSizes.length; i++) {
+ private DimensionSizes joinedSize(TensorType joinedType, IndexedTensor subspace, IndexedTensor superspace) {
+ DimensionSizes.Builder b = new DimensionSizes.Builder(joinedType.dimensions().size());
+ for (int i = 0; i < b.dimensions(); i++) {
Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name());
if (subspaceIndex.isPresent())
- joinedSizes[i] = Math.min(superspace.size(i), subspace.size(subspaceIndex.get()));
+ b.set(i, Math.min(superspace.dimensionSizes().size(i), subspace.dimensionSizes().size(subspaceIndex.get())));
else
- joinedSizes[i] = superspace.size(i);
+ b.set(i, superspace.dimensionSizes().size(i));
}
- return joinedSizes;
+ return b.build();
}
private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index afe98d4bc07..e9566eb3ddf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -147,7 +147,7 @@ public class Reduce extends PrimitiveTensorFunction {
private Tensor reduceIndexedVector(IndexedTensor argument) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
- for (int i = 0; i < argument.size(0); i++)
+ for (int i = 0; i < argument.dimensionSizes().size(0); i++)
valueAggregator.aggregate(argument.get(i));
return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
index 59f86e063ff..3f7f02c6c00 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
@@ -56,8 +56,8 @@ public class IndexedTensorTestCase {
assertEquals(emptyWithDimensions, emptyWithDimensionsFromString);
IndexedTensor emptyWithDimensionsIndexed = (IndexedTensor)emptyWithDimensions;
- assertEquals(0, emptyWithDimensionsIndexed.size(0));
- assertEquals(0, emptyWithDimensionsIndexed.size(1));
+ assertEquals(0, emptyWithDimensionsIndexed.dimensionSizes().size(0));
+ assertEquals(0, emptyWithDimensionsIndexed.dimensionSizes().size(1));
}
@Test