summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java9
1 files changed, 4 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index fa9dcda9179..e0ac549651c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -335,12 +335,11 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
int[] bIndexesInJoined = mapIndexes(b.type(), joinedType);
// Iterate once through the smaller tensor and construct a hash map for common dimensions
- Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>();
+ Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt());
for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell aCell = cellIterator.next();
TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon);
- aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList<>());
- aCellsByCommonAddress.get(partialCommonAddress).add(aCell);
+ aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell);
}
// Iterate once through the larger tensor and use the hash map to find joinable cells
@@ -363,7 +362,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
/**
- * Returns the an array having one entry in order for each dimension of fromType
+ * Returns an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
* That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
@@ -372,7 +371,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
static int[] mapIndexes(TensorType fromType, TensorType toType) {
int[] toIndexes = new int[fromType.dimensions().size()];
for (int i = 0; i < fromType.dimensions().size(); i++)
- toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
+ toIndexes[i] = toType.indexOfDimensionAsInt(fromType.dimensions().get(i).name());
return toIndexes;
}