diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
3 files changed, 7 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index a48ac19fbff..42c6fe2f4aa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -89,8 +89,9 @@ public class Concat extends PrimitiveTensorFunction { public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - a = ensureIndexedDimension(dimension, a); - b = ensureIndexedDimension(dimension, b); + TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type()); + a = ensureIndexedDimension(dimension, a, combinedValueType); + b = ensureIndexedDimension(dimension, b, combinedValueType); IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor IndexedTensor bIndexed = (IndexedTensor) b; @@ -128,7 +129,7 @@ public class Concat extends PrimitiveTensorFunction { } } - private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor) { + private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) { Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName); if ( dimension.isPresent() ) { if ( ! dimension.get().isIndexed()) @@ -141,7 +142,7 @@ public class Concat extends PrimitiveTensorFunction { if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); - Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType()) + Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) .indexed(dimensionName, 1) .build()) .cell(1,0) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 062e0d92e80..2939b964f04 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -390,8 +390,7 @@ public class Join extends PrimitiveTensorFunction { private static TensorType commonDimensions(Tensor a, Tensor b) { TensorType aType = a.type(); TensorType bType = b.type(); - TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.largestOf(aType.valueType(), - bType.valueType())); + TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.combinedValueType(aType, bType)); for (int i = 0; i < aType.dimensions().size(); ++i) { TensorType.Dimension aDim = aType.dimensions().get(i); for (int j = 0; j < bType.dimensions().size(); ++j) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index db950e6c8b9..1134e8177ad 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -268,8 +268,7 @@ public class ReduceJoin extends CompositeTensorFunction { } private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) { - TensorType.Builder builder = new TensorType.Builder(TensorType.Value.largestOf(a.type().valueType(), - b.type().valueType())); + TensorType.Builder builder = new TensorType.Builder(TensorType.combinedValueType(a.type(), b.type())); for (TensorType.Dimension aDim : a.type().dimensions()) { for (TensorType.Dimension bDim : b.type().dimensions()) { if (aDim.name().equals(bDim.name())) { |