diff options
Diffstat (limited to 'vespajlib')
5 files changed, 37 insertions, 15 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index ad82dd6c3ac..dc17c657db9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -178,9 +178,7 @@ public abstract class IndexedTensor implements Tensor { @Override public abstract IndexedTensor withType(TensorType type); - public DimensionSizes dimensionSizes() { - return dimensionSizes; - } + public DimensionSizes dimensionSizes() { return dimensionSizes; } @Override public Map<TensorAddress, Double> cells() { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 1ec4993bf57..f608aead347 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -152,7 +152,6 @@ public class MixedTensor implements Tensor { return index.denseSubspaceSize(); } - /** * Base class for building mixed tensors. */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index bcff4392c9a..5c47572c779 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -51,31 +51,41 @@ public class TypedBinaryFormat { } private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) { - if (tensor instanceof MixedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) { + boolean hasMappedDimensions = tensor.type().dimensions().stream().anyMatch(d -> d.isMapped()); + boolean hasIndexedDimensions = tensor.type().dimensions().stream().anyMatch(d -> d.isIndexed()); + boolean isMixed = hasMappedDimensions && hasIndexedDimensions; + + // TODO: Encoding as indexed if the implementation is mixed is not yet supported so use mixed format instead + if (tensor instanceof MixedTensor && ! isMixed && hasIndexedDimensions) + isMixed = true; + + if (isMixed && tensor.type().valueType() == TensorType.Value.DOUBLE) { encodeFormatType(buffer, MIXED_BINARY_FORMAT_TYPE); return new MixedBinaryFormat(); } - if (tensor instanceof MixedTensor) { + else if (isMixed) { encodeFormatType(buffer, MIXED_BINARY_FORMAT_WITH_CELLTYPE); encodeValueType(buffer, tensor.type().valueType()); return new MixedBinaryFormat(tensor.type().valueType()); } - if (tensor instanceof IndexedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) { + else if (hasIndexedDimensions && tensor.type().valueType() == TensorType.Value.DOUBLE) { encodeFormatType(buffer, DENSE_BINARY_FORMAT_TYPE); return new DenseBinaryFormat(); } - if (tensor instanceof IndexedTensor) { + else if (hasIndexedDimensions) { encodeFormatType(buffer, DENSE_BINARY_FORMAT_WITH_CELLTYPE); encodeValueType(buffer, tensor.type().valueType()); return new DenseBinaryFormat(tensor.type().valueType()); } - if (tensor.type().valueType() == TensorType.Value.DOUBLE) { + else if (tensor.type().valueType() == TensorType.Value.DOUBLE) { encodeFormatType(buffer, SPARSE_BINARY_FORMAT_TYPE); return new SparseBinaryFormat(); } - encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE); - encodeValueType(buffer, tensor.type().valueType()); - return new SparseBinaryFormat(tensor.type().valueType()); + else { + encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE); + encodeValueType(buffer, tensor.type().valueType()); + return new SparseBinaryFormat(tensor.type().valueType()); + } } private static BinaryFormat getFormatDecoder(GrowableByteBuffer buffer) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java index f002637847b..066a63b6d90 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java @@ -71,7 +71,8 @@ public class SerializationTestCase { serializedToABinaryRepresentation = true; } } - assertTrue("Tensor did not serialize to one of the given representations", serializedToABinaryRepresentation); + assertTrue("Tensor serialized to one of the given representations", + serializedToABinaryRepresentation); } } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index 9074579094c..50b71024ddf 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -2,7 +2,9 @@ package com.yahoo.tensor.serialization; import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -31,6 +33,17 @@ public class SparseBinaryFormatTestCase { } @Test + public void testSerializationFormatIsDecidedByTensorTypeNotImplementationType() { + Tensor sparse = Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")) + .cell(TensorAddress.ofLabels("key1"), 9.1).build(); + Tensor sparseAsMixed = MixedTensor.Builder.of(TensorType.fromSpec("tensor(x{})")) + .cell(TensorAddress.ofLabels("key1"), 9.1).build(); + byte[] sparseEncoded = TypedBinaryFormat.encode(sparse); + byte[] sparseAsMixedEncoded = TypedBinaryFormat.encode(sparseAsMixed); + assertEquals(Arrays.toString(sparseEncoded), Arrays.toString(sparseAsMixedEncoded)); + } + + @Test public void testSerializationToSeparateType() { try { assertSerialization(Tensor.from("tensor(x{},y{}):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x{})")); @@ -55,7 +68,8 @@ public class SparseBinaryFormatTestCase { @Test public void requireThatFloatSerializationFormatDoNotChange() { - byte[] encodedTensor = new byte[] {5, // binary format type + byte[] encodedTensor = new byte[] { + 5, // binary format type 1, // float type 2, // num dimensions 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions @@ -63,7 +77,7 @@ public class SparseBinaryFormatTestCase { 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1 assertEquals(Arrays.toString(encodedTensor), - Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); + Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}")))); } @Test |