diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-08 21:46:42 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-08 21:46:42 -0800 |
commit | e444e8e80296ed12fc8a530d0525ca6ca7566632 (patch) | |
tree | c23706b2cf884364b3ff2aa8d4c6f014396ba538 /vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | |
parent | 65feef43135153d5ec3b0dcc911d928cb13089d2 (diff) |
Optimize some indexed tensor operations
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 0a94a1e7b5e..3da236a2624 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MapTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -95,7 +96,10 @@ public class Reduce extends PrimitiveTensorFunction { dimensions + ": Not all those dimensions are present in this tensor"); if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) - return reduceAll(argument); + if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor) + return reduceIndexedVector((IndexedTensor)argument); + else + return reduceAllGeneral(argument); // Reduce type TensorType.Builder builder = new TensorType.Builder(); @@ -130,13 +134,20 @@ public class Reduce extends PrimitiveTensorFunction { return new TensorAddress(reducedLabels); } - private Tensor reduceAll(Tensor argument) { + private Tensor reduceAllGeneral(Tensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Double cellValue : argument.cells().values()) valueAggregator.aggregate(cellValue); - return new MapTensor(TensorType.empty, ImmutableMap.of(TensorAddress.empty, valueAggregator.aggregatedValue())); + return new IndexedTensor.Builder(TensorType.empty).set((valueAggregator.aggregatedValue())).build(); } - + + private Tensor reduceIndexedVector(IndexedTensor argument) { + ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); + for (int i = 0; i < argument.length(0); i++) + valueAggregator.aggregate(argument.get(i)); + return new IndexedTensor.Builder(TensorType.empty).set((valueAggregator.aggregatedValue())).build(); + } + private static abstract class ValueAggregator { private static ValueAggregator ofType(Aggregator aggregator) { |