aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2016-11-25 18:21:25 +0100
committerGitHub <noreply@github.com>2016-11-25 18:21:25 +0100
commit11b208db7d2422828c90aafa638f059306acbc24 (patch)
tree63d3f766b7a046b13b2b4fdc8e633fe71134847c /vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
parent5400980ea6bbac6ef385d089b5e9f9b100ecae71 (diff)
Revert "Bratseth/tensor functions 3"
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java236
1 files changed, 14 insertions, 222 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index ef18cb61b17..4b306d376a6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -1,246 +1,38 @@
package com.yahoo.tensor.functions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.MapTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Set;
-import java.util.stream.Collectors;
+import java.util.Optional;
+import java.util.function.DoubleBinaryOperator;
/**
- * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
- * are collapsed to a single value using an aggregator function.
+ * The reduce tensor function.
*
* @author bratseth
*/
public class Reduce extends PrimitiveTensorFunction {
- public enum Aggregator { avg, count, prod, sum, max, min; }
-
private final TensorFunction argument;
- private final List<String> dimensions;
- private final Aggregator aggregator;
-
- /** Creates a reduce function reducing aLL dimensions */
- public Reduce(TensorFunction argument, Aggregator aggregator) {
- this(argument, aggregator, Collections.emptyList());
- }
+ private final String dimension;
+ private final DoubleBinaryOperator reductor;
+ private final Optional<DoubleBinaryOperator> postTransformation;
- /** Creates a reduce function reducing a single dimension */
- public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) {
- this(argument, aggregator, Collections.singletonList(dimension));
- }
-
- /**
- * Creates a reduce function.
- *
- * @param argument the tensor to reduce
- * @param aggregator the aggregator function to use
- * @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced,
- * producing a dimensionless tensor (a scalar).
- * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor
- */
- public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) {
- Objects.requireNonNull(argument, "The argument tensor cannot be null");
- Objects.requireNonNull(aggregator, "The aggregator cannot be null");
- Objects.requireNonNull(dimensions, "The dimensions cannot be null");
+ public Reduce(TensorFunction argument, String dimension,
+ DoubleBinaryOperator reductor, Optional<DoubleBinaryOperator> postTransformation) {
this.argument = argument;
- this.aggregator = aggregator;
- this.dimensions = ImmutableList.copyOf(dimensions);
+ this.dimension = dimension;
+ this.reductor = reductor;
+ this.postTransformation = postTransformation;
}
public TensorFunction argument() { return argument; }
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
-
- @Override
public PrimitiveTensorFunction toPrimitive() {
- return new Reduce(argument.toPrimitive(), aggregator, dimensions);
- }
-
- @Override
- public String toString(ToStringContext context) {
- return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
- }
-
- private String commaSeparated(List<String> list) {
- StringBuilder b = new StringBuilder();
- for (String element : list)
- b.append(", ").append(element);
- return b.toString();
+ return new Reduce(argument.toPrimitive(), dimension, reductor, postTransformation);
}
@Override
- public Tensor evaluate(EvaluationContext context) {
- Tensor argument = this.argument.evaluate(context);
-
- if ( ! dimensions.isEmpty() && ! argument.dimensions().containsAll(dimensions))
- throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
- dimensions + ": Not all those dimensions are present in this tensor");
-
- if (dimensions.isEmpty() || dimensions.size() == argument.dimensions().size())
- return reduceAll(argument);
-
- // Reduce dimensions
- Set<String> reducedDimensions = new HashSet<>(argument.dimensions());
- reducedDimensions.removeAll(dimensions);
-
- // Reduce cells
- Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
- for (Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) {
- TensorAddress reducedAddress = reduceDimensions(cell.getKey(), reducedDimensions);
- aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
- aggregatingCells.get(reducedAddress).aggregate(cell.getValue());
- }
- ImmutableMap.Builder<TensorAddress, Double> reducedCells = new ImmutableMap.Builder<>();
- for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet())
- reducedCells.put(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
- return new MapTensor(reducedDimensions, reducedCells.build());
- }
-
- private TensorAddress reduceDimensions(TensorAddress address, Set<String> reducedDimensions) {
- return TensorAddress.fromSorted(address.elements().stream()
- .filter(e -> reducedDimensions.contains(e.dimension()))
- .collect(Collectors.toList()));
- }
-
- private Tensor reduceAll(Tensor argument) {
- ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
- for (Double cellValue : argument.cells().values())
- valueAggregator.aggregate(cellValue);
- return new MapTensor(ImmutableMap.of(TensorAddress.empty, valueAggregator.aggregatedValue()));
- }
-
- private static abstract class ValueAggregator {
-
- public static ValueAggregator ofType(Aggregator aggregator) {
- switch (aggregator) {
- case avg : return new AvgAggregator();
- case count : return new CountAggregator();
- case prod : return new ProdAggregator();
- case sum : return new SumAggregator();
- case max : return new MaxAggregator();
- case min : return new MinAggregator();
- default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
- }
-
- }
-
- /** Add a new value to those aggregated by this */
- public abstract void aggregate(double value);
-
- /** Returns the value aggregated by this */
- public abstract double aggregatedValue();
-
- }
-
- private static class AvgAggregator extends ValueAggregator {
-
- private int valueCount = 0;
- private double valueSum = 0.0;
-
- @Override
- public void aggregate(double value) {
- valueCount++;
- valueSum+= value;
- }
-
- @Override
- public double aggregatedValue() {
- return valueSum / valueCount;
- }
-
- }
-
- private static class CountAggregator extends ValueAggregator {
-
- private int valueCount = 0;
-
- @Override
- public void aggregate(double value) {
- valueCount++;
- }
-
- @Override
- public double aggregatedValue() {
- return valueCount;
- }
-
- }
-
- private static class ProdAggregator extends ValueAggregator {
-
- private double valueProd = 1.0;
-
- @Override
- public void aggregate(double value) {
- valueProd *= value;
- }
-
- @Override
- public double aggregatedValue() {
- return valueProd;
- }
-
- }
-
- private static class SumAggregator extends ValueAggregator {
-
- private double valueSum = 0.0;
-
- @Override
- public void aggregate(double value) {
- valueSum += value;
- }
-
- @Override
- public double aggregatedValue() {
- return valueSum;
- }
-
- }
-
- private static class MaxAggregator extends ValueAggregator {
-
- private double maxValue = Double.MIN_VALUE;
-
- @Override
- public void aggregate(double value) {
- if (value > maxValue)
- maxValue = value;
- }
-
- @Override
- public double aggregatedValue() {
- return maxValue;
- }
-
- }
-
- private static class MinAggregator extends ValueAggregator {
-
- private double minValue = Double.MAX_VALUE;
-
- @Override
- public void aggregate(double value) {
- if (value < minValue)
- minValue = value;
- }
-
- @Override
- public double aggregatedValue() {
- return minValue;
- }
-
+ public String toString() {
+ return "reduce(" + argument.toString() + ", " + dimension + ", lambda(a, b) (...), lambda(a, b) (...))";
}
}