From 91b2cf49d8e25c378f4aa00833ca8245f9c1ca65 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Thu, 3 Sep 2020 12:01:39 +0200 Subject: Use the tensor type to switch tensor binary format The binary format of a tensor should depend on the tensor type, not the implementation type as the API permits the user choosing that (and it may not be 1-1 anyway). This makes this change for sparse tensors using the mixed implementation type but not dense tensors using the mixed implementation type as that would be more work given the unfinished state of the mixed implementation. --- .../main/java/com/yahoo/tensor/IndexedTensor.java | 4 +--- .../main/java/com/yahoo/tensor/MixedTensor.java | 1 - .../tensor/serialization/TypedBinaryFormat.java | 26 +++++++++++++++------- .../serialization/SerializationTestCase.java | 3 ++- .../serialization/SparseBinaryFormatTestCase.java | 18 +++++++++++++-- 5 files changed, 37 insertions(+), 15 deletions(-) (limited to 'vespajlib') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index ad82dd6c3ac..dc17c657db9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -178,9 +178,7 @@ public abstract class IndexedTensor implements Tensor { @Override public abstract IndexedTensor withType(TensorType type); - public DimensionSizes dimensionSizes() { - return dimensionSizes; - } + public DimensionSizes dimensionSizes() { return dimensionSizes; } @Override public Map cells() { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 1ec4993bf57..f608aead347 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -152,7 +152,6 @@ public class MixedTensor implements Tensor { return index.denseSubspaceSize(); } - /** * Base class for building mixed tensors. */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index bcff4392c9a..5c47572c779 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -51,31 +51,41 @@ public class TypedBinaryFormat { } private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) { - if (tensor instanceof MixedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) { + boolean hasMappedDimensions = tensor.type().dimensions().stream().anyMatch(d -> d.isMapped()); + boolean hasIndexedDimensions = tensor.type().dimensions().stream().anyMatch(d -> d.isIndexed()); + boolean isMixed = hasMappedDimensions && hasIndexedDimensions; + + // TODO: Encoding as indexed if the implementation is mixed is not yet supported so use mixed format instead + if (tensor instanceof MixedTensor && ! isMixed && hasIndexedDimensions) + isMixed = true; + + if (isMixed && tensor.type().valueType() == TensorType.Value.DOUBLE) { encodeFormatType(buffer, MIXED_BINARY_FORMAT_TYPE); return new MixedBinaryFormat(); } - if (tensor instanceof MixedTensor) { + else if (isMixed) { encodeFormatType(buffer, MIXED_BINARY_FORMAT_WITH_CELLTYPE); encodeValueType(buffer, tensor.type().valueType()); return new MixedBinaryFormat(tensor.type().valueType()); } - if (tensor instanceof IndexedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) { + else if (hasIndexedDimensions && tensor.type().valueType() == TensorType.Value.DOUBLE) { encodeFormatType(buffer, DENSE_BINARY_FORMAT_TYPE); return new DenseBinaryFormat(); } - if (tensor instanceof IndexedTensor) { + else if (hasIndexedDimensions) { encodeFormatType(buffer, DENSE_BINARY_FORMAT_WITH_CELLTYPE); encodeValueType(buffer, tensor.type().valueType()); return new DenseBinaryFormat(tensor.type().valueType()); } - if (tensor.type().valueType() == TensorType.Value.DOUBLE) { + else if (tensor.type().valueType() == TensorType.Value.DOUBLE) { encodeFormatType(buffer, SPARSE_BINARY_FORMAT_TYPE); return new SparseBinaryFormat(); } - encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE); - encodeValueType(buffer, tensor.type().valueType()); - return new SparseBinaryFormat(tensor.type().valueType()); + else { + encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE); + encodeValueType(buffer, tensor.type().valueType()); + return new SparseBinaryFormat(tensor.type().valueType()); + } } private static BinaryFormat getFormatDecoder(GrowableByteBuffer buffer) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java index f002637847b..066a63b6d90 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java @@ -71,7 +71,8 @@ public class SerializationTestCase { serializedToABinaryRepresentation = true; } } - assertTrue("Tensor did not serialize to one of the given representations", serializedToABinaryRepresentation); + assertTrue("Tensor serialized to one of the given representations", + serializedToABinaryRepresentation); } } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index 9074579094c..50b71024ddf 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -2,7 +2,9 @@ package com.yahoo.tensor.serialization; import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -30,6 +32,17 @@ public class SparseBinaryFormatTestCase { assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0,{y:1,x:0,z:6}:3.0}"); } + @Test + public void testSerializationFormatIsDecidedByTensorTypeNotImplementationType() { + Tensor sparse = Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")) + .cell(TensorAddress.ofLabels("key1"), 9.1).build(); + Tensor sparseAsMixed = MixedTensor.Builder.of(TensorType.fromSpec("tensor(x{})")) + .cell(TensorAddress.ofLabels("key1"), 9.1).build(); + byte[] sparseEncoded = TypedBinaryFormat.encode(sparse); + byte[] sparseAsMixedEncoded = TypedBinaryFormat.encode(sparseAsMixed); + assertEquals(Arrays.toString(sparseEncoded), Arrays.toString(sparseAsMixedEncoded)); + } + @Test public void testSerializationToSeparateType() { try { @@ -55,7 +68,8 @@ public class SparseBinaryFormatTestCase { @Test public void requireThatFloatSerializationFormatDoNotChange() { - byte[] encodedTensor = new byte[] {5, // binary format type + byte[] encodedTensor = new byte[] { + 5, // binary format type 1, // float type 2, // num dimensions 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions @@ -63,7 +77,7 @@ public class SparseBinaryFormatTestCase { 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1 assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); } @Test -- cgit v1.2.3