From 03bce1fe1a494f2ac9d4268d4c90b08011b3f600 Mon Sep 17 00:00:00 2001 From: gjoranv Date: Sun, 17 Dec 2017 21:44:49 +0100 Subject: Revert "Bratseth/tensorflow models" --- .../main/java/com/yahoo/tensor/DimensionSizes.java | 4 +- .../main/java/com/yahoo/tensor/IndexedTensor.java | 164 ++++++++++----------- .../main/java/com/yahoo/tensor/MappedTensor.java | 16 +- .../main/java/com/yahoo/tensor/MixedTensor.java | 6 +- .../src/main/java/com/yahoo/tensor/Tensor.java | 46 +++--- .../main/java/com/yahoo/tensor/TensorAddress.java | 24 +-- .../main/java/com/yahoo/tensor/TensorParser.java | 2 +- .../src/main/java/com/yahoo/tensor/TensorType.java | 46 +++--- .../yahoo/tensor/evaluation/EvaluationContext.java | 9 +- .../tensor/evaluation/MapEvaluationContext.java | 4 +- .../yahoo/tensor/evaluation/VariableTensor.java | 8 +- .../tensor/functions/CompositeTensorFunction.java | 2 +- .../java/com/yahoo/tensor/functions/Concat.java | 10 +- .../com/yahoo/tensor/functions/ConstantTensor.java | 6 +- .../main/java/com/yahoo/tensor/functions/Diag.java | 8 +- .../java/com/yahoo/tensor/functions/Generate.java | 12 +- .../main/java/com/yahoo/tensor/functions/Join.java | 54 ++----- .../main/java/com/yahoo/tensor/functions/Map.java | 3 - .../java/com/yahoo/tensor/functions/Matmul.java | 9 +- .../tensor/functions/PrimitiveTensorFunction.java | 4 +- .../java/com/yahoo/tensor/functions/Random.java | 6 +- .../java/com/yahoo/tensor/functions/Range.java | 8 +- .../java/com/yahoo/tensor/functions/Reduce.java | 41 ++---- .../java/com/yahoo/tensor/functions/Rename.java | 24 ++- .../yahoo/tensor/functions/ScalarFunctions.java | 97 ++++++------ .../java/com/yahoo/tensor/functions/Softmax.java | 6 - .../com/yahoo/tensor/functions/TensorFunction.java | 6 +- .../yahoo/tensor/serialization/BinaryFormat.java | 2 +- .../tensor/serialization/DenseBinaryFormat.java | 6 +- .../tensor/serialization/TypedBinaryFormat.java | 6 +- 30 files changed, 297 insertions(+), 342 deletions(-) (limited to 'vespajlib/src/main/java/com') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index f6237a1977a..00e106dd035 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -7,7 +7,7 @@ import java.util.Arrays; /** * The sizes of a set of dimensions. - * + * * @author bratseth */ @Beta @@ -48,7 +48,7 @@ public final class DimensionSizes { @Override public int hashCode() { return Arrays.hashCode(sizes); } - /** + /** * Builder of a set of dimension sizes. * Dimensions whose size is not set before building will get size 0. */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 6b0d769de9f..c207dabca3a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -25,12 +25,12 @@ public class IndexedTensor implements Tensor { /** The prescribed and possibly abstract type this is an instance of */ private final TensorType type; - + /** The sizes of the dimensions of this in the order of the dimensions of the type */ private final DimensionSizes dimensionSizes; - + private final double[] values; - + private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) { this.type = type; this.dimensionSizes = dimensionSizes; @@ -43,8 +43,8 @@ public class IndexedTensor implements Tensor { } /** - * Returns an iterator over the cells of this. - * Cells are returned in order of increasing indexes in each dimension, increasing + * Returns an iterator over the cells of this. + * Cells are returned in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. */ @Override @@ -69,7 +69,7 @@ public class IndexedTensor implements Tensor { /** * Returns an iterator over the values of this. - * Values are returned in order of increasing indexes in each dimension, increasing + * Values are returned in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. */ @Override @@ -81,7 +81,7 @@ public class IndexedTensor implements Tensor { * Returns an iterator over value iterators where the outer iterator is over each unique value of the dimensions * given and the inner iterator is over each unique value of the rest of the dimensions, in the same order as * other iterator. - * + * * @param dimensions the names of the dimensions of the superspace * @param sizes the size of each dimension in the space we are returning values for, containing * one value per dimension of this tensor (in order). Each size may be the same or smaller @@ -96,9 +96,9 @@ public class IndexedTensor implements Tensor { return subspaceIterator(dimensions, dimensionSizes); } - /** + /** * Returns the value at the given indexes - * + * * @param indexes the indexes into the dimensions of this. Must be one number per dimension of this * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ @@ -119,7 +119,7 @@ public class IndexedTensor implements Tensor { } private double get(int valueIndex) { return values[valueIndex]; } - + private static int toValueIndex(int[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed if (indexes.length == 0) return 0; // for speed @@ -165,7 +165,7 @@ public class IndexedTensor implements Tensor { public Map cells() { if (dimensionSizes.dimensions() == 0) return Collections.singletonMap(TensorAddress.of(), values[0]); - + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); for (int i = 0; i < values.length; i++) { @@ -174,13 +174,13 @@ public class IndexedTensor implements Tensor { } return builder.build(); } - + @Override public int hashCode() { return Arrays.hashCode(values); } @Override public String toString() { return Tensor.toStandardString(this); } - + @Override public boolean equals(Object other) { if ( ! ( other instanceof Tensor)) return false; @@ -188,9 +188,9 @@ public class IndexedTensor implements Tensor { } public abstract static class Builder implements Tensor.Builder { - + final TensorType type; - + private Builder(TensorType type) { this.type = type; } @@ -202,7 +202,7 @@ public class IndexedTensor implements Tensor { return new UnboundBuilder(type); } - /** + /** * Create a builder with dimension size information for this instance. Must be one size entry per dimension, * and, agree with the type size information when specified in the type. * If sizes are completely specified in the type this size information is redundant. @@ -210,16 +210,16 @@ public class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes) { // validate if (sizes.dimensions() != type.dimensions().size()) - throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + + throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + "for " + type); for (int i = 0; i < sizes.dimensions(); i++ ) { Optional size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) - throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + + throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + sizes.size(i) + " but cannot be larger than " + size.get() + " in " + type); } - + return new BoundBuilder(type, sizes); } @@ -232,7 +232,7 @@ public class IndexedTensor implements Tensor { public abstract IndexedTensor build(); } - + /** A bound builder can create the double array directly */ public static class BoundBuilder extends Builder { @@ -257,13 +257,13 @@ public class IndexedTensor implements Tensor { this.sizes = sizes; values = new double[sizes.totalSize()]; } - + @Override public BoundBuilder cell(double value, int ... indexes) { values[toValueIndex(indexes, sizes)] = value; return this; } - + @Override public CellBuilder cell() { return new CellBuilder(type, this); @@ -294,8 +294,8 @@ public class IndexedTensor implements Tensor { return this; } - /** - * Set a cell value by the index in the internal layout of this cell. + /** + * Set a cell value by the index in the internal layout of this cell. * This requires knowledge of the internal layout of cells in this implementation, and should therefore * probably not be used (but when it can be used it is fast). */ @@ -330,7 +330,7 @@ public class IndexedTensor implements Tensor { fillValues(0, 0, firstDimension, dimensionSizes, values); return new IndexedTensor(type, dimensionSizes, values); } - + private DimensionSizes findDimensionSizes(List firstDimension) { List dimensionSizeList = new ArrayList<>(type.dimensions().size()); findDimensionSizes(0, dimensionSizeList, firstDimension); @@ -347,16 +347,16 @@ public class IndexedTensor implements Tensor { if (currentDimensionIndex == dimensionSizes.size()) dimensionSizes.add(currentDimension.size()); else if (dimensionSizes.get(currentDimensionIndex) != currentDimension.size()) - throw new IllegalArgumentException("Missing values in dimension " + + throw new IllegalArgumentException("Missing values in dimension " + type.dimensions().get(currentDimensionIndex) + " in " + type); - + for (Object value : currentDimension) if (value instanceof List) findDimensionSizes(currentDimensionIndex + 1, dimensionSizes, (List)value); } @SuppressWarnings("unchecked") - private void fillValues(int currentDimensionIndex, int offset, List currentDimension, + private void fillValues(int currentDimensionIndex, int offset, List currentDimension, DimensionSizes sizes, double[] values) { if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension for (int i = 0; i < currentDimension.size(); i++) @@ -369,7 +369,7 @@ public class IndexedTensor implements Tensor { } } } - + private double nullAsZero(Double value) { if (value == null) return 0; return value; @@ -431,7 +431,7 @@ public class IndexedTensor implements Tensor { } } - + private final class CellIterator implements Iterator { private int count = 0; @@ -451,7 +451,7 @@ public class IndexedTensor implements Tensor { reusedCell.value = get(indexes.toSourceValueIndex()); return reusedCell; } - + } private final class ValueIterator implements Iterator { @@ -474,25 +474,25 @@ public class IndexedTensor implements Tensor { } } - + private final class SuperspaceIterator implements Iterator { private final Indexes superindexes; /** Those indexes this should iterate over */ private final List subdimensionIndexes; - - /** + + /** * The sizes of the space we'll return values of, one value for each dimension of this tensor, - * which may be equal to or smaller than the sizes of this tensor + * which may be equal to or smaller than the sizes of this tensor */ private final DimensionSizes iterateSizes; private int count = 0; - + private SuperspaceIterator(Set superdimensionNames, DimensionSizes iterateSizes) { this.iterateSizes = iterateSizes; - + List superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length) for (int i = type.dimensions().size() - 1; i >= 0; i-- ) { // iterate inner dimensions first @@ -501,10 +501,10 @@ public class IndexedTensor implements Tensor { else subdimensionIndexes.add(i); } - + superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, superdimensionIndexes); } - + @Override public boolean hasNext() { return count < superindexes.size(); @@ -527,7 +527,7 @@ public class IndexedTensor implements Tensor { */ public final class SubspaceIterator implements Iterator { - /** + /** * This iterator will iterate over the given dimensions, in the order given * (the first dimension index given is incremented to exhaustion first (i.e is etc.). * This may be any subset of the dimensions given by address and dimensionSizes. @@ -538,21 +538,21 @@ public class IndexedTensor implements Tensor { private Indexes indexes; private int count = 0; - + /** A lazy cell for reuse */ private final LazyCell reusedCell; - - /** + + /** * Creates a new subspace iterator - * + * * @param iterateDimensions the dimensions to iterate over, given as indexes in the dimension order of the * type of the tensor this iterates over. This iterator will iterate over these - * dimensions to exhaustion in the order given (the first dimension index given is + * dimensions to exhaustion in the order given (the first dimension index given is * incremented to exhaustion first (i.e is etc.), while other dimensions will be held * at a constant position. * This may be any subset of the dimensions given by address and dimensionSizes. * This is treated as immutable. - * @param address the address of the first cell of this subspace. + * @param address the address of the first cell of this subspace. */ private SubspaceIterator(List iterateDimensions, int[] address, DimensionSizes iterateSizes) { this.iterateDimensions = iterateDimensions; @@ -561,26 +561,26 @@ public class IndexedTensor implements Tensor { this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); reusedCell = new LazyCell(indexes, Double.NaN); } - + /** Returns the total number of cells in this subspace */ - public int size() { + public int size() { return indexes.size(); } - + /** Returns the address of the cell this currently points to (which may be an invalid position) */ public TensorAddress address() { return indexes.toAddress(); } - + /** Rewind this iterator to the first element */ - public void reset() { + public void reset() { this.count = 0; - this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); + this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address); } - + @Override public boolean hasNext() { - return count < indexes.size(); + return count < indexes.size(); } - + /** Returns the next cell, which is valid until next() is called again */ @Override public Cell next() { @@ -611,15 +611,15 @@ public class IndexedTensor implements Tensor { public TensorAddress getKey() { return indexes.toAddress(); } - + @Override public Double getValue() { return value; } } // TODO: Make dimensionSizes a class - - /** + + /** * An array of indexes into this tensor which are able to find the next index in the value order. * next() can be called once per element in the dimensions we iterate over. It must be called once * before accessing the first position. @@ -631,7 +631,7 @@ public class IndexedTensor implements Tensor { private final DimensionSizes iterationSizes; protected final int[] indexes; - + public static Indexes of(DimensionSizes sizes) { return of(sizes, sizes); } @@ -676,14 +676,14 @@ public class IndexedTensor implements Tensor { return new MultiDimensionIndexes(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, size); } } - + private static List completeIterationOrder(int length) { List iterationDimensions = new ArrayList<>(length); for (int i = 0; i < length; i++) iterationDimensions.add(length - 1 - i); return iterationDimensions; } - + private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) { this.sourceSizes = sourceSizes; this.iterationSizes = iterationSizes; @@ -708,9 +708,9 @@ public class IndexedTensor implements Tensor { /** Returns a copy of the indexes of this which must not be modified */ public int[] indexesForReading() { return indexes; } - - int toSourceValueIndex() { - return IndexedTensor.toValueIndex(indexes, sourceSizes); + + int toSourceValueIndex() { + return IndexedTensor.toValueIndex(indexes, sourceSizes); } int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); } @@ -729,9 +729,9 @@ public class IndexedTensor implements Tensor { public String toString() { return "indexes " + Arrays.toString(indexes); } - + public abstract int size(); - + public abstract void next(); } @@ -763,18 +763,18 @@ public class IndexedTensor implements Tensor { public void next() {} } - + private static class MultiDimensionIndexes extends Indexes { private final int size; private final List iterateDimensions; - + private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List iterateDimensions, int[] initialIndexes, int size) { super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimensions = iterateDimensions; this.size = size; - + // Initialize to the (virtual) position before the first cell indexes[iterateDimensions.get(0)]--; } @@ -785,10 +785,10 @@ public class IndexedTensor implements Tensor { return size; } - /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. - * + /** + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. + * * @throws RuntimeException if this is called more times than its size */ @Override @@ -802,12 +802,12 @@ public class IndexedTensor implements Tensor { } } - + /** In this case we can reuse the source index computation for the iteration index */ private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes { private int lastComputedSourceValueIndex = -1; - + private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List iterateDimensions, int[] initialIndexes, int size) { super(sizes, sizes, iterateDimensions, initialIndexes, size); } @@ -827,7 +827,7 @@ public class IndexedTensor implements Tensor { private final int size; private final int iterateDimension; - + /** Maintain this directly as an optimization for 1-d iteration */ private int currentSourceValueIndex, currentIterationValueIndex; @@ -847,7 +847,7 @@ public class IndexedTensor implements Tensor { currentSourceValueIndex = IndexedTensor.toValueIndex(indexes, sourceSizes); currentIterationValueIndex = IndexedTensor.toValueIndex(indexes, iterateSizes); } - + /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @Override public int size() { @@ -855,8 +855,8 @@ public class IndexedTensor implements Tensor { } /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. * * @throws RuntimeException if this is called more times than its size */ @@ -888,7 +888,7 @@ public class IndexedTensor implements Tensor { /** The iteration step in the value index space */ private final int step; - private EqualSizeSingleDimensionIndexes(DimensionSizes sizes, + private EqualSizeSingleDimensionIndexes(DimensionSizes sizes, int iterateDimension, int[] initialIndexes, int size) { super(sizes, sizes, initialIndexes); this.iterateDimension = iterateDimension; @@ -907,8 +907,8 @@ public class IndexedTensor implements Tensor { } /** - * Advances this to the next cell in the standard indexed tensor cell order. - * The first call to this will put it at the first position. + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. * * @throws RuntimeException if this is called more times than its size */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index aba61478e69..618bff0caae 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -27,7 +27,7 @@ public class MappedTensor implements Tensor { @Override public TensorType type() { return type; } - + @Override public int size() { return cells.size(); } @@ -56,16 +56,16 @@ public class MappedTensor implements Tensor { } public static class Builder implements Tensor.Builder { - + private final TensorType type; private final ImmutableMap.Builder cells = new ImmutableMap.Builder<>(); - + public static Builder of(TensorType type) { return new Builder(type); } private Builder(TensorType type) { this.type = type; } - + public CellBuilder cell() { return new CellBuilder(type, this); } @@ -89,24 +89,24 @@ public class MappedTensor implements Tensor { public MappedTensor build() { return new MappedTensor(type, cells.build()); } - + } private static class CellIteratorAdaptor implements Iterator { private final Iterator> adaptedIterator; - + private CellIteratorAdaptor(Iterator> adaptedIterator) { this.adaptedIterator = adaptedIterator; } - + @Override public boolean hasNext() { return adaptedIterator.hasNext(); } @Override public Cell next() { Map.Entry entry = adaptedIterator.next(); - return new Cell(entry.getKey(), entry.getValue()); + return new Cell(entry.getKey(), entry.getValue()); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 9a751e078e0..79bb27fcd1b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -117,7 +117,7 @@ public class MixedTensor implements Tensor { return index.denseSubspaceSize(); } - + /** * Base class for building mixed tensors. */ @@ -286,7 +286,7 @@ public class MixedTensor implements Tensor { } return typeBuilder.build(); } - + } /** @@ -360,7 +360,7 @@ public class MixedTensor implements Tensor { } return denseSubspaceSize; } - + private TensorAddress sparsePartialAddress(TensorAddress address) { if (type.dimensions().size() != address.size()) { throw new IllegalArgumentException("Tensor type and address are not of same size."); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 1b60e01cf7e..2ed211539d8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -52,7 +52,7 @@ import java.util.function.Function; public interface Tensor { // ----------------- Accessors - + TensorType type(); /** Returns whether this have any cells */ @@ -70,13 +70,13 @@ public interface Tensor { /** Returns the values of this in some undefined order */ Iterator valueIterator(); - /** + /** * Returns an immutable map of the cells of this in no particular order. - * This may be expensive for some implementations - avoid when possible + * This may be expensive for some implementations - avoid when possible */ Map cells(); - /** + /** * Returns the value of this as a double if it has no dimensions and one value * * @throws IllegalStateException if this does not have zero dimensions and one value @@ -87,9 +87,9 @@ public interface Tensor { if (size() == 0) return Double.NaN; return valueIterator().next(); } - + // ----------------- Primitive tensor functions - + default Tensor map(DoubleUnaryOperator mapper) { return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate(); } @@ -108,7 +108,7 @@ public interface Tensor { } default Tensor rename(String fromDimension, String toDimension) { - return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), + return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), Collections.singletonList(toDimension)).evaluate(); } @@ -123,13 +123,13 @@ public interface Tensor { default Tensor rename(List fromDimensions, List toDimensions) { return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); } - + static Tensor generate(TensorType type, Function, Double> valueSupplier) { return new Generate(type, valueSupplier).evaluate(); } - + // ----------------- Composite tensor functions which have a defined primitive mapping - + default Tensor l1Normalize(String dimension) { return new L1Normalize(new ConstantTensor(this), dimension).evaluate(); } @@ -231,7 +231,7 @@ public interface Tensor { if (cellEntries.isEmpty()) return "{}"; return "{" + cellEntries.get(0).getValue() +"}"; } - + Collections.sort(cellEntries, java.util.Map.Entry.comparingByKey()); StringBuilder b = new StringBuilder("{"); @@ -253,7 +253,7 @@ public interface Tensor { */ boolean equals(Object o); - /** + /** * Implement here to make this work across implementations. * Implementations must override equals and call this because this is an interface and cannot override equals. */ @@ -328,13 +328,13 @@ public interface Tensor { @Override public TensorAddress getKey() { return address; } - /** + /** * Returns the direct index which can be used to locate this cell, or -1 if not available. * This is for optimizations mapping between tensors where this is possible without creating a * TensorAddress. */ int getDirectIndex() { return -1; } - + @Override public Double getValue() { return value; } @@ -388,20 +388,20 @@ public interface Tensor { /** Returns the type this is building */ TensorType type(); - + /** Return a cell builder */ CellBuilder cell(); /** Add a cell */ Builder cell(TensorAddress address, double value); - + /** Add a cell */ Builder cell(double value, int ... labels); - /** - * Add a cell - * - * @param cell a cell providing the location at which to add this cell + /** + * Add a cell + * + * @param cell a cell providing the location at which to add this cell * @param value the value to assign to the cell */ default Builder cell(Cell cell, double value) { @@ -409,12 +409,12 @@ public interface Tensor { } Tensor build(); - + class CellBuilder { private final TensorAddress.Builder addressBuilder; private final Tensor.Builder tensorBuilder; - + CellBuilder(TensorType type, Tensor.Builder tensorBuilder) { addressBuilder = new TensorAddress.Builder(type); this.tensorBuilder = tensorBuilder; @@ -436,5 +436,5 @@ public interface Tensor { } } - + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index ff1202463f2..7161450d5d5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -32,10 +32,10 @@ public abstract class TensorAddress implements Comparable { /** Returns the number of labels in this */ public abstract int size(); - + /** - * Returns the i'th label in this - * + * Returns the i'th label in this + * * @throws IllegalArgumentException if there is no label at this index */ public abstract String label(int i); @@ -102,23 +102,23 @@ public abstract class TensorAddress implements Comparable { private StringTensorAddress(String ... labels) { this.labels = Arrays.copyOf(labels, labels.length); } - + @Override public int size() { return labels.length; } - + @Override public String label(int i) { return labels[i]; } - + @Override - public int intLabel(int i) { + public int intLabel(int i) { try { return Integer.parseInt(labels[i]); - } + } catch (NumberFormatException e) { throw new IllegalArgumentException("Expected an int label in " + this + " at position " + i); } } - + @Override public TensorAddress withLabel(int index, int label) { String[] labels = Arrays.copyOf(this.labels, this.labels.length); @@ -169,7 +169,7 @@ public abstract class TensorAddress implements Comparable { private final TensorType type; private final String[] labels; - + public Builder(TensorType type) { this(type, new String[type.dimensions().size()]); } @@ -193,7 +193,7 @@ public abstract class TensorAddress implements Comparable { labels[labelIndex.get()] = label; return this; } - + /** Creates a copy of this which can be modified separately */ public Builder copy() { return new Builder(type, Arrays.copyOf(labels, labels.length)); @@ -202,7 +202,7 @@ public abstract class TensorAddress implements Comparable { public TensorAddress build() { for (int i = 0; i < labels.length; i++) if (labels[i] == null) - throw new IllegalArgumentException("Missing a value for dimension " + + throw new IllegalArgumentException("Missing a value for dimension " + type.dimensions().get(i).name() + " for " + type); return TensorAddress.of(labels); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 9b3a9328f07..da8ab3bb0ec 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -96,7 +96,7 @@ class TensorParser { if (valueEnd < 0) throw new IllegalArgumentException("A tensor string must end by '}'"); } - + TensorAddress address = addressBuilder.build(); Double value = asDouble(address, s.substring(0, valueEnd).trim()); builder.cell(address, value); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 914d853aeca..c05c35d6df3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -53,17 +53,14 @@ public class TensorType { return TensorTypeParser.fromSpec(specString); } - /** Returns the number of dimensions of this: dimensions().size() */ - public int rank() { return dimensions.size(); } - /** Returns an immutable list of the dimensions of this */ public List dimensions() { return dimensions; } - + /** Returns an immutable set of the names of the dimensions of this */ public Set dimensionNames() { return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); } - + /** Returns the dimension with this name, or empty if not present */ public Optional dimension(String name) { return indexOfDimension(name).map(i -> dimensions.get(i)); @@ -77,7 +74,7 @@ public class TensorType { return Optional.empty(); } - /** + /** * Returns whether this type can be assigned to the given type, * i.e if the given type is a generalization of this type. */ @@ -131,9 +128,9 @@ public class TensorType { private final String name; - private Dimension(String name) { + private Dimension(String name) { Objects.requireNonNull(name, "A tensor name cannot be null"); - this.name = name; + this.name = name; } public final String name() { return name; } @@ -149,7 +146,7 @@ public class TensorType { /** Returns true if this is an indexed bound or unboun type */ public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; } - /** + /** * Returns the dimension resulting from combining two dimensions having the same name but possibly different * types. This works by degrading to the type making the fewer promises. * [N] + [M] = [min(N, M)] @@ -168,7 +165,7 @@ public class TensorType { IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get(); return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; } - + @Override public abstract String toString(); @@ -178,21 +175,21 @@ public class TensorType { if (other == null || getClass() != other.getClass()) return false; return name.equals(((Dimension)other).name); } - + @Override public int hashCode() { return name.hashCode(); } - + @Override public int compareTo(Dimension other) { return this.name.compareTo(other.name); } - + public static Dimension indexed(String name, int size) { return new IndexedBoundDimension(name, size); } - + } public static class IndexedBoundDimension extends TensorType.Dimension { @@ -292,9 +289,9 @@ public class TensorType { public Builder() { } - /** - * Creates a builder containing a combination of the dimensions of the given types - * + /** + * Creates a builder containing a combination of the dimensions of the given types + * * If the same dimension is indexed with different size restrictions the largest size will be used. * If it is size restricted in one argument but not the other it will not be size restricted. * If it is indexed in one and mapped in the other it will become mapped. @@ -328,12 +325,9 @@ public class TensorType { } } - /** Returns the current number of dimensions in this */ - public int rank() { return dimensions.size(); } - - /** + /** * Adds a new dimension to this - * + * * @throws IllegalArgumentException if the dimension is already present */ private Builder add(Dimension dimension) { @@ -352,7 +346,7 @@ public class TensorType { return this; } - /** + /** * Adds a bound indexed dimension to this * * @throws IllegalArgumentException if the dimension is already present @@ -361,7 +355,7 @@ public class TensorType { /** * Adds an unbound indexed dimension to this - * + * * @throws IllegalArgumentException if the dimension is already present */ public Builder indexed(String name) { @@ -381,7 +375,7 @@ public class TensorType { public Builder dimension(Dimension dimension) { return add(dimension); } - + /** Returns the given dimension, or empty if none is present */ public Optional getDimension(String dimension) { return Optional.ofNullable(dimensions.get(dimension)); @@ -399,7 +393,7 @@ public class TensorType { public TensorType build() { return new TensorType(dimensions.values()); } - + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index 3db661f8a23..84caca78fb2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -2,17 +2,16 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; -import com.yahoo.tensor.Tensor; + +import java.util.HashMap; /** * An evaluation context which is passed down to all nested functions during evaluation. - * + * The default context is empty to allow various evaluation frameworks to support their own implementation. + * * @author bratseth */ @Beta public interface EvaluationContext { - /** Returns the tensor bound to this name, or null if none */ - Tensor getTensor(String name); - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index db8a66a5fa2..cf704c15f4f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -18,7 +18,7 @@ public class MapEvaluationContext implements EvaluationContext { public void put(String name, Tensor tensor) { bindings.put(name, tensor); } - @Override - public Tensor getTensor(String name) { return bindings.get(name); } + /** Returns the tensor bound to this name, or null if none */ + public Tensor get(String name) { return bindings.get(name); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 1f6ad050368..8ade181bdb7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -12,18 +12,18 @@ import java.util.List; /** * A tensor variable name which resolves to a tensor in the context at evaluation time - * + * * @author bratseth */ @Beta public class VariableTensor extends PrimitiveTensorFunction { private final String name; - + public VariableTensor(String name) { this.name = name; } - + @Override public List functionArguments() { return Collections.emptyList(); } @@ -35,7 +35,7 @@ public class VariableTensor extends PrimitiveTensorFunction { @Override public Tensor evaluate(EvaluationContext context) { - return context.getTensor(name); + return ((MapEvaluationContext)context).get(name); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 191c7988443..8f4dbf014a7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -8,7 +8,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; /** * A composite tensor function is a tensor function which can be expressed (less tersely) * as a tree of primitive tensor functions. - * + * * @author bratseth */ @Beta diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index faa0ca36cb6..1dbb94fdb20 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -15,7 +15,7 @@ import java.util.stream.Collectors; /** * Concatenation of two tensors along an (indexed) dimension - * + * * @author bratseth */ @Beta @@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction { concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder); return builder.build(); } - + private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { Set otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); @@ -112,7 +112,7 @@ public class Concat extends PrimitiveTensorFunction { Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build(); return tensor.multiply(unitTensor); } - + } /** Returns the type resulting from concatenating a and b */ @@ -144,7 +144,7 @@ public class Concat extends PrimitiveTensorFunction { /** * Combine two addresses, adding the offset to the concat dimension * - * @return the combined address or null if the addresses are incompatible + * @return the combined address or null if the addresses are incompatible * (in some other dimension than the concat dimension) */ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, @@ -161,7 +161,7 @@ public class Concat extends PrimitiveTensorFunction { /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then + * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 14ed38718ce..4ac7b21ba90 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -10,18 +10,18 @@ import java.util.List; /** * A function which returns a constant tensor. - * + * * @author bratseth */ @Beta public class ConstantTensor extends PrimitiveTensorFunction { private final Tensor constant; - + public ConstantTensor(String tensorString) { this.constant = Tensor.from(tensorString); } - + public ConstantTensor(Tensor tensor) { this.constant = tensor; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index c75d8ee4753..bbdbd5c3df1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -11,19 +11,19 @@ import java.util.stream.Stream; /** * A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere. - * + * * @author bratseth */ public class Diag extends CompositeTensorFunction { private final TensorType type; private final Function, Double> diagFunction; - + public Diag(TensorType type) { this.type = type; this.diagFunction = ScalarFunctions.equal(dimensionNames().collect(Collectors.toList())); } - + @Override public List functionArguments() { return Collections.emptyList(); } @@ -43,7 +43,7 @@ public class Diag extends CompositeTensorFunction { public String toString(ToStringContext context) { return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction; } - + private Stream dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::name); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index e42d25197e2..6ea73b7f310 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -15,7 +15,7 @@ import java.util.function.Function; /** * An indexed tensor whose values are generated by a function - * + * * @author bratseth */ @Beta @@ -26,7 +26,7 @@ public class Generate extends PrimitiveTensorFunction { /** * Creates a generated tensor - * + * * @param type the type of the tensor * @param generator the function generating values from a list of ints specifying the indexes of the * tensor cell which will receive the value @@ -39,7 +39,7 @@ public class Generate extends PrimitiveTensorFunction { this.type = type; this.generator = generator; } - + private void validateType(TensorType type) { for (TensorType.Dimension dimension : type.dimensions()) if (dimension.type() != TensorType.Dimension.Type.indexedBound) @@ -58,7 +58,7 @@ public class Generate extends PrimitiveTensorFunction { @Override public PrimitiveTensorFunction toPrimitive() { return this; } - + @Override public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type); @@ -69,14 +69,14 @@ public class Generate extends PrimitiveTensorFunction { } return builder.build(); } - + private DimensionSizes dimensionSizes(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); for (int i = 0; i < b.dimensions(); i++) b.set(i, type.dimensions().get(i).size().get()); return b.build(); } - + @Override public String toString(ToStringContext context) { return type + "(" + generator + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index ff887e3e9a6..8c4dbfb0acb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -28,12 +28,12 @@ import java.util.function.DoubleBinaryOperator; * The join tensor operation produces a tensor from the argument tensors containing the set of cells * given by the cross product of the cells of the given tensors, having as values the value produced by * applying the given combinator function on the values from the two source cells. - * + * * @author bratseth */ @Beta public class Join extends PrimitiveTensorFunction { - + private final TensorFunction argumentA, argumentB; private final DoubleBinaryOperator combinator; @@ -46,30 +46,6 @@ public class Join extends PrimitiveTensorFunction { this.combinator = combinator; } - /** Returns the type resulting from applying Join to the two given types */ - public static TensorType outputType(TensorType a, TensorType b) { - TensorType.Builder typeBuilder = new TensorType.Builder(); - for (int i = 0; i < a.dimensions().size(); ++i) { - TensorType.Dimension aDim = a.dimensions().get(i); - for (int j = 0; j < b.dimensions().size(); ++j) { - TensorType.Dimension bDim = b.dimensions().get(j); - if (aDim.name().equals(bDim.name())) { // include - if (aDim.isIndexed() && bDim.isIndexed()) { - if (aDim.size().isPresent() || bDim.size().isPresent()) - typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), - bDim.size().orElse(Integer.MAX_VALUE))); - else - typeBuilder.indexed(aDim.name()); - } - else { - typeBuilder.mapped(aDim.name()); - } - } - } - } - return typeBuilder.build(); - } - public TensorFunction argumentA() { return argumentA; } public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } @@ -112,11 +88,11 @@ public class Join extends PrimitiveTensorFunction { else return generalJoin(a, b, joinedType); } - + private boolean hasSingleIndexedDimension(Tensor tensor) { return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); } - + private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator aIterator = a.valueIterator(); @@ -138,7 +114,7 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } - + /** Join a tensor into a superspace */ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) @@ -150,7 +126,7 @@ public class Join extends PrimitiveTensorFunction { private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build(); - + DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace); IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes); @@ -158,14 +134,14 @@ public class Join extends PrimitiveTensorFunction { // Find dimensions which are only in the supertype Set superDimensionNames = new HashSet<>(superspace.type().dimensionNames()); superDimensionNames.removeAll(subspace.type().dimensionNames()); - + for (Iterator i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) { IndexedTensor.SubspaceIterator subspaceInSuper = i.next(); joinSubspaces(subspace.valueIterator(), subspace.size(), subspaceInSuper, subspaceInSuper.size(), reversedArgumentOrder, builder); } - + return builder.build(); } @@ -224,7 +200,7 @@ public class Join extends PrimitiveTensorFunction { subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get(); return subspaceIndexes; } - + private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { String[] subspaceLabels = new String[subspaceIndexes.length]; for (int i = 0; i < subspaceIndexes.length; i++) @@ -259,7 +235,7 @@ public class Join extends PrimitiveTensorFunction { DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize); // for each combination of dimensions only in a - for (Iterator ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { + for (Iterator ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { IndexedTensor.SubspaceIterator aSubspace = ia.next(); // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { @@ -276,7 +252,7 @@ public class Join extends PrimitiveTensorFunction { } } } - + private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set retainDimensions) { PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); for (int i = 0; i < addressType.dimensions().size(); i++) @@ -284,7 +260,7 @@ public class Join extends PrimitiveTensorFunction { builder.add(addressType.dimensions().get(i).name(), address.intLabel(i)); return builder.build(); } - + /** Returns the sizes from the joined sizes which are present in the type argument */ private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) { DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); @@ -295,7 +271,7 @@ public class Join extends PrimitiveTensorFunction { } return builder.build(); } - + private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); @@ -364,7 +340,7 @@ public class Join extends PrimitiveTensorFunction { /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then + * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ @@ -384,7 +360,7 @@ public class Join extends PrimitiveTensorFunction { return TensorAddress.of(joinedLabels); } - /** + /** * Maps the content in the given list to the given array, using the given index map. * * @return true if the mapping was successful, false if one of the destination positions was diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index a5e1a016a41..a9872bb42d8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Collections; @@ -33,8 +32,6 @@ public class Map extends PrimitiveTensorFunction { this.mapper = mapper; } - public static TensorType outputType(TensorType inputType) { return inputType; } - public TensorFunction argument() { return argument; } public DoubleUnaryOperator mapper() { return mapper; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 4071917c2b5..bb27e937699 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -3,7 +3,6 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; -import com.yahoo.tensor.TensorType; import java.util.List; @@ -15,17 +14,13 @@ public class Matmul extends CompositeTensorFunction { private final TensorFunction argument1, argument2; private final String dimension; - + public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { this.argument1 = argument1; this.argument2 = argument2; this.dimension = dimension; } - public static TensorType outputType(TensorType a, TensorType b, String dimension) { - return Join.outputType(a, b); - } - @Override public List functionArguments() { return ImmutableList.of(argument1, argument2); } @@ -44,7 +39,7 @@ public class Matmul extends CompositeTensorFunction { Reduce.Aggregator.sum, dimension); } - + @Override public String toString(ToStringContext context) { return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java index b7c9a5d2342..efb7b9e500c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java @@ -8,10 +8,10 @@ import com.yahoo.tensor.Tensor; * A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions. * All tensor implementations must implement all primitive tensor functions. * Primitive tensor functions are fully inspectable. - * + * * @author bratseth */ @Beta public abstract class PrimitiveTensorFunction extends TensorFunction { - + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 958ef85d1dc..457763e97ba 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -22,11 +22,11 @@ import java.util.stream.Stream; public class Random extends CompositeTensorFunction { private final TensorType type; - + public Random(TensorType type) { this.type = type; } - + @Override public List functionArguments() { return Collections.emptyList(); } @@ -46,7 +46,7 @@ public class Random extends CompositeTensorFunction { public String toString(ToStringContext context) { return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")"; } - + private Stream dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::toString); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index a56f82b026a..e2b39a2048d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -12,19 +12,19 @@ import java.util.stream.Stream; /** * A tensor generator which returns a tensor of any dimension filled with the sum of the tensor * indexes of each position. - * + * * @author bratseth */ public class Range extends CompositeTensorFunction { private final TensorType type; private final Function, Double> rangeFunction; - + public Range(TensorType type) { this.type = type; this.rangeFunction = ScalarFunctions.sum(dimensionNames().collect(Collectors.toList())); } - + @Override public List functionArguments() { return Collections.emptyList(); } @@ -44,7 +44,7 @@ public class Range extends CompositeTensorFunction { public String toString(ToStringContext context) { return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction; } - + private Stream dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::toString); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index de9f90a5804..cfc78be7e0c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -19,7 +19,7 @@ import java.util.Objects; import java.util.Set; /** - * The reduce tensor operation returns a tensor produced from the argument tensor where some dimensions + * The reduce tensor operation returns a tensor produced from the argument tensor where some dimensions * are collapsed to a single value using an aggregator function. * * @author bratseth @@ -45,7 +45,7 @@ public class Reduce extends PrimitiveTensorFunction { /** * Creates a reduce function. - * + * * @param argument the tensor to reduce * @param aggregator the aggregator function to use * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced, @@ -61,15 +61,6 @@ public class Reduce extends PrimitiveTensorFunction { this.dimensions = ImmutableList.copyOf(dimensions); } - public static TensorType outputType(TensorType inputType, List reduceDimensions) { - TensorType.Builder b = new TensorType.Builder(); - for (TensorType.Dimension dimension : inputType.dimensions()) { - if ( ! reduceDimensions.contains(dimension.name())) - b.dimension(dimension); - } - return b.build(); - } - public TensorFunction argument() { return argument; } @Override @@ -91,7 +82,7 @@ public class Reduce extends PrimitiveTensorFunction { public String toString(ToStringContext context) { return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } - + private String commaSeparated(List list) { StringBuilder b = new StringBuilder(); for (String element : list) @@ -103,7 +94,7 @@ public class Reduce extends PrimitiveTensorFunction { public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) - throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + + throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor"); // Special case: Reduce all @@ -112,14 +103,14 @@ public class Reduce extends PrimitiveTensorFunction { return reduceIndexedVector((IndexedTensor)argument); else return reduceAllGeneral(argument); - + // Reduce type TensorType.Builder builder = new TensorType.Builder(); for (TensorType.Dimension dimension : argument.type().dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); TensorType reducedType = builder.build(); - + // Reduce cells Map aggregatingCells = new HashMap<>(); for (Iterator i = argument.cellIterator(); i.hasNext(); ) { @@ -131,10 +122,10 @@ public class Reduce extends PrimitiveTensorFunction { Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); for (Map.Entry aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); - + return reducedBuilder.build(); } - + private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) { Set indexesToRemove = new HashSet<>(); for (String dimensionToRemove : this.dimensions) @@ -147,7 +138,7 @@ public class Reduce extends PrimitiveTensorFunction { reducedLabels[reducedLabelIndex++] = address.label(i); return TensorAddress.of(reducedLabels); } - + private Tensor reduceAllGeneral(Tensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator i = argument.valueIterator(); i.hasNext(); ) @@ -163,7 +154,7 @@ public class Reduce extends PrimitiveTensorFunction { } private static abstract class ValueAggregator { - + private static ValueAggregator ofType(Aggregator aggregator) { switch (aggregator) { case avg : return new AvgAggregator(); @@ -174,22 +165,22 @@ public class Reduce extends PrimitiveTensorFunction { case min : return new MinAggregator(); default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented"); } - + } /** Add a new value to those aggregated by this */ public abstract void aggregate(double value); - + /** Returns the value aggregated by this */ public abstract double aggregatedValue(); - + } - + private static class AvgAggregator extends ValueAggregator { private int valueCount = 0; private double valueSum = 0.0; - + @Override public void aggregate(double value) { valueCount++; @@ -197,7 +188,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public double aggregatedValue() { + public double aggregatedValue() { return valueSum / valueCount; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index ec9b762a41c..6b0daf1b49a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -3,6 +3,8 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -17,7 +19,7 @@ import java.util.Objects; /** * The rename tensor function returns a tensor where some dimensions are assigned new names. - * + * * @author bratseth */ @Beta @@ -27,10 +29,6 @@ public class Rename extends PrimitiveTensorFunction { private final List fromDimensions; private final List toDimensions; - public Rename(TensorFunction argument, String fromDimension, String toDimension) { - this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); - } - public Rename(TensorFunction argument, List fromDimensions, List toDimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null"); @@ -44,7 +42,7 @@ public class Rename extends PrimitiveTensorFunction { this.fromDimensions = ImmutableList.copyOf(fromDimensions); this.toDimensions = ImmutableList.copyOf(toDimensions); } - + @Override public List functionArguments() { return Collections.singletonList(argument); } @@ -64,7 +62,7 @@ public class Rename extends PrimitiveTensorFunction { Map fromToMap = fromToMap(); TensorType renamedType = rename(tensor.type(), fromToMap); - + // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; for (int i = 0; i < tensor.type().dimensions().size(); i++) { @@ -72,7 +70,7 @@ public class Rename extends PrimitiveTensorFunction { String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName); toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get(); } - + Tensor.Builder builder = Tensor.Builder.of(renamedType); for (Iterator i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry cell = i.next(); @@ -88,7 +86,7 @@ public class Rename extends PrimitiveTensorFunction { builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); return builder.build(); } - + private TensorAddress rename(TensorAddress address, int[] toIndexes) { String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) @@ -97,18 +95,18 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { - return "rename(" + argument.toString(context) + ", " + + public String toString(ToStringContext context) { + return "rename(" + argument.toString(context) + ", " + toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } - + private Map fromToMap() { Map map = new HashMap<>(); for (int i = 0; i < fromDimensions.size(); i++) map.put(fromDimensions.get(i), toDimensions.get(i)); return map; } - + private String toVectorString(List elements) { if (elements.size() == 1) return elements.get(0); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index fb5029fbfd6..99f79cb735a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -21,87 +21,101 @@ import java.util.stream.Collectors; @Beta public class ScalarFunctions { - public static DoubleBinaryOperator add() { return new Add(); } - public static DoubleBinaryOperator divide() { return new Divide(); } + public static DoubleBinaryOperator add() { return new Addition(); } + public static DoubleBinaryOperator multiply() { return new Multiplication(); } + public static DoubleBinaryOperator divide() { return new Division(); } public static DoubleBinaryOperator equal() { return new Equal(); } - public static DoubleBinaryOperator multiply() { return new Multiply(); } - - public static DoubleUnaryOperator acos() { return new Acos(); } - public static DoubleUnaryOperator exp() { return new Exp(); } - public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator square() { return new Square(); } - + public static DoubleUnaryOperator sqrt() { return new Sqrt(); } + public static DoubleUnaryOperator exp() { return new Exponent(); } public static Function, Double> random() { return new Random(); } public static Function, Double> equal(List argumentNames) { return new EqualElements(argumentNames); } public static Function, Double> sum(List argumentNames) { return new SumElements(argumentNames); } - // Binary operators ----------------------------------------------------------------------------- + public static class Addition implements DoubleBinaryOperator { - public static class Add implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left + right; } + @Override public String toString() { return "f(a,b)(a + b)"; } - } - public static class Equal implements DoubleBinaryOperator { - @Override - public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; } - @Override - public String toString() { return "f(a,b)(a==b)"; } } - public static class Exp implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return Math.exp(operand); } - @Override - public String toString() { return "f(a)(exp(a))"; } - } + public static class Multiplication implements DoubleBinaryOperator { - public static class Multiply implements DoubleBinaryOperator { @Override - public double applyAsDouble(double left, double right) { return left * right; } + public double applyAsDouble(double left, double right) { return left * right; } + @Override public String toString() { return "f(a,b)(a * b)"; } + } - public static class Divide implements DoubleBinaryOperator { + public static class Division implements DoubleBinaryOperator { + @Override public double applyAsDouble(double left, double right) { return left / right; } + @Override public String toString() { return "f(a,b)(a / b)"; } } - // Unary operators ------------------------------------------------------------------------------ + public static class Equal implements DoubleBinaryOperator { + + @Override + public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; } + + @Override + public String toString() { return "f(a,b)(a==b)"; } + } + + public static class Square implements DoubleUnaryOperator { - public static class Acos implements DoubleUnaryOperator { @Override - public double applyAsDouble(double operand) { return Math.acos(operand); } + public double applyAsDouble(double operand) { return operand * operand; } + @Override - public String toString() { return "f(a)(acos(a))"; } + public String toString() { return "f(a)(a * a)"; } + } public static class Sqrt implements DoubleUnaryOperator { + @Override public double applyAsDouble(double operand) { return Math.sqrt(operand); } + @Override public String toString() { return "f(a)(sqrt(a))"; } + } - public static class Square implements DoubleUnaryOperator { + public static class Exponent implements DoubleUnaryOperator { @Override - public double applyAsDouble(double operand) { return operand * operand; } + public double applyAsDouble(double operand) { return Math.exp(operand); } @Override - public String toString() { return "f(a)(a * a)"; } + public String toString() { return "f(a)(exp(a))"; } } - // Variable-length operators ----------------------------------------------------------------------------- + public static class Random implements Function, Double> { + + @Override + public Double apply(List values) { + return ThreadLocalRandom.current().nextDouble(); + } + + @Override + public String toString() { return "random"; } - public static class EqualElements implements Function, Double> { - private final ImmutableList argumentNames; + } + + public static class EqualElements implements Function, Double> { + + private final ImmutableList argumentNames; + private EqualElements(List argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @@ -114,6 +128,7 @@ public class ScalarFunctions { return 0.0; return 1.0; } + @Override public String toString() { if (argumentNames.size() == 0) return "1"; @@ -128,19 +143,13 @@ public class ScalarFunctions { } return b.toString(); } - } - public static class Random implements Function, Double> { - @Override - public Double apply(List values) { - return ThreadLocalRandom.current().nextDouble(); - } - @Override - public String toString() { return "random"; } } public static class SumElements implements Function, Double> { + private final ImmutableList argumentNames; + private SumElements(List argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @@ -152,10 +161,12 @@ public class ScalarFunctions { sum += value; return (double)sum; } + @Override public String toString() { return argumentNames.stream().collect(Collectors.joining("+")); } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index c856b548180..bf279eb24d8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -2,8 +2,6 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableList; -import com.yahoo.tensor.TensorType; import java.util.Collections; import java.util.List; @@ -21,10 +19,6 @@ public class Softmax extends CompositeTensorFunction { this.argument = argument; this.dimension = dimension; } - - public static TensorType outputType(TensorType inputType, String dimension) { - return Reduce.outputType(inputType, ImmutableList.of(dimension)); - } @Override public List functionArguments() { return Collections.singletonList(argument); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 533a46f87fe..cabcce198d1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -12,7 +12,7 @@ import java.util.List; * A representation of a tensor function which is able to be translated to a set of primitive * tensor functions if necessary. * All tensor functions are immutable. - * + * * @author bratseth */ @Beta @@ -48,11 +48,11 @@ public abstract class TensorFunction { /** * Return a string representation of this context. - * + * * @param context a context which must be passed to all nexted functions when requesting the string value */ public abstract String toString(ToStringContext context); - + @Override public String toString() { return toString(ToStringContext.empty()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java index 416b28afa22..e8c425d49e0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java @@ -24,7 +24,7 @@ interface BinaryFormat { /** * Deserialize the given binary data into a Tensor object. - * + * * @param type the expected abstract type of the tensor to serialize, or empty to use type information from the data * @param buffer the buffer containing the tensor binary data */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index aabb53d1c67..8b7325ec211 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -16,9 +16,9 @@ import java.util.Optional; * * Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]* * Cell_values = [double, double, double, ...]* - * where values are encoded in order of increasing indexes in each dimension, increasing + * where values are encoded in order of increasing indexes in each dimension, increasing * indexes of later dimensions in the dimension type before earlier. - * + * * @author bratseth */ @Beta @@ -54,7 +54,7 @@ public class DenseBinaryFormat implements BinaryFormat { type = optionalType.get(); TensorType serializedType = decodeType(buffer); if ( ! serializedType.isAssignableTo(type)) - throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + + throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + " cannot be assigned to type " + type); sizes = sizesFromType(serializedType); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 01a1d023f2b..7467554790a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -46,9 +46,9 @@ public class TypedBinaryFormat { return result; } - /** - * Decode some data to a tensor - * + /** + * Decode some data to a tensor + * * @param type the type to decode and validate to, or empty to use the type given in the data * @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array * @return the resulting tensor -- cgit v1.2.3