diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
5 files changed, 15 insertions, 11 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 91ab4f9d046..a0a257bb909 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -141,7 +141,11 @@ public class Concat extends PrimitiveTensorFunction { if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); - Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build(); + Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType()) + .indexed(dimensionName, 1) + .build()) + .cell(1,0) + .build(); return tensor.multiply(unitTensor); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 62ee471fcf4..062e0d92e80 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -386,13 +386,12 @@ public class Join extends PrimitiveTensorFunction { return true; } - /** - * Returns common dimension of a and b as a new tensor type - */ + /** Returns common dimension of a and b as a new tensor type */ private static TensorType commonDimensions(Tensor a, Tensor b) { - TensorType.Builder typeBuilder = new TensorType.Builder(); TensorType aType = a.type(); TensorType bType = b.type(); + TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.largestOf(aType.valueType(), + bType.valueType())); for (int i = 0; i < aType.dimensions().size(); ++i) { TensorType.Dimension aDim = aType.dimensions().get(i); for (int j = 0; j < bType.dimensions().size(); ++j) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 54d7710c9dc..017dc3920e6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -61,8 +61,8 @@ public class Reduce extends PrimitiveTensorFunction { } public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) { - if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all - TensorType.Builder b = new TensorType.Builder(); + TensorType.Builder b = new TensorType.Builder(inputType.valueType()); + if (reduceDimensions.isEmpty()) return b.build(); // means reduce all for (TensorType.Dimension dimension : inputType.dimensions()) { if ( ! reduceDimensions.contains(dimension.name())) b.dimension(dimension); @@ -109,8 +109,8 @@ public class Reduce extends PrimitiveTensorFunction { } private static TensorType type(TensorType argumentType, List<String> dimensions) { - if (dimensions.isEmpty()) return TensorType.empty; // means reduce all - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(argumentType.valueType()); + if (dimensions.isEmpty()) return builder.build(); // means reduce all for (TensorType.Dimension dimension : argumentType.dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index b268e33b418..db950e6c8b9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -268,7 +268,8 @@ public class ReduceJoin extends CompositeTensorFunction { } private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(TensorType.Value.largestOf(a.type().valueType(), + b.type().valueType())); for (TensorType.Dimension aDim : a.type().dimensions()) { for (TensorType.Dimension bDim : b.type().dimensions()) { if (aDim.name().equals(bDim.name())) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index e18af235d59..5694684956e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -75,7 +75,7 @@ public class Rename extends PrimitiveTensorFunction { } private TensorType type(TensorType type) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(type.valueType()); for (TensorType.Dimension dimension : type.dimensions()) builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); return builder.build(); |