diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-18 10:38:08 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-18 10:38:08 +0100 |
commit | 5bd40b8cdb0c025e439483bd7f246b68fee0e478 (patch) | |
tree | ca8ea288b55dd801aed7b95bc6e061ca6b8eee8f /vespajlib/src | |
parent | bad625e3565d83a72436224ed5ccbc2649ab89db (diff) |
Simplify and test type check
Diffstat (limited to 'vespajlib/src')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java | 18 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java | 21 |
2 files changed, 22 insertions, 17 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 6b0443c9bfe..6419cb04497 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -57,7 +57,8 @@ class SparseBinaryFormat implements BinaryFormat { TensorType type; if (optionalType.isPresent()) { type = optionalType.get(); - consumeAndValidateDimensions(optionalType.get(), buffer); + TensorType serializedType = decodeType(buffer); + serializedType.isAssignableTo(type); } else { type = decodeType(buffer); @@ -67,21 +68,6 @@ class SparseBinaryFormat implements BinaryFormat { return builder.build(); } - private void consumeAndValidateDimensions(TensorType type, GrowableByteBuffer buffer) { - int dimensionCount = buffer.getInt1_4Bytes(); - if (type.dimensions().size() != dimensionCount) - throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount + - " dimensions but type is " + type); - - for (int i = 0; i < dimensionCount; ++i) { - TensorType.Dimension expectedDimension = type.dimensions().get(i); - String encodedName = buffer.getUtf8String(); - if ( ! expectedDimension.name().equals(encodedName)) - throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName + - "' as dimension " + i + " but type is " + type); - } - } - private TensorType decodeType(GrowableByteBuffer buffer) { int numDimensions = buffer.getInt1_4Bytes(); TensorType.Builder builder = new TensorType.Builder(); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index 8a3d2879201..1ff8b3315b7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -4,6 +4,7 @@ package com.yahoo.tensor.serialization; import com.google.common.collect.Sets; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import org.junit.Ignore; import org.junit.Test; @@ -12,6 +13,7 @@ import java.util.Optional; import java.util.Set; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; /** * Tests for the dense binary format. @@ -28,6 +30,19 @@ public class DenseBinaryFormatTestCase { assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}"); assertSerialization("tensor(x[1],y[2],z[3]):{{y:0,x:0,z:0}:2.0}"); } + + @Test + public void testSerializationToSeparateType() { + assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[],y[])")); + assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[2],y[2])")); + try { + assertSerialization(Tensor.from("tensor(x[2],y[2]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[1],y[1])")); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Type/instance mismatch: Instance has size 2 in x[1] in type tensor(x[1],y[1])", expected.getMessage()); + } + } @Test public void requireThatSerializationFormatDoNotChange() { @@ -47,8 +62,12 @@ public class DenseBinaryFormatTestCase { } private void assertSerialization(Tensor tensor) { + assertSerialization(tensor, tensor.type()); + } + + private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(tensor.type()), GrowableByteBuffer.wrap(encodedTensor)); + Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); assertEquals(tensor, decodedTensor); } |