diff options
author | Lester Solbakken <lesters@oath.com> | 2018-09-28 15:54:17 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-09-28 15:54:17 +0200 |
commit | ed3bc0556f0ec8b867c8f5d80b26f8e79b6a8b27 (patch) | |
tree | 88724f3c25c7eb9e2d8cfba40ac7a4e4eb1866ad /vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java | |
parent | e6f46440bd697d78921379dba4f7e55ca2d85c7a (diff) |
Add reduce-join optimization in Java
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java | 366 |
1 files changed, 366 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java new file mode 100644 index 00000000000..9f53ced7719 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -0,0 +1,366 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.DimensionSizes; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Arrays; +import java.util.List; +import java.util.function.DoubleBinaryOperator; +import java.util.stream.Collectors; + +/** + * An optimization for tensor expressions where a join immediately follows a + * reduce. Evaluating this as one operations is significantly more efficient + * than evaluating each separately. + * + * This implementation optimizes the case where the reduce is done on the same + * dimensions as the join. A particularly efficient evaluation is done if there + * is one common dimension that is joined and reduced on, which is a common + * case as it covers vector and matrix like multiplications. + * + * @author lesters + */ +public class ReduceJoin extends CompositeTensorFunction { + + private final TensorFunction argumentA, argumentB; + private final DoubleBinaryOperator combinator; + private final Reduce.Aggregator aggregator; + private final List<String> dimensions; + + public ReduceJoin(Reduce reduce, Join join) { + this(join.arguments().get(0), join.arguments().get(1), join.combinator(), reduce.aggregator(), reduce.dimensions()); + } + + public ReduceJoin(TensorFunction argumentA, + TensorFunction argumentB, + DoubleBinaryOperator combinator, + Reduce.Aggregator aggregator, + List<String> dimensions) { + this.argumentA = argumentA; + this.argumentB = argumentB; + this.combinator = combinator; + this.aggregator = aggregator; + this.dimensions = ImmutableList.copyOf(dimensions); + } + + @Override + public List<TensorFunction> arguments() { + return ImmutableList.of(argumentA, argumentB); + } + + @Override + public TensorFunction withArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 2) + throw new IllegalArgumentException("ReduceJoin must have 2 arguments, got " + arguments.size()); + return new ReduceJoin(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions); + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + Join join = new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); + return new Reduce(join, aggregator, dimensions); + } + + @Override + public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor a = argumentA.evaluate(context); + Tensor b = argumentB.evaluate(context); + TensorType joinType = new TensorType.Builder(a.type(), b.type()).build(); + + if (canOptimize(a, b)) { + return evaluate((IndexedTensor)a, (IndexedTensor)b, joinType); + } + return Reduce.evaluate(Join.evaluate(a, b, joinType, combinator), dimensions, aggregator); + } + + /** + * Tests whether or not the reduce is over the join dimensions. The + * remaining logic in this class assumes this to return true. + * + * If no dimensions are given, the join must be on all tensor dimensions. + */ + public boolean canOptimize(Tensor a, Tensor b) { + if (a.type().dimensions().size() == 0 || b.type().dimensions().size() == 0) // TODO: support scalars + return false; + if ( ! (a instanceof IndexedTensor)) + return false; + if ( ! (a.type().dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))) + return false; + if ( ! (b instanceof IndexedTensor)) + return false; + if ( ! (b.type().dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))) + return false; + + TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b); + if (dimensions.isEmpty()) { + if (a.type().dimensions().size() != commonDimensions.dimensions().size()) + return false; + if (b.type().dimensions().size() != commonDimensions.dimensions().size()) + return false; + } else { + for (TensorType.Dimension dimension : commonDimensions.dimensions()) { + if (!dimensions.contains(dimension.name())) + return false; + } + } + return true; + } + + /** + * Tests if the reduce dimension is the innermost dimension in both tensors + */ + private boolean reduceDimensionIsInnermost(Tensor a, Tensor b) { + List<String> reducingDimensions = dimensions; + if (reducingDimensions.isEmpty()) { + reducingDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b).dimensions().stream() + .map(TensorType.Dimension::name) + .collect(Collectors.toList()); + } + if (reducingDimensions.size() != 1) { + return false; + } + String dimension = reducingDimensions.get(0); + int indexInA = a.type().indexOfDimension(dimension).orElseThrow(() -> new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor A.")); + if (indexInA != (a.type().dimensions().size() - 1)) { + return false; + } + int indexInB = b.type().indexOfDimension(dimension).orElseThrow(() -> new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor B.")); + if (indexInB < (b.type().dimensions().size() - 1)) { + return false; + } + return true; + } + + /** + * Evaluates the reduce-join. Special handling for common cases where the + * reduce dimension is the innermost dimension in both tensors. + */ + private Tensor evaluate(IndexedTensor a, IndexedTensor b, TensorType joinType) { + TensorType reducedType = Reduce.outputType(joinType, dimensions); + + if (reduceDimensionIsInnermost(a, b)) { + if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 1) { + return vectorVectorProduct(a, b, reducedType); + } + if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 2) { + return vectorMatrixProduct(a, b, reducedType, false); + } + if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 1) { + return vectorMatrixProduct(b, a, reducedType, true); + } + if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 2) { + return matrixMatrixProduct(a, b, reducedType); + } + } + return evaluateGeneral(a, b, reducedType); + } + + private Tensor vectorVectorProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType) { + if ( a.type().dimensions().size() != 1 || b.type().dimensions().size() != 1) { + throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-vector product"); + } + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType); + long commonSize = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); + + Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator); + for (int ic = 0; ic < commonSize; ++ic) { + double va = a.get(ic); + double vb = b.get(ic); + agg.aggregate(combinator.applyAsDouble(va, vb)); + } + builder.cellByDirectIndex(0, agg.aggregatedValue()); + return builder.build(); + } + + private Tensor vectorMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType, boolean swapped) { + if ( a.type().dimensions().size() != 1 || b.type().dimensions().size() != 2) { + throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-matrix product"); + } + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType); + DimensionSizes sizesA = a.dimensionSizes(); + DimensionSizes sizesB = b.dimensionSizes(); + + Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator); + for (int ib = 0; ib < sizesB.size(0); ++ib) { + agg.reset(); + for (int ic = 0; ic < Math.min(sizesA.size(0), sizesB.size(1)); ++ic) { + double va = a.get(ic); + double vb = b.get(ib * sizesB.size(1) + ic); + double result = swapped ? combinator.applyAsDouble(vb, va) : combinator.applyAsDouble(va, vb); + agg.aggregate(result); + } + builder.cellByDirectIndex(ib, agg.aggregatedValue()); + } + return builder.build(); + } + + private Tensor matrixMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reduceType) { + if ( a.type().dimensions().size() != 2 || b.type().dimensions().size() != 2) { + throw new IllegalArgumentException("Wrong dimension sizes for tensors for matrix-matrix product"); + } + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reduceType); + DimensionSizes sizesA = a.dimensionSizes(); + DimensionSizes sizesB = b.dimensionSizes(); + int iaToReduced = reduceType.indexOfDimension(a.type().dimensions().get(0).name()).get(); + int ibToReduced = reduceType.indexOfDimension(b.type().dimensions().get(0).name()).get(); + long strideA = iaToReduced < ibToReduced ? sizesB.size(0) : 1; + long strideB = ibToReduced < iaToReduced ? sizesA.size(0) : 1; + + Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator); + for (int ia = 0; ia < sizesA.size(0); ++ia) { + for (int ib = 0; ib < sizesB.size(0); ++ib) { + agg.reset(); + for (int ic = 0; ic < Math.min(sizesA.size(1), sizesB.size(1)); ++ic) { + double va = a.get(ia * sizesA.size(1) + ic); + double vb = b.get(ib * sizesB.size(1) + ic); + agg.aggregate(combinator.applyAsDouble(va, vb)); + } + builder.cellByDirectIndex(ia * strideA + ib * strideB, agg.aggregatedValue()); + } + } + return builder.build(); + } + + private Tensor evaluateGeneral(IndexedTensor a, IndexedTensor b, TensorType reduceType) { + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reduceType); + TensorType onlyInA = Reduce.outputType(a.type(), dimensions); + TensorType onlyInB = Reduce.outputType(b.type(), dimensions); + TensorType common = dimensionsInCommon(a, b); + + // pre-calculate strides for each index position + long[] stridesA = strides(a.type()); + long[] stridesB = strides(b.type()); + long[] stridesResult = strides(reduceType); + + // mapping of dimension indexes + int[] mapOnlyAToA = Join.mapIndexes(onlyInA, a.type()); + int[] mapCommonToA = Join.mapIndexes(common, a.type()); + int[] mapOnlyBToB = Join.mapIndexes(onlyInB, b.type()); + int[] mapCommonToB = Join.mapIndexes(common, b.type()); + int[] mapOnlyAToResult = Join.mapIndexes(onlyInA, reduceType); + int[] mapOnlyBToResult = Join.mapIndexes(onlyInB, reduceType); + + // TODO: refactor with code in IndexedTensor and Join + + MultiDimensionIterator ic = new MultiDimensionIterator(common); + Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator); + for (MultiDimensionIterator ia = new MultiDimensionIterator(onlyInA); ia.hasNext(); ia.next()) { + for (MultiDimensionIterator ib = new MultiDimensionIterator(onlyInB); ib.hasNext(); ib.next()) { + agg.reset(); + for (ic.reset(); ic.hasNext(); ic.next()) { + double va = a.get(toDirectIndex(ia, ic, stridesA, mapOnlyAToA, mapCommonToA)); + double vb = b.get(toDirectIndex(ib, ic, stridesB, mapOnlyBToB, mapCommonToB)); + agg.aggregate(combinator.applyAsDouble(va, vb)); + } + builder.cellByDirectIndex(toDirectIndex(ia, ib, stridesResult, mapOnlyAToResult, mapOnlyBToResult), + agg.aggregatedValue()); + } + } + return builder.build(); + } + + private long toDirectIndex(MultiDimensionIterator iter, MultiDimensionIterator common, long[] strides, int[] map, int[] commonmap) { + long directIndex = 0; + for (int i = 0; i < iter.length(); ++i) { + directIndex += strides[map[i]] * iter.iterator[i]; + } + for (int i = 0; i < common.length(); ++i) { + directIndex += strides[commonmap[i]] * common.iterator[i]; + } + return directIndex; + } + + private long[] strides(TensorType type) { + long[] strides = new long[type.dimensions().size()]; + if (strides.length > 0) { + long previous = 1; + strides[strides.length - 1] = previous; + for (int i = strides.length - 2; i >= 0; --i) { + strides[i] = previous * type.dimensions().get(i + 1).size().get(); + previous = strides[i]; + } + } + return strides; + } + + private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension aDim : a.type().dimensions()) { + for (TensorType.Dimension bDim : b.type().dimensions()) { + if (aDim.name().equals(bDim.name())) { + if ( ! aDim.size().isPresent()) { + builder.set(aDim); + } else if ( ! bDim.size().isPresent()) { + builder.set(bDim); + } else { + builder.set(aDim.size().get() < bDim.size().get() ? aDim : bDim); // minimum size of dimension + } + } + } + } + return builder.build(); + } + + @Override + public String toString(ToStringContext context) { + return "reduce_join(" + argumentA.toString(context) + ", " + + argumentB.toString(context) + ", " + + combinator + ", " + + aggregator + + Reduce.commaSeparated(dimensions) + ")"; + } + + private static class MultiDimensionIterator { + + private long[] max; + private long[] iterator; + private int remaining; + + MultiDimensionIterator(TensorType type) { + max = new long[type.dimensions().size()]; + iterator = new long[type.dimensions().size()]; + for (int i = 0; i < max.length; ++i) { + max[i] = type.dimensions().get(i).size().get(); + } + reset(); + } + + public int length() { + return iterator.length; + } + + public boolean hasNext() { + return remaining > 0; + } + + public void reset() { + remaining = 1; + for (int i = iterator.length - 1; i >= 0; --i) { + iterator[i] = 0; + remaining *= max[i]; + } + } + + public void next() { + for (int i = iterator.length - 1; i >= 0; --i) { + iterator[i] += 1; + if (iterator[i] < max[i]) { + break; + } + iterator[i] = 0; + } + remaining -= 1; + } + + public String toString() { + return Arrays.toString(iterator); + } + } + +} |