diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-29 12:53:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-29 12:53:37 +0200 |
commit | bfc4feb4f5b9b2cec76bd08027cdf3c2ca339f39 (patch) | |
tree | 3bd0250db894ae103c85062556b4b34c59e38e84 | |
parent | 44313db093c7954c6fa979296832ed44dd28991d (diff) | |
parent | 35da647c2c32ea3a81b5c65bd440af1dc59b1b3c (diff) |
Merge pull request #7145 from vespa-engine/lesters/add-java-reduce-join-optimization
Add reduce-join optimization in Java
11 files changed, 757 insertions, 70 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java index 7060cfc2132..84a90ee64c2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTForestOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTOptimizer; +import com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization.TensorOptimizer; /** * This class will perform various optimizations on the ranking expressions. Clients using optimized expressions @@ -32,8 +33,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTOpt public class ExpressionOptimizer { private GBDTOptimizer gbdtOptimizer = new GBDTOptimizer(); - private GBDTForestOptimizer gbdtForestOptimizer = new GBDTForestOptimizer(); + private TensorOptimizer tensorOptimizer = new TensorOptimizer(); /** Gets an optimizer instance used by this by class name, or null if the optimizer is not known */ public Optimizer getOptimizer(Class<?> clazz) { @@ -41,6 +42,8 @@ public class ExpressionOptimizer { return gbdtOptimizer; if (clazz == gbdtForestOptimizer.getClass()) return gbdtForestOptimizer; + if (clazz == tensorOptimizer.getClass()) + return tensorOptimizer; return null; } @@ -49,6 +52,7 @@ public class ExpressionOptimizer { // Note: Order of optimizations matter gbdtOptimizer.optimize(expression, contextIndex, report); gbdtForestOptimizer.optimize(expression, contextIndex, report); + tensorOptimizer.optimize(expression, contextIndex, report); return report; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java new file mode 100644 index 00000000000..63cea371d14 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java @@ -0,0 +1,85 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; +import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; +import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ReduceJoin; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; + +/** + * Recognizes and optimizes tensor expressions. + * + * @author lesters + */ +public class TensorOptimizer extends Optimizer { + + private OptimizationReport report; + + @Override + public void optimize(RankingExpression expression, ContextIndex context, OptimizationReport report) { + if (!isEnabled()) return; + this.report = report; + expression.setRoot(optimize(expression.getRoot(), context)); + report.note("Tensor expression optimization done"); + } + + private ExpressionNode optimize(ExpressionNode node, ContextIndex context) { + node = optimizeReduceJoin(node); + if (node instanceof CompositeNode) { + return optimizeChildren((CompositeNode)node, context); + } + return node; + } + + private ExpressionNode optimizeChildren(CompositeNode node, ContextIndex context) { + List<ExpressionNode> children = node.children(); + List<ExpressionNode> optimizedChildren = new ArrayList<>(children.size()); + for (ExpressionNode child : children) + optimizedChildren.add(optimize(child, context)); + return node.setChildren(optimizedChildren); + } + + /** + * Recognized a reduce followed by a join. In many cases, chunking these + * two operations together is significantly more efficient than evaluating + * each on its own, avoiding the cost of a temporary tensor. + * + * Note that this does not guarantee that the optimization is performed. + * The ReduceJoin class determines whether or not the arguments are + * compatible with the optimization. + */ + private ExpressionNode optimizeReduceJoin(ExpressionNode node) { + if ( ! (node instanceof TensorFunctionNode)) { + return node; + } + TensorFunction function = ((TensorFunctionNode) node).function(); + if ( ! (function instanceof Reduce)) { + return node; + } + List<ExpressionNode> children = ((TensorFunctionNode) node).children(); + if (children.size() != 1) { + return node; + } + ExpressionNode child = children.get(0); + if ( ! (child instanceof TensorFunctionNode)) { + return node; + } + TensorFunction argument = ((TensorFunctionNode) child).function(); + if (argument instanceof Join) { + report.incMetric("Replaced reduce->join", 1); + return new TensorFunctionNode(new ReduceJoin((Reduce)function, (Join)argument)); + } + return node; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index de98b01287e..3a3410aeebb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -12,6 +12,7 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; import java.util.List; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; @@ -90,7 +91,47 @@ public class LambdaFunctionNode extends CompositeNode { if (arguments.size() > 2) throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " + "Must have at most two argument " + " but has " + arguments); - return new DoubleBinaryLambda(); + + // Optimization: if possible, calculate directly rather than creating a context and evaluating the expression + return getDirectEvaluator().orElseGet(DoubleBinaryLambda::new); + } + + private Optional<DoubleBinaryOperator> getDirectEvaluator() { + if ( ! (functionExpression instanceof ArithmeticNode)) { + return Optional.empty(); + } + ArithmeticNode node = (ArithmeticNode) functionExpression; + if ( ! (node.children().get(0) instanceof ReferenceNode) || ! (node.children().get(1) instanceof ReferenceNode)) { + return Optional.empty(); + } + if (node.operators().size() != 1) { + return Optional.empty(); + } + ArithmeticOperator operator = node.operators().get(0); + switch (operator) { + case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0); + case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0); + case PLUS: return asFunctionExpression((left, right) -> left + right); + case MINUS: return asFunctionExpression((left, right) -> left - right); + case MULTIPLY: return asFunctionExpression((left, right) -> left * right); + case DIVIDE: return asFunctionExpression((left, right) -> left / right); + case MODULO: return asFunctionExpression((left, right) -> left % right); + case POWER: return asFunctionExpression(Math::pow); + } + return Optional.empty(); + } + + private Optional<DoubleBinaryOperator> asFunctionExpression(DoubleBinaryOperator operator) { + return Optional.of(new DoubleBinaryOperator() { + @Override + public double applyAsDouble(double left, double right) { + return operator.applyAsDouble(left, right); + } + @Override + public String toString() { + return LambdaFunctionNode.this.toString(); + } + }); } private class DoubleUnaryLambda implements DoubleUnaryOperator { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java new file mode 100644 index 00000000000..f29083bddc9 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java @@ -0,0 +1,116 @@ +package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext; +import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ReduceJoin; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author lesters + */ +public class TensorOptimizerTestCase { + + @Test + public void testReduceJoinOptimization() throws ParseException { + assertWillOptimize("d0[3]", "d0[3]"); + assertWillOptimize("d0[1]", "d0[1]", "d0"); + assertWillOptimize("d0[2]", "d0[2]", "d0"); + assertWillOptimize("d0[1]", "d0[3]", "d0"); + assertWillOptimize("d0[3]", "d0[3]", "d0"); + assertWillOptimize("d0[3]", "d0[3],d1[2]", "d0"); + assertWillOptimize("d0[3],d1[2]", "d0[3]", "d0"); + assertWillOptimize("d1[3]", "d0[2],d1[3]", "d1"); + assertWillOptimize("d0[2],d1[3]", "d1[3]", "d1"); + assertWillOptimize("d0[2],d2[2]", "d1[3],d2[2]", "d2"); + assertWillOptimize("d1[2],d2[2]", "d0[3],d2[2]", "d2"); + assertWillOptimize("d0[1],d2[2]", "d1[3],d2[4]", "d2"); + assertWillOptimize("d0[2],d2[2]", "d1[3],d2[4]", "d2"); + assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]"); + assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]", "d0,d1"); + assertWillOptimize("d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3"); + assertWillOptimize("d0[1],d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3"); + assertWillOptimize("d0[1],d1[2],d2[3]", "d2[3],d3[4],d4[5]", "d2"); + assertWillOptimize("d0[1],d1[2],d2[3]", "d1[2],d2[3],d4[4]", "d1,d2"); + assertWillOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]"); + assertWillOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]", "d0,d1,d2"); + + // Will not currently use reduce-join optimization + assertCantOptimize("d0[2],d1[3]", "d1[3]", "d0"); // reducing on a dimension not joining on + assertCantOptimize("d0[1],d1[2]", "d1[2],d2[3]", "d2"); // same + assertCantOptimize("d0[3]", "d0[3],d1[2]"); // reducing on more then we are combining + assertCantOptimize("d0[1],d2[3]", "d1[2],d2[3]"); // same + assertCantOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]", "d1,d2"); // reducing on less then joining on + } + + private void assertWillOptimize(String aType, String bType) throws ParseException { + assertWillOptimize(aType, bType, "", "sum"); + } + + private void assertWillOptimize(String aType, String bType, String reduceDim) throws ParseException { + assertWillOptimize(aType, bType, reduceDim, "sum"); + } + + private void assertWillOptimize(String aType, String bType, String reduceDim, String aggregator) throws ParseException { + assertReduceJoin(aType, bType, reduceDim, aggregator, true); + } + + private void assertCantOptimize(String aType, String bType) throws ParseException { + assertCantOptimize(aType, bType, "", "sum"); + } + + private void assertCantOptimize(String aType, String bType, String reduceDim) throws ParseException { + assertCantOptimize(aType, bType, reduceDim, "sum"); + } + + private void assertCantOptimize(String aType, String bType, String reduceDim, String aggregator) throws ParseException { + assertReduceJoin(aType, bType, reduceDim, aggregator, false); + } + + private void assertReduceJoin(String aType, String bType, String reduceDim, String aggregator, boolean assertOptimize) throws ParseException { + Tensor a = generateRandomTensor(aType); + Tensor b = generateRandomTensor(bType); + RankingExpression expression = generateRankingExpression(reduceDim, aggregator); + assert ((TensorFunctionNode)expression.getRoot()).function() instanceof Reduce; + + ArrayContext context = generateContext(a, b, expression); + Tensor result = expression.evaluate(context).asTensor(); + + ExpressionOptimizer optimizer = new ExpressionOptimizer(); + OptimizationReport report = optimizer.optimize(expression, context); + assertEquals(1, report.getMetric("Replaced reduce->join")); + assert ((TensorFunctionNode)expression.getRoot()).function() instanceof ReduceJoin; + + assertEquals(result, expression.evaluate(context).asTensor()); + assertEquals(assertOptimize, ((ReduceJoin)((TensorFunctionNode)expression.getRoot()).function()).canOptimize(a, b)); + } + + private RankingExpression generateRankingExpression(String reduceDim, String aggregator) throws ParseException { + String dimensions = ""; + if (reduceDim.length() > 0) { + dimensions = ", " + reduceDim; + } + return new RankingExpression("reduce(join(a, b, f(a,b)(a * b)), " + aggregator + dimensions + ")"); + } + + private ArrayContext generateContext(Tensor a, Tensor b, RankingExpression expression) { + ArrayContext context = new ArrayContext(expression); + context.put("a", new TensorValue(a)); + context.put("b", new TensorValue(b)); + return context; + } + + private Tensor generateRandomTensor(String type) { + return Tensor.random(TensorType.fromSpec("tensor(" + type + ")")); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java index 5447e5240f7..19b3329e8a4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java @@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; +import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; @@ -50,7 +52,11 @@ public class TestableTensorFlowModel { model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); - Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); + RankingExpression expression = model.expressions().get(operationName); + ExpressionOptimizer optimizer = new ExpressionOptimizer(); + optimizer.optimize(expression, (ContextIndex)context); + + Tensor vespaResult = expression.evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", tfResult.sum().asDouble(), vespaResult.sum().asDouble(), delta); } @@ -64,7 +70,11 @@ public class TestableTensorFlowModel { model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); - Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); + RankingExpression expression = model.expressions().get(operationName); + ExpressionOptimizer optimizer = new ExpressionOptimizer(); + optimizer.optimize(expression, (ContextIndex)context); + + Tensor vespaResult = expression.evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); } @@ -82,7 +92,7 @@ public class TestableTensorFlowModel { } private Context contextFrom(ImportedModel result) { - MapContext context = new MapContext(); + TestableModelContext context = new TestableModelContext(); result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); return context; @@ -118,4 +128,15 @@ public class TestableTensorFlowModel { } } + private static class TestableModelContext extends MapContext implements ContextIndex { + @Override + public int size() { + return bindings().size(); + } + @Override + public int getIndex(String name) { + throw new UnsupportedOperationException(this + " does not support index lookup by name"); + } + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 38979411313..2d127eb86cf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -116,7 +116,14 @@ public class IndexedTensor implements Tensor { } } - private double get(long valueIndex) { return values[(int)valueIndex]; } + /** + * Returns the value at the given index by direct lookup. Only use + * if you know the underlying data layout. + * + * @param valueIndex the direct index into the underlying data. + * @throws IndexOutOfBoundsException if index is out of bounds + */ + public double get(long valueIndex) { return values[(int)valueIndex]; } private static long toValueIndex(long[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index c846364df29..1d447ed3eed 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -462,7 +462,7 @@ public class TensorType { return add(new MappedDimension(name)); } - /** Adds the give dimension */ + /** Adds the given dimension */ public Builder dimension(Dimension dimension) { return add(dimension); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 2a52df20108..5dd2cc442aa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -22,7 +22,7 @@ public abstract class CompositeTensorFunction extends TensorFunction { /** Evaluates this by first converting it to a primitive function */ @Override - public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return toPrimitive().evaluate(context); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index be323313369..62ee471fcf4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -82,25 +82,29 @@ public class Join extends PrimitiveTensorFunction { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); + return evaluate(a, b, joinedType, combinator); + } + static Tensor evaluate(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { // Choose join algorithm if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) - return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType); + return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType, combinator); else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size()) - return singleSpaceJoin(a, b, joinedType); + return singleSpaceJoin(a, b, joinedType, combinator); else if (a.type().dimensions().containsAll(b.type().dimensions())) - return subspaceJoin(b, a, joinedType, true); + return subspaceJoin(b, a, joinedType, true, combinator); else if (b.type().dimensions().containsAll(a.type().dimensions())) - return subspaceJoin(a, b, joinedType, false); + return subspaceJoin(a, b, joinedType, false, combinator); else - return generalJoin(a, b, joinedType); + return generalJoin(a, b, joinedType, combinator); + } - private boolean hasSingleIndexedDimension(Tensor tensor) { + private static boolean hasSingleIndexedDimension(Tensor tensor) { return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed(); } - private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { + private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) { long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); @@ -111,7 +115,7 @@ public class Join extends PrimitiveTensorFunction { } /** When both tensors have the same dimensions, at most one cell matches a cell in the other tensor */ - private Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType) { + private static Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); @@ -123,14 +127,14 @@ public class Join extends PrimitiveTensorFunction { } /** Join a tensor into a superspace */ - private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { + private static Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) { if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) - return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder); + return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder, combinator); else - return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder); + return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder, combinator); } - private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { + private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) { if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build(); @@ -145,16 +149,17 @@ public class Join extends PrimitiveTensorFunction { for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) { IndexedTensor.SubspaceIterator subspaceInSuper = i.next(); joinSubspaces(subspace.valueIterator(), subspace.size(), - subspaceInSuper, subspaceInSuper.size(), - reversedArgumentOrder, builder); + subspaceInSuper, subspaceInSuper.size(), + reversedArgumentOrder, builder, combinator); } return builder.build(); } - private void joinSubspaces(Iterator<Double> subspace, long subspaceSize, - Iterator<Tensor.Cell> superspace, long superspaceSize, - boolean reversedArgumentOrder, IndexedTensor.Builder builder) { + private static void joinSubspaces(Iterator<Double> subspace, long subspaceSize, + Iterator<Tensor.Cell> superspace, long superspaceSize, + boolean reversedArgumentOrder, IndexedTensor.Builder builder, + DoubleBinaryOperator combinator) { long joinedLength = Math.min(subspaceSize, superspaceSize); if (reversedArgumentOrder) { for (int i = 0; i < joinedLength; i++) { @@ -169,7 +174,7 @@ public class Join extends PrimitiveTensorFunction { } } - private DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) { + private static DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) { DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size()); for (int i = 0; i < builder.dimensions(); i++) { String dimensionName = joinedType.dimensions().get(i).name(); @@ -185,7 +190,7 @@ public class Join extends PrimitiveTensorFunction { return builder.build(); } - private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { + private static Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) { int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type()); Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) { @@ -194,21 +199,21 @@ public class Join extends PrimitiveTensorFunction { double subspaceValue = subspace.get(subaddress); if ( ! Double.isNaN(subspaceValue)) builder.cell(supercell.getKey(), - reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) - : combinator.applyAsDouble(subspaceValue, supercell.getValue())); + reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) + : combinator.applyAsDouble(subspaceValue, supercell.getValue())); } return builder.build(); } /** Returns the indexes in the superspace type which should be retained to create the subspace type */ - private int[] subspaceIndexes(TensorType supertype, TensorType subtype) { + private static int[] subspaceIndexes(TensorType supertype, TensorType subtype) { int[] subspaceIndexes = new int[subtype.dimensions().size()]; for (int i = 0; i < subtype.dimensions().size(); i++) subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get(); return subspaceIndexes; } - private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { + private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { String[] subspaceLabels = new String[subspaceIndexes.length]; for (int i = 0; i < subspaceIndexes.length; i++) subspaceLabels[i] = superAddress.label(subspaceIndexes[i]); @@ -216,25 +221,25 @@ public class Join extends PrimitiveTensorFunction { } /** Slow join which works for any two tensors */ - private Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType) { + private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { if (a instanceof IndexedTensor && b instanceof IndexedTensor) - return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType); + return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType, combinator); else - return mappedHashJoin(a, b, joinedType); + return mappedHashJoin(a, b, joinedType, combinator); } - private Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType) { + private static Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType, DoubleBinaryOperator combinator) { DimensionSizes joinedSize = joinedSize(joinedType, a, b); Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize); int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); - joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder); -// joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder); + joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, builder, combinator); return builder.build(); } - private void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, - int[] aToIndexes, int[] bToIndexes, boolean reversedOrder, Tensor.Builder builder) { + private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, + int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder, + DoubleBinaryOperator combinator) { Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()); Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()); @@ -252,15 +257,14 @@ public class Join extends PrimitiveTensorFunction { for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) { Tensor.Cell bCell = bSubspace.next(); TensorAddress joinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType); - double joinedValue = reversedOrder ? combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) - : combinator.applyAsDouble(aCell.getValue(), bCell.getValue()); + double joinedValue = combinator.applyAsDouble(aCell.getValue(), bCell.getValue()); builder.cell(joinedAddress, joinedValue); } } } } - private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { + private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); for (int i = 0; i < addressType.dimensions().size(); i++) if (retainDimensions.contains(addressType.dimensions().get(i).name())) @@ -269,7 +273,7 @@ public class Join extends PrimitiveTensorFunction { } /** Returns the sizes from the joined sizes which are present in the type argument */ - private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) { + private static DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) { DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); int dimensionIndex = 0; for (int i = 0; i < joinedType.dimensions().size(); i++) { @@ -279,7 +283,7 @@ public class Join extends PrimitiveTensorFunction { return builder.build(); } - private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) { + private static Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); Tensor.Builder builder = Tensor.Builder.of(joinedType); @@ -288,7 +292,7 @@ public class Join extends PrimitiveTensorFunction { for (Iterator<Tensor.Cell> bIterator = b.cellIterator(); bIterator.hasNext(); ) { Map.Entry<TensorAddress, Double> bCell = bIterator.next(); TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes, - bCell.getKey(), bToIndexes, joinedType); + bCell.getKey(), bToIndexes, joinedType); if (combinedAddress == null) continue; // not combinable builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue())); } @@ -296,10 +300,10 @@ public class Join extends PrimitiveTensorFunction { return builder.build(); } - private Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType) { + private static Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { TensorType commonDimensionType = commonDimensions(a, b); if (commonDimensionType.dimensions().isEmpty()) { - return mappedGeneralJoin(a, b, joinedType); // fallback + return mappedGeneralJoin(a, b, joinedType, combinator); // fallback } boolean swapTensors = a.size() > b.size(); @@ -351,15 +355,15 @@ public class Join extends PrimitiveTensorFunction { * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ - private int[] mapIndexes(TensorType fromType, TensorType toType) { + static int[] mapIndexes(TensorType fromType, TensorType toType) { int[] toIndexes = new int[fromType.dimensions().size()]; for (int i = 0; i < fromType.dimensions().size(); i++) toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); return toIndexes; } - private TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, - TensorType joinedType) { + private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, + TensorType joinedType) { String[] joinedLabels = new String[joinedType.dimensions().size()]; mapContent(a, joinedLabels, aToIndexes); boolean compatible = mapContent(b, joinedLabels, bToIndexes); @@ -373,7 +377,7 @@ public class Join extends PrimitiveTensorFunction { * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { + private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { for (int i = 0; i < from.size(); i++) { int toIndex = indexMap[i]; if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false; @@ -382,11 +386,10 @@ public class Join extends PrimitiveTensorFunction { return true; } - /** * Returns common dimension of a and b as a new tensor type */ - private TensorType commonDimensions(Tensor a, Tensor b) { + private static TensorType commonDimensions(Tensor a, Tensor b) { TensorType.Builder typeBuilder = new TensorType.Builder(); TensorType aType = a.type(); TensorType bType = b.type(); @@ -402,14 +405,14 @@ public class Join extends PrimitiveTensorFunction { return typeBuilder.build(); } - private TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) { + private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) { TensorAddress address = cell.getKey(); String[] labels = new String[indexMap.length]; for (int i = 0; i < labels.length; ++i) { labels[i] = address.label(indexMap[i]); } return TensorAddress.of(labels); - } } + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 332150f957d..54d7710c9dc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -32,7 +32,7 @@ public class Reduce extends PrimitiveTensorFunction { private final List<String> dimensions; private final Aggregator aggregator; - /** Creates a reduce function reducing aLL dimensions */ + /** Creates a reduce function reducing all dimensions */ public Reduce(TensorFunction argument, Aggregator aggregator) { this(argument, aggregator, Collections.emptyList()); } @@ -61,6 +61,7 @@ public class Reduce extends PrimitiveTensorFunction { } public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) { + if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all TensorType.Builder b = new TensorType.Builder(); for (TensorType.Dimension dimension : inputType.dimensions()) { if ( ! reduceDimensions.contains(dimension.name())) @@ -71,6 +72,10 @@ public class Reduce extends PrimitiveTensorFunction { public TensorFunction argument() { return argument; } + Aggregator aggregator() { return aggregator; } + + List<String> dimensions() { return dimensions; } + @Override public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @@ -91,7 +96,7 @@ public class Reduce extends PrimitiveTensorFunction { return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } - private String commaSeparated(List<String> list) { + static String commaSeparated(List<String> list) { StringBuilder b = new StringBuilder(); for (String element : list) b.append(", ").append(element); @@ -100,10 +105,10 @@ public class Reduce extends PrimitiveTensorFunction { @Override public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { - return type(argument.type(context)); + return type(argument.type(context), dimensions); } - private TensorType type(TensorType argumentType) { + private static TensorType type(TensorType argumentType, List<String> dimensions) { if (dimensions.isEmpty()) return TensorType.empty; // means reduce all TensorType.Builder builder = new TensorType.Builder(); for (TensorType.Dimension dimension : argumentType.dimensions()) @@ -114,7 +119,10 @@ public class Reduce extends PrimitiveTensorFunction { @Override public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { - Tensor argument = this.argument.evaluate(context); + return evaluate(this.argument.evaluate(context), dimensions, aggregator); + } + + static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) { if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor"); @@ -122,17 +130,17 @@ public class Reduce extends PrimitiveTensorFunction { // Special case: Reduce all if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor) - return reduceIndexedVector((IndexedTensor)argument); + return reduceIndexedVector((IndexedTensor)argument, aggregator); else - return reduceAllGeneral(argument); + return reduceAllGeneral(argument, aggregator); - TensorType reducedType = type(argument.type()); + TensorType reducedType = type(argument.type(), dimensions); // Reduce cells Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); - TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType); + TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType, dimensions); aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); aggregatingCells.get(reducedAddress).aggregate(cell.getValue()); } @@ -141,11 +149,12 @@ public class Reduce extends PrimitiveTensorFunction { reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue()); return reducedBuilder.build(); + } - private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) { + private static TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType, List<String> dimensions) { Set<Integer> indexesToRemove = new HashSet<>(); - for (String dimensionToRemove : this.dimensions) + for (String dimensionToRemove : dimensions) indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get()); String[] reducedLabels = new String[reducedType.dimensions().size()]; @@ -156,23 +165,23 @@ public class Reduce extends PrimitiveTensorFunction { return TensorAddress.of(reducedLabels); } - private Tensor reduceAllGeneral(Tensor argument) { + private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); ) valueAggregator.aggregate(i.next()); return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); } - private Tensor reduceIndexedVector(IndexedTensor argument) { + private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (int i = 0; i < argument.dimensionSizes().size(0); i++) valueAggregator.aggregate(argument.get(i)); return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); } - private static abstract class ValueAggregator { + static abstract class ValueAggregator { - private static ValueAggregator ofType(Aggregator aggregator) { + static ValueAggregator ofType(Aggregator aggregator) { switch (aggregator) { case avg : return new AvgAggregator(); case count : return new CountAggregator(); @@ -191,6 +200,9 @@ public class Reduce extends PrimitiveTensorFunction { /** Returns the value aggregated by this */ public abstract double aggregatedValue(); + /** Resets the aggregator */ + public abstract void reset(); + } private static class AvgAggregator extends ValueAggregator { @@ -209,6 +221,11 @@ public class Reduce extends PrimitiveTensorFunction { return valueSum / valueCount; } + @Override + public void reset() { + valueCount = 0; + valueSum = 0.0; + } } private static class CountAggregator extends ValueAggregator { @@ -225,6 +242,10 @@ public class Reduce extends PrimitiveTensorFunction { return valueCount; } + @Override + public void reset() { + valueCount = 0; + } } private static class ProdAggregator extends ValueAggregator { @@ -241,6 +262,10 @@ public class Reduce extends PrimitiveTensorFunction { return valueProd; } + @Override + public void reset() { + valueProd = 1.0; + } } private static class SumAggregator extends ValueAggregator { @@ -257,6 +282,10 @@ public class Reduce extends PrimitiveTensorFunction { return valueSum; } + @Override + public void reset() { + valueSum = 0.0; + } } private static class MaxAggregator extends ValueAggregator { @@ -274,6 +303,10 @@ public class Reduce extends PrimitiveTensorFunction { return maxValue; } + @Override + public void reset() { + maxValue = Double.MIN_VALUE; + } } private static class MinAggregator extends ValueAggregator { @@ -291,6 +324,11 @@ public class Reduce extends PrimitiveTensorFunction { return minValue; } + @Override + public void reset() { + minValue = Double.MAX_VALUE; + } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java new file mode 100644 index 00000000000..b268e33b418 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -0,0 +1,372 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.DimensionSizes; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Arrays; +import java.util.List; +import java.util.function.DoubleBinaryOperator; +import java.util.stream.Collectors; + +/** + * An optimization for tensor expressions where a join immediately follows a + * reduce. Evaluating this as one operation is significantly more efficient + * than evaluating each separately. + * + * This implementation optimizes the case where the reduce is done on the same + * dimensions as the join. A particularly efficient evaluation is done if there + * is one common dimension that is joined and reduced on, which is a common + * case as it covers vector and matrix like multiplications. + * + * @author lesters + */ +public class ReduceJoin extends CompositeTensorFunction { + + private final TensorFunction argumentA, argumentB; + private final DoubleBinaryOperator combinator; + private final Reduce.Aggregator aggregator; + private final List<String> dimensions; + + public ReduceJoin(Reduce reduce, Join join) { + this(join.arguments().get(0), join.arguments().get(1), join.combinator(), reduce.aggregator(), reduce.dimensions()); + } + + public ReduceJoin(TensorFunction argumentA, + TensorFunction argumentB, + DoubleBinaryOperator combinator, + Reduce.Aggregator aggregator, + List<String> dimensions) { + this.argumentA = argumentA; + this.argumentB = argumentB; + this.combinator = combinator; + this.aggregator = aggregator; + this.dimensions = ImmutableList.copyOf(dimensions); + } + + @Override + public List<TensorFunction> arguments() { + return ImmutableList.of(argumentA, argumentB); + } + + @Override + public TensorFunction withArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 2) + throw new IllegalArgumentException("ReduceJoin must have 2 arguments, got " + arguments.size()); + return new ReduceJoin(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions); + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + Join join = new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); + return new Reduce(join, aggregator, dimensions); + } + + @Override + public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor a = argumentA.evaluate(context); + Tensor b = argumentB.evaluate(context); + TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); + + if (canOptimize(a, b)) { + return evaluate((IndexedTensor)a, (IndexedTensor)b, joinedType); + } + return Reduce.evaluate(Join.evaluate(a, b, joinedType, combinator), dimensions, aggregator); + } + + /** + * Tests whether or not the reduce is over the join dimensions. The + * remaining logic in this class assumes this to be true. + * + * If no dimensions are given, the join must be on all tensor dimensions. + * + * @return {@code true} if the implementation can optimize evaluation + * given the two tensors. + */ + public boolean canOptimize(Tensor a, Tensor b) { + if (a.type().dimensions().isEmpty() || b.type().dimensions().isEmpty()) // TODO: support scalars + return false; + if ( ! (a instanceof IndexedTensor)) + return false; + if ( ! (a.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + return false; + if ( ! (b instanceof IndexedTensor)) + return false; + if ( ! (b.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + return false; + + TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b); + if (dimensions.isEmpty()) { + if (a.type().dimensions().size() != commonDimensions.dimensions().size()) + return false; + if (b.type().dimensions().size() != commonDimensions.dimensions().size()) + return false; + } else { + for (TensorType.Dimension dimension : commonDimensions.dimensions()) { + if (!dimensions.contains(dimension.name())) + return false; + } + } + return true; + } + + /** + * Evaluates the reduce-join. Special handling for common cases where the + * reduce dimension is the innermost dimension in both tensors. + */ + private Tensor evaluate(IndexedTensor a, IndexedTensor b, TensorType joinedType) { + TensorType reducedType = Reduce.outputType(joinedType, dimensions); + + if (reduceDimensionIsInnermost(a, b)) { + if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 1) { + return vectorVectorProduct(a, b, reducedType); + } + if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 2) { + return vectorMatrixProduct(a, b, reducedType, false); + } + if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 1) { + return vectorMatrixProduct(b, a, reducedType, true); + } + if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 2) { + return matrixMatrixProduct(a, b, reducedType); + } + } + return evaluateGeneral(a, b, reducedType); + } + + private Tensor vectorVectorProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType) { + if ( a.type().dimensions().size() != 1 || b.type().dimensions().size() != 1) { + throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-vector product"); + } + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType); + long commonSize = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); + + Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator); + for (int ic = 0; ic < commonSize; ++ic) { + double va = a.get(ic); + double vb = b.get(ic); + agg.aggregate(combinator.applyAsDouble(va, vb)); + } + builder.cellByDirectIndex(0, agg.aggregatedValue()); + return builder.build(); + } + + private Tensor vectorMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType, boolean swapped) { + if ( a.type().dimensions().size() != 1 || b.type().dimensions().size() != 2) { + throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-matrix product"); + } + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType); + DimensionSizes sizesA = a.dimensionSizes(); + DimensionSizes sizesB = b.dimensionSizes(); + + Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator); + for (int ib = 0; ib < sizesB.size(0); ++ib) { + agg.reset(); + for (int ic = 0; ic < Math.min(sizesA.size(0), sizesB.size(1)); ++ic) { + double va = a.get(ic); + double vb = b.get(ib * sizesB.size(1) + ic); + double result = swapped ? combinator.applyAsDouble(vb, va) : combinator.applyAsDouble(va, vb); + agg.aggregate(result); + } + builder.cellByDirectIndex(ib, agg.aggregatedValue()); + } + return builder.build(); + } + + private Tensor matrixMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType) { + if ( a.type().dimensions().size() != 2 || b.type().dimensions().size() != 2) { + throw new IllegalArgumentException("Wrong dimension sizes for tensors for matrix-matrix product"); + } + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType); + DimensionSizes sizesA = a.dimensionSizes(); + DimensionSizes sizesB = b.dimensionSizes(); + int iaToReduced = reducedType.indexOfDimension(a.type().dimensions().get(0).name()).get(); + int ibToReduced = reducedType.indexOfDimension(b.type().dimensions().get(0).name()).get(); + long strideA = iaToReduced < ibToReduced ? sizesB.size(0) : 1; + long strideB = ibToReduced < iaToReduced ? sizesA.size(0) : 1; + + Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator); + for (int ia = 0; ia < sizesA.size(0); ++ia) { + for (int ib = 0; ib < sizesB.size(0); ++ib) { + agg.reset(); + for (int ic = 0; ic < Math.min(sizesA.size(1), sizesB.size(1)); ++ic) { + double va = a.get(ia * sizesA.size(1) + ic); + double vb = b.get(ib * sizesB.size(1) + ic); + agg.aggregate(combinator.applyAsDouble(va, vb)); + } + builder.cellByDirectIndex(ia * strideA + ib * strideB, agg.aggregatedValue()); + } + } + return builder.build(); + } + + private Tensor evaluateGeneral(IndexedTensor a, IndexedTensor b, TensorType reducedType) { + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType); + TensorType onlyInA = Reduce.outputType(a.type(), dimensions); + TensorType onlyInB = Reduce.outputType(b.type(), dimensions); + TensorType common = dimensionsInCommon(a, b); + + // pre-calculate strides for each index position + long[] stridesA = strides(a.type()); + long[] stridesB = strides(b.type()); + long[] stridesResult = strides(reducedType); + + // mapping of dimension indexes + int[] mapOnlyAToA = Join.mapIndexes(onlyInA, a.type()); + int[] mapCommonToA = Join.mapIndexes(common, a.type()); + int[] mapOnlyBToB = Join.mapIndexes(onlyInB, b.type()); + int[] mapCommonToB = Join.mapIndexes(common, b.type()); + int[] mapOnlyAToResult = Join.mapIndexes(onlyInA, reducedType); + int[] mapOnlyBToResult = Join.mapIndexes(onlyInB, reducedType); + + // TODO: refactor with code in IndexedTensor and Join + + MultiDimensionIterator ic = new MultiDimensionIterator(common); + Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator); + for (MultiDimensionIterator ia = new MultiDimensionIterator(onlyInA); ia.hasNext(); ia.next()) { + for (MultiDimensionIterator ib = new MultiDimensionIterator(onlyInB); ib.hasNext(); ib.next()) { + agg.reset(); + for (ic.reset(); ic.hasNext(); ic.next()) { + double va = a.get(toDirectIndex(ia, ic, stridesA, mapOnlyAToA, mapCommonToA)); + double vb = b.get(toDirectIndex(ib, ic, stridesB, mapOnlyBToB, mapCommonToB)); + agg.aggregate(combinator.applyAsDouble(va, vb)); + } + builder.cellByDirectIndex(toDirectIndex(ia, ib, stridesResult, mapOnlyAToResult, mapOnlyBToResult), + agg.aggregatedValue()); + } + } + return builder.build(); + } + + private long toDirectIndex(MultiDimensionIterator iter, MultiDimensionIterator common, long[] strides, int[] map, int[] commonmap) { + long directIndex = 0; + for (int i = 0; i < iter.length(); ++i) { + directIndex += strides[map[i]] * iter.iterator[i]; + } + for (int i = 0; i < common.length(); ++i) { + directIndex += strides[commonmap[i]] * common.iterator[i]; + } + return directIndex; + } + + private long[] strides(TensorType type) { + long[] strides = new long[type.dimensions().size()]; + if (strides.length > 0) { + long previous = 1; + strides[strides.length - 1] = previous; + for (int i = strides.length - 2; i >= 0; --i) { + strides[i] = previous * type.dimensions().get(i + 1).size().get(); + previous = strides[i]; + } + } + return strides; + } + + private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension aDim : a.type().dimensions()) { + for (TensorType.Dimension bDim : b.type().dimensions()) { + if (aDim.name().equals(bDim.name())) { + if ( ! aDim.size().isPresent()) { + builder.set(aDim); + } else if ( ! bDim.size().isPresent()) { + builder.set(bDim); + } else { + builder.set(aDim.size().get() < bDim.size().get() ? aDim : bDim); // minimum size of dimension + } + } + } + } + return builder.build(); + } + + /** + * Tests if there is exactly one reduce dimension and it is the innermost + * dimension in both tensors. + */ + private boolean reduceDimensionIsInnermost(Tensor a, Tensor b) { + List<String> reducingDimensions = dimensions; + if (reducingDimensions.isEmpty()) { + reducingDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b).dimensions().stream() + .map(TensorType.Dimension::name) + .collect(Collectors.toList()); + } + if (reducingDimensions.size() != 1) { + return false; + } + String dimension = reducingDimensions.get(0); + int indexInA = a.type().indexOfDimension(dimension).orElseThrow(() -> + new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor A.")); + if (indexInA != (a.type().dimensions().size() - 1)) { + return false; + } + int indexInB = b.type().indexOfDimension(dimension).orElseThrow(() -> + new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor B.")); + if (indexInB < (b.type().dimensions().size() - 1)) { + return false; + } + return true; + } + + @Override + public String toString(ToStringContext context) { + return "reduce_join(" + argumentA.toString(context) + ", " + + argumentB.toString(context) + ", " + + combinator + ", " + + aggregator + + Reduce.commaSeparated(dimensions) + ")"; + } + + private static class MultiDimensionIterator { + + private long[] bounds; + private long[] iterator; + private int remaining; + + MultiDimensionIterator(TensorType type) { + bounds = new long[type.dimensions().size()]; + iterator = new long[type.dimensions().size()]; + for (int i = 0; i < bounds.length; ++i) { + bounds[i] = type.dimensions().get(i).size().get(); + } + reset(); + } + + public int length() { + return iterator.length; + } + + public boolean hasNext() { + return remaining > 0; + } + + public void reset() { + remaining = 1; + for (int i = iterator.length - 1; i >= 0; --i) { + iterator[i] = 0; + remaining *= bounds[i]; + } + } + + public void next() { + for (int i = iterator.length - 1; i >= 0; --i) { + iterator[i] += 1; + if (iterator[i] < bounds[i]) { + break; + } + iterator[i] = 0; + } + remaining -= 1; + } + + public String toString() { + return Arrays.toString(iterator); + } + } + +} |