diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-09-28 19:00:08 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-09-28 19:00:08 +0200 |
commit | a1912b44d0b800f96b334a24ddefd0026f3af356 (patch) | |
tree | 352e2b4d026ae9373d73dc4fd7e9892c81943f7f | |
parent | bcbb2009c44380055b2670e7cdefcad232f9ece4 (diff) |
Use tensor vocabulary
22 files changed, 96 insertions, 101 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java index 9ffc73e1863..ad050d4ca63 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java @@ -66,9 +66,9 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor ChildNode lhs = stack.peek(); ExpressionNode combination; - if (rhs.op == Operator.AND) + if (rhs.op == Operator.and) combination = andByIfNode(lhs.child, rhs.child); - else if (rhs.op == Operator.OR) + else if (rhs.op == Operator.or) combination = orByIfNode(lhs.child, rhs.child); else { combination = resolve(lhs, rhs); diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java index dd714501a41..cf354a05a93 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java @@ -139,10 +139,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence); ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); ExpressionNode expr = new IfNode( - new OperationNode(new ReferenceNode("d1"), Operator.LESS, queryLengthExpr), + new OperationNode(new ReferenceNode("d1"), Operator.smaller, queryLengthExpr), ZERO, new IfNode( - new OperationNode(new ReferenceNode("d1"), Operator.LESS, restLengthExpr), + new OperationNode(new ReferenceNode("d1"), Operator.smaller, restLengthExpr), ONE, ZERO ) @@ -174,7 +174,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform List<ExpressionNode> tokenSequence = createTokenSequence(feature); ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); - OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.LESS, lengthExpr); + OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr); ExpressionNode expr = new IfNode(comparison, ONE, ZERO); return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); } @@ -254,7 +254,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform */ private ExpressionNode createTokenSequenceExpr(int iter, List<ExpressionNode> sequence) { ExpressionNode lengthExpr = createLengthExpr(iter, sequence); - OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.LESS, lengthExpr); + OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr); ExpressionNode trueExpr = sequence.get(iter); if (sequence.get(iter) instanceof ReferenceNode) { @@ -286,7 +286,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i)))); } if (i >= 1) { - operators.add(Operator.PLUS); + operators.add(Operator.plus); } } return new OperationNode(factors, operators); @@ -299,7 +299,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform ExpressionNode expr; if (iter >= 1) { ExpressionNode lengthExpr = new EmbracedNode(createLengthExpr(iter - 1, sequence)); - expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.MINUS, lengthExpr)); + expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.minus, lengthExpr)); } else { expr = new ReferenceNode("d1"); } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java index e0c99706e4a..3aa60e0b9a2 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java @@ -84,28 +84,28 @@ class LazyValue extends Value { } @Override - public Value greaterEqual(Value value) { - return computedValue().greaterEqual(value); + public Value largerOrEqual(Value value) { + return computedValue().largerOrEqual(value); } @Override - public Value greater(Value value) { - return computedValue().greater(value); + public Value larger(Value value) { + return computedValue().larger(value); } @Override - public Value lessEqual(Value value) { - return computedValue().lessEqual(value); + public Value smallerOrEqual(Value value) { + return computedValue().smallerOrEqual(value); } @Override - public Value less(Value value) { - return computedValue().less(value); + public Value smaller(Value value) { + return computedValue().smaller(value); } @Override - public Value approx(Value value) { - return computedValue().approx(value); + public Value approxEqual(Value value) { + return computedValue().approxEqual(value); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java index c66022975c7..7ac2f4bff84 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -92,7 +92,7 @@ public class Gather extends IntermediateOperation { ExpressionNode indexExpression = new ConstantNode(new DoubleValue(constantValue)); if (constantValue < 0) { ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); - indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.PLUS, axisSize)); + indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.plus, axisSize)); } addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression); } else { @@ -125,8 +125,8 @@ public class Gather extends IntermediateOperation { /** to support negative indexing */ private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) { ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); - ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.PLUS, axisSize)); - ExpressionNode mod = new OperationNode(plus, Operator.MODULO, axisSize); + ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.plus, axisSize)); + ExpressionNode mod = new OperationNode(plus, Operator.modulo, axisSize); return mod; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 81d633dea4b..97bfdda385e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -99,7 +99,7 @@ public class Gemm extends IntermediateOperation { TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( new OperationNode( new TensorFunctionNode(AxB), - Operator.MULTIPLY, + Operator.multiply, new ConstantNode(new DoubleValue(alpha)))); if (inputs.size() == 3) { @@ -107,7 +107,7 @@ public class Gemm extends IntermediateOperation { TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction( new OperationNode( new TensorFunctionNode(cFunction.get()), - Operator.MULTIPLY, + Operator.multiply, new ConstantNode(new DoubleValue(beta)))); return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java index 66e810b954e..9a38ab9dfde 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java @@ -65,8 +65,8 @@ public class Range extends IntermediateOperation { ExpressionNode startExpr = new ConstantNode(new DoubleValue(start)); ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta)); ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName)); - ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.MULTIPLY, dimExpr); - ExpressionNode addExpr = new OperationNode(startExpr, Operator.PLUS, stepExpr); + ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.multiply, dimExpr); + ExpressionNode addExpr = new OperationNode(startExpr, Operator.plus, stepExpr); TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(addExpr)); return function; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index ce93461bff3..bc94fc6aa76 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -157,13 +157,13 @@ public class Reshape extends IntermediateOperation { inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); } else if (dim == (inputType.rank() - 1)) { ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); - ExpressionNode div = new OperationNode(unrolled, Operator.MODULO, size); + ExpressionNode div = new OperationNode(unrolled, Operator.modulo, size); inputDimensionExpression = new EmbracedNode(div); } else { ExpressionNode size = new ConstantNode(new DoubleValue(innerSize)); ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize)); - ExpressionNode mod = new OperationNode(unrolled, Operator.MODULO, previousSize); - ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.DIVIDE, size); + ExpressionNode mod = new OperationNode(unrolled, Operator.modulo, previousSize); + ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.divide, size); inputDimensionExpression = new EmbracedNode(div); } dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression))); @@ -187,12 +187,12 @@ public class Reshape extends IntermediateOperation { TensorType.Dimension dimension = type.dimensions().get(i); children.add(0, new ReferenceNode(dimension.name())); if (size > 1) { - operators.add(0, Operator.MULTIPLY); + operators.add(0, Operator.multiply); children.add(0, new ConstantNode(new DoubleValue(size))); } size *= OrderedTensorType.dimensionSize(dimension); if (i > 0) { - operators.add(0, Operator.PLUS); + operators.add(0, Operator.plus); } } return new OperationNode(children, operators); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java index 617e1f00c94..916e9980131 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -166,8 +166,8 @@ public class Slice extends IntermediateOperation { // step * (d0 + start) ExpressionNode reference = new ReferenceNode(outputDimensionName); - ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.PLUS, startIndex)); - ExpressionNode mul = new OperationNode(stepSize, Operator.MULTIPLY, plus); + ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.plus, startIndex)); + ExpressionNode mul = new OperationNode(stepSize, Operator.multiply, plus); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul)))); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java index 42901259821..ef46d222941 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java @@ -96,7 +96,7 @@ public class Split extends IntermediateOperation { for (int i = 0; i < inputType.rank(); ++i) { String inputDimensionName = inputType.dimensions().get(i).name(); ExpressionNode reference = new ReferenceNode(inputDimensionName); - ExpressionNode offset = new OperationNode(reference, Operator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0))); + ExpressionNode offset = new OperationNode(reference, Operator.plus, new ConstantNode(new DoubleValue(i == axis ? start : 0))); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset)))); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java index cd88c625d81..a880bff87be 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -77,7 +77,7 @@ public class Tile extends IntermediateOperation { ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); ExpressionNode reference = new ReferenceNode(inputDimensionName); - ExpressionNode mod = new OperationNode(reference, Operator.MODULO, size); + ExpressionNode mod = new OperationNode(reference, Operator.modulo, size); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod)))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index e1db6378fcf..186208e036f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -42,27 +42,27 @@ public abstract class DoubleCompatibleValue extends Value { } @Override - public Value greaterEqual(Value value) { + public Value largerOrEqual(Value value) { return new BooleanValue(this.asDouble() >= value.asDouble()); } @Override - public Value greater(Value value) { + public Value larger(Value value) { return new BooleanValue(this.asDouble() > value.asDouble()); } @Override - public Value lessEqual(Value value) { + public Value smallerOrEqual(Value value) { return new BooleanValue(this.asDouble() <= value.asDouble()); } @Override - public Value less(Value value) { + public Value smaller(Value value) { return new BooleanValue(this.asDouble() < value.asDouble()); } @Override - public Value approx(Value value) { + public Value approxEqual(Value value) { return new BooleanValue(approxEqual(this.asDouble(), value.asDouble())); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index 3c09c644147..a585c989954 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -71,27 +71,27 @@ public class StringValue extends Value { } @Override - public Value greaterEqual(Value argument) { + public Value largerOrEqual(Value argument) { throw new UnsupportedOperationException("String values ('" + value + "') do not support greaterEqual"); } @Override - public Value greater(Value argument) { + public Value larger(Value argument) { throw new UnsupportedOperationException("String values ('" + value + "') do not support greater"); } @Override - public Value lessEqual(Value argument) { + public Value smallerOrEqual(Value argument) { throw new UnsupportedOperationException("String values ('" + value + "') do not support lessEqual"); } @Override - public Value less(Value argument) { + public Value smaller(Value argument) { throw new UnsupportedOperationException("String values ('" + value + "') do not support less"); } @Override - public Value approx(Value argument) { + public Value approxEqual(Value argument) { return new BooleanValue(this.asDouble() == argument.asDouble()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 73ea0b23986..25e03c75376 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -72,7 +72,7 @@ public class TensorValue extends Value { } @Override - public Value greaterEqual(Value argument) { + public Value largerOrEqual(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.largerOrEqual(((TensorValue)argument).value)); else @@ -80,7 +80,7 @@ public class TensorValue extends Value { } @Override - public Value greater(Value argument) { + public Value larger(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.larger(((TensorValue)argument).value)); else @@ -88,7 +88,7 @@ public class TensorValue extends Value { } @Override - public Value lessEqual(Value argument) { + public Value smallerOrEqual(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value)); else @@ -96,7 +96,7 @@ public class TensorValue extends Value { } @Override - public Value less(Value argument) { + public Value smaller(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.smaller(((TensorValue)argument).value)); else @@ -104,7 +104,7 @@ public class TensorValue extends Value { } @Override - public Value approx(Value argument) { + public Value approxEqual(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.approxEqual(((TensorValue)argument).value)); else diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 99663fe8d0d..5de2138147e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -53,11 +53,11 @@ public abstract class Value { public abstract Value or(Value value); public abstract Value and(Value value); - public abstract Value greaterEqual(Value value); - public abstract Value greater(Value value); - public abstract Value lessEqual(Value value); - public abstract Value less(Value value); - public abstract Value approx(Value value); + public abstract Value largerOrEqual(Value value); + public abstract Value larger(Value value); + public abstract Value smallerOrEqual(Value value); + public abstract Value smaller(Value value); + public abstract Value approxEqual(Value value); public abstract Value notEqual(Value value); public abstract Value equal(Value value); public abstract Value add(Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java index 78acd2e5af1..a3fc6aae9ac 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java @@ -89,7 +89,7 @@ public class GBDTForestOptimizer extends Optimizer { } OperationNode aNode = (OperationNode)node; for (Operator op : aNode.operators()) { - if (op != Operator.PLUS) { + if (op != Operator.plus) { return false; } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java index b3cbe252dfc..7ba671e62eb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java @@ -116,9 +116,9 @@ public class GBDTOptimizer extends Optimizer { private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) { if (isBinaryComparison(condition)) { OperationNode comparison = (OperationNode)condition; - if (comparison.operators().get(0) == Operator.LESS) + if (comparison.operators().get(0) == Operator.smaller) values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context)); - else if (comparison.operators().get(0) == Operator.EQUAL) + else if (comparison.operators().get(0) == Operator.equal) values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context)); else throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0)); @@ -135,7 +135,7 @@ public class GBDTOptimizer extends Optimizer { if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) { if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) { OperationNode comparison = (OperationNode)embracedNode.children().get(0); - if (comparison.operators().get(0) == Operator.GREATEREQUAL) + if (comparison.operators().get(0) == Operator.largerOrEqual) values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context)); else throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0)); @@ -153,13 +153,13 @@ public class GBDTOptimizer extends Optimizer { private boolean isBinaryComparison(ExpressionNode condition) { if ( ! (condition instanceof OperationNode binaryNode)) return false; if (binaryNode.operators().size() != 1) return false; - if (binaryNode.operators().get(0) == Operator.GREATEREQUAL) return true; - if (binaryNode.operators().get(0) == Operator.GREATER) return true; - if (binaryNode.operators().get(0) == Operator.LESSEQUAL) return true; - if (binaryNode.operators().get(0) == Operator.LESS) return true; - if (binaryNode.operators().get(0) == Operator.APPROX) return true; - if (binaryNode.operators().get(0) == Operator.NOTEQUAL) return true; - if (binaryNode.operators().get(0) == Operator.EQUAL) return true; + if (binaryNode.operators().get(0) == Operator.largerOrEqual) return true; + if (binaryNode.operators().get(0) == Operator.larger) return true; + if (binaryNode.operators().get(0) == Operator.smallerOrEqual) return true; + if (binaryNode.operators().get(0) == Operator.smaller) return true; + if (binaryNode.operators().get(0) == Operator.approxEqual) return true; + if (binaryNode.operators().get(0) == Operator.notEqual) return true; + if (binaryNode.operators().get(0) == Operator.equal) return true; return false; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 3c7f48aa38c..97e9a74f9c8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -123,14 +123,14 @@ public class LambdaFunctionNode extends CompositeNode { } Operator operator = node.operators().get(0); switch (operator) { - case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0); - case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0); - case PLUS: return asFunctionExpression((left, right) -> left + right); - case MINUS: return asFunctionExpression((left, right) -> left - right); - case MULTIPLY: return asFunctionExpression((left, right) -> left * right); - case DIVIDE: return asFunctionExpression((left, right) -> left / right); - case MODULO: return asFunctionExpression((left, right) -> left % right); - case POWER: return asFunctionExpression(Math::pow); + case or: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0); + case and: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0); + case plus: return asFunctionExpression((left, right) -> left + right); + case minus: return asFunctionExpression((left, right) -> left - right); + case multiply: return asFunctionExpression((left, right) -> left * right); + case divide: return asFunctionExpression((left, right) -> left / right); + case modulo: return asFunctionExpression((left, right) -> left % right); + case power: return asFunctionExpression(Math::pow); } return Optional.empty(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java index 392f42f6cbe..d08e2270935 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java @@ -127,14 +127,6 @@ public final class OperationNode extends CompositeNode { @Override public int hashCode() { return Objects.hash(children, operators); } - @Override - public boolean equals(Object o) { - if ( ! (o instanceof OperationNode other)) return false; - if ( ! this.children().equals(other.children())) return false; - if ( ! this.operators().equals(other.operators())) return false; - return true; - } - public static OperationNode resolve(ExpressionNode left, Operator op, ExpressionNode right) { if ( ! (left instanceof OperationNode leftArithmetic)) return new OperationNode(left, op, right); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java index 4ddbfa4ea9f..63144f0ef4a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java @@ -15,21 +15,21 @@ import java.util.function.BiFunction; public enum Operator { // In order from lowest to highest precedence - OR("||", (x, y) -> x.or(y)), - AND("&&", (x, y) -> x.and(y)), - GREATEREQUAL(">=", (x, y) -> x.greaterEqual(y)), - GREATER(">", (x, y) -> x.greater(y)), - LESSEQUAL("<=", (x, y) -> x.lessEqual(y)), - LESS("<", (x, y) -> x.less(y)), - APPROX("~=", (x, y) -> x.approx(y)), - NOTEQUAL("!=", (x, y) -> x.notEqual(y)), - EQUAL("==", (x, y) -> x.equal(y)), - PLUS("+", (x, y) -> x.add(y)), - MINUS("-", (x, y) -> x.subtract(y)), - MULTIPLY("*", (x, y) -> x.multiply(y)), - DIVIDE("/", (x, y) -> x.divide(y)), - MODULO("%", (x, y) -> x.modulo(y)), - POWER("^", true, (x, y) -> x.power(y)); + or("||", (x, y) -> x.or(y)), + and("&&", (x, y) -> x.and(y)), + largerOrEqual(">=", (x, y) -> x.largerOrEqual(y)), + larger(">", (x, y) -> x.larger(y)), + smallerOrEqual("<=", (x, y) -> x.smallerOrEqual(y)), + smaller("<", (x, y) -> x.smaller(y)), + approxEqual("~=", (x, y) -> x.approxEqual(y)), + notEqual("!=", (x, y) -> x.notEqual(y)), + equal("==", (x, y) -> x.equal(y)), + plus("+", (x, y) -> x.add(y)), + minus("-", (x, y) -> x.subtract(y)), + multiply("*", (x, y) -> x.multiply(y)), + divide("/", (x, y) -> x.divide(y)), + modulo("%", (x, y) -> x.modulo(y)), + power("^", true, (x, y) -> x.power(y)); /** A list of all the operators in this in order of increasing precedence */ public static final List<Operator> operatorsByPrecedence = Arrays.stream(Operator.values()).toList(); @@ -53,6 +53,9 @@ public enum Operator { return operatorsByPrecedence.indexOf(this) > operatorsByPrecedence.indexOf(op); } + /** Returns true if a sequence of these operations should be evaluated from right to left rather than left to right. */ + public boolean bindsRight() { return bindsRight; } + public final Value evaluate(Value x, Value y) { return function.apply(x, y); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index 04728966bc1..1c1f7509ce8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -116,7 +116,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { private boolean allMultiplicationOrDivision(OperationNode node) { for (Operator o : node.operators()) - if (o != Operator.MULTIPLY && o != Operator.DIVIDE) + if (o != Operator.multiply && o != Operator.divide) return false; return true; } @@ -131,7 +131,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { private boolean hasDivisionByZero(OperationNode node) { for (int i = 1; i < node.children().size(); i++) { - if (node.operators().get(i - 1) == Operator.DIVIDE && isZero(node.children().get(i))) + if (node.operators().get(i - 1) == Operator.divide && isZero(node.children().get(i))) return true; } return false; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 83de8e04a7d..f18240c3222 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -66,7 +66,7 @@ public class RankingExpressionTestCase { public void testProgrammaticBuilding() throws ParseException { ReferenceNode input = new ReferenceNode("input"); ReferenceNode constant = new ReferenceNode("constant"); - OperationNode product = new OperationNode(input, Operator.MULTIPLY, constant); + OperationNode product = new OperationNode(input, Operator.multiply, constant); Reduce<Reference> sum = new Reduce<>(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum); RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum)); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 019f76521e9..dac7393a168 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -779,8 +779,8 @@ public class EvaluationTestCase { @Test public void testProgrammaticBuildingAndPrecedence() { - RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.PLUS, new OperationNode(constant(3), Operator.MULTIPLY, constant(4)))); - RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.PLUS, constant(3)), Operator.MULTIPLY, constant(4))); + RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.plus, new OperationNode(constant(3), Operator.multiply, constant(4)))); + RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.plus, constant(3)), Operator.multiply, constant(4))); assertEquals(14.0, standardPrecedence.evaluate(null).asDouble(), tolerance); assertEquals(20.0, oppositePrecedence.evaluate(null).asDouble(), tolerance); assertEquals("2.0 + 3.0 * 4.0", standardPrecedence.toString()); |