aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java9
1 files changed, 4 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index e71d1c717d3..59a452588ca 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -68,14 +68,13 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type());
- a = ensureIndexedDimension(dimension, a, combinedValueType);
- b = ensureIndexedDimension(dimension, b, combinedValueType);
+ TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
+
+ a = ensureIndexedDimension(dimension, a, concatType.valueType());
+ b = ensureIndexedDimension(dimension, b, concatType.valueType());
IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
IndexedTensor bIndexed = (IndexedTensor) b;
-
- TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);