diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-31 13:52:41 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-31 14:22:21 +0100 |
commit | 43eed85aae7dfb7f515fe300f18457b8115aba69 (patch) | |
tree | 5185b211a5c4fad4919f882c21e86c5fd0924f7f /vespajlib | |
parent | 83c3d3716f34de515c9d03e90c34f2085b777268 (diff) |
- Make separate path for reduction of indexed tensors.
- This brough down reduce time from 60s to 3.7s or a factor of 16x for the splade embedder if using reduce.
This is ontop of the 1.5x improvement with https://github.com/vespa-engine/vespa/pull/30112
- A reduce based splade embedder now uses 8.0s vs 7.0s for a custom rolled out version, translated to 14% overhead.
Please enter the commit message for your changes. Lines starting
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java | 10 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 70 |
2 files changed, 74 insertions, 6 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java index 4379d50520c..cda3be47ddb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java @@ -22,7 +22,7 @@ public final class DirectIndexedAddress { directIndex = 0; } - static DirectIndexedAddress of(DimensionSizes sizes) { + public static DirectIndexedAddress of(DimensionSizes sizes) { return new DirectIndexedAddress(sizes); } @@ -39,6 +39,14 @@ public final class DirectIndexedAddress { /** Retrieve the index that can be used for direct lookup in an indexed tensor. */ public long getDirectIndex() { return directIndex; } + public long [] getIndexes() { + long[] asLong = new long[indexes.length]; + for (int i=0; i < indexes.length; i++) { + asLong[i] = indexes[i]; + } + return asLong; + } + /** returns the stride to be used for the given dimension */ public long getStride(int dimension) { return sizes.productOfDimensionsAfter(dimension); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index c394274032e..17cacb8f009 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -1,6 +1,8 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.DimensionSizes; +import com.yahoo.tensor.DirectIndexedAddress; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -10,7 +12,6 @@ import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.impl.Convert; -import com.yahoo.tensor.impl.TensorAddressAny; import java.util.ArrayList; import java.util.Collections; @@ -114,19 +115,78 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) { - if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) + if (!dimensions.isEmpty() && !argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + - dimensions + ": Not all those dimensions are present in this tensor"); + dimensions + ": Not all those dimensions are present in this tensor"); // Special case: Reduce all - if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) + if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) { if (argument.isEmpty()) return Tensor.from(0.0); else if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor) - return reduceIndexedVector((IndexedTensor)argument, aggregator); + return reduceIndexedVector((IndexedTensor) argument, aggregator); else return reduceAllGeneral(argument, aggregator); + } + if (argument instanceof IndexedTensor indexedTensor) { + return reduceIndexedTensor(indexedTensor, dimensions, aggregator); + } else { + return reduceGeneral(argument, dimensions, aggregator); + } + } + + private static void reduce(IndexedTensor argument, ValueAggregator aggregator, DirectIndexedAddress address, int[] reduce, int reduceIndex) { + int currentIndex = reduce[reduceIndex]; + int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex)); + if (reduceIndex + 1 < reduce.length) { + int nextDimension = reduceIndex + 1; + for (int i = 0; i < dimSize; i++) { + address.setIndex(currentIndex, i); + reduce(argument, aggregator, address, reduce, nextDimension); + } + } else { + address.setIndex(currentIndex, 0); + long increment = address.getStride(currentIndex); + long directIndex = address.getDirectIndex(); + for (int i = 0; i < dimSize; i++) { + aggregator.aggregate(argument.get(directIndex + i * increment)); + } + } + } + + private static void reduce(IndexedTensor.Builder builder, DirectIndexedAddress destAddress, IndexedTensor argument, Aggregator aggregator, DirectIndexedAddress address, int[] toKeep, int keepIndex, int[] toReduce) { + if (keepIndex < toKeep.length) { + int currentIndex = toKeep[keepIndex]; + int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex)); + + int nextKeep = keepIndex + 1; + for (int i = 0; i < dimSize; i++) { + address.setIndex(currentIndex, i); + destAddress.setIndex(keepIndex, i); + reduce(builder, destAddress, argument, aggregator, address, toKeep, nextKeep, toReduce); + } + } else { + ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); + reduce(argument, valueAggregator, address, toReduce, 0); + builder.cell(valueAggregator.aggregatedValue(), destAddress.getIndexes()); + } + + } + + private static Tensor reduceIndexedTensor(IndexedTensor argument, List<String> dimensions, Aggregator aggregator) { + TensorType reducedType = outputType(argument.type(), dimensions); + var reducedBuilder = IndexedTensor.Builder.of(reducedType); + DirectIndexedAddress reducedAddress = DirectIndexedAddress.of(DimensionSizes.of(reducedType)); + int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions); + int[] indexesToReduce = new int[dimensions.size()]; + for (int i = 0; i < dimensions.size(); i++) { + indexesToReduce[i] = argument.type().indexOfDimension(dimensions.get(i)).get(); + } + reduce(reducedBuilder, reducedAddress, argument, aggregator, argument.directAddress(), indexesToKeep, 0, indexesToReduce); + return reducedBuilder.build(); + } + private static Tensor reduceGeneral(Tensor argument, List<String> dimensions, Aggregator aggregator) { TensorType reducedType = outputType(argument.type(), dimensions); // Reduce cells |