summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-30 22:56:30 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-30 22:56:30 +0100
commite7dc5ca189d373c7b402205601ebd89be5c28d01 (patch)
tree0288070ac775c68b89c1b60adabd9c9bd60cc4c5 /vespajlib
parent0a72854f44982c7bb51ba60d752b6345a581690c (diff)
Use already optimize TensorAddress.partialCopy to reduce dimensions.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java11
1 files changed, 2 insertions, 9 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 0985e48c4e4..c394274032e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.Convert;
+import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.ArrayList;
import java.util.Collections;
@@ -135,7 +136,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(argument.sizeAsInt());
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
- TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey());
+ TensorAddress reducedAddress = cell.getKey().partialCopy(indexesToKeep);
ValueAggregator aggr = aggregatingCells.computeIfAbsent(reducedAddress, (key) ->ValueAggregator.ofType(aggregator));
aggr.aggregate(cell.getValue());
}
@@ -158,14 +159,6 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return indexesToKeep;
}
- private static TensorAddress reduceDimensions(int[] indexesToKeep, TensorAddress address) {
- String[] reducedLabels = new String[indexesToKeep.length];
- int reducedLabelIndex = 0;
- for (int toKeep : indexesToKeep)
- reducedLabels[reducedLabelIndex++] = address.label(toKeep);
- return TensorAddress.of(reducedLabels);
- }
-
private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); )