diff options
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index e71d1c717d3..59a452588ca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -68,14 +68,13 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type()); - a = ensureIndexedDimension(dimension, a, combinedValueType); - b = ensureIndexedDimension(dimension, b, combinedValueType); + TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension); + + a = ensureIndexedDimension(dimension, a, concatType.valueType()); + b = ensureIndexedDimension(dimension, b, concatType.valueType()); IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor IndexedTensor bIndexed = (IndexedTensor) b; - - TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension); DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); |