diff options
Diffstat (limited to 'vespajlib/src/main/java/com')
3 files changed, 11 insertions, 30 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index a66caa8dd35..a4b1a02f95c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -207,8 +207,9 @@ public class IndexedTensor implements Tensor { for (int i = 0; i < sizes.dimensions(); i++ ) { Optional<Integer> size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) - throw new IllegalArgumentException("Size of " + type.dimensions() + " is " + sizes.size(i) + - " but cannot be larger than " + size.get()); + throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + + sizes.size(i) + + " but cannot be larger than " + size.get() + " in " + type); } return new BoundBuilder(type, sizes); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index 3ff82ea774b..1c6d8170885 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -51,7 +51,11 @@ public class DenseBinaryFormat implements BinaryFormat { DimensionSizes sizes; if (optionalType.isPresent()) { type = optionalType.get(); - sizes = decodeAndValidateDimensionSizes(type, buffer); + TensorType serializedType = decodeType(buffer); + if ( ! type.isAssignableTo(serializedType)) + throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + + " cannot be assigned to type " + type); + sizes = sizesFromType(serializedType); } else { type = decodeType(buffer); @@ -62,32 +66,6 @@ public class DenseBinaryFormat implements BinaryFormat { return builder.build(); } - private DimensionSizes decodeAndValidateDimensionSizes(TensorType type, GrowableByteBuffer buffer) { - int dimensionCount = buffer.getInt1_4Bytes(); - if (type.dimensions().size() != dimensionCount) - throw new IllegalArgumentException("Type/instance mismatch: Instance has " + dimensionCount + - " dimensions but type is " + type); - - DimensionSizes.Builder builder = new DimensionSizes.Builder(dimensionCount); - for (int i = 0; i < dimensionCount; i++) { - TensorType.Dimension expectedDimension = type.dimensions().get(i); - - String encodedName = buffer.getUtf8String(); - int encodedSize = buffer.getInt1_4Bytes(); - - if ( ! expectedDimension.name().equals(encodedName)) - throw new IllegalArgumentException("Type/instance mismatch: Instance has '" + encodedName + - "' as dimension " + i + " but type is " + type); - - if (expectedDimension.size().isPresent() && expectedDimension.size().get() < encodedSize) - throw new IllegalArgumentException("Type/instance mismatch: Instance has size " + encodedSize + - " in " + expectedDimension + " in type " + type); - - builder.set(i, encodedSize); - } - return builder.build(); - } - private TensorType decodeType(GrowableByteBuffer buffer) { int dimensionCount = buffer.getInt1_4Bytes(); TensorType.Builder builder = new TensorType.Builder(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 6419cb04497..4442b5521c3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -58,7 +58,9 @@ class SparseBinaryFormat implements BinaryFormat { if (optionalType.isPresent()) { type = optionalType.get(); TensorType serializedType = decodeType(buffer); - serializedType.isAssignableTo(type); + if ( ! type.isAssignableTo(serializedType)) + throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + + " cannot be assigned to type " + type); } else { type = decodeType(buffer); |