summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-31 13:52:41 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-31 14:22:21 +0100
commit43eed85aae7dfb7f515fe300f18457b8115aba69 (patch)
tree5185b211a5c4fad4919f882c21e86c5fd0924f7f /vespajlib
parent83c3d3716f34de515c9d03e90c34f2085b777268 (diff)
- Make separate path for reduction of indexed tensors.
- This brough down reduce time from 60s to 3.7s or a factor of 16x for the splade embedder if using reduce. This is ontop of the 1.5x improvement with https://github.com/vespa-engine/vespa/pull/30112 - A reduce based splade embedder now uses 8.0s vs 7.0s for a custom rolled out version, translated to 14% overhead. Please enter the commit message for your changes. Lines starting
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java70
2 files changed, 74 insertions, 6 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
index 4379d50520c..cda3be47ddb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
@@ -22,7 +22,7 @@ public final class DirectIndexedAddress {
directIndex = 0;
}
- static DirectIndexedAddress of(DimensionSizes sizes) {
+ public static DirectIndexedAddress of(DimensionSizes sizes) {
return new DirectIndexedAddress(sizes);
}
@@ -39,6 +39,14 @@ public final class DirectIndexedAddress {
/** Retrieve the index that can be used for direct lookup in an indexed tensor. */
public long getDirectIndex() { return directIndex; }
+ public long [] getIndexes() {
+ long[] asLong = new long[indexes.length];
+ for (int i=0; i < indexes.length; i++) {
+ asLong[i] = indexes[i];
+ }
+ return asLong;
+ }
+
/** returns the stride to be used for the given dimension */
public long getStride(int dimension) {
return sizes.productOfDimensionsAfter(dimension);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index c394274032e..17cacb8f009 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -1,6 +1,8 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.DimensionSizes;
+import com.yahoo.tensor.DirectIndexedAddress;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -10,7 +12,6 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.Convert;
-import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.ArrayList;
import java.util.Collections;
@@ -114,19 +115,78 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) {
- if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
+ if (!dimensions.isEmpty() && !argument.type().dimensionNames().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
- dimensions + ": Not all those dimensions are present in this tensor");
+ dimensions + ": Not all those dimensions are present in this tensor");
// Special case: Reduce all
- if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size())
+ if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) {
if (argument.isEmpty())
return Tensor.from(0.0);
else if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor)
- return reduceIndexedVector((IndexedTensor)argument, aggregator);
+ return reduceIndexedVector((IndexedTensor) argument, aggregator);
else
return reduceAllGeneral(argument, aggregator);
+ }
+ if (argument instanceof IndexedTensor indexedTensor) {
+ return reduceIndexedTensor(indexedTensor, dimensions, aggregator);
+ } else {
+ return reduceGeneral(argument, dimensions, aggregator);
+ }
+ }
+
+ private static void reduce(IndexedTensor argument, ValueAggregator aggregator, DirectIndexedAddress address, int[] reduce, int reduceIndex) {
+ int currentIndex = reduce[reduceIndex];
+ int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex));
+ if (reduceIndex + 1 < reduce.length) {
+ int nextDimension = reduceIndex + 1;
+ for (int i = 0; i < dimSize; i++) {
+ address.setIndex(currentIndex, i);
+ reduce(argument, aggregator, address, reduce, nextDimension);
+ }
+ } else {
+ address.setIndex(currentIndex, 0);
+ long increment = address.getStride(currentIndex);
+ long directIndex = address.getDirectIndex();
+ for (int i = 0; i < dimSize; i++) {
+ aggregator.aggregate(argument.get(directIndex + i * increment));
+ }
+ }
+ }
+
+ private static void reduce(IndexedTensor.Builder builder, DirectIndexedAddress destAddress, IndexedTensor argument, Aggregator aggregator, DirectIndexedAddress address, int[] toKeep, int keepIndex, int[] toReduce) {
+ if (keepIndex < toKeep.length) {
+ int currentIndex = toKeep[keepIndex];
+ int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex));
+
+ int nextKeep = keepIndex + 1;
+ for (int i = 0; i < dimSize; i++) {
+ address.setIndex(currentIndex, i);
+ destAddress.setIndex(keepIndex, i);
+ reduce(builder, destAddress, argument, aggregator, address, toKeep, nextKeep, toReduce);
+ }
+ } else {
+ ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
+ reduce(argument, valueAggregator, address, toReduce, 0);
+ builder.cell(valueAggregator.aggregatedValue(), destAddress.getIndexes());
+ }
+
+ }
+
+ private static Tensor reduceIndexedTensor(IndexedTensor argument, List<String> dimensions, Aggregator aggregator) {
+ TensorType reducedType = outputType(argument.type(), dimensions);
+ var reducedBuilder = IndexedTensor.Builder.of(reducedType);
+ DirectIndexedAddress reducedAddress = DirectIndexedAddress.of(DimensionSizes.of(reducedType));
+ int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions);
+ int[] indexesToReduce = new int[dimensions.size()];
+ for (int i = 0; i < dimensions.size(); i++) {
+ indexesToReduce[i] = argument.type().indexOfDimension(dimensions.get(i)).get();
+ }
+ reduce(reducedBuilder, reducedAddress, argument, aggregator, argument.directAddress(), indexesToKeep, 0, indexesToReduce);
+ return reducedBuilder.build();
+ }
+ private static Tensor reduceGeneral(Tensor argument, List<String> dimensions, Aggregator aggregator) {
TensorType reducedType = outputType(argument.type(), dimensions);
// Reduce cells