summaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-04-21 15:56:41 +0200
committerJon Bratseth <bratseth@gmail.com>2022-04-21 15:56:41 +0200
commit4f2994d9301034e943620e106540fa80a6c3f01e (patch)
tree6b0a76afd976bbd314f1ba5c3764af12f88ef724 /container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java
parent50c7dfee0a9f32debb34d06191808cbd6ae67e4c (diff)
Resolve rank profile inputs
Diffstat (limited to 'container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java')
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java300
1 files changed, 300 insertions, 0 deletions
diff --git a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java
new file mode 100644
index 00000000000..1b10e4cd0ba
--- /dev/null
+++ b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java
@@ -0,0 +1,300 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.query;
+
+import com.yahoo.container.jdisc.HttpRequest;
+import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.search.Query;
+import com.yahoo.search.config.RankProfile;
+import com.yahoo.search.config.Schema;
+import com.yahoo.search.config.SchemaInfo;
+import com.yahoo.search.query.profile.QueryProfile;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.search.query.profile.compiled.CompiledQueryProfile;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.yolean.Exceptions;
+import org.junit.Test;
+
+import java.net.URLEncoder;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.fail;
+
+/**
+ * Tests queries towards rank profiles using input declarations.
+ *
+ * @author bratseth
+ */
+public class RankProfileInputTest {
+
+ @Test
+ public void testTensorRankFeatureInRequest() {
+ String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}";
+
+ {
+ Query query = createTensor1Query(tensorString, "commonProfile", "");
+ assertEquals(0, query.errors().size());
+ assertEquals(Tensor.from(tensorString), query.properties().get("ranking.features.query(myTensor1)"));
+ assertEquals(Tensor.from(tensorString), query.getRanking().getFeatures().getTensor("query(myTensor1)").get());
+ }
+
+ { // Partial resolution is sufficient
+ Query query = createTensor1Query(tensorString, "bOnly", "");
+ assertEquals(0, query.errors().size());
+ assertEquals(Tensor.from(tensorString), query.properties().get("ranking.features.query(myTensor1)"));
+ assertEquals(Tensor.from(tensorString), query.getRanking().getFeatures().getTensor("query(myTensor1)").get());
+ }
+
+ { // Resolution is limited to the correct sources
+ Query query = createTensor1Query(tensorString, "bOnly", "sources=a");
+ assertEquals(0, query.errors().size());
+ assertEquals("Not converted to tensor",
+ tensorString, query.properties().get("ranking.features.query(myTensor1)"));
+ }
+ }
+
+ @Test
+ public void testTensorRankFeatureInRequestInconsistentInput() {
+ String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}";
+ try {
+ createTensor1Query(tensorString, "inconsistent", "");
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("Conflicting input type declarations for 'query(myTensor1)': " +
+ "Declared as tensor(a{},b{}) in rank profile 'inconsistent' in schema 'a', " +
+ "and as tensor(x[10]) in rank profile 'inconsistent' in schema 'b'",
+ Exceptions.toMessageString(e));
+ }
+ }
+
+ @Test
+ public void testTensorRankFeatureWithSourceResolution() {
+ String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}";
+
+ {
+ createTensor1Query(tensorString, "inconsistent", "sources=a");
+ // Success: No exception
+ }
+
+ try {
+ createTensor1Query(tensorString, "inconsistent", "sources=ab");
+ fail("Excpected exception");
+ }
+ catch (IllegalArgumentException e) {
+ // success
+ }
+
+ {
+ createTensor1Query(tensorString, "inconsistent", "sources=a&restrict=a");
+ // Success: No exception
+ }
+ }
+
+ @Test
+ public void testTensorRankFeatureSetProgrammatically() {
+ String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}";
+ Query query = new Query.Builder()
+ .setSchemaInfo(createSchemaInfo())
+ .setQueryProfile(createQueryProfile()) // Use the instantiation path with query profiles
+ .setRequest(HttpRequest.createTestRequest("?" +
+ "&ranking=commonProfile",
+ com.yahoo.jdisc.http.HttpRequest.Method.GET))
+ .build();
+
+ query.properties().set("ranking.features.query(myTensor1)", Tensor.from(tensorString));
+ assertEquals(Tensor.from(tensorString), query.getRanking().getFeatures().getTensor("query(myTensor1)").get());
+ }
+
+ @Test
+ public void testTensorRankFeatureSetProgrammaticallyWithWrongType() {
+ Query query = new Query.Builder()
+ .setSchemaInfo(createSchemaInfo())
+ .setQueryProfile(createQueryProfile()) // Use the instantiation path with query profiles
+ .setRequest(HttpRequest.createTestRequest("?" +
+ "&ranking=commonProfile",
+ com.yahoo.jdisc.http.HttpRequest.Method.GET))
+ .build();
+
+ String tensorString = "tensor(x[3]):[0.1, 0.2, 0.3]";
+ try {
+ query.getRanking().getFeatures().put("query(myTensor1)",Tensor.from(tensorString));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " +
+ "This input is declared in rank profile 'commonProfile' as tensor(a{},b{})",
+ Exceptions.toMessageString(e));
+ }
+ try {
+ query.properties().set("ranking.features.query(myTensor1)", Tensor.from(tensorString));
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " +
+ "Require a tensor of type tensor(a{},b{})",
+ Exceptions.toMessageString(e));
+ }
+ }
+
+ @Test
+ public void testUnembeddedTensorRankFeatureInRequest() {
+ String text = "text to embed into a tensor";
+ Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]");
+ Tensor embedding2 = Tensor.from("tensor<float>(x[5]):[1,2,3,4,0]]");
+
+ Map<String, Embedder> embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1)
+ );
+ assertEmbedQuery("embed(" + text + ")", embedding1, embedders);
+ assertEmbedQuery("embed('" + text + "')", embedding1, embedders);
+ assertEmbedQuery("embed(\"" + text + "\")", embedding1, embedders);
+ assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders);
+ assertEmbedQuery("embed(emb1, \"" + text + "\")", embedding1, embedders);
+ assertEmbedQueryFails("embed(emb2, \"" + text + "\")", embedding1, embedders,
+ "Can't find embedder 'emb2'. Valid embedders are emb1");
+
+ embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1),
+ "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2)
+ );
+ assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders);
+ assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders);
+ assertEmbedQueryFails("embed(emb3, \"" + text + "\")", embedding1, embedders,
+ "Can't find embedder 'emb3'. Valid embedders are emb1,emb2");
+
+ // And with specified language
+ embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1)
+ );
+ assertEmbedQuery("embed(" + text + ")", embedding1, embedders, Language.ENGLISH.languageCode());
+
+ embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1),
+ "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2)
+ );
+ assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders, Language.ENGLISH.languageCode());
+ assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode());
+ }
+
+ private Query createTensor1Query(String tensorString, String profile, String additionalParams) {
+ return new Query.Builder()
+ .setSchemaInfo(createSchemaInfo())
+ .setQueryProfile(createQueryProfile()) // Use the instantiation path with query profiles
+ .setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features.query(myTensor1)") +
+ "=" + urlEncode(tensorString) +
+ "&ranking=" + profile +
+ "&" + additionalParams,
+ com.yahoo.jdisc.http.HttpRequest.Method.GET))
+ .build();
+ }
+
+ private String urlEncode(String s) {
+ return URLEncoder.encode(s, StandardCharsets.UTF_8);
+ }
+
+ private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders) {
+ assertEmbedQuery(embed, expected, embedders, null);
+ }
+
+ private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) {
+ String languageParam = language == null ? "" : "&language=" + language;
+ String destination = "query(myTensor4)";
+
+ Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest(
+ "?" + urlEncode("ranking.features." + destination) +
+ "=" + urlEncode(embed) +
+ "&ranking=commonProfile" +
+ languageParam,
+ com.yahoo.jdisc.http.HttpRequest.Method.GET))
+ .setSchemaInfo(createSchemaInfo())
+ .setQueryProfile(createQueryProfile())
+ .setEmbedders(embedders)
+ .build();
+ assertEquals(0, query.errors().size());
+ assertEquals(expected, query.properties().get("ranking.features." + destination));
+ assertEquals(expected, query.getRanking().getFeatures().getTensor(destination).get());
+ }
+
+ private void assertEmbedQueryFails(String embed, Tensor expected, Map<String, Embedder> embedders, String errMsg) {
+ Throwable t = assertThrows(IllegalArgumentException.class, () -> assertEmbedQuery(embed, expected, embedders));
+ while (t != null) {
+ if (t.getMessage().equals(errMsg)) return;
+ t = t.getCause();
+ }
+ fail("Error '" + errMsg + "' not thrown");
+ }
+
+ private CompiledQueryProfile createQueryProfile() {
+ var registry = new QueryProfileRegistry();
+ registry.register(new QueryProfile("test"));
+ return registry.compile().findQueryProfile("test");
+ }
+
+ private SchemaInfo createSchemaInfo() {
+ List<Schema> schemas = new ArrayList<>();
+ RankProfile common = new RankProfile.Builder("commonProfile")
+ .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})"))
+ .addInput("query(myTensor2)", TensorType.fromSpec("tensor(x[2],y[2])"))
+ .addInput("query(myTensor3)", TensorType.fromSpec("tensor(x[2],y[2])"))
+ .addInput("query(myTensor4)", TensorType.fromSpec("tensor<float>(x[5])"))
+ .build();
+ schemas.add(new Schema.Builder("a")
+ .add(common)
+ .add(new RankProfile.Builder("inconsistent")
+ .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})"))
+ .build())
+ .build());
+ schemas.add(new Schema.Builder("b")
+ .add(common)
+ .add(new RankProfile.Builder("inconsistent")
+ .addInput("query(myTensor1)", TensorType.fromSpec("tensor(x[10])"))
+ .build())
+ .add(new RankProfile.Builder("bOnly")
+ .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})"))
+ .build())
+ .build());
+ Map<String, List<String>> clusters = new HashMap<>();
+ clusters.put("ab", List.of("a", "b"));
+ clusters.put("a", List.of("a"));
+ return new SchemaInfo(schemas, clusters);
+ }
+
+ private static final class MockEmbedder implements Embedder {
+
+ private final String expectedText;
+ private final Language expectedLanguage;
+ private final Tensor tensorToReturn;
+
+ public MockEmbedder(String expectedText,
+ Language expectedLanguage,
+ Tensor tensorToReturn) {
+ this.expectedText = expectedText;
+ this.expectedLanguage = expectedLanguage;
+ this.tensorToReturn = tensorToReturn;
+ }
+
+ @Override
+ public List<Integer> embed(String text, Embedder.Context context) {
+ fail("Unexpected call");
+ return null;
+ }
+
+ @Override
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ assertEquals(expectedText, text);
+ assertEquals(expectedLanguage, context.getLanguage());
+ assertEquals(tensorToReturn.type(), tensorType);
+ return tensorToReturn;
+ }
+
+ }
+
+}