aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-17 22:55:30 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-17 22:55:30 +0100
commit865dac49dcf420cd00b9be0f08ee219e4f519aee (patch)
treed0614c2ecbaeec40a119195785bd3a105859fc60 /vespajlib/src
parent6b32fef5bf7e8850ca59a52ea023cf1c9dc17b75 (diff)
Create set of indexes to remove once.
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java13
2 files changed, 14 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index d0d01d29b26..8a4179cdc1a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -62,7 +62,11 @@ public interface Tensor {
/** Returns whether this have any cells */
default boolean isEmpty() { return size() == 0; }
- /** Returns the number of cells in this */
+ /**
+ * Returns the number of cells in this.
+ * TODO Figure how to best return an int instead of a long
+ * An int is large enough, and java is far better at int base loops than long
+ **/
long size();
/** Returns the value of a cell, or 0.0 if this cell does not exist */
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 5327457a438..76ca98124b8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -128,10 +128,13 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
TensorType reducedType = outputType(argument.type(), dimensions);
// Reduce cells
- Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
+ Set<Integer> indexesToRemove = createIndexesOfDimensionsToRemove(argument.type(), dimensions);
+ // TODO cells.size() is most likely an overestimate, and might need a better heuristic
+ // But the upside is larger than teh downside.
+ Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>((int)argument.size());
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
- TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType, dimensions);
+ TensorAddress reducedAddress = reduceDimensions(indexesToRemove, cell.getKey(), reducedType);
aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
aggregatingCells.get(reducedAddress).aggregate(cell.getValue());
}
@@ -142,12 +145,14 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return reducedBuilder.build();
}
-
- private static TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType, List<String> dimensions) {
+ private static Set<Integer> createIndexesOfDimensionsToRemove(TensorType argumentType, List<String> dimensions) {
Set<Integer> indexesToRemove = new HashSet<>(dimensions.size()*2);
for (String dimensionToRemove : dimensions)
indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get());
+ return indexesToRemove;
+ }
+ private static TensorAddress reduceDimensions(Set<Integer> indexesToRemove, TensorAddress address, TensorType reducedType) {
String[] reducedLabels = new String[reducedType.dimensions().size()];
int reducedLabelIndex = 0;
for (int i = 0; i < address.size(); i++)