diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorType.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 46 |
1 files changed, 26 insertions, 20 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index c05c35d6df3..914d853aeca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -53,14 +53,17 @@ public class TensorType { return TensorTypeParser.fromSpec(specString); } + /** Returns the number of dimensions of this: dimensions().size() */ + public int rank() { return dimensions.size(); } + /** Returns an immutable list of the dimensions of this */ public List<Dimension> dimensions() { return dimensions; } - + /** Returns an immutable set of the names of the dimensions of this */ public Set<String> dimensionNames() { return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); } - + /** Returns the dimension with this name, or empty if not present */ public Optional<Dimension> dimension(String name) { return indexOfDimension(name).map(i -> dimensions.get(i)); @@ -74,7 +77,7 @@ public class TensorType { return Optional.empty(); } - /** + /** * Returns whether this type can be assigned to the given type, * i.e if the given type is a generalization of this type. */ @@ -128,9 +131,9 @@ public class TensorType { private final String name; - private Dimension(String name) { + private Dimension(String name) { Objects.requireNonNull(name, "A tensor name cannot be null"); - this.name = name; + this.name = name; } public final String name() { return name; } @@ -146,7 +149,7 @@ public class TensorType { /** Returns true if this is an indexed bound or unboun type */ public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; } - /** + /** * Returns the dimension resulting from combining two dimensions having the same name but possibly different * types. This works by degrading to the type making the fewer promises. * [N] + [M] = [min(N, M)] @@ -165,7 +168,7 @@ public class TensorType { IndexedBoundDimension otherIb = (IndexedBoundDimension)other.get(); return thisIb.size().get() < otherIb.size().get() ? thisIb : otherIb; } - + @Override public abstract String toString(); @@ -175,21 +178,21 @@ public class TensorType { if (other == null || getClass() != other.getClass()) return false; return name.equals(((Dimension)other).name); } - + @Override public int hashCode() { return name.hashCode(); } - + @Override public int compareTo(Dimension other) { return this.name.compareTo(other.name); } - + public static Dimension indexed(String name, int size) { return new IndexedBoundDimension(name, size); } - + } public static class IndexedBoundDimension extends TensorType.Dimension { @@ -289,9 +292,9 @@ public class TensorType { public Builder() { } - /** - * Creates a builder containing a combination of the dimensions of the given types - * + /** + * Creates a builder containing a combination of the dimensions of the given types + * * If the same dimension is indexed with different size restrictions the largest size will be used. * If it is size restricted in one argument but not the other it will not be size restricted. * If it is indexed in one and mapped in the other it will become mapped. @@ -325,9 +328,12 @@ public class TensorType { } } - /** + /** Returns the current number of dimensions in this */ + public int rank() { return dimensions.size(); } + + /** * Adds a new dimension to this - * + * * @throws IllegalArgumentException if the dimension is already present */ private Builder add(Dimension dimension) { @@ -346,7 +352,7 @@ public class TensorType { return this; } - /** + /** * Adds a bound indexed dimension to this * * @throws IllegalArgumentException if the dimension is already present @@ -355,7 +361,7 @@ public class TensorType { /** * Adds an unbound indexed dimension to this - * + * * @throws IllegalArgumentException if the dimension is already present */ public Builder indexed(String name) { @@ -375,7 +381,7 @@ public class TensorType { public Builder dimension(Dimension dimension) { return add(dimension); } - + /** Returns the given dimension, or empty if none is present */ public Optional<Dimension> getDimension(String dimension) { return Optional.ofNullable(dimensions.get(dimension)); @@ -393,7 +399,7 @@ public class TensorType { public TensorType build() { return new TensorType(dimensions.values()); } - + } } |