summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java2
5 files changed, 15 insertions, 11 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 91ab4f9d046..a0a257bb909 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -141,7 +141,11 @@ public class Concat extends PrimitiveTensorFunction {
if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
throw new IllegalArgumentException("Concat requires an indexed tensor, " +
"but got a tensor with type " + tensor.type());
- Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType())
+ .indexed(dimensionName, 1)
+ .build())
+ .cell(1,0)
+ .build();
return tensor.multiply(unitTensor);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 62ee471fcf4..062e0d92e80 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -386,13 +386,12 @@ public class Join extends PrimitiveTensorFunction {
return true;
}
- /**
- * Returns common dimension of a and b as a new tensor type
- */
+ /** Returns common dimension of a and b as a new tensor type */
private static TensorType commonDimensions(Tensor a, Tensor b) {
- TensorType.Builder typeBuilder = new TensorType.Builder();
TensorType aType = a.type();
TensorType bType = b.type();
+ TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.largestOf(aType.valueType(),
+ bType.valueType()));
for (int i = 0; i < aType.dimensions().size(); ++i) {
TensorType.Dimension aDim = aType.dimensions().get(i);
for (int j = 0; j < bType.dimensions().size(); ++j) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 54d7710c9dc..017dc3920e6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -61,8 +61,8 @@ public class Reduce extends PrimitiveTensorFunction {
}
public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
- if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all
- TensorType.Builder b = new TensorType.Builder();
+ TensorType.Builder b = new TensorType.Builder(inputType.valueType());
+ if (reduceDimensions.isEmpty()) return b.build(); // means reduce all
for (TensorType.Dimension dimension : inputType.dimensions()) {
if ( ! reduceDimensions.contains(dimension.name()))
b.dimension(dimension);
@@ -109,8 +109,8 @@ public class Reduce extends PrimitiveTensorFunction {
}
private static TensorType type(TensorType argumentType, List<String> dimensions) {
- if (dimensions.isEmpty()) return TensorType.empty; // means reduce all
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(argumentType.valueType());
+ if (dimensions.isEmpty()) return builder.build(); // means reduce all
for (TensorType.Dimension dimension : argumentType.dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
builder.dimension(dimension);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index b268e33b418..db950e6c8b9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -268,7 +268,8 @@ public class ReduceJoin extends CompositeTensorFunction {
}
private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(TensorType.Value.largestOf(a.type().valueType(),
+ b.type().valueType()));
for (TensorType.Dimension aDim : a.type().dimensions()) {
for (TensorType.Dimension bDim : b.type().dimensions()) {
if (aDim.name().equals(bDim.name())) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index e18af235d59..5694684956e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -75,7 +75,7 @@ public class Rename extends PrimitiveTensorFunction {
}
private TensorType type(TensorType type) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(type.valueType());
for (TensorType.Dimension dimension : type.dimensions())
builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
return builder.build();