// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; /** * The reduce tensor operation returns a tensor produced from the argument tensor where some dimensions * are collapsed to a single value using an aggregator function. * * @author bratseth */ public class Reduce extends PrimitiveTensorFunction { public enum Aggregator { avg, count, max, median, min, prod, sum ; } private final TensorFunction argument; private final List dimensions; private final Aggregator aggregator; /** Creates a reduce function reducing all dimensions */ public Reduce(TensorFunction argument, Aggregator aggregator) { this(argument, aggregator, List.of()); } /** Creates a reduce function reducing a single dimension */ public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) { this(argument, aggregator, List.of(dimension)); } /** * Creates a reduce function. * * @param argument the tensor to reduce * @param aggregator the aggregator function to use * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced, * producing a dimensionless tensor (a scalar). * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor */ public Reduce(TensorFunction argument, Aggregator aggregator, List dimensions) { this.argument = Objects.requireNonNull(argument, "The argument tensor cannot be null"); this.aggregator = Objects.requireNonNull(aggregator, "The aggregator cannot be null"); this.dimensions = List.copyOf(dimensions); } public static TensorType outputType(TensorType inputType, List reduceDimensions) { return TypeResolver.reduce(inputType, reduceDimensions); } public TensorFunction argument() { return argument; } Aggregator aggregator() { return aggregator; } List dimensions() { return dimensions; } @Override public List> arguments() { return List.of(argument); } @Override public TensorFunction withArguments(List> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size()); return new Reduce<>(arguments.get(0), aggregator, dimensions); } @Override public PrimitiveTensorFunction toPrimitive() { return new Reduce<>(argument.toPrimitive(), aggregator, dimensions); } @Override public String toString(ToStringContext context) { return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } static String commaSeparated(List list) { StringBuilder b = new StringBuilder(); for (String element : list) b.append(", ").append(element); return b.toString(); } @Override public TensorType type(TypeContext context) { return outputType(argument.type(context), dimensions); } @Override public Tensor evaluate(EvaluationContext context) { return evaluate(this.argument.evaluate(context), dimensions, aggregator); } @Override public int hashCode() { return Objects.hash("reduce", argument, dimensions, aggregator); } static Tensor evaluate(Tensor argument, List dimensions, Aggregator aggregator) { if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor"); // Special case: Reduce all if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) if (argument.isEmpty()) return Tensor.from(0.0); else if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor) return reduceIndexedVector((IndexedTensor)argument, aggregator); else return reduceAllGeneral(argument, aggregator); TensorType reducedType = outputType(argument.type(), dimensions); // Reduce cells Map aggregatingCells = new HashMap<>(); for (Iterator i = argument.cellIterator(); i.hasNext(); ) { Map.Entry cell = i.next(); TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType, dimensions); aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); aggregatingCells.get(reducedAddress).aggregate(cell.getValue()); } Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); for (Map.Entry aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); return reducedBuilder.build(); } private static TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType, List dimensions) { Set indexesToRemove = new HashSet<>(); for (String dimensionToRemove : dimensions) indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get()); String[] reducedLabels = new String[reducedType.dimensions().size()]; int reducedLabelIndex = 0; for (int i = 0; i < address.size(); i++) if ( ! indexesToRemove.contains(i)) reducedLabels[reducedLabelIndex++] = address.label(i); return TensorAddress.of(reducedLabels); } private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator i = argument.valueIterator(); i.hasNext(); ) valueAggregator.aggregate(i.next()); return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); } private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (int i = 0; i < argument.dimensionSizes().size(0); i++) valueAggregator.aggregate(argument.get(i)); return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); } static abstract class ValueAggregator { static ValueAggregator ofType(Aggregator aggregator) { return switch (aggregator) { case avg -> new AvgAggregator(); case count -> new CountAggregator(); case max -> new MaxAggregator(); case median -> new MedianAggregator(); case min -> new MinAggregator(); case prod -> new ProdAggregator(); case sum -> new SumAggregator(); default -> throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented"); }; } /** Add a new value to those aggregated by this */ public abstract void aggregate(double value); /** Returns the value aggregated by this */ public abstract double aggregatedValue(); /** Resets the aggregator */ public abstract void reset(); /** Returns a hash of this aggregator which only depends on its identity */ @Override public abstract int hashCode(); } private static class AvgAggregator extends ValueAggregator { private int valueCount = 0; private double valueSum = 0.0; @Override public void aggregate(double value) { valueCount++; valueSum+= value; } @Override public double aggregatedValue() { return valueSum / valueCount; } @Override public void reset() { valueCount = 0; valueSum = 0.0; } @Override public int hashCode() { return "avgAggregator".hashCode(); } } private static class CountAggregator extends ValueAggregator { private int valueCount = 0; @Override public void aggregate(double value) { valueCount++; } @Override public double aggregatedValue() { return valueCount; } @Override public void reset() { valueCount = 0; } @Override public int hashCode() { return "countAggregator".hashCode(); } } private static class MaxAggregator extends ValueAggregator { private double maxValue = Double.NEGATIVE_INFINITY; @Override public void aggregate(double value) { if (value > maxValue) maxValue = value; } @Override public double aggregatedValue() { return maxValue; } @Override public void reset() { maxValue = Double.NEGATIVE_INFINITY; } @Override public int hashCode() { return "maxAggregator".hashCode(); } } private static class MedianAggregator extends ValueAggregator { /** If any NaN is added, the result should be NaN */ private boolean isNaN = false; private List values = new ArrayList<>(); @Override public void aggregate(double value) { if ( Double.isNaN(value)) isNaN = true; if ( ! isNaN) values.add(value); } @Override public double aggregatedValue() { if (isNaN || values.isEmpty()) return Double.NaN; Collections.sort(values); if (values.size() % 2 == 0) // even: average the two middle values return ( values.get(values.size() / 2 - 1) + values.get(values.size() / 2) ) / 2; else return values.get((values.size() - 1)/ 2); } @Override public void reset() { isNaN = false; values = new ArrayList<>(); } @Override public int hashCode() { return "medianAggregator".hashCode(); } } private static class MinAggregator extends ValueAggregator { private double minValue = Double.POSITIVE_INFINITY; @Override public void aggregate(double value) { if (value < minValue) minValue = value; } @Override public double aggregatedValue() { return minValue; } @Override public void reset() { minValue = Double.POSITIVE_INFINITY; } @Override public int hashCode() { return "minAggregator".hashCode(); } } private static class ProdAggregator extends ValueAggregator { private double valueProd = 1.0; @Override public void aggregate(double value) { valueProd *= value; } @Override public double aggregatedValue() { return valueProd; } @Override public void reset() { valueProd = 1.0; } @Override public int hashCode() { return "prodAggregator".hashCode(); } } private static class SumAggregator extends ValueAggregator { private double valueSum = 0.0; @Override public void aggregate(double value) { valueSum += value; } @Override public double aggregatedValue() { return valueSum; } @Override public void reset() { valueSum = 0.0; } @Override public int hashCode() { return "sumAggregator".hashCode(); } } }