aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-01-10 13:07:54 +0100
committerLester Solbakken <lesters@oath.com>2020-01-10 13:07:54 +0100
commitf9e6262cd5f8b919db33f8a43cad45c25546d2e4 (patch)
treefdbd5e4a7b918f7195f4d01f11e694009760741c /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
parent1bfeb920e039dd22f586c382c66fef90af6f4459 (diff)
Revert "Revert "Require equal sizes in join""
This reverts commit d78f8b089753025421524539e86ca96b7bf3369c.
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java11
1 files changed, 8 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 1e0eaa7fad3..5419d04a4fb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -48,7 +48,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
/** Returns the type resulting from applying Join to the two given types */
public static TensorType outputType(TensorType a, TensorType b) {
- return new TensorType.Builder(a, b).build();
+ try {
+ return new TensorType.Builder(false, a, b).build();
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Can not join " + a + " and " + b, e);
+ }
}
public DoubleBinaryOperator combinator() { return combinator; }
@@ -75,14 +80,14 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
@Override
public TensorType type(TypeContext<NAMETYPE> context) {
- return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
+ return outputType(argumentA.type(context), argumentB.type(context));
}
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
+ TensorType joinedType = outputType(a.type(), b.type());
return evaluate(a, b, joinedType, combinator);
}