diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-18 10:56:57 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-18 10:56:57 +0100 |
commit | ac728c6a77543ea618bee127221f950670e84eb8 (patch) | |
tree | d21c3cfd66ac2ec92aeb4189c0d5a7396c4bcea4 | |
parent | 5bd40b8cdb0c025e439483bd7f246b68fee0e478 (diff) |
Simplify and test type check
5 files changed, 30 insertions, 32 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index a66caa8dd35..a4b1a02f95c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -207,8 +207,9 @@ public class IndexedTensor implements Tensor { for (int i = 0; i < sizes.dimensions(); i++ ) { Optional<Integer> size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) - throw new IllegalArgumentException("Size of " + type.dimensions() + " is " + sizes.size(i) + - " but cannot be larger than " + size.get()); + throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + + sizes.size(i) + + " but cannot be larger than " + size.get() + " in " + type); } return new BoundBuilder(type, sizes); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 3ff82ea774b..1c6d8170885 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -51,7 +51,11 @@ public class DenseBinaryFormat implements BinaryFormat { DimensionSizes sizes; if (optionalType.isPresent()) { type = optionalType.get(); - sizes = decodeAndValidateDimensionSizes(type, buffer); + TensorType serializedType = decodeType(buffer); + if ( ! type.isAssignableTo(serializedType)) + throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + + " cannot be assigned to type " + type); + sizes = sizesFromType(serializedType); } else { type = decodeType(buffer); @@ -62,32 +66,6 @@ public class DenseBinaryFormat implements BinaryFormat { return builder.build(); } - private DimensionSizes decodeAndValidateDimensionSizes(TensorType type, GrowableByteBuffer buffer) { - int dimensionCount = buffer.getInt1_4Bytes(); - if (type.dimensions().size() != dimensionCount) - throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount + - " dimensions but type is " + type); - - DimensionSizes.Builder builder = new DimensionSizes.Builder(dimensionCount); - for (int i = 0; i < dimensionCount; i++) { - TensorType.Dimension expectedDimension = type.dimensions().get(i); - - String encodedName = buffer.getUtf8String(); - int encodedSize = buffer.getInt1_4Bytes(); - - if ( ! expectedDimension.name().equals(encodedName)) - throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName + - "' as dimension " + i + " but type is " + type); - - if (expectedDimension.size().isPresent() && expectedDimension.size().get() < encodedSize) - throw new IllegalArgumentException("Type/instance mismatch: Instance has size " + encodedSize + - " in " + expectedDimension + " in type " + type); - - builder.set(i, encodedSize); - } - return builder.build(); - } - private TensorType decodeType(GrowableByteBuffer buffer) { int dimensionCount = buffer.getInt1_4Bytes(); TensorType.Builder builder = new TensorType.Builder(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 6419cb04497..4442b5521c3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -58,7 +58,9 @@ class SparseBinaryFormat implements BinaryFormat { if (optionalType.isPresent()) { type = optionalType.get(); TensorType serializedType = decodeType(buffer); - serializedType.isAssignableTo(type); + if ( ! type.isAssignableTo(serializedType)) + throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + + " cannot be assigned to type " + type); } else { type = decodeType(buffer); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index 1ff8b3315b7..9cf48bd0fdf 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -40,7 +40,7 @@ public class DenseBinaryFormatTestCase { fail("Expected exception"); } catch (IllegalArgumentException expected) { - assertEquals("Type/instance mismatch: Instance has size 2 in x[1] in type tensor(x[1],y[1])", expected.getMessage()); + assertEquals("Type/instance mismatch: A tensor of type tensor(x[2],y[2]) cannot be assigned to type tensor(x[1],y[1])", expected.getMessage()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index 65f6b92f91e..79c4c7938c1 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -4,6 +4,7 @@ package com.yahoo.tensor.serialization; import com.google.common.collect.Sets; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import org.junit.Test; import java.util.Arrays; @@ -11,6 +12,7 @@ import java.util.Optional; import java.util.Set; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; /** * Tests for the sparse binary format. @@ -31,6 +33,17 @@ public class SparseBinaryFormatTestCase { } @Test + public void testSerializationToSeparateType() { + try { + assertSerialization(Tensor.from("tensor(x{},y{}):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x{})")); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Type/instance mismatch: A tensor of type tensor(x{},y{}) cannot be assigned to type tensor(x{})", expected.getMessage()); + } + } + + @Test public void requireThatSerializationFormatDoNotChange() { byte[] encodedTensor = new byte[] {1, // binary format type 2, // num dimensions @@ -47,8 +60,12 @@ public class SparseBinaryFormatTestCase { } private void assertSerialization(Tensor tensor) { + assertSerialization(tensor, tensor.type()); + } + + private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(tensor.type()), + Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); assertEquals(tensor, decodedTensor); } |