aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer42.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java32
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java45
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java87
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java55
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java30
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java15
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java5
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java55
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java59
18 files changed, 297 insertions, 150 deletions
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer42.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer42.java
index 9e764aae798..753008de7e0 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer42.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer42.java
@@ -278,7 +278,7 @@ public class VespaDocumentDeserializer42 extends VespaDocumentSerializer42 imple
int encodedTensorLength = buf.getInt1_4Bytes();
if (encodedTensorLength > 0) {
byte[] encodedTensor = getBytes(null, encodedTensorLength);
- value.assign(TypedBinaryFormat.decode(encodedTensor));
+ value.assign(TypedBinaryFormat.decode(null, encodedTensor)); // TODO: Pass type
} else {
value.clear();
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index b0e30cf2043..d0a188c0760 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -177,7 +177,7 @@ public class EvaluationTestCase {
tester.assertEvaluates("{ {x:0,y:0}:15, {x:1,y:0}:35 }", "join(tensor0, tensor1, f(x,y) (x*y))", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
// -- join composites
tester.assertEvaluates("{ }", "tensor0 * tensor0", "{}");
- tester.assertEvaluates("tensor(x{},y{},z{}):{}", "( tensor0 * tensor1 ) * ( tensor2 * tensor1 )",
+ tester.assertEvaluates("{{x:0,y:0,z:0}:0.0}", "( tensor0 * tensor1 ) * ( tensor2 * tensor1 )",
"{{x:0}:1}", "{}", "{{y:0,z:0}:1}");
tester.assertEvaluates("tensor(x{}):{}",
"tensor0 * tensor1", "{ {x:0}:3 }", "tensor(x{}):{ {x:1}:5 }");
@@ -262,7 +262,8 @@ public class EvaluationTestCase {
"{ {x:0}:1, {x:1}:2 }", "{ {y:0}:3, {y:1}:4 }", "{ {z:0}:5 }",
"{ {x:0,y:0,z:0}:0.5, {x:1,y:0,z:0}:1.5, {x:0,y:0,z:1}:4.5, {x:0,y:1,z:0}:0, {x:1,y:0,z:1}:0, {x:0,y:1,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:0 }");
tester.assertEvaluates("1.0", "sum(tensor0 * tensor1 + 0.5)", "{ {x:0}:0, {x:1}:0 }", "{ {x:0}:1, {x:1}:1 }");
- tester.assertEvaluates("0.0", "sum(tensor0 * tensor1 + 0.5)", "{}", "{ {x:0}:1, {x:1}:1 }");
+ tester.assertEvaluates("1.0", "sum(tensor0 * tensor1 + 0.5)", "{}", "{ {x:0}:1, {x:1}:1 }");
+ tester.assertEvaluates("0.0", "sum(tensor0 * tensor1 + 0.5)", "tensor(x{}):{}", "{ {x:0}:1, {x:1}:1 }");
// tensor result dimensions are given from argument dimensions, not the resulting values
tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{}):{ {x:1}:1 }");
diff --git a/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java b/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java
index c33882052b4..eba749bd14e 100644
--- a/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java
+++ b/vespajlib/src/main/java/com/yahoo/io/GrowableByteBuffer.java
@@ -1,6 +1,8 @@
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.io;
+import com.yahoo.text.Utf8;
+
import java.nio.*;
/**
@@ -20,21 +22,22 @@ import java.nio.*;
* No methods except getByteBuffer() expose the encapsulated
* ByteBuffer, which is intentional.
*
- * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
+ * @author Einar M R Rosenvinge
*/
public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> {
+
public static final int DEFAULT_BASE_SIZE = 64*1024;
public static final float DEFAULT_GROW_FACTOR = 2.0f;
private ByteBuffer buffer;
private float growFactor;
private int mark = -1;
- //NOTE: It might have been better to subclass HeapByteBuffer,
- //but that class is package-private. Subclassing ByteBuffer would involve
- //implementing a lot of abstract methods, which would mean reinventing
- //some (too many) wheels.
+ // NOTE: It might have been better to subclass HeapByteBuffer,
+ // but that class is package-private. Subclassing ByteBuffer would involve
+ // implementing a lot of abstract methods, which would mean reinventing
+ // some (too many) wheels.
- //CONSTRUCTORS:
+ // CONSTRUCTORS:
public GrowableByteBuffer() {
this(DEFAULT_BASE_SIZE, DEFAULT_GROW_FACTOR);
@@ -61,7 +64,7 @@ public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> {
}
- //ACCESSORS:
+ // ACCESSORS:
public float getGrowFactor() {
return growFactor;
@@ -364,6 +367,21 @@ public class GrowableByteBuffer implements Comparable<GrowableByteBuffer> {
}
}
+ /** Writes this string to the buffer as a 1_4 encoded length in bytes followed by the utf8 bytes */
+ public void putUtf8String(String value) {
+ byte[] stringBytes = Utf8.toBytes(value);
+ putInt1_4Bytes(stringBytes.length);
+ put(stringBytes);
+ }
+
+ /** Reads a string from the buffer as a 1_4 encoded length in bytes followed by the utf8 bytes */
+ public String getUtf8String() {
+ int stringLength = getInt1_4Bytes();
+ byte[] stringBytes = new byte[stringLength];
+ get(stringBytes);
+ return Utf8.toString(stringBytes);
+ }
+
/**
* Computes the size used for storing the given integer using 1 or 4 bytes.
*
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index daa85cc51e4..7570a357452 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -29,6 +29,14 @@ public final class DimensionSizes {
/** Returns the number of dimensions this provides the size of */
public int dimensions() { return sizes.length; }
+ /** Returns the product of the sizes of this */
+ public int totalSize() {
+ int productSize = 1;
+ for (int dimensionSize : sizes )
+ productSize *= dimensionSize;
+ return productSize;
+ }
+
@Override
public boolean equals(Object o) {
if (o == this) return true;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 9315922f57a..bee93ddb4e0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -103,7 +103,6 @@ public class IndexedTensor implements Tensor {
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
public double get(int ... indexes) {
- if (values.length == 0) return Double.NaN;
return values[toValueIndex(indexes, dimensionSizes)];
}
@@ -157,7 +156,7 @@ public class IndexedTensor implements Tensor {
@Override
public Map<TensorAddress, Double> cells() {
if (dimensionSizes.dimensions() == 0)
- return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]);
+ return Collections.singletonMap(TensorAddress.empty, values[0]);
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
@@ -217,13 +216,6 @@ public class IndexedTensor implements Tensor {
public abstract Builder cell(double value, int ... indexes);
- protected double[] arrayFor(DimensionSizes sizes) {
- int productSize = 1;
- for (int i = 0; i < sizes.dimensions(); i++ )
- productSize *= sizes.size(i);
- return new double[productSize];
- }
-
@Override
public TensorType type() { return type; }
@@ -233,7 +225,7 @@ public class IndexedTensor implements Tensor {
}
/** A bound builder can create the double array directly */
- private static class BoundBuilder extends Builder {
+ public static class BoundBuilder extends Builder {
private DimensionSizes sizes;
private double[] values;
@@ -242,7 +234,7 @@ public class IndexedTensor implements Tensor {
this(type, dimensionSizesOf(type));
}
- public static DimensionSizes dimensionSizesOf(TensorType type) {
+ static DimensionSizes dimensionSizesOf(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < type.dimensions().size(); i++)
b.set(i, type.dimensions().get(i).size().get());
@@ -254,8 +246,7 @@ public class IndexedTensor implements Tensor {
if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
- values = arrayFor(sizes);
- Arrays.fill(values, Double.NaN);
+ values = new double[sizes.totalSize()];
}
@Override
@@ -277,10 +268,6 @@ public class IndexedTensor implements Tensor {
@Override
public IndexedTensor build() {
- // Note that we do not check for no NaN's here for performance reasons.
- // NaN's don't get lost so leaving them in place should be quite benign
- if (values.length == 1 && Double.isNaN(values[0]))
- values = new double[0];
IndexedTensor tensor = new IndexedTensor(type, sizes, values);
// prevent further modification
sizes = null;
@@ -290,9 +277,6 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(Cell cell, double value) {
- // TODO: Use internal index if applicable
- // values[internalIndex] = value;
- // return this;
int directIndex = cell.getDirectIndex();
if (directIndex >= 0) // optimization
values[directIndex] = value;
@@ -301,6 +285,15 @@ public class IndexedTensor implements Tensor {
return this;
}
+ /**
+ * Set a cell value by the index in the internal layout of this cell.
+ * This requires knowledge of the internal layout of cells in this implementation, and should therefore
+ * probably not be used (but when it can be used it is fast).
+ */
+ public void cellByDirectIndex(int index, double value) {
+ values[index] = value;
+ }
+
}
/**
@@ -318,13 +311,13 @@ public class IndexedTensor implements Tensor {
@Override
public IndexedTensor build() {
- if (firstDimension == null) // empty
- return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {});
+ if (firstDimension == null) throw new IllegalArgumentException("Tensor of type " + type() + " has no values");
+
if (type.dimensions().isEmpty()) // single number
return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
DimensionSizes dimensionSizes = findDimensionSizes(firstDimension);
- double[] values = arrayFor(dimensionSizes);
+ double[] values = new double[dimensionSizes.totalSize()];
fillValues(0, 0, firstDimension, dimensionSizes, values);
return new IndexedTensor(type, dimensionSizes, values);
}
@@ -333,8 +326,10 @@ public class IndexedTensor implements Tensor {
List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size());
findDimensionSizes(0, dimensionSizeList, firstDimension);
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct
- for (int i = 0; i < b.dimensions(); i++)
- b.set(i, dimensionSizeList.get(i));
+ for (int i = 0; i < b.dimensions(); i++) {
+ if (i < dimensionSizeList.size())
+ b.set(i, dimensionSizeList.get(i));
+ }
return b.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 51d40a89f3b..29c508ce12f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -213,10 +213,9 @@ public interface Tensor {
static String contentToString(Tensor tensor) {
List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
- if (tensor.type().dimensions().isEmpty()) { // TODO: Decide on one way to represent degeneration to number
+ if (tensor.type().dimensions().isEmpty()) {
if (cellEntries.isEmpty()) return "{}";
- double value = cellEntries.get(0).getValue();
- return value == 0.0 ? "{}" : "{" + value +"}";
+ return "{" + cellEntries.get(0).getValue() +"}";
}
Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 82f36972a47..fbc469c1829 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -53,9 +53,6 @@ public class TensorType {
return TensorTypeParser.fromSpec(specString);
}
- /** Returns true if all dimensions of this are indexed */
- public boolean isIndexed() { return dimensions().stream().allMatch(d -> d.isIndexed()); }
-
/** Returns an immutable list of the dimensions of this */
public List<Dimension> dimensions() { return dimensions; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index ceade39ce42..f295e129a0f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -113,7 +113,7 @@ public class Join extends PrimitiveTensorFunction {
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
- if (subspace.type().isIndexed() && superspace.type().isIndexed())
+ if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder);
else
return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
index f3adf63739a..9b0ccdcb6c8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/BinaryFormat.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor.serialization;
import com.google.common.annotations.Beta;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
/**
* Representation of a specific binary format with functions for serializing a Tensor object into
@@ -21,7 +22,10 @@ interface BinaryFormat {
/**
* Deserialize the given binary data into a Tensor object.
+ *
+ * @param type the expected abstract type of the tensor to serialize
+ * @param buffer the buffer containing the tensor binary data
*/
- Tensor decode(GrowableByteBuffer buffer);
+ Tensor decode(TensorType type, GrowableByteBuffer buffer);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
new file mode 100644
index 00000000000..0a97576d5b7
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -0,0 +1,87 @@
+package com.yahoo.tensor.serialization;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.tensor.DimensionSizes;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.text.Utf8;
+
+import java.util.Iterator;
+
+/**
+ * Implementation of a dense binary format for a tensor on the form:
+ *
+ * Sorted dimensions = num_dimensions [dimension_str_len dimension_str_bytes dimension_size_int]*
+ * Cell_values = [double, double, double, ...]*
+ * where values are encoded in order of increasing indexes in each dimension, increasing
+ * indexes of later dimensions in the dimension type before earlier.
+ *
+ * @author bratseth
+ */
+@Beta
+public class DenseBinaryFormat implements BinaryFormat {
+
+ @Override
+ public void encode(GrowableByteBuffer buffer, Tensor tensor) {
+ if ( ! ( tensor instanceof IndexedTensor))
+ throw new RuntimeException("The dense format is only supported for indexed tensors");
+ encodeDimensions(buffer, (IndexedTensor)tensor);
+ encodeCells(buffer, tensor);
+ }
+
+ private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) {
+ buffer.putInt1_4Bytes(tensor.type().dimensions().size());
+ for (int i = 0; i < tensor.type().dimensions().size(); i++) {
+ buffer.putUtf8String(tensor.type().dimensions().get(i).name());
+ buffer.putInt1_4Bytes(tensor.dimensionSizes().size(i));
+ }
+ }
+
+ private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
+ Iterator<Double> i = tensor.valueIterator();
+ while (i.hasNext())
+ buffer.putDouble(i.next());
+ }
+
+ @Override
+ public Tensor decode(TensorType type, GrowableByteBuffer buffer) {
+ DimensionSizes sizes = decodeDimensionSizes(type, buffer);
+ Tensor.Builder builder = Tensor.Builder.of(type, sizes);
+ decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder);
+ return builder.build();
+ }
+
+ private DimensionSizes decodeDimensionSizes(TensorType type, GrowableByteBuffer buffer) {
+ int dimensionCount = buffer.getInt1_4Bytes();
+ if (type.dimensions().size() != dimensionCount)
+ throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount +
+ " dimensions but type is " + type);
+
+ DimensionSizes.Builder builder = new DimensionSizes.Builder(dimensionCount);
+ for (int i = 0; i < dimensionCount; i++) {
+ TensorType.Dimension expectedDimension = type.dimensions().get(i);
+
+ String encodedName = buffer.getUtf8String();
+ int encodedSize = buffer.getInt1_4Bytes();
+
+ if ( ! expectedDimension.name().equals(encodedName))
+ throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName +
+ "' as dimension " + i + " but type is " + type);
+
+ if (expectedDimension.size().isPresent() && expectedDimension.size().get() < encodedSize)
+ throw new IllegalArgumentException("Type/instance mismatch: Instance has size " + encodedSize +
+ " in " + expectedDimension + " in type " + type);
+
+ builder.set(i, encodedSize);
+ }
+ return builder.build();
+ }
+
+ private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
+ for (int i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, buffer.getDouble());
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index 27a009b5e7e..30b36e83457 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -31,14 +31,14 @@ class SparseBinaryFormat implements BinaryFormat {
encodeCells(buffer, tensor);
}
- private static void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) {
+ private void encodeDimensions(GrowableByteBuffer buffer, List<TensorType.Dimension> sortedDimensions) {
buffer.putInt1_4Bytes(sortedDimensions.size());
for (TensorType.Dimension dimension : sortedDimensions) {
- encodeString(buffer, dimension.name());
+ buffer.putUtf8String(dimension.name());
}
}
- private static void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
+ private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
buffer.putInt1_4Bytes(tensor.size());
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -47,35 +47,47 @@ class SparseBinaryFormat implements BinaryFormat {
}
}
- private static void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) {
+ private void encodeAddress(GrowableByteBuffer buffer, TensorAddress address) {
for (int i = 0; i < address.size(); i++)
- encodeString(buffer, address.label(i));
- }
-
- private static void encodeString(GrowableByteBuffer buffer, String value) {
- byte[] stringBytes = Utf8.toBytes(value);
- buffer.putInt1_4Bytes(stringBytes.length);
- buffer.put(stringBytes);
+ buffer.putUtf8String(address.label(i));
}
@Override
- public Tensor decode(GrowableByteBuffer buffer) {
- TensorType type = decodeDimensions(buffer);
+ public Tensor decode(TensorType type, GrowableByteBuffer buffer) {
+ if (type == null) // TODO (January 2017): Remove this when types are available
+ type = decodeDimensionsToType(buffer);
+ else
+ consumeAndValidateDimensions(type, buffer);
Tensor.Builder builder = Tensor.Builder.of(type);
decodeCells(buffer, builder, type);
return builder.build();
}
- private static TensorType decodeDimensions(GrowableByteBuffer buffer) {
+ private TensorType decodeDimensionsToType(GrowableByteBuffer buffer) {
TensorType.Builder builder = new TensorType.Builder();
int numDimensions = buffer.getInt1_4Bytes();
for (int i = 0; i < numDimensions; ++i) {
- builder.mapped(decodeString(buffer)); // TODO: Support indexed
+ builder.mapped(buffer.getUtf8String());
}
return builder.build();
}
- private static void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
+ private void consumeAndValidateDimensions(TensorType type, GrowableByteBuffer buffer) {
+ int dimensionCount = buffer.getInt1_4Bytes();
+ if (type.dimensions().size() != dimensionCount)
+ throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount +
+ " dimensions but type is " + type);
+
+ for (int i = 0; i < dimensionCount; ++i) {
+ TensorType.Dimension expectedDimension = type.dimensions().get(i);
+ String encodedName = buffer.getUtf8String();
+ if ( ! expectedDimension.name().equals(encodedName))
+ throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName +
+ "' as dimension " + i + " but type is " + type);
+ }
+ }
+
+ private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
int numCells = buffer.getInt1_4Bytes();
for (int i = 0; i < numCells; ++i) {
Tensor.Builder.CellBuilder cellBuilder = builder.cell();
@@ -84,20 +96,13 @@ class SparseBinaryFormat implements BinaryFormat {
}
}
- private static void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) {
+ private void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) {
for (TensorType.Dimension dimension : type.dimensions()) {
- String label = decodeString(buffer);
+ String label = buffer.getUtf8String();
if ( ! label.isEmpty()) {
builder.label(dimension.name(), label);
}
}
}
- private static String decodeString(GrowableByteBuffer buffer) {
- int stringLength = buffer.getInt1_4Bytes();
- byte[] stringBytes = new byte[stringLength];
- buffer.get(stringBytes);
- return Utf8.toString(stringBytes);
- }
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 5a45f20b6d8..65216aa2fcd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -3,7 +3,9 @@ package com.yahoo.tensor.serialization;
import com.google.common.annotations.Beta;
import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
/**
* Class used by clients for serializing a Tensor object into binary format or
@@ -18,25 +20,31 @@ import com.yahoo.tensor.Tensor;
public class TypedBinaryFormat {
private static final int SPARSE_BINARY_FORMAT_TYPE = 1;
+ private static final int DENSE_BINARY_FORMAT_TYPE = 2;
public static byte[] encode(Tensor tensor) {
GrowableByteBuffer buffer = new GrowableByteBuffer();
- buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE);
- new SparseBinaryFormat().encode(buffer, tensor);
+ if (tensor instanceof IndexedTensor && 1==2) { // TODO: Activate when we have type information everywhere
+ buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE);
+ new DenseBinaryFormat().encode(buffer, tensor);
+ }
+ else {
+ buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE);
+ new SparseBinaryFormat().encode(buffer, tensor);
+ }
buffer.flip();
byte[] result = new byte[buffer.remaining()];
buffer.get(result);
return result;
}
- public static Tensor decode(byte[] data) {
+ public static Tensor decode(TensorType type, byte[] data) {
GrowableByteBuffer buffer = GrowableByteBuffer.wrap(data);
int formatType = buffer.getInt1_4Bytes();
switch (formatType) {
- case SPARSE_BINARY_FORMAT_TYPE:
- return new SparseBinaryFormat().decode(buffer);
- default:
- throw new IllegalArgumentException("Binary format type " + formatType + " is not a known format");
+ case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer);
+ case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat().decode(type, buffer);
+ default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown");
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
index 3f7f02c6c00..e150b1cf24f 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
@@ -1,5 +1,6 @@
package com.yahoo.tensor;
+import junit.framework.TestCase;
import org.junit.Test;
import java.util.HashMap;
@@ -7,6 +8,7 @@ import java.util.Iterator;
import java.util.Map;
import static junit.framework.TestCase.assertTrue;
+import static junit.framework.TestCase.fail;
import static org.junit.Assert.assertEquals;
/**
@@ -23,16 +25,12 @@ public class IndexedTensorTestCase {
@Test
public void testEmpty() {
Tensor empty = Tensor.Builder.of(TensorType.empty).build();
- assertTrue(empty instanceof IndexedTensor);
- assertTrue(empty.isEmpty());
- assertEquals("{}", empty.toString());
+ assertEquals(1, empty.size());
+ assertEquals((double)0.0, (double)empty.valueIterator().next(), 0.00000001);
Tensor emptyFromString = Tensor.from(TensorType.empty, "{}");
- assertEquals("{}", Tensor.from(TensorType.empty, "{}").toString());
- assertTrue(emptyFromString.isEmpty());
- assertTrue(emptyFromString instanceof IndexedTensor);
assertEquals(empty, emptyFromString);
}
-
+
@Test
public void testSingleValue() {
Tensor singleValue = Tensor.Builder.of(TensorType.empty).cell(TensorAddress.empty, 3.5).build();
@@ -45,22 +43,6 @@ public class IndexedTensorTestCase {
}
@Test
- public void testSingleValueWithDimensions() {
- TensorType type = new TensorType.Builder().indexed("x").indexed("y").build();
- Tensor emptyWithDimensions = Tensor.Builder.of(type).build();
- assertTrue(emptyWithDimensions instanceof IndexedTensor);
- assertEquals("tensor(x[],y[]):{}", emptyWithDimensions.toString());
- Tensor emptyWithDimensionsFromString = Tensor.from("tensor(x[],y[]):{}");
- assertEquals("tensor(x[],y[]):{}", emptyWithDimensionsFromString.toString());
- assertTrue(emptyWithDimensionsFromString instanceof IndexedTensor);
- assertEquals(emptyWithDimensions, emptyWithDimensionsFromString);
-
- IndexedTensor emptyWithDimensionsIndexed = (IndexedTensor)emptyWithDimensions;
- assertEquals(0, emptyWithDimensionsIndexed.dimensionSizes().size(0));
- assertEquals(0, emptyWithDimensionsIndexed.dimensionSizes().size(1));
- }
-
- @Test
public void testBoundBuilding() {
TensorType type = new TensorType.Builder().indexed("v", vSize)
.indexed("w", wSize)
@@ -91,7 +73,7 @@ public class IndexedTensorTestCase {
for (int z = 0; z < zSize; z++)
builder.cell(value(v, w, x, y, z), v, w, x, y, z);
- IndexedTensor tensor = builder.build();
+ IndexedTensor tensor = (IndexedTensor)builder.build();
// Lookup by index arguments
for (int v = 0; v < vSize; v++)
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
index 4c32a80dc11..5c2c3b9db32 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor;
import com.google.common.collect.Sets;
+import junit.framework.TestCase;
import org.junit.Test;
import java.util.Set;
@@ -18,6 +19,20 @@ import static org.junit.Assert.fail;
public class MappedTensorTestCase {
@Test
+ public void testEmpty() {
+ TensorType type = new TensorType.Builder().mapped("x").build();
+ Tensor empty = Tensor.Builder.of(type).build();
+ TestCase.assertTrue(empty instanceof MappedTensor);
+ TestCase.assertTrue(empty.isEmpty());
+ assertEquals("tensor(x{}):{}", empty.toString());
+ Tensor emptyFromString = Tensor.from(type, "{}");
+ assertEquals("tensor(x{}):{}", Tensor.from("tensor(x{}):{}").toString());
+ TestCase.assertTrue(emptyFromString.isEmpty());
+ TestCase.assertTrue(emptyFromString instanceof MappedTensor);
+ assertEquals(empty, emptyFromString);
+ }
+
+ @Test
public void testOneDimensionalBuilding() {
TensorType type = new TensorType.Builder().mapped("x").build();
Tensor tensor = Tensor.Builder.of(type).
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
index 2f060239eb1..e2baa1d5ac3 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -27,6 +27,7 @@ public class TensorFunctionBenchmark {
modelVectors = modelVectors.stream().map(t -> t.multiply(unitVector("k"))).collect(Collectors.toList());
}
dotProduct(queryVector, modelVectors, Math.max(iterations/10, 10)); // warmup
+ System.gc();
long startTime = System.currentTimeMillis();
dotProduct(queryVector, modelVectors, iterations);
long totalTime = System.currentTimeMillis() - startTime;
@@ -106,51 +107,41 @@ public class TensorFunctionBenchmark {
// ---------------- Mapped with extra space (sidesteps current special-case optimizations):
// 410 ms
- System.gc();
time = new TensorFunctionBenchmark().benchmark(20, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true);
System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time);
// 770 ms
- System.gc();
time = new TensorFunctionBenchmark().benchmark(20, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true);
System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time);
// ---------------- Mapped:
// 2.6 ms
- System.gc();
time = new TensorFunctionBenchmark().benchmark(5000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false);
System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time);
// 6.8 ms
- System.gc();
time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false);
System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time);
// ---------------- Indexed (unbound) with extra space (sidesteps current special-case optimizations):
// 30 ms
- System.gc();
time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true);
System.out.printf("Indexed vectors, x space time per join: %1$8.3f ms\n", time);
// 27 ms
- System.gc();
time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true);
System.out.printf("Indexed matrix, x space time per join: %1$8.3f ms\n", time);
// ---------------- Indexed unbound:
// 0.14 ms
- System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false);
System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time);
// 0.14 ms
- System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false);
System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time);
// ---------------- Indexed bound:
// 0.14 ms
- System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false);
System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time);
// 0.14 ms
- System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false);
System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time);
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index feeba1a7a10..b35220cf013 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -21,7 +21,7 @@ import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
- * Tests Tensor functionality
+ * Tests tensor functionality
*
* @author bratseth
*/
@@ -29,7 +29,8 @@ public class TensorTestCase {
@Test
public void testStringForm() {
- assertEquals("{}", Tensor.from("{}").toString());
+ assertEquals("{5.7}", Tensor.from("{5.7}").toString());
+ assertTrue(Tensor.from("{5.7}") instanceof IndexedTensor);
assertEquals("{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString());
assertEquals("{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString());
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
new file mode 100644
index 00000000000..d2b2044f3ed
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -0,0 +1,55 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.serialization;
+
+import com.google.common.collect.Sets;
+import com.yahoo.tensor.Tensor;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Set;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for the dense binary format.
+ *
+ * @author bratseth
+ */
+public class DenseBinaryFormatTestCase {
+
+ @Test
+ public void testSerialization() {
+ assertSerialization("{-5.37}");
+ assertSerialization("tensor(x[]):{{x:0}:2.0}");
+ assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0}");
+ assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}");
+ }
+
+ @Test
+ @Ignore // TODO: Activate when encoding in this format is activated
+ public void requireThatSerializationFormatDoNotChange() {
+ byte[] encodedTensor = new byte[]{2, // binary format type
+ 2, // dimension count
+ 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
+ 1, (byte) 'z', 1, // dimension z with size
+ 64, 0, 0, 0, 0, 0, 0, 0, // value 1
+ 64, 8, 0, 0, 0, 0, 0, 0 // value 2
+ };
+ assertEquals(Arrays.toString(encodedTensor),
+ Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}"))));
+ }
+
+ private void assertSerialization(String tensorString) {
+ assertSerialization(Tensor.from(tensorString));
+ }
+
+ private void assertSerialization(Tensor tensor) {
+ byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
+ Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor);
+ assertEquals(tensor, decodedTensor);
+ }
+
+}
+
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index ad908101329..283aa90cf65 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -13,52 +13,23 @@ import static org.junit.Assert.assertEquals;
/**
* Tests for the sparse binary format.
*
- * TODO: When new formats are added we should refactor this test to test all formats
- * with the same set of tensor inputs (if feasible).
- *
* @author geirst
*/
public class SparseBinaryFormatTestCase {
- private static void assertSerialization(String tensorString) {
- assertSerialization(Tensor.from(tensorString));
- }
-
- private static void assertSerialization(String tensorString, Set<String> dimensions) {
- Tensor tensor = Tensor.from(tensorString);
- assertEquals(dimensions, tensor.type().dimensionNames());
- assertSerialization(tensor);
- }
-
- private static void assertSerialization(Tensor tensor) {
- byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
- Tensor decodedTensor = TypedBinaryFormat.decode(encodedTensor);
- assertEquals(tensor, decodedTensor);
- }
-
@Test
- public void testSerializationOfTensorsWithDenseTensorAddresses() {
- assertSerialization("{}");
- assertSerialization("{{x:0}:2.0}");
- assertSerialization("{{x:0}:2.0,{x:1}:3.0}");
- assertSerialization("{{x:0,y:0}:2.0}");
- assertSerialization("{{x:0,y:0}:2.0,{x:0,y:1}:3.0}");
- assertSerialization("{{y:0,x:0}:2.0}");
- assertSerialization("{{y:0,x:0}:2.0,{y:1,x:0}:3.0}");
- assertSerialization("{{dimX:labelA,dimY:labelB}:2.0,{dimY:labelC,dimX:labelD}:3.0}");
+ public void testSerialization() {
+ assertSerialization("tensor(x{}):{}");
+ assertSerialization("tensor(x{}):{{x:0}:2.0}");
+ assertSerialization("tensor(dimX{},dimY{}):{{dimX:labelA,dimY:labelB}:2.0,{dimY:labelC,dimX:labelD}:3.0}");
+ assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0}");
+ assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0,{x:1,y:4}:3.0}");
+ assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0}");
+ assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0,{y:1,x:0,z:6}:3.0}");
}
@Test
- public void testSerializationOfTensorsWithSparseTensorAddresses() {
- assertSerialization("{{x:0}:2.0, {x:1}:3.0}", Sets.newHashSet("x"));
- assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0}", Sets.newHashSet("x", "y"));
- assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0,{x:1,y:4}:3.0}", Sets.newHashSet("x", "y"));
- assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0}", Sets.newHashSet("x", "y", "z"));
- assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0,{y:1,x:0,z:6}:3.0}", Sets.newHashSet("x", "y", "z"));
- }
-
- @Test
- public void requireThatCompactSerializationFormatDoNotChange() {
+ public void requireThatSerializationFormatDoNotChange() {
byte[] encodedTensor = new byte[] {1, // binary format type
2, // num dimensions
2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
@@ -66,7 +37,17 @@ public class SparseBinaryFormatTestCase {
2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0
2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1
assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ }
+
+ private void assertSerialization(String tensorString) {
+ assertSerialization(Tensor.from(tensorString));
+ }
+
+ private void assertSerialization(Tensor tensor) {
+ byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
+ Tensor decodedTensor = TypedBinaryFormat.decode(tensor.type(), encodedTensor);
+ assertEquals(tensor, decodedTensor);
}
}