summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java70
1 files changed, 65 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index c394274032e..17cacb8f009 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -1,6 +1,8 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.DimensionSizes;
+import com.yahoo.tensor.DirectIndexedAddress;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -10,7 +12,6 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.Convert;
-import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.ArrayList;
import java.util.Collections;
@@ -114,19 +115,78 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) {
- if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
+ if (!dimensions.isEmpty() && !argument.type().dimensionNames().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
- dimensions + ": Not all those dimensions are present in this tensor");
+ dimensions + ": Not all those dimensions are present in this tensor");
// Special case: Reduce all
- if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size())
+ if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) {
if (argument.isEmpty())
return Tensor.from(0.0);
else if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor)
- return reduceIndexedVector((IndexedTensor)argument, aggregator);
+ return reduceIndexedVector((IndexedTensor) argument, aggregator);
else
return reduceAllGeneral(argument, aggregator);
+ }
+ if (argument instanceof IndexedTensor indexedTensor) {
+ return reduceIndexedTensor(indexedTensor, dimensions, aggregator);
+ } else {
+ return reduceGeneral(argument, dimensions, aggregator);
+ }
+ }
+
+ private static void reduce(IndexedTensor argument, ValueAggregator aggregator, DirectIndexedAddress address, int[] reduce, int reduceIndex) {
+ int currentIndex = reduce[reduceIndex];
+ int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex));
+ if (reduceIndex + 1 < reduce.length) {
+ int nextDimension = reduceIndex + 1;
+ for (int i = 0; i < dimSize; i++) {
+ address.setIndex(currentIndex, i);
+ reduce(argument, aggregator, address, reduce, nextDimension);
+ }
+ } else {
+ address.setIndex(currentIndex, 0);
+ long increment = address.getStride(currentIndex);
+ long directIndex = address.getDirectIndex();
+ for (int i = 0; i < dimSize; i++) {
+ aggregator.aggregate(argument.get(directIndex + i * increment));
+ }
+ }
+ }
+
+ private static void reduce(IndexedTensor.Builder builder, DirectIndexedAddress destAddress, IndexedTensor argument, Aggregator aggregator, DirectIndexedAddress address, int[] toKeep, int keepIndex, int[] toReduce) {
+ if (keepIndex < toKeep.length) {
+ int currentIndex = toKeep[keepIndex];
+ int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex));
+
+ int nextKeep = keepIndex + 1;
+ for (int i = 0; i < dimSize; i++) {
+ address.setIndex(currentIndex, i);
+ destAddress.setIndex(keepIndex, i);
+ reduce(builder, destAddress, argument, aggregator, address, toKeep, nextKeep, toReduce);
+ }
+ } else {
+ ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
+ reduce(argument, valueAggregator, address, toReduce, 0);
+ builder.cell(valueAggregator.aggregatedValue(), destAddress.getIndexes());
+ }
+
+ }
+
+ private static Tensor reduceIndexedTensor(IndexedTensor argument, List<String> dimensions, Aggregator aggregator) {
+ TensorType reducedType = outputType(argument.type(), dimensions);
+ var reducedBuilder = IndexedTensor.Builder.of(reducedType);
+ DirectIndexedAddress reducedAddress = DirectIndexedAddress.of(DimensionSizes.of(reducedType));
+ int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions);
+ int[] indexesToReduce = new int[dimensions.size()];
+ for (int i = 0; i < dimensions.size(); i++) {
+ indexesToReduce[i] = argument.type().indexOfDimension(dimensions.get(i)).get();
+ }
+ reduce(reducedBuilder, reducedAddress, argument, aggregator, argument.directAddress(), indexesToKeep, 0, indexesToReduce);
+ return reducedBuilder.build();
+ }
+ private static Tensor reduceGeneral(Tensor argument, List<String> dimensions, Aggregator aggregator) {
TensorType reducedType = outputType(argument.type(), dimensions);
// Reduce cells