aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-09-20 11:00:48 +0200
committerHenning Baldersheim <balder@yahoo-inc.com>2022-09-21 16:27:49 +0200
commit083aa54d59aecc9f9d045d4bde6cdb6c6cbe4dec (patch)
tree8c90676eb3e6cb01cf87d9ee40db4c60f14aad2c
parentd9db475220d68a54ba2c9f820d3bae78f80abd96 (diff)
Short circuit boolean expressions
Short circuit boolean expressions by converting them to (nested) if expressions. This also fixes a bug in Java expression evaluation where evaluation of arithmetic operations with the same precedence would be from right to left rather than left to right.
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java105
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java3
-rw-r--r--config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java57
-rwxr-xr-xdocumentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/GetDocumentReply.java3
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ArithmeticExpression.java20
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java27
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java12
9 files changed, 207 insertions, 30 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
new file mode 100644
index 00000000000..336156a11bd
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
@@ -0,0 +1,105 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema.expressiontransforms;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.IfNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Transforms
+ * a &amp;&amp; b into if(a, b, false)
+ * and
+ * a || b into if(a, true, b)
+ * to avoid computing b if a is false and true respectively.
+ *
+ * This may increase performance since boolean expressions are not short-circuited.
+ *
+ * @author bratseth
+ */
+public class BooleanExpressionTransformer extends ExpressionTransformer<TransformContext> {
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
+ if (node instanceof CompositeNode composite)
+ node = transformChildren(composite, context);
+
+ if (node instanceof ArithmeticNode arithmetic)
+ node = transformBooleanArithmetics(arithmetic);
+
+ return node;
+ }
+
+ private ExpressionNode transformBooleanArithmetics(ArithmeticNode node) {
+ Iterator<ExpressionNode> child = node.children().iterator();
+
+ // Transform in precedence order:
+ Deque<ChildNode> stack = new ArrayDeque<>();
+ stack.push(new ChildNode(ArithmeticOperator.OR, child.next()));
+ for (Iterator<ArithmeticOperator> it = node.operators().iterator(); it.hasNext() && child.hasNext();) {
+ ArithmeticOperator op = it.next();
+ if ( ! stack.isEmpty()) {
+ while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) {
+ popStack(stack);
+ }
+ }
+ stack.push(new ChildNode(op, child.next()));
+ }
+ while (stack.size() > 1)
+ popStack(stack);
+ return stack.getFirst().child;
+ }
+
+ private void popStack(Deque<ChildNode> stack) {
+ ChildNode rhs = stack.pop();
+ ChildNode lhs = stack.peek();
+
+ ExpressionNode combination;
+ if (rhs.op == ArithmeticOperator.AND)
+ combination = andByIfNode(lhs.child, rhs.child);
+ else if (rhs.op == ArithmeticOperator.OR)
+ combination = orByIfNode(lhs.child, rhs.child);
+ else
+ combination = new ArithmeticNode(List.of(lhs.child, rhs.child), List.of(rhs.op));
+ lhs.child = combination;
+ }
+
+
+ private IfNode andByIfNode(ExpressionNode a, ExpressionNode b) {
+ return new IfNode(a, b, new ConstantNode(new BooleanValue(false)));
+ }
+
+ private IfNode orByIfNode(ExpressionNode a, ExpressionNode b) {
+ return new IfNode(a, new ConstantNode(new BooleanValue(true)), b);
+ }
+
+ /** A child with the operator to be applied to it when combining it with the previous child. */
+ private static class ChildNode {
+
+ final ArithmeticOperator op;
+ ExpressionNode child;
+
+ public ChildNode(ArithmeticOperator op, ExpressionNode child) {
+ this.op = op;
+ this.child = child;
+ }
+
+ @Override
+ public String toString() {
+ return child.toString();
+ }
+
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java
index 86aedd4332a..132597ee75e 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java
@@ -35,7 +35,8 @@ public class ExpressionTransforms {
new FunctionInliner(),
new FunctionShadower(),
new TensorMaxMinTransformer(),
- new Simplifier());
+ new Simplifier(),
+ new BooleanExpressionTransformer());
}
public RankingExpression transform(RankingExpression expression, RankProfileTransformContext context) {
diff --git a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
index 789c4ac5577..5eecee516ec 100644
--- a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
@@ -163,7 +163,7 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase {
" \n" +
" rank-profile test {\n" +
" first-phase {\n" +
- " expression: A + C + D\n" +
+ " expression: A + C - D\n" +
" }\n" +
" function inline D() {\n" +
" expression: B + 1\n" +
@@ -184,7 +184,7 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase {
Schema s = builder.getSchema();
RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels());
- assertEquals("attribute(a) + C + (attribute(b) + 1)", test.getFirstPhaseRanking().getRoot().toString());
+ assertEquals("attribute(a) + C - (attribute(b) + 1)", test.getFirstPhaseRanking().getRoot().toString());
assertEquals("attribute(a) + attribute(b)", getRankingExpression("C", test, s));
assertEquals("attribute(b) + 1", getRankingExpression("D", test, s));
}
diff --git a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
new file mode 100644
index 00000000000..71d0657c701
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
@@ -0,0 +1,57 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema.expressiontransforms;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
+import org.junit.jupiter.api.Test;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import java.util.Map;
+
+/**
+ * @author bratseth
+ */
+public class BooleanExpressionTransformerTestCase {
+
+ @Test
+ public void testTransformer() throws Exception {
+ assertTransformed("if (a, b, false)", "a && b");
+ assertTransformed("if (a, true, b)", "a || b");
+ assertTransformed("if (a, true, b + c)", "a || b + c");
+ assertTransformed("if (c + a, true, b)", "c + a || b");
+ assertTransformed("if (c + a, true, b + c)", "c + a || b + c");
+ assertTransformed("if (a + b, true, if (c - d * e, f, false))", "a + b || c - d * e && f");
+ assertTransformed("if (a, true, if (b, c, false))", "a || b && c");
+ assertTransformed("if (a + b, true, if (if (c, d, false), e * f - g, false))", "a + b || c && d && e * f - g");
+ assertTransformed("if(1 - 1, true, 1 - 1)", "1 - 1 || 1 - 1");
+ }
+
+ @Test
+ public void testIt() throws Exception {
+ assertTransformed("if(1 - 1, true, 1 - 1)", "1 - 1 || 1 - 1");
+ }
+
+ private void assertTransformed(String expected, String input) throws Exception {
+ var transformedExpression = new BooleanExpressionTransformer()
+ .transform(new RankingExpression(input),
+ new TransformContext(Map.of(), new MapTypeContext()));
+
+ assertEquals(new RankingExpression(expected), transformedExpression, "Transformed as expected");
+
+ MapContext context = contextWithSingleLetterVariables();
+ var inputExpression = new RankingExpression(input);
+ assertEquals(inputExpression.evaluate(new MapContext()).asBoolean(),
+ transformedExpression.evaluate(new MapContext()).asBoolean(),
+ "Transform and original input are equivalent");
+ }
+
+ private MapContext contextWithSingleLetterVariables() {
+ var context = new MapContext();
+ for (int i = 0; i < 26; i++)
+ context.put(Character.toString(i + 97), Math.floorMod(i, 2));
+ return context;
+ }
+
+}
diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/GetDocumentReply.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/GetDocumentReply.java
index 2f2d90f2052..4613dfc472d 100755
--- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/GetDocumentReply.java
+++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/GetDocumentReply.java
@@ -25,7 +25,8 @@ public class GetDocumentReply extends DocumentAcceptedReply {
/**
* Constructs a new reply to lazily deserialize from a byte buffer.
- * @param decoder The decoder to use for deserialization.
+ *
+ * @param decoder The decoder to use for deserialization.
* @param buf A byte buffer that contains a serialized reply.
*/
GetDocumentReply(LazyDecoder decoder, DocumentDeserializer buf) {
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ArithmeticExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ArithmeticExpression.java
index fa82c4d88ee..e4bc2dae965 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ArithmeticExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ArithmeticExpression.java
@@ -166,19 +166,13 @@ public final class ArithmeticExpression extends CompositeExpression {
}
BigDecimal lhsVal = asBigDecimal((NumericFieldValue)lhs);
BigDecimal rhsVal = asBigDecimal((NumericFieldValue)rhs);
- switch (op) {
- case ADD:
- return createFieldValue(lhs, rhs, lhsVal.add(rhsVal));
- case SUB:
- return createFieldValue(lhs, rhs, lhsVal.subtract(rhsVal));
- case MUL:
- return createFieldValue(lhs, rhs, lhsVal.multiply(rhsVal));
- case DIV:
- return createFieldValue(lhs, rhs, lhsVal.divide(rhsVal, MathContext.DECIMAL64));
- case MOD:
- return createFieldValue(lhs, rhs, lhsVal.remainder(rhsVal));
- }
- throw new IllegalStateException("Unsupported operation: " + op);
+ return switch (op) {
+ case ADD -> createFieldValue(lhs, rhs, lhsVal.add(rhsVal));
+ case SUB -> createFieldValue(lhs, rhs, lhsVal.subtract(rhsVal));
+ case MUL -> createFieldValue(lhs, rhs, lhsVal.multiply(rhsVal));
+ case DIV -> createFieldValue(lhs, rhs, lhsVal.divide(rhsVal, MathContext.DECIMAL64));
+ case MOD -> createFieldValue(lhs, rhs, lhsVal.remainder(rhsVal));
+ };
}
private FieldValue createFieldValue(FieldValue lhs, FieldValue rhs, BigDecimal val) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
index 580f42e67cb..ce5853155d4 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
@@ -47,7 +47,9 @@ public final class ArithmeticNode extends CompositeNode {
string.append("(");
Iterator<ExpressionNode> child = children.iterator();
- child.next().toString(string, context, path, this).append(" ");
+ child.next().toString(string, context, path, this);
+ if (child.hasNext())
+ string.append(" ");
for (Iterator<ArithmeticOperator> op = operators.iterator(); op.hasNext() && child.hasNext();) {
string.append(op.next().toString()).append(" ");
child.next().toString(string, context, path, this);
@@ -65,16 +67,15 @@ public final class ArithmeticNode extends CompositeNode {
* (even though by virtue of being a node it will be calculated before the parent).
*/
private boolean nonDefaultPrecedence(CompositeNode parent) {
- if ( parent==null) return false;
- if ( ! (parent instanceof ArithmeticNode)) return false;
+ if ( parent == null) return false;
+ if ( ! (parent instanceof ArithmeticNode arithmeticParent)) return false;
- ArithmeticNode arithParent = (ArithmeticNode) parent;
// The line below can only be correct in both only have one operator.
// Getting this correct is impossible without more work.
- // So for now now we only handle the simple case correctly, and use a safe approach by adding
+ // So for now we only handle the simple case correctly, and use a safe approach by adding
// extra parenthesis just in case....
- return arithParent.operators.get(0).hasPrecedenceOver(this.operators.get(0))
- || ((arithParent.operators.size() > 1) || (operators.size() > 1));
+ return arithmeticParent.operators.get(0).hasPrecedenceOver(this.operators.get(0))
+ || ((arithmeticParent.operators.size() > 1) || (operators.size() > 1));
}
@Override
@@ -98,7 +99,7 @@ public final class ArithmeticNode extends CompositeNode {
for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) {
ArithmeticOperator op = it.next();
if ( ! stack.isEmpty()) {
- while (stack.peek().op.hasPrecedenceOver(op)) {
+ while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) {
popStack(stack);
}
}
@@ -127,9 +128,7 @@ public final class ArithmeticNode extends CompositeNode {
public int hashCode() { return Objects.hash(children, operators); }
public static ArithmeticNode resolve(ExpressionNode left, ArithmeticOperator op, ExpressionNode right) {
- if ( ! (left instanceof ArithmeticNode)) return new ArithmeticNode(left, op, right);
-
- ArithmeticNode leftArithmetic = (ArithmeticNode)left;
+ if ( ! (left instanceof ArithmeticNode leftArithmetic)) return new ArithmeticNode(left, op, right);
List<ExpressionNode> newChildren = new ArrayList<>(leftArithmetic.children());
newChildren.add(right);
@@ -149,6 +148,12 @@ public final class ArithmeticNode extends CompositeNode {
this.op = op;
this.value = value;
}
+
+ @Override
+ public String toString() {
+ return value.toString();
+ }
+
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
index 4aee3268111..e23c6ec5dd5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
@@ -20,7 +20,11 @@ public abstract class ExpressionTransformer<CONTEXT extends TransformContext> {
return new RankingExpression(expression.getName(), transform(expression.getRoot(), context));
}
- /** Transforms an expression node and returns the transformed node */
+ /**
+ * Transforms an expression node and returns the transformed node.
+ * This ic called with the root node of an expression to transform by clients of transformers.
+ * Transforming nested expression nodes are left to each transformer.
+ */
public abstract ExpressionNode transform(ExpressionNode node, CONTEXT context);
/**
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 19e32c23234..ad50a423eb9 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -56,6 +56,15 @@ public class EvaluationTestCase {
}
@Test
+ public void testEvaluationOrder() {
+ EvaluationTester tester = new EvaluationTester();
+ tester.assertEvaluates(-4, "1 + -2 + -3");
+ tester.assertEvaluates(2, "1 - (2 - 3)");
+ tester.assertEvaluates(-4, "(1 - 2) - 3");
+ tester.assertEvaluates(-4, "1 - 2 - 3");
+ }
+
+ @Test
public void testEvaluation() {
EvaluationTester tester = new EvaluationTester();
tester.assertEvaluates(0.5, "0.5");
@@ -78,6 +87,7 @@ public class EvaluationTestCase {
tester.assertEvaluates(3, "1 + 10 % 6 / 2");
tester.assertEvaluates(10.0, "3 ^ 2 + 1");
tester.assertEvaluates(18.0, "2 * 3 ^ 2");
+ tester.assertEvaluates(-4, "1 - 2 - 3"); // Means 1 + -2 + -3
// Conditionals
tester.assertEvaluates(2 * (3 * 4 + 3) * (4 * 5 - 4 * 200) / 10, "2*(3*4+3)*(4*5-4*200)/10");
@@ -106,7 +116,7 @@ public class EvaluationTestCase {
// Conditionals with branch probabilities
RankingExpression e = tester.assertEvaluates(3.5, "if(1.0-1.0, 2.5, 3.5, 0.3)");
- assertEquals(0.3d, (double)((IfNode) e.getRoot()).getTrueProbability(), tolerance);
+ assertEquals(0.3d, ((IfNode) e.getRoot()).getTrueProbability(), tolerance);
// Conditionals as expressions
tester.assertEvaluates(new BooleanValue(true), "2<3");