summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main/java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java131
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java73
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java139
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java27
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java61
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java63
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java69
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java22
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java)38
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java67
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java50
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java40
13 files changed, 415 insertions, 377 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index afd263f1553..2405d1bd528 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -2,7 +2,6 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -28,58 +27,146 @@ public abstract class DoubleCompatibleValue extends Value {
public Value negate() { return new DoubleValue(-asDouble()); }
@Override
- public Value add(Value value) {
- return new DoubleValue(asDouble() + value.asDouble());
+ public Value not() {
+ return new BooleanValue(!asBoolean());
}
@Override
- public Value subtract(Value value) {
- return new DoubleValue(asDouble() - value.asDouble());
+ public Value or(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.or(this);
+ else
+ return new BooleanValue(asBoolean() || value.asBoolean());
}
@Override
- public Value multiply(Value value) {
- return new DoubleValue(asDouble() * value.asDouble());
+ public Value and(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.and(this);
+ else
+ return new BooleanValue(asBoolean() && value.asBoolean());
}
@Override
- public Value divide(Value value) {
- return new DoubleValue(asDouble() / value.asDouble());
+ public Value largerOrEqual(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.largerOrEqual(this);
+ else
+ return new BooleanValue(this.asDouble() >= value.asDouble());
}
@Override
- public Value modulo(Value value) {
- return new DoubleValue(asDouble() % value.asDouble());
+ public Value larger(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.larger(this);
+ else
+ return new BooleanValue(this.asDouble() > value.asDouble());
}
@Override
- public Value and(Value value) {
- return new BooleanValue(asBoolean() && value.asBoolean());
+ public Value smallerOrEqual(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.smallerOrEqual(this);
+ else
+ return new BooleanValue(this.asDouble() <= value.asDouble());
}
@Override
- public Value or(Value value) {
- return new BooleanValue(asBoolean() || value.asBoolean());
+ public Value smaller(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.smaller(this);
+ else
+ return new BooleanValue(this.asDouble() < value.asDouble());
}
@Override
- public Value not() {
- return new BooleanValue(!asBoolean());
+ public Value approxEqual(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.approxEqual(this);
+ else
+ return new BooleanValue(approxEqual(this.asDouble(), value.asDouble()));
}
@Override
- public Value power(Value value) {
- return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble()));
+ public Value notEqual(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.notEqual(this);
+ else
+ return new BooleanValue(this.asDouble() != value.asDouble());
+ }
+
+ @Override
+ public Value equal(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.equal(this);
+ else
+ return new BooleanValue(this.asDouble() == value.asDouble());
+ }
+
+ @Override
+ public Value add(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.add(this);
+ else
+ return new DoubleValue(asDouble() + value.asDouble());
+ }
+
+ @Override
+ public Value subtract(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.subtract(this);
+ else
+ return new DoubleValue(asDouble() - value.asDouble());
+ }
+
+ @Override
+ public Value multiply(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.multiply(this);
+ else
+ return new DoubleValue(asDouble() * value.asDouble());
+ }
+
+ @Override
+ public Value divide(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.divide(this);
+ else
+ return new DoubleValue(asDouble() / value.asDouble());
+ }
+
+ @Override
+ public Value modulo(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.modulo(this);
+ else
+ return new DoubleValue(asDouble() % value.asDouble());
}
@Override
- public Value compare(TruthOperator operator, Value value) {
- return new BooleanValue(operator.evaluate(asDouble(), value.asDouble()));
+ public Value power(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.power(this);
+ else
+ return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble()));
}
@Override
public Value function(Function function, Value value) {
- return new DoubleValue(function.evaluate(asDouble(),value.asDouble()));
+ if (value instanceof TensorValue tensor)
+ return tensor.function(function, this);
+ else
+ return new DoubleValue(function.evaluate(asDouble(), value.asDouble()));
+ }
+
+ static boolean approxEqual(double x, double y) {
+ if (y < -1.0 || y > 1.0) {
+ x = Math.nextAfter(x/y, 1.0);
+ y = 1.0;
+ } else {
+ x = Math.nextAfter(x, y);
+ }
+ return x == y;
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index 2c2d5eead05..a585c989954 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -57,55 +56,83 @@ public class StringValue extends Value {
}
@Override
- public Value add(Value value) {
- return new StringValue(value + value.toString());
+ public Value not() {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support not");
}
@Override
- public Value subtract(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction");
+ public Value or(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support or");
}
@Override
- public Value multiply(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication");
+ public Value and(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support and");
}
@Override
- public Value divide(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support division");
+ public Value largerOrEqual(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support greaterEqual");
}
@Override
- public Value modulo(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo");
+ public Value larger(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support greater");
}
@Override
- public Value and(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support and");
+ public Value smallerOrEqual(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support lessEqual");
}
@Override
- public Value or(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support or");
+ public Value smaller(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support less");
}
@Override
- public Value not() {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support not");
+ public Value approxEqual(Value argument) {
+ return new BooleanValue(this.asDouble() == argument.asDouble());
}
@Override
- public Value power(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support ^");
+ public Value notEqual(Value argument) {
+ return new BooleanValue(this.asDouble() != argument.asDouble());
+ }
+
+ @Override
+ public Value equal(Value argument) {
+ return new BooleanValue(this.asDouble() == argument.asDouble());
}
@Override
- public Value compare(TruthOperator operator, Value value) {
- if (operator.equals(TruthOperator.EQUAL))
- return new BooleanValue(this.equals(value));
- throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='");
+ public Value add(Value value) {
+ return new StringValue(value + value.toString());
+ }
+
+ @Override
+ public Value subtract(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction");
+ }
+
+ @Override
+ public Value multiply(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication");
+ }
+
+ @Override
+ public Value divide(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support division");
+ }
+
+ @Override
+ public Value modulo(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo");
+ }
+
+ @Override
+ public Value power(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support ^");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index b37bbb543eb..25e03c75376 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.api.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -52,6 +51,83 @@ public class TensorValue extends Value {
}
@Override
+ public Value not() {
+ return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value or(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value and(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value largerOrEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.largerOrEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value >= argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value larger(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.larger(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value > argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value smallerOrEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value <= argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value smaller(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.smaller(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value < argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value approxEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.approxEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> DoubleCompatibleValue.approxEqual(value, argument.asDouble()) ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value notEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.notEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value != argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value equal(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.equal(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value == argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
public Value add(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.add(((TensorValue)argument).value));
@@ -92,27 +168,6 @@ public class TensorValue extends Value {
}
@Override
- public Value and(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
- else
- return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0));
- }
-
- @Override
- public Value or(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
- else
- return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0));
- }
-
- @Override
- public Value not() {
- return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0));
- }
-
- @Override
public Value power(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.pow(((TensorValue)argument).value));
@@ -123,24 +178,6 @@ public class TensorValue extends Value {
public Tensor asTensor() { return value; }
@Override
- public Value compare(TruthOperator operator, Value argument) {
- return new TensorValue(compareTensor(operator, argument.asTensor()));
- }
-
- private Tensor compareTensor(TruthOperator operator, Tensor argument) {
- switch (operator) {
- case LARGER: return value.larger(argument);
- case LARGEREQUAL: return value.largerOrEqual(argument);
- case SMALLER: return value.smaller(argument);
- case SMALLEREQUAL: return value.smallerOrEqual(argument);
- case EQUAL: return value.equal(argument);
- case NOTEQUAL: return value.notEqual(argument);
- case APPROX_EQUAL: return value.approxEqual(argument);
- default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
- }
- }
-
- @Override
public Value function(Function function, Value arg) {
if (arg instanceof TensorValue)
return new TensorValue(functionOnTensor(function, arg.asTensor()));
@@ -149,17 +186,17 @@ public class TensorValue extends Value {
}
private Tensor functionOnTensor(Function function, Tensor argument) {
- switch (function) {
- case min: return value.min(argument);
- case max: return value.max(argument);
- case atan2: return value.atan2(argument);
- case pow: return value.pow(argument);
- case fmod: return value.fmod(argument);
- case ldexp: return value.ldexp(argument);
- case bit: return value.bit(argument);
- case hamming: return value.hamming(argument);
- default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
- }
+ return switch (function) {
+ case min -> value.min(argument);
+ case max -> value.max(argument);
+ case atan2 -> value.atan2(argument);
+ case pow -> value.pow(argument);
+ case fmod -> value.fmod(argument);
+ case ldexp -> value.ldexp(argument);
+ case bit -> value.bit(argument);
+ case hamming -> value.hamming(argument);
+ default -> throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
+ };
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 207603c5038..5de2138147e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -1,9 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;
-import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -51,27 +49,24 @@ public abstract class Value {
public abstract Value negate();
- public abstract Value add(Value value);
+ public abstract Value not();
+ public abstract Value or(Value value);
+ public abstract Value and(Value value);
+ public abstract Value largerOrEqual(Value value);
+ public abstract Value larger(Value value);
+ public abstract Value smallerOrEqual(Value value);
+ public abstract Value smaller(Value value);
+ public abstract Value approxEqual(Value value);
+ public abstract Value notEqual(Value value);
+ public abstract Value equal(Value value);
+ public abstract Value add(Value value);
public abstract Value subtract(Value value);
-
public abstract Value multiply(Value value);
-
public abstract Value divide(Value value);
-
public abstract Value modulo(Value value);
-
- public abstract Value and(Value value);
-
- public abstract Value or(Value value);
-
- public abstract Value not();
-
public abstract Value power(Value value);
- /** Perform the comparison specified by the operator between this value and the given value */
- public abstract Value compare(TruthOperator operator, Value value);
-
/** Perform the given binary function on this value and the given value */
public abstract Value function(Function function, Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
index 6ab483800a7..a3fc6aae9ac 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
@@ -5,8 +5,8 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -84,12 +84,12 @@ public class GBDTForestOptimizer extends Optimizer {
currentTreesOptimized++;
return true;
}
- if (!(node instanceof ArithmeticNode)) {
+ if (!(node instanceof OperationNode)) {
return false;
}
- ArithmeticNode aNode = (ArithmeticNode)node;
- for (ArithmeticOperator op : aNode.operators()) {
- if (op != ArithmeticOperator.PLUS) {
+ OperationNode aNode = (OperationNode)node;
+ for (Operator op : aNode.operators()) {
+ if (op != Operator.plus) {
return false;
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
index 420f1f459f3..7ba671e62eb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
@@ -61,13 +61,13 @@ public class GBDTOptimizer extends Optimizer {
* @return the optimized expression
*/
private ExpressionNode optimize(ExpressionNode node, ContextIndex context) {
- if (node instanceof ArithmeticNode) {
- Iterator<ExpressionNode> childIt = ((ArithmeticNode)node).children().iterator();
+ if (node instanceof OperationNode) {
+ Iterator<ExpressionNode> childIt = ((OperationNode)node).children().iterator();
ExpressionNode ret = optimize(childIt.next(), context);
- Iterator<ArithmeticOperator> operIt = ((ArithmeticNode)node).operators().iterator();
+ Iterator<Operator> operIt = ((OperationNode)node).operators().iterator();
while (childIt.hasNext() && operIt.hasNext()) {
- ret = ArithmeticNode.resolve(ret, operIt.next(), optimize(childIt.next(), context));
+ ret = OperationNode.resolve(ret, operIt.next(), optimize(childIt.next(), context));
}
return ret;
}
@@ -114,15 +114,15 @@ public class GBDTOptimizer extends Optimizer {
/** Consumes the if condition and return the size of the values resulting, for convenience */
private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) {
- if (condition instanceof ComparisonNode) {
- ComparisonNode comparison = (ComparisonNode)condition;
- if (comparison.getOperator() == TruthOperator.SMALLER)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.getLeftCondition(), context));
- else if (comparison.getOperator() == TruthOperator.EQUAL)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.getLeftCondition(), context));
+ if (isBinaryComparison(condition)) {
+ OperationNode comparison = (OperationNode)condition;
+ if (comparison.operators().get(0) == Operator.smaller)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context));
+ else if (comparison.operators().get(0) == Operator.equal)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context));
else
- throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.getOperator());
- values.add(toValue(comparison.getRightCondition()));
+ throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0));
+ values.add(toValue(comparison.children().get(1)));
}
else if (condition instanceof SetMembershipNode) {
SetMembershipNode setMembership = (SetMembershipNode)condition;
@@ -131,17 +131,15 @@ public class GBDTOptimizer extends Optimizer {
for (ExpressionNode setElementNode : setMembership.getSetValues())
values.add(toValue(setElementNode));
}
- else if (condition instanceof NotNode) { // handle if inversion: !(a >= b)
- NotNode notNode = (NotNode)condition;
- if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) {
- EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0);
- if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) {
- ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0);
- if (comparison.getOperator() == TruthOperator.LARGEREQUAL)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context));
+ else if (condition instanceof NotNode notNode) { // handle if inversion: !(a >= b)
+ if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) {
+ if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) {
+ OperationNode comparison = (OperationNode)embracedNode.children().get(0);
+ if (comparison.operators().get(0) == Operator.largerOrEqual)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context));
else
- throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator());
- values.add(toValue(comparison.getRightCondition()));
+ throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0));
+ values.add(toValue(comparison.children().get(1)));
}
}
}
@@ -152,12 +150,24 @@ public class GBDTOptimizer extends Optimizer {
return values.size();
}
+ private boolean isBinaryComparison(ExpressionNode condition) {
+ if ( ! (condition instanceof OperationNode binaryNode)) return false;
+ if (binaryNode.operators().size() != 1) return false;
+ if (binaryNode.operators().get(0) == Operator.largerOrEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.larger) return true;
+ if (binaryNode.operators().get(0) == Operator.smallerOrEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.smaller) return true;
+ if (binaryNode.operators().get(0) == Operator.approxEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.notEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.equal) return true;
+ return false;
+ }
+
private double getVariableIndex(ExpressionNode node, ContextIndex context) {
- if (!(node instanceof ReferenceNode)) {
+ if (!(node instanceof ReferenceNode fNode)) {
throw new IllegalArgumentException("Contained a left-hand comparison expression " +
"which was not a feature value but was: " + node);
}
- ReferenceNode fNode = (ReferenceNode)node;
Integer index = context.getIndex(fNode.toString());
if (index == null) {
throw new IllegalStateException("The ranking expression contained feature '" + fNode.getName() +
@@ -177,8 +187,7 @@ public class GBDTOptimizer extends Optimizer {
value.getClass().getSimpleName() + " (" + value + ") in a set test: " + node);
}
- if (node instanceof NegativeNode) {
- NegativeNode nNode = (NegativeNode)node;
+ if (node instanceof NegativeNode nNode) {
if (!(nNode.getValue() instanceof ConstantNode)) {
throw new IllegalArgumentException("Contained a negation of a non-number: " + nNode.getValue());
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
deleted file mode 100644
index 959045a63a0..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-
-import java.util.List;
-
-/**
- * A mathematical operator
- *
- * @author bratseth
- */
-public enum ArithmeticOperator {
-
- OR(0, "||") { public Value evaluate(Value x, Value y) {
- return x.or(y);
- }},
- AND(1, "&&") { public Value evaluate(Value x, Value y) {
- return x.and(y);
- }},
- PLUS(2, "+") { public Value evaluate(Value x, Value y) {
- return x.add(y);
- }},
- MINUS(3, "-") { public Value evaluate(Value x, Value y) {
- return x.subtract(y);
- }},
- MULTIPLY(4, "*") { public Value evaluate(Value x, Value y) {
- return x.multiply(y);
- }},
- DIVIDE(5, "/") { public Value evaluate(Value x, Value y) {
- return x.divide(y);
- }},
- MODULO(6, "%") { public Value evaluate(Value x, Value y) {
- return x.modulo(y);
- }},
- POWER(7, "^") { public Value evaluate(Value x, Value y) {
- return x.power(y);
- }};
-
- /** A list of all the operators in this in order of decreasing precedence */
- public static final List<ArithmeticOperator> operatorsByPrecedence = List.of(POWER, MODULO, DIVIDE, MULTIPLY, MINUS, PLUS, AND, OR);
-
- private final int precedence;
- private final String image;
-
- private ArithmeticOperator(int precedence, String image) {
- this.precedence = precedence;
- this.image = image;
- }
-
- /** Returns true if this operator has precedence over the given operator */
- public boolean hasPrecedenceOver(ArithmeticOperator op) {
- return precedence > op.precedence;
- }
-
- public abstract Value evaluate(Value x, Value y);
-
- @Override
- public String toString() {
- return image;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
deleted file mode 100644
index e726a351f74..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-
-import java.util.Deque;
-import java.util.List;
-import java.util.Objects;
-
-/**
- * A node which returns the outcome of a comparison.
- *
- * @author bratseth
- */
-public class ComparisonNode extends BooleanNode {
-
- /** The operator string of this condition. */
- private final TruthOperator operator;
-
- private final List<ExpressionNode> conditions;
-
- public ComparisonNode(ExpressionNode leftCondition, TruthOperator operator, ExpressionNode rightCondition) {
- conditions = List.of(leftCondition, rightCondition);
- this.operator = operator;
- }
-
- @Override
- public List<ExpressionNode> children() {
- return conditions;
- }
-
- public TruthOperator getOperator() { return operator; }
-
- public ExpressionNode getLeftCondition() { return conditions.get(0); }
-
- public ExpressionNode getRightCondition() { return conditions.get(1); }
-
- @Override
- public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
- getLeftCondition().toString(string, context, path, this).append(' ').append(operator).append(' ');
- return getRightCondition().toString(string, context, path, this);
- }
-
- @Override
- public TensorType type(TypeContext<Reference> context) {
- return TensorType.empty; // by definition
- }
-
- @Override
- public Value evaluate(Context context) {
- Value leftValue = getLeftCondition().evaluate(context);
- Value rightValue = getRightCondition().evaluate(context);
- return leftValue.compare(operator,rightValue);
- }
-
- @Override
- public ComparisonNode setChildren(List<ExpressionNode> children) {
- if (children.size() != 2) throw new IllegalArgumentException("A comparison test must have 2 children");
- return new ComparisonNode(children.get(0), operator, children.get(1));
- }
-
- @Override
- public int hashCode() { return Objects.hash(operator, conditions); }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index 9f07f146264..97e9a74f9c8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -106,10 +106,10 @@ public class LambdaFunctionNode extends CompositeNode {
}
private Optional<DoubleBinaryOperator> getDirectEvaluator() {
- if ( ! (functionExpression instanceof ArithmeticNode)) {
+ if ( ! (functionExpression instanceof OperationNode)) {
return Optional.empty();
}
- ArithmeticNode node = (ArithmeticNode) functionExpression;
+ OperationNode node = (OperationNode) functionExpression;
if ( ! (node.children().get(0) instanceof ReferenceNode) || ! (node.children().get(1) instanceof ReferenceNode)) {
return Optional.empty();
}
@@ -121,16 +121,16 @@ public class LambdaFunctionNode extends CompositeNode {
if (node.operators().size() != 1) {
return Optional.empty();
}
- ArithmeticOperator operator = node.operators().get(0);
+ Operator operator = node.operators().get(0);
switch (operator) {
- case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
- case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
- case PLUS: return asFunctionExpression((left, right) -> left + right);
- case MINUS: return asFunctionExpression((left, right) -> left - right);
- case MULTIPLY: return asFunctionExpression((left, right) -> left * right);
- case DIVIDE: return asFunctionExpression((left, right) -> left / right);
- case MODULO: return asFunctionExpression((left, right) -> left % right);
- case POWER: return asFunctionExpression(Math::pow);
+ case or: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
+ case and: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
+ case plus: return asFunctionExpression((left, right) -> left + right);
+ case minus: return asFunctionExpression((left, right) -> left - right);
+ case multiply: return asFunctionExpression((left, right) -> left * right);
+ case divide: return asFunctionExpression((left, right) -> left / right);
+ case modulo: return asFunctionExpression((left, right) -> left % right);
+ case power: return asFunctionExpression(Math::pow);
}
return Optional.empty();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
index c3e39197316..0512e1dad2f 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
@@ -16,26 +16,26 @@ import java.util.List;
import java.util.Objects;
/**
- * A binary mathematical operation
+ * A sequence of binary operations.
*
* @author bratseth
*/
-public final class ArithmeticNode extends CompositeNode {
+public final class OperationNode extends CompositeNode {
private final List<ExpressionNode> children;
- private final List<ArithmeticOperator> operators;
+ private final List<Operator> operators;
- public ArithmeticNode(List<ExpressionNode> children, List<ArithmeticOperator> operators) {
+ public OperationNode(List<ExpressionNode> children, List<Operator> operators) {
this.children = List.copyOf(children);
this.operators = List.copyOf(operators);
}
- public ArithmeticNode(ExpressionNode leftExpression, ArithmeticOperator operator, ExpressionNode rightExpression) {
+ public OperationNode(ExpressionNode leftExpression, Operator operator, ExpressionNode rightExpression) {
this.children = List.of(leftExpression, rightExpression);
this.operators = List.of(operator);
}
- public List<ArithmeticOperator> operators() { return operators; }
+ public List<Operator> operators() { return operators; }
@Override
public List<ExpressionNode> children() { return children; }
@@ -50,7 +50,7 @@ public final class ArithmeticNode extends CompositeNode {
child.next().toString(string, context, path, this);
if (child.hasNext())
string.append(" ");
- for (Iterator<ArithmeticOperator> op = operators.iterator(); op.hasNext() && child.hasNext();) {
+ for (Iterator<Operator> op = operators.iterator(); op.hasNext() && child.hasNext();) {
string.append(op.next().toString()).append(" ");
child.next().toString(string, context, path, this);
if (op.hasNext())
@@ -68,14 +68,14 @@ public final class ArithmeticNode extends CompositeNode {
*/
private boolean nonDefaultPrecedence(CompositeNode parent) {
if ( parent == null) return false;
- if ( ! (parent instanceof ArithmeticNode arithmeticParent)) return false;
+ if ( ! (parent instanceof OperationNode operationParent)) return false;
// The line below can only be correct in both only have one operator.
// Getting this correct is impossible without more work.
// So for now we only handle the simple case correctly, and use a safe approach by adding
// extra parenthesis just in case....
- return arithmeticParent.operators.get(0).hasPrecedenceOver(this.operators.get(0))
- || ((arithmeticParent.operators.size() > 1) || (operators.size() > 1));
+ return operationParent.operators.get(0).hasPrecedenceOver(this.operators.get(0))
+ || ((operationParent.operators.size() > 1) || (operators.size() > 1));
}
@Override
@@ -96,8 +96,8 @@ public final class ArithmeticNode extends CompositeNode {
// Apply in precedence order:
Deque<ValueItem> stack = new ArrayDeque<>();
stack.push(new ValueItem(null, child.next().evaluate(context)));
- for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) {
- ArithmeticOperator op = it.next();
+ for (Iterator<Operator> it = operators.iterator(); it.hasNext() && child.hasNext();) {
+ Operator op = it.next();
if ( ! stack.isEmpty()) {
while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) {
popStack(stack);
@@ -121,30 +121,30 @@ public final class ArithmeticNode extends CompositeNode {
public CompositeNode setChildren(List<ExpressionNode> newChildren) {
if (children.size() != newChildren.size())
throw new IllegalArgumentException("Expected " + children.size() + " children but got " + newChildren.size());
- return new ArithmeticNode(newChildren, operators);
+ return new OperationNode(newChildren, operators);
}
@Override
public int hashCode() { return Objects.hash(children, operators); }
- public static ArithmeticNode resolve(ExpressionNode left, ArithmeticOperator op, ExpressionNode right) {
- if ( ! (left instanceof ArithmeticNode leftArithmetic)) return new ArithmeticNode(left, op, right);
+ public static OperationNode resolve(ExpressionNode left, Operator op, ExpressionNode right) {
+ if ( ! (left instanceof OperationNode leftArithmetic)) return new OperationNode(left, op, right);
List<ExpressionNode> newChildren = new ArrayList<>(leftArithmetic.children());
newChildren.add(right);
- List<ArithmeticOperator> newOperators = new ArrayList<>(leftArithmetic.operators());
+ List<Operator> newOperators = new ArrayList<>(leftArithmetic.operators());
newOperators.add(op);
- return new ArithmeticNode(newChildren, newOperators);
+ return new OperationNode(newChildren, newOperators);
}
private static class ValueItem {
- final ArithmeticOperator op;
+ final Operator op;
Value value;
- public ValueItem(ArithmeticOperator op, Value value) {
+ public ValueItem(Operator op, Value value) {
this.op = op;
this.value = value;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
new file mode 100644
index 00000000000..02af88b2c58
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
@@ -0,0 +1,67 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.function.BiFunction;
+
+/**
+ * A mathematical operator
+ *
+ * @author bratseth
+ */
+public enum Operator {
+
+ // In order from lowest to highest precedence
+ or("||", (x, y) -> x.or(y)),
+ and("&&", (x, y) -> x.and(y)),
+ largerOrEqual(">=", (x, y) -> x.largerOrEqual(y)),
+ larger(">", (x, y) -> x.larger(y)),
+ smallerOrEqual("<=", (x, y) -> x.smallerOrEqual(y)),
+ smaller("<", (x, y) -> x.smaller(y)),
+ approxEqual("~=", (x, y) -> x.approxEqual(y)),
+ notEqual("!=", (x, y) -> x.notEqual(y)),
+ equal("==", (x, y) -> x.equal(y)),
+ plus("+", (x, y) -> x.add(y)),
+ minus("-", (x, y) -> x.subtract(y)),
+ multiply("*", (x, y) -> x.multiply(y)),
+ divide("/", (x, y) -> x.divide(y)),
+ modulo("%", (x, y) -> x.modulo(y)),
+ power("^", true, (x, y) -> x.power(y));
+
+ /** A list of all the operators in this in order of increasing precedence */
+ public static final List<Operator> operatorsByPrecedence = Arrays.stream(Operator.values()).toList();
+
+ private final String image;
+ private final boolean rightPrecedence;
+ private final BiFunction<Value, Value, Value> function;
+
+ Operator(String image, BiFunction<Value, Value, Value> function) {
+ this(image, false, function);
+ }
+
+ Operator(String image, boolean rightPrecedence, BiFunction<Value, Value, Value> function) {
+ this.image = image;
+ this.rightPrecedence = rightPrecedence;
+ this.function = function;
+ }
+
+ /** Returns true if this operator has precedence over the given operator */
+ public boolean hasPrecedenceOver(Operator other) {
+ if (operatorsByPrecedence.indexOf(this) == operatorsByPrecedence.indexOf(other))
+ return rightPrecedence;
+ return operatorsByPrecedence.indexOf(this) > operatorsByPrecedence.indexOf(other);
+ }
+
+ public final Value evaluate(Value x, Value y) {
+ return function.apply(x, y);
+ }
+
+ @Override
+ public String toString() {
+ return image;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
deleted file mode 100644
index fc259867923..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import java.io.Serializable;
-
-/**
- * A mathematical operator
- *
- * @author bratseth
- */
-public enum TruthOperator implements Serializable {
-
- SMALLER("<") { public boolean evaluate(double x, double y) { return x<y; } },
- SMALLEREQUAL("<=") { public boolean evaluate(double x, double y) { return x<=y; } },
- EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } },
- APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } },
- LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } },
- LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } },
- NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } };
-
- private final String operatorString;
-
- TruthOperator(String operatorString) {
- this.operatorString = operatorString;
- }
-
- /** Perform the truth operation on the input */
- public abstract boolean evaluate(double x, double y);
-
- @Override
- public String toString() { return operatorString; }
-
- public static TruthOperator fromString(String string) {
- for (TruthOperator operator : values())
- if (operator.toString().equals(string))
- return operator;
- throw new IllegalArgumentException("Illegal truth operator '" + string + "'");
- }
-
- private static boolean approxEqual(double x,double y) {
- if (y < -1.0 || y > 1.0) {
- x = Math.nextAfter(x/y, 1.0);
- y = 1.0;
- } else {
- x = Math.nextAfter(x, y);
- }
- return x==y;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index 90861e64164..1c1f7509ce8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -4,9 +4,8 @@ package com.yahoo.searchlib.rankingexpression.transform;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
-import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
@@ -34,8 +33,8 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
node = transformIf((IfNode) node);
if (node instanceof EmbracedNode e && hasSingleUndividableChild(e))
node = e.children().get(0);
- if (node instanceof ArithmeticNode)
- node = transformArithmetic((ArithmeticNode) node);
+ if (node instanceof OperationNode)
+ node = transformArithmetic((OperationNode) node);
if (node instanceof NegativeNode)
node = transformNegativeNode((NegativeNode) node);
return node;
@@ -43,19 +42,18 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
private boolean hasSingleUndividableChild(EmbracedNode node) {
if (node.children().size() > 1) return false;
- if (node.children().get(0) instanceof ArithmeticNode) return false;
- if (node.children().get(0) instanceof ComparisonNode) return false;
+ if (node.children().get(0) instanceof OperationNode) return false;
return true;
}
- private ExpressionNode transformArithmetic(ArithmeticNode node) {
+ private ExpressionNode transformArithmetic(OperationNode node) {
// Fold the subset of expressions that are constant (such that in "1 + 2 + var")
if (node.children().size() > 1) {
List<ExpressionNode> children = new ArrayList<>(node.children());
- List<ArithmeticOperator> operators = new ArrayList<>(node.operators());
- for (ArithmeticOperator operator : ArithmeticOperator.operatorsByPrecedence)
+ List<Operator> operators = new ArrayList<>(node.operators());
+ for (Operator operator : Operator.operatorsByPrecedence)
transform(operator, children, operators);
- node = new ArithmeticNode(children, operators);
+ node = new OperationNode(children, operators);
}
if (isConstant(node) && ! node.evaluate(null).isNaN())
@@ -66,8 +64,8 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
return node;
}
- private void transform(ArithmeticOperator operatorToTransform,
- List<ExpressionNode> children, List<ArithmeticOperator> operators) {
+ private void transform(Operator operatorToTransform,
+ List<ExpressionNode> children, List<Operator> operators) {
int i = 0;
while (i < children.size()-1) {
boolean transformed = false;
@@ -75,7 +73,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
ExpressionNode child1 = children.get(i);
ExpressionNode child2 = children.get(i + 1);
if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) {
- Value evaluated = new ArithmeticNode(child1, operators.get(i), child2).evaluate(null);
+ Value evaluated = new OperationNode(child1, operators.get(i), child2).evaluate(null);
if ( ! evaluated.isNaN()) { // Don't replace by NaN
operators.remove(i);
children.set(i, new ConstantNode(evaluated.freeze()));
@@ -94,7 +92,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
* This check works because we simplify by decreasing precedence, so neighbours will either be single constant values
* or a more complex expression that can't be simplified and hence also prevents the simplification in question here.
*/
- private boolean hasPrecedence(List<ArithmeticOperator> operators, int i) {
+ private boolean hasPrecedence(List<Operator> operators, int i) {
if (i > 0 && operators.get(i-1).hasPrecedenceOver(operators.get(i))) return false;
if (i < operators.size()-1 && operators.get(i+1).hasPrecedenceOver(operators.get(i))) return false;
return true;
@@ -116,14 +114,14 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
return new ConstantNode(constant.getValue().negate() );
}
- private boolean allMultiplicationOrDivision(ArithmeticNode node) {
- for (ArithmeticOperator o : node.operators())
- if (o != ArithmeticOperator.MULTIPLY && o != ArithmeticOperator.DIVIDE)
+ private boolean allMultiplicationOrDivision(OperationNode node) {
+ for (Operator o : node.operators())
+ if (o != Operator.multiply && o != Operator.divide)
return false;
return true;
}
- private boolean hasZero(ArithmeticNode node) {
+ private boolean hasZero(OperationNode node) {
for (ExpressionNode child : node.children()) {
if (isZero(child))
return true;
@@ -131,9 +129,9 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
return false;
}
- private boolean hasDivisionByZero(ArithmeticNode node) {
+ private boolean hasDivisionByZero(OperationNode node) {
for (int i = 1; i < node.children().size(); i++) {
- if ( node.operators().get(i - 1) == ArithmeticOperator.DIVIDE && isZero(node.children().get(i)))
+ if (node.operators().get(i - 1) == Operator.divide && isZero(node.children().get(i)))
return true;
}
return false;