summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@yahoo-inc.com>2017-07-06 09:59:04 +0200
committerLester Solbakken <lesters@yahoo-inc.com>2017-07-06 09:59:04 +0200
commitb2ea86592479a78a35e58151d8f9f51a5baa300c (patch)
tree8833ddbb94d3ea5ceb593072594ec2a872384e0d /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
parent27aff3764de178e952634466733c5b4ac6b252c1 (diff)
Small fixes to mapped hash join
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java59
1 files changed, 37 insertions, 22 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index a7aa1d66a3d..95e3c444c1b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -12,8 +12,16 @@ import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
/**
@@ -284,35 +292,44 @@ public class Join extends PrimitiveTensorFunction {
private Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType) {
TensorType commonDimensionType = commonDimensions(a, b);
if (commonDimensionType.dimensions().isEmpty()) {
- return mappedGeneralJoin(a, b, joinedType); // If no common dimensions, fall back to mappedGeneralJoin
+ return mappedGeneralJoin(a, b, joinedType); // fallback
}
- Tensor smallerTensor = a.size() > b.size() ? b : a;
- Tensor largerTensor = a.size() > b.size() ? a : b;
- a = smallerTensor;
- b = largerTensor;
+ boolean switchTensors = a.size() > b.size();
+ if (switchTensors) {
+ Tensor temp = a;
+ a = b;
+ b = temp;
+ }
+
+ // Map dimension indexes to common and joined type
+ int[] aIndexesInCommon = mapIndexes(commonDimensionType, a.type());
+ int[] bIndexesInCommon = mapIndexes(commonDimensionType, b.type());
+ int[] aIndexesInJoined = mapIndexes(a.type(), joinedType);
+ int[] bIndexesInJoined = mapIndexes(b.type(), joinedType);
// Iterate once through the smaller tensor and construct a hash map for common dimensions
Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>();
for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) {
- Tensor.Cell cell = cellIterator.next();
- TensorAddress partialCommonAddress = partialCommonAddress(cell, a.type(), commonDimensionType);
+ Tensor.Cell aCell = cellIterator.next();
+ TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon);
aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList<>());
- aCellsByCommonAddress.get(partialCommonAddress).add(cell);
+ aCellsByCommonAddress.get(partialCommonAddress).add(aCell);
}
- // Iterate once through the larger tensor and use hash map to find joinable cells
- int[] aToIndexes = mapIndexes(a.type(), joinedType);
- int[] bToIndexes = mapIndexes(b.type(), joinedType);
+ // Iterate once through the larger tensor and use the hash map to find joinable cells
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) {
- Tensor.Cell cell = cellIterator.next();
- TensorAddress partialCommonAddress = partialCommonAddress(cell, b.type(), commonDimensionType);
+ Tensor.Cell bCell = cellIterator.next();
+ TensorAddress partialCommonAddress = partialCommonAddress(bCell, bIndexesInCommon);
for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, Collections.emptyList())) {
- TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes,
- cell.getKey(), bToIndexes, joinedType);
+ TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined,
+ bCell.getKey(), bIndexesInJoined, joinedType);
if (combinedAddress == null) continue; // not combinable
- builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), cell.getValue()));
+ double combinedValue = switchTensors ?
+ combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) :
+ combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
+ builder.cell(combinedAddress, combinedValue);
}
}
@@ -378,13 +395,11 @@ public class Join extends PrimitiveTensorFunction {
return typeBuilder.build();
}
- private TensorAddress partialCommonAddress(Tensor.Cell cell, TensorType type, TensorType commonDimensions) {
+ private TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) {
TensorAddress address = cell.getKey();
- String[] labels = new String[commonDimensions.dimensions().size()];
+ String[] labels = new String[indexMap.length];
for (int i = 0; i < labels.length; ++i) {
- String name = commonDimensions.dimensions().get(i).name();
- int index = type.indexOfDimension(name).orElseThrow(RuntimeException::new); // invariant, must exist
- labels[i] = address.label(index);
+ labels[i] = address.label(indexMap[i]);
}
return TensorAddress.of(labels);