aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-22 10:41:20 +0200
committerLester Solbakken <lesters@oath.com>2021-04-22 10:41:20 +0200
commit6b1286359bcdaed6c870f109450cb9934c110144 (patch)
tree686c15cc8a2625d4b7ef56e52adc220fb7738b6c /vespajlib
parent29fd55d44ac78032406a2c7b9b46c7373189f8dd (diff)
Concat: find value type from TypeResolver
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java9
1 files changed, 4 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index e71d1c717d3..59a452588ca 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -68,14 +68,13 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type());
- a = ensureIndexedDimension(dimension, a, combinedValueType);
- b = ensureIndexedDimension(dimension, b, combinedValueType);
+ TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
+
+ a = ensureIndexedDimension(dimension, a, concatType.valueType());
+ b = ensureIndexedDimension(dimension, b, concatType.valueType());
IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
IndexedTensor bIndexed = (IndexedTensor) b;
-
- TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);