From 3d49f155fccfa4fc08882b01e7a6e3a982c55212 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 28 Sep 2022 16:19:30 +0200 Subject: Fold comparisons into the other operators --- .../evaluation/gbdtoptimization/GBDTOptimizer.java | 53 +++++++++++++--------- 1 file changed, 31 insertions(+), 22 deletions(-) (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java index 420f1f459f3..cf4c35d94af 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java @@ -114,15 +114,15 @@ public class GBDTOptimizer extends Optimizer { /** Consumes the if condition and return the size of the values resulting, for convenience */ private int consumeIfCondition(ExpressionNode condition, List values, ContextIndex context) { - if (condition instanceof ComparisonNode) { - ComparisonNode comparison = (ComparisonNode)condition; - if (comparison.getOperator() == TruthOperator.SMALLER) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.getLeftCondition(), context)); - else if (comparison.getOperator() == TruthOperator.EQUAL) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.getLeftCondition(), context)); + if (isBinaryComparison(condition)) { + ArithmeticNode comparison = (ArithmeticNode)condition; + if (comparison.operators().get(0) == ArithmeticOperator.LESS) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context)); + else if (comparison.operators().get(0) == ArithmeticOperator.EQUAL) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context)); else - throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.getOperator()); - values.add(toValue(comparison.getRightCondition())); + throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0)); + values.add(toValue(comparison.children().get(1))); } else if (condition instanceof SetMembershipNode) { SetMembershipNode setMembership = (SetMembershipNode)condition; @@ -131,17 +131,15 @@ public class GBDTOptimizer extends Optimizer { for (ExpressionNode setElementNode : setMembership.getSetValues()) values.add(toValue(setElementNode)); } - else if (condition instanceof NotNode) { // handle if inversion: !(a >= b) - NotNode notNode = (NotNode)condition; - if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) { - EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0); - if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) { - ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0); - if (comparison.getOperator() == TruthOperator.LARGEREQUAL) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context)); + else if (condition instanceof NotNode notNode) { // handle if inversion: !(a >= b) + if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) { + if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) { + ArithmeticNode comparison = (ArithmeticNode)embracedNode.children().get(0); + if (comparison.operators().get(0) == ArithmeticOperator.GREATEREQUAL) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context)); else - throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator()); - values.add(toValue(comparison.getRightCondition())); + throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0)); + values.add(toValue(comparison.children().get(1))); } } } @@ -152,12 +150,24 @@ public class GBDTOptimizer extends Optimizer { return values.size(); } + private boolean isBinaryComparison(ExpressionNode condition) { + if ( ! (condition instanceof ArithmeticNode binaryNode)) return false; + if (binaryNode.operators().size() != 1) return false; + if (binaryNode.operators().get(0) == ArithmeticOperator.GREATEREQUAL) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.GREATER) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.LESSEQUAL) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.LESS) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.APPROX) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.NOTEQUAL) return true; + if (binaryNode.operators().get(0) == ArithmeticOperator.EQUAL) return true; + return false; + } + private double getVariableIndex(ExpressionNode node, ContextIndex context) { - if (!(node instanceof ReferenceNode)) { + if (!(node instanceof ReferenceNode fNode)) { throw new IllegalArgumentException("Contained a left-hand comparison expression " + "which was not a feature value but was: " + node); } - ReferenceNode fNode = (ReferenceNode)node; Integer index = context.getIndex(fNode.toString()); if (index == null) { throw new IllegalStateException("The ranking expression contained feature '" + fNode.getName() + @@ -177,8 +187,7 @@ public class GBDTOptimizer extends Optimizer { value.getClass().getSimpleName() + " (" + value + ") in a set test: " + node); } - if (node instanceof NegativeNode) { - NegativeNode nNode = (NegativeNode)node; + if (node instanceof NegativeNode nNode) { if (!(nNode.getValue() instanceof ConstantNode)) { throw new IllegalArgumentException("Contained a negation of a non-number: " + nNode.getValue()); } -- cgit v1.2.3 From bcbb2009c44380055b2670e7cdefcad232f9ece4 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 28 Sep 2022 17:42:55 +0200 Subject: Drop 'arithmetic' from name --- .../BooleanExpressionTransformer.java | 40 ++--- .../expressiontransforms/TokenTransformer.java | 20 +-- .../BooleanExpressionTransformerTestCase.java | 6 +- .../importer/operations/Gather.java | 10 +- .../importer/operations/Gemm.java | 12 +- .../importer/operations/Range.java | 8 +- .../importer/operations/Reshape.java | 20 ++- .../importer/operations/Slice.java | 8 +- .../importer/operations/Split.java | 6 +- .../importer/operations/Tile.java | 6 +- .../gbdtoptimization/GBDTForestOptimizer.java | 12 +- .../evaluation/gbdtoptimization/GBDTOptimizer.java | 34 ++--- .../rankingexpression/rule/ArithmeticNode.java | 160 -------------------- .../rankingexpression/rule/ArithmeticOperator.java | 65 -------- .../rankingexpression/rule/LambdaFunctionNode.java | 6 +- .../rankingexpression/rule/OperationNode.java | 168 +++++++++++++++++++++ .../searchlib/rankingexpression/rule/Operator.java | 65 ++++++++ .../rankingexpression/transform/Simplifier.java | 38 ++--- .../RankingExpressionTestCase.java | 6 +- .../evaluation/EvaluationTestCase.java | 8 +- 20 files changed, 352 insertions(+), 346 deletions(-) delete mode 100755 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java delete mode 100644 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java create mode 100755 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java create mode 100644 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java') diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java index 8fa4b469590..9ffc73e1863 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java @@ -2,8 +2,8 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -35,20 +35,20 @@ public class BooleanExpressionTransformer extends ExpressionTransformer child = node.children().iterator(); // Transform in precedence order: Deque stack = new ArrayDeque<>(); stack.push(new ChildNode(null, child.next())); - for (Iterator it = node.operators().iterator(); it.hasNext() && child.hasNext();) { - ArithmeticOperator op = it.next(); + for (Iterator it = node.operators().iterator(); it.hasNext() && child.hasNext();) { + Operator op = it.next(); if ( ! stack.isEmpty()) { while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) { popStack(stack); @@ -66,9 +66,9 @@ public class BooleanExpressionTransformer extends ExpressionTransformer joinedOps = new ArrayList<>(); + List joinedOps = new ArrayList<>(); joinOps(left, joinedOps); joinedOps.add(right.op); joinOps(right, joinedOps); List joinedChildren = new ArrayList<>(); joinChildren(left, joinedChildren); joinChildren(right, joinedChildren); - return new ArithmeticNode(joinedChildren, joinedOps); + return new OperationNode(joinedChildren, joinedOps); } - private static void joinOps(ChildNode node, List joinedOps) { - if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode) - joinedOps.addAll(arithmeticNode.operators()); + private static void joinOps(ChildNode node, List joinedOps) { + if (node.artificial && node.child instanceof OperationNode operationNode) + joinedOps.addAll(operationNode.operators()); } private static void joinChildren(ChildNode node, List joinedChildren) { - if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode) - joinedChildren.addAll(arithmeticNode.children()); + if (node.artificial && node.child instanceof OperationNode operationNode) + joinedChildren.addAll(operationNode.children()); else joinedChildren.add(node.child); } @@ -115,11 +115,11 @@ public class BooleanExpressionTransformer extends ExpressionTransformer tokenSequence = createTokenSequence(feature); ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); - ArithmeticNode comparison = new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, lengthExpr); + OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.LESS, lengthExpr); ExpressionNode expr = new IfNode(comparison, ONE, ZERO); return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); } @@ -254,7 +254,7 @@ public class TokenTransformer extends ExpressionTransformer sequence) { ExpressionNode lengthExpr = createLengthExpr(iter, sequence); - ArithmeticNode comparison = new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, lengthExpr); + OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.LESS, lengthExpr); ExpressionNode trueExpr = sequence.get(iter); if (sequence.get(iter) instanceof ReferenceNode) { @@ -278,7 +278,7 @@ public class TokenTransformer extends ExpressionTransformer sequence) { List factors = new ArrayList<>(); - List operators = new ArrayList<>(); + List operators = new ArrayList<>(); for (int i = 0; i < iter + 1; ++i) { if (sequence.get(i) instanceof ConstantNode) { factors.add(ONE); @@ -286,10 +286,10 @@ public class TokenTransformer extends ExpressionTransformer= 1) { - operators.add(ArithmeticOperator.PLUS); + operators.add(Operator.PLUS); } } - return new ArithmeticNode(factors, operators); + return new OperationNode(factors, operators); } /** @@ -299,7 +299,7 @@ public class TokenTransformer extends ExpressionTransformer= 1) { ExpressionNode lengthExpr = new EmbracedNode(createLengthExpr(iter - 1, sequence)); - expr = new EmbracedNode(new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.MINUS, lengthExpr)); + expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.MINUS, lengthExpr)); } else { expr = new ReferenceNode("d1"); } diff --git a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java index 65d12f1cdcf..d692b69d3c8 100644 --- a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java @@ -4,7 +4,7 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import org.junit.jupiter.api.Test; @@ -43,8 +43,8 @@ public class BooleanExpressionTransformerTestCase { var expr = new BooleanExpressionTransformer() .transform(new RankingExpression("a + b + c * d + e + f"), new TransformContext(Map.of(), new MapTypeContext())); - assertTrue(expr.getRoot() instanceof ArithmeticNode); - ArithmeticNode root = (ArithmeticNode) expr.getRoot(); + assertTrue(expr.getRoot() instanceof OperationNode); + OperationNode root = (OperationNode) expr.getRoot(); assertEquals(5, root.operators().size()); assertEquals(6, root.children().size()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java index cd0c4da6d0f..c66022975c7 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -92,7 +92,7 @@ public class Gather extends IntermediateOperation { ExpressionNode indexExpression = new ConstantNode(new DoubleValue(constantValue)); if (constantValue < 0) { ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); - indexExpression = new EmbracedNode(new ArithmeticNode(indexExpression, ArithmeticOperator.PLUS, axisSize)); + indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.PLUS, axisSize)); } addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression); } else { @@ -125,8 +125,8 @@ public class Gather extends IntermediateOperation { /** to support negative indexing */ private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) { ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); - ExpressionNode plus = new EmbracedNode(new ArithmeticNode(slice, ArithmeticOperator.PLUS, axisSize)); - ExpressionNode mod = new ArithmeticNode(plus, ArithmeticOperator.MODULO, axisSize); + ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.PLUS, axisSize)); + ExpressionNode mod = new OperationNode(plus, Operator.MODULO, axisSize); return mod; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 1f447f2a575..81d633dea4b 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; @@ -97,17 +97,17 @@ public class Gemm extends IntermediateOperation { TensorFunction AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension); TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( - new ArithmeticNode( + new OperationNode( new TensorFunctionNode(AxB), - ArithmeticOperator.MULTIPLY, + Operator.MULTIPLY, new ConstantNode(new DoubleValue(alpha)))); if (inputs.size() == 3) { Optional> cFunction = inputs.get(2).function(); TensorFunction betaxC = new TensorFunctionNode.ExpressionTensorFunction( - new ArithmeticNode( + new OperationNode( new TensorFunctionNode(cFunction.get()), - ArithmeticOperator.MULTIPLY, + Operator.MULTIPLY, new ConstantNode(new DoubleValue(beta)))); return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java index 5c4e8cd6cd0..66e810b954e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -65,8 +65,8 @@ public class Range extends IntermediateOperation { ExpressionNode startExpr = new ConstantNode(new DoubleValue(start)); ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta)); ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName)); - ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr); - ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr); + ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.MULTIPLY, dimExpr); + ExpressionNode addExpr = new OperationNode(startExpr, Operator.PLUS, stepExpr); TensorFunction function = Generate.bound(type.type(), wrapScalar(addExpr)); return function; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 7b675fa79af..ce93461bff3 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -6,13 +6,11 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.Tensor; @@ -159,13 +157,13 @@ public class Reshape extends IntermediateOperation { inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); } else if (dim == (inputType.rank() - 1)) { ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); - ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size); + ExpressionNode div = new OperationNode(unrolled, Operator.MODULO, size); inputDimensionExpression = new EmbracedNode(div); } else { ExpressionNode size = new ConstantNode(new DoubleValue(innerSize)); ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize)); - ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize); - ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size); + ExpressionNode mod = new OperationNode(unrolled, Operator.MODULO, previousSize); + ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.DIVIDE, size); inputDimensionExpression = new EmbracedNode(div); } dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression))); @@ -183,21 +181,21 @@ public class Reshape extends IntermediateOperation { return new ConstantNode(DoubleValue.zero); List children = new ArrayList<>(); - List operators = new ArrayList<>(); + List operators = new ArrayList<>(); int size = 1; for (int i = type.dimensions().size() - 1; i >= 0; --i) { TensorType.Dimension dimension = type.dimensions().get(i); children.add(0, new ReferenceNode(dimension.name())); if (size > 1) { - operators.add(0, ArithmeticOperator.MULTIPLY); + operators.add(0, Operator.MULTIPLY); children.add(0, new ConstantNode(new DoubleValue(size))); } size *= OrderedTensorType.dimensionSize(dimension); if (i > 0) { - operators.add(0, ArithmeticOperator.PLUS); + operators.add(0, Operator.PLUS); } } - return new ArithmeticNode(children, operators); + return new OperationNode(children, operators); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java index 91b7064b19c..617e1f00c94 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -6,8 +6,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -166,8 +166,8 @@ public class Slice extends IntermediateOperation { // step * (d0 + start) ExpressionNode reference = new ReferenceNode(outputDimensionName); - ExpressionNode plus = new EmbracedNode(new ArithmeticNode(reference, ArithmeticOperator.PLUS, startIndex)); - ExpressionNode mul = new ArithmeticNode(stepSize, ArithmeticOperator.MULTIPLY, plus); + ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.PLUS, startIndex)); + ExpressionNode mul = new OperationNode(stepSize, Operator.MULTIPLY, plus); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul)))); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java index 6f720716adb..42901259821 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -96,7 +96,7 @@ public class Split extends IntermediateOperation { for (int i = 0; i < inputType.rank(); ++i) { String inputDimensionName = inputType.dimensions().get(i).name(); ExpressionNode reference = new ReferenceNode(inputDimensionName); - ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0))); + ExpressionNode offset = new OperationNode(reference, Operator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0))); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset)))); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java index 4bfab284cc2..cd88c625d81 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -4,8 +4,8 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -77,7 +77,7 @@ public class Tile extends IntermediateOperation { ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); ExpressionNode reference = new ReferenceNode(inputDimensionName); - ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size); + ExpressionNode mod = new OperationNode(reference, Operator.MODULO, size); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod)))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java index 6ab483800a7..78acd2e5af1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java @@ -5,8 +5,8 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -84,12 +84,12 @@ public class GBDTForestOptimizer extends Optimizer { currentTreesOptimized++; return true; } - if (!(node instanceof ArithmeticNode)) { + if (!(node instanceof OperationNode)) { return false; } - ArithmeticNode aNode = (ArithmeticNode)node; - for (ArithmeticOperator op : aNode.operators()) { - if (op != ArithmeticOperator.PLUS) { + OperationNode aNode = (OperationNode)node; + for (Operator op : aNode.operators()) { + if (op != Operator.PLUS) { return false; } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java index cf4c35d94af..b3cbe252dfc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java @@ -61,13 +61,13 @@ public class GBDTOptimizer extends Optimizer { * @return the optimized expression */ private ExpressionNode optimize(ExpressionNode node, ContextIndex context) { - if (node instanceof ArithmeticNode) { - Iterator childIt = ((ArithmeticNode)node).children().iterator(); + if (node instanceof OperationNode) { + Iterator childIt = ((OperationNode)node).children().iterator(); ExpressionNode ret = optimize(childIt.next(), context); - Iterator operIt = ((ArithmeticNode)node).operators().iterator(); + Iterator operIt = ((OperationNode)node).operators().iterator(); while (childIt.hasNext() && operIt.hasNext()) { - ret = ArithmeticNode.resolve(ret, operIt.next(), optimize(childIt.next(), context)); + ret = OperationNode.resolve(ret, operIt.next(), optimize(childIt.next(), context)); } return ret; } @@ -115,10 +115,10 @@ public class GBDTOptimizer extends Optimizer { /** Consumes the if condition and return the size of the values resulting, for convenience */ private int consumeIfCondition(ExpressionNode condition, List values, ContextIndex context) { if (isBinaryComparison(condition)) { - ArithmeticNode comparison = (ArithmeticNode)condition; - if (comparison.operators().get(0) == ArithmeticOperator.LESS) + OperationNode comparison = (OperationNode)condition; + if (comparison.operators().get(0) == Operator.LESS) values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context)); - else if (comparison.operators().get(0) == ArithmeticOperator.EQUAL) + else if (comparison.operators().get(0) == Operator.EQUAL) values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context)); else throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0)); @@ -134,8 +134,8 @@ public class GBDTOptimizer extends Optimizer { else if (condition instanceof NotNode notNode) { // handle if inversion: !(a >= b) if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) { if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) { - ArithmeticNode comparison = (ArithmeticNode)embracedNode.children().get(0); - if (comparison.operators().get(0) == ArithmeticOperator.GREATEREQUAL) + OperationNode comparison = (OperationNode)embracedNode.children().get(0); + if (comparison.operators().get(0) == Operator.GREATEREQUAL) values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context)); else throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0)); @@ -151,15 +151,15 @@ public class GBDTOptimizer extends Optimizer { } private boolean isBinaryComparison(ExpressionNode condition) { - if ( ! (condition instanceof ArithmeticNode binaryNode)) return false; + if ( ! (condition instanceof OperationNode binaryNode)) return false; if (binaryNode.operators().size() != 1) return false; - if (binaryNode.operators().get(0) == ArithmeticOperator.GREATEREQUAL) return true; - if (binaryNode.operators().get(0) == ArithmeticOperator.GREATER) return true; - if (binaryNode.operators().get(0) == ArithmeticOperator.LESSEQUAL) return true; - if (binaryNode.operators().get(0) == ArithmeticOperator.LESS) return true; - if (binaryNode.operators().get(0) == ArithmeticOperator.APPROX) return true; - if (binaryNode.operators().get(0) == ArithmeticOperator.NOTEQUAL) return true; - if (binaryNode.operators().get(0) == ArithmeticOperator.EQUAL) return true; + if (binaryNode.operators().get(0) == Operator.GREATEREQUAL) return true; + if (binaryNode.operators().get(0) == Operator.GREATER) return true; + if (binaryNode.operators().get(0) == Operator.LESSEQUAL) return true; + if (binaryNode.operators().get(0) == Operator.LESS) return true; + if (binaryNode.operators().get(0) == Operator.APPROX) return true; + if (binaryNode.operators().get(0) == Operator.NOTEQUAL) return true; + if (binaryNode.operators().get(0) == Operator.EQUAL) return true; return false; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java deleted file mode 100755 index c3e39197316..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.functions.Join; - -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Deque; -import java.util.Iterator; -import java.util.List; -import java.util.Objects; - -/** - * A binary mathematical operation - * - * @author bratseth - */ -public final class ArithmeticNode extends CompositeNode { - - private final List children; - private final List operators; - - public ArithmeticNode(List children, List operators) { - this.children = List.copyOf(children); - this.operators = List.copyOf(operators); - } - - public ArithmeticNode(ExpressionNode leftExpression, ArithmeticOperator operator, ExpressionNode rightExpression) { - this.children = List.of(leftExpression, rightExpression); - this.operators = List.of(operator); - } - - public List operators() { return operators; } - - @Override - public List children() { return children; } - - @Override - public StringBuilder toString(StringBuilder string, SerializationContext context, Deque path, CompositeNode parent) { - boolean nonDefaultPrecedence = nonDefaultPrecedence(parent); - if (nonDefaultPrecedence) - string.append("("); - - Iterator child = children.iterator(); - child.next().toString(string, context, path, this); - if (child.hasNext()) - string.append(" "); - for (Iterator op = operators.iterator(); op.hasNext() && child.hasNext();) { - string.append(op.next().toString()).append(" "); - child.next().toString(string, context, path, this); - if (op.hasNext()) - string.append(" "); - } - if (nonDefaultPrecedence) - string.append(")"); - - return string; - } - - /** - * Returns true if this node has lower precedence than the parent - * (even though by virtue of being a node it will be calculated before the parent). - */ - private boolean nonDefaultPrecedence(CompositeNode parent) { - if ( parent == null) return false; - if ( ! (parent instanceof ArithmeticNode arithmeticParent)) return false; - - // The line below can only be correct in both only have one operator. - // Getting this correct is impossible without more work. - // So for now we only handle the simple case correctly, and use a safe approach by adding - // extra parenthesis just in case.... - return arithmeticParent.operators.get(0).hasPrecedenceOver(this.operators.get(0)) - || ((arithmeticParent.operators.size() > 1) || (operators.size() > 1)); - } - - @Override - public TensorType type(TypeContext context) { - // Compute type using tensor types as arithmetic operators are supported on tensors - // and is correct also in the special case of doubles. - // As all our functions are type-commutative, we don't need to take operator precedence into account - TensorType type = children.get(0).type(context); - for (int i = 1; i < children.size(); i++) - type = Join.outputType(type, children.get(i).type(context)); - return type; - } - - @Override - public Value evaluate(Context context) { - Iterator child = children.iterator(); - - // Apply in precedence order: - Deque stack = new ArrayDeque<>(); - stack.push(new ValueItem(null, child.next().evaluate(context))); - for (Iterator it = operators.iterator(); it.hasNext() && child.hasNext();) { - ArithmeticOperator op = it.next(); - if ( ! stack.isEmpty()) { - while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) { - popStack(stack); - } - } - stack.push(new ValueItem(op, child.next().evaluate(context))); - } - while (stack.size() > 1) { - popStack(stack); - } - return stack.getFirst().value; - } - - private void popStack(Deque stack) { - ValueItem rhs = stack.pop(); - ValueItem lhs = stack.peek(); - lhs.value = rhs.op.evaluate(lhs.value, rhs.value); - } - - @Override - public CompositeNode setChildren(List newChildren) { - if (children.size() != newChildren.size()) - throw new IllegalArgumentException("Expected " + children.size() + " children but got " + newChildren.size()); - return new ArithmeticNode(newChildren, operators); - } - - @Override - public int hashCode() { return Objects.hash(children, operators); } - - public static ArithmeticNode resolve(ExpressionNode left, ArithmeticOperator op, ExpressionNode right) { - if ( ! (left instanceof ArithmeticNode leftArithmetic)) return new ArithmeticNode(left, op, right); - - List newChildren = new ArrayList<>(leftArithmetic.children()); - newChildren.add(right); - - List newOperators = new ArrayList<>(leftArithmetic.operators()); - newOperators.add(op); - - return new ArithmeticNode(newChildren, newOperators); - } - - private static class ValueItem { - - final ArithmeticOperator op; - Value value; - - public ValueItem(ArithmeticOperator op, Value value) { - this.op = op; - this.value = value; - } - - @Override - public String toString() { - return value.toString(); - } - - } - -} - diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java deleted file mode 100644 index 435c92ff7da..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.yahoo.searchlib.rankingexpression.evaluation.Value; - -import java.util.Arrays; -import java.util.List; -import java.util.function.BiFunction; - -/** - * A mathematical operator - * - * @author bratseth - */ -public enum ArithmeticOperator { - - // In order from lowest to highest precedence - OR("||", (x, y) -> x.or(y)), - AND("&&", (x, y) -> x.and(y)), - GREATEREQUAL(">=", (x, y) -> x.greaterEqual(y)), - GREATER(">", (x, y) -> x.greater(y)), - LESSEQUAL("<=", (x, y) -> x.lessEqual(y)), - LESS("<", (x, y) -> x.less(y)), - APPROX("~=", (x, y) -> x.approx(y)), - NOTEQUAL("!=", (x, y) -> x.notEqual(y)), - EQUAL("==", (x, y) -> x.equal(y)), - PLUS("+", (x, y) -> x.add(y)), - MINUS("-", (x, y) -> x.subtract(y)), - MULTIPLY("*", (x, y) -> x.multiply(y)), - DIVIDE("/", (x, y) -> x.divide(y)), - MODULO("%", (x, y) -> x.modulo(y)), - POWER("^", true, (x, y) -> x.power(y)); - - /** A list of all the operators in this in order of decreasing precedence */ - public static final List operatorsByPrecedence = Arrays.stream(ArithmeticOperator.values()).toList(); - - private final String image; - private final boolean bindsRight; // TODO: Implement - private final BiFunction function; - - ArithmeticOperator(String image, BiFunction function) { - this(image, false, function); - } - - ArithmeticOperator(String image, boolean bindsRight, BiFunction function) { - this.image = image; - this.bindsRight = bindsRight; - this.function = function; - } - - /** Returns true if this operator has precedence over the given operator */ - public boolean hasPrecedenceOver(ArithmeticOperator op) { - return operatorsByPrecedence.indexOf(this) > operatorsByPrecedence.indexOf(op); - } - - public final Value evaluate(Value x, Value y) { - return function.apply(x, y); - } - - @Override - public String toString() { - return image; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 9f07f146264..3c7f48aa38c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -106,10 +106,10 @@ public class LambdaFunctionNode extends CompositeNode { } private Optional getDirectEvaluator() { - if ( ! (functionExpression instanceof ArithmeticNode)) { + if ( ! (functionExpression instanceof OperationNode)) { return Optional.empty(); } - ArithmeticNode node = (ArithmeticNode) functionExpression; + OperationNode node = (OperationNode) functionExpression; if ( ! (node.children().get(0) instanceof ReferenceNode) || ! (node.children().get(1) instanceof ReferenceNode)) { return Optional.empty(); } @@ -121,7 +121,7 @@ public class LambdaFunctionNode extends CompositeNode { if (node.operators().size() != 1) { return Optional.empty(); } - ArithmeticOperator operator = node.operators().get(0); + Operator operator = node.operators().get(0); switch (operator) { case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0); case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java new file mode 100755 index 00000000000..392f42f6cbe --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java @@ -0,0 +1,168 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Join; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +/** + * A sequence of binary operations. + * + * @author bratseth + */ +public final class OperationNode extends CompositeNode { + + private final List children; + private final List operators; + + public OperationNode(List children, List operators) { + this.children = List.copyOf(children); + this.operators = List.copyOf(operators); + } + + public OperationNode(ExpressionNode leftExpression, Operator operator, ExpressionNode rightExpression) { + this.children = List.of(leftExpression, rightExpression); + this.operators = List.of(operator); + } + + public List operators() { return operators; } + + @Override + public List children() { return children; } + + @Override + public StringBuilder toString(StringBuilder string, SerializationContext context, Deque path, CompositeNode parent) { + boolean nonDefaultPrecedence = nonDefaultPrecedence(parent); + if (nonDefaultPrecedence) + string.append("("); + + Iterator child = children.iterator(); + child.next().toString(string, context, path, this); + if (child.hasNext()) + string.append(" "); + for (Iterator op = operators.iterator(); op.hasNext() && child.hasNext();) { + string.append(op.next().toString()).append(" "); + child.next().toString(string, context, path, this); + if (op.hasNext()) + string.append(" "); + } + if (nonDefaultPrecedence) + string.append(")"); + + return string; + } + + /** + * Returns true if this node has lower precedence than the parent + * (even though by virtue of being a node it will be calculated before the parent). + */ + private boolean nonDefaultPrecedence(CompositeNode parent) { + if ( parent == null) return false; + if ( ! (parent instanceof OperationNode arithmeticParent)) return false; + + // The line below can only be correct in both only have one operator. + // Getting this correct is impossible without more work. + // So for now we only handle the simple case correctly, and use a safe approach by adding + // extra parenthesis just in case.... + return arithmeticParent.operators.get(0).hasPrecedenceOver(this.operators.get(0)) + || ((arithmeticParent.operators.size() > 1) || (operators.size() > 1)); + } + + @Override + public TensorType type(TypeContext context) { + // Compute type using tensor types as arithmetic operators are supported on tensors + // and is correct also in the special case of doubles. + // As all our functions are type-commutative, we don't need to take operator precedence into account + TensorType type = children.get(0).type(context); + for (int i = 1; i < children.size(); i++) + type = Join.outputType(type, children.get(i).type(context)); + return type; + } + + @Override + public Value evaluate(Context context) { + Iterator child = children.iterator(); + + // Apply in precedence order: + Deque stack = new ArrayDeque<>(); + stack.push(new ValueItem(null, child.next().evaluate(context))); + for (Iterator it = operators.iterator(); it.hasNext() && child.hasNext();) { + Operator op = it.next(); + if ( ! stack.isEmpty()) { + while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) { + popStack(stack); + } + } + stack.push(new ValueItem(op, child.next().evaluate(context))); + } + while (stack.size() > 1) { + popStack(stack); + } + return stack.getFirst().value; + } + + private void popStack(Deque stack) { + ValueItem rhs = stack.pop(); + ValueItem lhs = stack.peek(); + lhs.value = rhs.op.evaluate(lhs.value, rhs.value); + } + + @Override + public CompositeNode setChildren(List newChildren) { + if (children.size() != newChildren.size()) + throw new IllegalArgumentException("Expected " + children.size() + " children but got " + newChildren.size()); + return new OperationNode(newChildren, operators); + } + + @Override + public int hashCode() { return Objects.hash(children, operators); } + + @Override + public boolean equals(Object o) { + if ( ! (o instanceof OperationNode other)) return false; + if ( ! this.children().equals(other.children())) return false; + if ( ! this.operators().equals(other.operators())) return false; + return true; + } + + public static OperationNode resolve(ExpressionNode left, Operator op, ExpressionNode right) { + if ( ! (left instanceof OperationNode leftArithmetic)) return new OperationNode(left, op, right); + + List newChildren = new ArrayList<>(leftArithmetic.children()); + newChildren.add(right); + + List newOperators = new ArrayList<>(leftArithmetic.operators()); + newOperators.add(op); + + return new OperationNode(newChildren, newOperators); + } + + private static class ValueItem { + + final Operator op; + Value value; + + public ValueItem(Operator op, Value value) { + this.op = op; + this.value = value; + } + + @Override + public String toString() { + return value.toString(); + } + + } + +} + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java new file mode 100644 index 00000000000..4ddbfa4ea9f --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java @@ -0,0 +1,65 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Arrays; +import java.util.List; +import java.util.function.BiFunction; + +/** + * A mathematical operator + * + * @author bratseth + */ +public enum Operator { + + // In order from lowest to highest precedence + OR("||", (x, y) -> x.or(y)), + AND("&&", (x, y) -> x.and(y)), + GREATEREQUAL(">=", (x, y) -> x.greaterEqual(y)), + GREATER(">", (x, y) -> x.greater(y)), + LESSEQUAL("<=", (x, y) -> x.lessEqual(y)), + LESS("<", (x, y) -> x.less(y)), + APPROX("~=", (x, y) -> x.approx(y)), + NOTEQUAL("!=", (x, y) -> x.notEqual(y)), + EQUAL("==", (x, y) -> x.equal(y)), + PLUS("+", (x, y) -> x.add(y)), + MINUS("-", (x, y) -> x.subtract(y)), + MULTIPLY("*", (x, y) -> x.multiply(y)), + DIVIDE("/", (x, y) -> x.divide(y)), + MODULO("%", (x, y) -> x.modulo(y)), + POWER("^", true, (x, y) -> x.power(y)); + + /** A list of all the operators in this in order of increasing precedence */ + public static final List operatorsByPrecedence = Arrays.stream(Operator.values()).toList(); + + private final String image; + private final boolean bindsRight; // TODO: Implement + private final BiFunction function; + + Operator(String image, BiFunction function) { + this(image, false, function); + } + + Operator(String image, boolean bindsRight, BiFunction function) { + this.image = image; + this.bindsRight = bindsRight; + this.function = function; + } + + /** Returns true if this operator has precedence over the given operator */ + public boolean hasPrecedenceOver(Operator op) { + return operatorsByPrecedence.indexOf(this) > operatorsByPrecedence.indexOf(op); + } + + public final Value evaluate(Value x, Value y) { + return function.apply(x, y); + } + + @Override + public String toString() { + return image; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index 7a34f5b7b03..04728966bc1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -4,8 +4,8 @@ package com.yahoo.searchlib.rankingexpression.transform; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; @@ -33,8 +33,8 @@ public class Simplifier extends ExpressionTransformer { node = transformIf((IfNode) node); if (node instanceof EmbracedNode e && hasSingleUndividableChild(e)) node = e.children().get(0); - if (node instanceof ArithmeticNode) - node = transformArithmetic((ArithmeticNode) node); + if (node instanceof OperationNode) + node = transformArithmetic((OperationNode) node); if (node instanceof NegativeNode) node = transformNegativeNode((NegativeNode) node); return node; @@ -42,18 +42,18 @@ public class Simplifier extends ExpressionTransformer { private boolean hasSingleUndividableChild(EmbracedNode node) { if (node.children().size() > 1) return false; - if (node.children().get(0) instanceof ArithmeticNode) return false; + if (node.children().get(0) instanceof OperationNode) return false; return true; } - private ExpressionNode transformArithmetic(ArithmeticNode node) { + private ExpressionNode transformArithmetic(OperationNode node) { // Fold the subset of expressions that are constant (such that in "1 + 2 + var") if (node.children().size() > 1) { List children = new ArrayList<>(node.children()); - List operators = new ArrayList<>(node.operators()); - for (ArithmeticOperator operator : ArithmeticOperator.operatorsByPrecedence) + List operators = new ArrayList<>(node.operators()); + for (Operator operator : Operator.operatorsByPrecedence) transform(operator, children, operators); - node = new ArithmeticNode(children, operators); + node = new OperationNode(children, operators); } if (isConstant(node) && ! node.evaluate(null).isNaN()) @@ -64,8 +64,8 @@ public class Simplifier extends ExpressionTransformer { return node; } - private void transform(ArithmeticOperator operatorToTransform, - List children, List operators) { + private void transform(Operator operatorToTransform, + List children, List operators) { int i = 0; while (i < children.size()-1) { boolean transformed = false; @@ -73,7 +73,7 @@ public class Simplifier extends ExpressionTransformer { ExpressionNode child1 = children.get(i); ExpressionNode child2 = children.get(i + 1); if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) { - Value evaluated = new ArithmeticNode(child1, operators.get(i), child2).evaluate(null); + Value evaluated = new OperationNode(child1, operators.get(i), child2).evaluate(null); if ( ! evaluated.isNaN()) { // Don't replace by NaN operators.remove(i); children.set(i, new ConstantNode(evaluated.freeze())); @@ -92,7 +92,7 @@ public class Simplifier extends ExpressionTransformer { * This check works because we simplify by decreasing precedence, so neighbours will either be single constant values * or a more complex expression that can't be simplified and hence also prevents the simplification in question here. */ - private boolean hasPrecedence(List operators, int i) { + private boolean hasPrecedence(List operators, int i) { if (i > 0 && operators.get(i-1).hasPrecedenceOver(operators.get(i))) return false; if (i < operators.size()-1 && operators.get(i+1).hasPrecedenceOver(operators.get(i))) return false; return true; @@ -114,14 +114,14 @@ public class Simplifier extends ExpressionTransformer { return new ConstantNode(constant.getValue().negate() ); } - private boolean allMultiplicationOrDivision(ArithmeticNode node) { - for (ArithmeticOperator o : node.operators()) - if (o != ArithmeticOperator.MULTIPLY && o != ArithmeticOperator.DIVIDE) + private boolean allMultiplicationOrDivision(OperationNode node) { + for (Operator o : node.operators()) + if (o != Operator.MULTIPLY && o != Operator.DIVIDE) return false; return true; } - private boolean hasZero(ArithmeticNode node) { + private boolean hasZero(OperationNode node) { for (ExpressionNode child : node.children()) { if (isZero(child)) return true; @@ -129,9 +129,9 @@ public class Simplifier extends ExpressionTransformer { return false; } - private boolean hasDivisionByZero(ArithmeticNode node) { + private boolean hasDivisionByZero(OperationNode node) { for (int i = 1; i < node.children().size(); i++) { - if ( node.operators().get(i - 1) == ArithmeticOperator.DIVIDE && isZero(node.children().get(i))) + if (node.operators().get(i - 1) == Operator.DIVIDE && isZero(node.children().get(i))) return true; } return false; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 1ab9ee11252..83de8e04a7d 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -2,8 +2,8 @@ package com.yahoo.searchlib.rankingexpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -66,7 +66,7 @@ public class RankingExpressionTestCase { public void testProgrammaticBuilding() throws ParseException { ReferenceNode input = new ReferenceNode("input"); ReferenceNode constant = new ReferenceNode("constant"); - ArithmeticNode product = new ArithmeticNode(input, ArithmeticOperator.MULTIPLY, constant); + OperationNode product = new OperationNode(input, Operator.MULTIPLY, constant); Reduce sum = new Reduce<>(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum); RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum)); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index b1ac4b9e3ca..019f76521e9 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -5,8 +5,8 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.Arguments; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; @@ -779,8 +779,8 @@ public class EvaluationTestCase { @Test public void testProgrammaticBuildingAndPrecedence() { - RankingExpression standardPrecedence = new RankingExpression(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, new ArithmeticNode(constant(3), ArithmeticOperator.MULTIPLY, constant(4)))); - RankingExpression oppositePrecedence = new RankingExpression(new ArithmeticNode(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, constant(3)), ArithmeticOperator.MULTIPLY, constant(4))); + RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.PLUS, new OperationNode(constant(3), Operator.MULTIPLY, constant(4)))); + RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.PLUS, constant(3)), Operator.MULTIPLY, constant(4))); assertEquals(14.0, standardPrecedence.evaluate(null).asDouble(), tolerance); assertEquals(20.0, oppositePrecedence.evaluate(null).asDouble(), tolerance); assertEquals("2.0 + 3.0 * 4.0", standardPrecedence.toString()); -- cgit v1.2.3 From a1912b44d0b800f96b334a24ddefd0026f3af356 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 28 Sep 2022 19:00:08 +0200 Subject: Use tensor vocabulary --- .../BooleanExpressionTransformer.java | 4 +-- .../expressiontransforms/TokenTransformer.java | 12 ++++---- .../java/ai/vespa/models/evaluation/LazyValue.java | 20 ++++++------- .../importer/operations/Gather.java | 6 ++-- .../importer/operations/Gemm.java | 4 +-- .../importer/operations/Range.java | 4 +-- .../importer/operations/Reshape.java | 10 +++---- .../importer/operations/Slice.java | 4 +-- .../importer/operations/Split.java | 2 +- .../importer/operations/Tile.java | 2 +- .../evaluation/DoubleCompatibleValue.java | 10 +++---- .../rankingexpression/evaluation/StringValue.java | 10 +++---- .../rankingexpression/evaluation/TensorValue.java | 10 +++---- .../rankingexpression/evaluation/Value.java | 10 +++---- .../gbdtoptimization/GBDTForestOptimizer.java | 2 +- .../evaluation/gbdtoptimization/GBDTOptimizer.java | 20 ++++++------- .../rankingexpression/rule/LambdaFunctionNode.java | 16 +++++------ .../rankingexpression/rule/OperationNode.java | 8 ------ .../searchlib/rankingexpression/rule/Operator.java | 33 ++++++++++++---------- .../rankingexpression/transform/Simplifier.java | 4 +-- .../RankingExpressionTestCase.java | 2 +- .../evaluation/EvaluationTestCase.java | 4 +-- 22 files changed, 96 insertions(+), 101 deletions(-) (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java') diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java index 9ffc73e1863..ad050d4ca63 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java @@ -66,9 +66,9 @@ public class BooleanExpressionTransformer extends ExpressionTransformer tokenSequence = createTokenSequence(feature); ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); - OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.LESS, lengthExpr); + OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr); ExpressionNode expr = new IfNode(comparison, ONE, ZERO); return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); } @@ -254,7 +254,7 @@ public class TokenTransformer extends ExpressionTransformer sequence) { ExpressionNode lengthExpr = createLengthExpr(iter, sequence); - OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.LESS, lengthExpr); + OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr); ExpressionNode trueExpr = sequence.get(iter); if (sequence.get(iter) instanceof ReferenceNode) { @@ -286,7 +286,7 @@ public class TokenTransformer extends ExpressionTransformer= 1) { - operators.add(Operator.PLUS); + operators.add(Operator.plus); } } return new OperationNode(factors, operators); @@ -299,7 +299,7 @@ public class TokenTransformer extends ExpressionTransformer= 1) { ExpressionNode lengthExpr = new EmbracedNode(createLengthExpr(iter - 1, sequence)); - expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.MINUS, lengthExpr)); + expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.minus, lengthExpr)); } else { expr = new ReferenceNode("d1"); } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java index e0c99706e4a..3aa60e0b9a2 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java @@ -84,28 +84,28 @@ class LazyValue extends Value { } @Override - public Value greaterEqual(Value value) { - return computedValue().greaterEqual(value); + public Value largerOrEqual(Value value) { + return computedValue().largerOrEqual(value); } @Override - public Value greater(Value value) { - return computedValue().greater(value); + public Value larger(Value value) { + return computedValue().larger(value); } @Override - public Value lessEqual(Value value) { - return computedValue().lessEqual(value); + public Value smallerOrEqual(Value value) { + return computedValue().smallerOrEqual(value); } @Override - public Value less(Value value) { - return computedValue().less(value); + public Value smaller(Value value) { + return computedValue().smaller(value); } @Override - public Value approx(Value value) { - return computedValue().approx(value); + public Value approxEqual(Value value) { + return computedValue().approxEqual(value); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java index c66022975c7..7ac2f4bff84 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -92,7 +92,7 @@ public class Gather extends IntermediateOperation { ExpressionNode indexExpression = new ConstantNode(new DoubleValue(constantValue)); if (constantValue < 0) { ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); - indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.PLUS, axisSize)); + indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.plus, axisSize)); } addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression); } else { @@ -125,8 +125,8 @@ public class Gather extends IntermediateOperation { /** to support negative indexing */ private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) { ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); - ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.PLUS, axisSize)); - ExpressionNode mod = new OperationNode(plus, Operator.MODULO, axisSize); + ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.plus, axisSize)); + ExpressionNode mod = new OperationNode(plus, Operator.modulo, axisSize); return mod; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 81d633dea4b..97bfdda385e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -99,7 +99,7 @@ public class Gemm extends IntermediateOperation { TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( new OperationNode( new TensorFunctionNode(AxB), - Operator.MULTIPLY, + Operator.multiply, new ConstantNode(new DoubleValue(alpha)))); if (inputs.size() == 3) { @@ -107,7 +107,7 @@ public class Gemm extends IntermediateOperation { TensorFunction betaxC = new TensorFunctionNode.ExpressionTensorFunction( new OperationNode( new TensorFunctionNode(cFunction.get()), - Operator.MULTIPLY, + Operator.multiply, new ConstantNode(new DoubleValue(beta)))); return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java index 66e810b954e..9a38ab9dfde 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java @@ -65,8 +65,8 @@ public class Range extends IntermediateOperation { ExpressionNode startExpr = new ConstantNode(new DoubleValue(start)); ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta)); ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName)); - ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.MULTIPLY, dimExpr); - ExpressionNode addExpr = new OperationNode(startExpr, Operator.PLUS, stepExpr); + ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.multiply, dimExpr); + ExpressionNode addExpr = new OperationNode(startExpr, Operator.plus, stepExpr); TensorFunction function = Generate.bound(type.type(), wrapScalar(addExpr)); return function; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index ce93461bff3..bc94fc6aa76 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -157,13 +157,13 @@ public class Reshape extends IntermediateOperation { inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); } else if (dim == (inputType.rank() - 1)) { ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); - ExpressionNode div = new OperationNode(unrolled, Operator.MODULO, size); + ExpressionNode div = new OperationNode(unrolled, Operator.modulo, size); inputDimensionExpression = new EmbracedNode(div); } else { ExpressionNode size = new ConstantNode(new DoubleValue(innerSize)); ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize)); - ExpressionNode mod = new OperationNode(unrolled, Operator.MODULO, previousSize); - ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.DIVIDE, size); + ExpressionNode mod = new OperationNode(unrolled, Operator.modulo, previousSize); + ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.divide, size); inputDimensionExpression = new EmbracedNode(div); } dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression))); @@ -187,12 +187,12 @@ public class Reshape extends IntermediateOperation { TensorType.Dimension dimension = type.dimensions().get(i); children.add(0, new ReferenceNode(dimension.name())); if (size > 1) { - operators.add(0, Operator.MULTIPLY); + operators.add(0, Operator.multiply); children.add(0, new ConstantNode(new DoubleValue(size))); } size *= OrderedTensorType.dimensionSize(dimension); if (i > 0) { - operators.add(0, Operator.PLUS); + operators.add(0, Operator.plus); } } return new OperationNode(children, operators); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java index 617e1f00c94..916e9980131 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -166,8 +166,8 @@ public class Slice extends IntermediateOperation { // step * (d0 + start) ExpressionNode reference = new ReferenceNode(outputDimensionName); - ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.PLUS, startIndex)); - ExpressionNode mul = new OperationNode(stepSize, Operator.MULTIPLY, plus); + ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.plus, startIndex)); + ExpressionNode mul = new OperationNode(stepSize, Operator.multiply, plus); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul)))); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java index 42901259821..ef46d222941 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java @@ -96,7 +96,7 @@ public class Split extends IntermediateOperation { for (int i = 0; i < inputType.rank(); ++i) { String inputDimensionName = inputType.dimensions().get(i).name(); ExpressionNode reference = new ReferenceNode(inputDimensionName); - ExpressionNode offset = new OperationNode(reference, Operator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0))); + ExpressionNode offset = new OperationNode(reference, Operator.plus, new ConstantNode(new DoubleValue(i == axis ? start : 0))); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset)))); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java index cd88c625d81..a880bff87be 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -77,7 +77,7 @@ public class Tile extends IntermediateOperation { ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); ExpressionNode reference = new ReferenceNode(inputDimensionName); - ExpressionNode mod = new OperationNode(reference, Operator.MODULO, size); + ExpressionNode mod = new OperationNode(reference, Operator.modulo, size); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod)))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index e1db6378fcf..186208e036f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -42,27 +42,27 @@ public abstract class DoubleCompatibleValue extends Value { } @Override - public Value greaterEqual(Value value) { + public Value largerOrEqual(Value value) { return new BooleanValue(this.asDouble() >= value.asDouble()); } @Override - public Value greater(Value value) { + public Value larger(Value value) { return new BooleanValue(this.asDouble() > value.asDouble()); } @Override - public Value lessEqual(Value value) { + public Value smallerOrEqual(Value value) { return new BooleanValue(this.asDouble() <= value.asDouble()); } @Override - public Value less(Value value) { + public Value smaller(Value value) { return new BooleanValue(this.asDouble() < value.asDouble()); } @Override - public Value approx(Value value) { + public Value approxEqual(Value value) { return new BooleanValue(approxEqual(this.asDouble(), value.asDouble())); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index 3c09c644147..a585c989954 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -71,27 +71,27 @@ public class StringValue extends Value { } @Override - public Value greaterEqual(Value argument) { + public Value largerOrEqual(Value argument) { throw new UnsupportedOperationException("String values ('" + value + "') do not support greaterEqual"); } @Override - public Value greater(Value argument) { + public Value larger(Value argument) { throw new UnsupportedOperationException("String values ('" + value + "') do not support greater"); } @Override - public Value lessEqual(Value argument) { + public Value smallerOrEqual(Value argument) { throw new UnsupportedOperationException("String values ('" + value + "') do not support lessEqual"); } @Override - public Value less(Value argument) { + public Value smaller(Value argument) { throw new UnsupportedOperationException("String values ('" + value + "') do not support less"); } @Override - public Value approx(Value argument) { + public Value approxEqual(Value argument) { return new BooleanValue(this.asDouble() == argument.asDouble()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 73ea0b23986..25e03c75376 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -72,7 +72,7 @@ public class TensorValue extends Value { } @Override - public Value greaterEqual(Value argument) { + public Value largerOrEqual(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.largerOrEqual(((TensorValue)argument).value)); else @@ -80,7 +80,7 @@ public class TensorValue extends Value { } @Override - public Value greater(Value argument) { + public Value larger(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.larger(((TensorValue)argument).value)); else @@ -88,7 +88,7 @@ public class TensorValue extends Value { } @Override - public Value lessEqual(Value argument) { + public Value smallerOrEqual(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value)); else @@ -96,7 +96,7 @@ public class TensorValue extends Value { } @Override - public Value less(Value argument) { + public Value smaller(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.smaller(((TensorValue)argument).value)); else @@ -104,7 +104,7 @@ public class TensorValue extends Value { } @Override - public Value approx(Value argument) { + public Value approxEqual(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.approxEqual(((TensorValue)argument).value)); else diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 99663fe8d0d..5de2138147e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -53,11 +53,11 @@ public abstract class Value { public abstract Value or(Value value); public abstract Value and(Value value); - public abstract Value greaterEqual(Value value); - public abstract Value greater(Value value); - public abstract Value lessEqual(Value value); - public abstract Value less(Value value); - public abstract Value approx(Value value); + public abstract Value largerOrEqual(Value value); + public abstract Value larger(Value value); + public abstract Value smallerOrEqual(Value value); + public abstract Value smaller(Value value); + public abstract Value approxEqual(Value value); public abstract Value notEqual(Value value); public abstract Value equal(Value value); public abstract Value add(Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java index 78acd2e5af1..a3fc6aae9ac 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java @@ -89,7 +89,7 @@ public class GBDTForestOptimizer extends Optimizer { } OperationNode aNode = (OperationNode)node; for (Operator op : aNode.operators()) { - if (op != Operator.PLUS) { + if (op != Operator.plus) { return false; } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java index b3cbe252dfc..7ba671e62eb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java @@ -116,9 +116,9 @@ public class GBDTOptimizer extends Optimizer { private int consumeIfCondition(ExpressionNode condition, List values, ContextIndex context) { if (isBinaryComparison(condition)) { OperationNode comparison = (OperationNode)condition; - if (comparison.operators().get(0) == Operator.LESS) + if (comparison.operators().get(0) == Operator.smaller) values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context)); - else if (comparison.operators().get(0) == Operator.EQUAL) + else if (comparison.operators().get(0) == Operator.equal) values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context)); else throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0)); @@ -135,7 +135,7 @@ public class GBDTOptimizer extends Optimizer { if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) { if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) { OperationNode comparison = (OperationNode)embracedNode.children().get(0); - if (comparison.operators().get(0) == Operator.GREATEREQUAL) + if (comparison.operators().get(0) == Operator.largerOrEqual) values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context)); else throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0)); @@ -153,13 +153,13 @@ public class GBDTOptimizer extends Optimizer { private boolean isBinaryComparison(ExpressionNode condition) { if ( ! (condition instanceof OperationNode binaryNode)) return false; if (binaryNode.operators().size() != 1) return false; - if (binaryNode.operators().get(0) == Operator.GREATEREQUAL) return true; - if (binaryNode.operators().get(0) == Operator.GREATER) return true; - if (binaryNode.operators().get(0) == Operator.LESSEQUAL) return true; - if (binaryNode.operators().get(0) == Operator.LESS) return true; - if (binaryNode.operators().get(0) == Operator.APPROX) return true; - if (binaryNode.operators().get(0) == Operator.NOTEQUAL) return true; - if (binaryNode.operators().get(0) == Operator.EQUAL) return true; + if (binaryNode.operators().get(0) == Operator.largerOrEqual) return true; + if (binaryNode.operators().get(0) == Operator.larger) return true; + if (binaryNode.operators().get(0) == Operator.smallerOrEqual) return true; + if (binaryNode.operators().get(0) == Operator.smaller) return true; + if (binaryNode.operators().get(0) == Operator.approxEqual) return true; + if (binaryNode.operators().get(0) == Operator.notEqual) return true; + if (binaryNode.operators().get(0) == Operator.equal) return true; return false; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 3c7f48aa38c..97e9a74f9c8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -123,14 +123,14 @@ public class LambdaFunctionNode extends CompositeNode { } Operator operator = node.operators().get(0); switch (operator) { - case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0); - case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0); - case PLUS: return asFunctionExpression((left, right) -> left + right); - case MINUS: return asFunctionExpression((left, right) -> left - right); - case MULTIPLY: return asFunctionExpression((left, right) -> left * right); - case DIVIDE: return asFunctionExpression((left, right) -> left / right); - case MODULO: return asFunctionExpression((left, right) -> left % right); - case POWER: return asFunctionExpression(Math::pow); + case or: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0); + case and: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0); + case plus: return asFunctionExpression((left, right) -> left + right); + case minus: return asFunctionExpression((left, right) -> left - right); + case multiply: return asFunctionExpression((left, right) -> left * right); + case divide: return asFunctionExpression((left, right) -> left / right); + case modulo: return asFunctionExpression((left, right) -> left % right); + case power: return asFunctionExpression(Math::pow); } return Optional.empty(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java index 392f42f6cbe..d08e2270935 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java @@ -127,14 +127,6 @@ public final class OperationNode extends CompositeNode { @Override public int hashCode() { return Objects.hash(children, operators); } - @Override - public boolean equals(Object o) { - if ( ! (o instanceof OperationNode other)) return false; - if ( ! this.children().equals(other.children())) return false; - if ( ! this.operators().equals(other.operators())) return false; - return true; - } - public static OperationNode resolve(ExpressionNode left, Operator op, ExpressionNode right) { if ( ! (left instanceof OperationNode leftArithmetic)) return new OperationNode(left, op, right); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java index 4ddbfa4ea9f..63144f0ef4a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java @@ -15,21 +15,21 @@ import java.util.function.BiFunction; public enum Operator { // In order from lowest to highest precedence - OR("||", (x, y) -> x.or(y)), - AND("&&", (x, y) -> x.and(y)), - GREATEREQUAL(">=", (x, y) -> x.greaterEqual(y)), - GREATER(">", (x, y) -> x.greater(y)), - LESSEQUAL("<=", (x, y) -> x.lessEqual(y)), - LESS("<", (x, y) -> x.less(y)), - APPROX("~=", (x, y) -> x.approx(y)), - NOTEQUAL("!=", (x, y) -> x.notEqual(y)), - EQUAL("==", (x, y) -> x.equal(y)), - PLUS("+", (x, y) -> x.add(y)), - MINUS("-", (x, y) -> x.subtract(y)), - MULTIPLY("*", (x, y) -> x.multiply(y)), - DIVIDE("/", (x, y) -> x.divide(y)), - MODULO("%", (x, y) -> x.modulo(y)), - POWER("^", true, (x, y) -> x.power(y)); + or("||", (x, y) -> x.or(y)), + and("&&", (x, y) -> x.and(y)), + largerOrEqual(">=", (x, y) -> x.largerOrEqual(y)), + larger(">", (x, y) -> x.larger(y)), + smallerOrEqual("<=", (x, y) -> x.smallerOrEqual(y)), + smaller("<", (x, y) -> x.smaller(y)), + approxEqual("~=", (x, y) -> x.approxEqual(y)), + notEqual("!=", (x, y) -> x.notEqual(y)), + equal("==", (x, y) -> x.equal(y)), + plus("+", (x, y) -> x.add(y)), + minus("-", (x, y) -> x.subtract(y)), + multiply("*", (x, y) -> x.multiply(y)), + divide("/", (x, y) -> x.divide(y)), + modulo("%", (x, y) -> x.modulo(y)), + power("^", true, (x, y) -> x.power(y)); /** A list of all the operators in this in order of increasing precedence */ public static final List operatorsByPrecedence = Arrays.stream(Operator.values()).toList(); @@ -53,6 +53,9 @@ public enum Operator { return operatorsByPrecedence.indexOf(this) > operatorsByPrecedence.indexOf(op); } + /** Returns true if a sequence of these operations should be evaluated from right to left rather than left to right. */ + public boolean bindsRight() { return bindsRight; } + public final Value evaluate(Value x, Value y) { return function.apply(x, y); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index 04728966bc1..1c1f7509ce8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -116,7 +116,7 @@ public class Simplifier extends ExpressionTransformer { private boolean allMultiplicationOrDivision(OperationNode node) { for (Operator o : node.operators()) - if (o != Operator.MULTIPLY && o != Operator.DIVIDE) + if (o != Operator.multiply && o != Operator.divide) return false; return true; } @@ -131,7 +131,7 @@ public class Simplifier extends ExpressionTransformer { private boolean hasDivisionByZero(OperationNode node) { for (int i = 1; i < node.children().size(); i++) { - if (node.operators().get(i - 1) == Operator.DIVIDE && isZero(node.children().get(i))) + if (node.operators().get(i - 1) == Operator.divide && isZero(node.children().get(i))) return true; } return false; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 83de8e04a7d..f18240c3222 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -66,7 +66,7 @@ public class RankingExpressionTestCase { public void testProgrammaticBuilding() throws ParseException { ReferenceNode input = new ReferenceNode("input"); ReferenceNode constant = new ReferenceNode("constant"); - OperationNode product = new OperationNode(input, Operator.MULTIPLY, constant); + OperationNode product = new OperationNode(input, Operator.multiply, constant); Reduce sum = new Reduce<>(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum); RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum)); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 019f76521e9..dac7393a168 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -779,8 +779,8 @@ public class EvaluationTestCase { @Test public void testProgrammaticBuildingAndPrecedence() { - RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.PLUS, new OperationNode(constant(3), Operator.MULTIPLY, constant(4)))); - RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.PLUS, constant(3)), Operator.MULTIPLY, constant(4))); + RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.plus, new OperationNode(constant(3), Operator.multiply, constant(4)))); + RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.plus, constant(3)), Operator.multiply, constant(4))); assertEquals(14.0, standardPrecedence.evaluate(null).asDouble(), tolerance); assertEquals(20.0, oppositePrecedence.evaluate(null).asDouble(), tolerance); assertEquals("2.0 + 3.0 * 4.0", standardPrecedence.toString()); -- cgit v1.2.3