diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 42c6fe2f4aa..a31a7da67e5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -23,12 +23,12 @@ import java.util.stream.Collectors; * * @author bratseth */ -public class Concat extends PrimitiveTensorFunction { +public class Concat<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argumentA, argumentB; + private final TensorFunction<NAMETYPE> argumentA, argumentB; private final String dimension; - public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) { + public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) { Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); Objects.requireNonNull(dimension, "The dimension cannot be null"); @@ -38,18 +38,18 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 2) throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); - return new Concat(arguments.get(0), arguments.get(1), dimension); + return new Concat<>(arguments.get(0), arguments.get(1), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Concat(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension); } @Override @@ -58,7 +58,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return type(argumentA.type(context), argumentB.type(context)); } @@ -86,7 +86,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type()); |