diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 54d7710c9dc..017dc3920e6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -61,8 +61,8 @@ public class Reduce extends PrimitiveTensorFunction { } public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) { - if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all - TensorType.Builder b = new TensorType.Builder(); + TensorType.Builder b = new TensorType.Builder(inputType.valueType()); + if (reduceDimensions.isEmpty()) return b.build(); // means reduce all for (TensorType.Dimension dimension : inputType.dimensions()) { if ( ! reduceDimensions.contains(dimension.name())) b.dimension(dimension); @@ -109,8 +109,8 @@ public class Reduce extends PrimitiveTensorFunction { } private static TensorType type(TensorType argumentType, List<String> dimensions) { - if (dimensions.isEmpty()) return TensorType.empty; // means reduce all - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(argumentType.valueType()); + if (dimensions.isEmpty()) return builder.build(); // means reduce all for (TensorType.Dimension dimension : argumentType.dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); |