diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-18 09:27:05 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-18 09:27:05 +0100 |
commit | a8c35d35066f76f69a9254ee3957b3dd7aefb753 (patch) | |
tree | afb2810b9419d123cf035928565ead64aa1b393b /vespajlib | |
parent | 6b32fef5bf7e8850ca59a52ea023cf1c9dc17b75 (diff) | |
parent | a9cb1fe9dff8c09c62e3402cc261a57e6c88e751 (diff) |
Merge pull request #29959 from vespa-engine/balder/create-indexes-to-remove-once
Balder/create indexes to remove once
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 6 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 33 |
2 files changed, 27 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index d0d01d29b26..8a4179cdc1a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -62,7 +62,11 @@ public interface Tensor { /** Returns whether this have any cells */ default boolean isEmpty() { return size() == 0; } - /** Returns the number of cells in this */ + /** + * Returns the number of cells in this. + * TODO Figure how to best return an int instead of a long + * An int is large enough, and java is far better at int base loops than long + **/ long size(); /** Returns the value of a cell, or 0.0 if this cell does not exist */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 5327457a438..8cf88610599 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -128,31 +128,42 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET TensorType reducedType = outputType(argument.type(), dimensions); // Reduce cells - Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); + int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions); + // TODO cells.size() is most likely an overestimate, and might need a better heuristic + // But the upside is larger than the downside. + Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>((int)argument.size()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); - TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType, dimensions); - aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); - aggregatingCells.get(reducedAddress).aggregate(cell.getValue()); + TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey()); + ValueAggregator aggr = aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); + if (aggr == null) + aggr = aggregatingCells.get(reducedAddress); + aggr.aggregate(cell.getValue()); } Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); return reducedBuilder.build(); - } - - private static TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType, List<String> dimensions) { + private static int[] createIndexesToKeep(TensorType argumentType, List<String> dimensions) { Set<Integer> indexesToRemove = new HashSet<>(dimensions.size()*2); for (String dimensionToRemove : dimensions) indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get()); + int[] indexesToKeep = new int[argumentType.rank() - indexesToRemove.size()]; + int toKeepIndex = 0; + for (int i = 0; i < argumentType.rank(); i++) { + if ( ! indexesToRemove.contains(i)) + indexesToKeep[toKeepIndex++] = i; + } + return indexesToKeep; + } - String[] reducedLabels = new String[reducedType.dimensions().size()]; + private static TensorAddress reduceDimensions(int[] indexesToKeep, TensorAddress address) { + String[] reducedLabels = new String[indexesToKeep.length]; int reducedLabelIndex = 0; - for (int i = 0; i < address.size(); i++) - if ( ! indexesToRemove.contains(i)) - reducedLabels[reducedLabelIndex++] = address.label(i); + for (int toKeep : indexesToKeep) + reducedLabels[reducedLabelIndex++] = address.label(toKeep); return TensorAddress.of(reducedLabels); } |