diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 017dc3920e6..1d24333623b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -24,21 +24,21 @@ import java.util.Set; * * @author bratseth */ -public class Reduce extends PrimitiveTensorFunction { +public class Reduce<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { public enum Aggregator { avg, count, prod, sum, max, min; } - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final List<String> dimensions; private final Aggregator aggregator; /** Creates a reduce function reducing all dimensions */ - public Reduce(TensorFunction argument, Aggregator aggregator) { + public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator) { this(argument, aggregator, Collections.emptyList()); } /** Creates a reduce function reducing a single dimension */ - public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) { + public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, String dimension) { this(argument, aggregator, Collections.singletonList(dimension)); } @@ -51,7 +51,7 @@ public class Reduce extends PrimitiveTensorFunction { * producing a dimensionless tensor (a scalar). * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor */ - public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) { + public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, List<String> dimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(aggregator, "The aggregator cannot be null"); Objects.requireNonNull(dimensions, "The dimensions cannot be null"); @@ -70,25 +70,25 @@ public class Reduce extends PrimitiveTensorFunction { return b.build(); } - public TensorFunction argument() { return argument; } + public TensorFunction<NAMETYPE> argument() { return argument; } Aggregator aggregator() { return aggregator; } List<String> dimensions() { return dimensions; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size()); - return new Reduce(arguments.get(0), aggregator, dimensions); + return new Reduce<>(arguments.get(0), aggregator, dimensions); } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Reduce(argument.toPrimitive(), aggregator, dimensions); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Reduce<>(argument.toPrimitive(), aggregator, dimensions); } @Override @@ -104,7 +104,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context), dimensions); } @@ -118,7 +118,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return evaluate(this.argument.evaluate(context), dimensions, aggregator); } |