aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@yahoo-inc.com>2017-07-05 19:36:16 +0200
committerLester Solbakken <lesters@yahoo-inc.com>2017-07-05 19:36:16 +0200
commit27aff3764de178e952634466733c5b4ac6b252c1 (patch)
tree894beb2ae72ce2300c7785f49b92f42e34bd3c26 /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
parent61c0690fc5538257c6729f98e3695fa47a586437 (diff)
Optimize general mapped tensor join
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java82
1 files changed, 74 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index afb2bb6fb8c..a7aa1d66a3d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -12,13 +12,8 @@ import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.List;
+import java.util.*;
import java.util.Map;
-import java.util.Objects;
-import java.util.Optional;
-import java.util.Set;
import java.util.function.DoubleBinaryOperator;
/**
@@ -210,7 +205,7 @@ public class Join extends PrimitiveTensorFunction {
if (a instanceof IndexedTensor && b instanceof IndexedTensor)
return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType);
else
- return mappedGeneralJoin(a, b, joinedType);
+ return mappedHashJoin(a, b, joinedType);
}
private Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType) {
@@ -286,6 +281,45 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
+ private Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType) {
+ TensorType commonDimensionType = commonDimensions(a, b);
+ if (commonDimensionType.dimensions().isEmpty()) {
+ return mappedGeneralJoin(a, b, joinedType); // If no common dimensions, fall back to mappedGeneralJoin
+ }
+
+ Tensor smallerTensor = a.size() > b.size() ? b : a;
+ Tensor largerTensor = a.size() > b.size() ? a : b;
+ a = smallerTensor;
+ b = largerTensor;
+
+ // Iterate once through the smaller tensor and construct a hash map for common dimensions
+ Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>();
+ for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) {
+ Tensor.Cell cell = cellIterator.next();
+ TensorAddress partialCommonAddress = partialCommonAddress(cell, a.type(), commonDimensionType);
+ aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList<>());
+ aCellsByCommonAddress.get(partialCommonAddress).add(cell);
+ }
+
+ // Iterate once through the larger tensor and use hash map to find joinable cells
+ int[] aToIndexes = mapIndexes(a.type(), joinedType);
+ int[] bToIndexes = mapIndexes(b.type(), joinedType);
+ Tensor.Builder builder = Tensor.Builder.of(joinedType);
+ for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) {
+ Tensor.Cell cell = cellIterator.next();
+ TensorAddress partialCommonAddress = partialCommonAddress(cell, b.type(), commonDimensionType);
+ for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, Collections.emptyList())) {
+ TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes,
+ cell.getKey(), bToIndexes, joinedType);
+ if (combinedAddress == null) continue; // not combinable
+ builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), cell.getValue()));
+ }
+ }
+
+ return builder.build();
+ }
+
+
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
@@ -323,5 +357,37 @@ public class Join extends PrimitiveTensorFunction {
}
return true;
}
-
+
+
+ /**
+ * Returns common dimension of a and b as a new tensor type
+ */
+ private TensorType commonDimensions(Tensor a, Tensor b) {
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ TensorType aType = a.type();
+ TensorType bType = b.type();
+ for (int i = 0; i < aType.dimensions().size(); ++i) {
+ TensorType.Dimension aDim = aType.dimensions().get(i);
+ for (int j = 0; j < bType.dimensions().size(); ++j) {
+ TensorType.Dimension bDim = bType.dimensions().get(j);
+ if (aDim.equals(bDim)) {
+ typeBuilder.set(bDim);
+ }
+ }
+ }
+ return typeBuilder.build();
+ }
+
+ private TensorAddress partialCommonAddress(Tensor.Cell cell, TensorType type, TensorType commonDimensions) {
+ TensorAddress address = cell.getKey();
+ String[] labels = new String[commonDimensions.dimensions().size()];
+ for (int i = 0; i < labels.length; ++i) {
+ String name = commonDimensions.dimensions().get(i).name();
+ int index = type.indexOfDimension(name).orElseThrow(RuntimeException::new); // invariant, must exist
+ labels[i] = address.label(index);
+ }
+ return TensorAddress.of(labels);
+
+ }
+
}