diff options
author | Lester Solbakken <lesters@yahoo-inc.com> | 2017-07-05 19:36:16 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@yahoo-inc.com> | 2017-07-05 19:36:16 +0200 |
commit | 27aff3764de178e952634466733c5b4ac6b252c1 (patch) | |
tree | 894beb2ae72ce2300c7785f49b92f42e34bd3c26 /vespajlib | |
parent | 61c0690fc5538257c6729f98e3695fa47a586437 (diff) |
Optimize general mapped tensor join
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 82 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java | 8 |
2 files changed, 78 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index afb2bb6fb8c..a7aa1d66a3d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -12,13 +12,8 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; +import java.util.*; import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; import java.util.function.DoubleBinaryOperator; /** @@ -210,7 +205,7 @@ public class Join extends PrimitiveTensorFunction { if (a instanceof IndexedTensor && b instanceof IndexedTensor) return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType); else - return mappedGeneralJoin(a, b, joinedType); + return mappedHashJoin(a, b, joinedType); } private Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType) { @@ -286,6 +281,45 @@ public class Join extends PrimitiveTensorFunction { return builder.build(); } + private Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType) { + TensorType commonDimensionType = commonDimensions(a, b); + if (commonDimensionType.dimensions().isEmpty()) { + return mappedGeneralJoin(a, b, joinedType); // If no common dimensions, fall back to mappedGeneralJoin + } + + Tensor smallerTensor = a.size() > b.size() ? b : a; + Tensor largerTensor = a.size() > b.size() ? a : b; + a = smallerTensor; + b = largerTensor; + + // Iterate once through the smaller tensor and construct a hash map for common dimensions + Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(); + for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) { + Tensor.Cell cell = cellIterator.next(); + TensorAddress partialCommonAddress = partialCommonAddress(cell, a.type(), commonDimensionType); + aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList<>()); + aCellsByCommonAddress.get(partialCommonAddress).add(cell); + } + + // Iterate once through the larger tensor and use hash map to find joinable cells + int[] aToIndexes = mapIndexes(a.type(), joinedType); + int[] bToIndexes = mapIndexes(b.type(), joinedType); + Tensor.Builder builder = Tensor.Builder.of(joinedType); + for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) { + Tensor.Cell cell = cellIterator.next(); + TensorAddress partialCommonAddress = partialCommonAddress(cell, b.type(), commonDimensionType); + for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, Collections.emptyList())) { + TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes, + cell.getKey(), bToIndexes, joinedType); + if (combinedAddress == null) continue; // not combinable + builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), cell.getValue())); + } + } + + return builder.build(); + } + + /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. @@ -323,5 +357,37 @@ public class Join extends PrimitiveTensorFunction { } return true; } - + + + /** + * Returns common dimension of a and b as a new tensor type + */ + private TensorType commonDimensions(Tensor a, Tensor b) { + TensorType.Builder typeBuilder = new TensorType.Builder(); + TensorType aType = a.type(); + TensorType bType = b.type(); + for (int i = 0; i < aType.dimensions().size(); ++i) { + TensorType.Dimension aDim = aType.dimensions().get(i); + for (int j = 0; j < bType.dimensions().size(); ++j) { + TensorType.Dimension bDim = bType.dimensions().get(j); + if (aDim.equals(bDim)) { + typeBuilder.set(bDim); + } + } + } + return typeBuilder.build(); + } + + private TensorAddress partialCommonAddress(Tensor.Cell cell, TensorType type, TensorType commonDimensions) { + TensorAddress address = cell.getKey(); + String[] labels = new String[commonDimensions.dimensions().size()]; + for (int i = 0; i < labels.length; ++i) { + String name = commonDimensions.dimensions().get(i).name(); + int index = type.indexOfDimension(name).orElseThrow(RuntimeException::new); // invariant, must exist + labels[i] = address.label(index); + } + return TensorAddress.of(labels); + + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index 32b19ecf44f..c3086fd8c05 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -107,11 +107,11 @@ public class TensorFunctionBenchmark { double time = 0; // ---------------- Mapped with extra space (sidesteps current special-case optimizations): - // 410 ms - time = new TensorFunctionBenchmark().benchmark(20, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); + // 11.2 ms + time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time); - // 770 ms - time = new TensorFunctionBenchmark().benchmark(20, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); + // 10.8 ms + time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time); // ---------------- Mapped: |