diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorType.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 36693280183..57d276f278e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -88,6 +88,7 @@ public class TensorType { private final List<Dimension> dimensions; private final TensorType mappedSubtype; + private final TensorType indexedSubtype; public TensorType(Value valueType, Collection<Dimension> dimensions) { this.valueType = valueType; @@ -95,12 +96,18 @@ public class TensorType { Collections.sort(dimensionList); this.dimensions = List.copyOf(dimensionList); - if (dimensionList.stream().allMatch(d -> d.isIndexed())) + if (dimensionList.stream().allMatch(d -> d.isIndexed())) { mappedSubtype = empty; - else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) + indexedSubtype = this; + } + else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) { mappedSubtype = this; - else - mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> ! d.isIndexed()).toList()); + indexedSubtype = empty; + } + else { + mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> !d.isIndexed()).toList()); + indexedSubtype = new TensorType(valueType, dimensions.stream().filter(Dimension::isIndexed).toList()); + } } static public Value combinedValueType(TensorType ... types) { @@ -135,6 +142,9 @@ public class TensorType { /** The type representing the mapped subset of dimensions of this. */ public TensorType mappedSubtype() { return mappedSubtype; } + /** The type representing the indexed subset of dimensions of this. */ + public TensorType indexedSubtype() { return indexedSubtype; } + /** Returns the number of dimensions of this: dimensions().size() */ public int rank() { return dimensions.size(); } |