aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-09-28 15:54:17 +0200
committerLester Solbakken <lesters@oath.com>2018-09-28 15:54:17 +0200
commited3bc0556f0ec8b867c8f5d80b26f8e79b6a8b27 (patch)
tree88724f3c25c7eb9e2d8cfba40ac7a4e4eb1866ad /vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
parente6f46440bd697d78921379dba4f7e55ca2d85c7a (diff)
Add reduce-join optimization in Java
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java366
1 files changed, 366 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
new file mode 100644
index 00000000000..9f53ced7719
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -0,0 +1,366 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.DimensionSizes;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+import java.util.stream.Collectors;
+
+/**
+ * An optimization for tensor expressions where a join immediately follows a
+ * reduce. Evaluating this as one operations is significantly more efficient
+ * than evaluating each separately.
+ *
+ * This implementation optimizes the case where the reduce is done on the same
+ * dimensions as the join. A particularly efficient evaluation is done if there
+ * is one common dimension that is joined and reduced on, which is a common
+ * case as it covers vector and matrix like multiplications.
+ *
+ * @author lesters
+ */
+public class ReduceJoin extends CompositeTensorFunction {
+
+ private final TensorFunction argumentA, argumentB;
+ private final DoubleBinaryOperator combinator;
+ private final Reduce.Aggregator aggregator;
+ private final List<String> dimensions;
+
+ public ReduceJoin(Reduce reduce, Join join) {
+ this(join.arguments().get(0), join.arguments().get(1), join.combinator(), reduce.aggregator(), reduce.dimensions());
+ }
+
+ public ReduceJoin(TensorFunction argumentA,
+ TensorFunction argumentB,
+ DoubleBinaryOperator combinator,
+ Reduce.Aggregator aggregator,
+ List<String> dimensions) {
+ this.argumentA = argumentA;
+ this.argumentB = argumentB;
+ this.combinator = combinator;
+ this.aggregator = aggregator;
+ this.dimensions = ImmutableList.copyOf(dimensions);
+ }
+
+ @Override
+ public List<TensorFunction> arguments() {
+ return ImmutableList.of(argumentA, argumentB);
+ }
+
+ @Override
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 2)
+ throw new IllegalArgumentException("ReduceJoin must have 2 arguments, got " + arguments.size());
+ return new ReduceJoin(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions);
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ Join join = new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
+ return new Reduce(join, aggregator, dimensions);
+ }
+
+ @Override
+ public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor a = argumentA.evaluate(context);
+ Tensor b = argumentB.evaluate(context);
+ TensorType joinType = new TensorType.Builder(a.type(), b.type()).build();
+
+ if (canOptimize(a, b)) {
+ return evaluate((IndexedTensor)a, (IndexedTensor)b, joinType);
+ }
+ return Reduce.evaluate(Join.evaluate(a, b, joinType, combinator), dimensions, aggregator);
+ }
+
+ /**
+ * Tests whether or not the reduce is over the join dimensions. The
+ * remaining logic in this class assumes this to return true.
+ *
+ * If no dimensions are given, the join must be on all tensor dimensions.
+ */
+ public boolean canOptimize(Tensor a, Tensor b) {
+ if (a.type().dimensions().size() == 0 || b.type().dimensions().size() == 0) // TODO: support scalars
+ return false;
+ if ( ! (a instanceof IndexedTensor))
+ return false;
+ if ( ! (a.type().dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)))
+ return false;
+ if ( ! (b instanceof IndexedTensor))
+ return false;
+ if ( ! (b.type().dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)))
+ return false;
+
+ TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b);
+ if (dimensions.isEmpty()) {
+ if (a.type().dimensions().size() != commonDimensions.dimensions().size())
+ return false;
+ if (b.type().dimensions().size() != commonDimensions.dimensions().size())
+ return false;
+ } else {
+ for (TensorType.Dimension dimension : commonDimensions.dimensions()) {
+ if (!dimensions.contains(dimension.name()))
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Tests if the reduce dimension is the innermost dimension in both tensors
+ */
+ private boolean reduceDimensionIsInnermost(Tensor a, Tensor b) {
+ List<String> reducingDimensions = dimensions;
+ if (reducingDimensions.isEmpty()) {
+ reducingDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b).dimensions().stream()
+ .map(TensorType.Dimension::name)
+ .collect(Collectors.toList());
+ }
+ if (reducingDimensions.size() != 1) {
+ return false;
+ }
+ String dimension = reducingDimensions.get(0);
+ int indexInA = a.type().indexOfDimension(dimension).orElseThrow(() -> new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor A."));
+ if (indexInA != (a.type().dimensions().size() - 1)) {
+ return false;
+ }
+ int indexInB = b.type().indexOfDimension(dimension).orElseThrow(() -> new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor B."));
+ if (indexInB < (b.type().dimensions().size() - 1)) {
+ return false;
+ }
+ return true;
+ }
+
+ /**
+ * Evaluates the reduce-join. Special handling for common cases where the
+ * reduce dimension is the innermost dimension in both tensors.
+ */
+ private Tensor evaluate(IndexedTensor a, IndexedTensor b, TensorType joinType) {
+ TensorType reducedType = Reduce.outputType(joinType, dimensions);
+
+ if (reduceDimensionIsInnermost(a, b)) {
+ if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 1) {
+ return vectorVectorProduct(a, b, reducedType);
+ }
+ if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 2) {
+ return vectorMatrixProduct(a, b, reducedType, false);
+ }
+ if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 1) {
+ return vectorMatrixProduct(b, a, reducedType, true);
+ }
+ if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 2) {
+ return matrixMatrixProduct(a, b, reducedType);
+ }
+ }
+ return evaluateGeneral(a, b, reducedType);
+ }
+
+ private Tensor vectorVectorProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
+ if ( a.type().dimensions().size() != 1 || b.type().dimensions().size() != 1) {
+ throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-vector product");
+ }
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
+ long commonSize = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
+
+ Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
+ for (int ic = 0; ic < commonSize; ++ic) {
+ double va = a.get(ic);
+ double vb = b.get(ic);
+ agg.aggregate(combinator.applyAsDouble(va, vb));
+ }
+ builder.cellByDirectIndex(0, agg.aggregatedValue());
+ return builder.build();
+ }
+
+ private Tensor vectorMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType, boolean swapped) {
+ if ( a.type().dimensions().size() != 1 || b.type().dimensions().size() != 2) {
+ throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-matrix product");
+ }
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
+ DimensionSizes sizesA = a.dimensionSizes();
+ DimensionSizes sizesB = b.dimensionSizes();
+
+ Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
+ for (int ib = 0; ib < sizesB.size(0); ++ib) {
+ agg.reset();
+ for (int ic = 0; ic < Math.min(sizesA.size(0), sizesB.size(1)); ++ic) {
+ double va = a.get(ic);
+ double vb = b.get(ib * sizesB.size(1) + ic);
+ double result = swapped ? combinator.applyAsDouble(vb, va) : combinator.applyAsDouble(va, vb);
+ agg.aggregate(result);
+ }
+ builder.cellByDirectIndex(ib, agg.aggregatedValue());
+ }
+ return builder.build();
+ }
+
+ private Tensor matrixMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reduceType) {
+ if ( a.type().dimensions().size() != 2 || b.type().dimensions().size() != 2) {
+ throw new IllegalArgumentException("Wrong dimension sizes for tensors for matrix-matrix product");
+ }
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reduceType);
+ DimensionSizes sizesA = a.dimensionSizes();
+ DimensionSizes sizesB = b.dimensionSizes();
+ int iaToReduced = reduceType.indexOfDimension(a.type().dimensions().get(0).name()).get();
+ int ibToReduced = reduceType.indexOfDimension(b.type().dimensions().get(0).name()).get();
+ long strideA = iaToReduced < ibToReduced ? sizesB.size(0) : 1;
+ long strideB = ibToReduced < iaToReduced ? sizesA.size(0) : 1;
+
+ Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
+ for (int ia = 0; ia < sizesA.size(0); ++ia) {
+ for (int ib = 0; ib < sizesB.size(0); ++ib) {
+ agg.reset();
+ for (int ic = 0; ic < Math.min(sizesA.size(1), sizesB.size(1)); ++ic) {
+ double va = a.get(ia * sizesA.size(1) + ic);
+ double vb = b.get(ib * sizesB.size(1) + ic);
+ agg.aggregate(combinator.applyAsDouble(va, vb));
+ }
+ builder.cellByDirectIndex(ia * strideA + ib * strideB, agg.aggregatedValue());
+ }
+ }
+ return builder.build();
+ }
+
+ private Tensor evaluateGeneral(IndexedTensor a, IndexedTensor b, TensorType reduceType) {
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reduceType);
+ TensorType onlyInA = Reduce.outputType(a.type(), dimensions);
+ TensorType onlyInB = Reduce.outputType(b.type(), dimensions);
+ TensorType common = dimensionsInCommon(a, b);
+
+ // pre-calculate strides for each index position
+ long[] stridesA = strides(a.type());
+ long[] stridesB = strides(b.type());
+ long[] stridesResult = strides(reduceType);
+
+ // mapping of dimension indexes
+ int[] mapOnlyAToA = Join.mapIndexes(onlyInA, a.type());
+ int[] mapCommonToA = Join.mapIndexes(common, a.type());
+ int[] mapOnlyBToB = Join.mapIndexes(onlyInB, b.type());
+ int[] mapCommonToB = Join.mapIndexes(common, b.type());
+ int[] mapOnlyAToResult = Join.mapIndexes(onlyInA, reduceType);
+ int[] mapOnlyBToResult = Join.mapIndexes(onlyInB, reduceType);
+
+ // TODO: refactor with code in IndexedTensor and Join
+
+ MultiDimensionIterator ic = new MultiDimensionIterator(common);
+ Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
+ for (MultiDimensionIterator ia = new MultiDimensionIterator(onlyInA); ia.hasNext(); ia.next()) {
+ for (MultiDimensionIterator ib = new MultiDimensionIterator(onlyInB); ib.hasNext(); ib.next()) {
+ agg.reset();
+ for (ic.reset(); ic.hasNext(); ic.next()) {
+ double va = a.get(toDirectIndex(ia, ic, stridesA, mapOnlyAToA, mapCommonToA));
+ double vb = b.get(toDirectIndex(ib, ic, stridesB, mapOnlyBToB, mapCommonToB));
+ agg.aggregate(combinator.applyAsDouble(va, vb));
+ }
+ builder.cellByDirectIndex(toDirectIndex(ia, ib, stridesResult, mapOnlyAToResult, mapOnlyBToResult),
+ agg.aggregatedValue());
+ }
+ }
+ return builder.build();
+ }
+
+ private long toDirectIndex(MultiDimensionIterator iter, MultiDimensionIterator common, long[] strides, int[] map, int[] commonmap) {
+ long directIndex = 0;
+ for (int i = 0; i < iter.length(); ++i) {
+ directIndex += strides[map[i]] * iter.iterator[i];
+ }
+ for (int i = 0; i < common.length(); ++i) {
+ directIndex += strides[commonmap[i]] * common.iterator[i];
+ }
+ return directIndex;
+ }
+
+ private long[] strides(TensorType type) {
+ long[] strides = new long[type.dimensions().size()];
+ if (strides.length > 0) {
+ long previous = 1;
+ strides[strides.length - 1] = previous;
+ for (int i = strides.length - 2; i >= 0; --i) {
+ strides[i] = previous * type.dimensions().get(i + 1).size().get();
+ previous = strides[i];
+ }
+ }
+ return strides;
+ }
+
+ private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) {
+ TensorType.Builder builder = new TensorType.Builder();
+ for (TensorType.Dimension aDim : a.type().dimensions()) {
+ for (TensorType.Dimension bDim : b.type().dimensions()) {
+ if (aDim.name().equals(bDim.name())) {
+ if ( ! aDim.size().isPresent()) {
+ builder.set(aDim);
+ } else if ( ! bDim.size().isPresent()) {
+ builder.set(bDim);
+ } else {
+ builder.set(aDim.size().get() < bDim.size().get() ? aDim : bDim); // minimum size of dimension
+ }
+ }
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "reduce_join(" + argumentA.toString(context) + ", " +
+ argumentB.toString(context) + ", " +
+ combinator + ", " +
+ aggregator +
+ Reduce.commaSeparated(dimensions) + ")";
+ }
+
+ private static class MultiDimensionIterator {
+
+ private long[] max;
+ private long[] iterator;
+ private int remaining;
+
+ MultiDimensionIterator(TensorType type) {
+ max = new long[type.dimensions().size()];
+ iterator = new long[type.dimensions().size()];
+ for (int i = 0; i < max.length; ++i) {
+ max[i] = type.dimensions().get(i).size().get();
+ }
+ reset();
+ }
+
+ public int length() {
+ return iterator.length;
+ }
+
+ public boolean hasNext() {
+ return remaining > 0;
+ }
+
+ public void reset() {
+ remaining = 1;
+ for (int i = iterator.length - 1; i >= 0; --i) {
+ iterator[i] = 0;
+ remaining *= max[i];
+ }
+ }
+
+ public void next() {
+ for (int i = iterator.length - 1; i >= 0; --i) {
+ iterator[i] += 1;
+ if (iterator[i] < max[i]) {
+ break;
+ }
+ iterator[i] = 0;
+ }
+ remaining -= 1;
+ }
+
+ public String toString() {
+ return Arrays.toString(iterator);
+ }
+ }
+
+}