diff options
Diffstat (limited to 'config-model')
5 files changed, 62 insertions, 37 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java index 8fa4b469590..ad050d4ca63 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java @@ -2,8 +2,8 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -35,20 +35,20 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor if (node instanceof CompositeNode composite) node = transformChildren(composite, context); - if (node instanceof ArithmeticNode arithmetic) + if (node instanceof OperationNode arithmetic) node = transformBooleanArithmetics(arithmetic); return node; } - private ExpressionNode transformBooleanArithmetics(ArithmeticNode node) { + private ExpressionNode transformBooleanArithmetics(OperationNode node) { Iterator<ExpressionNode> child = node.children().iterator(); // Transform in precedence order: Deque<ChildNode> stack = new ArrayDeque<>(); stack.push(new ChildNode(null, child.next())); - for (Iterator<ArithmeticOperator> it = node.operators().iterator(); it.hasNext() && child.hasNext();) { - ArithmeticOperator op = it.next(); + for (Iterator<Operator> it = node.operators().iterator(); it.hasNext() && child.hasNext();) { + Operator op = it.next(); if ( ! stack.isEmpty()) { while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) { popStack(stack); @@ -66,9 +66,9 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor ChildNode lhs = stack.peek(); ExpressionNode combination; - if (rhs.op == ArithmeticOperator.AND) + if (rhs.op == Operator.and) combination = andByIfNode(lhs.child, rhs.child); - else if (rhs.op == ArithmeticOperator.OR) + else if (rhs.op == Operator.or) combination = orByIfNode(lhs.child, rhs.child); else { combination = resolve(lhs, rhs); @@ -77,28 +77,28 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor lhs.child = combination; } - private static ArithmeticNode resolve(ChildNode left, ChildNode right) { - if ( ! (left.child instanceof ArithmeticNode) && ! (right.child instanceof ArithmeticNode)) - return new ArithmeticNode(left.child, right.op, right.child); + private static OperationNode resolve(ChildNode left, ChildNode right) { + if (! (left.child instanceof OperationNode) && ! (right.child instanceof OperationNode)) + return new OperationNode(left.child, right.op, right.child); // Collapse inserted ArithmeticNodes - List<ArithmeticOperator> joinedOps = new ArrayList<>(); + List<Operator> joinedOps = new ArrayList<>(); joinOps(left, joinedOps); joinedOps.add(right.op); joinOps(right, joinedOps); List<ExpressionNode> joinedChildren = new ArrayList<>(); joinChildren(left, joinedChildren); joinChildren(right, joinedChildren); - return new ArithmeticNode(joinedChildren, joinedOps); + return new OperationNode(joinedChildren, joinedOps); } - private static void joinOps(ChildNode node, List<ArithmeticOperator> joinedOps) { - if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode) - joinedOps.addAll(arithmeticNode.operators()); + private static void joinOps(ChildNode node, List<Operator> joinedOps) { + if (node.artificial && node.child instanceof OperationNode operationNode) + joinedOps.addAll(operationNode.operators()); } private static void joinChildren(ChildNode node, List<ExpressionNode> joinedChildren) { - if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode) - joinedChildren.addAll(arithmeticNode.children()); + if (node.artificial && node.child instanceof OperationNode operationNode) + joinedChildren.addAll(operationNode.children()); else joinedChildren.add(node.child); } @@ -115,11 +115,11 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor /** A child with the operator to be applied to it when combining it with the previous child. */ private static class ChildNode { - final ArithmeticOperator op; + final Operator op; ExpressionNode child; boolean artificial; - public ChildNode(ArithmeticOperator op, ExpressionNode child) { + public ChildNode(Operator op, ExpressionNode child) { this.op = op; this.child = child; } diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java index 30d9a3766b3..cf354a05a93 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java @@ -3,9 +3,8 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; -import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; @@ -13,7 +12,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Generate; @@ -141,10 +139,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence); ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); ExpressionNode expr = new IfNode( - new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr), + new OperationNode(new ReferenceNode("d1"), Operator.smaller, queryLengthExpr), ZERO, new IfNode( - new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr), + new OperationNode(new ReferenceNode("d1"), Operator.smaller, restLengthExpr), ONE, ZERO ) @@ -176,7 +174,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform List<ExpressionNode> tokenSequence = createTokenSequence(feature); ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); - ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); + OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr); ExpressionNode expr = new IfNode(comparison, ONE, ZERO); return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); } @@ -256,7 +254,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform */ private ExpressionNode createTokenSequenceExpr(int iter, List<ExpressionNode> sequence) { ExpressionNode lengthExpr = createLengthExpr(iter, sequence); - ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); + OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr); ExpressionNode trueExpr = sequence.get(iter); if (sequence.get(iter) instanceof ReferenceNode) { @@ -280,7 +278,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform */ private ExpressionNode createLengthExpr(int iter, List<ExpressionNode> sequence) { List<ExpressionNode> factors = new ArrayList<>(); - List<ArithmeticOperator> operators = new ArrayList<>(); + List<Operator> operators = new ArrayList<>(); for (int i = 0; i < iter + 1; ++i) { if (sequence.get(i) instanceof ConstantNode) { factors.add(ONE); @@ -288,10 +286,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i)))); } if (i >= 1) { - operators.add(ArithmeticOperator.PLUS); + operators.add(Operator.plus); } } - return new ArithmeticNode(factors, operators); + return new OperationNode(factors, operators); } /** @@ -301,7 +299,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform ExpressionNode expr; if (iter >= 1) { ExpressionNode lengthExpr = new EmbracedNode(createLengthExpr(iter - 1, sequence)); - expr = new EmbracedNode(new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.MINUS, lengthExpr)); + expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.minus, lengthExpr)); } else { expr = new ReferenceNode("d1"); } diff --git a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java index 5eecee516ec..13d21884c7d 100644 --- a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java @@ -68,7 +68,7 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase { Schema s = builder.getSchema(); RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels()); - assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))", + assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + (if (7.0 < attribute(a), 1, 2) == 0)))", parent.getFirstPhaseRanking().getRoot().toString()); RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("7.0 * (9 + attribute(a))", @@ -76,6 +76,33 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase { } @Test + void testInlinedComparison() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry); + builder.addSchema("search test {\n" + + " document test { \n" + + " }\n" + + " \n" + + " rank-profile parent {\n" + + "function foo() {\n" + + " expression: 3 * bar\n" + + "}\n" + + "\n" + + "function inline bar() {\n" + + " expression: query(test) > 2.0\n" + + "}\n" + + "}\n" + + "}\n"); + builder.build(true); + Schema s = builder.getSchema(); + + RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels()); + assertEquals("3 * (query(test) > 2.0)", + parent.getFunctions().get("foo").function().getBody().getRoot().toString()); + + } + + @Test void testConstants() throws ParseException { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry); diff --git a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java index 65d12f1cdcf..d692b69d3c8 100644 --- a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java @@ -4,7 +4,7 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import org.junit.jupiter.api.Test; @@ -43,8 +43,8 @@ public class BooleanExpressionTransformerTestCase { var expr = new BooleanExpressionTransformer() .transform(new RankingExpression("a + b + c * d + e + f"), new TransformContext(Map.of(), new MapTypeContext())); - assertTrue(expr.getRoot() instanceof ArithmeticNode); - ArithmeticNode root = (ArithmeticNode) expr.getRoot(); + assertTrue(expr.getRoot() instanceof OperationNode); + OperationNode root = (OperationNode) expr.getRoot(); assertEquals(5, root.operators().size()); assertEquals(6, root.children().size()); } diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java index 22681858fc3..e9b674a8c87 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java @@ -171,7 +171,7 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name()); - assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 0.0, if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value()); + assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value()); assertEquals("test_unbound_model", config.rankprofile(8).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name()); |