aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-01-18 10:38:08 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-01-18 10:38:08 +0100
commit5bd40b8cdb0c025e439483bd7f246b68fee0e478 (patch)
treeca8ea288b55dd801aed7b95bc6e061ca6b8eee8f /vespajlib/src
parentbad625e3565d83a72436224ed5ccbc2649ab89db (diff)
Simplify and test type check
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java18
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java21
2 files changed, 22 insertions, 17 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index 6b0443c9bfe..6419cb04497 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -57,7 +57,8 @@ class SparseBinaryFormat implements BinaryFormat {
TensorType type;
if (optionalType.isPresent()) {
type = optionalType.get();
- consumeAndValidateDimensions(optionalType.get(), buffer);
+ TensorType serializedType = decodeType(buffer);
+ serializedType.isAssignableTo(type);
}
else {
type = decodeType(buffer);
@@ -67,21 +68,6 @@ class SparseBinaryFormat implements BinaryFormat {
return builder.build();
}
- private void consumeAndValidateDimensions(TensorType type, GrowableByteBuffer buffer) {
- int dimensionCount = buffer.getInt1_4Bytes();
- if (type.dimensions().size() != dimensionCount)
- throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount +
- " dimensions but type is " + type);
-
- for (int i = 0; i < dimensionCount; ++i) {
- TensorType.Dimension expectedDimension = type.dimensions().get(i);
- String encodedName = buffer.getUtf8String();
- if ( ! expectedDimension.name().equals(encodedName))
- throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName +
- "' as dimension " + i + " but type is " + type);
- }
- }
-
private TensorType decodeType(GrowableByteBuffer buffer) {
int numDimensions = buffer.getInt1_4Bytes();
TensorType.Builder builder = new TensorType.Builder();
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index 8a3d2879201..1ff8b3315b7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor.serialization;
import com.google.common.collect.Sets;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import org.junit.Ignore;
import org.junit.Test;
@@ -12,6 +13,7 @@ import java.util.Optional;
import java.util.Set;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
/**
* Tests for the dense binary format.
@@ -28,6 +30,19 @@ public class DenseBinaryFormatTestCase {
assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}");
}
+
+ @Test
+ public void testSerializationToSeparateType() {
+ assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[],y[])"));
+ assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[2],y[2])"));
+ try {
+ assertSerialization(Tensor.from("tensor(x[2],y[2]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[1],y[1])"));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Type/instance mismatch: Instance has size 2 in x[1] in type tensor(x[1],y[1])", expected.getMessage());
+ }
+ }
@Test
public void requireThatSerializationFormatDoNotChange() {
@@ -47,8 +62,12 @@ public class DenseBinaryFormatTestCase {
}
private void assertSerialization(Tensor tensor) {
+ assertSerialization(tensor, tensor.type());
+ }
+
+ private void assertSerialization(Tensor tensor, TensorType expectedType) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
- Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(tensor.type()), GrowableByteBuffer.wrap(encodedTensor));
+ Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor));
assertEquals(tensor, decodedTensor);
}