diff options
author | Lester Solbakken <lesters@oath.com> | 2019-04-11 15:40:01 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-04-11 15:40:01 +0200 |
commit | 59bbd539683c93f2e2f5a554c3092c0e87be142b (patch) | |
tree | c0e4655a10ae005f7a1d215a196dbcc22beeb900 /vespajlib | |
parent | accba940ad0884d4f52880f75a7509f60908ab84 (diff) |
Add tensor value serialization for mapped and mixed tensors
Diffstat (limited to 'vespajlib')
6 files changed, 199 insertions, 114 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index ecd4f7d1965..5072484567d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -9,6 +9,8 @@ import com.yahoo.tensor.TensorType; import java.util.Iterator; import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Supplier; /** * Implementation of a dense binary format for a tensor on the form: @@ -22,40 +24,23 @@ import java.util.Optional; */ public class DenseBinaryFormat implements BinaryFormat { - static private final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing - static private final int FLOAT_VALUE_TYPE = 1; + private final TensorType.Value serializationValueType; - enum EncodeType {NO_DEFAULT, DOUBLE_IS_DEFAULT} - private final EncodeType encodeType; DenseBinaryFormat() { - encodeType = EncodeType.DOUBLE_IS_DEFAULT; + this(TensorType.Value.DOUBLE); } - DenseBinaryFormat(EncodeType encodeType) { - this.encodeType = encodeType; + DenseBinaryFormat(TensorType.Value serializationValueType) { + this.serializationValueType = serializationValueType; } @Override public void encode(GrowableByteBuffer buffer, Tensor tensor) { if ( ! ( tensor instanceof IndexedTensor)) throw new RuntimeException("The dense format is only supported for indexed tensors"); - encodeValueType(buffer, tensor.type().valueType()); encodeDimensions(buffer, (IndexedTensor)tensor); encodeCells(buffer, tensor); } - private void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) { - switch (valueType) { - case DOUBLE: - if (encodeType != EncodeType.DOUBLE_IS_DEFAULT) { - buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); - } - break; - case FLOAT: - buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); - break; - } - } - private void encodeDimensions(GrowableByteBuffer buffer, IndexedTensor tensor) { buffer.putInt1_4Bytes(tensor.type().dimensions().size()); for (int i = 0; i < tensor.type().dimensions().size(); i++) { @@ -65,26 +50,17 @@ public class DenseBinaryFormat implements BinaryFormat { } private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { - switch (tensor.type().valueType()) { - case DOUBLE: - encodeCellsAsDouble(buffer, tensor); - break; - case FLOAT: - encodeCellsAsFloat(buffer, tensor); - break; + switch (serializationValueType) { + case DOUBLE: encodeCells(tensor, buffer::putDouble); break; + case FLOAT: encodeCells(tensor, (i) -> buffer.putFloat(i.floatValue())); break; } } - private void encodeCellsAsDouble(GrowableByteBuffer buffer, Tensor tensor) { + private void encodeCells(Tensor tensor, Consumer<Double> consumer) { Iterator<Double> i = tensor.valueIterator(); - while (i.hasNext()) - buffer.putDouble(i.next()); - } - - private void encodeCellsAsFloat(GrowableByteBuffer buffer, Tensor tensor) { - Iterator<Double> i = tensor.valueIterator(); - while (i.hasNext()) - buffer.putFloat(i.next().floatValue()); + while (i.hasNext()) { + consumer.accept(i.next()); + } } @Override @@ -93,40 +69,27 @@ public class DenseBinaryFormat implements BinaryFormat { DimensionSizes sizes; if (optionalType.isPresent()) { type = optionalType.get(); - TensorType serializedType = decodeType(buffer, type.valueType()); + if (type.valueType() != this.serializationValueType) { + throw new IllegalArgumentException("Tensor value type mismatch. Value type " + type.valueType() + + " is not " + this.serializationValueType); + } + TensorType serializedType = decodeType(buffer); if ( ! serializedType.isAssignableTo(type)) throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + " cannot be assigned to type " + type); sizes = sizesFromType(serializedType); } else { - type = decodeType(buffer, TensorType.Value.DOUBLE); + type = decodeType(buffer); sizes = sizesFromType(type); } Tensor.Builder builder = Tensor.Builder.of(type, sizes); - decodeCells(type.valueType(), sizes, buffer, (IndexedTensor.BoundBuilder)builder); + decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder); return builder.build(); } - private TensorType decodeType(GrowableByteBuffer buffer, TensorType.Value valueType) { - TensorType.Value serializedValueType = TensorType.Value.DOUBLE; - if ((valueType != TensorType.Value.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) { - int type = buffer.getInt1_4Bytes(); - switch (type) { - case DOUBLE_VALUE_TYPE: - serializedValueType = TensorType.Value.DOUBLE; - break; - case FLOAT_VALUE_TYPE: - serializedValueType = TensorType.Value.FLOAT; - break; - default: - throw new IllegalArgumentException("Received tensor value type '" + serializedValueType + "'. Only 0(double), or 1(float) are legal."); - } - } - if (valueType != serializedValueType) { - throw new IllegalArgumentException("Expected " + valueType + ", got " + serializedValueType); - } - TensorType.Builder builder = new TensorType.Builder(serializedValueType); + private TensorType decodeType(GrowableByteBuffer buffer) { + TensorType.Builder builder = new TensorType.Builder(serializationValueType); int dimensionCount = buffer.getInt1_4Bytes(); for (int i = 0; i < dimensionCount; i++) builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation @@ -141,24 +104,16 @@ public class DenseBinaryFormat implements BinaryFormat { return builder.build(); } - private void decodeCells(TensorType.Value valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { - switch (valueType) { - case DOUBLE: - decodeCellsAsDouble(sizes, buffer, builder); - break; - case FLOAT: - decodeCellsAsFloat(sizes, buffer, builder); - break; + private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { + switch (serializationValueType) { + case DOUBLE: decodeCells(sizes, builder, buffer::getDouble); break; + case FLOAT: decodeCells(sizes, builder, () -> (double)buffer.getFloat()); break; } } - private void decodeCellsAsDouble(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { - for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.getDouble()); - } - private void decodeCellsAsFloat(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { + private void decodeCells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, Supplier<Double> supplier) { for (long i = 0; i < sizes.totalSize(); i++) - builder.cellByDirectIndex(i, buffer.getFloat()); + builder.cellByDirectIndex(i, supplier.get()); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index 284dfea2141..bc247e5561f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -11,6 +11,8 @@ import com.yahoo.tensor.TensorType; import java.util.Iterator; import java.util.List; import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -21,6 +23,15 @@ import java.util.stream.Collectors; */ class MixedBinaryFormat implements BinaryFormat { + private final TensorType.Value serializationValueType; + + MixedBinaryFormat() { + this(TensorType.Value.DOUBLE); + } + MixedBinaryFormat(TensorType.Value serializationValueType) { + this.serializationValueType = serializationValueType; + } + @Override public void encode(GrowableByteBuffer buffer, Tensor tensor) { if ( ! ( tensor instanceof MixedTensor)) @@ -50,6 +61,13 @@ class MixedBinaryFormat implements BinaryFormat { } private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor) { + switch (serializationValueType) { + case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; + case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; + } + } + + private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor, Consumer<Double> consumer) { List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); long denseSubspaceSize = tensor.denseSubspaceSize(); if (sparseDimensions.size() > 0) { @@ -63,9 +81,9 @@ class MixedBinaryFormat implements BinaryFormat { new IllegalStateException("Dimension not found in address.")); buffer.putUtf8String(cell.getKey().label(index)); } - buffer.putDouble(cell.getValue()); + consumer.accept(cell.getValue()); for (int i = 1; i < denseSubspaceSize; ++i ) { - buffer.putDouble(cellIterator.next().getValue()); + consumer.accept(cellIterator.next().getValue()); } } } @@ -75,6 +93,10 @@ class MixedBinaryFormat implements BinaryFormat { TensorType type; if (optionalType.isPresent()) { type = optionalType.get(); + if (type.valueType() != this.serializationValueType) { + throw new IllegalArgumentException("Tensor value type mismatch. Value type " + type.valueType() + + " is not " + this.serializationValueType); + } TensorType serializedType = decodeType(buffer); if ( ! serializedType.isAssignableTo(type)) throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + @@ -89,7 +111,7 @@ class MixedBinaryFormat implements BinaryFormat { } private TensorType decodeType(GrowableByteBuffer buffer) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(serializationValueType); int numMappedDimensions = buffer.getInt1_4Bytes(); for (int i = 0; i < numMappedDimensions; ++i) { builder.mapped(buffer.getUtf8String()); @@ -102,6 +124,13 @@ class MixedBinaryFormat implements BinaryFormat { } private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) { + switch (serializationValueType) { + case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break; + case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break; + } + } + + private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type, Supplier<Double> supplier) { List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); TensorType sparseType = MixedTensor.createPartialType(type.valueType(), sparseDimensions); long denseSubspaceSize = builder.denseSubspaceSize(); @@ -118,7 +147,7 @@ class MixedBinaryFormat implements BinaryFormat { sparseAddress.add(sparseDimension.name(), buffer.getUtf8String()); } for (long denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) { - denseSubspace[(int)denseOffset] = buffer.getDouble(); + denseSubspace[(int)denseOffset] = supplier.get(); } builder.block(sparseAddress.build(), denseSubspace); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 190b31b7b35..cd671f824fa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -8,8 +8,9 @@ import com.yahoo.tensor.TensorType; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Supplier; /** * Implementation of a sparse binary format for a tensor on the form: @@ -24,6 +25,15 @@ import java.util.Optional; */ class SparseBinaryFormat implements BinaryFormat { + private final TensorType.Value serializationValueType; + + SparseBinaryFormat() { + this(TensorType.Value.DOUBLE); + } + SparseBinaryFormat(TensorType.Value serializationValueType) { + this.serializationValueType = serializationValueType; + } + @Override public void encode(GrowableByteBuffer buffer, Tensor tensor) { encodeDimensions(buffer, tensor.type().dimensions()); @@ -39,10 +49,17 @@ class SparseBinaryFormat implements BinaryFormat { private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { buffer.putInt1_4Bytes((int)tensor.size()); // XXX: Size truncation + switch (serializationValueType) { + case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; + case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; + } + } + + private void encodeCells(GrowableByteBuffer buffer, Tensor tensor, Consumer<Double> consumer) { for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { - Map.Entry<TensorAddress, Double> cell = i.next(); + Tensor.Cell cell = i.next(); encodeAddress(buffer, cell.getKey()); - buffer.putDouble(cell.getValue()); + consumer.accept(cell.getValue()); } } @@ -56,6 +73,10 @@ class SparseBinaryFormat implements BinaryFormat { TensorType type; if (optionalType.isPresent()) { type = optionalType.get(); + if (type.valueType() != this.serializationValueType) { + throw new IllegalArgumentException("Tensor value type mismatch. Value type " + type.valueType() + + " is not " + this.serializationValueType); + } TensorType serializedType = decodeType(buffer); if ( ! serializedType.isAssignableTo(type)) throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + @@ -71,18 +92,25 @@ class SparseBinaryFormat implements BinaryFormat { private TensorType decodeType(GrowableByteBuffer buffer) { int numDimensions = buffer.getInt1_4Bytes(); - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(serializationValueType); for (int i = 0; i < numDimensions; ++i) builder.mapped(buffer.getUtf8String()); return builder.build(); } private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) { + switch (serializationValueType) { + case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break; + case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break; + } + } + + private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type, Supplier<Double> supplier) { long numCells = buffer.getInt1_4Bytes(); // XXX: Size truncation for (long i = 0; i < numCells; ++i) { Tensor.Builder.CellBuilder cellBuilder = builder.cell(); decodeAddress(buffer, cellBuilder, type); - cellBuilder.value(buffer.getDouble()); + cellBuilder.value(supplier.get()); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 9b298f1dffb..bcff4392c9a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -27,32 +27,14 @@ public class TypedBinaryFormat { private static final int DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6; private static final int MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7; + private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing + private static final int FLOAT_VALUE_TYPE = 1; + public static byte[] encode(Tensor tensor) { GrowableByteBuffer buffer = new GrowableByteBuffer(); - if (tensor instanceof MixedTensor) { - buffer.putInt1_4Bytes(MIXED_BINARY_FORMAT_TYPE); - new MixedBinaryFormat().encode(buffer, tensor); - } - else if (tensor instanceof IndexedTensor) { - switch (tensor.type().valueType()) { - case DOUBLE: - buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE); - new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).encode(buffer, tensor); - break; - default: - buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE); - new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).encode(buffer, tensor); - break; - } - } - else { - buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE); - new SparseBinaryFormat().encode(buffer, tensor); - } - buffer.flip(); - byte[] result = new byte[buffer.remaining()]; - buffer.get(result); - return result; + BinaryFormat encoder = getFormatEncoder(buffer, tensor); + encoder.encode(buffer, tensor); + return asByteArray(buffer); } /** @@ -64,14 +46,82 @@ public class TypedBinaryFormat { * @throws IllegalArgumentException if the tensor data was invalid */ public static Tensor decode(Optional<TensorType> type, GrowableByteBuffer buffer) { - int formatType = buffer.getInt1_4Bytes(); + BinaryFormat decoder = getFormatDecoder(buffer); + return decoder.decode(type, buffer); + } + + private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) { + if (tensor instanceof MixedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) { + encodeFormatType(buffer, MIXED_BINARY_FORMAT_TYPE); + return new MixedBinaryFormat(); + } + if (tensor instanceof MixedTensor) { + encodeFormatType(buffer, MIXED_BINARY_FORMAT_WITH_CELLTYPE); + encodeValueType(buffer, tensor.type().valueType()); + return new MixedBinaryFormat(tensor.type().valueType()); + } + if (tensor instanceof IndexedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) { + encodeFormatType(buffer, DENSE_BINARY_FORMAT_TYPE); + return new DenseBinaryFormat(); + } + if (tensor instanceof IndexedTensor) { + encodeFormatType(buffer, DENSE_BINARY_FORMAT_WITH_CELLTYPE); + encodeValueType(buffer, tensor.type().valueType()); + return new DenseBinaryFormat(tensor.type().valueType()); + } + if (tensor.type().valueType() == TensorType.Value.DOUBLE) { + encodeFormatType(buffer, SPARSE_BINARY_FORMAT_TYPE); + return new SparseBinaryFormat(); + } + encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE); + encodeValueType(buffer, tensor.type().valueType()); + return new SparseBinaryFormat(tensor.type().valueType()); + } + + private static BinaryFormat getFormatDecoder(GrowableByteBuffer buffer) { + int formatType = decodeFormatType(buffer); switch (formatType) { - case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat().decode(type, buffer); - case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer); - case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).decode(type, buffer); - case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).decode(type, buffer); - default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown"); + case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat(); + case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat(); + case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat(); + case SPARSE_BINARY_FORMAT_WITH_CELLTYPE: return new SparseBinaryFormat(decodeValueType(buffer)); + case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(decodeValueType(buffer)); + case MIXED_BINARY_FORMAT_WITH_CELLTYPE: return new MixedBinaryFormat(decodeValueType(buffer)); } + throw new IllegalArgumentException("Binary format type " + formatType + " is unknown"); + } + + private static void encodeFormatType(GrowableByteBuffer buffer, int formatType) { + buffer.putInt1_4Bytes(formatType); + } + + private static int decodeFormatType(GrowableByteBuffer buffer) { + return buffer.getInt1_4Bytes(); + } + + private static void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) { + switch (valueType) { + case DOUBLE: buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); break; + case FLOAT: buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); break; + default: + throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType); + } + } + + private static TensorType.Value decodeValueType(GrowableByteBuffer buffer) { + int valueType = buffer.getInt1_4Bytes(); + switch (valueType) { + case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE; + case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT; + } + throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), or 1(float) are legal."); + } + + private static byte[] asByteArray(GrowableByteBuffer buffer) { + buffer.flip(); + byte[] result = new byte[buffer.remaining()]; + buffer.get(result); + return result; } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java index 33dfca017f4..69ef4922d8d 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java @@ -77,6 +77,12 @@ public class MixedBinaryFormatTestCase { assertSerialization(tensor); } + @Test + public void testSerializationOfDifferentValueTypes() { + assertSerialization("tensor<double>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<float>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + } + private void assertSerialization(String tensorString) { assertSerialization(Tensor.from(tensorString)); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index f895b64379b..9074579094c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.serialization; -import com.google.common.collect.Sets; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -9,7 +8,6 @@ import org.junit.Test; import java.util.Arrays; import java.util.Optional; -import java.util.Set; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -55,6 +53,25 @@ public class SparseBinaryFormatTestCase { Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); } + @Test + public void requireThatFloatSerializationFormatDoNotChange() { + byte[] encodedTensor = new byte[] {5, // binary format type + 1, // float type + 2, // num dimensions + 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions + 2, // num cells, + 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0 + 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1 + assertEquals(Arrays.toString(encodedTensor), + Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + } + + @Test + public void testSerializationOfDifferentValueTypes() { + assertSerialization("tensor<double>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + assertSerialization("tensor<float>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); + } + private void assertSerialization(String tensorString) { assertSerialization(Tensor.from(tensorString)); } |