diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-05 16:44:53 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-05 16:44:53 +0100 |
commit | e25d723262ed8702be60ade30d87c2da75fbadf2 (patch) | |
tree | fbfb8cc3327b9abab638fc513cb6fd93b69d8ab9 /vespajlib | |
parent | fd22e7e254528bea682a2e585f5cbb1fc625c93d (diff) |
Type DimensionSizes
Diffstat (limited to 'vespajlib')
8 files changed, 220 insertions, 164 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java new file mode 100644 index 00000000000..76340bb7d8f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -0,0 +1,71 @@ +package com.yahoo.tensor; + +import java.util.Arrays; + +/** + * The sizes of a set of dimensions. + * + * @author bratseth + */ +public final class DimensionSizes { + + private final int[] sizes; + + private DimensionSizes(Builder builder) { + this.sizes = builder.sizes; + builder.sizes = null; // invalidate builder to avoid copying the array + } + + /** + * Returns the length of this in the nth dimension + * + * @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one + */ + public int size(int dimensionIndex) { return sizes[dimensionIndex]; } + + /** Returns the number of dimensions this provides the size of */ + public int dimensions() { return sizes.length; } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (!(o instanceof DimensionSizes)) return false; + return Arrays.equals(((DimensionSizes) o).sizes, this.sizes); + } + + @Override + public int hashCode() { return Arrays.hashCode(sizes); } + + /** + * Builder of a set of dimension sizes. + * Dimensions whose size is not set before building will get size 0. + */ + public final static class Builder { + + private int[] sizes; + + public Builder(int dimensions) { + this.sizes = new int[dimensions]; + } + + public Builder set(int dimensionIndex, int size) { + sizes[dimensionIndex] = size; + return this; + } + + /** + * Returns the length of this in the nth dimension + * + * @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one + */ + public int size(int dimensionIndex) { return sizes[dimensionIndex]; } + + /** Returns the number of dimensions this provides the size of */ + public int dimensions() { return sizes.length; } + + /** Build this. This builder becomes invalid after calling this. */ + public DimensionSizes build() { return new DimensionSizes(this); } + + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 2c3cb6ebde2..d69cf65ee8d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -27,11 +27,11 @@ public class IndexedTensor implements Tensor { private final TensorType type; /** The sizes of the dimensions of this in the order of the dimensions of the type */ - private final int[] dimensionSizes; + private final DimensionSizes dimensionSizes; private final double[] values; - private IndexedTensor(TensorType type, int[] dimensionSizes, double[] values) { + private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) { this.type = type; this.dimensionSizes = dimensionSizes; this.values = values; @@ -68,12 +68,12 @@ public class IndexedTensor implements Tensor { * other iterator. * * @param dimensions the names of the dimensions of the superspace - * @param dimensionSizes the size of each dimension in the space we are returning values for, containing - * one value per dimension of this tensor (in order). Each size may be the same or smaller - * than the corresponding size of this tensor + * @param sizes the size of each dimension in the space we are returning values for, containing + * one value per dimension of this tensor (in order). Each size may be the same or smaller + * than the corresponding size of this tensor */ - public Iterator<SubspaceIterator> subspaceIterator(Set<String> dimensions, int[] dimensionSizes) { - return new SuperspaceIterator(dimensions, dimensionSizes); + public Iterator<SubspaceIterator> subspaceIterator(Set<String> dimensions, DimensionSizes sizes) { + return new SuperspaceIterator(dimensions, sizes); } /** Returns a subspace iterator having the sizes of the dimensions of this tensor */ @@ -81,12 +81,6 @@ public class IndexedTensor implements Tensor { return subspaceIterator(dimensions, dimensionSizes); } - /** Returns whether the dimensions sizes of this are equal to the given sizes */ - // TODO: Replace by returning immutable sizes when DimensionSizes are a class - public boolean dimensionSizesAre(int[] dimensionSizes) { - return Arrays.equals(dimensionSizes, this.dimensionSizes); - } - /** * Returns the value at the given indexes * @@ -110,54 +104,44 @@ public class IndexedTensor implements Tensor { } } - double get(int valueIndex) { return values[valueIndex]; } + private double get(int valueIndex) { return values[valueIndex]; } - /** Returns the value at these indexes */ - private double get(Indexes indexes) { - return values[toValueIndex(indexes.indexesForReading(), dimensionSizes)]; - } - - private static int toValueIndex(int[] indexes, int[] dimensionSizes) { + private static int toValueIndex(int[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed if (indexes.length == 0) return 0; // for speed int valueIndex = 0; for (int i = 0; i < indexes.length; i++) - valueIndex += productOfDimensionsAfter(i, dimensionSizes) * indexes[i]; + valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i]; return valueIndex; } - private static int toValueIndex(TensorAddress address, int[] dimensionSizes) { + private static int toValueIndex(TensorAddress address, DimensionSizes sizes) { if (address.isEmpty()) return 0; int valueIndex = 0; for (int i = 0; i < address.size(); i++) - valueIndex += productOfDimensionsAfter(i, dimensionSizes) * address.intLabel(i); + valueIndex += productOfDimensionsAfter(i, sizes) * address.intLabel(i); return valueIndex; } - private static int productOfDimensionsAfter(int afterIndex, int[] dimensionSizes) { + private static int productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) { int product = 1; - for (int i = afterIndex + 1; i < dimensionSizes.length; i++) - product *= dimensionSizes[i]; + for (int i = afterIndex + 1; i < sizes.dimensions(); i++) + product *= sizes.size(i); return product; } @Override public TensorType type() { return type; } - /** - * Returns the length of this in the nth dimension - * - * @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one - */ - public int size(int dimension) { - return dimensionSizes[dimension]; + public DimensionSizes dimensionSizes() { + return dimensionSizes; } @Override public Map<TensorAddress, Double> cells() { - if (dimensionSizes.length == 0) + if (dimensionSizes.dimensions() == 0) return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]); ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); @@ -201,27 +185,27 @@ public class IndexedTensor implements Tensor { * and, agree with the type size information when specified in the type. * If sizes are completely specified in the type this size information is redundant. */ - public static Builder of(TensorType type, int[] dimensionSizes) { + public static Builder of(TensorType type, DimensionSizes sizes) { // validate - if (dimensionSizes.length != type.dimensions().size()) - throw new IllegalArgumentException(dimensionSizes.length + " is the wrong number of dimension sizes " + - " for " + type); - for (int i = 0; i < dimensionSizes.length; i++ ) { + if (sizes.dimensions() != type.dimensions().size()) + throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + + "for " + type); + for (int i = 0; i < sizes.dimensions(); i++ ) { Optional<Integer> size = type.dimensions().get(i).size(); - if (size.isPresent() && size.get() < dimensionSizes[i]) - throw new IllegalArgumentException("Size of " + type.dimensions() + " is " + dimensionSizes[i] + + if (size.isPresent() && size.get() < sizes.size(i)) + throw new IllegalArgumentException("Size of " + type.dimensions() + " is " + sizes.size(i) + " but cannot be larger than " + size.get()); } - return new BoundBuilder(type, dimensionSizes); + return new BoundBuilder(type, sizes); } public abstract Builder cell(double value, int ... indexes); - protected double[] arrayFor(int[] dimensionSizes) { + protected double[] arrayFor(DimensionSizes sizes) { int productSize = 1; - for (int dimensionSize : dimensionSizes) - productSize *= dimensionSize; + for (int i = 0; i < sizes.dimensions(); i++ ) + productSize *= sizes.size(i); return new double[productSize]; } @@ -236,32 +220,32 @@ public class IndexedTensor implements Tensor { /** A bound builder can create the double array directly */ private static class BoundBuilder extends Builder { - private int[] dimensionSizes; + private DimensionSizes sizes; private double[] values; private BoundBuilder(TensorType type) { this(type, dimensionSizesOf(type)); } - private BoundBuilder(TensorType type, int[] dimensionSizes) { + public static DimensionSizes dimensionSizesOf(TensorType type) { + DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); + for (int i = 0; i < type.dimensions().size(); i++) + b.set(i, type.dimensions().get(i).size().get()); + return b.build(); + } + + private BoundBuilder(TensorType type, DimensionSizes sizes) { super(type); - if ( dimensionSizes.length != type.dimensions().size()) + if ( sizes.dimensions() != type.dimensions().size()) throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); - this.dimensionSizes = dimensionSizes; - values = arrayFor(dimensionSizes); + this.sizes = sizes; + values = arrayFor(sizes); Arrays.fill(values, Double.NaN); } - private static int[] dimensionSizesOf(TensorType type) { - int[] dimensionSizes = new int[type.dimensions().size()]; - for (int i = 0; i < type.dimensions().size(); i++) - dimensionSizes[i] = type.dimensions().get(i).size().get(); - return dimensionSizes; - } - @Override public BoundBuilder cell(double value, int ... indexes) { - values[toValueIndex(indexes, dimensionSizes)] = value; + values[toValueIndex(indexes, sizes)] = value; return this; } @@ -272,7 +256,7 @@ public class IndexedTensor implements Tensor { @Override public Builder cell(TensorAddress address, double value) { - values[toValueIndex(address, dimensionSizes)] = value; + values[toValueIndex(address, sizes)] = value; return this; } @@ -282,9 +266,9 @@ public class IndexedTensor implements Tensor { // NaN's don't get lost so leaving them in place should be quite benign if (values.length == 1 && Double.isNaN(values[0])) values = new double[0]; - IndexedTensor tensor = new IndexedTensor(type, dimensionSizes, values); + IndexedTensor tensor = new IndexedTensor(type, sizes, values); // prevent further modification - dimensionSizes = null; + sizes = null; values = null; return tensor; } @@ -320,23 +304,23 @@ public class IndexedTensor implements Tensor { @Override public IndexedTensor build() { if (firstDimension == null) // empty - return new IndexedTensor(type, new int[type.dimensions().size()], new double[] {}); + return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {}); if (type.dimensions().isEmpty()) // single number - return new IndexedTensor(type, new int[type.dimensions().size()], new double[] {(Double) firstDimension.get(0) }); + return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); - int[] dimensionSizes = findDimensionSizes(firstDimension); + DimensionSizes dimensionSizes = findDimensionSizes(firstDimension); double[] values = arrayFor(dimensionSizes); fillValues(0, 0, firstDimension, dimensionSizes, values); return new IndexedTensor(type, dimensionSizes, values); } - private int[] findDimensionSizes(List<Object> firstDimension) { + private DimensionSizes findDimensionSizes(List<Object> firstDimension) { List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size()); findDimensionSizes(0, dimensionSizeList, firstDimension); - int[] dimensionSizes = new int[type.dimensions().size()]; // may be longer than the list but that's correct - for (int i = 0; i < dimensionSizes.length; i++) - dimensionSizes[i] = dimensionSizeList.get(i); - return dimensionSizes; + DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct + for (int i = 0; i < b.dimensions(); i++) + b.set(i, dimensionSizeList.get(i)); + return b.build(); } @SuppressWarnings("unchecked") @@ -354,12 +338,12 @@ public class IndexedTensor implements Tensor { @SuppressWarnings("unchecked") private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension, - int[] dimensionSizes, double[] values) { - if (currentDimensionIndex < dimensionSizes.length - 1) { // recurse to next dimension + DimensionSizes sizes, double[] values) { + if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension for (int i = 0; i < currentDimension.size(); i++) fillValues(currentDimensionIndex + 1, - offset + productOfDimensionsAfter(currentDimensionIndex, dimensionSizes) * i, - (List<Object>) currentDimension.get(i), dimensionSizes, values); + offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i, + (List<Object>) currentDimension.get(i), sizes, values); } else { // last dimension - fill values for (int i = 0; i < currentDimension.size(); i++) values[offset + i] = (double) currentDimension.get(i); @@ -477,12 +461,12 @@ public class IndexedTensor implements Tensor { * The sizes of the space we'll return values of, one value for each dimension of this tensor, * which may be equal to or smaller than the sizes of this tensor */ - private final int[] iterateDimensionSizes; + private final DimensionSizes iterateSizes; private int count = 0; - private SuperspaceIterator(Set<String> superdimensionNames, int[] iterateDimensionSizes) { - this.iterateDimensionSizes = iterateDimensionSizes; + private SuperspaceIterator(Set<String> superdimensionNames, DimensionSizes iterateSizes) { + this.iterateSizes = iterateSizes; List<Integer> superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length) @@ -493,7 +477,7 @@ public class IndexedTensor implements Tensor { subdimensionIndexes.add(i); } - superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, superdimensionIndexes); + superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, superdimensionIndexes); } @Override @@ -506,7 +490,7 @@ public class IndexedTensor implements Tensor { if ( ! hasNext()) throw new NoSuchElementException("No cell at " + superindexes); count++; superindexes.next(); - return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), iterateDimensionSizes); + return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), iterateSizes); } } @@ -525,7 +509,7 @@ public class IndexedTensor implements Tensor { */ private final List<Integer> iterateDimensions; private final int[] address; - private final int[] iterateDimensionSizes; + private final DimensionSizes iterateSizes; private Indexes indexes; private int count = 0; @@ -545,11 +529,11 @@ public class IndexedTensor implements Tensor { * This is treated as immutable. * @param address the address of the first cell of this subspace. */ - private SubspaceIterator(List<Integer> iterateDimensions, int[] address, int[] iterateDimensionSizes) { + private SubspaceIterator(List<Integer> iterateDimensions, int[] address, DimensionSizes iterateSizes) { this.iterateDimensions = iterateDimensions; this.address = address; - this.iterateDimensionSizes = iterateDimensionSizes; - this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, iterateDimensions, address); + this.iterateSizes = iterateSizes; + this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); reusedCell = new LazyCell(indexes, Double.NaN); } @@ -564,7 +548,7 @@ public class IndexedTensor implements Tensor { /** Rewind this iterator to the first element */ public void reset() { this.count = 0; - this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, iterateDimensions, address); + this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); } @Override @@ -617,54 +601,54 @@ public class IndexedTensor implements Tensor { */ public abstract static class Indexes { - private final int[] sourceDimensionSizes; + private final DimensionSizes sourceSizes; - private final int[] iterationDimensionSizes; + private final DimensionSizes iterationSizes; protected final int[] indexes; - public static Indexes of(int[] dimensionSizes) { - return of(dimensionSizes, dimensionSizes); + public static Indexes of(DimensionSizes sizes) { + return of(sizes, sizes); } - private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes) { - return of(sourceDimensionSizes, iterateDimensionSizes, completeIterationOrder(iterateDimensionSizes.length)); + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes) { + return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions())); } - private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int size) { - return of(sourceDimensionSizes, iterateDimensionSizes, completeIterationOrder(iterateDimensionSizes.length), size); + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int size) { + return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size); } - private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions) { - return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, computeSize(iterateDimensionSizes, iterateDimensions)); + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions) { + return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions)); } - private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int size) { - return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, new int[iterateDimensionSizes.length], size); + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int size) { + return of(sourceSizes, iterateSizes, iterateDimensions, new int[iterateSizes.dimensions()], size); } - private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes) { - return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, initialIndexes, computeSize(iterateDimensionSizes, iterateDimensions)); + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes) { + return of(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, computeSize(iterateSizes, iterateDimensions)); } - private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { if (size == 0) { - return new EmptyIndexes(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); // we're told explicitly there are truly no values available + return new EmptyIndexes(sourceSizes, iterateSizes, initialIndexes); // we're told explicitly there are truly no values available } else if (size == 1) { - return new SingleValueIndexes(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); // with no (iterating) dimensions, we still return one value, not zero + return new SingleValueIndexes(sourceSizes, iterateSizes, initialIndexes); // with no (iterating) dimensions, we still return one value, not zero } else if (iterateDimensions.size() == 1) { - if (Arrays.equals(sourceDimensionSizes, iterateDimensionSizes)) - return new EqualSizeSingleDimensionIndexes(sourceDimensionSizes, iterateDimensions.get(0), initialIndexes, size); + if (sourceSizes.equals(iterateSizes)) + return new EqualSizeSingleDimensionIndexes(sourceSizes, iterateDimensions.get(0), initialIndexes, size); else - return new SingleDimensionIndexes(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions.get(0), initialIndexes, size); // optimization + return new SingleDimensionIndexes(sourceSizes, iterateSizes, iterateDimensions.get(0), initialIndexes, size); // optimization } else { - if (Arrays.equals(sourceDimensionSizes, iterateDimensionSizes)) - return new EqualSizeMultiDimensionIndexes(sourceDimensionSizes, iterateDimensions, initialIndexes, size); + if (sourceSizes.equals(iterateSizes)) + return new EqualSizeMultiDimensionIndexes(sourceSizes, iterateDimensions, initialIndexes, size); else - return new MultiDimensionIndexes(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, initialIndexes, size); + return new MultiDimensionIndexes(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, size); } } @@ -675,16 +659,16 @@ public class IndexedTensor implements Tensor { return iterationDimensions; } - private Indexes(int[] sourceDimensionSizes, int[] iterationDimensionSizes, int[] indexes) { - this.sourceDimensionSizes = sourceDimensionSizes; - this.iterationDimensionSizes = iterationDimensionSizes; + private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) { + this.sourceSizes = sourceSizes; + this.iterationSizes = iterationSizes; this.indexes = indexes; } - private static int computeSize(int[] dimensionSizes, List<Integer> iterateDimensions) { + private static int computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) { int size = 1; for (int iterateDimension : iterateDimensions) - size *= dimensionSizes[iterateDimension]; + size *= sizes.size(iterateDimension); return size; } @@ -701,13 +685,12 @@ public class IndexedTensor implements Tensor { public int[] indexesForReading() { return indexes; } int toSourceValueIndex() { - return IndexedTensor.toValueIndex(indexes, sourceDimensionSizes); + return IndexedTensor.toValueIndex(indexes, sourceSizes); } - int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationDimensionSizes); } + int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); } - /** Returns the dimension sizes of this. Do not modify the return value */ - int[] dimensionSizes() { return iterationDimensionSizes; } + DimensionSizes dimensionSizes() { return iterationSizes; } /** Returns an immutable list containing a copy of the indexes in this */ public List<Integer> toList() { @@ -730,8 +713,8 @@ public class IndexedTensor implements Tensor { private final static class EmptyIndexes extends Indexes { - private EmptyIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) { - super(sourceDimensionSizes, iterateDimensionSizes, indexes); + private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) { + super(sourceSizes, iterateSizes, indexes); } @Override @@ -744,8 +727,8 @@ public class IndexedTensor implements Tensor { private final static class SingleValueIndexes extends Indexes { - private SingleValueIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) { - super(sourceDimensionSizes, iterateDimensionSizes, indexes); + private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) { + super(sourceSizes, iterateSizes, indexes); } @Override @@ -762,8 +745,8 @@ public class IndexedTensor implements Tensor { private final List<Integer> iterateDimensions; - private MultiDimensionIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { - super(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); + private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { + super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimensions = iterateDimensions; this.size = size; @@ -786,7 +769,7 @@ public class IndexedTensor implements Tensor { @Override public void next() { int iterateDimensionsIndex = 0; - while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes()[iterateDimensions.get(iterateDimensionsIndex)]) { + while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes().size(iterateDimensions.get(iterateDimensionsIndex))) { indexes[iterateDimensions.get(iterateDimensionsIndex)] = 0; // carry over iterateDimensionsIndex++; } @@ -800,8 +783,8 @@ public class IndexedTensor implements Tensor { private int lastComputedSourceValueIndex = -1; - private EqualSizeMultiDimensionIndexes(int[] dimensionSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { - super(dimensionSizes, dimensionSizes, iterateDimensions, initialIndexes, size); + private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) { + super(sizes, sizes, iterateDimensions, initialIndexes, size); } int toSourceValueIndex() { @@ -826,18 +809,18 @@ public class IndexedTensor implements Tensor { /** The iteration step in the value index space */ private final int sourceStep, iterationStep; - private SingleDimensionIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, + private SingleDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int iterateDimension, int[] initialIndexes, int size) { - super(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); + super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; - this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceDimensionSizes); - this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateDimensionSizes); + this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceSizes); + this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateSizes); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; - currentSourceValueIndex = IndexedTensor.toValueIndex(indexes, sourceDimensionSizes); - currentIterationValueIndex = IndexedTensor.toValueIndex(indexes, iterateDimensionSizes); + currentSourceValueIndex = IndexedTensor.toValueIndex(indexes, sourceSizes); + currentIterationValueIndex = IndexedTensor.toValueIndex(indexes, iterateSizes); } /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @@ -880,16 +863,16 @@ public class IndexedTensor implements Tensor { /** The iteration step in the value index space */ private final int step; - private EqualSizeSingleDimensionIndexes(int[] dimensionSizes, + private EqualSizeSingleDimensionIndexes(DimensionSizes sizes, int iterateDimension, int[] initialIndexes, int size) { - super(dimensionSizes, dimensionSizes, initialIndexes); + super(sizes, sizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; - this.step = productOfDimensionsAfter(iterateDimension, dimensionSizes); + this.step = productOfDimensionsAfter(iterateDimension, sizes); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; - currentValueIndex = IndexedTensor.toValueIndex(indexes, dimensionSizes); + currentValueIndex = IndexedTensor.toValueIndex(indexes, sizes); } /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 800de360369..51d40a89f3b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -346,7 +346,7 @@ public interface Tensor { } /** Creates a suitable builder for the given type */ - static Builder of(TensorType type, int[] dimensionSizes) { + static Builder of(TensorType type, DimensionSizes dimensionSizes) { boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed()); boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); if (containsIndexed && containsMapped) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index f212e66fc86..05999ff1240 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -62,10 +63,10 @@ public class Concat extends PrimitiveTensorFunction { IndexedTensor bIndexed = (IndexedTensor) b; TensorType concatType = concatType(a, b); - int[] concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); + DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); - int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(aIndexed::size).orElseThrow(RuntimeException::new); + int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); int[] aToIndexes = mapIndexes(a.type(), concatType); int[] bToIndexes = mapIndexes(b.type(), concatType); concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); @@ -123,22 +124,22 @@ public class Concat extends PrimitiveTensorFunction { } /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ - private int[] concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { - int[] joinedSizes = new int[concatType.dimensions().size()]; - for (int i = 0; i < joinedSizes.length; i++) { + private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { + DimensionSizes.Builder joinedSizes = new DimensionSizes.Builder(concatType.dimensions().size()); + for (int i = 0; i < joinedSizes.dimensions(); i++) { String currentDimension = concatType.dimensions().get(i).name(); - int aSize = a.type().indexOfDimension(currentDimension).map(a::size).orElse(0); - int bSize = b.type().indexOfDimension(currentDimension).map(b::size).orElse(0); + int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0); + int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0); if (currentDimension.equals(concatDimension)) - joinedSizes[i] = aSize + bSize; + joinedSizes.set(i, aSize + bSize); else if (aSize != 0 && bSize != 0 && aSize!=bSize ) throw new IllegalArgumentException("Dimension " + currentDimension + " must be of the same size when " + "concatenating " + a.type() + " and " + b.type() + " along dimension " + concatDimension + ", but was " + aSize + " and " + bSize); else - joinedSizes[i] = Math.max(aSize, bSize); + joinedSizes.set(i, Math.max(aSize, bSize)); } - return joinedSizes; + return joinedSizes.build(); } /** diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 9c92ca00eac..d95feb29af4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -1,6 +1,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; +import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -68,11 +69,11 @@ public class Generate extends PrimitiveTensorFunction { return builder.build(); } - private int[] dimensionSizes(TensorType type) { - int dimensionSizes[] = new int[type.dimensions().size()]; - for (int i = 0; i < dimensionSizes.length; i++) - dimensionSizes[i] = type.dimensions().get(i).size().get(); - return dimensionSizes; + private DimensionSizes dimensionSizes(TensorType type) { + DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); + for (int i = 0; i < b.dimensions(); i++) + b.set(i, type.dimensions().get(i).size().get()); + return b.build(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 0844877ba29..23865e1cc1c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -2,13 +2,13 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; -import java.util.Arrays; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -88,10 +88,10 @@ public class Join extends PrimitiveTensorFunction { } private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { - int joinedLength = Math.min(a.size(0), b.size(0)); + int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); - IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new int[] { joinedLength}); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedLength).build()); for (int i = 0; i < joinedLength; i++) builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i); return builder.build(); @@ -119,9 +119,9 @@ public class Join extends PrimitiveTensorFunction { private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes - return Tensor.Builder.of(joinedType, new int[joinedType.dimensions().size()]).build(); + return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build(); - int[] joinedSizes = joinedSize(joinedType, subspace, superspace); + DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace); IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes); @@ -156,16 +156,16 @@ public class Join extends PrimitiveTensorFunction { } } - private int[] joinedSize(TensorType joinedType, IndexedTensor subspace, IndexedTensor superspace) { - int[] joinedSizes = new int[joinedType.dimensions().size()]; - for (int i = 0; i < joinedSizes.length; i++) { + private DimensionSizes joinedSize(TensorType joinedType, IndexedTensor subspace, IndexedTensor superspace) { + DimensionSizes.Builder b = new DimensionSizes.Builder(joinedType.dimensions().size()); + for (int i = 0; i < b.dimensions(); i++) { Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name()); if (subspaceIndex.isPresent()) - joinedSizes[i] = Math.min(superspace.size(i), subspace.size(subspaceIndex.get())); + b.set(i, Math.min(superspace.dimensionSizes().size(i), subspace.dimensionSizes().size(subspaceIndex.get()))); else - joinedSizes[i] = superspace.size(i); + b.set(i, superspace.dimensionSizes().size(i)); } - return joinedSizes; + return b.build(); } private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index afe98d4bc07..e9566eb3ddf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -147,7 +147,7 @@ public class Reduce extends PrimitiveTensorFunction { private Tensor reduceIndexedVector(IndexedTensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); - for (int i = 0; i < argument.size(0); i++) + for (int i = 0; i < argument.dimensionSizes().size(0); i++) valueAggregator.aggregate(argument.get(i)); return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index 59f86e063ff..3f7f02c6c00 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -56,8 +56,8 @@ public class IndexedTensorTestCase { assertEquals(emptyWithDimensions, emptyWithDimensionsFromString); IndexedTensor emptyWithDimensionsIndexed = (IndexedTensor)emptyWithDimensions; - assertEquals(0, emptyWithDimensionsIndexed.size(0)); - assertEquals(0, emptyWithDimensionsIndexed.size(1)); + assertEquals(0, emptyWithDimensionsIndexed.dimensionSizes().size(0)); + assertEquals(0, emptyWithDimensionsIndexed.dimensionSizes().size(1)); } @Test |