From 865dac49dcf420cd00b9be0f08ee219e4f519aee Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Wed, 17 Jan 2024 22:55:30 +0100 Subject: Create set of indexes to remove once. --- vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 6 +++++- .../src/main/java/com/yahoo/tensor/functions/Reduce.java | 13 +++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) (limited to 'vespajlib/src') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index d0d01d29b26..8a4179cdc1a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -62,7 +62,11 @@ public interface Tensor { /** Returns whether this have any cells */ default boolean isEmpty() { return size() == 0; } - /** Returns the number of cells in this */ + /** + * Returns the number of cells in this. + * TODO Figure how to best return an int instead of a long + * An int is large enough, and java is far better at int base loops than long + **/ long size(); /** Returns the value of a cell, or 0.0 if this cell does not exist */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 5327457a438..76ca98124b8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -128,10 +128,13 @@ public class Reduce extends PrimitiveTensorFunction aggregatingCells = new HashMap<>(); + Set indexesToRemove = createIndexesOfDimensionsToRemove(argument.type(), dimensions); + // TODO cells.size() is most likely an overestimate, and might need a better heuristic + // But the upside is larger than teh downside. + Map aggregatingCells = new HashMap<>((int)argument.size()); for (Iterator i = argument.cellIterator(); i.hasNext(); ) { Map.Entry cell = i.next(); - TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType, dimensions); + TensorAddress reducedAddress = reduceDimensions(indexesToRemove, cell.getKey(), reducedType); aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); aggregatingCells.get(reducedAddress).aggregate(cell.getValue()); } @@ -142,12 +145,14 @@ public class Reduce extends PrimitiveTensorFunction dimensions) { + private static Set createIndexesOfDimensionsToRemove(TensorType argumentType, List dimensions) { Set indexesToRemove = new HashSet<>(dimensions.size()*2); for (String dimensionToRemove : dimensions) indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get()); + return indexesToRemove; + } + private static TensorAddress reduceDimensions(Set indexesToRemove, TensorAddress address, TensorType reducedType) { String[] reducedLabels = new String[reducedType.dimensions().size()]; int reducedLabelIndex = 0; for (int i = 0; i < address.size(); i++) -- cgit v1.2.3