diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-17 23:09:30 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-17 23:13:27 +0100 |
commit | fd4a580ae08bcf8bcbc21fff8fd80fa1205286a7 (patch) | |
tree | 96e58994bd081f386d4bd8e64c0b185b7241f0e2 /vespajlib/src/main/java/com/yahoo | |
parent | 865dac49dcf420cd00b9be0f08ee219e4f519aee (diff) |
Reverse the problem from indexes-to-remove to indexes-to-keep. Then you avoid hash lookup in inner loop.
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 32 |
1 files changed, 19 insertions, 13 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 76ca98124b8..e5472209933 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -128,36 +128,42 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET TensorType reducedType = outputType(argument.type(), dimensions); // Reduce cells - Set<Integer> indexesToRemove = createIndexesOfDimensionsToRemove(argument.type(), dimensions); + int [] indexesToKeep = createIndexesToKeep(argument.type(), dimensions); // TODO cells.size() is most likely an overestimate, and might need a better heuristic - // But the upside is larger than teh downside. + // But the upside is larger than the downside. Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>((int)argument.size()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); - TensorAddress reducedAddress = reduceDimensions(indexesToRemove, cell.getKey(), reducedType); - aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); - aggregatingCells.get(reducedAddress).aggregate(cell.getValue()); + TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey()); + ValueAggregator aggr = aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); + if (aggr == null) + aggr = aggregatingCells.get(reducedAddress); + aggr.aggregate(cell.getValue()); } Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); return reducedBuilder.build(); - } - private static Set<Integer> createIndexesOfDimensionsToRemove(TensorType argumentType, List<String> dimensions) { + private static int [] createIndexesToKeep(TensorType argumentType, List<String> dimensions) { Set<Integer> indexesToRemove = new HashSet<>(dimensions.size()*2); for (String dimensionToRemove : dimensions) indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get()); - return indexesToRemove; + int [] indexesToKeep = new int[argumentType.rank() - indexesToRemove.size()]; + int toKeepIndex = 0; + for (int i = 0;i < argumentType.rank(); i++) { + if ( ! indexesToRemove.contains(i)) + indexesToKeep[toKeepIndex++] = i; + } + return indexesToKeep; } - private static TensorAddress reduceDimensions(Set<Integer> indexesToRemove, TensorAddress address, TensorType reducedType) { - String[] reducedLabels = new String[reducedType.dimensions().size()]; + private static TensorAddress reduceDimensions(int [] indexesToKeep, TensorAddress address) { + String[] reducedLabels = new String[indexesToKeep.length]; int reducedLabelIndex = 0; - for (int i = 0; i < address.size(); i++) - if ( ! indexesToRemove.contains(i)) - reducedLabels[reducedLabelIndex++] = address.label(i); + for (int toKeep : indexesToKeep) + reducedLabels[reducedLabelIndex++] = address.label(toKeep); return TensorAddress.of(reducedLabels); } |