From b2ea86592479a78a35e58151d8f9f51a5baa300c Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 6 Jul 2017 09:59:04 +0200 Subject: Small fixes to mapped hash join --- .../main/java/com/yahoo/tensor/functions/Join.java | 59 ++++++++++++++-------- 1 file changed, 37 insertions(+), 22 deletions(-) (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index a7aa1d66a3d..95e3c444c1b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -12,8 +12,16 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; import java.util.function.DoubleBinaryOperator; /** @@ -284,35 +292,44 @@ public class Join extends PrimitiveTensorFunction { private Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType) { TensorType commonDimensionType = commonDimensions(a, b); if (commonDimensionType.dimensions().isEmpty()) { - return mappedGeneralJoin(a, b, joinedType); // If no common dimensions, fall back to mappedGeneralJoin + return mappedGeneralJoin(a, b, joinedType); // fallback } - Tensor smallerTensor = a.size() > b.size() ? b : a; - Tensor largerTensor = a.size() > b.size() ? a : b; - a = smallerTensor; - b = largerTensor; + boolean switchTensors = a.size() > b.size(); + if (switchTensors) { + Tensor temp = a; + a = b; + b = temp; + } + + // Map dimension indexes to common and joined type + int[] aIndexesInCommon = mapIndexes(commonDimensionType, a.type()); + int[] bIndexesInCommon = mapIndexes(commonDimensionType, b.type()); + int[] aIndexesInJoined = mapIndexes(a.type(), joinedType); + int[] bIndexesInJoined = mapIndexes(b.type(), joinedType); // Iterate once through the smaller tensor and construct a hash map for common dimensions Map> aCellsByCommonAddress = new HashMap<>(); for (Iterator cellIterator = a.cellIterator(); cellIterator.hasNext(); ) { - Tensor.Cell cell = cellIterator.next(); - TensorAddress partialCommonAddress = partialCommonAddress(cell, a.type(), commonDimensionType); + Tensor.Cell aCell = cellIterator.next(); + TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon); aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList<>()); - aCellsByCommonAddress.get(partialCommonAddress).add(cell); + aCellsByCommonAddress.get(partialCommonAddress).add(aCell); } - // Iterate once through the larger tensor and use hash map to find joinable cells - int[] aToIndexes = mapIndexes(a.type(), joinedType); - int[] bToIndexes = mapIndexes(b.type(), joinedType); + // Iterate once through the larger tensor and use the hash map to find joinable cells Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator cellIterator = b.cellIterator(); cellIterator.hasNext(); ) { - Tensor.Cell cell = cellIterator.next(); - TensorAddress partialCommonAddress = partialCommonAddress(cell, b.type(), commonDimensionType); + Tensor.Cell bCell = cellIterator.next(); + TensorAddress partialCommonAddress = partialCommonAddress(bCell, bIndexesInCommon); for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, Collections.emptyList())) { - TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes, - cell.getKey(), bToIndexes, joinedType); + TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined, + bCell.getKey(), bIndexesInJoined, joinedType); if (combinedAddress == null) continue; // not combinable - builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), cell.getValue())); + double combinedValue = switchTensors ? + combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) : + combinator.applyAsDouble(aCell.getValue(), bCell.getValue()); + builder.cell(combinedAddress, combinedValue); } } @@ -378,13 +395,11 @@ public class Join extends PrimitiveTensorFunction { return typeBuilder.build(); } - private TensorAddress partialCommonAddress(Tensor.Cell cell, TensorType type, TensorType commonDimensions) { + private TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) { TensorAddress address = cell.getKey(); - String[] labels = new String[commonDimensions.dimensions().size()]; + String[] labels = new String[indexMap.length]; for (int i = 0; i < labels.length; ++i) { - String name = commonDimensions.dimensions().get(i).name(); - int index = type.indexOfDimension(name).orElseThrow(RuntimeException::new); // invariant, must exist - labels[i] = address.label(index); + labels[i] = address.label(indexMap[i]); } return TensorAddress.of(labels); -- cgit v1.2.3