summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-31 16:41:51 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-31 16:41:51 +0100
commit0f57ee63dd82bffd98b6b7ad9cdbd20cacaaf72b (patch)
tree62381352a14828fc5c0c631a25fadbda2ef27188 /vespajlib
parent5746b9a8f5cf53fa21568dea9f0e798ae17f0bab (diff)
Limit optimization to bound types, which is the prevalent kind.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java50
1 files changed, 27 insertions, 23 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 17cacb8f009..947fd6e0012 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -16,12 +16,10 @@ import com.yahoo.tensor.impl.Convert;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
-import java.util.Set;
/**
* The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
@@ -128,10 +126,14 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
else
return reduceAllGeneral(argument, aggregator);
}
- if (argument instanceof IndexedTensor indexedTensor) {
- return reduceIndexedTensor(indexedTensor, dimensions, aggregator);
+
+ TensorType reducedType = outputType(argument.type(), dimensions);
+ int[] indexesToReduce = createIndexesToReduce(argument.type(), dimensions);
+ int[] indexesToKeep = createIndexesToKeep(argument.type(), indexesToReduce);
+ if (argument instanceof IndexedTensor indexedTensor && reducedType.hasOnlyIndexedBoundDimensions()) {
+ return reduceIndexedTensor(indexedTensor, reducedType, indexesToKeep, indexesToReduce, aggregator);
} else {
- return reduceGeneral(argument, dimensions, aggregator);
+ return reduceGeneral(argument, reducedType, indexesToKeep, aggregator);
}
}
@@ -173,24 +175,15 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
- private static Tensor reduceIndexedTensor(IndexedTensor argument, List<String> dimensions, Aggregator aggregator) {
- TensorType reducedType = outputType(argument.type(), dimensions);
+ private static Tensor reduceIndexedTensor(IndexedTensor argument, TensorType reducedType, int[] indexesToKeep, int[] indexesToReduce, Aggregator aggregator) {
+
var reducedBuilder = IndexedTensor.Builder.of(reducedType);
DirectIndexedAddress reducedAddress = DirectIndexedAddress.of(DimensionSizes.of(reducedType));
- int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions);
- int[] indexesToReduce = new int[dimensions.size()];
- for (int i = 0; i < dimensions.size(); i++) {
- indexesToReduce[i] = argument.type().indexOfDimension(dimensions.get(i)).get();
- }
reduce(reducedBuilder, reducedAddress, argument, aggregator, argument.directAddress(), indexesToKeep, 0, indexesToReduce);
return reducedBuilder.build();
}
- private static Tensor reduceGeneral(Tensor argument, List<String> dimensions, Aggregator aggregator) {
- TensorType reducedType = outputType(argument.type(), dimensions);
-
- // Reduce cells
- int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions);
+ private static Tensor reduceGeneral(Tensor argument, TensorType reducedType, int[] indexesToKeep, Aggregator aggregator) {
// TODO cells.size() is most likely an overestimate, and might need a better heuristic
// But the upside is larger than the downside.
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(argument.sizeAsInt());
@@ -206,18 +199,29 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return reducedBuilder.build();
}
- private static int[] createIndexesToKeep(TensorType argumentType, List<String> dimensions) {
- Set<Integer> indexesToRemove = new HashSet<>(dimensions.size()*2);
- for (String dimensionToRemove : dimensions)
- indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get());
- int[] indexesToKeep = new int[argumentType.rank() - indexesToRemove.size()];
+
+ private static int[] createIndexesToReduce(TensorType tensorType, List<String> dimensions) {
+ int[] indexesToReduce = new int[dimensions.size()];
+ for (int i = 0; i < dimensions.size(); i++) {
+ indexesToReduce[i] = tensorType.indexOfDimension(dimensions.get(i)).get();
+ }
+ return indexesToReduce;
+ }
+ private static int[] createIndexesToKeep(TensorType argumentType, int[] indexesToReduce) {
+ int[] indexesToKeep = new int[argumentType.rank() - indexesToReduce.length];
int toKeepIndex = 0;
for (int i = 0; i < argumentType.rank(); i++) {
- if ( ! indexesToRemove.contains(i))
+ if ( ! contains(indexesToReduce, i))
indexesToKeep[toKeepIndex++] = i;
}
return indexesToKeep;
}
+ private static boolean contains(int[] list, int key) {
+ for (int candidate : list) {
+ if (candidate == key) return true;
+ }
+ return false;
+ }
private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);