diff options
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 11 |
1 files changed, 2 insertions, 9 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 0985e48c4e4..c394274032e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.impl.Convert; +import com.yahoo.tensor.impl.TensorAddressAny; import java.util.ArrayList; import java.util.Collections; @@ -135,7 +136,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(argument.sizeAsInt()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); - TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey()); + TensorAddress reducedAddress = cell.getKey().partialCopy(indexesToKeep); ValueAggregator aggr = aggregatingCells.computeIfAbsent(reducedAddress, (key) ->ValueAggregator.ofType(aggregator)); aggr.aggregate(cell.getValue()); } @@ -158,14 +159,6 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return indexesToKeep; } - private static TensorAddress reduceDimensions(int[] indexesToKeep, TensorAddress address) { - String[] reducedLabels = new String[indexesToKeep.length]; - int reducedLabelIndex = 0; - for (int toKeep : indexesToKeep) - reducedLabels[reducedLabelIndex++] = address.label(toKeep); - return TensorAddress.of(reducedLabels); - } - private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); ) |