summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java61
1 files changed, 35 insertions, 26 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
index 420f1f459f3..7ba671e62eb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
@@ -61,13 +61,13 @@ public class GBDTOptimizer extends Optimizer {
* @return the optimized expression
*/
private ExpressionNode optimize(ExpressionNode node, ContextIndex context) {
- if (node instanceof ArithmeticNode) {
- Iterator<ExpressionNode> childIt = ((ArithmeticNode)node).children().iterator();
+ if (node instanceof OperationNode) {
+ Iterator<ExpressionNode> childIt = ((OperationNode)node).children().iterator();
ExpressionNode ret = optimize(childIt.next(), context);
- Iterator<ArithmeticOperator> operIt = ((ArithmeticNode)node).operators().iterator();
+ Iterator<Operator> operIt = ((OperationNode)node).operators().iterator();
while (childIt.hasNext() && operIt.hasNext()) {
- ret = ArithmeticNode.resolve(ret, operIt.next(), optimize(childIt.next(), context));
+ ret = OperationNode.resolve(ret, operIt.next(), optimize(childIt.next(), context));
}
return ret;
}
@@ -114,15 +114,15 @@ public class GBDTOptimizer extends Optimizer {
/** Consumes the if condition and return the size of the values resulting, for convenience */
private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) {
- if (condition instanceof ComparisonNode) {
- ComparisonNode comparison = (ComparisonNode)condition;
- if (comparison.getOperator() == TruthOperator.SMALLER)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.getLeftCondition(), context));
- else if (comparison.getOperator() == TruthOperator.EQUAL)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.getLeftCondition(), context));
+ if (isBinaryComparison(condition)) {
+ OperationNode comparison = (OperationNode)condition;
+ if (comparison.operators().get(0) == Operator.smaller)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context));
+ else if (comparison.operators().get(0) == Operator.equal)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context));
else
- throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.getOperator());
- values.add(toValue(comparison.getRightCondition()));
+ throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0));
+ values.add(toValue(comparison.children().get(1)));
}
else if (condition instanceof SetMembershipNode) {
SetMembershipNode setMembership = (SetMembershipNode)condition;
@@ -131,17 +131,15 @@ public class GBDTOptimizer extends Optimizer {
for (ExpressionNode setElementNode : setMembership.getSetValues())
values.add(toValue(setElementNode));
}
- else if (condition instanceof NotNode) { // handle if inversion: !(a >= b)
- NotNode notNode = (NotNode)condition;
- if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) {
- EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0);
- if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) {
- ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0);
- if (comparison.getOperator() == TruthOperator.LARGEREQUAL)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context));
+ else if (condition instanceof NotNode notNode) { // handle if inversion: !(a >= b)
+ if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) {
+ if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) {
+ OperationNode comparison = (OperationNode)embracedNode.children().get(0);
+ if (comparison.operators().get(0) == Operator.largerOrEqual)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context));
else
- throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator());
- values.add(toValue(comparison.getRightCondition()));
+ throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0));
+ values.add(toValue(comparison.children().get(1)));
}
}
}
@@ -152,12 +150,24 @@ public class GBDTOptimizer extends Optimizer {
return values.size();
}
+ private boolean isBinaryComparison(ExpressionNode condition) {
+ if ( ! (condition instanceof OperationNode binaryNode)) return false;
+ if (binaryNode.operators().size() != 1) return false;
+ if (binaryNode.operators().get(0) == Operator.largerOrEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.larger) return true;
+ if (binaryNode.operators().get(0) == Operator.smallerOrEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.smaller) return true;
+ if (binaryNode.operators().get(0) == Operator.approxEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.notEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.equal) return true;
+ return false;
+ }
+
private double getVariableIndex(ExpressionNode node, ContextIndex context) {
- if (!(node instanceof ReferenceNode)) {
+ if (!(node instanceof ReferenceNode fNode)) {
throw new IllegalArgumentException("Contained a left-hand comparison expression " +
"which was not a feature value but was: " + node);
}
- ReferenceNode fNode = (ReferenceNode)node;
Integer index = context.getIndex(fNode.toString());
if (index == null) {
throw new IllegalStateException("The ranking expression contained feature '" + fNode.getName() +
@@ -177,8 +187,7 @@ public class GBDTOptimizer extends Optimizer {
value.getClass().getSimpleName() + " (" + value + ") in a set test: " + node);
}
- if (node instanceof NegativeNode) {
- NegativeNode nNode = (NegativeNode)node;
+ if (node instanceof NegativeNode nNode) {
if (!(nNode.getValue() instanceof ConstantNode)) {
throw new IllegalArgumentException("Contained a negation of a non-number: " + nNode.getValue());
}