summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-11-28 21:35:16 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-11-28 21:35:16 +0100
commit1d6791e6fa004ae80e85dbc6a6c7c2e4b8037a4f (patch)
tree650307f35d321145410248f703943ef7525f94fb /vespajlib
parent0606896d63cc8bbe4919c7c37126fb9bc3f6e34e (diff)
parent7e8f8da8f249cf3c529cec8ecdcf13b69c99da13 (diff)
Merge with master
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/concurrent/DaemonThreadFactory.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java441
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java129
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java9
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java155
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java7
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java95
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java149
11 files changed, 1013 insertions, 15 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/concurrent/DaemonThreadFactory.java b/vespajlib/src/main/java/com/yahoo/concurrent/DaemonThreadFactory.java
index 4c15b6e2365..6c5dd5e3ba5 100644
--- a/vespajlib/src/main/java/com/yahoo/concurrent/DaemonThreadFactory.java
+++ b/vespajlib/src/main/java/com/yahoo/concurrent/DaemonThreadFactory.java
@@ -45,4 +45,5 @@ public class DaemonThreadFactory implements ThreadFactory {
}
return t;
}
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 6750c99bf98..c207dabca3a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -125,8 +125,12 @@ public class IndexedTensor implements Tensor {
if (indexes.length == 0) return 0; // for speed
int valueIndex = 0;
- for (int i = 0; i < indexes.length; i++)
+ for (int i = 0; i < indexes.length; i++) {
+ if (indexes[i] >= sizes.size(i)) {
+ throw new IndexOutOfBoundsException();
+ }
valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i];
+ }
return valueIndex;
}
@@ -134,8 +138,12 @@ public class IndexedTensor implements Tensor {
if (address.isEmpty()) return 0;
int valueIndex = 0;
- for (int i = 0; i < address.size(); i++)
+ for (int i = 0; i < address.size(); i++) {
+ if (address.intLabel(i) >= sizes.size(i)) {
+ throw new IndexOutOfBoundsException();
+ }
valueIndex += productOfDimensionsAfter(i, sizes) * address.intLabel(i);
+ }
return valueIndex;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
new file mode 100644
index 00000000000..79bb27fcd1b
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -0,0 +1,441 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor;
+
+import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * A mixed tensor type. This is class is currently suitable for serialization
+ * and deserialization, not yet for computation.
+ *
+ * A mixed tensor has a combination of mapped and indexed dimensions. By
+ * reordering the mapped dimensions before the indexed dimensions, one can
+ * think of mixed tensors as the mapped dimensions mapping to a
+ * dense tensor. This dense tensor is called a dense subspace.
+ *
+ * @author lesters
+ */
+@Beta
+public class MixedTensor implements Tensor {
+
+ /** The dimension specification for this tensor */
+ private final TensorType type;
+
+ /** The list of cells in the tensor */
+ private final ImmutableList<Cell> cells;
+
+ /** An index structure over the cell list */
+ private final Index index;
+
+ private MixedTensor(TensorType type, ImmutableList<Cell> cells, Index index) {
+ this.type = type;
+ this.cells = ImmutableList.copyOf(cells);
+ this.index = index;
+ }
+
+ /** Returns the tensor type */
+ @Override
+ public TensorType type() { return type; }
+
+ /** Returns the size of the tensor measured in number of cells */
+ @Override
+ public int size() { return cells.size(); }
+
+ /** Returns the value at the given address */
+ @Override
+ public double get(TensorAddress address) {
+ int cellIndex = index.indexOf(address);
+ Cell cell = cells.get(cellIndex);
+ if (!address.equals(cell.getKey())) {
+ throw new IllegalStateException("Unable to find correct cell by direct index.");
+ }
+ return cell.getValue();
+ }
+
+ /**
+ * Returns an iterator over the cells of this tensor.
+ * Cells are returned in order of increasing indexes in the
+ * indexed dimensions, increasing indexes of later dimensions
+ * in the dimension type before earlier. No guarantee is
+ * given for the order of sparse dimensions.
+ */
+ @Override
+ public Iterator<Cell> cellIterator() {
+ return cells.iterator();
+ }
+
+ /**
+ * Returns an iterator over the values of this tensor.
+ * The iteration order is the same as for cellIterator.
+ */
+ @Override
+ public Iterator<Double> valueIterator() {
+ return new Iterator<Double>() {
+ Iterator<Cell> cellIterator = cellIterator();
+ @Override
+ public boolean hasNext() {
+ return cellIterator.hasNext();
+ }
+ @Override
+ public Double next() {
+ return cellIterator.next().getValue();
+ }
+ };
+ }
+
+ @Override
+ public Map<TensorAddress, Double> cells() {
+ ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
+ for (Cell cell : cells) {
+ builder.put(cell.getKey(), cell.getValue());
+ }
+ return builder.build();
+ }
+
+ @Override
+ public int hashCode() { return cells.hashCode(); }
+
+ @Override
+ public String toString() { return Tensor.toStandardString(this); }
+
+ @Override
+ public boolean equals(Object other) {
+ if ( ! ( other instanceof Tensor)) return false;
+ return Tensor.equals(this, ((Tensor)other));
+ }
+
+ /** Returns the size of dense subspaces */
+ public int denseSubspaceSize() {
+ return index.denseSubspaceSize();
+ }
+
+
+ /**
+ * Base class for building mixed tensors.
+ */
+ public abstract static class Builder implements Tensor.Builder {
+
+ final TensorType type;
+
+ /**
+ * Create a builder depending upon the type of indexed dimensions.
+ * If at least one indexed dimension is unbound, we create
+ * a temporary structure while finding dimension bounds.
+ */
+ public static Builder of(TensorType type) {
+ if (type.dimensions().stream().anyMatch(d -> d instanceof TensorType.IndexedUnboundDimension)) {
+ return new UnboundBuilder(type);
+ } else {
+ return new BoundBuilder(type);
+ }
+ }
+
+ private Builder(TensorType type) {
+ this.type = type;
+ }
+
+ @Override
+ public TensorType type() {
+ return type;
+ }
+
+ @Override
+ public Tensor.Builder cell(double value, int... labels) {
+ throw new UnsupportedOperationException("Not implemented.");
+ }
+
+ @Override
+ public CellBuilder cell() {
+ return new CellBuilder(type(), this);
+ }
+
+ @Override
+ public abstract MixedTensor build();
+
+ }
+
+
+ /**
+ * Builder for mixed tensors with bound indexed dimensions.
+ */
+ public static class BoundBuilder extends Builder {
+
+ /** For each sparse partial address, hold a dense subspace */
+ final private Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>();
+ final private Index.Builder indexBuilder;
+ final private Index index;
+
+ private BoundBuilder(TensorType type) {
+ super(type);
+ indexBuilder = new Index.Builder(type);
+ index = indexBuilder.index();
+ }
+
+ public int denseSubspaceSize() {
+ return index.denseSubspaceSize();
+ }
+
+ private double[] denseSubspace(TensorAddress sparsePartial) {
+ if (!denseSubspaceMap.containsKey(sparsePartial)) {
+ denseSubspaceMap.put(sparsePartial, new double[denseSubspaceSize()]);
+ }
+ return denseSubspaceMap.get(sparsePartial);
+ }
+
+ @Override
+ public Tensor.Builder cell(TensorAddress address, double value) {
+ TensorAddress sparsePart = index.sparsePartialAddress(address);
+ int denseOffset = index.denseOffset(address);
+ double[] denseSubspace = denseSubspace(sparsePart);
+ denseSubspace[denseOffset] = value;
+ return this;
+ }
+
+ public Tensor.Builder block(TensorAddress sparsePart, double[] values) {
+ double[] denseSubspace = denseSubspace(sparsePart);
+ System.arraycopy(values, 0, denseSubspace, 0, denseSubspaceSize());
+ return this;
+ }
+
+ @Override
+ public MixedTensor build() {
+ int count = 0;
+ ImmutableList.Builder<Cell> builder = new ImmutableList.Builder<>();
+
+ for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) {
+ TensorAddress sparsePart = entry.getKey();
+ indexBuilder.put(sparsePart, count);
+
+ double[] denseSubspace = entry.getValue();
+ for (int offset = 0; offset < denseSubspace.length; ++offset) {
+ TensorAddress cellAddress = index.addressOf(sparsePart, offset);
+ double value = denseSubspace[offset];
+ builder.add(new Cell(cellAddress, value));
+ count++;
+ }
+ }
+ return new MixedTensor(type, builder.build(), indexBuilder.build());
+ }
+
+ }
+
+
+ /**
+ * Temporarily stores all cells to find bounds of indexed dimensions,
+ * then creates a tensor using BoundBuilder. This is due to the
+ * fact that for serialization the size of the dense subspace must be
+ * known, and equal for all dense subspaces. A side effect is that the
+ * tensor type is effectively changed, such that unbound indexed
+ * dimensions become bound.
+ */
+ public static class UnboundBuilder extends Builder {
+
+ private Map<TensorAddress, Double> cells;
+ private final int[] dimensionBounds;
+
+ private UnboundBuilder(TensorType type) {
+ super(type);
+ cells = new HashMap<>();
+ dimensionBounds = new int[type.dimensions().size()];
+ }
+
+ @Override
+ public Tensor.Builder cell(TensorAddress address, double value) {
+ cells.put(address, value);
+ trackBounds(address);
+ return this;
+ }
+
+ @Override
+ public MixedTensor build() {
+ TensorType boundType = createBoundType();
+ BoundBuilder builder = new BoundBuilder(boundType);
+ for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) {
+ builder.cell(cell.getKey(), cell.getValue());
+ }
+ return builder.build();
+ }
+
+ public void trackBounds(TensorAddress address) {
+ for (int i = 0; i < type.dimensions().size(); ++i) {
+ TensorType.Dimension dimension = type.dimensions().get(i);
+ if (dimension.isIndexed()) {
+ dimensionBounds[i] = Math.max(address.intLabel(i), dimensionBounds[i]);
+ }
+ }
+ }
+
+ public TensorType createBoundType() {
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (int i = 0; i < type.dimensions().size(); ++i) {
+ TensorType.Dimension dimension = type.dimensions().get(i);
+ if (!dimension.isIndexed()) {
+ typeBuilder.mapped(dimension.name());
+ } else {
+ int size = dimension.size().orElse(dimensionBounds[i] + 1);
+ typeBuilder.indexed(dimension.name(), size);
+ }
+ }
+ return typeBuilder.build();
+ }
+
+ }
+
+ /**
+ * An immutable index into a list of cells.
+ * Contains additional information required
+ * for handling mixed tensor addresses.
+ * Assumes indexed dimensions are bound.
+ */
+ private static class Index {
+
+ private final TensorType type;
+ private final TensorType sparseType;
+ private final TensorType denseType;
+ private final List<TensorType.Dimension> mappedDimensions;
+ private final List<TensorType.Dimension> indexedDimensions;
+
+ private ImmutableMap<TensorAddress, Integer> sparseMap;
+ private int denseSubspaceSize = -1;
+
+ private Index(TensorType type) {
+ this.type = type;
+ this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
+ this.indexedDimensions = type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList());
+ this.sparseType = createPartialType(mappedDimensions);
+ this.denseType = createPartialType(indexedDimensions);
+ }
+
+ public int indexOf(TensorAddress address) {
+ TensorAddress sparsePart = sparsePartialAddress(address);
+ if (!sparseMap.containsKey(sparsePart)) {
+ throw new IllegalArgumentException("Address not found");
+ }
+ int base = sparseMap.get(sparsePart);
+ int offset = denseOffset(address);
+ return base + offset;
+ }
+
+ public static class Builder {
+ private final Index index;
+ private final ImmutableMap.Builder<TensorAddress, Integer> builder;
+
+ public Builder(TensorType type) {
+ index = new Index(type);
+ builder = new ImmutableMap.Builder<>();
+ }
+
+ public void put(TensorAddress address, int index) {
+ builder.put(address, index);
+ }
+
+ public Index build() {
+ index.sparseMap = builder.build();
+ return index;
+ }
+
+ public Index index() {
+ return index;
+ }
+ }
+
+ public int denseSubspaceSize() {
+ if (denseSubspaceSize == -1) {
+ denseSubspaceSize = 1;
+ for (int i = 0; i < type.dimensions().size(); ++i) {
+ TensorType.Dimension dimension = type.dimensions().get(i);
+ if (dimension.isIndexed()) {
+ denseSubspaceSize *= dimension.size().orElseThrow(() ->
+ new IllegalArgumentException("Unknown size of indexed dimension."));
+ }
+ }
+ }
+ return denseSubspaceSize;
+ }
+
+ private TensorAddress sparsePartialAddress(TensorAddress address) {
+ if (type.dimensions().size() != address.size()) {
+ throw new IllegalArgumentException("Tensor type and address are not of same size.");
+ }
+ TensorAddress.Builder builder = new TensorAddress.Builder(sparseType);
+ for (int i = 0; i < type.dimensions().size(); ++i) {
+ TensorType.Dimension dimension = type.dimensions().get(i);
+ if (!dimension.isIndexed()) {
+ builder.add(dimension.name(), address.label(i));
+ }
+ }
+ return builder.build();
+ }
+
+ private int denseOffset(TensorAddress address) {
+ int innerSize = 1;
+ int offset = 0;
+ for (int i = type.dimensions().size(); --i >= 0; ) {
+ TensorType.Dimension dimension = type.dimensions().get(i);
+ if (dimension.isIndexed()) {
+ int label = address.intLabel(i);
+ offset += label * innerSize;
+ innerSize *= dimension.size().orElseThrow(() ->
+ new IllegalArgumentException("Unknown size of indexed dimension."));
+ }
+ }
+ return offset;
+ }
+
+ private TensorAddress denseOffsetToAddress(int denseOffset) {
+ if (denseOffset < 0 || denseOffset > denseSubspaceSize) {
+ throw new IllegalArgumentException("Offset out of bounds");
+ }
+
+ int restSize = denseOffset;
+ int innerSize = denseSubspaceSize;
+ int[] labels = new int[indexedDimensions.size()];
+
+ for (int i = 0; i < labels.length; ++i) {
+ TensorType.Dimension dimension = indexedDimensions.get(i);
+ int dimensionSize = dimension.size().orElseThrow(() ->
+ new IllegalArgumentException("Unknown size of indexed dimension."));
+
+ innerSize /= dimensionSize;
+ labels[i] = restSize / innerSize;
+ restSize %= innerSize;
+ }
+ return TensorAddress.of(labels);
+ }
+
+ private TensorAddress addressOf(TensorAddress sparsePart, int denseOffset) {
+ TensorAddress densePart = denseOffsetToAddress(denseOffset);
+ String[] labels = new String[type.dimensions().size()];
+ int mappedIndex = 0;
+ int indexedIndex = 0;
+ for (TensorType.Dimension d : type.dimensions()) {
+ if (d.isIndexed()) {
+ labels[mappedIndex + indexedIndex] = densePart.label(indexedIndex);
+ indexedIndex++;
+ } else {
+ labels[mappedIndex + indexedIndex] = sparsePart.label(mappedIndex);
+ mappedIndex++;
+ }
+ }
+ return TensorAddress.of(labels);
+ }
+
+ }
+
+ public static TensorType createPartialType(List<TensorType.Dimension> dimensions) {
+ TensorType.Builder builder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : dimensions) {
+ builder.set(dimension);
+ }
+ return builder.build();
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 5e3af70cba4..2ed211539d8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -171,12 +171,16 @@ public interface Tensor {
default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); }
default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); }
default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); }
+ default Tensor pow(Tensor argument) { return join(argument, Math::pow); }
+ default Tensor fmod(Tensor argument) { return join(argument, (a, b) -> ( a % b )); }
+ default Tensor ldexp(Tensor argument) { return join(argument, (a, b) -> ( a * Math.pow(2.0, (int)b) )); }
default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); }
default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); }
default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); }
default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); }
default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); }
default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); }
+ default Tensor approxEqual(Tensor argument) { return join(argument, (a, b) -> ( approxEquals(a,b) ? 1.0 : 0.0)); }
default Tensor avg(String dimension) { return avg(Collections.singletonList(dimension)); }
default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); }
@@ -259,11 +263,27 @@ public interface Tensor {
if ( a.size() != b.size()) return false;
for (Iterator<Cell> aIterator = a.cellIterator(); aIterator.hasNext(); ) {
Cell aCell = aIterator.next();
- if ( ! aCell.getValue().equals(b.get(aCell.getKey()))) return false;
+ double aValue = aCell.getValue();
+ double bValue = b.get(aCell.getKey());
+ if (!approxEquals(aValue, bValue, 1e-6)) return false;
}
return true;
}
+ static boolean approxEquals(double x, double y, double tolerance) {
+ return Math.abs(x-y) < tolerance;
+ }
+
+ static boolean approxEquals(double x, double y) {
+ if (y < -1.0 || y > 1.0) {
+ x = Math.nextAfter(x/y, 1.0);
+ y = 1.0;
+ } else {
+ x = Math.nextAfter(x, y);
+ }
+ return x==y;
+ }
+
// ----------------- Factories
/**
@@ -347,7 +367,7 @@ public interface Tensor {
boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed());
boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
if (containsIndexed && containsMapped)
- throw new IllegalArgumentException("Combining indexed and mapped dimensions is not supported yet");
+ return MixedTensor.Builder.of(type);
if (containsMapped)
return MappedTensor.Builder.of(type);
else // indexed or empty
@@ -359,7 +379,7 @@ public interface Tensor {
boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed());
boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
if (containsIndexed && containsMapped)
- throw new IllegalArgumentException("Combining indexed and mapped dimensions is not supported yet");
+ return MixedTensor.Builder.of(type);
if (containsMapped)
return MappedTensor.Builder.of(type);
else // indexed or empty
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 401f9a10eda..1dbb94fdb20 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -134,9 +134,7 @@ public class Concat extends PrimitiveTensorFunction {
if (currentDimension.equals(concatDimension))
concatSizes.set(i, aSize + bSize);
else if (aSize != 0 && bSize != 0 && aSize!=bSize )
- throw new IllegalArgumentException("Dimension " + currentDimension + " must be of the same size when " +
- "concatenating " + a.type() + " and " + b.type() + " along dimension " +
- concatDimension + ", but was " + aSize + " and " + bSize);
+ concatSizes.set(i, Math.min(aSize, bSize));
else
concatSizes.set(i, Math.max(aSize, bSize));
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
new file mode 100644
index 00000000000..61dfa888567
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -0,0 +1,129 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.serialization;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.tensor.MixedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * Implementation of a mixed binary format for a tensor.
+ * See eval/src/vespa/eval/tensor/serialization/format.txt for format.
+ *
+ * @author lesters
+ */
+@Beta
+class MixedBinaryFormat implements BinaryFormat {
+
+ @Override
+ public void encode(GrowableByteBuffer buffer, Tensor tensor) {
+ if ( ! ( tensor instanceof MixedTensor))
+ throw new RuntimeException("The mixed format is only supported for mixed tensors");
+ MixedTensor mixed = (MixedTensor) tensor;
+ encodeSparseDimensions(buffer, mixed);
+ encodeDenseDimensions(buffer, mixed);
+ encodeCells(buffer, mixed);
+ }
+
+ private void encodeSparseDimensions(GrowableByteBuffer buffer, MixedTensor tensor) {
+ List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
+ buffer.putInt1_4Bytes(sparseDimensions.size());
+ for (TensorType.Dimension dimension : sparseDimensions) {
+ buffer.putUtf8String(dimension.name());
+ }
+ }
+
+ private void encodeDenseDimensions(GrowableByteBuffer buffer, MixedTensor tensor) {
+ List<TensorType.Dimension> denseDimensions = tensor.type().dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList());
+ buffer.putInt1_4Bytes(denseDimensions.size());
+ for (TensorType.Dimension dimension : denseDimensions) {
+ buffer.putUtf8String(dimension.name());
+ buffer.putInt1_4Bytes(dimension.size().orElseThrow(() ->
+ new IllegalArgumentException("Unknown size of indexed dimension.")));
+ }
+ }
+
+ private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor) {
+ List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
+ int denseSubspaceSize = tensor.denseSubspaceSize();
+ if (sparseDimensions.size() > 0) {
+ buffer.putInt1_4Bytes(tensor.size() / denseSubspaceSize);
+ }
+ Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
+ while (cellIterator.hasNext()) {
+ Tensor.Cell cell = cellIterator.next();
+ for (TensorType.Dimension dimension : sparseDimensions) {
+ int index = tensor.type().indexOfDimension(dimension.name()).orElseThrow(() ->
+ new IllegalStateException("Dimension not found in address."));
+ buffer.putUtf8String(cell.getKey().label(index));
+ }
+ buffer.putDouble(cell.getValue());
+ for (int i = 1; i < denseSubspaceSize; ++i ) {
+ buffer.putDouble(cellIterator.next().getValue());
+ }
+ }
+ }
+
+ @Override
+ public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) {
+ TensorType type;
+ if (optionalType.isPresent()) {
+ type = optionalType.get();
+ TensorType serializedType = decodeType(buffer);
+ if ( ! serializedType.isAssignableTo(type))
+ throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
+ " cannot be assigned to type " + type);
+ }
+ else {
+ type = decodeType(buffer);
+ }
+ MixedTensor.BoundBuilder builder = (MixedTensor.BoundBuilder)MixedTensor.Builder.of(type);
+ decodeCells(buffer, builder, type);
+ return builder.build();
+ }
+
+ private TensorType decodeType(GrowableByteBuffer buffer) {
+ TensorType.Builder builder = new TensorType.Builder();
+ int numMappedDimensions = buffer.getInt1_4Bytes();
+ for (int i = 0; i < numMappedDimensions; ++i) {
+ builder.mapped(buffer.getUtf8String());
+ }
+ int numIndexedDimensions = buffer.getInt1_4Bytes();
+ for (int i = 0; i < numIndexedDimensions; ++i) {
+ builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes());
+ }
+ return builder.build();
+ }
+
+ private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) {
+ List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
+ TensorType sparseType = MixedTensor.createPartialType(sparseDimensions);
+ int denseSubspaceSize = builder.denseSubspaceSize();
+
+ int numBlocks = 1;
+ if (sparseDimensions.size() > 0) {
+ numBlocks = buffer.getInt1_4Bytes();
+ }
+
+ double[] denseSubspace = new double[denseSubspaceSize];
+ for (int i = 0; i < numBlocks; ++i) {
+ TensorAddress.Builder sparseAddress = new TensorAddress.Builder(sparseType);
+ for (TensorType.Dimension sparseDimension : sparseDimensions) {
+ sparseAddress.add(sparseDimension.name(), buffer.getUtf8String());
+ }
+ for (int denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) {
+ denseSubspace[denseOffset] = buffer.getDouble();
+ }
+ builder.block(sparseAddress.build(), denseSubspace);
+ }
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 657d262b401..7467554790a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor.serialization;
import com.google.common.annotations.Beta;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -23,10 +24,15 @@ public class TypedBinaryFormat {
private static final int SPARSE_BINARY_FORMAT_TYPE = 1;
private static final int DENSE_BINARY_FORMAT_TYPE = 2;
+ private static final int MIXED_BINARY_FORMAT_TYPE = 3;
public static byte[] encode(Tensor tensor) {
GrowableByteBuffer buffer = new GrowableByteBuffer();
- if (tensor instanceof IndexedTensor) {
+ if (tensor instanceof MixedTensor) {
+ buffer.putInt1_4Bytes(MIXED_BINARY_FORMAT_TYPE);
+ new MixedBinaryFormat().encode(buffer, tensor);
+ }
+ else if (tensor instanceof IndexedTensor) {
buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE);
new DenseBinaryFormat().encode(buffer, tensor);
}
@@ -51,6 +57,7 @@ public class TypedBinaryFormat {
public static Tensor decode(Optional<TensorType> type, GrowableByteBuffer buffer) {
int formatType = buffer.getInt1_4Bytes();
switch (formatType) {
+ case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat().decode(type, buffer);
case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer);
case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat().decode(type, buffer);
default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown");
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java
new file mode 100644
index 00000000000..fef8f05f4e1
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java
@@ -0,0 +1,155 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor;
+
+import com.google.common.collect.Sets;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Basic mixed tensor tests. Tensor operations are tested in EvaluationTestCase
+ *
+ * @author lesters
+ */
+public class MixedTensorTestCase {
+
+ @Test
+ public void testEmpty() {
+ TensorType type = new TensorType.Builder().mapped("x").indexed("y", 3).build();
+ Tensor empty = Tensor.Builder.of(type).build();
+ assertTrue(empty instanceof MixedTensor);
+ assertTrue(empty.isEmpty());
+ assertEquals("tensor(x{},y[3]):{}", empty.toString());
+ assertEquals("tensor(x{},y[3]):{}", Tensor.from("tensor(x{},y[3]):{}").toString());
+ }
+
+ @Test
+ public void testScalar() {
+ TensorType type = new TensorType.Builder().build();
+ Tensor scalar = MixedTensor.Builder.of(type).cell().value(42.0).build();
+ assertEquals(scalar.asDouble(), 42.0, 1e-6);
+ }
+
+ @Test
+ public void testOneIndexedBuilding() {
+ TensorType type = new TensorType.Builder().indexed("y", 3).build();
+ Tensor tensor = MixedTensor.Builder.of(type).
+ cell().label("y", 0).value(1).
+ cell().label("y", 1).value(2).
+ // {y:2} should be 0.0 and non NaN since we specify indexed size
+ build();
+ assertEquals(Sets.newHashSet("y"), tensor.type().dimensionNames());
+ assertEquals("{{y:0}:1.0,{y:1}:2.0,{y:2}:0.0}",
+ tensor.toString());
+ }
+
+ @Test
+ public void testTwoIndexedBuilding() {
+ TensorType type = new TensorType.Builder().indexed("x").indexed("y", 3).build();
+ Tensor tensor = MixedTensor.Builder.of(type).
+ cell().label("x", 0).label("y", 0).value(1).
+ cell().label("x", 0).label("y", 1).value(2).
+ // {x:1,y:2} should be 0.0 and non NaN since we specify indexed size
+ cell().label("x", 1).label("y", 0).value(4).
+ cell().label("x", 1).label("y", 1).value(5).
+ cell().label("x", 1).label("y", 2).value(6).
+ build();
+ assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames());
+ assertEquals("{{x:0,y:0}:1.0,{x:0,y:1}:2.0,{x:0,y:2}:0.0,{x:1,y:0}:4.0,{x:1,y:1}:5.0,{x:1,y:2}:6.0}",
+ tensor.toString());
+ }
+
+ @Test
+ public void testOneMappedBuilding() {
+ TensorType type = new TensorType.Builder().mapped("x").build();
+ Tensor tensor = MixedTensor.Builder.of(type).
+ cell().label("x", "0").value(1).
+ cell().label("x", "1").value(2).
+ build();
+ assertEquals(Sets.newHashSet("x"), tensor.type().dimensionNames());
+ assertEquals("{{x:0}:1.0,{x:1}:2.0}",
+ tensor.toString());
+ }
+
+ @Test
+ public void testTwoMappedBuilding() {
+ TensorType type = new TensorType.Builder().mapped("x").mapped("y").build();
+ Tensor tensor = MixedTensor.Builder.of(type).
+ cell().label("x", "0").label("y", "0").value(1).
+ cell().label("x", "0").label("y", "1").value(2).
+ cell().label("x", "1").label("y", "0").value(4).
+ cell().label("x", "1").label("y", "1").value(5).
+ cell().label("x", "1").label("y", "2").value(6).
+ build();
+ assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames());
+ assertEquals("{{x:0,y:0}:1.0,{x:0,y:1}:2.0,{x:1,y:0}:4.0,{x:1,y:1}:5.0,{x:1,y:2}:6.0}",
+ tensor.toString());
+ }
+
+ @Test
+ public void testOneMappedOneIndexedBuilding() {
+ TensorType type = new TensorType.Builder().mapped("x").indexed("y", 3).build();
+ Tensor tensor = MixedTensor.Builder.of(type).
+ cell().label("x", "1").label("y", 0).value(1).
+ cell().label("x", "1").label("y", 1).value(2).
+ // {x:1,y:2} should be 0.0 and non NaN since we specify indexed size
+ cell().label("x", "2").label("y", 0).value(4).
+ cell().label("x", "2").label("y", 1).value(5).
+ cell().label("x", "2").label("y", 2).value(6).
+ build();
+ assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames());
+ assertEquals("{{x:1,y:0}:1.0,{x:1,y:1}:2.0,{x:1,y:2}:0.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}",
+ tensor.toString());
+ }
+
+ @Test
+ public void testTwoMappedOneIndexedBuilding() {
+ TensorType type = new TensorType.Builder().mapped("x").indexed("y").mapped("z").build();
+ Tensor tensor = Tensor.Builder.of(type).
+ cell().label("x", "x1").label("y", 0).label("z","z1").value(1).
+ cell().label("x", "x1").label("y", 0).label("z","z2").value(2).
+ cell().label("x", "x1").label("y", 1).label("z","z1").value(3).
+ cell().label("x", "x1").label("y", 1).label("z","z2").value(4).
+ cell().label("x", "x1").label("y", 2).label("z","z1").value(5).
+ cell().label("x", "x1").label("y", 2).label("z","z2").value(6).
+ cell().label("x", "x2").label("y", 0).label("z","z1").value(11).
+ cell().label("x", "x2").label("y", 0).label("z","z2").value(12).
+ cell().label("x", "x2").label("y", 1).label("z","z1").value(13).
+ cell().label("x", "x2").label("y", 1).label("z","z2").value(14).
+ cell().label("x", "x2").label("y", 2).label("z","z1").value(15).
+ cell().label("x", "x2").label("y", 2).label("z","z2").value(16).
+ build();
+ assertEquals(Sets.newHashSet("x", "y", "z"), tensor.type().dimensionNames());
+ assertEquals("{{x:x1,y:0,z:z1}:1.0,{x:x1,y:0,z:z2}:2.0,{x:x1,y:1,z:z1}:3.0,{x:x1,y:1,z:z2}:4.0,{x:x1,y:2,z:z1}:5.0,{x:x1,y:2,z:z2}:6.0,{x:x2,y:0,z:z1}:11.0,{x:x2,y:0,z:z2}:12.0,{x:x2,y:1,z:z1}:13.0,{x:x2,y:1,z:z2}:14.0,{x:x2,y:2,z:z1}:15.0,{x:x2,y:2,z:z2}:16.0}",
+ tensor.toString());
+ }
+
+ @Test
+ public void testTwoMappedTwoIndexedBuilding() {
+ TensorType type = new TensorType.Builder().mapped("i").indexed("j", 2).mapped("k").indexed("l", 2).build();
+ Tensor tensor = Tensor.Builder.of(type).
+ cell().label("i", "a").label("k","c").label("j",0).label("l",0).value(1).
+ cell().label("i", "a").label("k","c").label("j",0).label("l",1).value(2).
+ cell().label("i", "a").label("k","c").label("j",1).label("l",0).value(3).
+ cell().label("i", "a").label("k","c").label("j",1).label("l",1).value(4).
+ cell().label("i", "a").label("k","d").label("j",0).label("l",0).value(5).
+ cell().label("i", "a").label("k","d").label("j",0).label("l",1).value(6).
+ cell().label("i", "a").label("k","d").label("j",1).label("l",0).value(7).
+ cell().label("i", "a").label("k","d").label("j",1).label("l",1).value(8).
+ cell().label("i", "b").label("k","c").label("j",0).label("l",0).value(9).
+ cell().label("i", "b").label("k","c").label("j",0).label("l",1).value(10).
+ cell().label("i", "b").label("k","c").label("j",1).label("l",0).value(11).
+ cell().label("i", "b").label("k","c").label("j",1).label("l",1).value(12).
+ cell().label("i", "b").label("k","d").label("j",0).label("l",0).value(13).
+ cell().label("i", "b").label("k","d").label("j",0).label("l",1).value(14).
+ cell().label("i", "b").label("k","d").label("j",1).label("l",0).value(15).
+ cell().label("i", "b").label("k","d").label("j",1).label("l",1).value(16).
+ build();
+ assertEquals(Sets.newHashSet("i", "j", "k", "l"), tensor.type().dimensionNames());
+ assertEquals("{{i:a,j:0,k:c,l:0}:1.0,{i:a,j:0,k:c,l:1}:2.0,{i:a,j:0,k:d,l:0}:5.0,{i:a,j:0,k:d,l:1}:6.0,{i:a,j:1,k:c,l:0}:3.0,{i:a,j:1,k:c,l:1}:4.0,{i:a,j:1,k:d,l:0}:7.0,{i:a,j:1,k:d,l:1}:8.0,{i:b,j:0,k:c,l:0}:9.0,{i:b,j:0,k:c,l:1}:10.0,{i:b,j:0,k:d,l:0}:13.0,{i:b,j:0,k:d,l:1}:14.0,{i:b,j:1,k:c,l:0}:11.0,{i:b,j:1,k:c,l:1}:12.0,{i:b,j:1,k:d,l:0}:15.0,{i:b,j:1,k:d,l:1}:16.0}",
+ tensor.toString());
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
index a653ef97734..7e1f292eb7b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
@@ -43,12 +43,7 @@ public class ConcatTestCase {
Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }");
Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x"));
- try {
- a.concat(b, "y");
- fail("Expected exception");
- } catch (IllegalArgumentException expected) {
- // success
- }
+ assertEquals(Tensor.from("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }"), a.concat(b, "y"));
}
@Test
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
new file mode 100644
index 00000000000..b1d7d797b3e
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
@@ -0,0 +1,95 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.serialization;
+
+import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.tensor.MixedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import java.util.Optional;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for the mixed binary format.
+ *
+ * @author lesters
+ */
+public class MixedBinaryFormatTestCase {
+
+ @Test
+ public void testSerialization() {
+ assertSerialization("tensor(x{},y[3]):{{x:1,y:0}:1.0,{x:1,y:1}:2.0,{x:1,y:2}:0.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}");
+ assertSerialization("tensor(x{},y[]):{{x:1,y:0}:1.0,{x:1,y:1}:2.0,{x:1,y:2}:0.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}");
+
+ assertSerialization("tensor(x{},y[3],z{}):{{x:x1,y:0,z:z1}:1.0,{x:x1,y:0,z:z2}:2.0,{x:x1,y:1,z:z1}:3.0,{x:x1,y:1,z:z2}:4.0,{x:x1,y:2,z:z1}:5.0,{x:x1,y:2,z:z2}:6.0,{x:x2,y:0,z:z1}:11.0,{x:x2,y:0,z:z2}:12.0,{x:x2,y:1,z:z1}:13.0,{x:x2,y:1,z:z2}:14.0,{x:x2,y:2,z:z1}:15.0,{x:x2,y:2,z:z2}:16.0}");
+ assertSerialization("tensor(x{},y[],z{}):{{x:x1,y:0,z:z1}:1.0,{x:x1,y:0,z:z2}:2.0,{x:x1,y:1,z:z1}:3.0,{x:x1,y:1,z:z2}:4.0,{x:x1,y:2,z:z1}:5.0,{x:x1,y:2,z:z2}:6.0,{x:x2,y:0,z:z1}:11.0,{x:x2,y:0,z:z2}:12.0,{x:x2,y:1,z:z1}:13.0,{x:x2,y:1,z:z2}:14.0,{x:x2,y:2,z:z1}:15.0,{x:x2,y:2,z:z2}:16.0}");
+
+ assertSerialization("tensor(i{},j[2],k{},l[2]):{{i:a,j:0,k:c,l:0}:1.0,{i:a,j:0,k:c,l:1}:2.0,{i:a,j:0,k:d,l:0}:5.0,{i:a,j:0,k:d,l:1}:6.0,{i:a,j:1,k:c,l:0}:3.0,{i:a,j:1,k:c,l:1}:4.0,{i:a,j:1,k:d,l:0}:7.0,{i:a,j:1,k:d,l:1}:8.0,{i:b,j:0,k:c,l:0}:9.0,{i:b,j:0,k:c,l:1}:10.0,{i:b,j:0,k:d,l:0}:13.0,{i:b,j:0,k:d,l:1}:14.0,{i:b,j:1,k:c,l:0}:11.0,{i:b,j:1,k:c,l:1}:12.0,{i:b,j:1,k:d,l:0}:15.0,{i:b,j:1,k:d,l:1}:16.0}");
+ assertSerialization("tensor(i{},j[],k{},l[]):{{i:a,j:0,k:c,l:0}:1.0,{i:a,j:0,k:c,l:1}:2.0,{i:a,j:0,k:d,l:0}:5.0,{i:a,j:0,k:d,l:1}:6.0,{i:a,j:1,k:c,l:0}:3.0,{i:a,j:1,k:c,l:1}:4.0,{i:a,j:1,k:d,l:0}:7.0,{i:a,j:1,k:d,l:1}:8.0,{i:b,j:0,k:c,l:0}:9.0,{i:b,j:0,k:c,l:1}:10.0,{i:b,j:0,k:d,l:0}:13.0,{i:b,j:0,k:d,l:1}:14.0,{i:b,j:1,k:c,l:0}:11.0,{i:b,j:1,k:c,l:1}:12.0,{i:b,j:1,k:d,l:0}:15.0,{i:b,j:1,k:d,l:1}:16.0}");
+ }
+
+ @Test
+ public void testOneIndexedSerialization() {
+ TensorType type = new TensorType.Builder().indexed("y", 3).build();
+ Tensor tensor = MixedTensor.Builder.of(type).
+ cell().label("y", 0).value(1).
+ cell().label("y", 1).value(2).
+ build();
+ assertSerialization(tensor);
+ }
+
+ @Test
+ public void testTwoIndexedSerialization() {
+ TensorType type = new TensorType.Builder().indexed("x").indexed("y", 3).build();
+ Tensor tensor = MixedTensor.Builder.of(type).
+ cell().label("x", 0).label("y", 0).value(1).
+ cell().label("x", 0).label("y", 1).value(2).
+ cell().label("x", 1).label("y", 0).value(4).
+ cell().label("x", 1).label("y", 1).value(5).
+ cell().label("x", 1).label("y", 2).value(6).
+ build();
+ assertSerialization(tensor);
+ }
+
+ @Test
+ public void testOneMappedSerialization() {
+ TensorType type = new TensorType.Builder().mapped("x").build();
+ Tensor tensor = MixedTensor.Builder.of(type).
+ cell().label("x", "0").value(1).
+ cell().label("x", "1").value(2).
+ build();
+ assertSerialization(tensor);
+ }
+
+ @Test
+ public void testTwoMappedSerialization() {
+ TensorType type = new TensorType.Builder().mapped("x").mapped("y").build();
+ Tensor tensor = MixedTensor.Builder.of(type).
+ cell().label("x", "0").label("y", "0").value(1).
+ cell().label("x", "0").label("y", "1").value(2).
+ cell().label("x", "1").label("y", "0").value(4).
+ cell().label("x", "1").label("y", "1").value(5).
+ cell().label("x", "1").label("y", "2").value(6).
+ build();
+ assertSerialization(tensor);
+ }
+
+ private void assertSerialization(String tensorString) {
+ assertSerialization(Tensor.from(tensorString));
+ }
+
+ private void assertSerialization(Tensor tensor) {
+ assertSerialization(tensor, tensor.type());
+ }
+
+ private void assertSerialization(Tensor tensor, TensorType expectedType) {
+ byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
+ Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor));
+ assertEquals(tensor, decodedTensor);
+ }
+
+}
+
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
new file mode 100644
index 00000000000..68bf59e3ed9
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
@@ -0,0 +1,149 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.serialization;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class SerializationTestCase {
+
+ private static String testPath = "eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json";
+ private static List<String> tests = new ArrayList<>();
+
+ @Before
+ public void loadTests() throws IOException {
+ File testSpec = new File(testPath);
+ if (!testSpec.exists()) {
+ testSpec = new File("../" + testPath);
+ }
+ try(BufferedReader br = new BufferedReader(new FileReader(testSpec))) {
+ String test = br.readLine();
+ while (test != null) {
+ tests.add(test);
+ test = br.readLine();
+ }
+ }
+ }
+
+ @Test
+ public void testSerialization() throws IOException {
+ for (String test : tests) {
+ ObjectMapper mapper = new ObjectMapper();
+ JsonNode node = mapper.readTree(test);
+ if (node.has("tensor") && node.has("binary")) {
+ System.out.println("Running test: " + test);
+
+ Tensor tensor = buildTensor(node.get("tensor"));
+ String spec = getSpec(node.get("tensor"));
+ byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
+ boolean serializedToABinaryRepresentation = false;
+
+ JsonNode binaryNode = node.get("binary");
+ for (int i = 0; i < binaryNode.size(); ++i) {
+ byte[] bin = getBytes(binaryNode.get(i).asText());
+ Tensor decodedTensor = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(bin));
+
+ if (spec.equalsIgnoreCase("double")) {
+ assertEquals(tensor.asDouble(), decodedTensor.asDouble(), 1e-6);
+ } else {
+ assertEquals(tensor, decodedTensor);
+ }
+
+ if (Arrays.equals(encodedTensor, bin)) {
+ serializedToABinaryRepresentation = true;
+ }
+ }
+ assertTrue("Tensor did not serialize to one of the given representations", serializedToABinaryRepresentation);
+ }
+ }
+ }
+
+ private Tensor buildTensor(JsonNode tensor) {
+ TensorType type = tensorType(tensor);
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ tensorCells(tensor, builder);
+ return builder.build();
+ }
+
+ private TensorType tensorType(JsonNode tensor) {
+ String spec = getSpec(tensor);
+ if (spec.equalsIgnoreCase("double")) {
+ spec = "tensor()";
+ }
+ return TensorType.fromSpec(spec);
+ }
+
+ private String getSpec(JsonNode tensor) {
+ return tensor.get("type").asText();
+ }
+
+ private void tensorCells(JsonNode tensor, Tensor.Builder builder) {
+ JsonNode cells = tensor.get("cells");
+ for (JsonNode cell : cells) {
+ tensorCell(cell, builder.cell());
+ }
+ }
+
+ private void tensorCell(JsonNode cell, Tensor.Builder.CellBuilder cellBuilder) {
+ tensorCellAddress(cellBuilder, cell.get("address"));
+ tensorCellValue(cellBuilder, cell.get("value"));
+ }
+
+ private void tensorCellValue(Tensor.Builder.CellBuilder cellBuilder, JsonNode value) {
+ cellBuilder.value(value.doubleValue());
+ }
+
+ private void tensorCellAddress(Tensor.Builder.CellBuilder cellBuilder, JsonNode address) {
+ Iterator<String> dimension = address.fieldNames();
+ while (dimension.hasNext()) {
+ String name = dimension.next();
+ JsonNode label = address.get(name);
+ cellBuilder.label(name, label.asText());
+ }
+ }
+
+ private byte[] getBytes(String binaryRepresentation) {
+ return parseHexValue(binaryRepresentation.substring(2));
+ }
+
+ private byte[] parseHexValue(String s) {
+ final int len = s.length();
+ byte[] bytes = new byte[len/2];
+ for (int i = 0; i < len; i += 2) {
+ int c1 = hexValue(s.charAt(i)) << 4;
+ int c2 = hexValue(s.charAt(i + 1));
+ bytes[i/2] = (byte)(c1 + c2);
+ }
+ return bytes;
+ }
+
+ private int hexValue(Character c) {
+ if (c >= 'a' && c <= 'f') {
+ return c - 'a' + 10;
+ } else if (c >= 'A' && c <= 'F') {
+ return c - 'A' + 10;
+ } else if (c >= '0' && c <= '9') {
+ return c - '0';
+ }
+ throw new IllegalArgumentException("Hex contains illegal characters");
+ }
+
+}