diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-30 22:56:30 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-30 22:56:30 +0100 |
commit | e7dc5ca189d373c7b402205601ebd89be5c28d01 (patch) | |
tree | 0288070ac775c68b89c1b60adabd9c9bd60cc4c5 /vespajlib | |
parent | 0a72854f44982c7bb51ba60d752b6345a581690c (diff) |
Use already optimize TensorAddress.partialCopy to reduce dimensions.
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 11 |
1 files changed, 2 insertions, 9 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 0985e48c4e4..c394274032e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.impl.Convert; +import com.yahoo.tensor.impl.TensorAddressAny; import java.util.ArrayList; import java.util.Collections; @@ -135,7 +136,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(argument.sizeAsInt()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); - TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey()); + TensorAddress reducedAddress = cell.getKey().partialCopy(indexesToKeep); ValueAggregator aggr = aggregatingCells.computeIfAbsent(reducedAddress, (key) ->ValueAggregator.ofType(aggregator)); aggr.aggregate(cell.getValue()); } @@ -158,14 +159,6 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return indexesToKeep; } - private static TensorAddress reduceDimensions(int[] indexesToKeep, TensorAddress address) { - String[] reducedLabels = new String[indexesToKeep.length]; - int reducedLabelIndex = 0; - for (int toKeep : indexesToKeep) - reducedLabels[reducedLabelIndex++] = address.label(toKeep); - return TensorAddress.of(reducedLabels); - } - private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); ) |