diff options
Diffstat (limited to 'vespajlib/src')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java | 25 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java | 2 |
2 files changed, 17 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index f96fd65e15c..aef3b90af56 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -104,10 +104,12 @@ class TensorParser { if (type.isEmpty()) throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); - if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1) + if (type.get().dimensions().stream().filter(d -> d.isMapped()).count() > 1) throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " + "but got " + type.get()); - + if (! MixedValueParser.findMappedDimension(type.get()).isPresent()) + throw new IllegalArgumentException("No suitable dimension in " + type.get() + " for parsing a tensor on " + + "the mixed form: Should have one mapped dimension"); try { valueString = valueString.trim(); @@ -426,7 +428,7 @@ class TensorParser { } private void parse() { - TensorType.Dimension mappedDimension = findMappedDimension(); + TensorType.Dimension mappedDimension = findMappedDimension().get(); TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension)); if (dimensionOrder != null) dimensionOrder.remove(mappedDimension.name()); @@ -448,13 +450,16 @@ class TensorParser { } } - private TensorType.Dimension findMappedDimension() { - Optional<TensorType.Dimension> mappedDimension = builder.type().dimensions().stream().filter(TensorType.Dimension::isMapped).findAny(); - if (mappedDimension.isPresent()) return mappedDimension.get(); - if (builder.type().rank() == 1 && builder.type().dimensions().get(0).size().isEmpty()) - return builder.type().dimensions().get(0); - throw new IllegalStateException("No suitable dimension in " + builder.type() + - " for parsing as a mixed tensor. This is a bug."); + private Optional<TensorType.Dimension> findMappedDimension() { + return findMappedDimension(builder.type()); + } + + static Optional<TensorType.Dimension> findMappedDimension(TensorType type) { + Optional<TensorType.Dimension> mappedDimension = type.dimensions().stream().filter(TensorType.Dimension::isMapped).findAny(); + if (mappedDimension.isPresent()) return Optional.of(mappedDimension.get()); + if (type.rank() == 1 && type.dimensions().get(0).size().isEmpty()) + return Optional.of(type.dimensions().get(0)); + return Optional.empty(); } private void parseDenseSubspace(TensorAddress mappedAddress, List<String> denseDimensionOrder) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java index 5a049eeca04..7bc0556987b 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java @@ -261,6 +261,8 @@ public class TensorParserTestCase { "tensor(x[3]):[1, 2]"); assertIllegal("At value position 8: Expected a ']' but got ','", "tensor(x[3]):[1, 2, 3, 4]"); + assertIllegal("No suitable dimension in tensor(x[3]) for parsing a tensor on the mixed form: Should have one mapped dimension", + "tensor(x[3]):{1:[1,2,3], 2:[2,3,4], 3:[3,4,5]}"); } private void assertIllegal(String message, String tensor) { |