diff options
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java | 61 |
1 files changed, 35 insertions, 26 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java index 420f1f459f3..7ba671e62eb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java @@ -61,13 +61,13 @@ public class GBDTOptimizer extends Optimizer { * @return the optimized expression */ private ExpressionNode optimize(ExpressionNode node, ContextIndex context) { - if (node instanceof ArithmeticNode) { - Iterator<ExpressionNode> childIt = ((ArithmeticNode)node).children().iterator(); + if (node instanceof OperationNode) { + Iterator<ExpressionNode> childIt = ((OperationNode)node).children().iterator(); ExpressionNode ret = optimize(childIt.next(), context); - Iterator<ArithmeticOperator> operIt = ((ArithmeticNode)node).operators().iterator(); + Iterator<Operator> operIt = ((OperationNode)node).operators().iterator(); while (childIt.hasNext() && operIt.hasNext()) { - ret = ArithmeticNode.resolve(ret, operIt.next(), optimize(childIt.next(), context)); + ret = OperationNode.resolve(ret, operIt.next(), optimize(childIt.next(), context)); } return ret; } @@ -114,15 +114,15 @@ public class GBDTOptimizer extends Optimizer { /** Consumes the if condition and return the size of the values resulting, for convenience */ private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) { - if (condition instanceof ComparisonNode) { - ComparisonNode comparison = (ComparisonNode)condition; - if (comparison.getOperator() == TruthOperator.SMALLER) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.getLeftCondition(), context)); - else if (comparison.getOperator() == TruthOperator.EQUAL) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.getLeftCondition(), context)); + if (isBinaryComparison(condition)) { + OperationNode comparison = (OperationNode)condition; + if (comparison.operators().get(0) == Operator.smaller) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context)); + else if (comparison.operators().get(0) == Operator.equal) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context)); else - throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.getOperator()); - values.add(toValue(comparison.getRightCondition())); + throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0)); + values.add(toValue(comparison.children().get(1))); } else if (condition instanceof SetMembershipNode) { SetMembershipNode setMembership = (SetMembershipNode)condition; @@ -131,17 +131,15 @@ public class GBDTOptimizer extends Optimizer { for (ExpressionNode setElementNode : setMembership.getSetValues()) values.add(toValue(setElementNode)); } - else if (condition instanceof NotNode) { // handle if inversion: !(a >= b) - NotNode notNode = (NotNode)condition; - if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) { - EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0); - if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) { - ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0); - if (comparison.getOperator() == TruthOperator.LARGEREQUAL) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context)); + else if (condition instanceof NotNode notNode) { // handle if inversion: !(a >= b) + if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) { + if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) { + OperationNode comparison = (OperationNode)embracedNode.children().get(0); + if (comparison.operators().get(0) == Operator.largerOrEqual) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context)); else - throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator()); - values.add(toValue(comparison.getRightCondition())); + throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0)); + values.add(toValue(comparison.children().get(1))); } } } @@ -152,12 +150,24 @@ public class GBDTOptimizer extends Optimizer { return values.size(); } + private boolean isBinaryComparison(ExpressionNode condition) { + if ( ! (condition instanceof OperationNode binaryNode)) return false; + if (binaryNode.operators().size() != 1) return false; + if (binaryNode.operators().get(0) == Operator.largerOrEqual) return true; + if (binaryNode.operators().get(0) == Operator.larger) return true; + if (binaryNode.operators().get(0) == Operator.smallerOrEqual) return true; + if (binaryNode.operators().get(0) == Operator.smaller) return true; + if (binaryNode.operators().get(0) == Operator.approxEqual) return true; + if (binaryNode.operators().get(0) == Operator.notEqual) return true; + if (binaryNode.operators().get(0) == Operator.equal) return true; + return false; + } + private double getVariableIndex(ExpressionNode node, ContextIndex context) { - if (!(node instanceof ReferenceNode)) { + if (!(node instanceof ReferenceNode fNode)) { throw new IllegalArgumentException("Contained a left-hand comparison expression " + "which was not a feature value but was: " + node); } - ReferenceNode fNode = (ReferenceNode)node; Integer index = context.getIndex(fNode.toString()); if (index == null) { throw new IllegalStateException("The ranking expression contained feature '" + fNode.getName() + @@ -177,8 +187,7 @@ public class GBDTOptimizer extends Optimizer { value.getClass().getSimpleName() + " (" + value + ") in a set test: " + node); } - if (node instanceof NegativeNode) { - NegativeNode nNode = (NegativeNode)node; + if (node instanceof NegativeNode nNode) { if (!(nNode.getValue() instanceof ConstantNode)) { throw new IllegalArgumentException("Contained a negation of a non-number: " + nNode.getValue()); } |