summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-29 12:53:37 +0200
committerGitHub <noreply@github.com>2018-09-29 12:53:37 +0200
commitbfc4feb4f5b9b2cec76bd08027cdf3c2ca339f39 (patch)
tree3bd0250db894ae103c85062556b4b34c59e38e84 /searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
parent44313db093c7954c6fa979296832ed44dd28991d (diff)
parent35da647c2c32ea3a81b5c65bd440af1dc59b1b3c (diff)
Merge pull request #7145 from vespa-engine/lesters/add-java-reduce-join-optimization
Add reduce-join optimization in Java
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java43
1 files changed, 42 insertions, 1 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index de98b01287e..3a3410aeebb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -12,6 +12,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
+import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
@@ -90,7 +91,47 @@ public class LambdaFunctionNode extends CompositeNode {
if (arguments.size() > 2)
throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " +
"Must have at most two argument " + " but has " + arguments);
- return new DoubleBinaryLambda();
+
+ // Optimization: if possible, calculate directly rather than creating a context and evaluating the expression
+ return getDirectEvaluator().orElseGet(DoubleBinaryLambda::new);
+ }
+
+ private Optional<DoubleBinaryOperator> getDirectEvaluator() {
+ if ( ! (functionExpression instanceof ArithmeticNode)) {
+ return Optional.empty();
+ }
+ ArithmeticNode node = (ArithmeticNode) functionExpression;
+ if ( ! (node.children().get(0) instanceof ReferenceNode) || ! (node.children().get(1) instanceof ReferenceNode)) {
+ return Optional.empty();
+ }
+ if (node.operators().size() != 1) {
+ return Optional.empty();
+ }
+ ArithmeticOperator operator = node.operators().get(0);
+ switch (operator) {
+ case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
+ case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
+ case PLUS: return asFunctionExpression((left, right) -> left + right);
+ case MINUS: return asFunctionExpression((left, right) -> left - right);
+ case MULTIPLY: return asFunctionExpression((left, right) -> left * right);
+ case DIVIDE: return asFunctionExpression((left, right) -> left / right);
+ case MODULO: return asFunctionExpression((left, right) -> left % right);
+ case POWER: return asFunctionExpression(Math::pow);
+ }
+ return Optional.empty();
+ }
+
+ private Optional<DoubleBinaryOperator> asFunctionExpression(DoubleBinaryOperator operator) {
+ return Optional.of(new DoubleBinaryOperator() {
+ @Override
+ public double applyAsDouble(double left, double right) {
+ return operator.applyAsDouble(left, right);
+ }
+ @Override
+ public String toString() {
+ return LambdaFunctionNode.this.toString();
+ }
+ });
}
private class DoubleUnaryLambda implements DoubleUnaryOperator {