diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 9aa764a0b36..becec1a4493 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -99,7 +99,7 @@ class TensorParser { if (type.isEmpty()) throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); - if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() != 1) + if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1) throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " + "but got " + type.get()); @@ -310,7 +310,7 @@ class TensorParser { } private void parse() { - TensorType.Dimension mappedDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get(); + TensorType.Dimension mappedDimension = findMappedDimension(); TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension)); if (dimensionOrder != null) dimensionOrder.remove(mappedDimension.name()); @@ -332,6 +332,15 @@ class TensorParser { } } + private TensorType.Dimension findMappedDimension() { + Optional<TensorType.Dimension> mappedDimension = builder.type().dimensions().stream().filter(d -> d.isMapped()).findAny(); + if (mappedDimension.isPresent()) return mappedDimension.get(); + if (builder.type().rank() == 1 && builder.type().dimensions().get(0).size().isEmpty()) + return builder.type().dimensions().get(0); + throw new IllegalStateException("No suitable dimension in " + builder.type() + + " for parsing as a mixed tensor. This is a bug."); + } + private void parseDenseSubspace(TensorAddress mappedAddress, List<String> denseDimensionOrder) { DenseValueParser denseParser = new DenseValueParser(string.substring(position), denseDimensionOrder, |