aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-03 11:12:50 -0800
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-03 11:12:50 -0800
commitf289d6375fec9376c453ab06fbf19a2d1c8f9b49 (patch)
treebdc5c8638b616fbf93ef35b954991b05ce8ddb98 /vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
parent9a4929ad69be79fa60127f62a849e846951738a8 (diff)
Store labels only in dimension addresses
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java37
1 files changed, 20 insertions, 17 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 416b6ec1473..9251521f6bf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -98,34 +98,37 @@ public class Reduce extends PrimitiveTensorFunction {
if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size())
return reduceAll(argument);
- // Reduce dimensions
- Set<String> reducedDimensions = new HashSet<>(argument.type().dimensionNames());
- reducedDimensions.removeAll(dimensions);
+ // Reduce type
+ TensorType.Builder builder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : argument.type().dimensions())
+ if ( ! dimensions.contains(dimension.name())) // keep
+ builder.dimension(dimension);
+ TensorType reducedType = builder.build();
// Reduce cells
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
for (Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) {
- TensorAddress reducedAddress = reduceDimensions(cell.getKey(), reducedDimensions);
+ TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType);
aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
aggregatingCells.get(reducedAddress).aggregate(cell.getValue());
}
ImmutableMap.Builder<TensorAddress, Double> reducedCells = new ImmutableMap.Builder<>();
for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet())
reducedCells.put(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
- return new MapTensor(asMappedDimensions(reducedDimensions), reducedCells.build());
- }
-
- private TensorType asMappedDimensions(Set<String> dimensionNames) {
- TensorType.Builder builder = new TensorType.Builder();
- for (String dimensionName : dimensionNames)
- builder.mapped(dimensionName);
- return builder.build();
+ return new MapTensor(reducedType, reducedCells.build());
}
- private TensorAddress reduceDimensions(TensorAddress address, Set<String> reducedDimensions) {
- return TensorAddress.fromSorted(address.elements().stream()
- .filter(e -> reducedDimensions.contains(e.dimension()))
- .collect(Collectors.toList()));
+ private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) {
+ Set<Integer> indexesToRemove = new HashSet<>();
+ for (String dimensionToRemove : this.dimensions)
+ indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get());
+
+ String[] reducedLabels = new String[reducedType.dimensions().size()];
+ int reducedLabelIndex = 0;
+ for (int i = 0; i < address.elements().size(); i++)
+ if ( ! indexesToRemove.contains(i))
+ reducedLabels[reducedLabelIndex++] = address.elements().get(i);
+ return new TensorAddress(reducedLabels);
}
private Tensor reduceAll(Tensor argument) {
@@ -137,7 +140,7 @@ public class Reduce extends PrimitiveTensorFunction {
private static abstract class ValueAggregator {
- public static ValueAggregator ofType(Aggregator aggregator) {
+ private static ValueAggregator ofType(Aggregator aggregator) {
switch (aggregator) {
case avg : return new AvgAggregator();
case count : return new CountAggregator();