diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-01-29 14:51:23 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-01-29 14:51:23 +0100 |
commit | 1b4fde01d98bf724a54b6c1cfe3ffa4b29aec90e (patch) | |
tree | 20a127542b004eceb94e4d1344b3446df8092bd2 /vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | |
parent | 28e3545728977a0be82159b8f278be8e772cb59b (diff) |
Propagate type information
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index de9f90a5804..591a6e4649e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -73,10 +73,10 @@ public class Reduce extends PrimitiveTensorFunction { public TensorFunction argument() { return argument; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size()); return new Reduce(arguments.get(0), aggregator, dimensions); @@ -100,6 +100,19 @@ public class Reduce extends PrimitiveTensorFunction { } @Override + public TensorType type(EvaluationContext context) { + return type(argument.type(context)); + } + + private TensorType type(TensorType argumentType) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension dimension : argumentType.dimensions()) + if ( ! dimensions.contains(dimension.name())) // keep + builder.dimension(dimension); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) @@ -113,12 +126,7 @@ public class Reduce extends PrimitiveTensorFunction { else return reduceAllGeneral(argument); - // Reduce type - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension : argument.type().dimensions()) - if ( ! dimensions.contains(dimension.name())) // keep - builder.dimension(dimension); - TensorType reducedType = builder.build(); + TensorType reducedType = type(argument.type()); // Reduce cells Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); |