aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-11-27 15:58:06 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-11-27 15:58:06 +0200
commit77bb8f5117b7a0f78b2dc99a3937430339e4291d (patch)
tree9037b54f17e3175a8d11e1b43b55b71887f867a4 /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
parentf4203c3cc571722f08ee65047437c1290ed63f69 (diff)
Support index generating expressions in tensor value functions
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java20
1 files changed, 10 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 2939b964f04..29239957260 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -31,12 +31,12 @@ import java.util.function.DoubleBinaryOperator;
*
* @author bratseth
*/
-public class Join extends PrimitiveTensorFunction {
+public class Join<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argumentA, argumentB;
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
private final DoubleBinaryOperator combinator;
- public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) {
+ public Join(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator combinator) {
Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
Objects.requireNonNull(combinator, "The combinator function cannot be null");
@@ -53,18 +53,18 @@ public class Join extends PrimitiveTensorFunction {
public DoubleBinaryOperator combinator() { return combinator; }
@Override
- public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size());
- return new Join(arguments.get(0), arguments.get(1), combinator);
+ return new Join<>(arguments.get(0), arguments.get(1), combinator);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
}
@Override
@@ -73,12 +73,12 @@ public class Join extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();