// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
/**
* The reduce tensor operation returns a tensor produced from the argument tensor where some dimensions
* are collapsed to a single value using an aggregator function.
*
* @author bratseth
*/
public class Reduce extends PrimitiveTensorFunction {
public enum Aggregator { avg, count, max, median, min, prod, sum ; }
private final TensorFunction argument;
private final List dimensions;
private final Aggregator aggregator;
/** Creates a reduce function reducing all dimensions */
public Reduce(TensorFunction argument, Aggregator aggregator) {
this(argument, aggregator, List.of());
}
/** Creates a reduce function reducing a single dimension */
public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) {
this(argument, aggregator, List.of(dimension));
}
/**
* Creates a reduce function.
*
* @param argument the tensor to reduce
* @param aggregator the aggregator function to use
* @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced,
* producing a dimensionless tensor (a scalar).
* @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor
*/
public Reduce(TensorFunction argument, Aggregator aggregator, List dimensions) {
this.argument = Objects.requireNonNull(argument, "The argument tensor cannot be null");
this.aggregator = Objects.requireNonNull(aggregator, "The aggregator cannot be null");
this.dimensions = List.copyOf(dimensions);
}
public static TensorType outputType(TensorType inputType, List reduceDimensions) {
return TypeResolver.reduce(inputType, reduceDimensions);
}
public TensorFunction argument() { return argument; }
Aggregator aggregator() { return aggregator; }
List dimensions() { return dimensions; }
@Override
public List> arguments() { return List.of(argument); }
@Override
public TensorFunction withArguments(List> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size());
return new Reduce<>(arguments.get(0), aggregator, dimensions);
}
@Override
public PrimitiveTensorFunction toPrimitive() {
return new Reduce<>(argument.toPrimitive(), aggregator, dimensions);
}
@Override
public String toString(ToStringContext context) {
return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
}
static String commaSeparated(List list) {
StringBuilder b = new StringBuilder();
for (String element : list)
b.append(", ").append(element);
return b.toString();
}
@Override
public TensorType type(TypeContext context) {
return outputType(argument.type(context), dimensions);
}
@Override
public Tensor evaluate(EvaluationContext context) {
return evaluate(this.argument.evaluate(context), dimensions, aggregator);
}
@Override
public int hashCode() {
return Objects.hash("reduce", argument, dimensions, aggregator);
}
static Tensor evaluate(Tensor argument, List dimensions, Aggregator aggregator) {
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
dimensions + ": Not all those dimensions are present in this tensor");
// Special case: Reduce all
if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size())
if (argument.isEmpty())
return Tensor.from(0.0);
else if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor)
return reduceIndexedVector((IndexedTensor)argument, aggregator);
else
return reduceAllGeneral(argument, aggregator);
TensorType reducedType = outputType(argument.type(), dimensions);
// Reduce cells
Map aggregatingCells = new HashMap<>();
for (Iterator i = argument.cellIterator(); i.hasNext(); ) {
Map.Entry cell = i.next();
TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType, dimensions);
aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
aggregatingCells.get(reducedAddress).aggregate(cell.getValue());
}
Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
for (Map.Entry aggregatingCell : aggregatingCells.entrySet())
reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
return reducedBuilder.build();
}
private static TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType, List dimensions) {
Set indexesToRemove = new HashSet<>(dimensions.size()*2);
for (String dimensionToRemove : dimensions)
indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get());
String[] reducedLabels = new String[reducedType.dimensions().size()];
int reducedLabelIndex = 0;
for (int i = 0; i < address.size(); i++)
if ( ! indexesToRemove.contains(i))
reducedLabels[reducedLabelIndex++] = address.label(i);
return TensorAddress.of(reducedLabels);
}
private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Iterator i = argument.valueIterator(); i.hasNext(); )
valueAggregator.aggregate(i.next());
return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
}
private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (int i = 0; i < argument.dimensionSizes().size(0); i++)
valueAggregator.aggregate(argument.get(i));
return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
}
static abstract class ValueAggregator {
static ValueAggregator ofType(Aggregator aggregator) {
return switch (aggregator) {
case avg -> new AvgAggregator();
case count -> new CountAggregator();
case max -> new MaxAggregator();
case median -> new MedianAggregator();
case min -> new MinAggregator();
case prod -> new ProdAggregator();
case sum -> new SumAggregator();
default -> throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
};
}
/** Add a new value to those aggregated by this */
public abstract void aggregate(double value);
/** Returns the value aggregated by this */
public abstract double aggregatedValue();
/** Resets the aggregator */
public abstract void reset();
/** Returns a hash of this aggregator which only depends on its identity */
@Override
public abstract int hashCode();
}
private static class AvgAggregator extends ValueAggregator {
private int valueCount = 0;
private double valueSum = 0.0;
@Override
public void aggregate(double value) {
valueCount++;
valueSum+= value;
}
@Override
public double aggregatedValue() {
return valueSum / valueCount;
}
@Override
public void reset() {
valueCount = 0;
valueSum = 0.0;
}
@Override
public int hashCode() { return "avgAggregator".hashCode(); }
}
private static class CountAggregator extends ValueAggregator {
private int valueCount = 0;
@Override
public void aggregate(double value) {
valueCount++;
}
@Override
public double aggregatedValue() {
return valueCount;
}
@Override
public void reset() {
valueCount = 0;
}
@Override
public int hashCode() { return "countAggregator".hashCode(); }
}
private static class MaxAggregator extends ValueAggregator {
private double maxValue = Double.NEGATIVE_INFINITY;
@Override
public void aggregate(double value) {
if (value > maxValue)
maxValue = value;
}
@Override
public double aggregatedValue() {
return maxValue;
}
@Override
public void reset() {
maxValue = Double.NEGATIVE_INFINITY;
}
@Override
public int hashCode() { return "maxAggregator".hashCode(); }
}
private static class MedianAggregator extends ValueAggregator {
/** If any NaN is added, the result should be NaN */
private boolean isNaN = false;
private List values = new ArrayList<>();
@Override
public void aggregate(double value) {
if ( Double.isNaN(value))
isNaN = true;
if ( ! isNaN)
values.add(value);
}
@Override
public double aggregatedValue() {
if (isNaN || values.isEmpty()) return Double.NaN;
Collections.sort(values);
if (values.size() % 2 == 0) // even: average the two middle values
return ( values.get(values.size() / 2 - 1) + values.get(values.size() / 2) ) / 2;
else
return values.get((values.size() - 1)/ 2);
}
@Override
public void reset() {
isNaN = false;
values = new ArrayList<>();
}
@Override
public int hashCode() { return "medianAggregator".hashCode(); }
}
private static class MinAggregator extends ValueAggregator {
private double minValue = Double.POSITIVE_INFINITY;
@Override
public void aggregate(double value) {
if (value < minValue)
minValue = value;
}
@Override
public double aggregatedValue() {
return minValue;
}
@Override
public void reset() {
minValue = Double.POSITIVE_INFINITY;
}
@Override
public int hashCode() { return "minAggregator".hashCode(); }
}
private static class ProdAggregator extends ValueAggregator {
private double valueProd = 1.0;
@Override
public void aggregate(double value) {
valueProd *= value;
}
@Override
public double aggregatedValue() {
return valueProd;
}
@Override
public void reset() {
valueProd = 1.0;
}
@Override
public int hashCode() { return "prodAggregator".hashCode(); }
}
private static class SumAggregator extends ValueAggregator {
private double valueSum = 0.0;
@Override
public void aggregate(double value) {
valueSum += value;
}
@Override
public double aggregatedValue() {
return valueSum;
}
@Override
public void reset() {
valueSum = 0.0;
}
@Override
public int hashCode() { return "sumAggregator".hashCode(); }
}
}