summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-29 12:53:37 +0200
committerGitHub <noreply@github.com>2018-09-29 12:53:37 +0200
commitbfc4feb4f5b9b2cec76bd08027cdf3c2ca339f39 (patch)
tree3bd0250db894ae103c85062556b4b34c59e38e84 /searchlib
parent44313db093c7954c6fa979296832ed44dd28991d (diff)
parent35da647c2c32ea3a81b5c65bd440af1dc59b1b3c (diff)
Merge pull request #7145 from vespa-engine/lesters/add-java-reduce-join-optimization
Add reduce-join optimization in Java
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java85
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java43
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java116
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java27
5 files changed, 272 insertions, 5 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java
index 7060cfc2132..84a90ee64c2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java
@@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTForestOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization.TensorOptimizer;
/**
* This class will perform various optimizations on the ranking expressions. Clients using optimized expressions
@@ -32,8 +33,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTOpt
public class ExpressionOptimizer {
private GBDTOptimizer gbdtOptimizer = new GBDTOptimizer();
-
private GBDTForestOptimizer gbdtForestOptimizer = new GBDTForestOptimizer();
+ private TensorOptimizer tensorOptimizer = new TensorOptimizer();
/** Gets an optimizer instance used by this by class name, or null if the optimizer is not known */
public Optimizer getOptimizer(Class<?> clazz) {
@@ -41,6 +42,8 @@ public class ExpressionOptimizer {
return gbdtOptimizer;
if (clazz == gbdtForestOptimizer.getClass())
return gbdtForestOptimizer;
+ if (clazz == tensorOptimizer.getClass())
+ return tensorOptimizer;
return null;
}
@@ -49,6 +52,7 @@ public class ExpressionOptimizer {
// Note: Order of optimizations matter
gbdtOptimizer.optimize(expression, contextIndex, report);
gbdtForestOptimizer.optimize(expression, contextIndex, report);
+ tensorOptimizer.optimize(expression, contextIndex, report);
return report;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
new file mode 100644
index 00000000000..63cea371d14
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
@@ -0,0 +1,85 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
+import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ReduceJoin;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Recognizes and optimizes tensor expressions.
+ *
+ * @author lesters
+ */
+public class TensorOptimizer extends Optimizer {
+
+ private OptimizationReport report;
+
+ @Override
+ public void optimize(RankingExpression expression, ContextIndex context, OptimizationReport report) {
+ if (!isEnabled()) return;
+ this.report = report;
+ expression.setRoot(optimize(expression.getRoot(), context));
+ report.note("Tensor expression optimization done");
+ }
+
+ private ExpressionNode optimize(ExpressionNode node, ContextIndex context) {
+ node = optimizeReduceJoin(node);
+ if (node instanceof CompositeNode) {
+ return optimizeChildren((CompositeNode)node, context);
+ }
+ return node;
+ }
+
+ private ExpressionNode optimizeChildren(CompositeNode node, ContextIndex context) {
+ List<ExpressionNode> children = node.children();
+ List<ExpressionNode> optimizedChildren = new ArrayList<>(children.size());
+ for (ExpressionNode child : children)
+ optimizedChildren.add(optimize(child, context));
+ return node.setChildren(optimizedChildren);
+ }
+
+ /**
+ * Recognized a reduce followed by a join. In many cases, chunking these
+ * two operations together is significantly more efficient than evaluating
+ * each on its own, avoiding the cost of a temporary tensor.
+ *
+ * Note that this does not guarantee that the optimization is performed.
+ * The ReduceJoin class determines whether or not the arguments are
+ * compatible with the optimization.
+ */
+ private ExpressionNode optimizeReduceJoin(ExpressionNode node) {
+ if ( ! (node instanceof TensorFunctionNode)) {
+ return node;
+ }
+ TensorFunction function = ((TensorFunctionNode) node).function();
+ if ( ! (function instanceof Reduce)) {
+ return node;
+ }
+ List<ExpressionNode> children = ((TensorFunctionNode) node).children();
+ if (children.size() != 1) {
+ return node;
+ }
+ ExpressionNode child = children.get(0);
+ if ( ! (child instanceof TensorFunctionNode)) {
+ return node;
+ }
+ TensorFunction argument = ((TensorFunctionNode) child).function();
+ if (argument instanceof Join) {
+ report.incMetric("Replaced reduce->join", 1);
+ return new TensorFunctionNode(new ReduceJoin((Reduce)function, (Join)argument));
+ }
+ return node;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index de98b01287e..3a3410aeebb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -12,6 +12,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
+import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
@@ -90,7 +91,47 @@ public class LambdaFunctionNode extends CompositeNode {
if (arguments.size() > 2)
throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " +
"Must have at most two argument " + " but has " + arguments);
- return new DoubleBinaryLambda();
+
+ // Optimization: if possible, calculate directly rather than creating a context and evaluating the expression
+ return getDirectEvaluator().orElseGet(DoubleBinaryLambda::new);
+ }
+
+ private Optional<DoubleBinaryOperator> getDirectEvaluator() {
+ if ( ! (functionExpression instanceof ArithmeticNode)) {
+ return Optional.empty();
+ }
+ ArithmeticNode node = (ArithmeticNode) functionExpression;
+ if ( ! (node.children().get(0) instanceof ReferenceNode) || ! (node.children().get(1) instanceof ReferenceNode)) {
+ return Optional.empty();
+ }
+ if (node.operators().size() != 1) {
+ return Optional.empty();
+ }
+ ArithmeticOperator operator = node.operators().get(0);
+ switch (operator) {
+ case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
+ case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
+ case PLUS: return asFunctionExpression((left, right) -> left + right);
+ case MINUS: return asFunctionExpression((left, right) -> left - right);
+ case MULTIPLY: return asFunctionExpression((left, right) -> left * right);
+ case DIVIDE: return asFunctionExpression((left, right) -> left / right);
+ case MODULO: return asFunctionExpression((left, right) -> left % right);
+ case POWER: return asFunctionExpression(Math::pow);
+ }
+ return Optional.empty();
+ }
+
+ private Optional<DoubleBinaryOperator> asFunctionExpression(DoubleBinaryOperator operator) {
+ return Optional.of(new DoubleBinaryOperator() {
+ @Override
+ public double applyAsDouble(double left, double right) {
+ return operator.applyAsDouble(left, right);
+ }
+ @Override
+ public String toString() {
+ return LambdaFunctionNode.this.toString();
+ }
+ });
}
private class DoubleUnaryLambda implements DoubleUnaryOperator {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java
new file mode 100644
index 00000000000..f29083bddc9
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java
@@ -0,0 +1,116 @@
+package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ReduceJoin;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+public class TensorOptimizerTestCase {
+
+ @Test
+ public void testReduceJoinOptimization() throws ParseException {
+ assertWillOptimize("d0[3]", "d0[3]");
+ assertWillOptimize("d0[1]", "d0[1]", "d0");
+ assertWillOptimize("d0[2]", "d0[2]", "d0");
+ assertWillOptimize("d0[1]", "d0[3]", "d0");
+ assertWillOptimize("d0[3]", "d0[3]", "d0");
+ assertWillOptimize("d0[3]", "d0[3],d1[2]", "d0");
+ assertWillOptimize("d0[3],d1[2]", "d0[3]", "d0");
+ assertWillOptimize("d1[3]", "d0[2],d1[3]", "d1");
+ assertWillOptimize("d0[2],d1[3]", "d1[3]", "d1");
+ assertWillOptimize("d0[2],d2[2]", "d1[3],d2[2]", "d2");
+ assertWillOptimize("d1[2],d2[2]", "d0[3],d2[2]", "d2");
+ assertWillOptimize("d0[1],d2[2]", "d1[3],d2[4]", "d2");
+ assertWillOptimize("d0[2],d2[2]", "d1[3],d2[4]", "d2");
+ assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]");
+ assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]", "d0,d1");
+ assertWillOptimize("d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3");
+ assertWillOptimize("d0[1],d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3");
+ assertWillOptimize("d0[1],d1[2],d2[3]", "d2[3],d3[4],d4[5]", "d2");
+ assertWillOptimize("d0[1],d1[2],d2[3]", "d1[2],d2[3],d4[4]", "d1,d2");
+ assertWillOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]");
+ assertWillOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]", "d0,d1,d2");
+
+ // Will not currently use reduce-join optimization
+ assertCantOptimize("d0[2],d1[3]", "d1[3]", "d0"); // reducing on a dimension not joining on
+ assertCantOptimize("d0[1],d1[2]", "d1[2],d2[3]", "d2"); // same
+ assertCantOptimize("d0[3]", "d0[3],d1[2]"); // reducing on more then we are combining
+ assertCantOptimize("d0[1],d2[3]", "d1[2],d2[3]"); // same
+ assertCantOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]", "d1,d2"); // reducing on less then joining on
+ }
+
+ private void assertWillOptimize(String aType, String bType) throws ParseException {
+ assertWillOptimize(aType, bType, "", "sum");
+ }
+
+ private void assertWillOptimize(String aType, String bType, String reduceDim) throws ParseException {
+ assertWillOptimize(aType, bType, reduceDim, "sum");
+ }
+
+ private void assertWillOptimize(String aType, String bType, String reduceDim, String aggregator) throws ParseException {
+ assertReduceJoin(aType, bType, reduceDim, aggregator, true);
+ }
+
+ private void assertCantOptimize(String aType, String bType) throws ParseException {
+ assertCantOptimize(aType, bType, "", "sum");
+ }
+
+ private void assertCantOptimize(String aType, String bType, String reduceDim) throws ParseException {
+ assertCantOptimize(aType, bType, reduceDim, "sum");
+ }
+
+ private void assertCantOptimize(String aType, String bType, String reduceDim, String aggregator) throws ParseException {
+ assertReduceJoin(aType, bType, reduceDim, aggregator, false);
+ }
+
+ private void assertReduceJoin(String aType, String bType, String reduceDim, String aggregator, boolean assertOptimize) throws ParseException {
+ Tensor a = generateRandomTensor(aType);
+ Tensor b = generateRandomTensor(bType);
+ RankingExpression expression = generateRankingExpression(reduceDim, aggregator);
+ assert ((TensorFunctionNode)expression.getRoot()).function() instanceof Reduce;
+
+ ArrayContext context = generateContext(a, b, expression);
+ Tensor result = expression.evaluate(context).asTensor();
+
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ OptimizationReport report = optimizer.optimize(expression, context);
+ assertEquals(1, report.getMetric("Replaced reduce->join"));
+ assert ((TensorFunctionNode)expression.getRoot()).function() instanceof ReduceJoin;
+
+ assertEquals(result, expression.evaluate(context).asTensor());
+ assertEquals(assertOptimize, ((ReduceJoin)((TensorFunctionNode)expression.getRoot()).function()).canOptimize(a, b));
+ }
+
+ private RankingExpression generateRankingExpression(String reduceDim, String aggregator) throws ParseException {
+ String dimensions = "";
+ if (reduceDim.length() > 0) {
+ dimensions = ", " + reduceDim;
+ }
+ return new RankingExpression("reduce(join(a, b, f(a,b)(a * b)), " + aggregator + dimensions + ")");
+ }
+
+ private ArrayContext generateContext(Tensor a, Tensor b, RankingExpression expression) {
+ ArrayContext context = new ArrayContext(expression);
+ context.put("a", new TensorValue(a));
+ context.put("b", new TensorValue(b));
+ return context;
+ }
+
+ private Tensor generateRandomTensor(String type) {
+ return Tensor.random(TensorType.fromSpec("tensor(" + type + ")"));
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
index 5447e5240f7..19b3329e8a4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
@@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
@@ -50,7 +52,11 @@ public class TestableTensorFlowModel {
model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
- Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor();
+ RankingExpression expression = model.expressions().get(operationName);
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+
+ Tensor vespaResult = expression.evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results",
tfResult.sum().asDouble(), vespaResult.sum().asDouble(), delta);
}
@@ -64,7 +70,11 @@ public class TestableTensorFlowModel {
model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
- Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor();
+ RankingExpression expression = model.expressions().get(operationName);
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+
+ Tensor vespaResult = expression.evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
}
@@ -82,7 +92,7 @@ public class TestableTensorFlowModel {
}
private Context contextFrom(ImportedModel result) {
- MapContext context = new MapContext();
+ TestableModelContext context = new TestableModelContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
return context;
@@ -118,4 +128,15 @@ public class TestableTensorFlowModel {
}
}
+ private static class TestableModelContext extends MapContext implements ContextIndex {
+ @Override
+ public int size() {
+ return bindings().size();
+ }
+ @Override
+ public int getIndex(String name) {
+ throw new UnsupportedOperationException(this + " does not support index lookup by name");
+ }
+ }
+
}