diff options
author | Jon Bratseth <bratseth@gmail.com> | 2020-01-13 15:48:30 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-13 15:48:30 +0100 |
commit | 6ba68c27681b36ef4c8fd1b3f5b7b03ec8459fc3 (patch) | |
tree | cace1e1df8b5c723d1c74aa99b3b3db01d247dab /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | |
parent | 1f2a575ec6a02c15275b89adb3e610e20c776e8f (diff) |
Revert "Revert "Revert "Revert "Require equal sizes in join""""
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 1e0eaa7fad3..5419d04a4fb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -48,7 +48,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP /** Returns the type resulting from applying Join to the two given types */ public static TensorType outputType(TensorType a, TensorType b) { - return new TensorType.Builder(a, b).build(); + try { + return new TensorType.Builder(false, a, b).build(); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Can not join " + a + " and " + b, e); + } } public DoubleBinaryOperator combinator() { return combinator; } @@ -75,14 +80,14 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP @Override public TensorType type(TypeContext<NAMETYPE> context) { - return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); + return outputType(argumentA.type(context), argumentB.type(context)); } @Override public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); + TensorType joinedType = outputType(a.type(), b.type()); return evaluate(a, b, joinedType, combinator); } |