diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index a48ac19fbff..42c6fe2f4aa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -89,8 +89,9 @@ public class Concat extends PrimitiveTensorFunction { public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - a = ensureIndexedDimension(dimension, a); - b = ensureIndexedDimension(dimension, b); + TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type()); + a = ensureIndexedDimension(dimension, a, combinedValueType); + b = ensureIndexedDimension(dimension, b, combinedValueType); IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor IndexedTensor bIndexed = (IndexedTensor) b; @@ -128,7 +129,7 @@ public class Concat extends PrimitiveTensorFunction { } } - private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor) { + private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) { Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName); if ( dimension.isPresent() ) { if ( ! dimension.get().isIndexed()) @@ -141,7 +142,7 @@ public class Concat extends PrimitiveTensorFunction { if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); - Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType()) + Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) .indexed(dimensionName, 1) .build()) .cell(1,0) |