package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
/**
* The reduce tensor operation returns a tensor produced from the argument tensor where some dimensions
* are collapsed to a single value using an aggregator function.
*
* @author bratseth
*/
@Beta
public class Reduce extends PrimitiveTensorFunction {
public enum Aggregator { avg, count, prod, sum, max, min; }
private final TensorFunction argument;
private final List dimensions;
private final Aggregator aggregator;
/** Creates a reduce function reducing aLL dimensions */
public Reduce(TensorFunction argument, Aggregator aggregator) {
this(argument, aggregator, Collections.emptyList());
}
/** Creates a reduce function reducing a single dimension */
public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) {
this(argument, aggregator, Collections.singletonList(dimension));
}
/**
* Creates a reduce function.
*
* @param argument the tensor to reduce
* @param aggregator the aggregator function to use
* @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced,
* producing a dimensionless tensor (a scalar).
* @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor
*/
public Reduce(TensorFunction argument, Aggregator aggregator, List dimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(aggregator, "The aggregator cannot be null");
Objects.requireNonNull(dimensions, "The dimensions cannot be null");
this.argument = argument;
this.aggregator = aggregator;
this.dimensions = ImmutableList.copyOf(dimensions);
}
public TensorFunction argument() { return argument; }
@Override
public List functionArguments() { return Collections.singletonList(argument); }
@Override
public TensorFunction replaceArguments(List arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size());
return new Reduce(arguments.get(0), aggregator, dimensions);
}
@Override
public PrimitiveTensorFunction toPrimitive() {
return new Reduce(argument.toPrimitive(), aggregator, dimensions);
}
@Override
public String toString(ToStringContext context) {
return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
}
private String commaSeparated(List list) {
StringBuilder b = new StringBuilder();
for (String element : list)
b.append(", ").append(element);
return b.toString();
}
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
dimensions + ": Not all those dimensions are present in this tensor");
// Special case: Reduce all
if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size())
if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor)
return reduceIndexedVector((IndexedTensor)argument);
else
return reduceAllGeneral(argument);
// Reduce type
TensorType.Builder builder = new TensorType.Builder();
for (TensorType.Dimension dimension : argument.type().dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
builder.dimension(dimension);
TensorType reducedType = builder.build();
// Reduce cells
Map aggregatingCells = new HashMap<>();
for (Iterator i = argument.cellIterator(); i.hasNext(); ) {
Map.Entry cell = i.next();
TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType);
aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
aggregatingCells.get(reducedAddress).aggregate(cell.getValue());
}
Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
for (Map.Entry aggregatingCell : aggregatingCells.entrySet())
reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
return reducedBuilder.build();
}
private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) {
Set indexesToRemove = new HashSet<>();
for (String dimensionToRemove : this.dimensions)
indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get());
String[] reducedLabels = new String[reducedType.dimensions().size()];
int reducedLabelIndex = 0;
for (int i = 0; i < address.size(); i++)
if ( ! indexesToRemove.contains(i))
reducedLabels[reducedLabelIndex++] = address.label(i);
return TensorAddress.of(reducedLabels);
}
private Tensor reduceAllGeneral(Tensor argument) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Iterator i = argument.valueIterator(); i.hasNext(); )
valueAggregator.aggregate(i.next());
return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
}
private Tensor reduceIndexedVector(IndexedTensor argument) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (int i = 0; i < argument.dimensionSizes().size(0); i++)
valueAggregator.aggregate(argument.get(i));
return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
}
private static abstract class ValueAggregator {
private static ValueAggregator ofType(Aggregator aggregator) {
switch (aggregator) {
case avg : return new AvgAggregator();
case count : return new CountAggregator();
case prod : return new ProdAggregator();
case sum : return new SumAggregator();
case max : return new MaxAggregator();
case min : return new MinAggregator();
default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
}
}
/** Add a new value to those aggregated by this */
public abstract void aggregate(double value);
/** Returns the value aggregated by this */
public abstract double aggregatedValue();
}
private static class AvgAggregator extends ValueAggregator {
private int valueCount = 0;
private double valueSum = 0.0;
@Override
public void aggregate(double value) {
valueCount++;
valueSum+= value;
}
@Override
public double aggregatedValue() {
return valueSum / valueCount;
}
}
private static class CountAggregator extends ValueAggregator {
private int valueCount = 0;
@Override
public void aggregate(double value) {
valueCount++;
}
@Override
public double aggregatedValue() {
return valueCount;
}
}
private static class ProdAggregator extends ValueAggregator {
private double valueProd = 1.0;
@Override
public void aggregate(double value) {
valueProd *= value;
}
@Override
public double aggregatedValue() {
return valueProd;
}
}
private static class SumAggregator extends ValueAggregator {
private double valueSum = 0.0;
@Override
public void aggregate(double value) {
valueSum += value;
}
@Override
public double aggregatedValue() {
return valueSum;
}
}
private static class MaxAggregator extends ValueAggregator {
private double maxValue = Double.MIN_VALUE;
@Override
public void aggregate(double value) {
if (value > maxValue)
maxValue = value;
}
@Override
public double aggregatedValue() {
return maxValue;
}
}
private static class MinAggregator extends ValueAggregator {
private double minValue = Double.MAX_VALUE;
@Override
public void aggregate(double value) {
if (value < minValue)
minValue = value;
}
@Override
public double aggregatedValue() {
return minValue;
}
}
}