summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java85
1 files changed, 85 insertions, 0 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
new file mode 100644
index 00000000000..63cea371d14
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
@@ -0,0 +1,85 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
+import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ReduceJoin;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Recognizes and optimizes tensor expressions.
+ *
+ * @author lesters
+ */
+public class TensorOptimizer extends Optimizer {
+
+ private OptimizationReport report;
+
+ @Override
+ public void optimize(RankingExpression expression, ContextIndex context, OptimizationReport report) {
+ if (!isEnabled()) return;
+ this.report = report;
+ expression.setRoot(optimize(expression.getRoot(), context));
+ report.note("Tensor expression optimization done");
+ }
+
+ private ExpressionNode optimize(ExpressionNode node, ContextIndex context) {
+ node = optimizeReduceJoin(node);
+ if (node instanceof CompositeNode) {
+ return optimizeChildren((CompositeNode)node, context);
+ }
+ return node;
+ }
+
+ private ExpressionNode optimizeChildren(CompositeNode node, ContextIndex context) {
+ List<ExpressionNode> children = node.children();
+ List<ExpressionNode> optimizedChildren = new ArrayList<>(children.size());
+ for (ExpressionNode child : children)
+ optimizedChildren.add(optimize(child, context));
+ return node.setChildren(optimizedChildren);
+ }
+
+ /**
+ * Recognized a reduce followed by a join. In many cases, chunking these
+ * two operations together is significantly more efficient than evaluating
+ * each on its own, avoiding the cost of a temporary tensor.
+ *
+ * Note that this does not guarantee that the optimization is performed.
+ * The ReduceJoin class determines whether or not the arguments are
+ * compatible with the optimization.
+ */
+ private ExpressionNode optimizeReduceJoin(ExpressionNode node) {
+ if ( ! (node instanceof TensorFunctionNode)) {
+ return node;
+ }
+ TensorFunction function = ((TensorFunctionNode) node).function();
+ if ( ! (function instanceof Reduce)) {
+ return node;
+ }
+ List<ExpressionNode> children = ((TensorFunctionNode) node).children();
+ if (children.size() != 1) {
+ return node;
+ }
+ ExpressionNode child = children.get(0);
+ if ( ! (child instanceof TensorFunctionNode)) {
+ return node;
+ }
+ TensorFunction argument = ((TensorFunctionNode) child).function();
+ if (argument instanceof Join) {
+ report.incMetric("Replaced reduce->join", 1);
+ return new TensorFunctionNode(new ReduceJoin((Reduce)function, (Join)argument));
+ }
+ return node;
+ }
+
+}