diff options
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 6 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java | 33 |
2 files changed, 27 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index d0d01d29b26..8a4179cdc1a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -62,7 +62,11 @@ public interface Tensor { /** Returns whether this have any cells */ default boolean isEmpty() { return size() == 0; } - /** Returns the number of cells in this */ + /** + * Returns the number of cells in this. + * TODO Figure how to best return an int instead of a long + * An int is large enough, and java is far better at int base loops than long + **/ long size(); /** Returns the value of a cell, or 0.0 if this cell does not exist */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 5327457a438..8cf88610599 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -128,31 +128,42 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET TensorType reducedType = outputType(argument.type(), dimensions); // Reduce cells - Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); + int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions); + // TODO cells.size() is most likely an overestimate, and might need a better heuristic + // But the upside is larger than the downside. + Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>((int)argument.size()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); - TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType, dimensions); - aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); - aggregatingCells.get(reducedAddress).aggregate(cell.getValue()); + TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey()); + ValueAggregator aggr = aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); + if (aggr == null) + aggr = aggregatingCells.get(reducedAddress); + aggr.aggregate(cell.getValue()); } Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); return reducedBuilder.build(); - } - - private static TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType, List<String> dimensions) { + private static int[] createIndexesToKeep(TensorType argumentType, List<String> dimensions) { Set<Integer> indexesToRemove = new HashSet<>(dimensions.size()*2); for (String dimensionToRemove : dimensions) indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get()); + int[] indexesToKeep = new int[argumentType.rank() - indexesToRemove.size()]; + int toKeepIndex = 0; + for (int i = 0; i < argumentType.rank(); i++) { + if ( ! indexesToRemove.contains(i)) + indexesToKeep[toKeepIndex++] = i; + } + return indexesToKeep; + } - String[] reducedLabels = new String[reducedType.dimensions().size()]; + private static TensorAddress reduceDimensions(int[] indexesToKeep, TensorAddress address) { + String[] reducedLabels = new String[indexesToKeep.length]; int reducedLabelIndex = 0; - for (int i = 0; i < address.size(); i++) - if ( ! indexesToRemove.contains(i)) - reducedLabels[reducedLabelIndex++] = address.label(i); + for (int toKeep : indexesToKeep) + reducedLabels[reducedLabelIndex++] = address.label(toKeep); return TensorAddress.of(reducedLabels); } |