summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java8
1 files changed, 4 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 54d7710c9dc..017dc3920e6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -61,8 +61,8 @@ public class Reduce extends PrimitiveTensorFunction {
}
public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
- if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all
- TensorType.Builder b = new TensorType.Builder();
+ TensorType.Builder b = new TensorType.Builder(inputType.valueType());
+ if (reduceDimensions.isEmpty()) return b.build(); // means reduce all
for (TensorType.Dimension dimension : inputType.dimensions()) {
if ( ! reduceDimensions.contains(dimension.name()))
b.dimension(dimension);
@@ -109,8 +109,8 @@ public class Reduce extends PrimitiveTensorFunction {
}
private static TensorType type(TensorType argumentType, List<String> dimensions) {
- if (dimensions.isEmpty()) return TensorType.empty; // means reduce all
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(argumentType.valueType());
+ if (dimensions.isEmpty()) return builder.build(); // means reduce all
for (TensorType.Dimension dimension : argumentType.dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
builder.dimension(dimension);