aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionTypeResolverTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionTypeResolverTestCase.java')
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionTypeResolverTestCase.java521
1 files changed, 521 insertions, 0 deletions
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionTypeResolverTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionTypeResolverTestCase.java
new file mode 100644
index 00000000000..4b6a22fc81a
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionTypeResolverTestCase.java
@@ -0,0 +1,521 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema.processing;
+
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.search.query.profile.types.FieldDescription;
+import com.yahoo.search.query.profile.types.QueryProfileType;
+import com.yahoo.search.query.profile.types.TensorFieldType;
+import com.yahoo.schema.RankProfile;
+import com.yahoo.schema.RankProfileRegistry;
+import com.yahoo.schema.ApplicationBuilder;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.yolean.Exceptions;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.logging.Level;
+import java.util.stream.Collectors;
+
+import static com.yahoo.config.model.test.TestUtil.joinLines;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
+/**
+ * @author bratseth
+ */
+public class RankingExpressionTypeResolverTestCase {
+
+ @Test
+ public void tensorFirstPhaseMustProduceDouble() throws Exception {
+ try {
+ ApplicationBuilder builder = new ApplicationBuilder();
+ builder.addSchema(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[10],y[3]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: attribute(a)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build(true);
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In schema 'test', rank profile 'my_rank_profile': The first-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[10],y[3])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+
+ @Test
+ public void tensorFirstPhaseFromConstantMustProduceDouble() throws Exception {
+ try {
+ ApplicationBuilder builder = new ApplicationBuilder();
+ builder.addSchema(joinLines(
+ "schema test {",
+ " document test { ",
+ " field a type tensor(d0[3]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " function my_func() {",
+ " expression: x_tensor*2.0",
+ " }",
+ " function inline other_func() {",
+ " expression: z_tensor+3.0",
+ " }",
+ " first-phase {",
+ " expression: reduce(attribute(a),sum,d0)+y_tensor+my_func+other_func",
+ " }",
+ " constants {",
+ " x_tensor {", // legacy form
+ " type: tensor(x{})",
+ " value: { {x:bar}:17 }",
+ " }",
+ " y_tensor tensor(y{}):{{y:foo}:42 }",
+ " z_tensor tensor(z{}):{qux:666}",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build(true);
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In schema 'test', rank profile 'my_rank_profile': The first-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x{},y{},z{})",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+
+
+ @Test
+ public void tensorSecondPhaseMustProduceDouble() throws Exception {
+ try {
+ ApplicationBuilder builder = new ApplicationBuilder();
+ builder.addSchema(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[10],y[3]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: sum(attribute(a))",
+ " }",
+ " second-phase {",
+ " expression: attribute(a)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build(true);
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In schema 'test', rank profile 'my_rank_profile': The second-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[10],y[3])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void tensorConditionsMustHaveTypeCompatibleBranches() throws Exception {
+ try {
+ ApplicationBuilder schemaBuilder = new ApplicationBuilder();
+ schemaBuilder.addSchema(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[10],y[5]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: sum(if(1>0, attribute(a), attribute(b)))",
+ " }",
+ " }",
+ "}"
+ ));
+ schemaBuilder.build(true);
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In schema 'test', rank profile 'my_rank_profile': The first-phase expression is invalid: An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[10],y[5]) while the 'false' type is tensor(z[10])" +
+ "\n'true' branch: attribute(a)" +
+ "\n'false' branch: attribute(b)",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void testFunctionInvocationTypes() throws Exception {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry);
+ builder.addSchema(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[10],y[3]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " function macro1(attribute_to_use) {",
+ " expression: attribute(attribute_to_use)",
+ " }",
+ " summary-features {",
+ " macro1(a)",
+ " macro1(b)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build(true);
+ RankProfile profile =
+ builder.getRankProfileRegistry().get(builder.getSchema(), "my_rank_profile");
+ assertEquals(TensorType.fromSpec("tensor(x[10],y[3])"),
+ summaryFeatures(profile).get("macro1(a)").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ assertEquals(TensorType.fromSpec("tensor(z[10])"),
+ summaryFeatures(profile).get("macro1(b)").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ }
+
+ @Test
+ public void testTensorFunctionInvocationTypes_Nested() throws Exception {
+ ApplicationBuilder builder = new ApplicationBuilder();
+ builder.addSchema(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[10],y[1]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " function return_a() {",
+ " expression: return_first(attribute(a), attribute(b))",
+ " }",
+ " function return_b() {",
+ " expression: return_second(attribute(a), attribute(b))",
+ " }",
+ " function return_first(e1, e2) {",
+ " expression: e1",
+ " }",
+ " function return_second(e1, e2) {",
+ " expression: return_first(e2, e1)",
+ " }",
+ " summary-features {",
+ " return_a",
+ " return_b",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build(true);
+ RankProfile profile =
+ builder.getRankProfileRegistry().get(builder.getSchema(), "my_rank_profile");
+ assertEquals(TensorType.fromSpec("tensor(x[10],y[1])"),
+ summaryFeatures(profile).get("return_a").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ assertEquals(TensorType.fromSpec("tensor(z[10])"),
+ summaryFeatures(profile).get("return_b").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ }
+
+ @Test
+ public void testAttributeInvocationViaBoundIdentifier() throws Exception {
+ ApplicationBuilder builder = new ApplicationBuilder();
+ builder.addSchema(joinLines(
+ "search newsarticle {",
+ " document newsarticle {",
+ " field title type string {",
+ " indexing {",
+ " input title | index",
+ " }",
+ " weight: 30",
+ " }",
+ " field usstaticrank type int {",
+ " indexing: summary | attribute",
+ " }",
+ " field eustaticrank type int {",
+ " indexing: summary | attribute",
+ " }",
+ " }",
+ " rank-profile default {",
+ " macro newsboost() { ",
+ " expression: 200 * matches(title)",
+ " }",
+ " macro commonboost(mystaticrank) { ",
+ " expression: attribute(mystaticrank) + newsboost",
+ " }",
+ " macro commonfirstphase(mystaticrank) { ",
+ " expression: nativeFieldMatch(title) + commonboost(mystaticrank) ",
+ " }",
+ " first-phase { expression: commonfirstphase(usstaticrank) }",
+ " }",
+ " rank-profile eurank inherits default {",
+ " first-phase { expression: commonfirstphase(eustaticrank) }",
+ " }",
+ "}"));
+ builder.build(true);
+ RankProfile profile = builder.getRankProfileRegistry().get(builder.getSchema(), "eurank");
+ }
+
+ @Test
+ public void testTensorFunctionInvocationTypes_NestedSameName() throws Exception {
+ ApplicationBuilder builder = new ApplicationBuilder();
+ builder.addSchema(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[10],y[1]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " function return_a() {",
+ " expression: return_first(attribute(a), attribute(b))",
+ " }",
+ " function return_b() {",
+ " expression: return_second(attribute(a), attribute(b))",
+ " }",
+ " function return_first(e1, e2) {",
+ " expression: just_return(e1)",
+ " }",
+ " function just_return(e1) {",
+ " expression: e1",
+ " }",
+ " function return_second(e1, e2) {",
+ " expression: return_first(e2+0, e1)",
+ " }",
+ " summary-features {",
+ " return_a",
+ " return_b",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build(true);
+ RankProfile profile =
+ builder.getRankProfileRegistry().get(builder.getSchema(), "my_rank_profile");
+ assertEquals(TensorType.fromSpec("tensor(x[10],y[1])"),
+ summaryFeatures(profile).get("return_a").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ assertEquals(TensorType.fromSpec("tensor(z[10])"),
+ summaryFeatures(profile).get("return_b").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ }
+
+ @Test
+ public void testTensorFunctionInvocationTypes_viaFuncWithExpr() throws Exception {
+ ApplicationBuilder builder = new ApplicationBuilder();
+ builder.addSchema(joinLines(
+ "search test {",
+ " document test {",
+ " field t1 type tensor<float>(y{}) { indexing: attribute | summary }",
+ " field t2 type tensor<float>(x{}) { indexing: attribute | summary }",
+ " }",
+ " rank-profile test {",
+ " function my_func(t) { expression: sum(t, x) + 1 }",
+ " function test_func_via_func_with_expr() { expression: call_func_with_expr( attribute(t1), attribute(t2) ) }",
+ " function call_func_with_expr(a, b) { expression: my_func( a * b ) }",
+ " summary-features { test_func_via_func_with_expr }",
+ " }",
+ "}"));
+ builder.build(true);
+ RankProfile profile = builder.getRankProfileRegistry().get(builder.getSchema(), "test");
+ assertEquals(TensorType.fromSpec("tensor<float>(y{})"),
+ summaryFeatures(profile).get("test_func_via_func_with_expr").type(profile.typeContext(builder.getQueryProfileRegistry())));
+ }
+
+ @Test
+ public void importedFieldsAreAvailable() throws Exception {
+ ApplicationBuilder builder = new ApplicationBuilder();
+ builder.addSchema(joinLines(
+ "search parent {",
+ " document parent {",
+ " field a type tensor(x[5],y[1000]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.addSchema(joinLines(
+ "search child {",
+ " document child { ",
+ " field ref type reference<parent> {",
+ "indexing: attribute | summary",
+ " }",
+ " }",
+ " import field ref.a as imported_a {}",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: sum(attribute(imported_a))",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build(true);
+ }
+
+ @Test
+ public void undeclaredQueryFeaturesAreAccepted() throws Exception {
+ InspectableDeployLogger logger = new InspectableDeployLogger();
+ ApplicationBuilder builder = new ApplicationBuilder(logger);
+ builder.addSchema(joinLines(
+ "search test {",
+ " document test { ",
+ " field anyfield type double {" +
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: query(foo) + f() + sum(attribute(anyfield))",
+ " }",
+ " function f() {",
+ " expression: query(bar) + query(baz)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build(true);
+ String message = logger.findMessage("The following query features");
+ assertNull(message);
+ }
+
+ @Test
+ public void undeclaredQueryFeaturesAreNotAcceptedWhenStrict() throws Exception {
+ try {
+ InspectableDeployLogger logger = new InspectableDeployLogger();
+ ApplicationBuilder builder = new ApplicationBuilder(logger);
+ builder.addSchema(joinLines(
+ "search test {",
+ " document test { ",
+ " field anyfield type double {" +
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " strict: true" +
+ " first-phase {",
+ " expression: query(foo) + f() + sum(attribute(anyfield))",
+ " }",
+ " function f() {",
+ " expression: query(bar) + query(baz)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build(true);
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("In schema 'test', rank profile 'my_rank_profile': rank profile 'my_rank_profile' is strict but is missing a query profile type declaration of features [query(bar), query(baz), query(foo)]",
+ Exceptions.toMessageString(e));
+ }
+ }
+
+ @Test
+ public void undeclaredQueryFeaturesAreAcceptedWithWarningWhenUsingTensors() throws Exception {
+ InspectableDeployLogger logger = new InspectableDeployLogger();
+ ApplicationBuilder builder = new ApplicationBuilder(logger);
+ builder.addSchema(joinLines(
+ "search test {",
+ " document test { ",
+ " field anyfield type tensor(d[2]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: query(foo) + f() + sum(attribute(anyfield))",
+ " }",
+ " function f() {",
+ " expression: query(bar) + query(baz)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build(true);
+ String message = logger.findMessage("The following query features");
+ assertNotNull(message);
+ assertEquals("WARNING: The following query features used in rank profile 'my_rank_profile' are not declared in query profile types and " +
+ "will be interpreted as scalars, not tensors: [query(bar), query(baz), query(foo)]",
+ message);
+ }
+
+ @Test
+ public void noWarningWhenUsingTensorsWhenQueryFeaturesAreDeclared() throws Exception {
+ InspectableDeployLogger logger = new InspectableDeployLogger();
+ ApplicationBuilder builder = new ApplicationBuilder(logger);
+ QueryProfileType myType = new QueryProfileType("mytype");
+ myType.addField(new FieldDescription("rank.feature.query(foo)",
+ new TensorFieldType(TensorType.fromSpec("tensor(d[2])"))),
+ builder.getQueryProfileRegistry().getTypeRegistry());
+ myType.addField(new FieldDescription("rank.feature.query(bar)",
+ new TensorFieldType(TensorType.fromSpec("tensor(d[2])"))),
+ builder.getQueryProfileRegistry().getTypeRegistry());
+ myType.addField(new FieldDescription("rank.feature.query(baz)",
+ new TensorFieldType(TensorType.fromSpec("tensor(d[2])"))),
+ builder.getQueryProfileRegistry().getTypeRegistry());
+ builder.getQueryProfileRegistry().getTypeRegistry().register(myType);
+ builder.addSchema(joinLines(
+ "search test {",
+ " document test { ",
+ " field anyfield type tensor(d[2]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: sum(query(foo) + f() + sum(attribute(anyfield)))",
+ " }",
+ " function f() {",
+ " expression: query(bar) + query(baz)",
+ " }",
+ " }",
+ "}"
+ ));
+ builder.build(true);
+ String message = logger.findMessage("The following query features");
+ assertNull(message);
+ }
+
+ private Map<String, ReferenceNode> summaryFeatures(RankProfile profile) {
+ return profile.getSummaryFeatures().stream().collect(Collectors.toMap(f -> f.toString(), f -> f));
+ }
+
+ private static class InspectableDeployLogger implements DeployLogger {
+
+ private List<String> messages = new ArrayList<>();
+
+ @Override
+ public void log(Level level, String message) {
+ messages.add(level + ": " + message);
+ }
+
+ /** Returns the first message containing the given string, or null if none */
+ public String findMessage(String substring) {
+ return messages.stream().filter(message -> message.contains(substring)).findFirst().orElse(null);
+ }
+
+ }
+
+}