From ac728c6a77543ea618bee127221f950670e84eb8 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 18 Jan 2017 10:56:57 +0100 Subject: Simplify and test type check --- .../main/java/com/yahoo/tensor/IndexedTensor.java | 5 ++-- .../tensor/serialization/DenseBinaryFormat.java | 32 ++++------------------ .../tensor/serialization/SparseBinaryFormat.java | 4 ++- .../serialization/DenseBinaryFormatTestCase.java | 2 +- .../serialization/SparseBinaryFormatTestCase.java | 19 ++++++++++++- 5 files changed, 30 insertions(+), 32 deletions(-) (limited to 'vespajlib') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index a66caa8dd35..a4b1a02f95c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -207,8 +207,9 @@ public class IndexedTensor implements Tensor { for (int i = 0; i < sizes.dimensions(); i++ ) { Optional size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) - throw new IllegalArgumentException("Size of " + type.dimensions() + " is " + sizes.size(i) + - " but cannot be larger than " + size.get()); + throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + + sizes.size(i) + + " but cannot be larger than " + size.get() + " in " + type); } return new BoundBuilder(type, sizes); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 3ff82ea774b..1c6d8170885 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -51,7 +51,11 @@ public class DenseBinaryFormat implements BinaryFormat { DimensionSizes sizes; if (optionalType.isPresent()) { type = optionalType.get(); - sizes = decodeAndValidateDimensionSizes(type, buffer); + TensorType serializedType = decodeType(buffer); + if ( ! type.isAssignableTo(serializedType)) + throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + + " cannot be assigned to type " + type); + sizes = sizesFromType(serializedType); } else { type = decodeType(buffer); @@ -62,32 +66,6 @@ public class DenseBinaryFormat implements BinaryFormat { return builder.build(); } - private DimensionSizes decodeAndValidateDimensionSizes(TensorType type, GrowableByteBuffer buffer) { - int dimensionCount = buffer.getInt1_4Bytes(); - if (type.dimensions().size() != dimensionCount) - throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount + - " dimensions but type is " + type); - - DimensionSizes.Builder builder = new DimensionSizes.Builder(dimensionCount); - for (int i = 0; i < dimensionCount; i++) { - TensorType.Dimension expectedDimension = type.dimensions().get(i); - - String encodedName = buffer.getUtf8String(); - int encodedSize = buffer.getInt1_4Bytes(); - - if ( ! expectedDimension.name().equals(encodedName)) - throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName + - "' as dimension " + i + " but type is " + type); - - if (expectedDimension.size().isPresent() && expectedDimension.size().get() < encodedSize) - throw new IllegalArgumentException("Type/instance mismatch: Instance has size " + encodedSize + - " in " + expectedDimension + " in type " + type); - - builder.set(i, encodedSize); - } - return builder.build(); - } - private TensorType decodeType(GrowableByteBuffer buffer) { int dimensionCount = buffer.getInt1_4Bytes(); TensorType.Builder builder = new TensorType.Builder(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 6419cb04497..4442b5521c3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -58,7 +58,9 @@ class SparseBinaryFormat implements BinaryFormat { if (optionalType.isPresent()) { type = optionalType.get(); TensorType serializedType = decodeType(buffer); - serializedType.isAssignableTo(type); + if ( ! type.isAssignableTo(serializedType)) + throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + + " cannot be assigned to type " + type); } else { type = decodeType(buffer); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index 1ff8b3315b7..9cf48bd0fdf 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -40,7 +40,7 @@ public class DenseBinaryFormatTestCase { fail("Expected exception"); } catch (IllegalArgumentException expected) { - assertEquals("Type/instance mismatch: Instance has size 2 in x[1] in type tensor(x[1],y[1])", expected.getMessage()); + assertEquals("Type/instance mismatch: A tensor of type tensor(x[2],y[2]) cannot be assigned to type tensor(x[1],y[1])", expected.getMessage()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index 65f6b92f91e..79c4c7938c1 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -4,6 +4,7 @@ package com.yahoo.tensor.serialization; import com.google.common.collect.Sets; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import org.junit.Test; import java.util.Arrays; @@ -11,6 +12,7 @@ import java.util.Optional; import java.util.Set; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; /** * Tests for the sparse binary format. @@ -30,6 +32,17 @@ public class SparseBinaryFormatTestCase { assertSerialization("tensor(x{},y{},z{}):{{y:0,x:0,z:3}:2.0,{y:1,x:0,z:6}:3.0}"); } + @Test + public void testSerializationToSeparateType() { + try { + assertSerialization(Tensor.from("tensor(x{},y{}):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x{})")); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Type/instance mismatch: A tensor of type tensor(x{},y{}) cannot be assigned to type tensor(x{})", expected.getMessage()); + } + } + @Test public void requireThatSerializationFormatDoNotChange() { byte[] encodedTensor = new byte[] {1, // binary format type @@ -47,8 +60,12 @@ public class SparseBinaryFormatTestCase { } private void assertSerialization(Tensor tensor) { + assertSerialization(tensor, tensor.type()); + } + + private void assertSerialization(Tensor tensor, TensorType expectedType) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(tensor.type()), + Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType), GrowableByteBuffer.wrap(encodedTensor)); assertEquals(tensor, decodedTensor); } -- cgit v1.2.3