summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-01-29 14:51:23 +0100
committerJon Bratseth <bratseth@oath.com>2018-01-29 14:51:23 +0100
commit1b4fde01d98bf724a54b6c1cfe3ffa4b29aec90e (patch)
tree20a127542b004eceb94e4d1344b3446df8092bd2 /vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
parent28e3545728977a0be82159b8f278be8e772cb59b (diff)
Propagate type information
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java24
1 files changed, 16 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index de9f90a5804..591a6e4649e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -73,10 +73,10 @@ public class Reduce extends PrimitiveTensorFunction {
public TensorFunction argument() { return argument; }
@Override
- public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size());
return new Reduce(arguments.get(0), aggregator, dimensions);
@@ -100,6 +100,19 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
+ public TensorType type(EvaluationContext context) {
+ return type(argument.type(context));
+ }
+
+ private TensorType type(TensorType argumentType) {
+ TensorType.Builder builder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : argumentType.dimensions())
+ if ( ! dimensions.contains(dimension.name())) // keep
+ builder.dimension(dimension);
+ return builder.build();
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
@@ -113,12 +126,7 @@ public class Reduce extends PrimitiveTensorFunction {
else
return reduceAllGeneral(argument);
- // Reduce type
- TensorType.Builder builder = new TensorType.Builder();
- for (TensorType.Dimension dimension : argument.type().dimensions())
- if ( ! dimensions.contains(dimension.name())) // keep
- builder.dimension(dimension);
- TensorType reducedType = builder.build();
+ TensorType reducedType = type(argument.type());
// Reduce cells
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();