aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-21 13:48:31 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-21 13:48:31 +0100
commit965a96d30aa606b70ce37767f7922cd8809b0ba3 (patch)
tree89d3b7c03633044917804b96b1468cffaf794212 /vespajlib
parent43c05215e666f47c15d9d73aadc80a9735b1b426 (diff)
Cache size of intersected sets, as they are recomputed every time.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java12
1 files changed, 7 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 1ded16636d3..7a336233de0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -114,7 +114,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) {
- long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
+ int joinedRank = (int)Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build());
@@ -170,7 +170,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Iterator<Tensor.Cell> superspace, long superspaceSize,
boolean reversedArgumentOrder, IndexedTensor.Builder builder,
DoubleBinaryOperator combinator) {
- long joinedLength = Math.min(subspaceSize, superspaceSize);
+ int joinedLength = (int)Math.min(subspaceSize, superspaceSize);
if (reversedArgumentOrder) {
for (int i = 0; i < joinedLength; i++) {
Tensor.Cell supercell = superspace.next();
@@ -252,6 +252,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder,
DoubleBinaryOperator combinator) {
Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames());
+ int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection
Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames());
DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize);
@@ -263,7 +264,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
Tensor.Cell aCell = aSubspace.next();
- PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions);
+ PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize);
// for each matching combination of dimensions ony in b
for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) {
Tensor.Cell bCell = bSubspace.next();
@@ -275,8 +276,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
}
- private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
- PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
+ private static PartialAddress partialAddress(TensorType addressType, TensorAddress address,
+ Set<String> retainDimensions, int sharedDimensionSize) {
+ PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize);
for (int i = 0; i < addressType.dimensions().size(); i++)
if (retainDimensions.contains(addressType.dimensions().get(i).name()))
builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i));