diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-03 11:12:50 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-03 11:12:50 -0800 |
commit | f289d6375fec9376c453ab06fbf19a2d1c8f9b49 (patch) | |
tree | bdc5c8638b616fbf93ef35b954991b05ce8ddb98 /vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | |
parent | 9a4929ad69be79fa60127f62a849e846951738a8 (diff) |
Store labels only in dimension addresses
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 37 |
1 files changed, 20 insertions, 17 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 416b6ec1473..9251521f6bf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -98,34 +98,37 @@ public class Reduce extends PrimitiveTensorFunction { if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) return reduceAll(argument); - // Reduce dimensions - Set<String> reducedDimensions = new HashSet<>(argument.type().dimensionNames()); - reducedDimensions.removeAll(dimensions); + // Reduce type + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension dimension : argument.type().dimensions()) + if ( ! dimensions.contains(dimension.name())) // keep + builder.dimension(dimension); + TensorType reducedType = builder.build(); // Reduce cells Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); for (Map.Entry<TensorAddress, Double> cell : argument.cells().entrySet()) { - TensorAddress reducedAddress = reduceDimensions(cell.getKey(), reducedDimensions); + TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType); aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); aggregatingCells.get(reducedAddress).aggregate(cell.getValue()); } ImmutableMap.Builder<TensorAddress, Double> reducedCells = new ImmutableMap.Builder<>(); for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet()) reducedCells.put(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); - return new MapTensor(asMappedDimensions(reducedDimensions), reducedCells.build()); - } - - private TensorType asMappedDimensions(Set<String> dimensionNames) { - TensorType.Builder builder = new TensorType.Builder(); - for (String dimensionName : dimensionNames) - builder.mapped(dimensionName); - return builder.build(); + return new MapTensor(reducedType, reducedCells.build()); } - private TensorAddress reduceDimensions(TensorAddress address, Set<String> reducedDimensions) { - return TensorAddress.fromSorted(address.elements().stream() - .filter(e -> reducedDimensions.contains(e.dimension())) - .collect(Collectors.toList())); + private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) { + Set<Integer> indexesToRemove = new HashSet<>(); + for (String dimensionToRemove : this.dimensions) + indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get()); + + String[] reducedLabels = new String[reducedType.dimensions().size()]; + int reducedLabelIndex = 0; + for (int i = 0; i < address.elements().size(); i++) + if ( ! indexesToRemove.contains(i)) + reducedLabels[reducedLabelIndex++] = address.elements().get(i); + return new TensorAddress(reducedLabels); } private Tensor reduceAll(Tensor argument) { @@ -137,7 +140,7 @@ public class Reduce extends PrimitiveTensorFunction { private static abstract class ValueAggregator { - public static ValueAggregator ofType(Aggregator aggregator) { + private static ValueAggregator ofType(Aggregator aggregator) { switch (aggregator) { case avg : return new AvgAggregator(); case count : return new CountAggregator(); |