summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java9
1 files changed, 5 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index a48ac19fbff..42c6fe2f4aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -89,8 +89,9 @@ public class Concat extends PrimitiveTensorFunction {
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- a = ensureIndexedDimension(dimension, a);
- b = ensureIndexedDimension(dimension, b);
+ TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type());
+ a = ensureIndexedDimension(dimension, a, combinedValueType);
+ b = ensureIndexedDimension(dimension, b, combinedValueType);
IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
IndexedTensor bIndexed = (IndexedTensor) b;
@@ -128,7 +129,7 @@ public class Concat extends PrimitiveTensorFunction {
}
}
- private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor) {
+ private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) {
Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
if ( dimension.isPresent() ) {
if ( ! dimension.get().isIndexed())
@@ -141,7 +142,7 @@ public class Concat extends PrimitiveTensorFunction {
if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
throw new IllegalArgumentException("Concat requires an indexed tensor, " +
"but got a tensor with type " + tensor.type());
- Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType())
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType)
.indexed(dimensionName, 1)
.build())
.cell(1,0)