summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java26
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java3
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java18
5 files changed, 37 insertions, 15 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index ad82dd6c3ac..dc17c657db9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -178,9 +178,7 @@ public abstract class IndexedTensor implements Tensor {
@Override
public abstract IndexedTensor withType(TensorType type);
- public DimensionSizes dimensionSizes() {
- return dimensionSizes;
- }
+ public DimensionSizes dimensionSizes() { return dimensionSizes; }
@Override
public Map<TensorAddress, Double> cells() {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 1ec4993bf57..f608aead347 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -152,7 +152,6 @@ public class MixedTensor implements Tensor {
return index.denseSubspaceSize();
}
-
/**
* Base class for building mixed tensors.
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index bcff4392c9a..5c47572c779 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -51,31 +51,41 @@ public class TypedBinaryFormat {
}
private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) {
- if (tensor instanceof MixedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) {
+ boolean hasMappedDimensions = tensor.type().dimensions().stream().anyMatch(d -> d.isMapped());
+ boolean hasIndexedDimensions = tensor.type().dimensions().stream().anyMatch(d -> d.isIndexed());
+ boolean isMixed = hasMappedDimensions && hasIndexedDimensions;
+
+ // TODO: Encoding as indexed if the implementation is mixed is not yet supported so use mixed format instead
+ if (tensor instanceof MixedTensor && ! isMixed && hasIndexedDimensions)
+ isMixed = true;
+
+ if (isMixed && tensor.type().valueType() == TensorType.Value.DOUBLE) {
encodeFormatType(buffer, MIXED_BINARY_FORMAT_TYPE);
return new MixedBinaryFormat();
}
- if (tensor instanceof MixedTensor) {
+ else if (isMixed) {
encodeFormatType(buffer, MIXED_BINARY_FORMAT_WITH_CELLTYPE);
encodeValueType(buffer, tensor.type().valueType());
return new MixedBinaryFormat(tensor.type().valueType());
}
- if (tensor instanceof IndexedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) {
+ else if (hasIndexedDimensions && tensor.type().valueType() == TensorType.Value.DOUBLE) {
encodeFormatType(buffer, DENSE_BINARY_FORMAT_TYPE);
return new DenseBinaryFormat();
}
- if (tensor instanceof IndexedTensor) {
+ else if (hasIndexedDimensions) {
encodeFormatType(buffer, DENSE_BINARY_FORMAT_WITH_CELLTYPE);
encodeValueType(buffer, tensor.type().valueType());
return new DenseBinaryFormat(tensor.type().valueType());
}
- if (tensor.type().valueType() == TensorType.Value.DOUBLE) {
+ else if (tensor.type().valueType() == TensorType.Value.DOUBLE) {
encodeFormatType(buffer, SPARSE_BINARY_FORMAT_TYPE);
return new SparseBinaryFormat();
}
- encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE);
- encodeValueType(buffer, tensor.type().valueType());
- return new SparseBinaryFormat(tensor.type().valueType());
+ else {
+ encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE);
+ encodeValueType(buffer, tensor.type().valueType());
+ return new SparseBinaryFormat(tensor.type().valueType());
+ }
}
private static BinaryFormat getFormatDecoder(GrowableByteBuffer buffer) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
index f002637847b..066a63b6d90 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SerializationTestCase.java
@@ -71,7 +71,8 @@ public class SerializationTestCase {
serializedToABinaryRepresentation = true;
}
}
- assertTrue("Tensor did not serialize to one of the given representations", serializedToABinaryRepresentation);
+ assertTrue("Tensor serialized to one of the given representations",
+ serializedToABinaryRepresentation);
}
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index 9074579094c..50b71024ddf 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -2,7 +2,9 @@
package com.yahoo.tensor.serialization;
import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -31,6 +33,17 @@ public class SparseBinaryFormatTestCase {
}
@Test
+ public void testSerializationFormatIsDecidedByTensorTypeNotImplementationType() {
+ Tensor sparse = Tensor.Builder.of(TensorType.fromSpec("tensor(x{})"))
+ .cell(TensorAddress.ofLabels("key1"), 9.1).build();
+ Tensor sparseAsMixed = MixedTensor.Builder.of(TensorType.fromSpec("tensor(x{})"))
+ .cell(TensorAddress.ofLabels("key1"), 9.1).build();
+ byte[] sparseEncoded = TypedBinaryFormat.encode(sparse);
+ byte[] sparseAsMixedEncoded = TypedBinaryFormat.encode(sparseAsMixed);
+ assertEquals(Arrays.toString(sparseEncoded), Arrays.toString(sparseAsMixedEncoded));
+ }
+
+ @Test
public void testSerializationToSeparateType() {
try {
assertSerialization(Tensor.from("tensor(x{},y{}):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x{})"));
@@ -55,7 +68,8 @@ public class SparseBinaryFormatTestCase {
@Test
public void requireThatFloatSerializationFormatDoNotChange() {
- byte[] encodedTensor = new byte[] {5, // binary format type
+ byte[] encodedTensor = new byte[] {
+ 5, // binary format type
1, // float type
2, // num dimensions
2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
@@ -63,7 +77,7 @@ public class SparseBinaryFormatTestCase {
2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0
2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1
assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
}
@Test