diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-27 15:58:06 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-27 15:58:06 +0200 |
commit | 77bb8f5117b7a0f78b2dc99a3937430339e4291d (patch) | |
tree | 9037b54f17e3175a8d11e1b43b55b71887f867a4 /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | |
parent | f4203c3cc571722f08ee65047437c1290ed63f69 (diff) |
Support index generating expressions in tensor value functions
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 2939b964f04..29239957260 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -31,12 +31,12 @@ import java.util.function.DoubleBinaryOperator; * * @author bratseth */ -public class Join extends PrimitiveTensorFunction { +public class Join<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argumentA, argumentB; + private final TensorFunction<NAMETYPE> argumentA, argumentB; private final DoubleBinaryOperator combinator; - public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) { + public Join(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator combinator) { Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); Objects.requireNonNull(combinator, "The combinator function cannot be null"); @@ -53,18 +53,18 @@ public class Join extends PrimitiveTensorFunction { public DoubleBinaryOperator combinator() { return combinator; } @Override - public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size()); - return new Join(arguments.get(0), arguments.get(1), combinator); + return new Join<>(arguments.get(0), arguments.get(1), combinator); } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); } @Override @@ -73,12 +73,12 @@ public class Join extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); |