diff options
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 50 |
1 files changed, 27 insertions, 23 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 17cacb8f009..947fd6e0012 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -16,12 +16,10 @@ import com.yahoo.tensor.impl.Convert; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; /** * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions @@ -128,10 +126,14 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET else return reduceAllGeneral(argument, aggregator); } - if (argument instanceof IndexedTensor indexedTensor) { - return reduceIndexedTensor(indexedTensor, dimensions, aggregator); + + TensorType reducedType = outputType(argument.type(), dimensions); + int[] indexesToReduce = createIndexesToReduce(argument.type(), dimensions); + int[] indexesToKeep = createIndexesToKeep(argument.type(), indexesToReduce); + if (argument instanceof IndexedTensor indexedTensor && reducedType.hasOnlyIndexedBoundDimensions()) { + return reduceIndexedTensor(indexedTensor, reducedType, indexesToKeep, indexesToReduce, aggregator); } else { - return reduceGeneral(argument, dimensions, aggregator); + return reduceGeneral(argument, reducedType, indexesToKeep, aggregator); } } @@ -173,24 +175,15 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } - private static Tensor reduceIndexedTensor(IndexedTensor argument, List<String> dimensions, Aggregator aggregator) { - TensorType reducedType = outputType(argument.type(), dimensions); + private static Tensor reduceIndexedTensor(IndexedTensor argument, TensorType reducedType, int[] indexesToKeep, int[] indexesToReduce, Aggregator aggregator) { + var reducedBuilder = IndexedTensor.Builder.of(reducedType); DirectIndexedAddress reducedAddress = DirectIndexedAddress.of(DimensionSizes.of(reducedType)); - int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions); - int[] indexesToReduce = new int[dimensions.size()]; - for (int i = 0; i < dimensions.size(); i++) { - indexesToReduce[i] = argument.type().indexOfDimension(dimensions.get(i)).get(); - } reduce(reducedBuilder, reducedAddress, argument, aggregator, argument.directAddress(), indexesToKeep, 0, indexesToReduce); return reducedBuilder.build(); } - private static Tensor reduceGeneral(Tensor argument, List<String> dimensions, Aggregator aggregator) { - TensorType reducedType = outputType(argument.type(), dimensions); - - // Reduce cells - int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions); + private static Tensor reduceGeneral(Tensor argument, TensorType reducedType, int[] indexesToKeep, Aggregator aggregator) { // TODO cells.size() is most likely an overestimate, and might need a better heuristic // But the upside is larger than the downside. Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(argument.sizeAsInt()); @@ -206,18 +199,29 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return reducedBuilder.build(); } - private static int[] createIndexesToKeep(TensorType argumentType, List<String> dimensions) { - Set<Integer> indexesToRemove = new HashSet<>(dimensions.size()*2); - for (String dimensionToRemove : dimensions) - indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get()); - int[] indexesToKeep = new int[argumentType.rank() - indexesToRemove.size()]; + + private static int[] createIndexesToReduce(TensorType tensorType, List<String> dimensions) { + int[] indexesToReduce = new int[dimensions.size()]; + for (int i = 0; i < dimensions.size(); i++) { + indexesToReduce[i] = tensorType.indexOfDimension(dimensions.get(i)).get(); + } + return indexesToReduce; + } + private static int[] createIndexesToKeep(TensorType argumentType, int[] indexesToReduce) { + int[] indexesToKeep = new int[argumentType.rank() - indexesToReduce.length]; int toKeepIndex = 0; for (int i = 0; i < argumentType.rank(); i++) { - if ( ! indexesToRemove.contains(i)) + if ( ! contains(indexesToReduce, i)) indexesToKeep[toKeepIndex++] = i; } return indexesToKeep; } + private static boolean contains(int[] list, int key) { + for (int candidate : list) { + if (candidate == key) return true; + } + return false; + } private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); |