summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorType.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java18
1 files changed, 14 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 36693280183..57d276f278e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -88,6 +88,7 @@ public class TensorType {
private final List<Dimension> dimensions;
private final TensorType mappedSubtype;
+ private final TensorType indexedSubtype;
public TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
@@ -95,12 +96,18 @@ public class TensorType {
Collections.sort(dimensionList);
this.dimensions = List.copyOf(dimensionList);
- if (dimensionList.stream().allMatch(d -> d.isIndexed()))
+ if (dimensionList.stream().allMatch(d -> d.isIndexed())) {
mappedSubtype = empty;
- else if (dimensionList.stream().noneMatch(d -> d.isIndexed()))
+ indexedSubtype = this;
+ }
+ else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) {
mappedSubtype = this;
- else
- mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> ! d.isIndexed()).toList());
+ indexedSubtype = empty;
+ }
+ else {
+ mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> !d.isIndexed()).toList());
+ indexedSubtype = new TensorType(valueType, dimensions.stream().filter(Dimension::isIndexed).toList());
+ }
}
static public Value combinedValueType(TensorType ... types) {
@@ -135,6 +142,9 @@ public class TensorType {
/** The type representing the mapped subset of dimensions of this. */
public TensorType mappedSubtype() { return mappedSubtype; }
+ /** The type representing the indexed subset of dimensions of this. */
+ public TensorType indexedSubtype() { return indexedSubtype; }
+
/** Returns the number of dimensions of this: dimensions().size() */
public int rank() { return dimensions.size(); }