aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java20
1 files changed, 10 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 42c6fe2f4aa..a31a7da67e5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -23,12 +23,12 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
-public class Concat extends PrimitiveTensorFunction {
+public class Concat<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argumentA, argumentB;
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
private final String dimension;
- public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) {
+ public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) {
Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
Objects.requireNonNull(dimension, "The dimension cannot be null");
@@ -38,18 +38,18 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if (arguments.size() != 2)
throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
- return new Concat(arguments.get(0), arguments.get(1), dimension);
+ return new Concat<>(arguments.get(0), arguments.get(1), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Concat(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
}
@Override
@@ -58,7 +58,7 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return type(argumentA.type(context), argumentB.type(context));
}
@@ -86,7 +86,7 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type());