aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--searchlib/abi-spec.json1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java58
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java9
4 files changed, 51 insertions, 22 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 0768999c52f..aca88ff864e 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -728,6 +728,7 @@
"public abstract com.yahoo.tensor.TensorType type()",
"public abstract double asDouble()",
"public com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue asDoubleValue()",
+ "public boolean isNaN()",
"public abstract com.yahoo.tensor.Tensor asTensor()",
"protected com.yahoo.tensor.Tensor doubleAsTensor(double)",
"public abstract boolean hasDouble()",
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 382cbb7ce9a..071fc6e9fc2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -30,6 +30,11 @@ public abstract class Value {
return new DoubleValue(asDouble());
}
+ /** Returns true if this has a double value which is NaN */
+ public boolean isNaN() {
+ return hasDouble() && Double.isNaN(asDouble());
+ }
+
/** Returns this as a tensor value */
public abstract Tensor asTensor();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index 394faee71a8..1522e0025c7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.transform;
+import com.yahoo.document.update.ArithmeticValueUpdate;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
@@ -46,6 +47,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
}
private ExpressionNode transformArithmetic(ArithmeticNode node) {
+ // Fold the subset of expressions that are constant (such that in "1 + 2 + var")
if (node.children().size() > 1) {
List<ExpressionNode> children = new ArrayList<>(node.children());
List<ArithmeticOperator> operators = new ArrayList<>(node.operators());
@@ -54,32 +56,34 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
node = new ArithmeticNode(children, operators);
}
- if (isConstant(node))
+ if (isConstant(node) && ! node.evaluate(null).isNaN())
return new ConstantNode(node.evaluate(null));
- else if (allMultiplicationOrDivision(node) && hasZero(node)) // disregarding the /0 case
+ else if (allMultiplicationOrDivision(node) && hasZero(node) && ! hasDivisionByZero(node))
return new ConstantNode(new DoubleValue(0));
else
return node;
}
- private void transform(ArithmeticOperator operator, List<ExpressionNode> children, List<ArithmeticOperator> operators) {
+ private void transform(ArithmeticOperator operatorToTransform,
+ List<ExpressionNode> children, List<ArithmeticOperator> operators) {
int i = 0;
while (i < children.size()-1) {
- if ( ! operators.get(i).equals(operator)) {
- i++;
- continue;
- }
-
- ExpressionNode child1 = children.get(i);
- ExpressionNode child2 = children.get(i + 1);
- if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) {
- Value evaluated = new ArithmeticNode(child1, operators.remove(i), child2).evaluate(null);
- children.set(i, new ConstantNode(evaluated.freeze()));
- children.remove(i+1);
+ boolean transformed = false;
+ if ( operators.get(i).equals(operatorToTransform)) {
+ ExpressionNode child1 = children.get(i);
+ ExpressionNode child2 = children.get(i + 1);
+ if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) {
+ Value evaluated = new ArithmeticNode(child1, operators.get(i), child2).evaluate(null);
+ if ( ! evaluated.isNaN()) { // Don't replace by NaN
+ operators.remove(i);
+ children.set(i, new ConstantNode(evaluated.freeze()));
+ children.remove(i + 1);
+ transformed = true;
+ }
+ }
}
- else { // try the next index
+ if ( ! transformed) // try the next index
i++;
- }
}
}
@@ -105,22 +109,34 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
private boolean allMultiplicationOrDivision(ArithmeticNode node) {
for (ArithmeticOperator o : node.operators())
- if (o == ArithmeticOperator.PLUS || o == ArithmeticOperator.MINUS)
+ if (o != ArithmeticOperator.MULTIPLY && o != ArithmeticOperator.DIVIDE)
return false;
return true;
}
private boolean hasZero(ArithmeticNode node) {
for (ExpressionNode child : node.children()) {
- if ( ! (child instanceof ConstantNode)) continue;
- ConstantNode constant = (ConstantNode)child;
- if ( ! constant.getValue().hasDouble()) return false;
- if (constant.getValue().asDouble() == 0.0)
+ if (isZero(child))
return true;
}
return false;
}
+ private boolean hasDivisionByZero(ArithmeticNode node) {
+ for (int i = 1; i < node.children().size(); i++) {
+ if ( node.operators().get(i - 1) == ArithmeticOperator.DIVIDE && isZero(node.children().get(i)))
+ return true;
+ }
+ return false;
+ }
+
+ private boolean isZero(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode)) return false;
+ ConstantNode constant = (ConstantNode)node;
+ if ( ! constant.getValue().hasDouble()) return false;
+ return constant.getValue().asDouble() == 0.0;
+ }
+
private boolean isConstant(ExpressionNode node) {
if (node instanceof ConstantNode) return true;
if (node instanceof ReferenceNode) return false;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
index cf761aac2d3..7861da717f3 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
@@ -42,7 +42,14 @@ public class SimplifierTestCase {
assertEquals("attribute(number) * 1.5 - 0.001 * attribute(number)", s.transform(new RankingExpression("attribute(number) * 1.5 - 0.001 * attribute(number)"), c).toString());
}
- // A black box test verifying we are not screwing up real expressions
+ @Test
+ public void testNaNExpression() throws ParseException {
+ Simplifier s = new Simplifier();
+ TransformContext c = new TransformContext(Collections.emptyMap(), new MapTypeContext());
+ assertEquals("0 / 0", s.transform(new RankingExpression("0/0"), c).toString());
+ assertEquals("1 + 0.0 / 0.0", s.transform(new RankingExpression("1 + (1-1)/(2-2)"), c).toString());
+ }
+
@Test
public void testSimplifyComplexExpression() throws ParseException {
RankingExpression initial = new RankingExpression("sqrt(if (if (INFERRED * 0.9 < INFERRED, GMP, (1 + 1.1) * INFERRED) < INFERRED * INFERRED - INFERRED, if (GMP < 85.80799542793133 * GMP, INFERRED, if (GMP < GMP, tanh(INFERRED), log(76.89956221113943))), tanh(tanh(INFERRED))) * sqrt(sqrt(GMP + INFERRED)) * GMP ) + 13.5 * (1 - GMP) * pow(GMP * 0.1, 2 + 1.1 * 0)");