summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-04-11 15:40:01 +0200
committerLester Solbakken <lesters@oath.com>2019-04-11 15:40:01 +0200
commit59bbd539683c93f2e2f5a554c3092c0e87be142b (patch)
treec0e4655a10ae005f7a1d215a196dbcc22beeb900 /vespajlib
parentaccba940ad0884d4f52880f75a7509f60908ab84 (diff)
Add tensor value serialization for mapped and mixed tensors
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java101
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java37
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java110
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java21
6 files changed, 199 insertions, 114 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index ecd4f7d1965..5072484567d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -9,6 +9,8 @@ import com.yahoo.tensor.TensorType;
import java.util.Iterator;
import java.util.Optional;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
/**
* Implementation of a dense binary format for a tensor on the form:
@@ -22,40 +24,23 @@ import java.util.Optional;
*/
public class DenseBinaryFormat implements BinaryFormat {
- static private final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing
- static private final int FLOAT_VALUE_TYPE = 1;
+ private final TensorType.Value serializationValueType;
- enum EncodeType {NO_DEFAULT, DOUBLE_IS_DEFAULT}
- private final EncodeType encodeType;
DenseBinaryFormat() {
- encodeType = EncodeType.DOUBLE_IS_DEFAULT;
+ this(TensorType.Value.DOUBLE);
}
- DenseBinaryFormat(EncodeType encodeType) {
- this.encodeType = encodeType;
+ DenseBinaryFormat(TensorType.Value serializationValueType) {
+ this.serializationValueType = serializationValueType;
}
@Override
public void encode(GrowableByteBuffer buffer, Tensor tensor) {
if ( ! ( tensor instanceof IndexedTensor))
throw new RuntimeException("The dense format is only supported for indexed tensors");
- encodeValueType(buffer, tensor.type().valueType());
encodeDimensions(buffer, (IndexedTensor)tensor);
encodeCells(buffer, tensor);
}
- private void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) {
- switch (valueType) {
- case DOUBLE:
- if (encodeType != EncodeType.DOUBLE_IS_DEFAULT) {
- buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE);
- }
- break;
- case FLOAT:
- buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE);
- break;
- }
- }
-
private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) {
buffer.putInt1_4Bytes(tensor.type().dimensions().size());
for (int i = 0; i < tensor.type().dimensions().size(); i++) {
@@ -65,26 +50,17 @@ public class DenseBinaryFormat implements BinaryFormat {
}
private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
- switch (tensor.type().valueType()) {
- case DOUBLE:
- encodeCellsAsDouble(buffer, tensor);
- break;
- case FLOAT:
- encodeCellsAsFloat(buffer, tensor);
- break;
+ switch (serializationValueType) {
+ case DOUBLE: encodeCells(tensor, buffer::putDouble); break;
+ case FLOAT: encodeCells(tensor, (i) -> buffer.putFloat(i.floatValue())); break;
}
}
- private void encodeCellsAsDouble(GrowableByteBuffer buffer, Tensor tensor) {
+ private void encodeCells(Tensor tensor, Consumer<Double> consumer) {
Iterator<Double> i = tensor.valueIterator();
- while (i.hasNext())
- buffer.putDouble(i.next());
- }
-
- private void encodeCellsAsFloat(GrowableByteBuffer buffer, Tensor tensor) {
- Iterator<Double> i = tensor.valueIterator();
- while (i.hasNext())
- buffer.putFloat(i.next().floatValue());
+ while (i.hasNext()) {
+ consumer.accept(i.next());
+ }
}
@Override
@@ -93,40 +69,27 @@ public class DenseBinaryFormat implements BinaryFormat {
DimensionSizes sizes;
if (optionalType.isPresent()) {
type = optionalType.get();
- TensorType serializedType = decodeType(buffer, type.valueType());
+ if (type.valueType() != this.serializationValueType) {
+ throw new IllegalArgumentException("Tensor value type mismatch. Value type " + type.valueType() +
+ " is not " + this.serializationValueType);
+ }
+ TensorType serializedType = decodeType(buffer);
if ( ! serializedType.isAssignableTo(type))
throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
" cannot be assigned to type " + type);
sizes = sizesFromType(serializedType);
}
else {
- type = decodeType(buffer, TensorType.Value.DOUBLE);
+ type = decodeType(buffer);
sizes = sizesFromType(type);
}
Tensor.Builder builder = Tensor.Builder.of(type, sizes);
- decodeCells(type.valueType(), sizes, buffer, (IndexedTensor.BoundBuilder)builder);
+ decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder);
return builder.build();
}
- private TensorType decodeType(GrowableByteBuffer buffer, TensorType.Value valueType) {
- TensorType.Value serializedValueType = TensorType.Value.DOUBLE;
- if ((valueType != TensorType.Value.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) {
- int type = buffer.getInt1_4Bytes();
- switch (type) {
- case DOUBLE_VALUE_TYPE:
- serializedValueType = TensorType.Value.DOUBLE;
- break;
- case FLOAT_VALUE_TYPE:
- serializedValueType = TensorType.Value.FLOAT;
- break;
- default:
- throw new IllegalArgumentException("Received tensor value type '" + serializedValueType + "'. Only 0(double), or 1(float) are legal.");
- }
- }
- if (valueType != serializedValueType) {
- throw new IllegalArgumentException("Expected " + valueType + ", got " + serializedValueType);
- }
- TensorType.Builder builder = new TensorType.Builder(serializedValueType);
+ private TensorType decodeType(GrowableByteBuffer buffer) {
+ TensorType.Builder builder = new TensorType.Builder(serializationValueType);
int dimensionCount = buffer.getInt1_4Bytes();
for (int i = 0; i < dimensionCount; i++)
builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation
@@ -141,24 +104,16 @@ public class DenseBinaryFormat implements BinaryFormat {
return builder.build();
}
- private void decodeCells(TensorType.Value valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
- switch (valueType) {
- case DOUBLE:
- decodeCellsAsDouble(sizes, buffer, builder);
- break;
- case FLOAT:
- decodeCellsAsFloat(sizes, buffer, builder);
- break;
+ private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
+ switch (serializationValueType) {
+ case DOUBLE: decodeCells(sizes, builder, buffer::getDouble); break;
+ case FLOAT: decodeCells(sizes, builder, () -> (double)buffer.getFloat()); break;
}
}
- private void decodeCellsAsDouble(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
- for (long i = 0; i < sizes.totalSize(); i++)
- builder.cellByDirectIndex(i, buffer.getDouble());
- }
- private void decodeCellsAsFloat(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
+ private void decodeCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, Supplier<Double> supplier) {
for (long i = 0; i < sizes.totalSize(); i++)
- builder.cellByDirectIndex(i, buffer.getFloat());
+ builder.cellByDirectIndex(i, supplier.get());
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
index 284dfea2141..bc247e5561f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -11,6 +11,8 @@ import com.yahoo.tensor.TensorType;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
import java.util.stream.Collectors;
/**
@@ -21,6 +23,15 @@ import java.util.stream.Collectors;
*/
class MixedBinaryFormat implements BinaryFormat {
+ private final TensorType.Value serializationValueType;
+
+ MixedBinaryFormat() {
+ this(TensorType.Value.DOUBLE);
+ }
+ MixedBinaryFormat(TensorType.Value serializationValueType) {
+ this.serializationValueType = serializationValueType;
+ }
+
@Override
public void encode(GrowableByteBuffer buffer, Tensor tensor) {
if ( ! ( tensor instanceof MixedTensor))
@@ -50,6 +61,13 @@ class MixedBinaryFormat implements BinaryFormat {
}
private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor) {
+ switch (serializationValueType) {
+ case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
+ case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;
+ }
+ }
+
+ private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor, Consumer<Double> consumer) {
List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
long denseSubspaceSize = tensor.denseSubspaceSize();
if (sparseDimensions.size() > 0) {
@@ -63,9 +81,9 @@ class MixedBinaryFormat implements BinaryFormat {
new IllegalStateException("Dimension not found in address."));
buffer.putUtf8String(cell.getKey().label(index));
}
- buffer.putDouble(cell.getValue());
+ consumer.accept(cell.getValue());
for (int i = 1; i < denseSubspaceSize; ++i ) {
- buffer.putDouble(cellIterator.next().getValue());
+ consumer.accept(cellIterator.next().getValue());
}
}
}
@@ -75,6 +93,10 @@ class MixedBinaryFormat implements BinaryFormat {
TensorType type;
if (optionalType.isPresent()) {
type = optionalType.get();
+ if (type.valueType() != this.serializationValueType) {
+ throw new IllegalArgumentException("Tensor value type mismatch. Value type " + type.valueType() +
+ " is not " + this.serializationValueType);
+ }
TensorType serializedType = decodeType(buffer);
if ( ! serializedType.isAssignableTo(type))
throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
@@ -89,7 +111,7 @@ class MixedBinaryFormat implements BinaryFormat {
}
private TensorType decodeType(GrowableByteBuffer buffer) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(serializationValueType);
int numMappedDimensions = buffer.getInt1_4Bytes();
for (int i = 0; i < numMappedDimensions; ++i) {
builder.mapped(buffer.getUtf8String());
@@ -102,6 +124,13 @@ class MixedBinaryFormat implements BinaryFormat {
}
private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) {
+ switch (serializationValueType) {
+ case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break;
+ case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break;
+ }
+ }
+
+ private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type, Supplier<Double> supplier) {
List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
TensorType sparseType = MixedTensor.createPartialType(type.valueType(), sparseDimensions);
long denseSubspaceSize = builder.denseSubspaceSize();
@@ -118,7 +147,7 @@ class MixedBinaryFormat implements BinaryFormat {
sparseAddress.add(sparseDimension.name(), buffer.getUtf8String());
}
for (long denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) {
- denseSubspace[(int)denseOffset] = buffer.getDouble();
+ denseSubspace[(int)denseOffset] = supplier.get();
}
builder.block(sparseAddress.build(), denseSubspace);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index 190b31b7b35..cd671f824fa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -8,8 +8,9 @@ import com.yahoo.tensor.TensorType;
import java.util.Iterator;
import java.util.List;
-import java.util.Map;
import java.util.Optional;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
/**
* Implementation of a sparse binary format for a tensor on the form:
@@ -24,6 +25,15 @@ import java.util.Optional;
*/
class SparseBinaryFormat implements BinaryFormat {
+ private final TensorType.Value serializationValueType;
+
+ SparseBinaryFormat() {
+ this(TensorType.Value.DOUBLE);
+ }
+ SparseBinaryFormat(TensorType.Value serializationValueType) {
+ this.serializationValueType = serializationValueType;
+ }
+
@Override
public void encode(GrowableByteBuffer buffer, Tensor tensor) {
encodeDimensions(buffer, tensor.type().dimensions());
@@ -39,10 +49,17 @@ class SparseBinaryFormat implements BinaryFormat {
private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
buffer.putInt1_4Bytes((int)tensor.size()); // XXX: Size truncation
+ switch (serializationValueType) {
+ case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
+ case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;
+ }
+ }
+
+ private void encodeCells(GrowableByteBuffer buffer, Tensor tensor, Consumer<Double> consumer) {
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
- Map.Entry<TensorAddress, Double> cell = i.next();
+ Tensor.Cell cell = i.next();
encodeAddress(buffer, cell.getKey());
- buffer.putDouble(cell.getValue());
+ consumer.accept(cell.getValue());
}
}
@@ -56,6 +73,10 @@ class SparseBinaryFormat implements BinaryFormat {
TensorType type;
if (optionalType.isPresent()) {
type = optionalType.get();
+ if (type.valueType() != this.serializationValueType) {
+ throw new IllegalArgumentException("Tensor value type mismatch. Value type " + type.valueType() +
+ " is not " + this.serializationValueType);
+ }
TensorType serializedType = decodeType(buffer);
if ( ! serializedType.isAssignableTo(type))
throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
@@ -71,18 +92,25 @@ class SparseBinaryFormat implements BinaryFormat {
private TensorType decodeType(GrowableByteBuffer buffer) {
int numDimensions = buffer.getInt1_4Bytes();
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(serializationValueType);
for (int i = 0; i < numDimensions; ++i)
builder.mapped(buffer.getUtf8String());
return builder.build();
}
private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
+ switch (serializationValueType) {
+ case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break;
+ case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break;
+ }
+ }
+
+ private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type, Supplier<Double> supplier) {
long numCells = buffer.getInt1_4Bytes(); // XXX: Size truncation
for (long i = 0; i < numCells; ++i) {
Tensor.Builder.CellBuilder cellBuilder = builder.cell();
decodeAddress(buffer, cellBuilder, type);
- cellBuilder.value(buffer.getDouble());
+ cellBuilder.value(supplier.get());
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 9b298f1dffb..bcff4392c9a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -27,32 +27,14 @@ public class TypedBinaryFormat {
private static final int DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6;
private static final int MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7;
+ private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing
+ private static final int FLOAT_VALUE_TYPE = 1;
+
public static byte[] encode(Tensor tensor) {
GrowableByteBuffer buffer = new GrowableByteBuffer();
- if (tensor instanceof MixedTensor) {
- buffer.putInt1_4Bytes(MIXED_BINARY_FORMAT_TYPE);
- new MixedBinaryFormat().encode(buffer, tensor);
- }
- else if (tensor instanceof IndexedTensor) {
- switch (tensor.type().valueType()) {
- case DOUBLE:
- buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE);
- new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).encode(buffer, tensor);
- break;
- default:
- buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE);
- new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).encode(buffer, tensor);
- break;
- }
- }
- else {
- buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE);
- new SparseBinaryFormat().encode(buffer, tensor);
- }
- buffer.flip();
- byte[] result = new byte[buffer.remaining()];
- buffer.get(result);
- return result;
+ BinaryFormat encoder = getFormatEncoder(buffer, tensor);
+ encoder.encode(buffer, tensor);
+ return asByteArray(buffer);
}
/**
@@ -64,14 +46,82 @@ public class TypedBinaryFormat {
* @throws IllegalArgumentException if the tensor data was invalid
*/
public static Tensor decode(Optional<TensorType> type, GrowableByteBuffer buffer) {
- int formatType = buffer.getInt1_4Bytes();
+ BinaryFormat decoder = getFormatDecoder(buffer);
+ return decoder.decode(type, buffer);
+ }
+
+ private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) {
+ if (tensor instanceof MixedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) {
+ encodeFormatType(buffer, MIXED_BINARY_FORMAT_TYPE);
+ return new MixedBinaryFormat();
+ }
+ if (tensor instanceof MixedTensor) {
+ encodeFormatType(buffer, MIXED_BINARY_FORMAT_WITH_CELLTYPE);
+ encodeValueType(buffer, tensor.type().valueType());
+ return new MixedBinaryFormat(tensor.type().valueType());
+ }
+ if (tensor instanceof IndexedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) {
+ encodeFormatType(buffer, DENSE_BINARY_FORMAT_TYPE);
+ return new DenseBinaryFormat();
+ }
+ if (tensor instanceof IndexedTensor) {
+ encodeFormatType(buffer, DENSE_BINARY_FORMAT_WITH_CELLTYPE);
+ encodeValueType(buffer, tensor.type().valueType());
+ return new DenseBinaryFormat(tensor.type().valueType());
+ }
+ if (tensor.type().valueType() == TensorType.Value.DOUBLE) {
+ encodeFormatType(buffer, SPARSE_BINARY_FORMAT_TYPE);
+ return new SparseBinaryFormat();
+ }
+ encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE);
+ encodeValueType(buffer, tensor.type().valueType());
+ return new SparseBinaryFormat(tensor.type().valueType());
+ }
+
+ private static BinaryFormat getFormatDecoder(GrowableByteBuffer buffer) {
+ int formatType = decodeFormatType(buffer);
switch (formatType) {
- case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat().decode(type, buffer);
- case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer);
- case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).decode(type, buffer);
- case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).decode(type, buffer);
- default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown");
+ case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat();
+ case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat();
+ case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat();
+ case SPARSE_BINARY_FORMAT_WITH_CELLTYPE: return new SparseBinaryFormat(decodeValueType(buffer));
+ case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(decodeValueType(buffer));
+ case MIXED_BINARY_FORMAT_WITH_CELLTYPE: return new MixedBinaryFormat(decodeValueType(buffer));
}
+ throw new IllegalArgumentException("Binary format type " + formatType + " is unknown");
+ }
+
+ private static void encodeFormatType(GrowableByteBuffer buffer, int formatType) {
+ buffer.putInt1_4Bytes(formatType);
+ }
+
+ private static int decodeFormatType(GrowableByteBuffer buffer) {
+ return buffer.getInt1_4Bytes();
+ }
+
+ private static void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) {
+ switch (valueType) {
+ case DOUBLE: buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); break;
+ case FLOAT: buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); break;
+ default:
+ throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType);
+ }
+ }
+
+ private static TensorType.Value decodeValueType(GrowableByteBuffer buffer) {
+ int valueType = buffer.getInt1_4Bytes();
+ switch (valueType) {
+ case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE;
+ case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT;
+ }
+ throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), or 1(float) are legal.");
+ }
+
+ private static byte[] asByteArray(GrowableByteBuffer buffer) {
+ buffer.flip();
+ byte[] result = new byte[buffer.remaining()];
+ buffer.get(result);
+ return result;
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
index 33dfca017f4..69ef4922d8d 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
@@ -77,6 +77,12 @@ public class MixedBinaryFormatTestCase {
assertSerialization(tensor);
}
+ @Test
+ public void testSerializationOfDifferentValueTypes() {
+ assertSerialization("tensor<double>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<float>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ }
+
private void assertSerialization(String tensorString) {
assertSerialization(Tensor.from(tensorString));
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index f895b64379b..9074579094c 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -1,7 +1,6 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
-import com.google.common.collect.Sets;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -9,7 +8,6 @@ import org.junit.Test;
import java.util.Arrays;
import java.util.Optional;
-import java.util.Set;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
@@ -55,6 +53,25 @@ public class SparseBinaryFormatTestCase {
Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
}
+ @Test
+ public void requireThatFloatSerializationFormatDoNotChange() {
+ byte[] encodedTensor = new byte[] {5, // binary format type
+ 1, // float type
+ 2, // num dimensions
+ 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1
+ assertEquals(Arrays.toString(encodedTensor),
+ Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ }
+
+ @Test
+ public void testSerializationOfDifferentValueTypes() {
+ assertSerialization("tensor<double>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<float>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ }
+
private void assertSerialization(String tensorString) {
assertSerialization(Tensor.from(tensorString));
}