aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-09-28 19:00:08 +0200
committerJon Bratseth <bratseth@gmail.com>2022-09-28 19:00:08 +0200
commita1912b44d0b800f96b334a24ddefd0026f3af356 (patch)
tree352e2b4d026ae9373d73dc4fd7e9892c81943f7f
parentbcbb2009c44380055b2670e7cdefcad232f9ece4 (diff)
Use tensor vocabulary
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java4
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java12
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java20
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java16
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java4
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java4
22 files changed, 96 insertions, 101 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
index 9ffc73e1863..ad050d4ca63 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
@@ -66,9 +66,9 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
ChildNode lhs = stack.peek();
ExpressionNode combination;
- if (rhs.op == Operator.AND)
+ if (rhs.op == Operator.and)
combination = andByIfNode(lhs.child, rhs.child);
- else if (rhs.op == Operator.OR)
+ else if (rhs.op == Operator.or)
combination = orByIfNode(lhs.child, rhs.child);
else {
combination = resolve(lhs, rhs);
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
index dd714501a41..cf354a05a93 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
@@ -139,10 +139,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence);
ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
ExpressionNode expr = new IfNode(
- new OperationNode(new ReferenceNode("d1"), Operator.LESS, queryLengthExpr),
+ new OperationNode(new ReferenceNode("d1"), Operator.smaller, queryLengthExpr),
ZERO,
new IfNode(
- new OperationNode(new ReferenceNode("d1"), Operator.LESS, restLengthExpr),
+ new OperationNode(new ReferenceNode("d1"), Operator.smaller, restLengthExpr),
ONE,
ZERO
)
@@ -174,7 +174,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
List<ExpressionNode> tokenSequence = createTokenSequence(feature);
ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
- OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.LESS, lengthExpr);
+ OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr);
ExpressionNode expr = new IfNode(comparison, ONE, ZERO);
return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
}
@@ -254,7 +254,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
*/
private ExpressionNode createTokenSequenceExpr(int iter, List<ExpressionNode> sequence) {
ExpressionNode lengthExpr = createLengthExpr(iter, sequence);
- OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.LESS, lengthExpr);
+ OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr);
ExpressionNode trueExpr = sequence.get(iter);
if (sequence.get(iter) instanceof ReferenceNode) {
@@ -286,7 +286,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i))));
}
if (i >= 1) {
- operators.add(Operator.PLUS);
+ operators.add(Operator.plus);
}
}
return new OperationNode(factors, operators);
@@ -299,7 +299,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
ExpressionNode expr;
if (iter >= 1) {
ExpressionNode lengthExpr = new EmbracedNode(createLengthExpr(iter - 1, sequence));
- expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.MINUS, lengthExpr));
+ expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.minus, lengthExpr));
} else {
expr = new ReferenceNode("d1");
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
index e0c99706e4a..3aa60e0b9a2 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
@@ -84,28 +84,28 @@ class LazyValue extends Value {
}
@Override
- public Value greaterEqual(Value value) {
- return computedValue().greaterEqual(value);
+ public Value largerOrEqual(Value value) {
+ return computedValue().largerOrEqual(value);
}
@Override
- public Value greater(Value value) {
- return computedValue().greater(value);
+ public Value larger(Value value) {
+ return computedValue().larger(value);
}
@Override
- public Value lessEqual(Value value) {
- return computedValue().lessEqual(value);
+ public Value smallerOrEqual(Value value) {
+ return computedValue().smallerOrEqual(value);
}
@Override
- public Value less(Value value) {
- return computedValue().less(value);
+ public Value smaller(Value value) {
+ return computedValue().smaller(value);
}
@Override
- public Value approx(Value value) {
- return computedValue().approx(value);
+ public Value approxEqual(Value value) {
+ return computedValue().approxEqual(value);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
index c66022975c7..7ac2f4bff84 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
@@ -92,7 +92,7 @@ public class Gather extends IntermediateOperation {
ExpressionNode indexExpression = new ConstantNode(new DoubleValue(constantValue));
if (constantValue < 0) {
ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get()));
- indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.PLUS, axisSize));
+ indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.plus, axisSize));
}
addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression);
} else {
@@ -125,8 +125,8 @@ public class Gather extends IntermediateOperation {
/** to support negative indexing */
private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) {
ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get()));
- ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.PLUS, axisSize));
- ExpressionNode mod = new OperationNode(plus, Operator.MODULO, axisSize);
+ ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.plus, axisSize));
+ ExpressionNode mod = new OperationNode(plus, Operator.modulo, axisSize);
return mod;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
index 81d633dea4b..97bfdda385e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -99,7 +99,7 @@ public class Gemm extends IntermediateOperation {
TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
new OperationNode(
new TensorFunctionNode(AxB),
- Operator.MULTIPLY,
+ Operator.multiply,
new ConstantNode(new DoubleValue(alpha))));
if (inputs.size() == 3) {
@@ -107,7 +107,7 @@ public class Gemm extends IntermediateOperation {
TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction(
new OperationNode(
new TensorFunctionNode(cFunction.get()),
- Operator.MULTIPLY,
+ Operator.multiply,
new ConstantNode(new DoubleValue(beta))));
return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
index 66e810b954e..9a38ab9dfde 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
@@ -65,8 +65,8 @@ public class Range extends IntermediateOperation {
ExpressionNode startExpr = new ConstantNode(new DoubleValue(start));
ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta));
ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName));
- ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.MULTIPLY, dimExpr);
- ExpressionNode addExpr = new OperationNode(startExpr, Operator.PLUS, stepExpr);
+ ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.multiply, dimExpr);
+ ExpressionNode addExpr = new OperationNode(startExpr, Operator.plus, stepExpr);
TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(addExpr));
return function;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index ce93461bff3..bc94fc6aa76 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -157,13 +157,13 @@ public class Reshape extends IntermediateOperation {
inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
} else if (dim == (inputType.rank() - 1)) {
ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
- ExpressionNode div = new OperationNode(unrolled, Operator.MODULO, size);
+ ExpressionNode div = new OperationNode(unrolled, Operator.modulo, size);
inputDimensionExpression = new EmbracedNode(div);
} else {
ExpressionNode size = new ConstantNode(new DoubleValue(innerSize));
ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize));
- ExpressionNode mod = new OperationNode(unrolled, Operator.MODULO, previousSize);
- ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.DIVIDE, size);
+ ExpressionNode mod = new OperationNode(unrolled, Operator.modulo, previousSize);
+ ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.divide, size);
inputDimensionExpression = new EmbracedNode(div);
}
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression)));
@@ -187,12 +187,12 @@ public class Reshape extends IntermediateOperation {
TensorType.Dimension dimension = type.dimensions().get(i);
children.add(0, new ReferenceNode(dimension.name()));
if (size > 1) {
- operators.add(0, Operator.MULTIPLY);
+ operators.add(0, Operator.multiply);
children.add(0, new ConstantNode(new DoubleValue(size)));
}
size *= OrderedTensorType.dimensionSize(dimension);
if (i > 0) {
- operators.add(0, Operator.PLUS);
+ operators.add(0, Operator.plus);
}
}
return new OperationNode(children, operators);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
index 617e1f00c94..916e9980131 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
@@ -166,8 +166,8 @@ public class Slice extends IntermediateOperation {
// step * (d0 + start)
ExpressionNode reference = new ReferenceNode(outputDimensionName);
- ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.PLUS, startIndex));
- ExpressionNode mul = new OperationNode(stepSize, Operator.MULTIPLY, plus);
+ ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.plus, startIndex));
+ ExpressionNode mul = new OperationNode(stepSize, Operator.multiply, plus);
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul))));
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
index 42901259821..ef46d222941 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
@@ -96,7 +96,7 @@ public class Split extends IntermediateOperation {
for (int i = 0; i < inputType.rank(); ++i) {
String inputDimensionName = inputType.dimensions().get(i).name();
ExpressionNode reference = new ReferenceNode(inputDimensionName);
- ExpressionNode offset = new OperationNode(reference, Operator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0)));
+ ExpressionNode offset = new OperationNode(reference, Operator.plus, new ConstantNode(new DoubleValue(i == axis ? start : 0)));
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset))));
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
index cd88c625d81..a880bff87be 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
@@ -77,7 +77,7 @@ public class Tile extends IntermediateOperation {
ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
ExpressionNode reference = new ReferenceNode(inputDimensionName);
- ExpressionNode mod = new OperationNode(reference, Operator.MODULO, size);
+ ExpressionNode mod = new OperationNode(reference, Operator.modulo, size);
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod))));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index e1db6378fcf..186208e036f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -42,27 +42,27 @@ public abstract class DoubleCompatibleValue extends Value {
}
@Override
- public Value greaterEqual(Value value) {
+ public Value largerOrEqual(Value value) {
return new BooleanValue(this.asDouble() >= value.asDouble());
}
@Override
- public Value greater(Value value) {
+ public Value larger(Value value) {
return new BooleanValue(this.asDouble() > value.asDouble());
}
@Override
- public Value lessEqual(Value value) {
+ public Value smallerOrEqual(Value value) {
return new BooleanValue(this.asDouble() <= value.asDouble());
}
@Override
- public Value less(Value value) {
+ public Value smaller(Value value) {
return new BooleanValue(this.asDouble() < value.asDouble());
}
@Override
- public Value approx(Value value) {
+ public Value approxEqual(Value value) {
return new BooleanValue(approxEqual(this.asDouble(), value.asDouble()));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index 3c09c644147..a585c989954 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -71,27 +71,27 @@ public class StringValue extends Value {
}
@Override
- public Value greaterEqual(Value argument) {
+ public Value largerOrEqual(Value argument) {
throw new UnsupportedOperationException("String values ('" + value + "') do not support greaterEqual");
}
@Override
- public Value greater(Value argument) {
+ public Value larger(Value argument) {
throw new UnsupportedOperationException("String values ('" + value + "') do not support greater");
}
@Override
- public Value lessEqual(Value argument) {
+ public Value smallerOrEqual(Value argument) {
throw new UnsupportedOperationException("String values ('" + value + "') do not support lessEqual");
}
@Override
- public Value less(Value argument) {
+ public Value smaller(Value argument) {
throw new UnsupportedOperationException("String values ('" + value + "') do not support less");
}
@Override
- public Value approx(Value argument) {
+ public Value approxEqual(Value argument) {
return new BooleanValue(this.asDouble() == argument.asDouble());
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 73ea0b23986..25e03c75376 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -72,7 +72,7 @@ public class TensorValue extends Value {
}
@Override
- public Value greaterEqual(Value argument) {
+ public Value largerOrEqual(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.largerOrEqual(((TensorValue)argument).value));
else
@@ -80,7 +80,7 @@ public class TensorValue extends Value {
}
@Override
- public Value greater(Value argument) {
+ public Value larger(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.larger(((TensorValue)argument).value));
else
@@ -88,7 +88,7 @@ public class TensorValue extends Value {
}
@Override
- public Value lessEqual(Value argument) {
+ public Value smallerOrEqual(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value));
else
@@ -96,7 +96,7 @@ public class TensorValue extends Value {
}
@Override
- public Value less(Value argument) {
+ public Value smaller(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.smaller(((TensorValue)argument).value));
else
@@ -104,7 +104,7 @@ public class TensorValue extends Value {
}
@Override
- public Value approx(Value argument) {
+ public Value approxEqual(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.approxEqual(((TensorValue)argument).value));
else
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 99663fe8d0d..5de2138147e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -53,11 +53,11 @@ public abstract class Value {
public abstract Value or(Value value);
public abstract Value and(Value value);
- public abstract Value greaterEqual(Value value);
- public abstract Value greater(Value value);
- public abstract Value lessEqual(Value value);
- public abstract Value less(Value value);
- public abstract Value approx(Value value);
+ public abstract Value largerOrEqual(Value value);
+ public abstract Value larger(Value value);
+ public abstract Value smallerOrEqual(Value value);
+ public abstract Value smaller(Value value);
+ public abstract Value approxEqual(Value value);
public abstract Value notEqual(Value value);
public abstract Value equal(Value value);
public abstract Value add(Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
index 78acd2e5af1..a3fc6aae9ac 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
@@ -89,7 +89,7 @@ public class GBDTForestOptimizer extends Optimizer {
}
OperationNode aNode = (OperationNode)node;
for (Operator op : aNode.operators()) {
- if (op != Operator.PLUS) {
+ if (op != Operator.plus) {
return false;
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
index b3cbe252dfc..7ba671e62eb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
@@ -116,9 +116,9 @@ public class GBDTOptimizer extends Optimizer {
private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) {
if (isBinaryComparison(condition)) {
OperationNode comparison = (OperationNode)condition;
- if (comparison.operators().get(0) == Operator.LESS)
+ if (comparison.operators().get(0) == Operator.smaller)
values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context));
- else if (comparison.operators().get(0) == Operator.EQUAL)
+ else if (comparison.operators().get(0) == Operator.equal)
values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context));
else
throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0));
@@ -135,7 +135,7 @@ public class GBDTOptimizer extends Optimizer {
if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) {
if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) {
OperationNode comparison = (OperationNode)embracedNode.children().get(0);
- if (comparison.operators().get(0) == Operator.GREATEREQUAL)
+ if (comparison.operators().get(0) == Operator.largerOrEqual)
values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context));
else
throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0));
@@ -153,13 +153,13 @@ public class GBDTOptimizer extends Optimizer {
private boolean isBinaryComparison(ExpressionNode condition) {
if ( ! (condition instanceof OperationNode binaryNode)) return false;
if (binaryNode.operators().size() != 1) return false;
- if (binaryNode.operators().get(0) == Operator.GREATEREQUAL) return true;
- if (binaryNode.operators().get(0) == Operator.GREATER) return true;
- if (binaryNode.operators().get(0) == Operator.LESSEQUAL) return true;
- if (binaryNode.operators().get(0) == Operator.LESS) return true;
- if (binaryNode.operators().get(0) == Operator.APPROX) return true;
- if (binaryNode.operators().get(0) == Operator.NOTEQUAL) return true;
- if (binaryNode.operators().get(0) == Operator.EQUAL) return true;
+ if (binaryNode.operators().get(0) == Operator.largerOrEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.larger) return true;
+ if (binaryNode.operators().get(0) == Operator.smallerOrEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.smaller) return true;
+ if (binaryNode.operators().get(0) == Operator.approxEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.notEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.equal) return true;
return false;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index 3c7f48aa38c..97e9a74f9c8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -123,14 +123,14 @@ public class LambdaFunctionNode extends CompositeNode {
}
Operator operator = node.operators().get(0);
switch (operator) {
- case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
- case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
- case PLUS: return asFunctionExpression((left, right) -> left + right);
- case MINUS: return asFunctionExpression((left, right) -> left - right);
- case MULTIPLY: return asFunctionExpression((left, right) -> left * right);
- case DIVIDE: return asFunctionExpression((left, right) -> left / right);
- case MODULO: return asFunctionExpression((left, right) -> left % right);
- case POWER: return asFunctionExpression(Math::pow);
+ case or: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
+ case and: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
+ case plus: return asFunctionExpression((left, right) -> left + right);
+ case minus: return asFunctionExpression((left, right) -> left - right);
+ case multiply: return asFunctionExpression((left, right) -> left * right);
+ case divide: return asFunctionExpression((left, right) -> left / right);
+ case modulo: return asFunctionExpression((left, right) -> left % right);
+ case power: return asFunctionExpression(Math::pow);
}
return Optional.empty();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
index 392f42f6cbe..d08e2270935 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
@@ -127,14 +127,6 @@ public final class OperationNode extends CompositeNode {
@Override
public int hashCode() { return Objects.hash(children, operators); }
- @Override
- public boolean equals(Object o) {
- if ( ! (o instanceof OperationNode other)) return false;
- if ( ! this.children().equals(other.children())) return false;
- if ( ! this.operators().equals(other.operators())) return false;
- return true;
- }
-
public static OperationNode resolve(ExpressionNode left, Operator op, ExpressionNode right) {
if ( ! (left instanceof OperationNode leftArithmetic)) return new OperationNode(left, op, right);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
index 4ddbfa4ea9f..63144f0ef4a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
@@ -15,21 +15,21 @@ import java.util.function.BiFunction;
public enum Operator {
// In order from lowest to highest precedence
- OR("||", (x, y) -> x.or(y)),
- AND("&&", (x, y) -> x.and(y)),
- GREATEREQUAL(">=", (x, y) -> x.greaterEqual(y)),
- GREATER(">", (x, y) -> x.greater(y)),
- LESSEQUAL("<=", (x, y) -> x.lessEqual(y)),
- LESS("<", (x, y) -> x.less(y)),
- APPROX("~=", (x, y) -> x.approx(y)),
- NOTEQUAL("!=", (x, y) -> x.notEqual(y)),
- EQUAL("==", (x, y) -> x.equal(y)),
- PLUS("+", (x, y) -> x.add(y)),
- MINUS("-", (x, y) -> x.subtract(y)),
- MULTIPLY("*", (x, y) -> x.multiply(y)),
- DIVIDE("/", (x, y) -> x.divide(y)),
- MODULO("%", (x, y) -> x.modulo(y)),
- POWER("^", true, (x, y) -> x.power(y));
+ or("||", (x, y) -> x.or(y)),
+ and("&&", (x, y) -> x.and(y)),
+ largerOrEqual(">=", (x, y) -> x.largerOrEqual(y)),
+ larger(">", (x, y) -> x.larger(y)),
+ smallerOrEqual("<=", (x, y) -> x.smallerOrEqual(y)),
+ smaller("<", (x, y) -> x.smaller(y)),
+ approxEqual("~=", (x, y) -> x.approxEqual(y)),
+ notEqual("!=", (x, y) -> x.notEqual(y)),
+ equal("==", (x, y) -> x.equal(y)),
+ plus("+", (x, y) -> x.add(y)),
+ minus("-", (x, y) -> x.subtract(y)),
+ multiply("*", (x, y) -> x.multiply(y)),
+ divide("/", (x, y) -> x.divide(y)),
+ modulo("%", (x, y) -> x.modulo(y)),
+ power("^", true, (x, y) -> x.power(y));
/** A list of all the operators in this in order of increasing precedence */
public static final List<Operator> operatorsByPrecedence = Arrays.stream(Operator.values()).toList();
@@ -53,6 +53,9 @@ public enum Operator {
return operatorsByPrecedence.indexOf(this) > operatorsByPrecedence.indexOf(op);
}
+ /** Returns true if a sequence of these operations should be evaluated from right to left rather than left to right. */
+ public boolean bindsRight() { return bindsRight; }
+
public final Value evaluate(Value x, Value y) {
return function.apply(x, y);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index 04728966bc1..1c1f7509ce8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -116,7 +116,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
private boolean allMultiplicationOrDivision(OperationNode node) {
for (Operator o : node.operators())
- if (o != Operator.MULTIPLY && o != Operator.DIVIDE)
+ if (o != Operator.multiply && o != Operator.divide)
return false;
return true;
}
@@ -131,7 +131,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
private boolean hasDivisionByZero(OperationNode node) {
for (int i = 1; i < node.children().size(); i++) {
- if (node.operators().get(i - 1) == Operator.DIVIDE && isZero(node.children().get(i)))
+ if (node.operators().get(i - 1) == Operator.divide && isZero(node.children().get(i)))
return true;
}
return false;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index 83de8e04a7d..f18240c3222 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -66,7 +66,7 @@ public class RankingExpressionTestCase {
public void testProgrammaticBuilding() throws ParseException {
ReferenceNode input = new ReferenceNode("input");
ReferenceNode constant = new ReferenceNode("constant");
- OperationNode product = new OperationNode(input, Operator.MULTIPLY, constant);
+ OperationNode product = new OperationNode(input, Operator.multiply, constant);
Reduce<Reference> sum = new Reduce<>(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum);
RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 019f76521e9..dac7393a168 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -779,8 +779,8 @@ public class EvaluationTestCase {
@Test
public void testProgrammaticBuildingAndPrecedence() {
- RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.PLUS, new OperationNode(constant(3), Operator.MULTIPLY, constant(4))));
- RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.PLUS, constant(3)), Operator.MULTIPLY, constant(4)));
+ RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.plus, new OperationNode(constant(3), Operator.multiply, constant(4))));
+ RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.plus, constant(3)), Operator.multiply, constant(4)));
assertEquals(14.0, standardPrecedence.evaluate(null).asDouble(), tolerance);
assertEquals(20.0, oppositePrecedence.evaluate(null).asDouble(), tolerance);
assertEquals("2.0 + 3.0 * 4.0", standardPrecedence.toString());