From f6007e2cce9c6048ae27a8af3df6fdd917162f75 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Thu, 5 Jan 2017 09:59:58 +0100 Subject: Optimize subspace iteration --- .../main/java/com/yahoo/tensor/IndexedTensor.java | 155 ++++++++++++++------- .../src/main/java/com/yahoo/tensor/Tensor.java | 4 +- .../main/java/com/yahoo/tensor/TensorAddress.java | 28 +++- .../main/java/com/yahoo/tensor/functions/Join.java | 58 +++++--- 4 files changed, 169 insertions(+), 76 deletions(-) (limited to 'vespajlib/src/main/java') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 6e03c27af75..b89185b5131 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -81,6 +81,12 @@ public class IndexedTensor implements Tensor { return subspaceIterator(dimensions, dimensionSizes); } + /** Returns whether the dimensions sizes of this are equal to the given sizes */ + // TODO: Replace by returning immutable sizes when DimensionSizes are a class + public boolean dimensionSizesAre(int[] dimensionSizes) { + return Arrays.equals(dimensionSizes, this.dimensionSizes); + } + /** * Returns the value at the given indexes * @@ -95,7 +101,7 @@ public class IndexedTensor implements Tensor { /** Returns the value at this address, or NaN if there is no value at this address */ @Override public double get(TensorAddress address) { - // optimize for fast lookup within bounds + // optimize for fast lookup within bounds: try { return values[toValueIndex(address, dimensionSizes)]; } @@ -104,6 +110,8 @@ public class IndexedTensor implements Tensor { } } + double get(int valueIndex) { return values[valueIndex]; } + /** Returns the value at these indexes */ private double get(Indexes indexes) { return values[toValueIndex(indexes.indexesForReading(), dimensionSizes)]; @@ -153,10 +161,10 @@ public class IndexedTensor implements Tensor { return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]); ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); - Indexes indexes = Indexes.of(dimensionSizes, values.length); + Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); for (int i = 0; i < values.length; i++) { indexes.next(); - builder.put(indexes.toAddress(), values[i]); + builder.put(indexes.toAddress(i), values[i]); } return builder.build(); } @@ -209,7 +217,10 @@ public class IndexedTensor implements Tensor { } public abstract Builder cell(double value, int ... indexes); - + + /** Add a cell by internal index */ + public abstract Builder cellWithInternalIndex(int internalIndex, double value); + protected double[] arrayFor(int[] dimensionSizes) { int productSize = 1; for (int dimensionSize : dimensionSizes) @@ -281,6 +292,12 @@ public class IndexedTensor implements Tensor { return tensor; } + @Override + public Builder cellWithInternalIndex(int internalIndex, double value) { + values[internalIndex] = value; + return this; + } + } /** @@ -400,13 +417,17 @@ public class IndexedTensor implements Tensor { list.add(list.size(), null); } + @Override + public Builder cellWithInternalIndex(int internalIndex, double value) { + throw new UnsupportedOperationException("Not supoprted for unbound builders"); + } + } - // TODO: Generalize to vector cell iterator? private final class CellIterator implements Iterator> { private int count = 0; - private final Indexes indexes = Indexes.of(dimensionSizes, values.length); + private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); @Override public boolean hasNext() { @@ -418,7 +439,9 @@ public class IndexedTensor implements Tensor { if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes); count++; indexes.next(); - return new Cell(indexes.toAddress(), get(indexes)); + int valueIndex = toValueIndex(indexes.indexesForReading(), IndexedTensor.this.dimensionSizes); + TensorAddress address = indexes.toAddress(valueIndex); + return new Cell(address, get(valueIndex)); } } @@ -493,12 +516,12 @@ public class IndexedTensor implements Tensor { * The sizes of the space we'll return values of, one value for each dimension of this tensor, * which may be equal to or smaller than the sizes of this tensor */ - private final int[] dimensionSizes; + private final int[] iterateDimensionSizes; private int count = 0; - private SuperspaceIterator(Set superdimensionNames, int[] dimensionSizes) { - this.dimensionSizes = dimensionSizes; + private SuperspaceIterator(Set superdimensionNames, int[] iterateDimensionSizes) { + this.iterateDimensionSizes = iterateDimensionSizes; List superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length) @@ -509,7 +532,7 @@ public class IndexedTensor implements Tensor { subdimensionIndexes.add(i); } - superindexes = Indexes.of(dimensionSizes, superdimensionIndexes); + superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, superdimensionIndexes); } @Override @@ -522,7 +545,7 @@ public class IndexedTensor implements Tensor { if ( ! hasNext()) throw new NoSuchElementException("No cell at " + superindexes); count++; superindexes.next(); - return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), dimensionSizes); + return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), iterateDimensionSizes); } } @@ -539,7 +562,7 @@ public class IndexedTensor implements Tensor { */ private final List iterateDimensions; private final int[] address; - private final int[] dimensionSizes; + private final int[] iterateDimensionSizes; private Indexes indexes; private int count = 0; @@ -556,11 +579,11 @@ public class IndexedTensor implements Tensor { * This is treated as immutable. * @param address the address of the first cell of this subspace. */ - private SubspaceIterator(List iterateDimensions, int[] address, int[] dimensionSizes) { + private SubspaceIterator(List iterateDimensions, int[] address, int[] iterateDimensionSizes) { this.iterateDimensions = iterateDimensions; this.address = address; - this.dimensionSizes = dimensionSizes; - this.indexes = Indexes.of(dimensionSizes, iterateDimensions, address); + this.iterateDimensionSizes = iterateDimensionSizes; + this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, iterateDimensions, address); } /** Returns the total number of cells in this subspace */ @@ -569,12 +592,12 @@ public class IndexedTensor implements Tensor { } /** Returns the address of the cell this currently points to (which may be an invalid position) */ - public TensorAddress address() { return indexes.toAddress(); } + public TensorAddress address() { return indexes.toAddress(-1); } /** Rewind this iterator to the first element */ public void reset() { this.count = 0; - this.indexes = Indexes.of(dimensionSizes, iterateDimensions, address); + this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateDimensionSizes, iterateDimensions, address); } @Override @@ -587,10 +610,14 @@ public class IndexedTensor implements Tensor { if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes); count++; indexes.next(); - return new Cell(indexes.toAddress(), get(indexes)); + int valueIndex = indexes.toValueIndex(); + TensorAddress address = indexes.toAddress(valueIndex); + return new Cell(address, get(valueIndex)); // TODO: Change type to Cell, then change Cell to work with indexes + valueIndex instead of creating an address? } } + + // TODO: Make dimensionSizes a class /** * An array of indexes into this tensor which are able to find the next index in the value order. @@ -599,37 +626,45 @@ public class IndexedTensor implements Tensor { */ public abstract static class Indexes { + private final int[] sourceDimensionSizes; + + private final int[] iterateDimensionSizes; + protected final int[] indexes; public static Indexes of(int[] dimensionSizes) { - return of(dimensionSizes, completeIterationOrder(dimensionSizes.length)); + return of(dimensionSizes, dimensionSizes); + } + + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes) { + return of(sourceDimensionSizes, iterateDimensionSizes, completeIterationOrder(iterateDimensionSizes.length)); } - private static Indexes of(int[] dimensionSizes, int size) { - return of(dimensionSizes, completeIterationOrder(dimensionSizes.length), size); + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int size) { + return of(sourceDimensionSizes, iterateDimensionSizes, completeIterationOrder(iterateDimensionSizes.length), size); } - private static Indexes of(int[] dimensionSizes, List iterateDimensions) { - return of(dimensionSizes, iterateDimensions, computeSize(dimensionSizes, iterateDimensions)); + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List iterateDimensions) { + return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, computeSize(iterateDimensionSizes, iterateDimensions)); } - private static Indexes of(int[] dimensionSizes, List iterateDimensions, int size) { - return of(dimensionSizes, iterateDimensions, new int[dimensionSizes.length], size); + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List iterateDimensions, int size) { + return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, new int[iterateDimensionSizes.length], size); } - private static Indexes of(int[] dimensionSizes, List iterateDimensions, int[] initialIndexes) { - return of(dimensionSizes, iterateDimensions, initialIndexes, computeSize(dimensionSizes, iterateDimensions)); + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List iterateDimensions, int[] initialIndexes) { + return of(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, initialIndexes, computeSize(iterateDimensionSizes, iterateDimensions)); } - private static Indexes of(int[] dimensionSizes, List iterateDimensions, int[] initialIndexes, int size) { + private static Indexes of(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List iterateDimensions, int[] initialIndexes, int size) { if (size == 0) - return new EmptyIndexes(initialIndexes); // we're told explicitly there are truly no values available + return new EmptyIndexes(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); // we're told explicitly there are truly no values available else if (size == 1) - return new SingleValueIndexes(initialIndexes); // with no (iterating) dimensions, we still return one value, not zero + return new SingleValueIndexes(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); // with no (iterating) dimensions, we still return one value, not zero else if (iterateDimensions.size() == 1) - return new SingleDimensionIndexes(iterateDimensions.get(0), initialIndexes, size); // optimization + return new SingleDimensionIndexes(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions.get(0), initialIndexes, size); // optimization else - return new MultivalueIndexes(dimensionSizes, iterateDimensions, initialIndexes, size); + return new MultivalueIndexes(sourceDimensionSizes, iterateDimensionSizes, iterateDimensions, initialIndexes, size); } private static List completeIterationOrder(int length) { @@ -639,7 +674,9 @@ public class IndexedTensor implements Tensor { return iterationDimensions; } - private Indexes(int[] indexes) { + private Indexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) { + this.sourceDimensionSizes = sourceDimensionSizes; + this.iterateDimensionSizes = iterateDimensionSizes; this.indexes = indexes; } @@ -651,8 +688,8 @@ public class IndexedTensor implements Tensor { } /** Returns the address of the current position of these indexes */ - private TensorAddress toAddress() { - return TensorAddress.of(indexes); + private TensorAddress toAddress(int valueIndex) { + return TensorAddress.withValueIndex(valueIndex, indexes); } public int[] indexesCopy() { @@ -661,6 +698,14 @@ public class IndexedTensor implements Tensor { /** Returns a copy of the indexes of this which must not be modified */ public int[] indexesForReading() { return indexes; } + + /** Returns the value index for this in the tensor we are iterating over */ + int toValueIndex() { + return IndexedTensor.toValueIndex(indexes, sourceDimensionSizes); + } + + /** Returns the dimension sizes of this. Do not modify the return value */ + int[] dimensionSizes() { return iterateDimensionSizes; } /** Returns an immutable list containing a copy of the indexes in this */ public List toList() { @@ -683,8 +728,8 @@ public class IndexedTensor implements Tensor { private final static class EmptyIndexes extends Indexes { - private EmptyIndexes(int[] indexes) { - super(indexes); + private EmptyIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) { + super(sourceDimensionSizes, iterateDimensionSizes, indexes); } @Override @@ -697,8 +742,8 @@ public class IndexedTensor implements Tensor { private final static class SingleValueIndexes extends Indexes { - private SingleValueIndexes(int[] indexes) { - super(indexes); + private SingleValueIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, int[] indexes) { + super(sourceDimensionSizes, iterateDimensionSizes, indexes); } @Override @@ -713,13 +758,10 @@ public class IndexedTensor implements Tensor { private final int size; - private final int[] dimensionSizes; - private final List iterateDimensions; - private MultivalueIndexes(int[] dimensionSizes, List iterateDimensions, int[] initialIndexes, int size) { - super(initialIndexes); - this.dimensionSizes = dimensionSizes; + private MultivalueIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, List iterateDimensions, int[] initialIndexes, int size) { + super(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); this.iterateDimensions = iterateDimensions; this.size = size; @@ -742,7 +784,7 @@ public class IndexedTensor implements Tensor { @Override public void next() { int iterateDimensionsIndex = 0; - while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes[iterateDimensions.get(iterateDimensionsIndex)]) { + while ( indexes[iterateDimensions.get(iterateDimensionsIndex)] + 1 == dimensionSizes()[iterateDimensions.get(iterateDimensionsIndex)]) { indexes[iterateDimensions.get(iterateDimensionsIndex)] = 0; // carry over iterateDimensionsIndex++; } @@ -756,16 +798,25 @@ public class IndexedTensor implements Tensor { private final int size; private final int iterateDimension; + + /** Maintain this directly as an optimization for 1-d iteration */ + private int currentValueIndex; - private SingleDimensionIndexes(int iterateDimension, int[] initialIndexes, int size) { - super(initialIndexes); + /** The iteration step in the value index space */ + private final int step; + + private SingleDimensionIndexes(int[] sourceDimensionSizes, int[] iterateDimensionSizes, + int iterateDimension, int[] initialIndexes, int size) { + super(sourceDimensionSizes, iterateDimensionSizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; + this.step = productOfDimensionsAfter(iterateDimension, sourceDimensionSizes); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; + currentValueIndex = IndexedTensor.toValueIndex(indexes, sourceDimensionSizes); } - + /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @Override public int size() { @@ -781,6 +832,12 @@ public class IndexedTensor implements Tensor { @Override public void next() { indexes[iterateDimension]++; + currentValueIndex += step; + } + + @Override + int toValueIndex() { + return currentValueIndex; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index bbe6cf7d017..4a3302d7a71 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -323,9 +323,9 @@ public interface Tensor { /** Add a cell */ Builder cell(double value, int ... labels); - + Tensor build(); - + class CellBuilder { private final TensorAddress.Builder addressBuilder; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 1c5eec01834..9bf88cc8213 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -15,8 +15,7 @@ import java.util.Set; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension - * in a particular tensor type. As it is just a list of cell labels, it has no independenty meaning without - * its accompanying type. + * in a particular tensor type. By itself it is just a list of cell labels, it's meaning depends on its accompanying type. * * @author bratseth */ @@ -33,6 +32,11 @@ public abstract class TensorAddress implements Comparable { return new IntTensorAddress(labels); } + /** A tensor address which knows its value index (computed from the labels and the size) in some tensor */ + static TensorAddress withValueIndex(int valueIndex, int[] labels) { + return new IntTensorAddress(valueIndex, labels); + } + /** Returns the number of labels in this */ public abstract int size(); @@ -54,7 +58,13 @@ public abstract class TensorAddress implements Comparable { public abstract TensorAddress withLabel(int labelIndex, int label); public final boolean isEmpty() { return size() == 0; } - + + /** + * Returns the value index of this address (computed from the labels and the size) in some tensor. + * This may be retained as an optimization. It is -1 if not set. + */ + public int valueIndex() { return -1; } + @Override public int compareTo(TensorAddress other) { // TODO: Formal issue (only): Ordering with different address sizes @@ -138,9 +148,16 @@ public abstract class TensorAddress implements Comparable { private static final class IntTensorAddress extends TensorAddress { + private final int valueIndex; + private final int[] labels; - private IntTensorAddress(int ... labels) { + private IntTensorAddress(int[] labels) { + this(-1, labels); + } + + private IntTensorAddress(int valueIndex, int[] labels) { + this.valueIndex = valueIndex; this.labels = Arrays.copyOf(labels, labels.length); } @@ -165,6 +182,9 @@ public abstract class TensorAddress implements Comparable { return Arrays.toString(labels); } + @Override + public int valueIndex() { return valueIndex; } + } /** Supports building of a tensor address */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index d7f00f2d6f2..030fdb754de 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -122,8 +122,9 @@ public class Join extends PrimitiveTensorFunction { return Tensor.Builder.of(joinedType, new int[joinedType.dimensions().size()]).build(); int[] joinedSizes = joinedSize(joinedType, subspace, superspace); + boolean equalTargetAndSourceSize = superspace.dimensionSizesAre(joinedSizes); - Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSizes); + IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes); // Find dimensions which are only in the supertype Set superDimensionNames = new HashSet<>(superspace.type().dimensionNames()); @@ -133,12 +134,45 @@ public class Join extends PrimitiveTensorFunction { IndexedTensor.SubspaceIterator subspaceInSuper = i.next(); joinSubspaces(subspace.valueIterator(), subspace.size(), subspaceInSuper, subspaceInSuper.size(), - reversedArgumentOrder, builder); + reversedArgumentOrder, builder, equalTargetAndSourceSize); } return builder.build(); } - + + private void joinSubspaces(Iterator subspace, int subspaceSize, + Iterator> superspace, int superspaceSize, + boolean reversedArgumentOrder, IndexedTensor.Builder builder, boolean equalTargetAndSourceSize) { + int joinedLength = Math.min(subspaceSize, superspaceSize); + // This is inner loop and therefore it is suplicated four times to move checks out of it + if (equalTargetAndSourceSize) { // we can write cells without recomputing the address index + if (reversedArgumentOrder) { + for (int i = 0; i < joinedLength; i++) { + Map.Entry supercell = superspace.next(); + builder.cellWithInternalIndex(supercell.getKey().valueIndex(), combinator.applyAsDouble(supercell.getValue(), subspace.next())); + } + } else { + for (int i = 0; i < joinedLength; i++) { + Map.Entry supercell = superspace.next(); + builder.cellWithInternalIndex(supercell.getKey().valueIndex(), combinator.applyAsDouble(subspace.next(), supercell.getValue())); + } + } + } + else { + if (reversedArgumentOrder) { + for (int i = 0; i < joinedLength; i++) { + Map.Entry supercell = superspace.next(); + builder.cell(supercell.getKey(), combinator.applyAsDouble(supercell.getValue(), subspace.next())); + } + } else { + for (int i = 0; i < joinedLength; i++) { + Map.Entry supercell = superspace.next(); + builder.cell(supercell.getKey(), combinator.applyAsDouble(subspace.next(), supercell.getValue())); + } + } + } + } + private int[] joinedSize(TensorType joinedType, IndexedTensor subspace, IndexedTensor superspace) { int[] joinedSizes = new int[joinedType.dimensions().size()]; for (int i = 0; i < joinedSizes.length; i++) { @@ -166,24 +200,6 @@ public class Join extends PrimitiveTensorFunction { return builder.build(); } - private void joinSubspaces(Iterator subspace, int subspaceSize, - Iterator> superspace, int superspaceSize, - boolean reversedArgumentOrder, Tensor.Builder builder) { - int joinedLength = Math.min(subspaceSize, superspaceSize); - if (reversedArgumentOrder) { - for (int i = 0; i < joinedLength; i++) { - Map.Entry supercell = superspace.next(); - builder.cell(supercell.getKey(), combinator.applyAsDouble(supercell.getValue(), subspace.next())); - } - } - else { - for (int i = 0; i < joinedLength; i++) { - Map.Entry supercell = superspace.next(); - builder.cell(supercell.getKey(), combinator.applyAsDouble(subspace.next(), supercell.getValue())); - } - } - } - /** Returns the indexes in the superspace type which should be retained to create the subspace type */ private int[] subspaceIndexes(TensorType supertype, TensorType subtype) { int[] subspaceIndexes = new int[subtype.dimensions().size()]; -- cgit v1.2.3