aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2022-09-28 22:54:13 +0200
committerGitHub <noreply@github.com>2022-09-28 22:54:13 +0200
commit12992ecdc0e77968eb5c5544f2ae7d855e443162 (patch)
treeac8cec3ae02f27ae638876940399f490b4ac4ab1 /config-model
parentd50f7bd9c99ed9d8edeabb71825f3966f9cd6bd9 (diff)
parentfb0074925e9e8358d38145dc5753de1c935f737d (diff)
Merge pull request #24251 from vespa-engine/bratseth/operatorsv8.61.17
Bratseth/operators
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java40
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java22
-rw-r--r--config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java29
-rw-r--r--config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java6
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java2
5 files changed, 62 insertions, 37 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
index 8fa4b469590..ad050d4ca63 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
@@ -2,8 +2,8 @@
package com.yahoo.schema.expressiontransforms;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -35,20 +35,20 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
if (node instanceof CompositeNode composite)
node = transformChildren(composite, context);
- if (node instanceof ArithmeticNode arithmetic)
+ if (node instanceof OperationNode arithmetic)
node = transformBooleanArithmetics(arithmetic);
return node;
}
- private ExpressionNode transformBooleanArithmetics(ArithmeticNode node) {
+ private ExpressionNode transformBooleanArithmetics(OperationNode node) {
Iterator<ExpressionNode> child = node.children().iterator();
// Transform in precedence order:
Deque<ChildNode> stack = new ArrayDeque<>();
stack.push(new ChildNode(null, child.next()));
- for (Iterator<ArithmeticOperator> it = node.operators().iterator(); it.hasNext() && child.hasNext();) {
- ArithmeticOperator op = it.next();
+ for (Iterator<Operator> it = node.operators().iterator(); it.hasNext() && child.hasNext();) {
+ Operator op = it.next();
if ( ! stack.isEmpty()) {
while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) {
popStack(stack);
@@ -66,9 +66,9 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
ChildNode lhs = stack.peek();
ExpressionNode combination;
- if (rhs.op == ArithmeticOperator.AND)
+ if (rhs.op == Operator.and)
combination = andByIfNode(lhs.child, rhs.child);
- else if (rhs.op == ArithmeticOperator.OR)
+ else if (rhs.op == Operator.or)
combination = orByIfNode(lhs.child, rhs.child);
else {
combination = resolve(lhs, rhs);
@@ -77,28 +77,28 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
lhs.child = combination;
}
- private static ArithmeticNode resolve(ChildNode left, ChildNode right) {
- if ( ! (left.child instanceof ArithmeticNode) && ! (right.child instanceof ArithmeticNode))
- return new ArithmeticNode(left.child, right.op, right.child);
+ private static OperationNode resolve(ChildNode left, ChildNode right) {
+ if (! (left.child instanceof OperationNode) && ! (right.child instanceof OperationNode))
+ return new OperationNode(left.child, right.op, right.child);
// Collapse inserted ArithmeticNodes
- List<ArithmeticOperator> joinedOps = new ArrayList<>();
+ List<Operator> joinedOps = new ArrayList<>();
joinOps(left, joinedOps);
joinedOps.add(right.op);
joinOps(right, joinedOps);
List<ExpressionNode> joinedChildren = new ArrayList<>();
joinChildren(left, joinedChildren);
joinChildren(right, joinedChildren);
- return new ArithmeticNode(joinedChildren, joinedOps);
+ return new OperationNode(joinedChildren, joinedOps);
}
- private static void joinOps(ChildNode node, List<ArithmeticOperator> joinedOps) {
- if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode)
- joinedOps.addAll(arithmeticNode.operators());
+ private static void joinOps(ChildNode node, List<Operator> joinedOps) {
+ if (node.artificial && node.child instanceof OperationNode operationNode)
+ joinedOps.addAll(operationNode.operators());
}
private static void joinChildren(ChildNode node, List<ExpressionNode> joinedChildren) {
- if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode)
- joinedChildren.addAll(arithmeticNode.children());
+ if (node.artificial && node.child instanceof OperationNode operationNode)
+ joinedChildren.addAll(operationNode.children());
else
joinedChildren.add(node.child);
}
@@ -115,11 +115,11 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
/** A child with the operator to be applied to it when combining it with the previous child. */
private static class ChildNode {
- final ArithmeticOperator op;
+ final Operator op;
ExpressionNode child;
boolean artificial;
- public ChildNode(ArithmeticOperator op, ExpressionNode child) {
+ public ChildNode(Operator op, ExpressionNode child) {
this.op = op;
this.child = child;
}
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
index 30d9a3766b3..cf354a05a93 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
@@ -3,9 +3,8 @@ package com.yahoo.schema.expressiontransforms;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
-import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
@@ -13,7 +12,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
@@ -141,10 +139,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence);
ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
ExpressionNode expr = new IfNode(
- new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr),
+ new OperationNode(new ReferenceNode("d1"), Operator.smaller, queryLengthExpr),
ZERO,
new IfNode(
- new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr),
+ new OperationNode(new ReferenceNode("d1"), Operator.smaller, restLengthExpr),
ONE,
ZERO
)
@@ -176,7 +174,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
List<ExpressionNode> tokenSequence = createTokenSequence(feature);
ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
- ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
+ OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr);
ExpressionNode expr = new IfNode(comparison, ONE, ZERO);
return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
}
@@ -256,7 +254,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
*/
private ExpressionNode createTokenSequenceExpr(int iter, List<ExpressionNode> sequence) {
ExpressionNode lengthExpr = createLengthExpr(iter, sequence);
- ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
+ OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr);
ExpressionNode trueExpr = sequence.get(iter);
if (sequence.get(iter) instanceof ReferenceNode) {
@@ -280,7 +278,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
*/
private ExpressionNode createLengthExpr(int iter, List<ExpressionNode> sequence) {
List<ExpressionNode> factors = new ArrayList<>();
- List<ArithmeticOperator> operators = new ArrayList<>();
+ List<Operator> operators = new ArrayList<>();
for (int i = 0; i < iter + 1; ++i) {
if (sequence.get(i) instanceof ConstantNode) {
factors.add(ONE);
@@ -288,10 +286,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i))));
}
if (i >= 1) {
- operators.add(ArithmeticOperator.PLUS);
+ operators.add(Operator.plus);
}
}
- return new ArithmeticNode(factors, operators);
+ return new OperationNode(factors, operators);
}
/**
@@ -301,7 +299,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
ExpressionNode expr;
if (iter >= 1) {
ExpressionNode lengthExpr = new EmbracedNode(createLengthExpr(iter - 1, sequence));
- expr = new EmbracedNode(new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.MINUS, lengthExpr));
+ expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.minus, lengthExpr));
} else {
expr = new ReferenceNode("d1");
}
diff --git a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
index 5eecee516ec..13d21884c7d 100644
--- a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
@@ -68,7 +68,7 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase {
Schema s = builder.getSchema();
RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels());
- assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))",
+ assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + (if (7.0 < attribute(a), 1, 2) == 0)))",
parent.getFirstPhaseRanking().getRoot().toString());
RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels());
assertEquals("7.0 * (9 + attribute(a))",
@@ -76,6 +76,33 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase {
}
@Test
+ void testInlinedComparison() throws ParseException {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry);
+ builder.addSchema("search test {\n" +
+ " document test { \n" +
+ " }\n" +
+ " \n" +
+ " rank-profile parent {\n" +
+ "function foo() {\n" +
+ " expression: 3 * bar\n" +
+ "}\n" +
+ "\n" +
+ "function inline bar() {\n" +
+ " expression: query(test) > 2.0\n" +
+ "}\n" +
+ "}\n" +
+ "}\n");
+ builder.build(true);
+ Schema s = builder.getSchema();
+
+ RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels());
+ assertEquals("3 * (query(test) > 2.0)",
+ parent.getFunctions().get("foo").function().getBody().getRoot().toString());
+
+ }
+
+ @Test
void testConstants() throws ParseException {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry);
diff --git a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
index 65d12f1cdcf..d692b69d3c8 100644
--- a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
@@ -4,7 +4,7 @@ package com.yahoo.schema.expressiontransforms;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import org.junit.jupiter.api.Test;
@@ -43,8 +43,8 @@ public class BooleanExpressionTransformerTestCase {
var expr = new BooleanExpressionTransformer()
.transform(new RankingExpression("a + b + c * d + e + f"),
new TransformContext(Map.of(), new MapTypeContext()));
- assertTrue(expr.getRoot() instanceof ArithmeticNode);
- ArithmeticNode root = (ArithmeticNode) expr.getRoot();
+ assertTrue(expr.getRoot() instanceof OperationNode);
+ OperationNode root = (OperationNode) expr.getRoot();
assertEquals(5, root.operators().size());
assertEquals(6, root.children().size());
}
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
index 22681858fc3..e9b674a8c87 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -171,7 +171,7 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name());
- assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 0.0, if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value());
+ assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value());
assertEquals("test_unbound_model", config.rankprofile(8).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name());