From 6b1286359bcdaed6c870f109450cb9934c110144 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 22 Apr 2021 10:41:20 +0200 Subject: Concat: find value type from TypeResolver --- vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'vespajlib/src') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index e71d1c717d3..59a452588ca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -68,14 +68,13 @@ public class Concat extends PrimitiveTensorFunction context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type()); - a = ensureIndexedDimension(dimension, a, combinedValueType); - b = ensureIndexedDimension(dimension, b, combinedValueType); + TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension); + + a = ensureIndexedDimension(dimension, a, concatType.valueType()); + b = ensureIndexedDimension(dimension, b, concatType.valueType()); IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor IndexedTensor bIndexed = (IndexedTensor) b; - - TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension); DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); -- cgit v1.2.3