aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-08 21:46:42 -0800
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-08 21:46:42 -0800
commite444e8e80296ed12fc8a530d0525ca6ca7566632 (patch)
treec23706b2cf884364b3ff2aa8d4c6f014396ba538 /vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
parent65feef43135153d5ec3b0dcc911d928cb13089d2 (diff)
Optimize some indexed tensor operations
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java19
1 files changed, 15 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 0a94a1e7b5e..3da236a2624 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -2,6 +2,7 @@ package com.yahoo.tensor.functions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MapTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -95,7 +96,10 @@ public class Reduce extends PrimitiveTensorFunction {
dimensions + ": Not all those dimensions are present in this tensor");
if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size())
- return reduceAll(argument);
+ if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor)
+ return reduceIndexedVector((IndexedTensor)argument);
+ else
+ return reduceAllGeneral(argument);
// Reduce type
TensorType.Builder builder = new TensorType.Builder();
@@ -130,13 +134,20 @@ public class Reduce extends PrimitiveTensorFunction {
return new TensorAddress(reducedLabels);
}
- private Tensor reduceAll(Tensor argument) {
+ private Tensor reduceAllGeneral(Tensor argument) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Double cellValue : argument.cells().values())
valueAggregator.aggregate(cellValue);
- return new MapTensor(TensorType.empty, ImmutableMap.of(TensorAddress.empty, valueAggregator.aggregatedValue()));
+ return new IndexedTensor.Builder(TensorType.empty).set((valueAggregator.aggregatedValue())).build();
}
-
+
+ private Tensor reduceIndexedVector(IndexedTensor argument) {
+ ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
+ for (int i = 0; i < argument.length(0); i++)
+ valueAggregator.aggregate(argument.get(i));
+ return new IndexedTensor.Builder(TensorType.empty).set((valueAggregator.aggregatedValue())).build();
+ }
+
private static abstract class ValueAggregator {
private static ValueAggregator ofType(Aggregator aggregator) {