summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-01-18 10:56:57 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-01-18 10:56:57 +0100
commitac728c6a77543ea618bee127221f950670e84eb8 (patch)
treed21c3cfd66ac2ec92aeb4189c0d5a7396c4bcea4 /vespajlib
parent5bd40b8cdb0c025e439483bd7f246b68fee0e478 (diff)
Simplify and test type check
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java32
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java19
5 files changed, 30 insertions, 32 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index a66caa8dd35..a4b1a02f95c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -207,8 +207,9 @@ public class IndexedTensor implements Tensor {
for (int i = 0; i < sizes.dimensions(); i++ ) {
Optional<Integer> size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
- throw new IllegalArgumentException("Size of " + type.dimensions() + " is " + sizes.size(i) +
- " but cannot be larger than " + size.get());
+ throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
+ sizes.size(i) +
+ " but cannot be larger than " + size.get() + " in " + type);
}
return new BoundBuilder(type, sizes);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 3ff82ea774b..1c6d8170885 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -51,7 +51,11 @@ public class DenseBinaryFormat implements BinaryFormat {
DimensionSizes sizes;
if (optionalType.isPresent()) {
type = optionalType.get();
- sizes = decodeAndValidateDimensionSizes(type, buffer);
+ TensorType serializedType = decodeType(buffer);
+ if ( ! type.isAssignableTo(serializedType))
+ throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
+ " cannot be assigned to type " + type);
+ sizes = sizesFromType(serializedType);
}
else {
type = decodeType(buffer);
@@ -62,32 +66,6 @@ public class DenseBinaryFormat implements BinaryFormat {
return builder.build();
}
- private DimensionSizes decodeAndValidateDimensionSizes(TensorType type, GrowableByteBuffer buffer) {
- int dimensionCount = buffer.getInt1_4Bytes();
- if (type.dimensions().size() != dimensionCount)
- throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount +
- " dimensions but type is " + type);
-
- DimensionSizes.Builder builder = new DimensionSizes.Builder(dimensionCount);
- for (int i = 0; i < dimensionCount; i++) {
- TensorType.Dimension expectedDimension = type.dimensions().get(i);
-
- String encodedName = buffer.getUtf8String();
- int encodedSize = buffer.getInt1_4Bytes();
-
- if ( ! expectedDimension.name().equals(encodedName))
- throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName +
- "' as dimension " + i + " but type is " + type);
-
- if (expectedDimension.size().isPresent() && expectedDimension.size().get() < encodedSize)
- throw new IllegalArgumentException("Type/instance mismatch: Instance has size " + encodedSize +
- " in " + expectedDimension + " in type " + type);
-
- builder.set(i, encodedSize);
- }
- return builder.build();
- }
-
private TensorType decodeType(GrowableByteBuffer buffer) {
int dimensionCount = buffer.getInt1_4Bytes();
TensorType.Builder builder = new TensorType.Builder();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index 6419cb04497..4442b5521c3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -58,7 +58,9 @@ class SparseBinaryFormat implements BinaryFormat {
if (optionalType.isPresent()) {
type = optionalType.get();
TensorType serializedType = decodeType(buffer);
- serializedType.isAssignableTo(type);
+ if ( ! type.isAssignableTo(serializedType))
+ throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
+ " cannot be assigned to type " + type);
}
else {
type = decodeType(buffer);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index 1ff8b3315b7..9cf48bd0fdf 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -40,7 +40,7 @@ public class DenseBinaryFormatTestCase {
fail("Expected exception");
}
catch (IllegalArgumentException expected) {
- assertEquals("Type/instance mismatch: Instance has size 2 in x[1] in type tensor(x[1],y[1])", expected.getMessage());
+ assertEquals("Type/instance mismatch: A tensor of type tensor(x[2],y[2]) cannot be assigned to type tensor(x[1],y[1])", expected.getMessage());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index 65f6b92f91e..79c4c7938c1 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor.serialization;
import com.google.common.collect.Sets;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import org.junit.Test;
import java.util.Arrays;
@@ -11,6 +12,7 @@ import java.util.Optional;
import java.util.Set;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
/**
* Tests for the sparse binary format.
@@ -31,6 +33,17 @@ public class SparseBinaryFormatTestCase {
}
@Test
+ public void testSerializationToSeparateType() {
+ try {
+ assertSerialization(Tensor.from("tensor(x{},y{}):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x{})"));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Type/instance mismatch: A tensor of type tensor(x{},y{}) cannot be assigned to type tensor(x{})", expected.getMessage());
+ }
+ }
+
+ @Test
public void requireThatSerializationFormatDoNotChange() {
byte[] encodedTensor = new byte[] {1, // binary format type
2, // num dimensions
@@ -47,8 +60,12 @@ public class SparseBinaryFormatTestCase {
}
private void assertSerialization(Tensor tensor) {
+ assertSerialization(tensor, tensor.type());
+ }
+
+ private void assertSerialization(Tensor tensor, TensorType expectedType) {
byte[] encodedTensor = TypedBinaryFormat.encode(tensor);
- Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(tensor.type()),
+ Tensor decodedTensor = TypedBinaryFormat.decode(Optional.of(expectedType),
GrowableByteBuffer.wrap(encodedTensor));
assertEquals(tensor, decodedTensor);
}