summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-17 23:09:30 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-17 23:13:27 +0100
commitfd4a580ae08bcf8bcbc21fff8fd80fa1205286a7 (patch)
tree96e58994bd081f386d4bd8e64c0b185b7241f0e2 /vespajlib
parent865dac49dcf420cd00b9be0f08ee219e4f519aee (diff)
Reverse the problem from indexes-to-remove to indexes-to-keep. Then you avoid hash lookup in inner loop.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java32
1 files changed, 19 insertions, 13 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 76ca98124b8..e5472209933 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -128,36 +128,42 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
TensorType reducedType = outputType(argument.type(), dimensions);
// Reduce cells
- Set<Integer> indexesToRemove = createIndexesOfDimensionsToRemove(argument.type(), dimensions);
+ int [] indexesToKeep = createIndexesToKeep(argument.type(), dimensions);
// TODO cells.size() is most likely an overestimate, and might need a better heuristic
- // But the upside is larger than teh downside.
+ // But the upside is larger than the downside.
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>((int)argument.size());
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
- TensorAddress reducedAddress = reduceDimensions(indexesToRemove, cell.getKey(), reducedType);
- aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
- aggregatingCells.get(reducedAddress).aggregate(cell.getValue());
+ TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey());
+ ValueAggregator aggr = aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
+ if (aggr == null)
+ aggr = aggregatingCells.get(reducedAddress);
+ aggr.aggregate(cell.getValue());
}
Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet())
reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
return reducedBuilder.build();
-
}
- private static Set<Integer> createIndexesOfDimensionsToRemove(TensorType argumentType, List<String> dimensions) {
+ private static int [] createIndexesToKeep(TensorType argumentType, List<String> dimensions) {
Set<Integer> indexesToRemove = new HashSet<>(dimensions.size()*2);
for (String dimensionToRemove : dimensions)
indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get());
- return indexesToRemove;
+ int [] indexesToKeep = new int[argumentType.rank() - indexesToRemove.size()];
+ int toKeepIndex = 0;
+ for (int i = 0;i < argumentType.rank(); i++) {
+ if ( ! indexesToRemove.contains(i))
+ indexesToKeep[toKeepIndex++] = i;
+ }
+ return indexesToKeep;
}
- private static TensorAddress reduceDimensions(Set<Integer> indexesToRemove, TensorAddress address, TensorType reducedType) {
- String[] reducedLabels = new String[reducedType.dimensions().size()];
+ private static TensorAddress reduceDimensions(int [] indexesToKeep, TensorAddress address) {
+ String[] reducedLabels = new String[indexesToKeep.length];
int reducedLabelIndex = 0;
- for (int i = 0; i < address.size(); i++)
- if ( ! indexesToRemove.contains(i))
- reducedLabels[reducedLabelIndex++] = address.label(i);
+ for (int toKeep : indexesToKeep)
+ reducedLabels[reducedLabelIndex++] = address.label(toKeep);
return TensorAddress.of(reducedLabels);
}