diff options
4 files changed, 51 insertions, 22 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 0768999c52f..aca88ff864e 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -728,6 +728,7 @@ "public abstract com.yahoo.tensor.TensorType type()", "public abstract double asDouble()", "public com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue asDoubleValue()", + "public boolean isNaN()", "public abstract com.yahoo.tensor.Tensor asTensor()", "protected com.yahoo.tensor.Tensor doubleAsTensor(double)", "public abstract boolean hasDouble()", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 382cbb7ce9a..071fc6e9fc2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -30,6 +30,11 @@ public abstract class Value { return new DoubleValue(asDouble()); } + /** Returns true if this has a double value which is NaN */ + public boolean isNaN() { + return hasDouble() && Double.isNaN(asDouble()); + } + /** Returns this as a tensor value */ public abstract Tensor asTensor(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index 394faee71a8..1522e0025c7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.transform; +import com.yahoo.document.update.ArithmeticValueUpdate; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; @@ -46,6 +47,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { } private ExpressionNode transformArithmetic(ArithmeticNode node) { + // Fold the subset of expressions that are constant (such that in "1 + 2 + var") if (node.children().size() > 1) { List<ExpressionNode> children = new ArrayList<>(node.children()); List<ArithmeticOperator> operators = new ArrayList<>(node.operators()); @@ -54,32 +56,34 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { node = new ArithmeticNode(children, operators); } - if (isConstant(node)) + if (isConstant(node) && ! node.evaluate(null).isNaN()) return new ConstantNode(node.evaluate(null)); - else if (allMultiplicationOrDivision(node) && hasZero(node)) // disregarding the /0 case + else if (allMultiplicationOrDivision(node) && hasZero(node) && ! hasDivisionByZero(node)) return new ConstantNode(new DoubleValue(0)); else return node; } - private void transform(ArithmeticOperator operator, List<ExpressionNode> children, List<ArithmeticOperator> operators) { + private void transform(ArithmeticOperator operatorToTransform, + List<ExpressionNode> children, List<ArithmeticOperator> operators) { int i = 0; while (i < children.size()-1) { - if ( ! operators.get(i).equals(operator)) { - i++; - continue; - } - - ExpressionNode child1 = children.get(i); - ExpressionNode child2 = children.get(i + 1); - if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) { - Value evaluated = new ArithmeticNode(child1, operators.remove(i), child2).evaluate(null); - children.set(i, new ConstantNode(evaluated.freeze())); - children.remove(i+1); + boolean transformed = false; + if ( operators.get(i).equals(operatorToTransform)) { + ExpressionNode child1 = children.get(i); + ExpressionNode child2 = children.get(i + 1); + if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) { + Value evaluated = new ArithmeticNode(child1, operators.get(i), child2).evaluate(null); + if ( ! evaluated.isNaN()) { // Don't replace by NaN + operators.remove(i); + children.set(i, new ConstantNode(evaluated.freeze())); + children.remove(i + 1); + transformed = true; + } + } } - else { // try the next index + if ( ! transformed) // try the next index i++; - } } } @@ -105,22 +109,34 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { private boolean allMultiplicationOrDivision(ArithmeticNode node) { for (ArithmeticOperator o : node.operators()) - if (o == ArithmeticOperator.PLUS || o == ArithmeticOperator.MINUS) + if (o != ArithmeticOperator.MULTIPLY && o != ArithmeticOperator.DIVIDE) return false; return true; } private boolean hasZero(ArithmeticNode node) { for (ExpressionNode child : node.children()) { - if ( ! (child instanceof ConstantNode)) continue; - ConstantNode constant = (ConstantNode)child; - if ( ! constant.getValue().hasDouble()) return false; - if (constant.getValue().asDouble() == 0.0) + if (isZero(child)) return true; } return false; } + private boolean hasDivisionByZero(ArithmeticNode node) { + for (int i = 1; i < node.children().size(); i++) { + if ( node.operators().get(i - 1) == ArithmeticOperator.DIVIDE && isZero(node.children().get(i))) + return true; + } + return false; + } + + private boolean isZero(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) return false; + ConstantNode constant = (ConstantNode)node; + if ( ! constant.getValue().hasDouble()) return false; + return constant.getValue().asDouble() == 0.0; + } + private boolean isConstant(ExpressionNode node) { if (node instanceof ConstantNode) return true; if (node instanceof ReferenceNode) return false; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java index cf761aac2d3..7861da717f3 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java @@ -42,7 +42,14 @@ public class SimplifierTestCase { assertEquals("attribute(number) * 1.5 - 0.001 * attribute(number)", s.transform(new RankingExpression("attribute(number) * 1.5 - 0.001 * attribute(number)"), c).toString()); } - // A black box test verifying we are not screwing up real expressions + @Test + public void testNaNExpression() throws ParseException { + Simplifier s = new Simplifier(); + TransformContext c = new TransformContext(Collections.emptyMap(), new MapTypeContext()); + assertEquals("0 / 0", s.transform(new RankingExpression("0/0"), c).toString()); + assertEquals("1 + 0.0 / 0.0", s.transform(new RankingExpression("1 + (1-1)/(2-2)"), c).toString()); + } + @Test public void testSimplifyComplexExpression() throws ParseException { RankingExpression initial = new RankingExpression("sqrt(if (if (INFERRED * 0.9 < INFERRED, GMP, (1 + 1.1) * INFERRED) < INFERRED * INFERRED - INFERRED, if (GMP < 85.80799542793133 * GMP, INFERRED, if (GMP < GMP, tanh(INFERRED), log(76.89956221113943))), tanh(tanh(INFERRED))) * sqrt(sqrt(GMP + INFERRED)) * GMP ) + 13.5 * (1 - GMP) * pow(GMP * 0.1, 2 + 1.1 * 0)"); |