diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-09-28 16:19:30 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-09-28 16:19:30 +0200 |
commit | 3d49f155fccfa4fc08882b01e7a6e3a982c55212 (patch) | |
tree | 865d6e301e5fcd3fba248807ff980bcc7e18d41f /searchlib/src | |
parent | 7cfc4fa47828261ee1f839a27a437d8bc49eb26f (diff) |
Fold comparisons into the other operators
Diffstat (limited to 'searchlib/src')
11 files changed, 275 insertions, 304 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index afd263f1553..e1db6378fcf 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -2,7 +2,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -28,53 +27,83 @@ public abstract class DoubleCompatibleValue extends Value { public Value negate() { return new DoubleValue(-asDouble()); } @Override - public Value add(Value value) { - return new DoubleValue(asDouble() + value.asDouble()); + public Value not() { + return new BooleanValue(!asBoolean()); } @Override - public Value subtract(Value value) { - return new DoubleValue(asDouble() - value.asDouble()); + public Value or(Value value) { + return new BooleanValue(asBoolean() || value.asBoolean()); } @Override - public Value multiply(Value value) { - return new DoubleValue(asDouble() * value.asDouble()); + public Value and(Value value) { + return new BooleanValue(asBoolean() && value.asBoolean()); } @Override - public Value divide(Value value) { - return new DoubleValue(asDouble() / value.asDouble()); + public Value greaterEqual(Value value) { + return new BooleanValue(this.asDouble() >= value.asDouble()); } @Override - public Value modulo(Value value) { - return new DoubleValue(asDouble() % value.asDouble()); + public Value greater(Value value) { + return new BooleanValue(this.asDouble() > value.asDouble()); } @Override - public Value and(Value value) { - return new BooleanValue(asBoolean() && value.asBoolean()); + public Value lessEqual(Value value) { + return new BooleanValue(this.asDouble() <= value.asDouble()); } @Override - public Value or(Value value) { - return new BooleanValue(asBoolean() || value.asBoolean()); + public Value less(Value value) { + return new BooleanValue(this.asDouble() < value.asDouble()); } @Override - public Value not() { - return new BooleanValue(!asBoolean()); + public Value approx(Value value) { + return new BooleanValue(approxEqual(this.asDouble(), value.asDouble())); } @Override - public Value power(Value value) { - return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble())); + public Value notEqual(Value value) { + return new BooleanValue(this.asDouble() != value.asDouble()); + } + + @Override + public Value equal(Value value) { + return new BooleanValue(this.asDouble() == value.asDouble()); + } + + @Override + public Value add(Value value) { + return new DoubleValue(asDouble() + value.asDouble()); + } + + @Override + public Value subtract(Value value) { + return new DoubleValue(asDouble() - value.asDouble()); + } + + @Override + public Value multiply(Value value) { + return new DoubleValue(asDouble() * value.asDouble()); } @Override - public Value compare(TruthOperator operator, Value value) { - return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); + public Value divide(Value value) { + return new DoubleValue(asDouble() / value.asDouble()); + } + + @Override + public Value modulo(Value value) { + return new DoubleValue(asDouble() % value.asDouble()); + } + + @Override + public Value power(Value value) { + return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble())); } @Override @@ -82,4 +111,14 @@ public abstract class DoubleCompatibleValue extends Value { return new DoubleValue(function.evaluate(asDouble(),value.asDouble())); } + static boolean approxEqual(double x, double y) { + if (y < -1.0 || y > 1.0) { + x = Math.nextAfter(x/y, 1.0); + y = 1.0; + } else { + x = Math.nextAfter(x, y); + } + return x == y; + } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index 2c2d5eead05..3c09c644147 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -57,55 +56,83 @@ public class StringValue extends Value { } @Override - public Value add(Value value) { - return new StringValue(value + value.toString()); + public Value not() { + throw new UnsupportedOperationException("String values ('" + value + "') do not support not"); } @Override - public Value subtract(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction"); + public Value or(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support or"); } @Override - public Value multiply(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication"); + public Value and(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support and"); } @Override - public Value divide(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support division"); + public Value greaterEqual(Value argument) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support greaterEqual"); } @Override - public Value modulo(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo"); + public Value greater(Value argument) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support greater"); } @Override - public Value and(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support and"); + public Value lessEqual(Value argument) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support lessEqual"); } @Override - public Value or(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support or"); + public Value less(Value argument) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support less"); } @Override - public Value not() { - throw new UnsupportedOperationException("String values ('" + value + "') do not support not"); + public Value approx(Value argument) { + return new BooleanValue(this.asDouble() == argument.asDouble()); } @Override - public Value power(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support ^"); + public Value notEqual(Value argument) { + return new BooleanValue(this.asDouble() != argument.asDouble()); + } + + @Override + public Value equal(Value argument) { + return new BooleanValue(this.asDouble() == argument.asDouble()); } @Override - public Value compare(TruthOperator operator, Value value) { - if (operator.equals(TruthOperator.EQUAL)) - return new BooleanValue(this.equals(value)); - throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='"); + public Value add(Value value) { + return new StringValue(value + value.toString()); + } + + @Override + public Value subtract(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction"); + } + + @Override + public Value multiply(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication"); + } + + @Override + public Value divide(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support division"); + } + + @Override + public Value modulo(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo"); + } + + @Override + public Value power(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support ^"); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index b37bbb543eb..73ea0b23986 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.api.annotations.Beta; import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -52,6 +51,83 @@ public class TensorValue extends Value { } @Override + public Value not() { + return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0)); + } + + @Override + public Value or(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 )); + else + return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0)); + } + + @Override + public Value and(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 )); + else + return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0)); + } + + @Override + public Value greaterEqual(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.largerOrEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value >= argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value greater(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.larger(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value > argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value lessEqual(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value <= argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value less(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.smaller(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value < argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value approx(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.approxEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> DoubleCompatibleValue.approxEqual(value, argument.asDouble()) ? 1.0 : 0.0)); + } + + @Override + public Value notEqual(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.notEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value != argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value equal(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.equal(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value == argument.asDouble() ? 1.0 : 0.0)); + } + + @Override public Value add(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.add(((TensorValue)argument).value)); @@ -92,27 +168,6 @@ public class TensorValue extends Value { } @Override - public Value and(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 )); - else - return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0)); - } - - @Override - public Value or(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 )); - else - return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0)); - } - - @Override - public Value not() { - return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0)); - } - - @Override public Value power(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.pow(((TensorValue)argument).value)); @@ -123,24 +178,6 @@ public class TensorValue extends Value { public Tensor asTensor() { return value; } @Override - public Value compare(TruthOperator operator, Value argument) { - return new TensorValue(compareTensor(operator, argument.asTensor())); - } - - private Tensor compareTensor(TruthOperator operator, Tensor argument) { - switch (operator) { - case LARGER: return value.larger(argument); - case LARGEREQUAL: return value.largerOrEqual(argument); - case SMALLER: return value.smaller(argument); - case SMALLEREQUAL: return value.smallerOrEqual(argument); - case EQUAL: return value.equal(argument); - case NOTEQUAL: return value.notEqual(argument); - case APPROX_EQUAL: return value.approxEqual(argument); - default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator); - } - } - - @Override public Value function(Function function, Value arg) { if (arg instanceof TensorValue) return new TensorValue(functionOnTensor(function, arg.asTensor())); @@ -149,17 +186,17 @@ public class TensorValue extends Value { } private Tensor functionOnTensor(Function function, Tensor argument) { - switch (function) { - case min: return value.min(argument); - case max: return value.max(argument); - case atan2: return value.atan2(argument); - case pow: return value.pow(argument); - case fmod: return value.fmod(argument); - case ldexp: return value.ldexp(argument); - case bit: return value.bit(argument); - case hamming: return value.hamming(argument); - default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); - } + return switch (function) { + case min -> value.min(argument); + case max -> value.max(argument); + case atan2 -> value.atan2(argument); + case pow -> value.pow(argument); + case fmod -> value.fmod(argument); + case ldexp -> value.ldexp(argument); + case bit -> value.bit(argument); + case hamming -> value.hamming(argument); + default -> throw new UnsupportedOperationException("Cannot combine two tensors using " + function); + }; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 207603c5038..99663fe8d0d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -1,9 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation; -import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -51,27 +49,24 @@ public abstract class Value { public abstract Value negate(); - public abstract Value add(Value value); + public abstract Value not(); + public abstract Value or(Value value); + public abstract Value and(Value value); + public abstract Value greaterEqual(Value value); + public abstract Value greater(Value value); + public abstract Value lessEqual(Value value); + public abstract Value less(Value value); + public abstract Value approx(Value value); + public abstract Value notEqual(Value value); + public abstract Value equal(Value value); + public abstract Value add(Value value); public abstract Value subtract(Value value); - public abstract Value multiply(Value value); - public abstract Value divide(Value value); - public abstract Value modulo(Value value); - - public abstract Value and(Value value); - - public abstract Value or(Value value); - - public abstract Value not(); - public abstract Value power(Value value); - /** Perform the comparison specified by the operator between this value and the given value */ - public abstract Value compare(TruthOperator operator, Value value); - /** Perform the given binary function on this value and the given value */ public abstract Value function(Function function, Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java index 420f1f459f3..cf4c35d94af 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java @@ -114,15 +114,15 @@ public class GBDTOptimizer extends Optimizer { /** Consumes the if condition and return the size of the values resulting, for convenience */ private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) { - if (condition instanceof ComparisonNode) { - ComparisonNode comparison = (ComparisonNode)condition; - if (comparison.getOperator() == TruthOperator.SMALLER) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.getLeftCondition(), context)); - else if (comparison.getOperator() == TruthOperator.EQUAL) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.getLeftCondition(), context)); + if (isBinaryComparison(condition)) { + ArithmeticNode comparison = (ArithmeticNode)condition; + if (comparison.operators().get(0) == ArithmeticOperator.LESS) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context)); + else if (comparison.operators().get(0) == ArithmeticOperator.EQUAL) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context)); else - throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.getOperator()); - values.add(toValue(comparison.getRightCondition())); + throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0)); + values.add(toValue(comparison.children().get(1))); } else if (condition instanceof SetMembershipNode) { SetMembershipNode setMembership = (SetMembershipNode)condition; @@ -131,17 +131,15 @@ public class GBDTOptimizer extends Optimizer { for (ExpressionNode setElementNode : setMembership.getSetValues()) values.add(toValue(setElementNode)); } - else if (condition instanceof NotNode) { // handle if inversion: !(a >= b) - NotNode notNode = (NotNode)condition; - if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) { - EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0); - if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) { - ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0); - if (comparison.getOperator() == TruthOperator.LARGEREQUAL) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context)); + else if (condition instanceof NotNode notNode) { // handle if inversion: !(a >= b) + if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) { + if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) { + ArithmeticNode comparison = (ArithmeticNode)embracedNode.children().get(0); + if (comparison.operators().get(0) == ArithmeticOperator.GREATEREQUAL) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context)); else - throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator()); - values.add(toValue(comparison.getRightCondition())); + throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0)); + values.add(toValue(comparison.children().get(1))); } } } @@ -152,12 +150,24 @@ public class GBDTOptimizer extends Optimizer { return values.size(); } + private boolean isBinaryComparison(ExpressionNode condition) { + if ( ! (condition instanceof ArithmeticNode binaryNode)) return false; + if (binaryNode.operators().size() != 1) return false; + if (binaryNode.operators().get(0) == ArithmeticOperator.GREATEREQUAL) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.GREATER) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.LESSEQUAL) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.LESS) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.APPROX) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.NOTEQUAL) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.EQUAL) return true; + return false; + } + private double getVariableIndex(ExpressionNode node, ContextIndex context) { - if (!(node instanceof ReferenceNode)) { + if (!(node instanceof ReferenceNode fNode)) { throw new IllegalArgumentException("Contained a left-hand comparison expression " + "which was not a feature value but was: " + node); } - ReferenceNode fNode = (ReferenceNode)node; Integer index = context.getIndex(fNode.toString()); if (index == null) { throw new IllegalStateException("The ranking expression contained feature '" + fNode.getName() + @@ -177,8 +187,7 @@ public class GBDTOptimizer extends Optimizer { value.getClass().getSimpleName() + " (" + value + ") in a set test: " + node); } - if (node instanceof NegativeNode) { - NegativeNode nNode = (NegativeNode)node; + if (node instanceof NegativeNode nNode) { if (!(nNode.getValue() instanceof ConstantNode)) { throw new IllegalArgumentException("Contained a negation of a non-number: " + nNode.getValue()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java index a2521398529..435c92ff7da 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java @@ -14,26 +14,16 @@ import java.util.function.BiFunction; */ public enum ArithmeticOperator { -/* -struct Sub : OperatorHelper<Sub> { Sub() : Helper("-", 101, LEFT) {}}; -struct Mul : OperatorHelper<Mul> { Mul() : Helper("*", 102, LEFT) {}}; -struct Div : OperatorHelper<Div> { Div() : Helper("/", 102, LEFT) {}}; -struct Mod : OperatorHelper<Mod> { Mod() : Helper("%", 102, LEFT) {}}; -struct Pow : OperatorHelper<Pow> { Pow() : Helper("^", 103, RIGHT) {}}; -struct Equal : OperatorHelper<Equal> { Equal() : Helper("==", 10, LEFT) {}}; -struct NotEqual : OperatorHelper<NotEqual> { NotEqual() : Helper("!=", 10, LEFT) {}}; -struct Approx : OperatorHelper<Approx> { Approx() : Helper("~=", 10, LEFT) {}}; -struct Less : OperatorHelper<Less> { Less() : Helper("<", 10, LEFT) {}}; -struct LessEqual : OperatorHelper<LessEqual> { LessEqual() : Helper("<=", 10, LEFT) {}}; -struct Greater : OperatorHelper<Greater> { Greater() : Helper(">", 10, LEFT) {}}; -struct GreaterEqual : OperatorHelper<GreaterEqual> { GreaterEqual() : Helper(">=", 10, LEFT) {}}; -struct And : OperatorHelper<And> { And() : Helper("&&", 2, LEFT) {}}; -struct Or : OperatorHelper<Or> { Or() : Helper("||", 1, LEFT) {}}; - */ - // In order from lowest to highest precedence OR("||", (x, y) -> x.or(y)), AND("&&", (x, y) -> x.and(y)), + GREATEREQUAL(">=", (x, y) -> x.greaterEqual(y)), + GREATER(">", (x, y) -> x.greater(y)), + LESSEQUAL("<=", (x, y) -> x.lessEqual(y)), + LESS("<", (x, y) -> x.less(y)), + APPROX("~=", (x, y) -> x.approx(y)), + NOTEQUAL("!=", (x, y) -> x.notEqual(y)), + EQUAL("==", (x, y) -> x.equal(y)), PLUS("+", (x, y) -> x.add(y)), MINUS("-", (x, y) -> x.subtract(y)), MULTIPLY("*", (x, y) -> x.multiply(y)), diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java deleted file mode 100644 index e726a351f74..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; - -import java.util.Deque; -import java.util.List; -import java.util.Objects; - -/** - * A node which returns the outcome of a comparison. - * - * @author bratseth - */ -public class ComparisonNode extends BooleanNode { - - /** The operator string of this condition. */ - private final TruthOperator operator; - - private final List<ExpressionNode> conditions; - - public ComparisonNode(ExpressionNode leftCondition, TruthOperator operator, ExpressionNode rightCondition) { - conditions = List.of(leftCondition, rightCondition); - this.operator = operator; - } - - @Override - public List<ExpressionNode> children() { - return conditions; - } - - public TruthOperator getOperator() { return operator; } - - public ExpressionNode getLeftCondition() { return conditions.get(0); } - - public ExpressionNode getRightCondition() { return conditions.get(1); } - - @Override - public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) { - getLeftCondition().toString(string, context, path, this).append(' ').append(operator).append(' '); - return getRightCondition().toString(string, context, path, this); - } - - @Override - public TensorType type(TypeContext<Reference> context) { - return TensorType.empty; // by definition - } - - @Override - public Value evaluate(Context context) { - Value leftValue = getLeftCondition().evaluate(context); - Value rightValue = getRightCondition().evaluate(context); - return leftValue.compare(operator,rightValue); - } - - @Override - public ComparisonNode setChildren(List<ExpressionNode> children) { - if (children.size() != 2) throw new IllegalArgumentException("A comparison test must have 2 children"); - return new ComparisonNode(children.get(0), operator, children.get(1)); - } - - @Override - public int hashCode() { return Objects.hash(operator, conditions); } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java deleted file mode 100644 index fc259867923..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import java.io.Serializable; - -/** - * A mathematical operator - * - * @author bratseth - */ -public enum TruthOperator implements Serializable { - - SMALLER("<") { public boolean evaluate(double x, double y) { return x<y; } }, - SMALLEREQUAL("<=") { public boolean evaluate(double x, double y) { return x<=y; } }, - EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } }, - APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } }, - LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } }, - LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }, - NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } }; - - private final String operatorString; - - TruthOperator(String operatorString) { - this.operatorString = operatorString; - } - - /** Perform the truth operation on the input */ - public abstract boolean evaluate(double x, double y); - - @Override - public String toString() { return operatorString; } - - public static TruthOperator fromString(String string) { - for (TruthOperator operator : values()) - if (operator.toString().equals(string)) - return operator; - throw new IllegalArgumentException("Illegal truth operator '" + string + "'"); - } - - private static boolean approxEqual(double x,double y) { - if (y < -1.0 || y > 1.0) { - x = Math.nextAfter(x/y, 1.0); - y = 1.0; - } else { - x = Math.nextAfter(x, y); - } - return x==y; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index 90861e64164..7a34f5b7b03 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -6,7 +6,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; -import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; @@ -44,7 +43,6 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { private boolean hasSingleUndividableChild(EmbracedNode node) { if (node.children().size() > 1) return false; if (node.children().get(0) instanceof ArithmeticNode) return false; - if (node.children().get(0) instanceof ComparisonNode) return false; return true; } diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index ebe1e048247..2261d39829c 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -72,13 +72,13 @@ TOKEN : <COMMA: ","> | <COLON: ":"> | - <LE: "<="> | - <LT: "<"> | - <EQ: "=="> | - <NQ: "!="> | - <AQ: "~="> | - <GE: ">="> | - <GT: ">"> | + <GREATEREQUAL: ">="> | + <GREATER: ">"> | + <LESSEQUAL: "<="> | + <LESS: "<"> | + <APPROX: "~="> | + <NOTEQUAL: "!="> | + <EQUAL: "=="> | <STRING: ("\"" (~["\""] | "\\\"")* "\"") | ("'" (~["'"] | "\\'")* "'")> | @@ -188,14 +188,12 @@ ExpressionNode expression() : { ExpressionNode left, right; List<ExpressionNode> rightList; - TruthOperator comparatorOp; } { ( left = arithmeticExpression() ( - ( comparatorOp = comparator() right = arithmeticExpression() { left = new ComparisonNode(left, comparatorOp, right); } ) | ( <IN> rightList = expressionList() { left = new SetMembershipNode(left, rightList); } ) - ) * + ) ? ) { return left; } } @@ -214,29 +212,26 @@ ExpressionNode arithmeticExpression() : ArithmeticOperator arithmetic() : { } { - ( <ADD> { return ArithmeticOperator.PLUS; } | - <SUB> { return ArithmeticOperator.MINUS; } | - <DIV> { return ArithmeticOperator.DIVIDE; } | - <MUL> { return ArithmeticOperator.MULTIPLY; } | - <MOD> { return ArithmeticOperator.MODULO; } | - <AND> { return ArithmeticOperator.AND; } | - <OR> { return ArithmeticOperator.OR; } | - <POWOP> { return ArithmeticOperator.POWER; } ) + ( + <OR> { return ArithmeticOperator.OR; } | + <AND> { return ArithmeticOperator.AND; } | + <GREATEREQUAL> { return ArithmeticOperator.GREATEREQUAL; } | + <GREATER> { return ArithmeticOperator.GREATER; } | + <LESSEQUAL> { return ArithmeticOperator.LESSEQUAL; } | + <LESS> { return ArithmeticOperator.LESS; } | + <APPROX> { return ArithmeticOperator.APPROX; } | + <NOTEQUAL> { return ArithmeticOperator.NOTEQUAL; } | + <EQUAL> { return ArithmeticOperator.EQUAL; } | + <ADD> { return ArithmeticOperator.PLUS; } | + <SUB> { return ArithmeticOperator.MINUS; } | + <DIV> { return ArithmeticOperator.DIVIDE; } | + <MUL> { return ArithmeticOperator.MULTIPLY; } | + <MOD> { return ArithmeticOperator.MODULO; } | + <POWOP> { return ArithmeticOperator.POWER; } + ) { return null; } } -TruthOperator comparator() : { } -{ - ( <LE> { return TruthOperator.SMALLEREQUAL; } | - <LT> { return TruthOperator.SMALLER; } | - <EQ> { return TruthOperator.EQUAL; } | - <NQ> { return TruthOperator.NOTEQUAL; } | - <AQ> { return TruthOperator.APPROX_EQUAL; } | - <GE> { return TruthOperator.LARGEREQUAL; } | - <GT> { return TruthOperator.LARGER; } ) - { return null; } -} - ExpressionNode value() : { ExpressionNode value; @@ -665,7 +660,7 @@ TensorType.Value optionalTensorValueTypeParameter() : String valueType = "double"; } { - ( <LT> valueType = identifier() <GT> )? + ( <LESS> valueType = identifier() <GREATER> )? { return TensorType.Value.fromId(valueType); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index ad50a423eb9..b1ac4b9e3ca 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -198,9 +198,9 @@ public class EvaluationTestCase { tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", "tensor0 % 2 == map(tensor0, f(x) (x % 2))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", - "tensor0 || 1 == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + "(tensor0 || 1) == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", - "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + "(tensor0 && 1) == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }"); |