aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2022-09-28 22:54:13 +0200
committerGitHub <noreply@github.com>2022-09-28 22:54:13 +0200
commit12992ecdc0e77968eb5c5544f2ae7d855e443162 (patch)
treeac8cec3ae02f27ae638876940399f490b4ac4ab1
parentd50f7bd9c99ed9d8edeabb71825f3966f9cd6bd9 (diff)
parentfb0074925e9e8358d38145dc5753de1c935f737d (diff)
Merge pull request #24251 from vespa-engine/bratseth/operatorsv8.61.17
Bratseth/operators
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java40
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java22
-rw-r--r--config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java29
-rw-r--r--config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java6
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java71
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java20
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java6
-rw-r--r--searchlib/abi-spec.json226
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java131
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java73
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java139
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java27
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java61
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java63
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java69
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java22
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java)38
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java67
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java50
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java40
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj67
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java18
30 files changed, 711 insertions, 638 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
index 8fa4b469590..ad050d4ca63 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
@@ -2,8 +2,8 @@
package com.yahoo.schema.expressiontransforms;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -35,20 +35,20 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
if (node instanceof CompositeNode composite)
node = transformChildren(composite, context);
- if (node instanceof ArithmeticNode arithmetic)
+ if (node instanceof OperationNode arithmetic)
node = transformBooleanArithmetics(arithmetic);
return node;
}
- private ExpressionNode transformBooleanArithmetics(ArithmeticNode node) {
+ private ExpressionNode transformBooleanArithmetics(OperationNode node) {
Iterator<ExpressionNode> child = node.children().iterator();
// Transform in precedence order:
Deque<ChildNode> stack = new ArrayDeque<>();
stack.push(new ChildNode(null, child.next()));
- for (Iterator<ArithmeticOperator> it = node.operators().iterator(); it.hasNext() && child.hasNext();) {
- ArithmeticOperator op = it.next();
+ for (Iterator<Operator> it = node.operators().iterator(); it.hasNext() && child.hasNext();) {
+ Operator op = it.next();
if ( ! stack.isEmpty()) {
while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) {
popStack(stack);
@@ -66,9 +66,9 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
ChildNode lhs = stack.peek();
ExpressionNode combination;
- if (rhs.op == ArithmeticOperator.AND)
+ if (rhs.op == Operator.and)
combination = andByIfNode(lhs.child, rhs.child);
- else if (rhs.op == ArithmeticOperator.OR)
+ else if (rhs.op == Operator.or)
combination = orByIfNode(lhs.child, rhs.child);
else {
combination = resolve(lhs, rhs);
@@ -77,28 +77,28 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
lhs.child = combination;
}
- private static ArithmeticNode resolve(ChildNode left, ChildNode right) {
- if ( ! (left.child instanceof ArithmeticNode) && ! (right.child instanceof ArithmeticNode))
- return new ArithmeticNode(left.child, right.op, right.child);
+ private static OperationNode resolve(ChildNode left, ChildNode right) {
+ if (! (left.child instanceof OperationNode) && ! (right.child instanceof OperationNode))
+ return new OperationNode(left.child, right.op, right.child);
// Collapse inserted ArithmeticNodes
- List<ArithmeticOperator> joinedOps = new ArrayList<>();
+ List<Operator> joinedOps = new ArrayList<>();
joinOps(left, joinedOps);
joinedOps.add(right.op);
joinOps(right, joinedOps);
List<ExpressionNode> joinedChildren = new ArrayList<>();
joinChildren(left, joinedChildren);
joinChildren(right, joinedChildren);
- return new ArithmeticNode(joinedChildren, joinedOps);
+ return new OperationNode(joinedChildren, joinedOps);
}
- private static void joinOps(ChildNode node, List<ArithmeticOperator> joinedOps) {
- if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode)
- joinedOps.addAll(arithmeticNode.operators());
+ private static void joinOps(ChildNode node, List<Operator> joinedOps) {
+ if (node.artificial && node.child instanceof OperationNode operationNode)
+ joinedOps.addAll(operationNode.operators());
}
private static void joinChildren(ChildNode node, List<ExpressionNode> joinedChildren) {
- if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode)
- joinedChildren.addAll(arithmeticNode.children());
+ if (node.artificial && node.child instanceof OperationNode operationNode)
+ joinedChildren.addAll(operationNode.children());
else
joinedChildren.add(node.child);
}
@@ -115,11 +115,11 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
/** A child with the operator to be applied to it when combining it with the previous child. */
private static class ChildNode {
- final ArithmeticOperator op;
+ final Operator op;
ExpressionNode child;
boolean artificial;
- public ChildNode(ArithmeticOperator op, ExpressionNode child) {
+ public ChildNode(Operator op, ExpressionNode child) {
this.op = op;
this.child = child;
}
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
index 30d9a3766b3..cf354a05a93 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
@@ -3,9 +3,8 @@ package com.yahoo.schema.expressiontransforms;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
-import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
@@ -13,7 +12,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
@@ -141,10 +139,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence);
ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
ExpressionNode expr = new IfNode(
- new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr),
+ new OperationNode(new ReferenceNode("d1"), Operator.smaller, queryLengthExpr),
ZERO,
new IfNode(
- new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr),
+ new OperationNode(new ReferenceNode("d1"), Operator.smaller, restLengthExpr),
ONE,
ZERO
)
@@ -176,7 +174,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
List<ExpressionNode> tokenSequence = createTokenSequence(feature);
ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
- ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
+ OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr);
ExpressionNode expr = new IfNode(comparison, ONE, ZERO);
return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
}
@@ -256,7 +254,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
*/
private ExpressionNode createTokenSequenceExpr(int iter, List<ExpressionNode> sequence) {
ExpressionNode lengthExpr = createLengthExpr(iter, sequence);
- ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
+ OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr);
ExpressionNode trueExpr = sequence.get(iter);
if (sequence.get(iter) instanceof ReferenceNode) {
@@ -280,7 +278,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
*/
private ExpressionNode createLengthExpr(int iter, List<ExpressionNode> sequence) {
List<ExpressionNode> factors = new ArrayList<>();
- List<ArithmeticOperator> operators = new ArrayList<>();
+ List<Operator> operators = new ArrayList<>();
for (int i = 0; i < iter + 1; ++i) {
if (sequence.get(i) instanceof ConstantNode) {
factors.add(ONE);
@@ -288,10 +286,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i))));
}
if (i >= 1) {
- operators.add(ArithmeticOperator.PLUS);
+ operators.add(Operator.plus);
}
}
- return new ArithmeticNode(factors, operators);
+ return new OperationNode(factors, operators);
}
/**
@@ -301,7 +299,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
ExpressionNode expr;
if (iter >= 1) {
ExpressionNode lengthExpr = new EmbracedNode(createLengthExpr(iter - 1, sequence));
- expr = new EmbracedNode(new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.MINUS, lengthExpr));
+ expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.minus, lengthExpr));
} else {
expr = new ReferenceNode("d1");
}
diff --git a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
index 5eecee516ec..13d21884c7d 100644
--- a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
@@ -68,7 +68,7 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase {
Schema s = builder.getSchema();
RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels());
- assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))",
+ assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + (if (7.0 < attribute(a), 1, 2) == 0)))",
parent.getFirstPhaseRanking().getRoot().toString());
RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels());
assertEquals("7.0 * (9 + attribute(a))",
@@ -76,6 +76,33 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase {
}
@Test
+ void testInlinedComparison() throws ParseException {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry);
+ builder.addSchema("search test {\n" +
+ " document test { \n" +
+ " }\n" +
+ " \n" +
+ " rank-profile parent {\n" +
+ "function foo() {\n" +
+ " expression: 3 * bar\n" +
+ "}\n" +
+ "\n" +
+ "function inline bar() {\n" +
+ " expression: query(test) > 2.0\n" +
+ "}\n" +
+ "}\n" +
+ "}\n");
+ builder.build(true);
+ Schema s = builder.getSchema();
+
+ RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels());
+ assertEquals("3 * (query(test) > 2.0)",
+ parent.getFunctions().get("foo").function().getBody().getRoot().toString());
+
+ }
+
+ @Test
void testConstants() throws ParseException {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry);
diff --git a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
index 65d12f1cdcf..d692b69d3c8 100644
--- a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
@@ -4,7 +4,7 @@ package com.yahoo.schema.expressiontransforms;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import org.junit.jupiter.api.Test;
@@ -43,8 +43,8 @@ public class BooleanExpressionTransformerTestCase {
var expr = new BooleanExpressionTransformer()
.transform(new RankingExpression("a + b + c * d + e + f"),
new TransformContext(Map.of(), new MapTypeContext()));
- assertTrue(expr.getRoot() instanceof ArithmeticNode);
- ArithmeticNode root = (ArithmeticNode) expr.getRoot();
+ assertTrue(expr.getRoot() instanceof OperationNode);
+ OperationNode root = (OperationNode) expr.getRoot();
assertEquals(5, root.operators().size());
assertEquals(6, root.children().size());
}
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
index 22681858fc3..e9b674a8c87 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -171,7 +171,7 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name());
- assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 0.0, if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value());
+ assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value());
assertEquals("test_unbound_model", config.rankprofile(8).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name());
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
index 8788eda572d..3aa60e0b9a2 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
@@ -4,7 +4,6 @@ package ai.vespa.models.evaluation;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -70,53 +69,83 @@ class LazyValue extends Value {
}
@Override
- public Value add(Value value) {
- return computedValue().add(value);
+ public Value not() {
+ return computedValue().not();
}
@Override
- public Value subtract(Value value) {
- return computedValue().subtract(value);
+ public Value or(Value value) {
+ return computedValue().or(value);
}
@Override
- public Value multiply(Value value) {
- return computedValue().multiply(value);
+ public Value and(Value value) {
+ return computedValue().and(value);
}
@Override
- public Value divide(Value value) {
- return computedValue().divide(value);
+ public Value largerOrEqual(Value value) {
+ return computedValue().largerOrEqual(value);
}
@Override
- public Value modulo(Value value) {
- return computedValue().modulo(value);
+ public Value larger(Value value) {
+ return computedValue().larger(value);
}
@Override
- public Value and(Value value) {
- return computedValue().and(value);
+ public Value smallerOrEqual(Value value) {
+ return computedValue().smallerOrEqual(value);
}
@Override
- public Value or(Value value) {
- return computedValue().or(value);
+ public Value smaller(Value value) {
+ return computedValue().smaller(value);
}
@Override
- public Value not() {
- return computedValue().not();
+ public Value approxEqual(Value value) {
+ return computedValue().approxEqual(value);
}
@Override
- public Value power(Value value) {
- return computedValue().power(value);
+ public Value notEqual(Value value) {
+ return computedValue().notEqual(value);
+ }
+
+ @Override
+ public Value equal(Value value) {
+ return computedValue().equal(value);
}
@Override
- public Value compare(TruthOperator operator, Value value) {
- return computedValue().compare(operator, value);
+ public Value add(Value value) {
+ return computedValue().add(value);
+ }
+
+ @Override
+ public Value subtract(Value value) {
+ return computedValue().subtract(value);
+ }
+
+ @Override
+ public Value multiply(Value value) {
+ return computedValue().multiply(value);
+ }
+
+ @Override
+ public Value divide(Value value) {
+ return computedValue().divide(value);
+ }
+
+ @Override
+ public Value modulo(Value value) {
+ return computedValue().modulo(value);
+ }
+
+ @Override
+ public Value power(Value value) {
+ return computedValue().power(value);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
index cd0c4da6d0f..7ac2f4bff84 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -92,7 +92,7 @@ public class Gather extends IntermediateOperation {
ExpressionNode indexExpression = new ConstantNode(new DoubleValue(constantValue));
if (constantValue < 0) {
ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get()));
- indexExpression = new EmbracedNode(new ArithmeticNode(indexExpression, ArithmeticOperator.PLUS, axisSize));
+ indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.plus, axisSize));
}
addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression);
} else {
@@ -125,8 +125,8 @@ public class Gather extends IntermediateOperation {
/** to support negative indexing */
private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) {
ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get()));
- ExpressionNode plus = new EmbracedNode(new ArithmeticNode(slice, ArithmeticOperator.PLUS, axisSize));
- ExpressionNode mod = new ArithmeticNode(plus, ArithmeticOperator.MODULO, axisSize);
+ ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.plus, axisSize));
+ ExpressionNode mod = new OperationNode(plus, Operator.modulo, axisSize);
return mod;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
index 1f447f2a575..97bfdda385e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
@@ -97,17 +97,17 @@ public class Gemm extends IntermediateOperation {
TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension);
TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
- new ArithmeticNode(
+ new OperationNode(
new TensorFunctionNode(AxB),
- ArithmeticOperator.MULTIPLY,
+ Operator.multiply,
new ConstantNode(new DoubleValue(alpha))));
if (inputs.size() == 3) {
Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function();
TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction(
- new ArithmeticNode(
+ new OperationNode(
new TensorFunctionNode(cFunction.get()),
- ArithmeticOperator.MULTIPLY,
+ Operator.multiply,
new ConstantNode(new DoubleValue(beta))));
return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
index 5c4e8cd6cd0..9a38ab9dfde 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -65,8 +65,8 @@ public class Range extends IntermediateOperation {
ExpressionNode startExpr = new ConstantNode(new DoubleValue(start));
ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta));
ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName));
- ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr);
- ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr);
+ ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.multiply, dimExpr);
+ ExpressionNode addExpr = new OperationNode(startExpr, Operator.plus, stepExpr);
TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(addExpr));
return function;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index 7b675fa79af..bc94fc6aa76 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -6,13 +6,11 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
@@ -159,13 +157,13 @@ public class Reshape extends IntermediateOperation {
inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
} else if (dim == (inputType.rank() - 1)) {
ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
- ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size);
+ ExpressionNode div = new OperationNode(unrolled, Operator.modulo, size);
inputDimensionExpression = new EmbracedNode(div);
} else {
ExpressionNode size = new ConstantNode(new DoubleValue(innerSize));
ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize));
- ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize);
- ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size);
+ ExpressionNode mod = new OperationNode(unrolled, Operator.modulo, previousSize);
+ ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.divide, size);
inputDimensionExpression = new EmbracedNode(div);
}
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression)));
@@ -183,21 +181,21 @@ public class Reshape extends IntermediateOperation {
return new ConstantNode(DoubleValue.zero);
List<ExpressionNode> children = new ArrayList<>();
- List<ArithmeticOperator> operators = new ArrayList<>();
+ List<Operator> operators = new ArrayList<>();
int size = 1;
for (int i = type.dimensions().size() - 1; i >= 0; --i) {
TensorType.Dimension dimension = type.dimensions().get(i);
children.add(0, new ReferenceNode(dimension.name()));
if (size > 1) {
- operators.add(0, ArithmeticOperator.MULTIPLY);
+ operators.add(0, Operator.multiply);
children.add(0, new ConstantNode(new DoubleValue(size)));
}
size *= OrderedTensorType.dimensionSize(dimension);
if (i > 0) {
- operators.add(0, ArithmeticOperator.PLUS);
+ operators.add(0, Operator.plus);
}
}
- return new ArithmeticNode(children, operators);
+ return new OperationNode(children, operators);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
index 91b7064b19c..916e9980131 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
@@ -6,8 +6,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -166,8 +166,8 @@ public class Slice extends IntermediateOperation {
// step * (d0 + start)
ExpressionNode reference = new ReferenceNode(outputDimensionName);
- ExpressionNode plus = new EmbracedNode(new ArithmeticNode(reference, ArithmeticOperator.PLUS, startIndex));
- ExpressionNode mul = new ArithmeticNode(stepSize, ArithmeticOperator.MULTIPLY, plus);
+ ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.plus, startIndex));
+ ExpressionNode mul = new OperationNode(stepSize, Operator.multiply, plus);
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul))));
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
index 6f720716adb..ef46d222941 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -96,7 +96,7 @@ public class Split extends IntermediateOperation {
for (int i = 0; i < inputType.rank(); ++i) {
String inputDimensionName = inputType.dimensions().get(i).name();
ExpressionNode reference = new ReferenceNode(inputDimensionName);
- ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0)));
+ ExpressionNode offset = new OperationNode(reference, Operator.plus, new ConstantNode(new DoubleValue(i == axis ? start : 0)));
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset))));
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
index 4bfab284cc2..a880bff87be 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
@@ -4,8 +4,8 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -77,7 +77,7 @@ public class Tile extends IntermediateOperation {
ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
ExpressionNode reference = new ReferenceNode(inputDimensionName);
- ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size);
+ ExpressionNode mod = new OperationNode(reference, Operator.modulo, size);
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod))));
}
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 7caf5a06032..a0517f408bf 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -496,16 +496,22 @@
"public boolean hasDouble()",
"public com.yahoo.tensor.Tensor asTensor()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value negate()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value largerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value larger(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value smallerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value smaller(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value approxEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)"
],
"fields": []
@@ -691,16 +697,22 @@
"public boolean hasDouble()",
"public boolean asBoolean()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value negate()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value largerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value larger(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value smallerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value smaller(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value approxEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value asMutable()",
"public java.lang.String toString()",
@@ -723,17 +735,23 @@
"public boolean hasDouble()",
"public boolean asBoolean()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value negate()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value largerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value larger(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value smallerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value smaller(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value approxEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.tensor.Tensor asTensor()",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value asMutable()",
"public java.lang.String toString()",
@@ -760,16 +778,22 @@
"public abstract boolean hasDouble()",
"public abstract boolean asBoolean()",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value negate()",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value largerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value larger(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value smallerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value smaller(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value approxEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value freeze()",
"public final boolean isFrozen()",
@@ -891,9 +915,8 @@
"public final java.util.List featureList()",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode rankingExpression()",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode expression()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode arithmeticExpression()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator arithmetic()",
- "public final com.yahoo.searchlib.rankingexpression.rule.TruthOperator comparator()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode operationExpression()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.Operator binaryOperator()",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode value()",
"public final com.yahoo.searchlib.rankingexpression.rule.IfNode ifExpression()",
"public final com.yahoo.searchlib.rankingexpression.rule.ReferenceNode feature()",
@@ -1013,13 +1036,13 @@
"public static final int DOLLAR",
"public static final int COMMA",
"public static final int COLON",
- "public static final int LE",
- "public static final int LT",
- "public static final int EQ",
- "public static final int NQ",
- "public static final int AQ",
- "public static final int GE",
- "public static final int GT",
+ "public static final int GREATEREQUAL",
+ "public static final int GREATER",
+ "public static final int LESSEQUAL",
+ "public static final int LESS",
+ "public static final int APPROX",
+ "public static final int NOTEQUAL",
+ "public static final int EQUAL",
"public static final int STRING",
"public static final int IF",
"public static final int IN",
@@ -1220,54 +1243,6 @@
"public static final com.yahoo.searchlib.rankingexpression.rule.Arguments EMPTY"
]
},
- "com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode": {
- "superClass": "com.yahoo.searchlib.rankingexpression.rule.CompositeNode",
- "interfaces": [],
- "attributes": [
- "public",
- "final"
- ],
- "methods": [
- "public void <init>(java.util.List, java.util.List)",
- "public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
- "public java.util.List operators()",
- "public java.util.List children()",
- "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
- "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
- "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)",
- "public int hashCode()",
- "public static com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode resolve(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)"
- ],
- "fields": []
- },
- "com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator": {
- "superClass": "java.lang.Enum",
- "interfaces": [],
- "attributes": [
- "public",
- "abstract",
- "enum"
- ],
- "methods": [
- "public static com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator[] values()",
- "public static com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator valueOf(java.lang.String)",
- "public boolean hasPrecedenceOver(com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator)",
- "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Value, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public java.lang.String toString()"
- ],
- "fields": [
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator OR",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator AND",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator PLUS",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator MINUS",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator MULTIPLY",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator DIVIDE",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator MODULO",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator POWER",
- "public static final java.util.List operatorsByPrecedence"
- ]
- },
"com.yahoo.searchlib.rankingexpression.rule.BooleanNode": {
"superClass": "com.yahoo.searchlib.rankingexpression.rule.CompositeNode",
"interfaces": [],
@@ -1280,27 +1255,6 @@
],
"fields": []
},
- "com.yahoo.searchlib.rankingexpression.rule.ComparisonNode": {
- "superClass": "com.yahoo.searchlib.rankingexpression.rule.BooleanNode",
- "interfaces": [],
- "attributes": [
- "public"
- ],
- "methods": [
- "public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
- "public java.util.List children()",
- "public com.yahoo.searchlib.rankingexpression.rule.TruthOperator getOperator()",
- "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode getLeftCondition()",
- "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode getRightCondition()",
- "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
- "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
- "public com.yahoo.searchlib.rankingexpression.rule.ComparisonNode setChildren(java.util.List)",
- "public int hashCode()",
- "public bridge synthetic com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)"
- ],
- "fields": []
- },
"com.yahoo.searchlib.rankingexpression.rule.CompositeNode": {
"superClass": "com.yahoo.searchlib.rankingexpression.rule.ExpressionNode",
"interfaces": [],
@@ -1581,6 +1535,61 @@
],
"fields": []
},
+ "com.yahoo.searchlib.rankingexpression.rule.OperationNode": {
+ "superClass": "com.yahoo.searchlib.rankingexpression.rule.CompositeNode",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "final"
+ ],
+ "methods": [
+ "public void <init>(java.util.List, java.util.List)",
+ "public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.Operator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
+ "public java.util.List operators()",
+ "public java.util.List children()",
+ "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
+ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
+ "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)",
+ "public int hashCode()",
+ "public static com.yahoo.searchlib.rankingexpression.rule.OperationNode resolve(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.Operator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)"
+ ],
+ "fields": []
+ },
+ "com.yahoo.searchlib.rankingexpression.rule.Operator": {
+ "superClass": "java.lang.Enum",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "final",
+ "enum"
+ ],
+ "methods": [
+ "public static com.yahoo.searchlib.rankingexpression.rule.Operator[] values()",
+ "public static com.yahoo.searchlib.rankingexpression.rule.Operator valueOf(java.lang.String)",
+ "public boolean hasPrecedenceOver(com.yahoo.searchlib.rankingexpression.rule.Operator)",
+ "public final com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Value, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public java.lang.String toString()"
+ ],
+ "fields": [
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator or",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator and",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator largerOrEqual",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator larger",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator smallerOrEqual",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator smaller",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator approxEqual",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator notEqual",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator equal",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator plus",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator minus",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator multiply",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator divide",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator modulo",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator power",
+ "public static final java.util.List operatorsByPrecedence"
+ ]
+ },
"com.yahoo.searchlib.rankingexpression.rule.ReferenceNode": {
"superClass": "com.yahoo.searchlib.rankingexpression.rule.CompositeNode",
"interfaces": [],
@@ -1693,32 +1702,5 @@
"public int hashCode()"
],
"fields": []
- },
- "com.yahoo.searchlib.rankingexpression.rule.TruthOperator": {
- "superClass": "java.lang.Enum",
- "interfaces": [
- "java.io.Serializable"
- ],
- "attributes": [
- "public",
- "abstract",
- "enum"
- ],
- "methods": [
- "public static com.yahoo.searchlib.rankingexpression.rule.TruthOperator[] values()",
- "public static com.yahoo.searchlib.rankingexpression.rule.TruthOperator valueOf(java.lang.String)",
- "public abstract boolean evaluate(double, double)",
- "public java.lang.String toString()",
- "public static com.yahoo.searchlib.rankingexpression.rule.TruthOperator fromString(java.lang.String)"
- ],
- "fields": [
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator SMALLER",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator SMALLEREQUAL",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator EQUAL",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator APPROX_EQUAL",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator LARGER",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator LARGEREQUAL",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator NOTEQUAL"
- ]
}
} \ No newline at end of file
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index afd263f1553..2405d1bd528 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -2,7 +2,6 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -28,58 +27,146 @@ public abstract class DoubleCompatibleValue extends Value {
public Value negate() { return new DoubleValue(-asDouble()); }
@Override
- public Value add(Value value) {
- return new DoubleValue(asDouble() + value.asDouble());
+ public Value not() {
+ return new BooleanValue(!asBoolean());
}
@Override
- public Value subtract(Value value) {
- return new DoubleValue(asDouble() - value.asDouble());
+ public Value or(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.or(this);
+ else
+ return new BooleanValue(asBoolean() || value.asBoolean());
}
@Override
- public Value multiply(Value value) {
- return new DoubleValue(asDouble() * value.asDouble());
+ public Value and(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.and(this);
+ else
+ return new BooleanValue(asBoolean() && value.asBoolean());
}
@Override
- public Value divide(Value value) {
- return new DoubleValue(asDouble() / value.asDouble());
+ public Value largerOrEqual(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.largerOrEqual(this);
+ else
+ return new BooleanValue(this.asDouble() >= value.asDouble());
}
@Override
- public Value modulo(Value value) {
- return new DoubleValue(asDouble() % value.asDouble());
+ public Value larger(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.larger(this);
+ else
+ return new BooleanValue(this.asDouble() > value.asDouble());
}
@Override
- public Value and(Value value) {
- return new BooleanValue(asBoolean() && value.asBoolean());
+ public Value smallerOrEqual(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.smallerOrEqual(this);
+ else
+ return new BooleanValue(this.asDouble() <= value.asDouble());
}
@Override
- public Value or(Value value) {
- return new BooleanValue(asBoolean() || value.asBoolean());
+ public Value smaller(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.smaller(this);
+ else
+ return new BooleanValue(this.asDouble() < value.asDouble());
}
@Override
- public Value not() {
- return new BooleanValue(!asBoolean());
+ public Value approxEqual(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.approxEqual(this);
+ else
+ return new BooleanValue(approxEqual(this.asDouble(), value.asDouble()));
}
@Override
- public Value power(Value value) {
- return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble()));
+ public Value notEqual(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.notEqual(this);
+ else
+ return new BooleanValue(this.asDouble() != value.asDouble());
+ }
+
+ @Override
+ public Value equal(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.equal(this);
+ else
+ return new BooleanValue(this.asDouble() == value.asDouble());
+ }
+
+ @Override
+ public Value add(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.add(this);
+ else
+ return new DoubleValue(asDouble() + value.asDouble());
+ }
+
+ @Override
+ public Value subtract(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.subtract(this);
+ else
+ return new DoubleValue(asDouble() - value.asDouble());
+ }
+
+ @Override
+ public Value multiply(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.multiply(this);
+ else
+ return new DoubleValue(asDouble() * value.asDouble());
+ }
+
+ @Override
+ public Value divide(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.divide(this);
+ else
+ return new DoubleValue(asDouble() / value.asDouble());
+ }
+
+ @Override
+ public Value modulo(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.modulo(this);
+ else
+ return new DoubleValue(asDouble() % value.asDouble());
}
@Override
- public Value compare(TruthOperator operator, Value value) {
- return new BooleanValue(operator.evaluate(asDouble(), value.asDouble()));
+ public Value power(Value value) {
+ if (value instanceof TensorValue tensor)
+ return tensor.power(this);
+ else
+ return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble()));
}
@Override
public Value function(Function function, Value value) {
- return new DoubleValue(function.evaluate(asDouble(),value.asDouble()));
+ if (value instanceof TensorValue tensor)
+ return tensor.function(function, this);
+ else
+ return new DoubleValue(function.evaluate(asDouble(), value.asDouble()));
+ }
+
+ static boolean approxEqual(double x, double y) {
+ if (y < -1.0 || y > 1.0) {
+ x = Math.nextAfter(x/y, 1.0);
+ y = 1.0;
+ } else {
+ x = Math.nextAfter(x, y);
+ }
+ return x == y;
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index 2c2d5eead05..a585c989954 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -57,55 +56,83 @@ public class StringValue extends Value {
}
@Override
- public Value add(Value value) {
- return new StringValue(value + value.toString());
+ public Value not() {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support not");
}
@Override
- public Value subtract(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction");
+ public Value or(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support or");
}
@Override
- public Value multiply(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication");
+ public Value and(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support and");
}
@Override
- public Value divide(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support division");
+ public Value largerOrEqual(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support greaterEqual");
}
@Override
- public Value modulo(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo");
+ public Value larger(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support greater");
}
@Override
- public Value and(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support and");
+ public Value smallerOrEqual(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support lessEqual");
}
@Override
- public Value or(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support or");
+ public Value smaller(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support less");
}
@Override
- public Value not() {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support not");
+ public Value approxEqual(Value argument) {
+ return new BooleanValue(this.asDouble() == argument.asDouble());
}
@Override
- public Value power(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support ^");
+ public Value notEqual(Value argument) {
+ return new BooleanValue(this.asDouble() != argument.asDouble());
+ }
+
+ @Override
+ public Value equal(Value argument) {
+ return new BooleanValue(this.asDouble() == argument.asDouble());
}
@Override
- public Value compare(TruthOperator operator, Value value) {
- if (operator.equals(TruthOperator.EQUAL))
- return new BooleanValue(this.equals(value));
- throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='");
+ public Value add(Value value) {
+ return new StringValue(value + value.toString());
+ }
+
+ @Override
+ public Value subtract(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction");
+ }
+
+ @Override
+ public Value multiply(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication");
+ }
+
+ @Override
+ public Value divide(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support division");
+ }
+
+ @Override
+ public Value modulo(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo");
+ }
+
+ @Override
+ public Value power(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support ^");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index b37bbb543eb..25e03c75376 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.api.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -52,6 +51,83 @@ public class TensorValue extends Value {
}
@Override
+ public Value not() {
+ return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value or(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value and(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value largerOrEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.largerOrEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value >= argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value larger(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.larger(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value > argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value smallerOrEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value <= argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value smaller(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.smaller(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value < argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value approxEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.approxEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> DoubleCompatibleValue.approxEqual(value, argument.asDouble()) ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value notEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.notEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value != argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value equal(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.equal(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value == argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
public Value add(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.add(((TensorValue)argument).value));
@@ -92,27 +168,6 @@ public class TensorValue extends Value {
}
@Override
- public Value and(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
- else
- return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0));
- }
-
- @Override
- public Value or(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
- else
- return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0));
- }
-
- @Override
- public Value not() {
- return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0));
- }
-
- @Override
public Value power(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.pow(((TensorValue)argument).value));
@@ -123,24 +178,6 @@ public class TensorValue extends Value {
public Tensor asTensor() { return value; }
@Override
- public Value compare(TruthOperator operator, Value argument) {
- return new TensorValue(compareTensor(operator, argument.asTensor()));
- }
-
- private Tensor compareTensor(TruthOperator operator, Tensor argument) {
- switch (operator) {
- case LARGER: return value.larger(argument);
- case LARGEREQUAL: return value.largerOrEqual(argument);
- case SMALLER: return value.smaller(argument);
- case SMALLEREQUAL: return value.smallerOrEqual(argument);
- case EQUAL: return value.equal(argument);
- case NOTEQUAL: return value.notEqual(argument);
- case APPROX_EQUAL: return value.approxEqual(argument);
- default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
- }
- }
-
- @Override
public Value function(Function function, Value arg) {
if (arg instanceof TensorValue)
return new TensorValue(functionOnTensor(function, arg.asTensor()));
@@ -149,17 +186,17 @@ public class TensorValue extends Value {
}
private Tensor functionOnTensor(Function function, Tensor argument) {
- switch (function) {
- case min: return value.min(argument);
- case max: return value.max(argument);
- case atan2: return value.atan2(argument);
- case pow: return value.pow(argument);
- case fmod: return value.fmod(argument);
- case ldexp: return value.ldexp(argument);
- case bit: return value.bit(argument);
- case hamming: return value.hamming(argument);
- default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
- }
+ return switch (function) {
+ case min -> value.min(argument);
+ case max -> value.max(argument);
+ case atan2 -> value.atan2(argument);
+ case pow -> value.pow(argument);
+ case fmod -> value.fmod(argument);
+ case ldexp -> value.ldexp(argument);
+ case bit -> value.bit(argument);
+ case hamming -> value.hamming(argument);
+ default -> throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
+ };
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 207603c5038..5de2138147e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -1,9 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;
-import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -51,27 +49,24 @@ public abstract class Value {
public abstract Value negate();
- public abstract Value add(Value value);
+ public abstract Value not();
+ public abstract Value or(Value value);
+ public abstract Value and(Value value);
+ public abstract Value largerOrEqual(Value value);
+ public abstract Value larger(Value value);
+ public abstract Value smallerOrEqual(Value value);
+ public abstract Value smaller(Value value);
+ public abstract Value approxEqual(Value value);
+ public abstract Value notEqual(Value value);
+ public abstract Value equal(Value value);
+ public abstract Value add(Value value);
public abstract Value subtract(Value value);
-
public abstract Value multiply(Value value);
-
public abstract Value divide(Value value);
-
public abstract Value modulo(Value value);
-
- public abstract Value and(Value value);
-
- public abstract Value or(Value value);
-
- public abstract Value not();
-
public abstract Value power(Value value);
- /** Perform the comparison specified by the operator between this value and the given value */
- public abstract Value compare(TruthOperator operator, Value value);
-
/** Perform the given binary function on this value and the given value */
public abstract Value function(Function function, Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
index 6ab483800a7..a3fc6aae9ac 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
@@ -5,8 +5,8 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -84,12 +84,12 @@ public class GBDTForestOptimizer extends Optimizer {
currentTreesOptimized++;
return true;
}
- if (!(node instanceof ArithmeticNode)) {
+ if (!(node instanceof OperationNode)) {
return false;
}
- ArithmeticNode aNode = (ArithmeticNode)node;
- for (ArithmeticOperator op : aNode.operators()) {
- if (op != ArithmeticOperator.PLUS) {
+ OperationNode aNode = (OperationNode)node;
+ for (Operator op : aNode.operators()) {
+ if (op != Operator.plus) {
return false;
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
index 420f1f459f3..7ba671e62eb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
@@ -61,13 +61,13 @@ public class GBDTOptimizer extends Optimizer {
* @return the optimized expression
*/
private ExpressionNode optimize(ExpressionNode node, ContextIndex context) {
- if (node instanceof ArithmeticNode) {
- Iterator<ExpressionNode> childIt = ((ArithmeticNode)node).children().iterator();
+ if (node instanceof OperationNode) {
+ Iterator<ExpressionNode> childIt = ((OperationNode)node).children().iterator();
ExpressionNode ret = optimize(childIt.next(), context);
- Iterator<ArithmeticOperator> operIt = ((ArithmeticNode)node).operators().iterator();
+ Iterator<Operator> operIt = ((OperationNode)node).operators().iterator();
while (childIt.hasNext() && operIt.hasNext()) {
- ret = ArithmeticNode.resolve(ret, operIt.next(), optimize(childIt.next(), context));
+ ret = OperationNode.resolve(ret, operIt.next(), optimize(childIt.next(), context));
}
return ret;
}
@@ -114,15 +114,15 @@ public class GBDTOptimizer extends Optimizer {
/** Consumes the if condition and return the size of the values resulting, for convenience */
private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) {
- if (condition instanceof ComparisonNode) {
- ComparisonNode comparison = (ComparisonNode)condition;
- if (comparison.getOperator() == TruthOperator.SMALLER)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.getLeftCondition(), context));
- else if (comparison.getOperator() == TruthOperator.EQUAL)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.getLeftCondition(), context));
+ if (isBinaryComparison(condition)) {
+ OperationNode comparison = (OperationNode)condition;
+ if (comparison.operators().get(0) == Operator.smaller)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context));
+ else if (comparison.operators().get(0) == Operator.equal)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context));
else
- throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.getOperator());
- values.add(toValue(comparison.getRightCondition()));
+ throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0));
+ values.add(toValue(comparison.children().get(1)));
}
else if (condition instanceof SetMembershipNode) {
SetMembershipNode setMembership = (SetMembershipNode)condition;
@@ -131,17 +131,15 @@ public class GBDTOptimizer extends Optimizer {
for (ExpressionNode setElementNode : setMembership.getSetValues())
values.add(toValue(setElementNode));
}
- else if (condition instanceof NotNode) { // handle if inversion: !(a >= b)
- NotNode notNode = (NotNode)condition;
- if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) {
- EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0);
- if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) {
- ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0);
- if (comparison.getOperator() == TruthOperator.LARGEREQUAL)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context));
+ else if (condition instanceof NotNode notNode) { // handle if inversion: !(a >= b)
+ if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) {
+ if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) {
+ OperationNode comparison = (OperationNode)embracedNode.children().get(0);
+ if (comparison.operators().get(0) == Operator.largerOrEqual)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context));
else
- throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator());
- values.add(toValue(comparison.getRightCondition()));
+ throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0));
+ values.add(toValue(comparison.children().get(1)));
}
}
}
@@ -152,12 +150,24 @@ public class GBDTOptimizer extends Optimizer {
return values.size();
}
+ private boolean isBinaryComparison(ExpressionNode condition) {
+ if ( ! (condition instanceof OperationNode binaryNode)) return false;
+ if (binaryNode.operators().size() != 1) return false;
+ if (binaryNode.operators().get(0) == Operator.largerOrEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.larger) return true;
+ if (binaryNode.operators().get(0) == Operator.smallerOrEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.smaller) return true;
+ if (binaryNode.operators().get(0) == Operator.approxEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.notEqual) return true;
+ if (binaryNode.operators().get(0) == Operator.equal) return true;
+ return false;
+ }
+
private double getVariableIndex(ExpressionNode node, ContextIndex context) {
- if (!(node instanceof ReferenceNode)) {
+ if (!(node instanceof ReferenceNode fNode)) {
throw new IllegalArgumentException("Contained a left-hand comparison expression " +
"which was not a feature value but was: " + node);
}
- ReferenceNode fNode = (ReferenceNode)node;
Integer index = context.getIndex(fNode.toString());
if (index == null) {
throw new IllegalStateException("The ranking expression contained feature '" + fNode.getName() +
@@ -177,8 +187,7 @@ public class GBDTOptimizer extends Optimizer {
value.getClass().getSimpleName() + " (" + value + ") in a set test: " + node);
}
- if (node instanceof NegativeNode) {
- NegativeNode nNode = (NegativeNode)node;
+ if (node instanceof NegativeNode nNode) {
if (!(nNode.getValue() instanceof ConstantNode)) {
throw new IllegalArgumentException("Contained a negation of a non-number: " + nNode.getValue());
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
deleted file mode 100644
index 959045a63a0..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-
-import java.util.List;
-
-/**
- * A mathematical operator
- *
- * @author bratseth
- */
-public enum ArithmeticOperator {
-
- OR(0, "||") { public Value evaluate(Value x, Value y) {
- return x.or(y);
- }},
- AND(1, "&&") { public Value evaluate(Value x, Value y) {
- return x.and(y);
- }},
- PLUS(2, "+") { public Value evaluate(Value x, Value y) {
- return x.add(y);
- }},
- MINUS(3, "-") { public Value evaluate(Value x, Value y) {
- return x.subtract(y);
- }},
- MULTIPLY(4, "*") { public Value evaluate(Value x, Value y) {
- return x.multiply(y);
- }},
- DIVIDE(5, "/") { public Value evaluate(Value x, Value y) {
- return x.divide(y);
- }},
- MODULO(6, "%") { public Value evaluate(Value x, Value y) {
- return x.modulo(y);
- }},
- POWER(7, "^") { public Value evaluate(Value x, Value y) {
- return x.power(y);
- }};
-
- /** A list of all the operators in this in order of decreasing precedence */
- public static final List<ArithmeticOperator> operatorsByPrecedence = List.of(POWER, MODULO, DIVIDE, MULTIPLY, MINUS, PLUS, AND, OR);
-
- private final int precedence;
- private final String image;
-
- private ArithmeticOperator(int precedence, String image) {
- this.precedence = precedence;
- this.image = image;
- }
-
- /** Returns true if this operator has precedence over the given operator */
- public boolean hasPrecedenceOver(ArithmeticOperator op) {
- return precedence > op.precedence;
- }
-
- public abstract Value evaluate(Value x, Value y);
-
- @Override
- public String toString() {
- return image;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
deleted file mode 100644
index e726a351f74..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-
-import java.util.Deque;
-import java.util.List;
-import java.util.Objects;
-
-/**
- * A node which returns the outcome of a comparison.
- *
- * @author bratseth
- */
-public class ComparisonNode extends BooleanNode {
-
- /** The operator string of this condition. */
- private final TruthOperator operator;
-
- private final List<ExpressionNode> conditions;
-
- public ComparisonNode(ExpressionNode leftCondition, TruthOperator operator, ExpressionNode rightCondition) {
- conditions = List.of(leftCondition, rightCondition);
- this.operator = operator;
- }
-
- @Override
- public List<ExpressionNode> children() {
- return conditions;
- }
-
- public TruthOperator getOperator() { return operator; }
-
- public ExpressionNode getLeftCondition() { return conditions.get(0); }
-
- public ExpressionNode getRightCondition() { return conditions.get(1); }
-
- @Override
- public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
- getLeftCondition().toString(string, context, path, this).append(' ').append(operator).append(' ');
- return getRightCondition().toString(string, context, path, this);
- }
-
- @Override
- public TensorType type(TypeContext<Reference> context) {
- return TensorType.empty; // by definition
- }
-
- @Override
- public Value evaluate(Context context) {
- Value leftValue = getLeftCondition().evaluate(context);
- Value rightValue = getRightCondition().evaluate(context);
- return leftValue.compare(operator,rightValue);
- }
-
- @Override
- public ComparisonNode setChildren(List<ExpressionNode> children) {
- if (children.size() != 2) throw new IllegalArgumentException("A comparison test must have 2 children");
- return new ComparisonNode(children.get(0), operator, children.get(1));
- }
-
- @Override
- public int hashCode() { return Objects.hash(operator, conditions); }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index 9f07f146264..97e9a74f9c8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -106,10 +106,10 @@ public class LambdaFunctionNode extends CompositeNode {
}
private Optional<DoubleBinaryOperator> getDirectEvaluator() {
- if ( ! (functionExpression instanceof ArithmeticNode)) {
+ if ( ! (functionExpression instanceof OperationNode)) {
return Optional.empty();
}
- ArithmeticNode node = (ArithmeticNode) functionExpression;
+ OperationNode node = (OperationNode) functionExpression;
if ( ! (node.children().get(0) instanceof ReferenceNode) || ! (node.children().get(1) instanceof ReferenceNode)) {
return Optional.empty();
}
@@ -121,16 +121,16 @@ public class LambdaFunctionNode extends CompositeNode {
if (node.operators().size() != 1) {
return Optional.empty();
}
- ArithmeticOperator operator = node.operators().get(0);
+ Operator operator = node.operators().get(0);
switch (operator) {
- case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
- case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
- case PLUS: return asFunctionExpression((left, right) -> left + right);
- case MINUS: return asFunctionExpression((left, right) -> left - right);
- case MULTIPLY: return asFunctionExpression((left, right) -> left * right);
- case DIVIDE: return asFunctionExpression((left, right) -> left / right);
- case MODULO: return asFunctionExpression((left, right) -> left % right);
- case POWER: return asFunctionExpression(Math::pow);
+ case or: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
+ case and: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
+ case plus: return asFunctionExpression((left, right) -> left + right);
+ case minus: return asFunctionExpression((left, right) -> left - right);
+ case multiply: return asFunctionExpression((left, right) -> left * right);
+ case divide: return asFunctionExpression((left, right) -> left / right);
+ case modulo: return asFunctionExpression((left, right) -> left % right);
+ case power: return asFunctionExpression(Math::pow);
}
return Optional.empty();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
index c3e39197316..0512e1dad2f 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
@@ -16,26 +16,26 @@ import java.util.List;
import java.util.Objects;
/**
- * A binary mathematical operation
+ * A sequence of binary operations.
*
* @author bratseth
*/
-public final class ArithmeticNode extends CompositeNode {
+public final class OperationNode extends CompositeNode {
private final List<ExpressionNode> children;
- private final List<ArithmeticOperator> operators;
+ private final List<Operator> operators;
- public ArithmeticNode(List<ExpressionNode> children, List<ArithmeticOperator> operators) {
+ public OperationNode(List<ExpressionNode> children, List<Operator> operators) {
this.children = List.copyOf(children);
this.operators = List.copyOf(operators);
}
- public ArithmeticNode(ExpressionNode leftExpression, ArithmeticOperator operator, ExpressionNode rightExpression) {
+ public OperationNode(ExpressionNode leftExpression, Operator operator, ExpressionNode rightExpression) {
this.children = List.of(leftExpression, rightExpression);
this.operators = List.of(operator);
}
- public List<ArithmeticOperator> operators() { return operators; }
+ public List<Operator> operators() { return operators; }
@Override
public List<ExpressionNode> children() { return children; }
@@ -50,7 +50,7 @@ public final class ArithmeticNode extends CompositeNode {
child.next().toString(string, context, path, this);
if (child.hasNext())
string.append(" ");
- for (Iterator<ArithmeticOperator> op = operators.iterator(); op.hasNext() && child.hasNext();) {
+ for (Iterator<Operator> op = operators.iterator(); op.hasNext() && child.hasNext();) {
string.append(op.next().toString()).append(" ");
child.next().toString(string, context, path, this);
if (op.hasNext())
@@ -68,14 +68,14 @@ public final class ArithmeticNode extends CompositeNode {
*/
private boolean nonDefaultPrecedence(CompositeNode parent) {
if ( parent == null) return false;
- if ( ! (parent instanceof ArithmeticNode arithmeticParent)) return false;
+ if ( ! (parent instanceof OperationNode operationParent)) return false;
// The line below can only be correct in both only have one operator.
// Getting this correct is impossible without more work.
// So for now we only handle the simple case correctly, and use a safe approach by adding
// extra parenthesis just in case....
- return arithmeticParent.operators.get(0).hasPrecedenceOver(this.operators.get(0))
- || ((arithmeticParent.operators.size() > 1) || (operators.size() > 1));
+ return operationParent.operators.get(0).hasPrecedenceOver(this.operators.get(0))
+ || ((operationParent.operators.size() > 1) || (operators.size() > 1));
}
@Override
@@ -96,8 +96,8 @@ public final class ArithmeticNode extends CompositeNode {
// Apply in precedence order:
Deque<ValueItem> stack = new ArrayDeque<>();
stack.push(new ValueItem(null, child.next().evaluate(context)));
- for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) {
- ArithmeticOperator op = it.next();
+ for (Iterator<Operator> it = operators.iterator(); it.hasNext() && child.hasNext();) {
+ Operator op = it.next();
if ( ! stack.isEmpty()) {
while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) {
popStack(stack);
@@ -121,30 +121,30 @@ public final class ArithmeticNode extends CompositeNode {
public CompositeNode setChildren(List<ExpressionNode> newChildren) {
if (children.size() != newChildren.size())
throw new IllegalArgumentException("Expected " + children.size() + " children but got " + newChildren.size());
- return new ArithmeticNode(newChildren, operators);
+ return new OperationNode(newChildren, operators);
}
@Override
public int hashCode() { return Objects.hash(children, operators); }
- public static ArithmeticNode resolve(ExpressionNode left, ArithmeticOperator op, ExpressionNode right) {
- if ( ! (left instanceof ArithmeticNode leftArithmetic)) return new ArithmeticNode(left, op, right);
+ public static OperationNode resolve(ExpressionNode left, Operator op, ExpressionNode right) {
+ if ( ! (left instanceof OperationNode leftArithmetic)) return new OperationNode(left, op, right);
List<ExpressionNode> newChildren = new ArrayList<>(leftArithmetic.children());
newChildren.add(right);
- List<ArithmeticOperator> newOperators = new ArrayList<>(leftArithmetic.operators());
+ List<Operator> newOperators = new ArrayList<>(leftArithmetic.operators());
newOperators.add(op);
- return new ArithmeticNode(newChildren, newOperators);
+ return new OperationNode(newChildren, newOperators);
}
private static class ValueItem {
- final ArithmeticOperator op;
+ final Operator op;
Value value;
- public ValueItem(ArithmeticOperator op, Value value) {
+ public ValueItem(Operator op, Value value) {
this.op = op;
this.value = value;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
new file mode 100644
index 00000000000..02af88b2c58
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
@@ -0,0 +1,67 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.function.BiFunction;
+
+/**
+ * A mathematical operator
+ *
+ * @author bratseth
+ */
+public enum Operator {
+
+ // In order from lowest to highest precedence
+ or("||", (x, y) -> x.or(y)),
+ and("&&", (x, y) -> x.and(y)),
+ largerOrEqual(">=", (x, y) -> x.largerOrEqual(y)),
+ larger(">", (x, y) -> x.larger(y)),
+ smallerOrEqual("<=", (x, y) -> x.smallerOrEqual(y)),
+ smaller("<", (x, y) -> x.smaller(y)),
+ approxEqual("~=", (x, y) -> x.approxEqual(y)),
+ notEqual("!=", (x, y) -> x.notEqual(y)),
+ equal("==", (x, y) -> x.equal(y)),
+ plus("+", (x, y) -> x.add(y)),
+ minus("-", (x, y) -> x.subtract(y)),
+ multiply("*", (x, y) -> x.multiply(y)),
+ divide("/", (x, y) -> x.divide(y)),
+ modulo("%", (x, y) -> x.modulo(y)),
+ power("^", true, (x, y) -> x.power(y));
+
+ /** A list of all the operators in this in order of increasing precedence */
+ public static final List<Operator> operatorsByPrecedence = Arrays.stream(Operator.values()).toList();
+
+ private final String image;
+ private final boolean rightPrecedence;
+ private final BiFunction<Value, Value, Value> function;
+
+ Operator(String image, BiFunction<Value, Value, Value> function) {
+ this(image, false, function);
+ }
+
+ Operator(String image, boolean rightPrecedence, BiFunction<Value, Value, Value> function) {
+ this.image = image;
+ this.rightPrecedence = rightPrecedence;
+ this.function = function;
+ }
+
+ /** Returns true if this operator has precedence over the given operator */
+ public boolean hasPrecedenceOver(Operator other) {
+ if (operatorsByPrecedence.indexOf(this) == operatorsByPrecedence.indexOf(other))
+ return rightPrecedence;
+ return operatorsByPrecedence.indexOf(this) > operatorsByPrecedence.indexOf(other);
+ }
+
+ public final Value evaluate(Value x, Value y) {
+ return function.apply(x, y);
+ }
+
+ @Override
+ public String toString() {
+ return image;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
deleted file mode 100644
index fc259867923..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import java.io.Serializable;
-
-/**
- * A mathematical operator
- *
- * @author bratseth
- */
-public enum TruthOperator implements Serializable {
-
- SMALLER("<") { public boolean evaluate(double x, double y) { return x<y; } },
- SMALLEREQUAL("<=") { public boolean evaluate(double x, double y) { return x<=y; } },
- EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } },
- APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } },
- LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } },
- LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } },
- NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } };
-
- private final String operatorString;
-
- TruthOperator(String operatorString) {
- this.operatorString = operatorString;
- }
-
- /** Perform the truth operation on the input */
- public abstract boolean evaluate(double x, double y);
-
- @Override
- public String toString() { return operatorString; }
-
- public static TruthOperator fromString(String string) {
- for (TruthOperator operator : values())
- if (operator.toString().equals(string))
- return operator;
- throw new IllegalArgumentException("Illegal truth operator '" + string + "'");
- }
-
- private static boolean approxEqual(double x,double y) {
- if (y < -1.0 || y > 1.0) {
- x = Math.nextAfter(x/y, 1.0);
- y = 1.0;
- } else {
- x = Math.nextAfter(x, y);
- }
- return x==y;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index 90861e64164..1c1f7509ce8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -4,9 +4,8 @@ package com.yahoo.searchlib.rankingexpression.transform;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
-import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
@@ -34,8 +33,8 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
node = transformIf((IfNode) node);
if (node instanceof EmbracedNode e && hasSingleUndividableChild(e))
node = e.children().get(0);
- if (node instanceof ArithmeticNode)
- node = transformArithmetic((ArithmeticNode) node);
+ if (node instanceof OperationNode)
+ node = transformArithmetic((OperationNode) node);
if (node instanceof NegativeNode)
node = transformNegativeNode((NegativeNode) node);
return node;
@@ -43,19 +42,18 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
private boolean hasSingleUndividableChild(EmbracedNode node) {
if (node.children().size() > 1) return false;
- if (node.children().get(0) instanceof ArithmeticNode) return false;
- if (node.children().get(0) instanceof ComparisonNode) return false;
+ if (node.children().get(0) instanceof OperationNode) return false;
return true;
}
- private ExpressionNode transformArithmetic(ArithmeticNode node) {
+ private ExpressionNode transformArithmetic(OperationNode node) {
// Fold the subset of expressions that are constant (such that in "1 + 2 + var")
if (node.children().size() > 1) {
List<ExpressionNode> children = new ArrayList<>(node.children());
- List<ArithmeticOperator> operators = new ArrayList<>(node.operators());
- for (ArithmeticOperator operator : ArithmeticOperator.operatorsByPrecedence)
+ List<Operator> operators = new ArrayList<>(node.operators());
+ for (Operator operator : Operator.operatorsByPrecedence)
transform(operator, children, operators);
- node = new ArithmeticNode(children, operators);
+ node = new OperationNode(children, operators);
}
if (isConstant(node) && ! node.evaluate(null).isNaN())
@@ -66,8 +64,8 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
return node;
}
- private void transform(ArithmeticOperator operatorToTransform,
- List<ExpressionNode> children, List<ArithmeticOperator> operators) {
+ private void transform(Operator operatorToTransform,
+ List<ExpressionNode> children, List<Operator> operators) {
int i = 0;
while (i < children.size()-1) {
boolean transformed = false;
@@ -75,7 +73,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
ExpressionNode child1 = children.get(i);
ExpressionNode child2 = children.get(i + 1);
if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) {
- Value evaluated = new ArithmeticNode(child1, operators.get(i), child2).evaluate(null);
+ Value evaluated = new OperationNode(child1, operators.get(i), child2).evaluate(null);
if ( ! evaluated.isNaN()) { // Don't replace by NaN
operators.remove(i);
children.set(i, new ConstantNode(evaluated.freeze()));
@@ -94,7 +92,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
* This check works because we simplify by decreasing precedence, so neighbours will either be single constant values
* or a more complex expression that can't be simplified and hence also prevents the simplification in question here.
*/
- private boolean hasPrecedence(List<ArithmeticOperator> operators, int i) {
+ private boolean hasPrecedence(List<Operator> operators, int i) {
if (i > 0 && operators.get(i-1).hasPrecedenceOver(operators.get(i))) return false;
if (i < operators.size()-1 && operators.get(i+1).hasPrecedenceOver(operators.get(i))) return false;
return true;
@@ -116,14 +114,14 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
return new ConstantNode(constant.getValue().negate() );
}
- private boolean allMultiplicationOrDivision(ArithmeticNode node) {
- for (ArithmeticOperator o : node.operators())
- if (o != ArithmeticOperator.MULTIPLY && o != ArithmeticOperator.DIVIDE)
+ private boolean allMultiplicationOrDivision(OperationNode node) {
+ for (Operator o : node.operators())
+ if (o != Operator.multiply && o != Operator.divide)
return false;
return true;
}
- private boolean hasZero(ArithmeticNode node) {
+ private boolean hasZero(OperationNode node) {
for (ExpressionNode child : node.children()) {
if (isZero(child))
return true;
@@ -131,9 +129,9 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
return false;
}
- private boolean hasDivisionByZero(ArithmeticNode node) {
+ private boolean hasDivisionByZero(OperationNode node) {
for (int i = 1; i < node.children().size(); i++) {
- if ( node.operators().get(i - 1) == ArithmeticOperator.DIVIDE && isZero(node.children().get(i)))
+ if (node.operators().get(i - 1) == Operator.divide && isZero(node.children().get(i)))
return true;
}
return false;
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index ebe1e048247..42b5f2c191a 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -72,13 +72,13 @@ TOKEN :
<COMMA: ","> |
<COLON: ":"> |
- <LE: "<="> |
- <LT: "<"> |
- <EQ: "=="> |
- <NQ: "!="> |
- <AQ: "~="> |
- <GE: ">="> |
- <GT: ">"> |
+ <GREATEREQUAL: ">="> |
+ <GREATER: ">"> |
+ <LESSEQUAL: "<="> |
+ <LESS: "<"> |
+ <APPROX: "~="> |
+ <NOTEQUAL: "!="> |
+ <EQUAL: "=="> |
<STRING: ("\"" (~["\""] | "\\\"")* "\"") |
("'" (~["'"] | "\\'")* "'")> |
@@ -188,55 +188,50 @@ ExpressionNode expression() :
{
ExpressionNode left, right;
List<ExpressionNode> rightList;
- TruthOperator comparatorOp;
}
{
- ( left = arithmeticExpression()
+ ( left = operationExpression()
(
- ( comparatorOp = comparator() right = arithmeticExpression() { left = new ComparisonNode(left, comparatorOp, right); } ) |
( <IN> rightList = expressionList() { left = new SetMembershipNode(left, rightList); } )
- ) *
+ ) ?
)
{ return left; }
}
-ExpressionNode arithmeticExpression() :
+ExpressionNode operationExpression() :
{
ExpressionNode left, right = null;
- ArithmeticOperator arithmeticOp;
+ Operator operator;
}
{
( left = value()
- ( arithmeticOp = arithmetic() right = value() { left = ArithmeticNode.resolve(left, arithmeticOp, right); } ) *
+ ( operator = binaryOperator() right = value() { left = OperationNode.resolve(left, operator, right); } ) *
)
{ return left; }
}
-ArithmeticOperator arithmetic() : { }
+Operator binaryOperator() : { }
{
- ( <ADD> { return ArithmeticOperator.PLUS; } |
- <SUB> { return ArithmeticOperator.MINUS; } |
- <DIV> { return ArithmeticOperator.DIVIDE; } |
- <MUL> { return ArithmeticOperator.MULTIPLY; } |
- <MOD> { return ArithmeticOperator.MODULO; } |
- <AND> { return ArithmeticOperator.AND; } |
- <OR> { return ArithmeticOperator.OR; } |
- <POWOP> { return ArithmeticOperator.POWER; } )
+ (
+ <OR> { return Operator.or; } |
+ <AND> { return Operator.and; } |
+ <GREATEREQUAL> { return Operator.largerOrEqual; } |
+ <GREATER> { return Operator.larger; } |
+ <LESSEQUAL> { return Operator.smallerOrEqual; } |
+ <LESS> { return Operator.smaller; } |
+ <APPROX> { return Operator.approxEqual; } |
+ <NOTEQUAL> { return Operator.notEqual; } |
+ <EQUAL> { return Operator.equal; } |
+ <ADD> { return Operator.plus; } |
+ <SUB> { return Operator.minus; } |
+ <DIV> { return Operator.divide; } |
+ <MUL> { return Operator.multiply; } |
+ <MOD> { return Operator.modulo; } |
+ <POWOP> { return Operator.power; }
+ )
{ return null; }
}
-TruthOperator comparator() : { }
-{
- ( <LE> { return TruthOperator.SMALLEREQUAL; } |
- <LT> { return TruthOperator.SMALLER; } |
- <EQ> { return TruthOperator.EQUAL; } |
- <NQ> { return TruthOperator.NOTEQUAL; } |
- <AQ> { return TruthOperator.APPROX_EQUAL; } |
- <GE> { return TruthOperator.LARGEREQUAL; } |
- <GT> { return TruthOperator.LARGER; } )
- { return null; }
-}
-
ExpressionNode value() :
{
ExpressionNode value;
@@ -665,7 +660,7 @@ TensorType.Value optionalTensorValueTypeParameter() :
String valueType = "double";
}
{
- ( <LT> valueType = identifier() <GT> )?
+ ( <LESS> valueType = identifier() <GREATER> )?
{ return TensorType.Value.fromId(valueType); }
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index 1ab9ee11252..f18240c3222 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -2,8 +2,8 @@
package com.yahoo.searchlib.rankingexpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -66,7 +66,7 @@ public class RankingExpressionTestCase {
public void testProgrammaticBuilding() throws ParseException {
ReferenceNode input = new ReferenceNode("input");
ReferenceNode constant = new ReferenceNode("constant");
- ArithmeticNode product = new ArithmeticNode(input, ArithmeticOperator.MULTIPLY, constant);
+ OperationNode product = new OperationNode(input, Operator.multiply, constant);
Reduce<Reference> sum = new Reduce<>(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum);
RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index ad50a423eb9..acab7bb38c6 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -5,8 +5,8 @@ import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
@@ -88,6 +88,7 @@ public class EvaluationTestCase {
tester.assertEvaluates(10.0, "3 ^ 2 + 1");
tester.assertEvaluates(18.0, "2 * 3 ^ 2");
tester.assertEvaluates(-4, "1 - 2 - 3"); // Means 1 + -2 + -3
+ tester.assertEvaluates(Math.pow(4, 9), "4^3^2"); // Right precedence, by 51% majority
// Conditionals
tester.assertEvaluates(2 * (3 * 4 + 3) * (4 * 5 - 4 * 200) / 10, "2*(3*4+3)*(4*5-4*200)/10");
@@ -185,6 +186,11 @@ public class EvaluationTestCase {
"map(tensor0, f(x) (log10(x)))", "{ {d1:0}:10, {d1:1}:100, {d1:2}:1000 }");
tester.assertEvaluates("{ {d1:0}:4, {d1:1}:9, {d1:2 }:16 }",
"map(tensor0, f(x) (x * x))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ // -- tensor map shorthands
+ tester.assertEvaluates("{ {d1:0}:0, {d1:1}:1, {d1:2 }:0 }",
+ "tensor0 == 3", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ tester.assertEvaluates("{ {d1:0}:0, {d1:1}:1, {d1:2 }:0 }",
+ "3 == tensor0", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
// -- tensor map composites
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:2, {d1:2 }:3 }",
"log10(tensor0)", "{ {d1:0}:10, {d1:1}:100, {d1:2}:1000 }");
@@ -198,9 +204,9 @@ public class EvaluationTestCase {
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
"tensor0 % 2 == map(tensor0, f(x) (x % 2))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
- "tensor0 || 1 == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ "(tensor0 || 1) == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
- "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ "(tensor0 && 1) == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
"!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }");
@@ -779,8 +785,8 @@ public class EvaluationTestCase {
@Test
public void testProgrammaticBuildingAndPrecedence() {
- RankingExpression standardPrecedence = new RankingExpression(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, new ArithmeticNode(constant(3), ArithmeticOperator.MULTIPLY, constant(4))));
- RankingExpression oppositePrecedence = new RankingExpression(new ArithmeticNode(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, constant(3)), ArithmeticOperator.MULTIPLY, constant(4)));
+ RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.plus, new OperationNode(constant(3), Operator.multiply, constant(4))));
+ RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.plus, constant(3)), Operator.multiply, constant(4)));
assertEquals(14.0, standardPrecedence.evaluate(null).asDouble(), tolerance);
assertEquals(20.0, oppositePrecedence.evaluate(null).asDouble(), tolerance);
assertEquals("2.0 + 3.0 * 4.0", standardPrecedence.toString());