diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-29 12:53:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-29 12:53:37 +0200 |
commit | bfc4feb4f5b9b2cec76bd08027cdf3c2ca339f39 (patch) | |
tree | 3bd0250db894ae103c85062556b4b34c59e38e84 /searchlib | |
parent | 44313db093c7954c6fa979296832ed44dd28991d (diff) | |
parent | 35da647c2c32ea3a81b5c65bd440af1dc59b1b3c (diff) |
Merge pull request #7145 from vespa-engine/lesters/add-java-reduce-join-optimization
Add reduce-join optimization in Java
Diffstat (limited to 'searchlib')
5 files changed, 272 insertions, 5 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java index 7060cfc2132..84a90ee64c2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTForestOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTOptimizer; +import com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization.TensorOptimizer; /** * This class will perform various optimizations on the ranking expressions. Clients using optimized expressions @@ -32,8 +33,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTOpt public class ExpressionOptimizer { private GBDTOptimizer gbdtOptimizer = new GBDTOptimizer(); - private GBDTForestOptimizer gbdtForestOptimizer = new GBDTForestOptimizer(); + private TensorOptimizer tensorOptimizer = new TensorOptimizer(); /** Gets an optimizer instance used by this by class name, or null if the optimizer is not known */ public Optimizer getOptimizer(Class<?> clazz) { @@ -41,6 +42,8 @@ public class ExpressionOptimizer { return gbdtOptimizer; if (clazz == gbdtForestOptimizer.getClass()) return gbdtForestOptimizer; + if (clazz == tensorOptimizer.getClass()) + return tensorOptimizer; return null; } @@ -49,6 +52,7 @@ public class ExpressionOptimizer { // Note: Order of optimizations matter gbdtOptimizer.optimize(expression, contextIndex, report); gbdtForestOptimizer.optimize(expression, contextIndex, report); + tensorOptimizer.optimize(expression, contextIndex, report); return report; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java new file mode 100644 index 00000000000..63cea371d14 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java @@ -0,0 +1,85 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; +import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; +import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ReduceJoin; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; + +/** + * Recognizes and optimizes tensor expressions. + * + * @author lesters + */ +public class TensorOptimizer extends Optimizer { + + private OptimizationReport report; + + @Override + public void optimize(RankingExpression expression, ContextIndex context, OptimizationReport report) { + if (!isEnabled()) return; + this.report = report; + expression.setRoot(optimize(expression.getRoot(), context)); + report.note("Tensor expression optimization done"); + } + + private ExpressionNode optimize(ExpressionNode node, ContextIndex context) { + node = optimizeReduceJoin(node); + if (node instanceof CompositeNode) { + return optimizeChildren((CompositeNode)node, context); + } + return node; + } + + private ExpressionNode optimizeChildren(CompositeNode node, ContextIndex context) { + List<ExpressionNode> children = node.children(); + List<ExpressionNode> optimizedChildren = new ArrayList<>(children.size()); + for (ExpressionNode child : children) + optimizedChildren.add(optimize(child, context)); + return node.setChildren(optimizedChildren); + } + + /** + * Recognized a reduce followed by a join. In many cases, chunking these + * two operations together is significantly more efficient than evaluating + * each on its own, avoiding the cost of a temporary tensor. + * + * Note that this does not guarantee that the optimization is performed. + * The ReduceJoin class determines whether or not the arguments are + * compatible with the optimization. + */ + private ExpressionNode optimizeReduceJoin(ExpressionNode node) { + if ( ! (node instanceof TensorFunctionNode)) { + return node; + } + TensorFunction function = ((TensorFunctionNode) node).function(); + if ( ! (function instanceof Reduce)) { + return node; + } + List<ExpressionNode> children = ((TensorFunctionNode) node).children(); + if (children.size() != 1) { + return node; + } + ExpressionNode child = children.get(0); + if ( ! (child instanceof TensorFunctionNode)) { + return node; + } + TensorFunction argument = ((TensorFunctionNode) child).function(); + if (argument instanceof Join) { + report.incMetric("Replaced reduce->join", 1); + return new TensorFunctionNode(new ReduceJoin((Reduce)function, (Join)argument)); + } + return node; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index de98b01287e..3a3410aeebb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -12,6 +12,7 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.Deque; import java.util.List; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; @@ -90,7 +91,47 @@ public class LambdaFunctionNode extends CompositeNode { if (arguments.size() > 2) throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " + "Must have at most two argument " + " but has " + arguments); - return new DoubleBinaryLambda(); + + // Optimization: if possible, calculate directly rather than creating a context and evaluating the expression + return getDirectEvaluator().orElseGet(DoubleBinaryLambda::new); + } + + private Optional<DoubleBinaryOperator> getDirectEvaluator() { + if ( ! (functionExpression instanceof ArithmeticNode)) { + return Optional.empty(); + } + ArithmeticNode node = (ArithmeticNode) functionExpression; + if ( ! (node.children().get(0) instanceof ReferenceNode) || ! (node.children().get(1) instanceof ReferenceNode)) { + return Optional.empty(); + } + if (node.operators().size() != 1) { + return Optional.empty(); + } + ArithmeticOperator operator = node.operators().get(0); + switch (operator) { + case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0); + case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0); + case PLUS: return asFunctionExpression((left, right) -> left + right); + case MINUS: return asFunctionExpression((left, right) -> left - right); + case MULTIPLY: return asFunctionExpression((left, right) -> left * right); + case DIVIDE: return asFunctionExpression((left, right) -> left / right); + case MODULO: return asFunctionExpression((left, right) -> left % right); + case POWER: return asFunctionExpression(Math::pow); + } + return Optional.empty(); + } + + private Optional<DoubleBinaryOperator> asFunctionExpression(DoubleBinaryOperator operator) { + return Optional.of(new DoubleBinaryOperator() { + @Override + public double applyAsDouble(double left, double right) { + return operator.applyAsDouble(left, right); + } + @Override + public String toString() { + return LambdaFunctionNode.this.toString(); + } + }); } private class DoubleUnaryLambda implements DoubleUnaryOperator { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java new file mode 100644 index 00000000000..f29083bddc9 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java @@ -0,0 +1,116 @@ +package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext; +import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ReduceJoin; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author lesters + */ +public class TensorOptimizerTestCase { + + @Test + public void testReduceJoinOptimization() throws ParseException { + assertWillOptimize("d0[3]", "d0[3]"); + assertWillOptimize("d0[1]", "d0[1]", "d0"); + assertWillOptimize("d0[2]", "d0[2]", "d0"); + assertWillOptimize("d0[1]", "d0[3]", "d0"); + assertWillOptimize("d0[3]", "d0[3]", "d0"); + assertWillOptimize("d0[3]", "d0[3],d1[2]", "d0"); + assertWillOptimize("d0[3],d1[2]", "d0[3]", "d0"); + assertWillOptimize("d1[3]", "d0[2],d1[3]", "d1"); + assertWillOptimize("d0[2],d1[3]", "d1[3]", "d1"); + assertWillOptimize("d0[2],d2[2]", "d1[3],d2[2]", "d2"); + assertWillOptimize("d1[2],d2[2]", "d0[3],d2[2]", "d2"); + assertWillOptimize("d0[1],d2[2]", "d1[3],d2[4]", "d2"); + assertWillOptimize("d0[2],d2[2]", "d1[3],d2[4]", "d2"); + assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]"); + assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]", "d0,d1"); + assertWillOptimize("d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3"); + assertWillOptimize("d0[1],d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3"); + assertWillOptimize("d0[1],d1[2],d2[3]", "d2[3],d3[4],d4[5]", "d2"); + assertWillOptimize("d0[1],d1[2],d2[3]", "d1[2],d2[3],d4[4]", "d1,d2"); + assertWillOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]"); + assertWillOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]", "d0,d1,d2"); + + // Will not currently use reduce-join optimization + assertCantOptimize("d0[2],d1[3]", "d1[3]", "d0"); // reducing on a dimension not joining on + assertCantOptimize("d0[1],d1[2]", "d1[2],d2[3]", "d2"); // same + assertCantOptimize("d0[3]", "d0[3],d1[2]"); // reducing on more then we are combining + assertCantOptimize("d0[1],d2[3]", "d1[2],d2[3]"); // same + assertCantOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]", "d1,d2"); // reducing on less then joining on + } + + private void assertWillOptimize(String aType, String bType) throws ParseException { + assertWillOptimize(aType, bType, "", "sum"); + } + + private void assertWillOptimize(String aType, String bType, String reduceDim) throws ParseException { + assertWillOptimize(aType, bType, reduceDim, "sum"); + } + + private void assertWillOptimize(String aType, String bType, String reduceDim, String aggregator) throws ParseException { + assertReduceJoin(aType, bType, reduceDim, aggregator, true); + } + + private void assertCantOptimize(String aType, String bType) throws ParseException { + assertCantOptimize(aType, bType, "", "sum"); + } + + private void assertCantOptimize(String aType, String bType, String reduceDim) throws ParseException { + assertCantOptimize(aType, bType, reduceDim, "sum"); + } + + private void assertCantOptimize(String aType, String bType, String reduceDim, String aggregator) throws ParseException { + assertReduceJoin(aType, bType, reduceDim, aggregator, false); + } + + private void assertReduceJoin(String aType, String bType, String reduceDim, String aggregator, boolean assertOptimize) throws ParseException { + Tensor a = generateRandomTensor(aType); + Tensor b = generateRandomTensor(bType); + RankingExpression expression = generateRankingExpression(reduceDim, aggregator); + assert ((TensorFunctionNode)expression.getRoot()).function() instanceof Reduce; + + ArrayContext context = generateContext(a, b, expression); + Tensor result = expression.evaluate(context).asTensor(); + + ExpressionOptimizer optimizer = new ExpressionOptimizer(); + OptimizationReport report = optimizer.optimize(expression, context); + assertEquals(1, report.getMetric("Replaced reduce->join")); + assert ((TensorFunctionNode)expression.getRoot()).function() instanceof ReduceJoin; + + assertEquals(result, expression.evaluate(context).asTensor()); + assertEquals(assertOptimize, ((ReduceJoin)((TensorFunctionNode)expression.getRoot()).function()).canOptimize(a, b)); + } + + private RankingExpression generateRankingExpression(String reduceDim, String aggregator) throws ParseException { + String dimensions = ""; + if (reduceDim.length() > 0) { + dimensions = ", " + reduceDim; + } + return new RankingExpression("reduce(join(a, b, f(a,b)(a * b)), " + aggregator + dimensions + ")"); + } + + private ArrayContext generateContext(Tensor a, Tensor b, RankingExpression expression) { + ArrayContext context = new ArrayContext(expression); + context.put("a", new TensorValue(a)); + context.put("b", new TensorValue(b)); + return context; + } + + private Tensor generateRandomTensor(String type) { + return Tensor.random(TensorType.fromSpec("tensor(" + type + ")")); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java index 5447e5240f7..19b3329e8a4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java @@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; +import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; @@ -50,7 +52,11 @@ public class TestableTensorFlowModel { model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); - Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); + RankingExpression expression = model.expressions().get(operationName); + ExpressionOptimizer optimizer = new ExpressionOptimizer(); + optimizer.optimize(expression, (ContextIndex)context); + + Tensor vespaResult = expression.evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", tfResult.sum().asDouble(), vespaResult.sum().asDouble(), delta); } @@ -64,7 +70,11 @@ public class TestableTensorFlowModel { model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); - Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); + RankingExpression expression = model.expressions().get(operationName); + ExpressionOptimizer optimizer = new ExpressionOptimizer(); + optimizer.optimize(expression, (ContextIndex)context); + + Tensor vespaResult = expression.evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); } @@ -82,7 +92,7 @@ public class TestableTensorFlowModel { } private Context contextFrom(ImportedModel result) { - MapContext context = new MapContext(); + TestableModelContext context = new TestableModelContext(); result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); return context; @@ -118,4 +128,15 @@ public class TestableTensorFlowModel { } } + private static class TestableModelContext extends MapContext implements ContextIndex { + @Override + public int size() { + return bindings().size(); + } + @Override + public int getIndex(String name) { + throw new UnsupportedOperationException(this + " does not support index lookup by name"); + } + } + } |