aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-21 11:20:59 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-21 11:38:53 +0100
commit351e0f62fe4f9b1b1015c0c1289a3f519fa9f868 (patch)
tree28c7e3c8c3b00ddc52cb5a80ecaccffc9dab1c24 /vespajlib
parentf481110c07a4759b98452deb35fe719ce6c49be4 (diff)
- Extract dimension names in a set to avoid recomputing it in dimensionNames() mathod.
- Also extract some information about what kind of dimensions a tensor has construcion time. This avoids streamin through all diemsion later on, and keeps this local in the TensorType.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java15
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java29
4 files changed, 40 insertions, 16 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 93cdc3f630f..5d384e0329b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -164,9 +164,10 @@ public abstract class IndexedTensor implements Tensor {
long valueIndex = 0;
for (int i = 0; i < address.size(); i++) {
- if (address.numericLabel(i) >= sizes.size(i))
+ long label = address.numericLabel(i);
+ if (label >= sizes.size(i))
throw new IllegalArgumentException(address + " is not within the bounds of " + type);
- valueIndex += sizes.productOfDimensionsAfter(i) * address.numericLabel(i);
+ valueIndex += sizes.productOfDimensionsAfter(i) * label;
}
return valueIndex;
}
@@ -281,7 +282,7 @@ public abstract class IndexedTensor implements Tensor {
}
public static Builder of(TensorType type) {
- if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type));
else
return new UnboundBuilder(type);
@@ -295,7 +296,7 @@ public abstract class IndexedTensor implements Tensor {
* must not be further mutated by the caller
*/
public static Builder of(TensorType type, float[] values) {
- if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type), values);
else
return new UnboundBuilder(type);
@@ -309,7 +310,7 @@ public abstract class IndexedTensor implements Tensor {
* must not be further mutated by the caller
*/
public static Builder of(TensorType type, double[] values) {
- if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type), values);
else
return new UnboundBuilder(type);
@@ -615,11 +616,11 @@ public abstract class IndexedTensor implements Tensor {
private final class ValueIterator implements Iterator<Double> {
- private long count = 0;
+ private int count = 0;
@Override
public boolean hasNext() {
- return count < size();
+ return count < sizeAsInt();
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 74b338fb503..d755d6d6337 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -266,7 +266,7 @@ public class MixedTensor implements Tensor {
* a temporary structure while finding dimension bounds.
*/
public static Builder of(TensorType type) {
- if (type.dimensions().stream().anyMatch(d -> d instanceof TensorType.IndexedUnboundDimension)) {
+ if (type.hasIndexedUnboundDimensions()) {
return new UnboundBuilder(type);
} else {
return new BoundBuilder(type);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index cc8e1602adb..cff17fdfd7c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -113,7 +113,7 @@ public interface Tensor {
* @throws IllegalStateException if this does not have zero dimensions and one value
*/
default double asDouble() {
- if (type().dimensions().size() > 0)
+ if (!type().dimensions().isEmpty())
throw new IllegalStateException("Require a dimensionless tensor but has " + type());
if (size() == 0) return Double.NaN;
return valueIterator().next();
@@ -553,8 +553,8 @@ public interface Tensor {
/** Creates a suitable builder for the given type */
static Builder of(TensorType type) {
- boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
- boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
+ boolean containsIndexed = type.hasIndexedDimensions();
+ boolean containsMapped = type.hasMappedDimensions();
if (containsIndexed && containsMapped)
return MixedTensor.Builder.of(type);
if (containsMapped)
@@ -565,8 +565,8 @@ public interface Tensor {
/** Creates a suitable builder for the given type */
static Builder of(TensorType type, DimensionSizes dimensionSizes) {
- boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
- boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
+ boolean containsIndexed = type.hasIndexedDimensions();
+ boolean containsMapped = type.hasMappedDimensions();
if (containsIndexed && containsMapped)
return MixedTensor.Builder.of(type);
if (containsMapped)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index b30b664a5f7..b7346348672 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import com.google.common.collect.ImmutableSet;
import com.yahoo.text.Ascii7BitMatcher;
import java.util.ArrayList;
@@ -86,16 +87,20 @@ public class TensorType {
/** Sorted list of the dimensions of this */
private final List<Dimension> dimensions;
+ private final Set<String> dimensionNames;
private final TensorType mappedSubtype;
private final TensorType indexedSubtype;
+ private final int indexedUnBoundCount;
// only used to initialize the "empty" instance
private TensorType() {
this.valueType = Value.DOUBLE;
this.dimensions = List.of();
+ this.dimensionNames = Set.of();
this.mappedSubtype = this;
this.indexedSubtype = this;
+ indexedUnBoundCount = 0;
}
public TensorType(Value valueType, Collection<Dimension> dimensions) {
@@ -103,12 +108,25 @@ public class TensorType {
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
this.dimensions = List.copyOf(dimensionList);
+ ImmutableSet.Builder<String> namesbuilder = new ImmutableSet.Builder<>();
+ int indexedBoundCount = 0, indexedUnBoundCount = 0, mappedCount = 0;
+ for (Dimension dimension : dimensionList) {
+ namesbuilder.add(dimension.name());
+ Dimension.Type type = dimension.type();
+ switch (type) {
+ case indexedUnbound -> indexedUnBoundCount++;
+ case indexedBound -> indexedBoundCount++;
+ case mapped -> mappedCount++;
+ }
+ }
+ this.indexedUnBoundCount = indexedUnBoundCount;
+ dimensionNames = namesbuilder.build();
- if (dimensionList.stream().allMatch(Dimension::isIndexed)) {
+ if (mappedCount == 0) {
mappedSubtype = empty;
indexedSubtype = this;
}
- else if (dimensionList.stream().noneMatch(Dimension::isIndexed)) {
+ else if ((indexedBoundCount + indexedUnBoundCount) == 0) {
mappedSubtype = this;
indexedSubtype = empty;
}
@@ -118,6 +136,11 @@ public class TensorType {
}
}
+ boolean hasIndexedDimensions() { return indexedSubtype != empty; }
+ boolean hasMappedDimensions() { return mappedSubtype != empty; }
+ boolean hasOnlyIndexedBoundDimensions() { return !hasMappedDimensions() && ! hasIndexedUnboundDimensions(); }
+ boolean hasIndexedUnboundDimensions() { return indexedUnBoundCount > 0; }
+
static public Value combinedValueType(TensorType ... types) {
List<Value> valueTypes = new ArrayList<>();
for (TensorType type : types) {
@@ -161,7 +184,7 @@ public class TensorType {
/** Returns an immutable set of the names of the dimensions of this */
public Set<String> dimensionNames() {
- return dimensions.stream().map(Dimension::name).collect(Collectors.toSet());
+ return dimensionNames;
}
/** Returns the dimension with this name, or empty if not present */