package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; /** * The reduce tensor operation returns a tensor produced from the argument tensor where some dimensions * are collapsed to a single value using an aggregator function. * * @author bratseth */ @Beta public class Reduce extends PrimitiveTensorFunction { public enum Aggregator { avg, count, prod, sum, max, min; } private final TensorFunction argument; private final List dimensions; private final Aggregator aggregator; /** Creates a reduce function reducing aLL dimensions */ public Reduce(TensorFunction argument, Aggregator aggregator) { this(argument, aggregator, Collections.emptyList()); } /** Creates a reduce function reducing a single dimension */ public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) { this(argument, aggregator, Collections.singletonList(dimension)); } /** * Creates a reduce function. * * @param argument the tensor to reduce * @param aggregator the aggregator function to use * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced, * producing a dimensionless tensor (a scalar). * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor */ public Reduce(TensorFunction argument, Aggregator aggregator, List dimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(aggregator, "The aggregator cannot be null"); Objects.requireNonNull(dimensions, "The dimensions cannot be null"); this.argument = argument; this.aggregator = aggregator; this.dimensions = ImmutableList.copyOf(dimensions); } public TensorFunction argument() { return argument; } @Override public List functionArguments() { return Collections.singletonList(argument); } @Override public TensorFunction replaceArguments(List arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size()); return new Reduce(arguments.get(0), aggregator, dimensions); } @Override public PrimitiveTensorFunction toPrimitive() { return new Reduce(argument.toPrimitive(), aggregator, dimensions); } @Override public String toString(ToStringContext context) { return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } private String commaSeparated(List list) { StringBuilder b = new StringBuilder(); for (String element : list) b.append(", ").append(element); return b.toString(); } @Override public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor"); // Special case: Reduce all if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor) return reduceIndexedVector((IndexedTensor)argument); else return reduceAllGeneral(argument); // Reduce type TensorType.Builder builder = new TensorType.Builder(); for (TensorType.Dimension dimension : argument.type().dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); TensorType reducedType = builder.build(); // Reduce cells Map aggregatingCells = new HashMap<>(); for (Iterator i = argument.cellIterator(); i.hasNext(); ) { Map.Entry cell = i.next(); TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType); aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); aggregatingCells.get(reducedAddress).aggregate(cell.getValue()); } Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); for (Map.Entry aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); return reducedBuilder.build(); } private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) { Set indexesToRemove = new HashSet<>(); for (String dimensionToRemove : this.dimensions) indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get()); String[] reducedLabels = new String[reducedType.dimensions().size()]; int reducedLabelIndex = 0; for (int i = 0; i < address.size(); i++) if ( ! indexesToRemove.contains(i)) reducedLabels[reducedLabelIndex++] = address.label(i); return TensorAddress.of(reducedLabels); } private Tensor reduceAllGeneral(Tensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator i = argument.valueIterator(); i.hasNext(); ) valueAggregator.aggregate(i.next()); return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); } private Tensor reduceIndexedVector(IndexedTensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (int i = 0; i < argument.dimensionSizes().size(0); i++) valueAggregator.aggregate(argument.get(i)); return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); } private static abstract class ValueAggregator { private static ValueAggregator ofType(Aggregator aggregator) { switch (aggregator) { case avg : return new AvgAggregator(); case count : return new CountAggregator(); case prod : return new ProdAggregator(); case sum : return new SumAggregator(); case max : return new MaxAggregator(); case min : return new MinAggregator(); default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented"); } } /** Add a new value to those aggregated by this */ public abstract void aggregate(double value); /** Returns the value aggregated by this */ public abstract double aggregatedValue(); } private static class AvgAggregator extends ValueAggregator { private int valueCount = 0; private double valueSum = 0.0; @Override public void aggregate(double value) { valueCount++; valueSum+= value; } @Override public double aggregatedValue() { return valueSum / valueCount; } } private static class CountAggregator extends ValueAggregator { private int valueCount = 0; @Override public void aggregate(double value) { valueCount++; } @Override public double aggregatedValue() { return valueCount; } } private static class ProdAggregator extends ValueAggregator { private double valueProd = 1.0; @Override public void aggregate(double value) { valueProd *= value; } @Override public double aggregatedValue() { return valueProd; } } private static class SumAggregator extends ValueAggregator { private double valueSum = 0.0; @Override public void aggregate(double value) { valueSum += value; } @Override public double aggregatedValue() { return valueSum; } } private static class MaxAggregator extends ValueAggregator { private double maxValue = Double.MIN_VALUE; @Override public void aggregate(double value) { if (value > maxValue) maxValue = value; } @Override public double aggregatedValue() { return maxValue; } } private static class MinAggregator extends ValueAggregator { private double minValue = Double.MAX_VALUE; @Override public void aggregate(double value) { if (value < minValue) minValue = value; } @Override public double aggregatedValue() { return minValue; } } }