aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test/java')
-rw-r--r--config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java29
-rw-r--r--config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java6
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java2
3 files changed, 32 insertions, 5 deletions
diff --git a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
index 5eecee516ec..13d21884c7d 100644
--- a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
@@ -68,7 +68,7 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase {
Schema s = builder.getSchema();
RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels());
- assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))",
+ assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + (if (7.0 < attribute(a), 1, 2) == 0)))",
parent.getFirstPhaseRanking().getRoot().toString());
RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels());
assertEquals("7.0 * (9 + attribute(a))",
@@ -76,6 +76,33 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase {
}
@Test
+ void testInlinedComparison() throws ParseException {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry);
+ builder.addSchema("search test {\n" +
+ " document test { \n" +
+ " }\n" +
+ " \n" +
+ " rank-profile parent {\n" +
+ "function foo() {\n" +
+ " expression: 3 * bar\n" +
+ "}\n" +
+ "\n" +
+ "function inline bar() {\n" +
+ " expression: query(test) > 2.0\n" +
+ "}\n" +
+ "}\n" +
+ "}\n");
+ builder.build(true);
+ Schema s = builder.getSchema();
+
+ RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels());
+ assertEquals("3 * (query(test) > 2.0)",
+ parent.getFunctions().get("foo").function().getBody().getRoot().toString());
+
+ }
+
+ @Test
void testConstants() throws ParseException {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry);
diff --git a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
index 65d12f1cdcf..d692b69d3c8 100644
--- a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
@@ -4,7 +4,7 @@ package com.yahoo.schema.expressiontransforms;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import org.junit.jupiter.api.Test;
@@ -43,8 +43,8 @@ public class BooleanExpressionTransformerTestCase {
var expr = new BooleanExpressionTransformer()
.transform(new RankingExpression("a + b + c * d + e + f"),
new TransformContext(Map.of(), new MapTypeContext()));
- assertTrue(expr.getRoot() instanceof ArithmeticNode);
- ArithmeticNode root = (ArithmeticNode) expr.getRoot();
+ assertTrue(expr.getRoot() instanceof OperationNode);
+ OperationNode root = (OperationNode) expr.getRoot();
assertEquals(5, root.operators().size());
assertEquals(6, root.children().size());
}
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
index 22681858fc3..e9b674a8c87 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -171,7 +171,7 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name());
- assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 0.0, if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value());
+ assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value());
assertEquals("test_unbound_model", config.rankprofile(8).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name());