summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2017-12-18 09:14:37 +0100
committerGitHub <noreply@github.com>2017-12-18 09:14:37 +0100
commit9347da6b81bd1f723d754fea2add617268ea90fa (patch)
tree6b6489e089b2ff9c5d67599d6be55d694c9ee99b /vespajlib
parentdbf5328bdc8daed3e4111742e2f6e0a48277e3d3 (diff)
Revert "Revert "Bratseth/tensorflow models""
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java164
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java46
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java46
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java54
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java41
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java97
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java14
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java14
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java10
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java97
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java2
39 files changed, 467 insertions, 325 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index 00e106dd035..f6237a1977a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -7,7 +7,7 @@ import java.util.Arrays;
/**
* The sizes of a set of dimensions.
- *
+ *
* @author bratseth
*/
@Beta
@@ -48,7 +48,7 @@ public final class DimensionSizes {
@Override
public int hashCode() { return Arrays.hashCode(sizes); }
- /**
+ /**
* Builder of a set of dimension sizes.
* Dimensions whose size is not set before building will get size 0.
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index c207dabca3a..6b0d769de9f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -25,12 +25,12 @@ public class IndexedTensor implements Tensor {
/** The prescribed and possibly abstract type this is an instance of */
private final TensorType type;
-
+
/** The sizes of the dimensions of this in the order of the dimensions of the type */
private final DimensionSizes dimensionSizes;
-
+
private final double[] values;
-
+
private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) {
this.type = type;
this.dimensionSizes = dimensionSizes;
@@ -43,8 +43,8 @@ public class IndexedTensor implements Tensor {
}
/**
- * Returns an iterator over the cells of this.
- * Cells are returned in order of increasing indexes in each dimension, increasing
+ * Returns an iterator over the cells of this.
+ * Cells are returned in order of increasing indexes in each dimension, increasing
* indexes of later dimensions in the dimension type before earlier.
*/
@Override
@@ -69,7 +69,7 @@ public class IndexedTensor implements Tensor {
/**
* Returns an iterator over the values of this.
- * Values are returned in order of increasing indexes in each dimension, increasing
+ * Values are returned in order of increasing indexes in each dimension, increasing
* indexes of later dimensions in the dimension type before earlier.
*/
@Override
@@ -81,7 +81,7 @@ public class IndexedTensor implements Tensor {
* Returns an iterator over value iterators where the outer iterator is over each unique value of the dimensions
* given and the inner iterator is over each unique value of the rest of the dimensions, in the same order as
* other iterator.
- *
+ *
* @param dimensions the names of the dimensions of the superspace
* @param sizes the size of each dimension in the space we are returning values for, containing
* one value per dimension of this tensor (in order). Each size may be the same or smaller
@@ -96,9 +96,9 @@ public class IndexedTensor implements Tensor {
return subspaceIterator(dimensions, dimensionSizes);
}
- /**
+ /**
* Returns the value at the given indexes
- *
+ *
* @param indexes the indexes into the dimensions of this. Must be one number per dimension of this
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
@@ -119,7 +119,7 @@ public class IndexedTensor implements Tensor {
}
private double get(int valueIndex) { return values[valueIndex]; }
-
+
private static int toValueIndex(int[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
if (indexes.length == 0) return 0; // for speed
@@ -165,7 +165,7 @@ public class IndexedTensor implements Tensor {
public Map<TensorAddress, Double> cells() {
if (dimensionSizes.dimensions() == 0)
return Collections.singletonMap(TensorAddress.of(), values[0]);
-
+
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
for (int i = 0; i < values.length; i++) {
@@ -174,13 +174,13 @@ public class IndexedTensor implements Tensor {
}
return builder.build();
}
-
+
@Override
public int hashCode() { return Arrays.hashCode(values); }
@Override
public String toString() { return Tensor.toStandardString(this); }
-
+
@Override
public boolean equals(Object other) {
if ( ! ( other instanceof Tensor)) return false;
@@ -188,9 +188,9 @@ public class IndexedTensor implements Tensor {
}
public abstract static class Builder implements Tensor.Builder {
-
+
final TensorType type;
-
+
private Builder(TensorType type) {
this.type = type;
}
@@ -202,7 +202,7 @@ public class IndexedTensor implements Tensor {
return new UnboundBuilder(type);
}
- /**
+ /**
* Create a builder with dimension size information for this instance. Must be one size entry per dimension,
* and, agree with the type size information when specified in the type.
* If sizes are completely specified in the type this size information is redundant.
@@ -210,16 +210,16 @@ public class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes) {
// validate
if (sizes.dimensions() != type.dimensions().size())
- throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
+ throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
"for " + type);
for (int i = 0; i < sizes.dimensions(); i++ ) {
Optional<Integer> size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
- throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
+ throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
sizes.size(i) +
" but cannot be larger than " + size.get() + " in " + type);
}
-
+
return new BoundBuilder(type, sizes);
}
@@ -232,7 +232,7 @@ public class IndexedTensor implements Tensor {
public abstract IndexedTensor build();
}
-
+
/** A bound builder can create the double array directly */
public static class BoundBuilder extends Builder {
@@ -257,13 +257,13 @@ public class IndexedTensor implements Tensor {
this.sizes = sizes;
values = new double[sizes.totalSize()];
}
-
+
@Override
public BoundBuilder cell(double value, int ... indexes) {
values[toValueIndex(indexes, sizes)] = value;
return this;
}
-
+
@Override
public CellBuilder cell() {
return new CellBuilder(type, this);
@@ -294,8 +294,8 @@ public class IndexedTensor implements Tensor {
return this;
}
- /**
- * Set a cell value by the index in the internal layout of this cell.
+ /**
+ * Set a cell value by the index in the internal layout of this cell.
* This requires knowledge of the internal layout of cells in this implementation, and should therefore
* probably not be used (but when it can be used it is fast).
*/
@@ -330,7 +330,7 @@ public class IndexedTensor implements Tensor {
fillValues(0, 0, firstDimension, dimensionSizes, values);
return new IndexedTensor(type, dimensionSizes, values);
}
-
+
private DimensionSizes findDimensionSizes(List<Object> firstDimension) {
List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size());
findDimensionSizes(0, dimensionSizeList, firstDimension);
@@ -347,16 +347,16 @@ public class IndexedTensor implements Tensor {
if (currentDimensionIndex == dimensionSizes.size())
dimensionSizes.add(currentDimension.size());
else if (dimensionSizes.get(currentDimensionIndex) != currentDimension.size())
- throw new IllegalArgumentException("Missing values in dimension " +
+ throw new IllegalArgumentException("Missing values in dimension " +
type.dimensions().get(currentDimensionIndex) + " in " + type);
-
+
for (Object value : currentDimension)
if (value instanceof List)
findDimensionSizes(currentDimensionIndex + 1, dimensionSizes, (List<Object>)value);
}
@SuppressWarnings("unchecked")
- private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension,
+ private void fillValues(int currentDimensionIndex, int offset, List<Object> currentDimension,
DimensionSizes sizes, double[] values) {
if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension
for (int i = 0; i < currentDimension.size(); i++)
@@ -369,7 +369,7 @@ public class IndexedTensor implements Tensor {
}
}
}
-
+
private double nullAsZero(Double value) {
if (value == null) return 0;
return value;
@@ -431,7 +431,7 @@ public class IndexedTensor implements Tensor {
}
}
-
+
private final class CellIterator implements Iterator<Cell> {
private int count = 0;
@@ -451,7 +451,7 @@ public class IndexedTensor implements Tensor {
reusedCell.value = get(indexes.toSourceValueIndex());
return reusedCell;
}
-
+
}
private final class ValueIterator implements Iterator<Double> {
@@ -474,25 +474,25 @@ public class IndexedTensor implements Tensor {
}
}
-
+
private final class SuperspaceIterator implements Iterator<SubspaceIterator> {
private final Indexes superindexes;
/** Those indexes this should iterate over */
private final List<Integer> subdimensionIndexes;
-
- /**
+
+ /**
* The sizes of the space we'll return values of, one value for each dimension of this tensor,
- * which may be equal to or smaller than the sizes of this tensor
+ * which may be equal to or smaller than the sizes of this tensor
*/
private final DimensionSizes iterateSizes;
private int count = 0;
-
+
private SuperspaceIterator(Set<String> superdimensionNames, DimensionSizes iterateSizes) {
this.iterateSizes = iterateSizes;
-
+
List<Integer> superdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for outer iterator
subdimensionIndexes = new ArrayList<>(superdimensionNames.size()); // for inner iterator (max length)
for (int i = type.dimensions().size() - 1; i >= 0; i-- ) { // iterate inner dimensions first
@@ -501,10 +501,10 @@ public class IndexedTensor implements Tensor {
else
subdimensionIndexes.add(i);
}
-
+
superindexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, superdimensionIndexes);
}
-
+
@Override
public boolean hasNext() {
return count < superindexes.size();
@@ -527,7 +527,7 @@ public class IndexedTensor implements Tensor {
*/
public final class SubspaceIterator implements Iterator<Tensor.Cell> {
- /**
+ /**
* This iterator will iterate over the given dimensions, in the order given
* (the first dimension index given is incremented to exhaustion first (i.e is etc.).
* This may be any subset of the dimensions given by address and dimensionSizes.
@@ -538,21 +538,21 @@ public class IndexedTensor implements Tensor {
private Indexes indexes;
private int count = 0;
-
+
/** A lazy cell for reuse */
private final LazyCell reusedCell;
-
- /**
+
+ /**
* Creates a new subspace iterator
- *
+ *
* @param iterateDimensions the dimensions to iterate over, given as indexes in the dimension order of the
* type of the tensor this iterates over. This iterator will iterate over these
- * dimensions to exhaustion in the order given (the first dimension index given is
+ * dimensions to exhaustion in the order given (the first dimension index given is
* incremented to exhaustion first (i.e is etc.), while other dimensions will be held
* at a constant position.
* This may be any subset of the dimensions given by address and dimensionSizes.
* This is treated as immutable.
- * @param address the address of the first cell of this subspace.
+ * @param address the address of the first cell of this subspace.
*/
private SubspaceIterator(List<Integer> iterateDimensions, int[] address, DimensionSizes iterateSizes) {
this.iterateDimensions = iterateDimensions;
@@ -561,26 +561,26 @@ public class IndexedTensor implements Tensor {
this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address);
reusedCell = new LazyCell(indexes, Double.NaN);
}
-
+
/** Returns the total number of cells in this subspace */
- public int size() {
+ public int size() {
return indexes.size();
}
-
+
/** Returns the address of the cell this currently points to (which may be an invalid position) */
public TensorAddress address() { return indexes.toAddress(); }
-
+
/** Rewind this iterator to the first element */
- public void reset() {
+ public void reset() {
this.count = 0;
- this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address);
+ this.indexes = Indexes.of(IndexedTensor.this.dimensionSizes, iterateSizes, iterateDimensions, address);
}
-
+
@Override
public boolean hasNext() {
- return count < indexes.size();
+ return count < indexes.size();
}
-
+
/** Returns the next cell, which is valid until next() is called again */
@Override
public Cell next() {
@@ -611,15 +611,15 @@ public class IndexedTensor implements Tensor {
public TensorAddress getKey() {
return indexes.toAddress();
}
-
+
@Override
public Double getValue() { return value; }
}
// TODO: Make dimensionSizes a class
-
- /**
+
+ /**
* An array of indexes into this tensor which are able to find the next index in the value order.
* next() can be called once per element in the dimensions we iterate over. It must be called once
* before accessing the first position.
@@ -631,7 +631,7 @@ public class IndexedTensor implements Tensor {
private final DimensionSizes iterationSizes;
protected final int[] indexes;
-
+
public static Indexes of(DimensionSizes sizes) {
return of(sizes, sizes);
}
@@ -676,14 +676,14 @@ public class IndexedTensor implements Tensor {
return new MultiDimensionIndexes(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, size);
}
}
-
+
private static List<Integer> completeIterationOrder(int length) {
List<Integer> iterationDimensions = new ArrayList<>(length);
for (int i = 0; i < length; i++)
iterationDimensions.add(length - 1 - i);
return iterationDimensions;
}
-
+
private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) {
this.sourceSizes = sourceSizes;
this.iterationSizes = iterationSizes;
@@ -708,9 +708,9 @@ public class IndexedTensor implements Tensor {
/** Returns a copy of the indexes of this which must not be modified */
public int[] indexesForReading() { return indexes; }
-
- int toSourceValueIndex() {
- return IndexedTensor.toValueIndex(indexes, sourceSizes);
+
+ int toSourceValueIndex() {
+ return IndexedTensor.toValueIndex(indexes, sourceSizes);
}
int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); }
@@ -729,9 +729,9 @@ public class IndexedTensor implements Tensor {
public String toString() {
return "indexes " + Arrays.toString(indexes);
}
-
+
public abstract int size();
-
+
public abstract void next();
}
@@ -763,18 +763,18 @@ public class IndexedTensor implements Tensor {
public void next() {}
}
-
+
private static class MultiDimensionIndexes extends Indexes {
private final int size;
private final List<Integer> iterateDimensions;
-
+
private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimensions = iterateDimensions;
this.size = size;
-
+
// Initialize to the (virtual) position before the first cell
indexes[iterateDimensions.get(0)]--;
}
@@ -785,10 +785,10 @@ public class IndexedTensor implements Tensor {
return size;
}
- /**
- * Advances this to the next cell in the standard indexed tensor cell order.
- * The first call to this will put it at the first position.
- *
+ /**
+ * Advances this to the next cell in the standard indexed tensor cell order.
+ * The first call to this will put it at the first position.
+ *
* @throws RuntimeException if this is called more times than its size
*/
@Override
@@ -802,12 +802,12 @@ public class IndexedTensor implements Tensor {
}
}
-
+
/** In this case we can reuse the source index computation for the iteration index */
private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes {
private int lastComputedSourceValueIndex = -1;
-
+
private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, int[] initialIndexes, int size) {
super(sizes, sizes, iterateDimensions, initialIndexes, size);
}
@@ -827,7 +827,7 @@ public class IndexedTensor implements Tensor {
private final int size;
private final int iterateDimension;
-
+
/** Maintain this directly as an optimization for 1-d iteration */
private int currentSourceValueIndex, currentIterationValueIndex;
@@ -847,7 +847,7 @@ public class IndexedTensor implements Tensor {
currentSourceValueIndex = IndexedTensor.toValueIndex(indexes, sourceSizes);
currentIterationValueIndex = IndexedTensor.toValueIndex(indexes, iterateSizes);
}
-
+
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
@Override
public int size() {
@@ -855,8 +855,8 @@ public class IndexedTensor implements Tensor {
}
/**
- * Advances this to the next cell in the standard indexed tensor cell order.
- * The first call to this will put it at the first position.
+ * Advances this to the next cell in the standard indexed tensor cell order.
+ * The first call to this will put it at the first position.
*
* @throws RuntimeException if this is called more times than its size
*/
@@ -888,7 +888,7 @@ public class IndexedTensor implements Tensor {
/** The iteration step in the value index space */
private final int step;
- private EqualSizeSingleDimensionIndexes(DimensionSizes sizes,
+ private EqualSizeSingleDimensionIndexes(DimensionSizes sizes,
int iterateDimension, int[] initialIndexes, int size) {
super(sizes, sizes, initialIndexes);
this.iterateDimension = iterateDimension;
@@ -907,8 +907,8 @@ public class IndexedTensor implements Tensor {
}
/**
- * Advances this to the next cell in the standard indexed tensor cell order.
- * The first call to this will put it at the first position.
+ * Advances this to the next cell in the standard indexed tensor cell order.
+ * The first call to this will put it at the first position.
*
* @throws RuntimeException if this is called more times than its size
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index 618bff0caae..aba61478e69 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -27,7 +27,7 @@ public class MappedTensor implements Tensor {
@Override
public TensorType type() { return type; }
-
+
@Override
public int size() { return cells.size(); }
@@ -56,16 +56,16 @@ public class MappedTensor implements Tensor {
}
public static class Builder implements Tensor.Builder {
-
+
private final TensorType type;
private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
-
+
public static Builder of(TensorType type) { return new Builder(type); }
private Builder(TensorType type) {
this.type = type;
}
-
+
public CellBuilder cell() {
return new CellBuilder(type, this);
}
@@ -89,24 +89,24 @@ public class MappedTensor implements Tensor {
public MappedTensor build() {
return new MappedTensor(type, cells.build());
}
-
+
}
private static class CellIteratorAdaptor implements Iterator<Cell> {
private final Iterator<Map.Entry<TensorAddress, Double>> adaptedIterator;
-
+
private CellIteratorAdaptor(Iterator<Map.Entry<TensorAddress, Double>> adaptedIterator) {
this.adaptedIterator = adaptedIterator;
}
-
+
@Override
public boolean hasNext() { return adaptedIterator.hasNext(); }
@Override
public Cell next() {
Map.Entry<TensorAddress, Double> entry = adaptedIterator.next();
- return new Cell(entry.getKey(), entry.getValue());
+ return new Cell(entry.getKey(), entry.getValue());
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 79bb27fcd1b..9a751e078e0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -117,7 +117,7 @@ public class MixedTensor implements Tensor {
return index.denseSubspaceSize();
}
-
+
/**
* Base class for building mixed tensors.
*/
@@ -286,7 +286,7 @@ public class MixedTensor implements Tensor {
}
return typeBuilder.build();
}
-
+
}
/**
@@ -360,7 +360,7 @@ public class MixedTensor implements Tensor {
}
return denseSubspaceSize;
}
-
+
private TensorAddress sparsePartialAddress(TensorAddress address) {
if (type.dimensions().size() != address.size()) {
throw new IllegalArgumentException("Tensor type and address are not of same size.");
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 2ed211539d8..1b60e01cf7e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -52,7 +52,7 @@ import java.util.function.Function;
public interface Tensor {
// ----------------- Accessors
-
+
TensorType type();
/** Returns whether this have any cells */
@@ -70,13 +70,13 @@ public interface Tensor {
/** Returns the values of this in some undefined order */
Iterator<Double> valueIterator();
- /**
+ /**
* Returns an immutable map of the cells of this in no particular order.
- * This may be expensive for some implementations - avoid when possible
+ * This may be expensive for some implementations - avoid when possible
*/
Map<TensorAddress, Double> cells();
- /**
+ /**
* Returns the value of this as a double if it has no dimensions and one value
*
* @throws IllegalStateException if this does not have zero dimensions and one value
@@ -87,9 +87,9 @@ public interface Tensor {
if (size() == 0) return Double.NaN;
return valueIterator().next();
}
-
+
// ----------------- Primitive tensor functions
-
+
default Tensor map(DoubleUnaryOperator mapper) {
return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate();
}
@@ -108,7 +108,7 @@ public interface Tensor {
}
default Tensor rename(String fromDimension, String toDimension) {
- return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension),
+ return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension),
Collections.singletonList(toDimension)).evaluate();
}
@@ -123,13 +123,13 @@ public interface Tensor {
default Tensor rename(List<String> fromDimensions, List<String> toDimensions) {
return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate();
}
-
+
static Tensor generate(TensorType type, Function<List<Integer>, Double> valueSupplier) {
return new Generate(type, valueSupplier).evaluate();
}
-
+
// ----------------- Composite tensor functions which have a defined primitive mapping
-
+
default Tensor l1Normalize(String dimension) {
return new L1Normalize(new ConstantTensor(this), dimension).evaluate();
}
@@ -231,7 +231,7 @@ public interface Tensor {
if (cellEntries.isEmpty()) return "{}";
return "{" + cellEntries.get(0).getValue() +"}";
}
-
+
Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
StringBuilder b = new StringBuilder("{");
@@ -253,7 +253,7 @@ public interface Tensor {
*/
boolean equals(Object o);
- /**
+ /**
* Implement here to make this work across implementations.
* Implementations must override equals and call this because this is an interface and cannot override equals.
*/
@@ -328,13 +328,13 @@ public interface Tensor {
@Override
public TensorAddress getKey() { return address; }
- /**
+ /**
* Returns the direct index which can be used to locate this cell, or -1 if not available.
* This is for optimizations mapping between tensors where this is possible without creating a
* TensorAddress.
*/
int getDirectIndex() { return -1; }
-
+
@Override
public Double getValue() { return value; }
@@ -388,20 +388,20 @@ public interface Tensor {
/** Returns the type this is building */
TensorType type();
-
+
/** Return a cell builder */
CellBuilder cell();
/** Add a cell */
Builder cell(TensorAddress address, double value);
-
+
/** Add a cell */
Builder cell(double value, int ... labels);
- /**
- * Add a cell
- *
- * @param cell a cell providing the location at which to add this cell
+ /**
+ * Add a cell
+ *
+ * @param cell a cell providing the location at which to add this cell
* @param value the value to assign to the cell
*/
default Builder cell(Cell cell, double value) {
@@ -409,12 +409,12 @@ public interface Tensor {
}
Tensor build();
-
+
class CellBuilder {
private final TensorAddress.Builder addressBuilder;
private final Tensor.Builder tensorBuilder;
-
+
CellBuilder(TensorType type, Tensor.Builder tensorBuilder) {
addressBuilder = new TensorAddress.Builder(type);
this.tensorBuilder = tensorBuilder;
@@ -436,5 +436,5 @@ public interface Tensor {
}
}
-
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 7161450d5d5..ff1202463f2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -32,10 +32,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
/** Returns the number of labels in this */
public abstract int size();
-
+
/**
- * Returns the i'th label in this
- *
+ * Returns the i'th label in this
+ *
* @throws IllegalArgumentException if there is no label at this index
*/
public abstract String label(int i);
@@ -102,23 +102,23 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
private StringTensorAddress(String ... labels) {
this.labels = Arrays.copyOf(labels, labels.length);
}
-
+
@Override
public int size() { return labels.length; }
-
+
@Override
public String label(int i) { return labels[i]; }
-
+
@Override
- public int intLabel(int i) {
+ public int intLabel(int i) {
try {
return Integer.parseInt(labels[i]);
- }
+ }
catch (NumberFormatException e) {
throw new IllegalArgumentException("Expected an int label in " + this + " at position " + i);
}
}
-
+
@Override
public TensorAddress withLabel(int index, int label) {
String[] labels = Arrays.copyOf(this.labels, this.labels.length);
@@ -169,7 +169,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
private final TensorType type;
private final String[] labels;
-
+
public Builder(TensorType type) {
this(type, new String[type.dimensions().size()]);
}
@@ -193,7 +193,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
labels[labelIndex.get()] = label;
return this;
}
-
+
/** Creates a copy of this which can be modified separately */
public Builder copy() {
return new Builder(type, Arrays.copyOf(labels, labels.length));
@@ -202,7 +202,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
public TensorAddress build() {
for (int i = 0; i < labels.length; i++)
if (labels[i] == null)
- throw new IllegalArgumentException("Missing a value for dimension " +
+ throw new IllegalArgumentException("Missing a value for dimension " +
type.dimensions().get(i).name() + " for " + type);
return TensorAddress.of(labels);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index da8ab3bb0ec..9b3a9328f07 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -96,7 +96,7 @@ class TensorParser {
if (valueEnd < 0)
throw new IllegalArgumentException("A tensor string must end by '}'");
}
-
+
TensorAddress address = addressBuilder.build();
Double value = asDouble(address, s.substring(0, valueEnd).trim());
builder.cell(address, value);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index c05c35d6df3..914d853aeca 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -53,14 +53,17 @@ public class TensorType {
return TensorTypeParser.fromSpec(specString);
}
+ /** Returns the number of dimensions of this: dimensions().size() */
+ public int rank() { return dimensions.size(); }
+
/** Returns an immutable list of the dimensions of this */
public List<Dimension> dimensions() { return dimensions; }
-
+
/** Returns an immutable set of the names of the dimensions of this */
public Set<String> dimensionNames() {
return dimensions.stream().map(Dimension::name).collect(Collectors.toSet());
}
-
+
/** Returns the dimension with this name, or empty if not present */
public Optional<Dimension> dimension(String name) {
return indexOfDimension(name).map(i -> dimensions.get(i));
@@ -74,7 +77,7 @@ public class TensorType {
return Optional.empty();
}
- /**
+ /**
* Returns whether this type can be assigned to the given type,
* i.e if the given type is a generalization of this type.
*/
@@ -128,9 +131,9 @@ public class TensorType {
private final String name;
- private Dimension(String name) {
+ private Dimension(String name) {
Objects.requireNonNull(name, "A tensor name cannot be null");
- this.name = name;
+ this.name = name;
}
public final String name() { return name; }
@@ -146,7 +149,7 @@ public class TensorType {
/** Returns true if this is an indexed bound or unboun type */
public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; }
- /**
+ /**
* Returns the dimension resulting from combining two dimensions having the same name but possibly different
* types. This works by degrading to the type making the fewer promises.
* [N] + [M] = [min(N, M)]
@@ -165,7 +168,7 @@ public class TensorType {
IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get();
return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb;
}
-
+
@Override
public abstract String toString();
@@ -175,21 +178,21 @@ public class TensorType {
if (other == null || getClass() != other.getClass()) return false;
return name.equals(((Dimension)other).name);
}
-
+
@Override
public int hashCode() {
return name.hashCode();
}
-
+
@Override
public int compareTo(Dimension other) {
return this.name.compareTo(other.name);
}
-
+
public static Dimension indexed(String name, int size) {
return new IndexedBoundDimension(name, size);
}
-
+
}
public static class IndexedBoundDimension extends TensorType.Dimension {
@@ -289,9 +292,9 @@ public class TensorType {
public Builder() {
}
- /**
- * Creates a builder containing a combination of the dimensions of the given types
- *
+ /**
+ * Creates a builder containing a combination of the dimensions of the given types
+ *
* If the same dimension is indexed with different size restrictions the largest size will be used.
* If it is size restricted in one argument but not the other it will not be size restricted.
* If it is indexed in one and mapped in the other it will become mapped.
@@ -325,9 +328,12 @@ public class TensorType {
}
}
- /**
+ /** Returns the current number of dimensions in this */
+ public int rank() { return dimensions.size(); }
+
+ /**
* Adds a new dimension to this
- *
+ *
* @throws IllegalArgumentException if the dimension is already present
*/
private Builder add(Dimension dimension) {
@@ -346,7 +352,7 @@ public class TensorType {
return this;
}
- /**
+ /**
* Adds a bound indexed dimension to this
*
* @throws IllegalArgumentException if the dimension is already present
@@ -355,7 +361,7 @@ public class TensorType {
/**
* Adds an unbound indexed dimension to this
- *
+ *
* @throws IllegalArgumentException if the dimension is already present
*/
public Builder indexed(String name) {
@@ -375,7 +381,7 @@ public class TensorType {
public Builder dimension(Dimension dimension) {
return add(dimension);
}
-
+
/** Returns the given dimension, or empty if none is present */
public Optional<Dimension> getDimension(String dimension) {
return Optional.ofNullable(dimensions.get(dimension));
@@ -393,7 +399,7 @@ public class TensorType {
public TensorType build() {
return new TensorType(dimensions.values());
}
-
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
index 84caca78fb2..3db661f8a23 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
@@ -2,16 +2,17 @@
package com.yahoo.tensor.evaluation;
import com.google.common.annotations.Beta;
-
-import java.util.HashMap;
+import com.yahoo.tensor.Tensor;
/**
* An evaluation context which is passed down to all nested functions during evaluation.
- * The default context is empty to allow various evaluation frameworks to support their own implementation.
- *
+ *
* @author bratseth
*/
@Beta
public interface EvaluationContext {
+ /** Returns the tensor bound to this name, or null if none */
+ Tensor getTensor(String name);
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
index cf704c15f4f..db8a66a5fa2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
@@ -18,7 +18,7 @@ public class MapEvaluationContext implements EvaluationContext {
public void put(String name, Tensor tensor) { bindings.put(name, tensor); }
- /** Returns the tensor bound to this name, or null if none */
- public Tensor get(String name) { return bindings.get(name); }
+ @Override
+ public Tensor getTensor(String name) { return bindings.get(name); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index 8ade181bdb7..1f6ad050368 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -12,18 +12,18 @@ import java.util.List;
/**
* A tensor variable name which resolves to a tensor in the context at evaluation time
- *
+ *
* @author bratseth
*/
@Beta
public class VariableTensor extends PrimitiveTensorFunction {
private final String name;
-
+
public VariableTensor(String name) {
this.name = name;
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -35,7 +35,7 @@ public class VariableTensor extends PrimitiveTensorFunction {
@Override
public Tensor evaluate(EvaluationContext context) {
- return ((MapEvaluationContext)context).get(name);
+ return context.getTensor(name);
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 8f4dbf014a7..191c7988443 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -8,7 +8,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
/**
* A composite tensor function is a tensor function which can be expressed (less tersely)
* as a tree of primitive tensor functions.
- *
+ *
* @author bratseth
*/
@Beta
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 1dbb94fdb20..faa0ca36cb6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -15,7 +15,7 @@ import java.util.stream.Collectors;
/**
* Concatenation of two tensors along an (indexed) dimension
- *
+ *
* @author bratseth
*/
@Beta
@@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction {
concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
return builder.build();
}
-
+
private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
@@ -112,7 +112,7 @@ public class Concat extends PrimitiveTensorFunction {
Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
return tensor.multiply(unitTensor);
}
-
+
}
/** Returns the type resulting from concatenating a and b */
@@ -144,7 +144,7 @@ public class Concat extends PrimitiveTensorFunction {
/**
* Combine two addresses, adding the offset to the concat dimension
*
- * @return the combined address or null if the addresses are incompatible
+ * @return the combined address or null if the addresses are incompatible
* (in some other dimension than the concat dimension)
*/
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
@@ -161,7 +161,7 @@ public class Concat extends PrimitiveTensorFunction {
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
+ * That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 4ac7b21ba90..14ed38718ce 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -10,18 +10,18 @@ import java.util.List;
/**
* A function which returns a constant tensor.
- *
+ *
* @author bratseth
*/
@Beta
public class ConstantTensor extends PrimitiveTensorFunction {
private final Tensor constant;
-
+
public ConstantTensor(String tensorString) {
this.constant = Tensor.from(tensorString);
}
-
+
public ConstantTensor(Tensor tensor) {
this.constant = tensor;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index bbdbd5c3df1..c75d8ee4753 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -11,19 +11,19 @@ import java.util.stream.Stream;
/**
* A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere.
- *
+ *
* @author bratseth
*/
public class Diag extends CompositeTensorFunction {
private final TensorType type;
private final Function<List<Integer>, Double> diagFunction;
-
+
public Diag(TensorType type) {
this.type = type;
this.diagFunction = ScalarFunctions.equal(dimensionNames().collect(Collectors.toList()));
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -43,7 +43,7 @@ public class Diag extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction;
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::name);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 6ea73b7f310..e42d25197e2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -15,7 +15,7 @@ import java.util.function.Function;
/**
* An indexed tensor whose values are generated by a function
- *
+ *
* @author bratseth
*/
@Beta
@@ -26,7 +26,7 @@ public class Generate extends PrimitiveTensorFunction {
/**
* Creates a generated tensor
- *
+ *
* @param type the type of the tensor
* @param generator the function generating values from a list of ints specifying the indexes of the
* tensor cell which will receive the value
@@ -39,7 +39,7 @@ public class Generate extends PrimitiveTensorFunction {
this.type = type;
this.generator = generator;
}
-
+
private void validateType(TensorType type) {
for (TensorType.Dimension dimension : type.dimensions())
if (dimension.type() != TensorType.Dimension.Type.indexedBound)
@@ -58,7 +58,7 @@ public class Generate extends PrimitiveTensorFunction {
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
-
+
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor.Builder builder = Tensor.Builder.of(type);
@@ -69,14 +69,14 @@ public class Generate extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
private DimensionSizes dimensionSizes(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < b.dimensions(); i++)
b.set(i, type.dimensions().get(i).size().get());
return b.build();
}
-
+
@Override
public String toString(ToStringContext context) { return type + "(" + generator + ")"; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 8c4dbfb0acb..ff887e3e9a6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -28,12 +28,12 @@ import java.util.function.DoubleBinaryOperator;
* The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells
* given by the cross product of the cells of the given tensors, having as values the value produced by
* applying the given combinator function on the values from the two source cells.
- *
+ *
* @author bratseth
*/
@Beta
public class Join extends PrimitiveTensorFunction {
-
+
private final TensorFunction argumentA, argumentB;
private final DoubleBinaryOperator combinator;
@@ -46,6 +46,30 @@ public class Join extends PrimitiveTensorFunction {
this.combinator = combinator;
}
+ /** Returns the type resulting from applying Join to the two given types */
+ public static TensorType outputType(TensorType a, TensorType b) {
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (int i = 0; i < a.dimensions().size(); ++i) {
+ TensorType.Dimension aDim = a.dimensions().get(i);
+ for (int j = 0; j < b.dimensions().size(); ++j) {
+ TensorType.Dimension bDim = b.dimensions().get(j);
+ if (aDim.name().equals(bDim.name())) { // include
+ if (aDim.isIndexed() && bDim.isIndexed()) {
+ if (aDim.size().isPresent() || bDim.size().isPresent())
+ typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
+ bDim.size().orElse(Integer.MAX_VALUE)));
+ else
+ typeBuilder.indexed(aDim.name());
+ }
+ else {
+ typeBuilder.mapped(aDim.name());
+ }
+ }
+ }
+ }
+ return typeBuilder.build();
+ }
+
public TensorFunction argumentA() { return argumentA; }
public TensorFunction argumentB() { return argumentB; }
public DoubleBinaryOperator combinator() { return combinator; }
@@ -88,11 +112,11 @@ public class Join extends PrimitiveTensorFunction {
else
return generalJoin(a, b, joinedType);
}
-
+
private boolean hasSingleIndexedDimension(Tensor tensor) {
return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
}
-
+
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
@@ -114,7 +138,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
@@ -126,7 +150,7 @@ public class Join extends PrimitiveTensorFunction {
private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
-
+
DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace);
IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);
@@ -134,14 +158,14 @@ public class Join extends PrimitiveTensorFunction {
// Find dimensions which are only in the supertype
Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
superDimensionNames.removeAll(subspace.type().dimensionNames());
-
+
for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
joinSubspaces(subspace.valueIterator(), subspace.size(),
subspaceInSuper, subspaceInSuper.size(),
reversedArgumentOrder, builder);
}
-
+
return builder.build();
}
@@ -200,7 +224,7 @@ public class Join extends PrimitiveTensorFunction {
subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
return subspaceIndexes;
}
-
+
private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
String[] subspaceLabels = new String[subspaceIndexes.length];
for (int i = 0; i < subspaceIndexes.length; i++)
@@ -235,7 +259,7 @@ public class Join extends PrimitiveTensorFunction {
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
// for each combination of dimensions only in a
- for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
IndexedTensor.SubspaceIterator aSubspace = ia.next();
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
@@ -252,7 +276,7 @@ public class Join extends PrimitiveTensorFunction {
}
}
}
-
+
private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
@@ -260,7 +284,7 @@ public class Join extends PrimitiveTensorFunction {
builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
return builder.build();
}
-
+
/** Returns the sizes from the joined sizes which are present in the type argument */
private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
@@ -271,7 +295,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
@@ -340,7 +364,7 @@ public class Join extends PrimitiveTensorFunction {
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
+ * That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
@@ -360,7 +384,7 @@ public class Join extends PrimitiveTensorFunction {
return TensorAddress.of(joinedLabels);
}
- /**
+ /**
* Maps the content in the given list to the given array, using the given index map.
*
* @return true if the mapping was successful, false if one of the destination positions was
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index a9872bb42d8..a5e1a016a41 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -6,6 +6,7 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Collections;
@@ -32,6 +33,8 @@ public class Map extends PrimitiveTensorFunction {
this.mapper = mapper;
}
+ public static TensorType outputType(TensorType inputType) { return inputType; }
+
public TensorFunction argument() { return argument; }
public DoubleUnaryOperator mapper() { return mapper; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index bb27e937699..4071917c2b5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.TensorType;
import java.util.List;
@@ -14,13 +15,17 @@ public class Matmul extends CompositeTensorFunction {
private final TensorFunction argument1, argument2;
private final String dimension;
-
+
public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
this.argument1 = argument1;
this.argument2 = argument2;
this.dimension = dimension;
}
+ public static TensorType outputType(TensorType a, TensorType b, String dimension) {
+ return Join.outputType(a, b);
+ }
+
@Override
public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
@@ -39,7 +44,7 @@ public class Matmul extends CompositeTensorFunction {
Reduce.Aggregator.sum,
dimension);
}
-
+
@Override
public String toString(ToStringContext context) {
return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
index efb7b9e500c..b7c9a5d2342 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
@@ -8,10 +8,10 @@ import com.yahoo.tensor.Tensor;
* A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions.
* All tensor implementations must implement all primitive tensor functions.
* Primitive tensor functions are fully inspectable.
- *
+ *
* @author bratseth
*/
@Beta
public abstract class PrimitiveTensorFunction extends TensorFunction {
-
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 457763e97ba..958ef85d1dc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -22,11 +22,11 @@ import java.util.stream.Stream;
public class Random extends CompositeTensorFunction {
private final TensorType type;
-
+
public Random(TensorType type) {
this.type = type;
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -46,7 +46,7 @@ public class Random extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index e2b39a2048d..a56f82b026a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -12,19 +12,19 @@ import java.util.stream.Stream;
/**
* A tensor generator which returns a tensor of any dimension filled with the sum of the tensor
* indexes of each position.
- *
+ *
* @author bratseth
*/
public class Range extends CompositeTensorFunction {
private final TensorType type;
private final Function<List<Integer>, Double> rangeFunction;
-
+
public Range(TensorType type) {
this.type = type;
this.rangeFunction = ScalarFunctions.sum(dimensionNames().collect(Collectors.toList()));
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -44,7 +44,7 @@ public class Range extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction;
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index cfc78be7e0c..de9f90a5804 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -19,7 +19,7 @@ import java.util.Objects;
import java.util.Set;
/**
- * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
+ * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
* are collapsed to a single value using an aggregator function.
*
* @author bratseth
@@ -45,7 +45,7 @@ public class Reduce extends PrimitiveTensorFunction {
/**
* Creates a reduce function.
- *
+ *
* @param argument the tensor to reduce
* @param aggregator the aggregator function to use
* @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced,
@@ -61,6 +61,15 @@ public class Reduce extends PrimitiveTensorFunction {
this.dimensions = ImmutableList.copyOf(dimensions);
}
+ public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
+ TensorType.Builder b = new TensorType.Builder();
+ for (TensorType.Dimension dimension : inputType.dimensions()) {
+ if ( ! reduceDimensions.contains(dimension.name()))
+ b.dimension(dimension);
+ }
+ return b.build();
+ }
+
public TensorFunction argument() { return argument; }
@Override
@@ -82,7 +91,7 @@ public class Reduce extends PrimitiveTensorFunction {
public String toString(ToStringContext context) {
return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
}
-
+
private String commaSeparated(List<String> list) {
StringBuilder b = new StringBuilder();
for (String element : list)
@@ -94,7 +103,7 @@ public class Reduce extends PrimitiveTensorFunction {
public Tensor evaluate(EvaluationContext context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
- throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
+ throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
dimensions + ": Not all those dimensions are present in this tensor");
// Special case: Reduce all
@@ -103,14 +112,14 @@ public class Reduce extends PrimitiveTensorFunction {
return reduceIndexedVector((IndexedTensor)argument);
else
return reduceAllGeneral(argument);
-
+
// Reduce type
TensorType.Builder builder = new TensorType.Builder();
for (TensorType.Dimension dimension : argument.type().dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
builder.dimension(dimension);
TensorType reducedType = builder.build();
-
+
// Reduce cells
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
@@ -122,10 +131,10 @@ public class Reduce extends PrimitiveTensorFunction {
Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet())
reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
-
+
return reducedBuilder.build();
}
-
+
private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) {
Set<Integer> indexesToRemove = new HashSet<>();
for (String dimensionToRemove : this.dimensions)
@@ -138,7 +147,7 @@ public class Reduce extends PrimitiveTensorFunction {
reducedLabels[reducedLabelIndex++] = address.label(i);
return TensorAddress.of(reducedLabels);
}
-
+
private Tensor reduceAllGeneral(Tensor argument) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); )
@@ -154,7 +163,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
private static abstract class ValueAggregator {
-
+
private static ValueAggregator ofType(Aggregator aggregator) {
switch (aggregator) {
case avg : return new AvgAggregator();
@@ -165,22 +174,22 @@ public class Reduce extends PrimitiveTensorFunction {
case min : return new MinAggregator();
default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
}
-
+
}
/** Add a new value to those aggregated by this */
public abstract void aggregate(double value);
-
+
/** Returns the value aggregated by this */
public abstract double aggregatedValue();
-
+
}
-
+
private static class AvgAggregator extends ValueAggregator {
private int valueCount = 0;
private double valueSum = 0.0;
-
+
@Override
public void aggregate(double value) {
valueCount++;
@@ -188,7 +197,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public double aggregatedValue() {
+ public double aggregatedValue() {
return valueSum / valueCount;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 6b0daf1b49a..ec9b762a41c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -3,8 +3,6 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -19,7 +17,7 @@ import java.util.Objects;
/**
* The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names.
- *
+ *
* @author bratseth
*/
@Beta
@@ -29,6 +27,10 @@ public class Rename extends PrimitiveTensorFunction {
private final List<String> fromDimensions;
private final List<String> toDimensions;
+ public Rename(TensorFunction argument, String fromDimension, String toDimension) {
+ this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
+ }
+
public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
@@ -42,7 +44,7 @@ public class Rename extends PrimitiveTensorFunction {
this.fromDimensions = ImmutableList.copyOf(fromDimensions);
this.toDimensions = ImmutableList.copyOf(toDimensions);
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@@ -62,7 +64,7 @@ public class Rename extends PrimitiveTensorFunction {
Map<String, String> fromToMap = fromToMap();
TensorType renamedType = rename(tensor.type(), fromToMap);
-
+
// an array which lists the index of each label in the renamed type
int[] toIndexes = new int[tensor.type().dimensions().size()];
for (int i = 0; i < tensor.type().dimensions().size(); i++) {
@@ -70,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction {
String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName);
toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
}
-
+
Tensor.Builder builder = Tensor.Builder.of(renamedType);
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -86,7 +88,7 @@ public class Rename extends PrimitiveTensorFunction {
builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
return builder.build();
}
-
+
private TensorAddress rename(TensorAddress address, int[] toIndexes) {
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)
@@ -95,18 +97,18 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) {
- return "rename(" + argument.toString(context) + ", " +
+ public String toString(ToStringContext context) {
+ return "rename(" + argument.toString(context) + ", " +
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
-
+
private Map<String, String> fromToMap() {
Map<String, String> map = new HashMap<>();
for (int i = 0; i < fromDimensions.size(); i++)
map.put(fromDimensions.get(i), toDimensions.get(i));
return map;
}
-
+
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 99f79cb735a..fb5029fbfd6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -21,101 +21,87 @@ import java.util.stream.Collectors;
@Beta
public class ScalarFunctions {
- public static DoubleBinaryOperator add() { return new Addition(); }
- public static DoubleBinaryOperator multiply() { return new Multiplication(); }
- public static DoubleBinaryOperator divide() { return new Division(); }
+ public static DoubleBinaryOperator add() { return new Add(); }
+ public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
- public static DoubleUnaryOperator square() { return new Square(); }
+ public static DoubleBinaryOperator multiply() { return new Multiply(); }
+
+ public static DoubleUnaryOperator acos() { return new Acos(); }
+ public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
- public static DoubleUnaryOperator exp() { return new Exponent(); }
+ public static DoubleUnaryOperator square() { return new Square(); }
+
public static Function<List<Integer>, Double> random() { return new Random(); }
public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
- public static class Addition implements DoubleBinaryOperator {
+ // Binary operators -----------------------------------------------------------------------------
+ public static class Add implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left + right; }
-
@Override
public String toString() { return "f(a,b)(a + b)"; }
-
}
- public static class Multiplication implements DoubleBinaryOperator {
-
+ public static class Equal implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left * right; }
-
+ public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
@Override
- public String toString() { return "f(a,b)(a * b)"; }
-
+ public String toString() { return "f(a,b)(a==b)"; }
}
- public static class Division implements DoubleBinaryOperator {
-
+ public static class Exp implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left / right; }
-
+ public double applyAsDouble(double operand) { return Math.exp(operand); }
@Override
- public String toString() { return "f(a,b)(a / b)"; }
+ public String toString() { return "f(a)(exp(a))"; }
}
- public static class Equal implements DoubleBinaryOperator {
-
+ public static class Multiply implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
+ public double applyAsDouble(double left, double right) { return left * right; }
+ @Override
+ public String toString() { return "f(a,b)(a * b)"; }
+ }
+ public static class Divide implements DoubleBinaryOperator {
@Override
- public String toString() { return "f(a,b)(a==b)"; }
+ public double applyAsDouble(double left, double right) { return left / right; }
+ @Override
+ public String toString() { return "f(a,b)(a / b)"; }
}
- public static class Square implements DoubleUnaryOperator {
+ // Unary operators ------------------------------------------------------------------------------
+ public static class Acos implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double operand) { return operand * operand; }
-
+ public double applyAsDouble(double operand) { return Math.acos(operand); }
@Override
- public String toString() { return "f(a)(a * a)"; }
-
+ public String toString() { return "f(a)(acos(a))"; }
}
public static class Sqrt implements DoubleUnaryOperator {
-
@Override
public double applyAsDouble(double operand) { return Math.sqrt(operand); }
-
@Override
public String toString() { return "f(a)(sqrt(a))"; }
-
}
- public static class Exponent implements DoubleUnaryOperator {
+ public static class Square implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double operand) { return Math.exp(operand); }
+ public double applyAsDouble(double operand) { return operand * operand; }
@Override
- public String toString() { return "f(a)(exp(a))"; }
+ public String toString() { return "f(a)(a * a)"; }
}
- public static class Random implements Function<List<Integer>, Double> {
-
- @Override
- public Double apply(List<Integer> values) {
- return ThreadLocalRandom.current().nextDouble();
- }
-
- @Override
- public String toString() { return "random"; }
+ // Variable-length operators -----------------------------------------------------------------------------
- }
-
- public static class EqualElements implements Function<List<Integer>, Double> {
-
- private final ImmutableList<String> argumentNames;
-
+ public static class EqualElements implements Function<List<Integer>, Double> {
+ private final ImmutableList<String> argumentNames;
private EqualElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@@ -128,7 +114,6 @@ public class ScalarFunctions {
return 0.0;
return 1.0;
}
-
@Override
public String toString() {
if (argumentNames.size() == 0) return "1";
@@ -143,13 +128,19 @@ public class ScalarFunctions {
}
return b.toString();
}
+ }
+ public static class Random implements Function<List<Integer>, Double> {
+ @Override
+ public Double apply(List<Integer> values) {
+ return ThreadLocalRandom.current().nextDouble();
+ }
+ @Override
+ public String toString() { return "random"; }
}
public static class SumElements implements Function<List<Integer>, Double> {
-
private final ImmutableList<String> argumentNames;
-
private SumElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@@ -161,12 +152,10 @@ public class ScalarFunctions {
sum += value;
return (double)sum;
}
-
@Override
public String toString() {
return argumentNames.stream().collect(Collectors.joining("+"));
}
-
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index bf279eb24d8..c856b548180 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -2,6 +2,8 @@
package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.TensorType;
import java.util.Collections;
import java.util.List;
@@ -19,6 +21,10 @@ public class Softmax extends CompositeTensorFunction {
this.argument = argument;
this.dimension = dimension;
}
+
+ public static TensorType outputType(TensorType inputType, String dimension) {
+ return Reduce.outputType(inputType, ImmutableList.of(dimension));
+ }
@Override
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index cabcce198d1..533a46f87fe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -12,7 +12,7 @@ import java.util.List;
* A representation of a tensor function which is able to be translated to a set of primitive
* tensor functions if necessary.
* All tensor functions are immutable.
- *
+ *
* @author bratseth
*/
@Beta
@@ -48,11 +48,11 @@ public abstract class TensorFunction {
/**
* Return a string representation of this context.
- *
+ *
* @param context a context which must be passed to all nexted functions when requesting the string value
*/
public abstract String toString(ToStringContext context);
-
+
@Override
public String toString() { return toString(ToStringContext.empty()); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
index e8c425d49e0..416b28afa22 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
@@ -24,7 +24,7 @@ interface BinaryFormat {
/**
* Deserialize the given binary data into a Tensor object.
- *
+ *
* @param type the expected abstract type of the tensor to serialize, or empty to use type information from the data
* @param buffer the buffer containing the tensor binary data
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 8b7325ec211..aabb53d1c67 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -16,9 +16,9 @@ import java.util.Optional;
*
* Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]*
* Cell_values = [double, double, double, ...]*
- * where values are encoded in order of increasing indexes in each dimension, increasing
+ * where values are encoded in order of increasing indexes in each dimension, increasing
* indexes of later dimensions in the dimension type before earlier.
- *
+ *
* @author bratseth
*/
@Beta
@@ -54,7 +54,7 @@ public class DenseBinaryFormat implements BinaryFormat {
type = optionalType.get();
TensorType serializedType = decodeType(buffer);
if ( ! serializedType.isAssignableTo(type))
- throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
+ throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
" cannot be assigned to type " + type);
sizes = sizesFromType(serializedType);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 7467554790a..01a1d023f2b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -46,9 +46,9 @@ public class TypedBinaryFormat {
return result;
}
- /**
- * Decode some data to a tensor
- *
+ /**
+ * Decode some data to a tensor
+ *
* @param type the type to decode and validate to, or empty to use the type given in the data
* @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array
* @return the resulting tensor
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
index d199dd3a876..abdb3071bf7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -13,14 +13,14 @@ import java.util.stream.Collectors;
/**
* Microbenchmark of tensor operations.
- *
+ *
* @author bratseth
*/
public class TensorFunctionBenchmark {
private final static Random random = new Random();
-
- public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType,
+
+ public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType,
boolean extraSpace) {
Tensor queryVector = vectors(1, 300, dimensionType).get(0);
if (extraSpace) {
@@ -34,7 +34,7 @@ public class TensorFunctionBenchmark {
long totalTime = System.currentTimeMillis() - startTime;
return (double)totalTime / (double)iterations;
}
-
+
private Tensor unitVector(String dimension) {
return Tensor.Builder.of(new TensorType.Builder().indexed(dimension, 1).build())
.cell().label(dimension, 0).value(1).build();
@@ -49,11 +49,11 @@ public class TensorFunctionBenchmark {
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double largest = Double.MIN_VALUE;
- TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
- new VariableTensor("argument"), (a, b) -> a * b),
+ TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
+ new VariableTensor("argument"), (a, b) -> a * b),
Reduce.Aggregator.sum).toPrimitive();
MapEvaluationContext context = new MapEvaluationContext();
-
+
for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor
context.put("argument", tensorElement);
double dotProduct = dotProductFunction.evaluate(context).asDouble();
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 30078b4a826..693b0f09351 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -25,7 +25,7 @@ import static org.junit.Assert.fail;
/**
* Tests tensor functionality
- *
+ *
* @author bratseth
*/
public class TensorTestCase {
@@ -108,7 +108,7 @@ public class TensorTestCase {
Tensor.diag(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build()));
assertEquals(Tensor.from("{ {x:1}:0, {x:3}:1, {x:9}:0 }"), Tensor.from("{ {x:1}:1, {x:3}:5, {x:9}:3 }").argmax("x"));
}
-
+
/** Test the same computation made in various ways which are implemented with special-case optimizations */
@Test
public void testOptimizedComputation() {
@@ -130,7 +130,7 @@ public class TensorTestCase {
assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.indexedUnbound), vectors(Type.mapped, 2)));
assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2)));
assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2)));
-
+
// Test the unoptimized path by joining in another dimension
Tensor unitJ = Tensor.Builder.of(new TensorType.Builder().mapped("j").build()).cell().label("j", 0).value(1).build();
Tensor unitK = Tensor.Builder.of(new TensorType.Builder().mapped("k").build()).cell().label("k", 0).value(1).build();
@@ -138,7 +138,7 @@ public class TensorTestCase {
Tensor matrixInKSpace = matrix(Type.mapped, 2).get(0).multiply(unitK);
assertEquals("Generic computation implementation", 42, (int)dotProduct(vectorInJSpace, Collections.singletonList(matrixInKSpace)));
}
-
+
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double sum = 0;
TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
@@ -161,7 +161,7 @@ public class TensorTestCase {
private Tensor vector(int vectorSize, TensorType.Dimension.Type dimensionType) {
return vectors(vectorSize, dimensionType, 1).get(0);
}
-
+
/** Create a list of vectors having a single dimension x */
private List<Tensor> vectors(TensorType.Dimension.Type dimensionType, int vectorCount) {
return vectors(3, dimensionType, vectorCount);
@@ -179,8 +179,8 @@ public class TensorTestCase {
}
return tensors;
}
-
- /**
+
+ /**
* Create a matrix of vectors (in dimension i) where each vector has the dimension x.
* This matrix contains the same vectors as returned by createVectors, in a single list element for convenience.
*/
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
index fab53218b2c..f11c068bd74 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
@@ -10,12 +10,12 @@ import static org.junit.Assert.assertEquals;
* @author bratseth
*/
public class JoinTestCase {
-
+
/** Test the indexed subspace join optimization */
@Test
public void testJoinIndexedSubspace() {
Tensor t1, t2;
-
+
t1 = Tensor.from("tensor(x[]):{{x:0}:1.0,{x:1}:2.0}");
t2 = Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10,{x:1,y:1,z:0}:0.0}");
assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:20.0,{x:1,y:1,z:0}:0.0}"),
@@ -34,10 +34,10 @@ public class JoinTestCase {
assertEquals(Tensor.from("tensor(x[],y[],z[]):{{x:0,y:0,z:0}:6,{x:0,y:1,z:0}:0.0,{x:1,y:0,z:0}:10.0,{x:1,y:1,z:0}:0.0}"),
t2.divide(t1));
}
-
+
@Test
public void testGeneralJoin() {
- assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"),
+ assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"),
Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:4, {x:2}:6 }")
.divide(Tensor.from("tensor(y[]):{{y:0}:2}")));
@@ -45,5 +45,5 @@ public class JoinTestCase {
Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:6, {x:1,y:0}:8, {x:0,y:1}:20, {x:1,y:1}:24 }")
.divide(Tensor.from("tensor(y[],z[]):{ {y:0,z:0}:2, {y:1,z:0}:4, {y:2,z:0}:6 }")));
}
-
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
new file mode 100644
index 00000000000..9643c0a56e7
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
@@ -0,0 +1,97 @@
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class MatmulTestCase {
+
+ @Test
+ public void testMatmul2d() {
+ // d0 is the 'outermost' dimension, etc.
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3])"));
+ ab.cell( 1,0, 0);
+ ab.cell( 2,0, 1);
+ ab.cell( 3,0, 2);
+ ab.cell( 4,1, 0);
+ ab.cell( 5,1, 1);
+ ab.cell( 6,1, 2);
+ Tensor a = ab.build();
+
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[3],d1[2])"));
+ bb.cell( 7,0, 0);
+ bb.cell( 8,0, 1);
+ bb.cell( 9,1, 0);
+ bb.cell(10,1, 1);
+ bb.cell(11,2, 0);
+ bb.cell(12,2, 1);
+ Tensor b = bb.build();
+
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2])"));
+ rb.cell( 58,0, 0);
+ rb.cell( 64,0, 1);
+ rb.cell(139,1, 0);
+ rb.cell(154,1, 1);
+ Tensor r = rb.build();
+
+ Tensor result = a.matmul(b.rename(ImmutableList.of("d0","d1"), ImmutableList.of("d1","d2")), "d1")
+ .rename("d2","d1");
+ assertEquals(r, result);
+ }
+
+ @Test
+ public void testMatmul3d() {
+ // Convention: a is the 'outermost' dimension, etc.
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[3])"));
+ ab.cell( 1,0, 0, 0);
+ ab.cell( 2,0, 0, 1);
+ ab.cell( 3,0, 0, 2);
+ ab.cell( 4,0, 1, 0);
+ ab.cell( 5,0, 1, 1);
+ ab.cell( 6,0, 1, 2);
+ ab.cell( 7,1, 0, 0);
+ ab.cell( 8,1, 0, 1);
+ ab.cell( 9,1, 0, 2);
+ ab.cell(10,1, 1, 0);
+ ab.cell(11,1, 1, 1);
+ ab.cell(12,1, 1, 2);
+ Tensor a = ab.build();
+
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3],d2[2])"));
+ bb.cell(13,0, 0, 0);
+ bb.cell(14,0, 0, 1);
+ bb.cell(15,0, 1, 0);
+ bb.cell(16,0, 1, 1);
+ bb.cell(17,0, 2, 0);
+ bb.cell(18,0, 2, 1);
+ bb.cell(19,1, 0, 0);
+ bb.cell(20,1, 0, 1);
+ bb.cell(21,1, 1, 0);
+ bb.cell(22,1, 1, 1);
+ bb.cell(23,1, 2, 0);
+ bb.cell(24,1, 2, 1);
+ Tensor b = bb.build();
+
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[2])"));
+ rb.cell( 94,0, 0, 0);
+ rb.cell(100,0, 0, 1);
+ rb.cell(229,0, 1, 0);
+ rb.cell(244,0, 1, 1);
+ rb.cell(508,1, 0, 0);
+ rb.cell(532,1, 0, 1);
+ rb.cell(697,1, 1, 0);
+ rb.cell(730,1, 1, 1);
+ Tensor r = rb.build();
+
+ Tensor result = a.matmul(b.rename(ImmutableList.of("d1","d2"), ImmutableList.of("d2","d3")), "d2")
+ .rename("d3","d2");
+ assertEquals(r, result);
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
index 8a58cb0bbed..55069eaced7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
@@ -7,7 +7,7 @@ import static org.junit.Assert.assertEquals;
/**
* Tests translation of composite to primitive tensor function translation.
- *
+ *
* @author bratseth
*/
public class TensorFunctionTestCase {
@@ -16,12 +16,12 @@ public class TensorFunctionTestCase {
public void testTranslation() {
assertTranslated("join({{x:1}:1.0}, reduce({{x:1}:1.0}, sum, x), f(a,b)(a / b))",
new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x"));
- assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))",
+ assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))",
new Diag(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build()));
assertTranslated("join({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce({{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))",
new Argmax(new ConstantTensor("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x"));
}
-
+
private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) {
assertEquals(expectedTranslation, inputFunction.toPrimitive().toString());
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index 349309a5052..15a872e439f 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -30,7 +30,7 @@ public class DenseBinaryFormatTestCase {
assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}");
}
-
+
@Test
public void testSerializationToSeparateType() {
assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[],y[])"));
@@ -64,7 +64,7 @@ public class DenseBinaryFormatTestCase {
private void assertSerialization(Tensor tensor) {
assertSerialization(tensor, tensor.type());
}
-
+
private void assertSerialization(Tensor tensor, TensorType expectedType) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor));
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
index b1d7d797b3e..33dfca017f4 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
@@ -84,7 +84,7 @@ public class MixedBinaryFormatTestCase {
private void assertSerialization(Tensor tensor) {
assertSerialization(tensor, tensor.type());
}
-
+
private void assertSerialization(Tensor tensor, TensorType expectedType) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor));
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
index 68bf59e3ed9..f002637847b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
@@ -50,7 +50,7 @@ public class SerializationTestCase {
JsonNode node = mapper.readTree(test);
if (node.has("tensor") && node.has("binary")) {
System.out.println("Running test: " + test);
-
+
Tensor tensor = buildTensor(node.get("tensor"));
String spec = getSpec(node.get("tensor"));
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
@@ -123,7 +123,7 @@ public class SerializationTestCase {
private byte[] getBytes(String binaryRepresentation) {
return parseHexValue(binaryRepresentation.substring(2));
}
-
+
private byte[] parseHexValue(String s) {
final int len = s.length();
byte[] bytes = new byte[len/2];
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index d17148cf8dc..f895b64379b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -65,7 +65,7 @@ public class SparseBinaryFormatTestCase {
private void assertSerialization(Tensor tensor, TensorType expectedType) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
- Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType),
+ Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType),
GrowableByteBuffer.wrap(encodedTensor));
assertEquals(tensor, decodedTensor);
}