summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2024-01-21 13:18:18 +0100
committerGitHub <noreply@github.com>2024-01-21 13:18:18 +0100
commit43c05215e666f47c15d9d73aadc80a9735b1b426 (patch)
treef6111f79d9f7648c5963f3138e6b9797756036a6 /vespajlib
parent0db9b671b13857c77bdb08b026fcf1413dd5b3ae (diff)
parent87dd4177a06f31a97156c8851eddfd96668f8b60 (diff)
Merge pull request #29993 from vespa-engine/balder/precompute-type-related-information-once
- Extract dimension names in a set to avoid recomputing it in dimensiā€¦
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java15
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java13
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java4
10 files changed, 57 insertions, 33 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 174ce6332db..4a65b00a6a4 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1470,6 +1470,9 @@
],
"methods" : [
"public void <init>(com.yahoo.tensor.TensorType$Value, java.util.Collection)",
+ "public boolean hasIndexedDimensions()",
+ "public boolean hasMappedDimensions()",
+ "public boolean hasOnlyIndexedBoundDimensions()",
"public static varargs com.yahoo.tensor.TensorType$Value combinedValueType(com.yahoo.tensor.TensorType[])",
"public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
"public com.yahoo.tensor.TensorType$Value valueType()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 93cdc3f630f..5d384e0329b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -164,9 +164,10 @@ public abstract class IndexedTensor implements Tensor {
long valueIndex = 0;
for (int i = 0; i < address.size(); i++) {
- if (address.numericLabel(i) >= sizes.size(i))
+ long label = address.numericLabel(i);
+ if (label >= sizes.size(i))
throw new IllegalArgumentException(address + " is not within the bounds of " + type);
- valueIndex += sizes.productOfDimensionsAfter(i) * address.numericLabel(i);
+ valueIndex += sizes.productOfDimensionsAfter(i) * label;
}
return valueIndex;
}
@@ -281,7 +282,7 @@ public abstract class IndexedTensor implements Tensor {
}
public static Builder of(TensorType type) {
- if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type));
else
return new UnboundBuilder(type);
@@ -295,7 +296,7 @@ public abstract class IndexedTensor implements Tensor {
* must not be further mutated by the caller
*/
public static Builder of(TensorType type, float[] values) {
- if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type), values);
else
return new UnboundBuilder(type);
@@ -309,7 +310,7 @@ public abstract class IndexedTensor implements Tensor {
* must not be further mutated by the caller
*/
public static Builder of(TensorType type, double[] values) {
- if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type), values);
else
return new UnboundBuilder(type);
@@ -615,11 +616,11 @@ public abstract class IndexedTensor implements Tensor {
private final class ValueIterator implements Iterator<Double> {
- private long count = 0;
+ private int count = 0;
@Override
public boolean hasNext() {
- return count < size();
+ return count < sizeAsInt();
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index e6315dbef80..30dd1d6dc29 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -267,7 +267,7 @@ public class MixedTensor implements Tensor {
* a temporary structure while finding dimension bounds.
*/
public static Builder of(TensorType type) {
- if (type.dimensions().stream().anyMatch(d -> d instanceof TensorType.IndexedUnboundDimension)) {
+ if (type.hasIndexedUnboundDimensions()) {
return new UnboundBuilder(type);
} else {
return new BoundBuilder(type);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index cc8e1602adb..cff17fdfd7c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -113,7 +113,7 @@ public interface Tensor {
* @throws IllegalStateException if this does not have zero dimensions and one value
*/
default double asDouble() {
- if (type().dimensions().size() > 0)
+ if (!type().dimensions().isEmpty())
throw new IllegalStateException("Require a dimensionless tensor but has " + type());
if (size() == 0) return Double.NaN;
return valueIterator().next();
@@ -553,8 +553,8 @@ public interface Tensor {
/** Creates a suitable builder for the given type */
static Builder of(TensorType type) {
- boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
- boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
+ boolean containsIndexed = type.hasIndexedDimensions();
+ boolean containsMapped = type.hasMappedDimensions();
if (containsIndexed && containsMapped)
return MixedTensor.Builder.of(type);
if (containsMapped)
@@ -565,8 +565,8 @@ public interface Tensor {
/** Creates a suitable builder for the given type */
static Builder of(TensorType type, DimensionSizes dimensionSizes) {
- boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
- boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
+ boolean containsIndexed = type.hasIndexedDimensions();
+ boolean containsMapped = type.hasMappedDimensions();
if (containsIndexed && containsMapped)
return MixedTensor.Builder.of(type);
if (containsMapped)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index b30b664a5f7..82968476296 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import com.google.common.collect.ImmutableSet;
import com.yahoo.text.Ascii7BitMatcher;
import java.util.ArrayList;
@@ -86,16 +87,20 @@ public class TensorType {
/** Sorted list of the dimensions of this */
private final List<Dimension> dimensions;
+ private final Set<String> dimensionNames;
private final TensorType mappedSubtype;
private final TensorType indexedSubtype;
+ private final int indexedUnBoundCount;
// only used to initialize the "empty" instance
private TensorType() {
this.valueType = Value.DOUBLE;
this.dimensions = List.of();
+ this.dimensionNames = Set.of();
this.mappedSubtype = this;
this.indexedSubtype = this;
+ indexedUnBoundCount = 0;
}
public TensorType(Value valueType, Collection<Dimension> dimensions) {
@@ -103,12 +108,25 @@ public class TensorType {
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
this.dimensions = List.copyOf(dimensionList);
+ ImmutableSet.Builder<String> namesbuilder = new ImmutableSet.Builder<>();
+ int indexedBoundCount = 0, indexedUnBoundCount = 0, mappedCount = 0;
+ for (Dimension dimension : dimensionList) {
+ namesbuilder.add(dimension.name());
+ Dimension.Type type = dimension.type();
+ switch (type) {
+ case indexedUnbound -> indexedUnBoundCount++;
+ case indexedBound -> indexedBoundCount++;
+ case mapped -> mappedCount++;
+ }
+ }
+ this.indexedUnBoundCount = indexedUnBoundCount;
+ dimensionNames = namesbuilder.build();
- if (dimensionList.stream().allMatch(Dimension::isIndexed)) {
+ if (mappedCount == 0) {
mappedSubtype = empty;
indexedSubtype = this;
}
- else if (dimensionList.stream().noneMatch(Dimension::isIndexed)) {
+ else if ((indexedBoundCount + indexedUnBoundCount) == 0) {
mappedSubtype = this;
indexedSubtype = empty;
}
@@ -118,6 +136,11 @@ public class TensorType {
}
}
+ public boolean hasIndexedDimensions() { return indexedSubtype != empty; }
+ public boolean hasMappedDimensions() { return mappedSubtype != empty; }
+ public boolean hasOnlyIndexedBoundDimensions() { return !hasMappedDimensions() && ! hasIndexedUnboundDimensions(); }
+ boolean hasIndexedUnboundDimensions() { return indexedUnBoundCount > 0; }
+
static public Value combinedValueType(TensorType ... types) {
List<Value> valueTypes = new ArrayList<>();
for (TensorType type : types) {
@@ -161,7 +184,7 @@ public class TensorType {
/** Returns an immutable set of the names of the dimensions of this */
public Set<String> dimensionNames() {
- return dimensions.stream().map(Dimension::name).collect(Collectors.toSet());
+ return dimensionNames;
}
/** Returns the dimension with this name, or empty if not present */
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 8d8fe2b356f..866b710b72e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -134,7 +134,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return tensor;
}
else { // extend tensor with this dimension
- if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
+ if (tensor.type().hasMappedDimensions())
throw new IllegalArgumentException("Concat requires an indexed tensor, " +
"but got a tensor with type " + tensor.type());
Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 3b6e03186a3..b595b1a40cd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -40,7 +40,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
@Override
public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if (arguments.size() != 0)
+ if (!arguments.isEmpty())
throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size());
return this;
}
@@ -79,7 +79,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() {
var result = new ArrayList<TensorFunction<NAMETYPE>>();
for (var fun : cells.values()) {
- fun.asTensorFunction().ifPresent(tf -> result.add(tf));
+ fun.asTensorFunction().ifPresent(result::add);
}
return result;
}
@@ -133,7 +133,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) {
super(type);
- if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))
+ if ( ! type.hasOnlyIndexedBoundDimensions())
throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " +
"only indexed, bound dimensions, but this has " + type);
this.cells = List.copyOf(cells);
@@ -142,7 +142,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() {
var result = new ArrayList<TensorFunction<NAMETYPE>>();
for (var fun : cells) {
- fun.asTensorFunction().ifPresent(tf -> result.add(tf));
+ fun.asTensorFunction().ifPresent(result::add);
}
return result;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index aece782d296..2d5a0518747 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -92,11 +92,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
return false;
if ( ! (a instanceof IndexedTensor))
return false;
- if ( ! (a.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)))
+ if ( ! (a.type().hasOnlyIndexedBoundDimensions()))
return false;
if ( ! (b instanceof IndexedTensor))
return false;
- if ( ! (b.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)))
+ if ( ! (b.type().hasOnlyIndexedBoundDimensions()))
return false;
TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 444ce02b14a..771b74633d9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -21,10 +21,8 @@ import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Slice;
import java.util.ArrayList;
-import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
-import java.util.Set;
/**
* Writes tensors on the JSON format used in Vespa tensor document fields:
@@ -60,8 +58,7 @@ public class JsonFormat {
// Short form for a single mapped dimension
Cursor parent = root == null ? slime.setObject() : root.setObject("cells");
encodeSingleDimensionCells((MappedTensor) tensor, parent);
- } else if (tensor instanceof MixedTensor &&
- tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped)) {
+ } else if (tensor instanceof MixedTensor && tensor.type().hasMappedDimensions()) {
// Short form for a mixed tensor
boolean singleMapped = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1;
Cursor parent = root == null ? ( singleMapped ? slime.setObject() : slime.setArray() )
@@ -204,7 +201,7 @@ public class JsonFormat {
if (root.field("cells").valid() && ! primitiveContent(root.field("cells")))
decodeCells(root.field("cells"), builder);
- else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed()))
+ else if (root.field("values").valid() && ! builder.type().hasMappedDimensions())
decodeValuesAtTop(root.field("values"), builder);
else if (root.field("blocks").valid())
decodeBlocks(root.field("blocks"), builder);
@@ -298,14 +295,14 @@ public class JsonFormat {
/** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */
private static void decodeDirectValue(Inspector root, Tensor.Builder builder) {
- boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
- boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
+ boolean hasIndexed = builder.type().hasIndexedDimensions();
+ boolean hasMapped = builder.type().hasMappedDimensions();
if (isArrayOfObjects(root))
decodeCells(root, builder);
else if ( ! hasMapped)
decodeValuesAtTop(root, builder);
- else if (hasMapped && hasIndexed)
+ else if (hasIndexed)
decodeBlocks(root, builder);
else
decodeCells(root, builder);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index d4b18c73f11..0a5c713f3e2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -55,8 +55,8 @@ public class TypedBinaryFormat {
}
private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) {
- boolean hasMappedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
- boolean hasIndexedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
+ boolean hasMappedDimensions = tensor.type().hasMappedDimensions();
+ boolean hasIndexedDimensions = tensor.type().hasIndexedDimensions();
boolean isMixed = hasMappedDimensions && hasIndexedDimensions;
// TODO: Encoding as indexed if the implementation is mixed is not yet supported so use mixed format instead