diff options
Diffstat (limited to 'container-search/src/test/java/com/yahoo/search')
3 files changed, 483 insertions, 0 deletions
diff --git a/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTest.java b/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTest.java new file mode 100644 index 00000000000..728ebbf8f7f --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTest.java @@ -0,0 +1,50 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.config; + +import com.yahoo.tensor.TensorType; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class SchemaInfoTest { + + @Test + public void testSchemaInfoConfiguration() { + assertEquals(SchemaInfoTester.createSchemaInfoFromConfig(), SchemaInfoTester.createSchemaInfo()); + } + + @Test + public void testInputResolution() { + var tester = new SchemaInfoTester(); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "", "", "commonProfile", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "ab", "", "commonProfile", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "a", "", "commonProfile", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "b", "", "commonProfile", "query(myTensor1)"); + + tester.assertInputConflict(TensorType.fromSpec("tensor(a{},b{})"), + "", "", "inconsistent", "query(myTensor1)"); + tester.assertInputConflict(TensorType.fromSpec("tensor(a{},b{})"), + "ab", "", "inconsistent", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "ab", "a", "inconsistent", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(x[10])"), + "ab", "b", "inconsistent", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "a", "", "inconsistent", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(x[10])"), + "b", "", "inconsistent", "query(myTensor1)"); + tester.assertInput(null, + "a", "", "bOnly", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "ab", "", "bOnly", "query(myTensor1)"); + } + +} diff --git a/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTester.java b/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTester.java new file mode 100644 index 00000000000..d5b4522f3aa --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTester.java @@ -0,0 +1,133 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.config; + +import com.yahoo.container.QrSearchersConfig; +import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig; +import com.yahoo.search.Query; +import com.yahoo.tensor.TensorType; +import com.yahoo.yolean.Exceptions; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class SchemaInfoTester { + + private final SchemaInfo schemaInfo; + + SchemaInfoTester() { + this.schemaInfo = createSchemaInfo(); + } + + SchemaInfo schemaInfo() { return schemaInfo; } + + Query query(String sources, String restrict) { + Map<String, String> params = new HashMap<>(); + if ( ! sources.isEmpty()) + params.put("sources", sources); + if ( ! restrict.isEmpty()) + params.put("restrict", restrict); + return new Query.Builder().setSchemaInfo(schemaInfo) + .setRequestMap(params) + .build(); + } + + void assertInput(TensorType expectedType, String sources, String restrict, String rankProfile, String feature) { + assertEquals(expectedType, + schemaInfo.newSession(query(sources, restrict)).rankProfileInput(feature, rankProfile)); + } + + void assertInputConflict(TensorType expectedType, String sources, String restrict, String rankProfile, String feature) { + try { + assertInput(expectedType, sources, restrict, rankProfile, feature); + } + catch (IllegalArgumentException e) { + assertEquals("Conflicting input type declarations for '" + feature + "'", + e.getMessage().split(":")[0]); + } + } + + static SchemaInfo createSchemaInfo() { + List<Schema> schemas = new ArrayList<>(); + RankProfile common = new RankProfile.Builder("commonProfile") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .addInput("query(myTensor2)", TensorType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor3)", TensorType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor4)", TensorType.fromSpec("tensor<float>(x[5])")) + .build(); + schemas.add(new Schema.Builder("a") + .add(common) + .add(new RankProfile.Builder("inconsistent") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .build()) + .build()); + schemas.add(new Schema.Builder("b") + .add(common) + .add(new RankProfile.Builder("inconsistent") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(x[10])")) + .build()) + .add(new RankProfile.Builder("bOnly") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .build()) + .build()); + Map<String, List<String>> clusters = new HashMap<>(); + clusters.put("ab", List.of("a", "b")); + clusters.put("a", List.of("a")); + return new SchemaInfo(schemas, clusters); + } + + /** Creates the same schema info as createSchemaInfo from config objects. */ + static SchemaInfo createSchemaInfoFromConfig() { + var indexInfoConfig = new IndexInfoConfig.Builder(); + + var rankProfileCommon = new DocumentdbInfoConfig.Documentdb.Rankprofile.Builder(); + rankProfileCommon.name("commonProfile"); + rankProfileCommon.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor1)").type("tensor(a{},b{})")); + rankProfileCommon.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor2)").type("tensor(x[2],y[2])")); + rankProfileCommon.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor3)").type("tensor(x[2],y[2])")); + rankProfileCommon.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor4)").type("tensor<float>(x[5])")); + + var documentDbInfoInfoConfig = new DocumentdbInfoConfig.Builder(); + + var documentDbA = new DocumentdbInfoConfig.Documentdb.Builder(); + documentDbA.name("a"); + documentDbA.rankprofile(rankProfileCommon); + var rankProfileInconsistentA = new DocumentdbInfoConfig.Documentdb.Rankprofile.Builder(); + rankProfileInconsistentA.name("inconsistent"); + rankProfileInconsistentA.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor1)").type("tensor(a{},b{})")); + documentDbA.rankprofile(rankProfileInconsistentA); + documentDbInfoInfoConfig.documentdb(documentDbA); + + var documentDbB = new DocumentdbInfoConfig.Documentdb.Builder(); + documentDbB.name("b"); + documentDbB.rankprofile(rankProfileCommon); + var rankProfileInconsistentB = new DocumentdbInfoConfig.Documentdb.Rankprofile.Builder(); + rankProfileInconsistentB.name("inconsistent"); + rankProfileInconsistentB.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor1)").type("tensor(x[10])")); + documentDbB.rankprofile(rankProfileInconsistentB); + var rankProfileBOnly = new DocumentdbInfoConfig.Documentdb.Rankprofile.Builder(); + rankProfileBOnly.name("bOnly"); + rankProfileBOnly.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor1)").type("tensor(a{},b{})")); + documentDbB.rankprofile(rankProfileBOnly); + documentDbInfoInfoConfig.documentdb(documentDbB); + + var qrSearchersConfig = new QrSearchersConfig.Builder(); + var clusterAB = new QrSearchersConfig.Searchcluster.Builder(); + clusterAB.name("ab"); + clusterAB.searchdef("a").searchdef("b"); + qrSearchersConfig.searchcluster(clusterAB); + var clusterA = new QrSearchersConfig.Searchcluster.Builder(); + clusterA.name("a"); + clusterA.searchdef("a"); + qrSearchersConfig.searchcluster(clusterA); + + return new SchemaInfo(indexInfoConfig.build(), documentDbInfoInfoConfig.build(), qrSearchersConfig.build()); + } + +} diff --git a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java new file mode 100644 index 00000000000..1b10e4cd0ba --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java @@ -0,0 +1,300 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.query; + +import com.yahoo.container.jdisc.HttpRequest; +import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; +import com.yahoo.search.Query; +import com.yahoo.search.config.RankProfile; +import com.yahoo.search.config.Schema; +import com.yahoo.search.config.SchemaInfo; +import com.yahoo.search.query.profile.QueryProfile; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.compiled.CompiledQueryProfile; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +/** + * Tests queries towards rank profiles using input declarations. + * + * @author bratseth + */ +public class RankProfileInputTest { + + @Test + public void testTensorRankFeatureInRequest() { + String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}"; + + { + Query query = createTensor1Query(tensorString, "commonProfile", ""); + assertEquals(0, query.errors().size()); + assertEquals(Tensor.from(tensorString), query.properties().get("ranking.features.query(myTensor1)")); + assertEquals(Tensor.from(tensorString), query.getRanking().getFeatures().getTensor("query(myTensor1)").get()); + } + + { // Partial resolution is sufficient + Query query = createTensor1Query(tensorString, "bOnly", ""); + assertEquals(0, query.errors().size()); + assertEquals(Tensor.from(tensorString), query.properties().get("ranking.features.query(myTensor1)")); + assertEquals(Tensor.from(tensorString), query.getRanking().getFeatures().getTensor("query(myTensor1)").get()); + } + + { // Resolution is limited to the correct sources + Query query = createTensor1Query(tensorString, "bOnly", "sources=a"); + assertEquals(0, query.errors().size()); + assertEquals("Not converted to tensor", + tensorString, query.properties().get("ranking.features.query(myTensor1)")); + } + } + + @Test + public void testTensorRankFeatureInRequestInconsistentInput() { + String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}"; + try { + createTensor1Query(tensorString, "inconsistent", ""); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("Conflicting input type declarations for 'query(myTensor1)': " + + "Declared as tensor(a{},b{}) in rank profile 'inconsistent' in schema 'a', " + + "and as tensor(x[10]) in rank profile 'inconsistent' in schema 'b'", + Exceptions.toMessageString(e)); + } + } + + @Test + public void testTensorRankFeatureWithSourceResolution() { + String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}"; + + { + createTensor1Query(tensorString, "inconsistent", "sources=a"); + // Success: No exception + } + + try { + createTensor1Query(tensorString, "inconsistent", "sources=ab"); + fail("Excpected exception"); + } + catch (IllegalArgumentException e) { + // success + } + + { + createTensor1Query(tensorString, "inconsistent", "sources=a&restrict=a"); + // Success: No exception + } + } + + @Test + public void testTensorRankFeatureSetProgrammatically() { + String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}"; + Query query = new Query.Builder() + .setSchemaInfo(createSchemaInfo()) + .setQueryProfile(createQueryProfile()) // Use the instantiation path with query profiles + .setRequest(HttpRequest.createTestRequest("?" + + "&ranking=commonProfile", + com.yahoo.jdisc.http.HttpRequest.Method.GET)) + .build(); + + query.properties().set("ranking.features.query(myTensor1)", Tensor.from(tensorString)); + assertEquals(Tensor.from(tensorString), query.getRanking().getFeatures().getTensor("query(myTensor1)").get()); + } + + @Test + public void testTensorRankFeatureSetProgrammaticallyWithWrongType() { + Query query = new Query.Builder() + .setSchemaInfo(createSchemaInfo()) + .setQueryProfile(createQueryProfile()) // Use the instantiation path with query profiles + .setRequest(HttpRequest.createTestRequest("?" + + "&ranking=commonProfile", + com.yahoo.jdisc.http.HttpRequest.Method.GET)) + .build(); + + String tensorString = "tensor(x[3]):[0.1, 0.2, 0.3]"; + try { + query.getRanking().getFeatures().put("query(myTensor1)",Tensor.from(tensorString)); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " + + "This input is declared in rank profile 'commonProfile' as tensor(a{},b{})", + Exceptions.toMessageString(e)); + } + try { + query.properties().set("ranking.features.query(myTensor1)", Tensor.from(tensorString)); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " + + "Require a tensor of type tensor(a{},b{})", + Exceptions.toMessageString(e)); + } + } + + @Test + public void testUnembeddedTensorRankFeatureInRequest() { + String text = "text to embed into a tensor"; + Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); + Tensor embedding2 = Tensor.from("tensor<float>(x[5]):[1,2,3,4,0]]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1) + ); + assertEmbedQuery("embed(" + text + ")", embedding1, embedders); + assertEmbedQuery("embed('" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(\"" + text + "\")", embedding1, embedders); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(emb1, \"" + text + "\")", embedding1, embedders); + assertEmbedQueryFails("embed(emb2, \"" + text + "\")", embedding1, embedders, + "Can't find embedder 'emb2'. Valid embedders are emb1"); + + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1), + "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2) + ); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders); + assertEmbedQueryFails("embed(emb3, \"" + text + "\")", embedding1, embedders, + "Can't find embedder 'emb3'. Valid embedders are emb1,emb2"); + + // And with specified language + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1) + ); + assertEmbedQuery("embed(" + text + ")", embedding1, embedders, Language.ENGLISH.languageCode()); + + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1), + "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2) + ); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders, Language.ENGLISH.languageCode()); + assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode()); + } + + private Query createTensor1Query(String tensorString, String profile, String additionalParams) { + return new Query.Builder() + .setSchemaInfo(createSchemaInfo()) + .setQueryProfile(createQueryProfile()) // Use the instantiation path with query profiles + .setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features.query(myTensor1)") + + "=" + urlEncode(tensorString) + + "&ranking=" + profile + + "&" + additionalParams, + com.yahoo.jdisc.http.HttpRequest.Method.GET)) + .build(); + } + + private String urlEncode(String s) { + return URLEncoder.encode(s, StandardCharsets.UTF_8); + } + + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders) { + assertEmbedQuery(embed, expected, embedders, null); + } + + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) { + String languageParam = language == null ? "" : "&language=" + language; + String destination = "query(myTensor4)"; + + Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest( + "?" + urlEncode("ranking.features." + destination) + + "=" + urlEncode(embed) + + "&ranking=commonProfile" + + languageParam, + com.yahoo.jdisc.http.HttpRequest.Method.GET)) + .setSchemaInfo(createSchemaInfo()) + .setQueryProfile(createQueryProfile()) + .setEmbedders(embedders) + .build(); + assertEquals(0, query.errors().size()); + assertEquals(expected, query.properties().get("ranking.features." + destination)); + assertEquals(expected, query.getRanking().getFeatures().getTensor(destination).get()); + } + + private void assertEmbedQueryFails(String embed, Tensor expected, Map<String, Embedder> embedders, String errMsg) { + Throwable t = assertThrows(IllegalArgumentException.class, () -> assertEmbedQuery(embed, expected, embedders)); + while (t != null) { + if (t.getMessage().equals(errMsg)) return; + t = t.getCause(); + } + fail("Error '" + errMsg + "' not thrown"); + } + + private CompiledQueryProfile createQueryProfile() { + var registry = new QueryProfileRegistry(); + registry.register(new QueryProfile("test")); + return registry.compile().findQueryProfile("test"); + } + + private SchemaInfo createSchemaInfo() { + List<Schema> schemas = new ArrayList<>(); + RankProfile common = new RankProfile.Builder("commonProfile") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .addInput("query(myTensor2)", TensorType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor3)", TensorType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor4)", TensorType.fromSpec("tensor<float>(x[5])")) + .build(); + schemas.add(new Schema.Builder("a") + .add(common) + .add(new RankProfile.Builder("inconsistent") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .build()) + .build()); + schemas.add(new Schema.Builder("b") + .add(common) + .add(new RankProfile.Builder("inconsistent") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(x[10])")) + .build()) + .add(new RankProfile.Builder("bOnly") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .build()) + .build()); + Map<String, List<String>> clusters = new HashMap<>(); + clusters.put("ab", List.of("a", "b")); + clusters.put("a", List.of("a")); + return new SchemaInfo(schemas, clusters); + } + + private static final class MockEmbedder implements Embedder { + + private final String expectedText; + private final Language expectedLanguage; + private final Tensor tensorToReturn; + + public MockEmbedder(String expectedText, + Language expectedLanguage, + Tensor tensorToReturn) { + this.expectedText = expectedText; + this.expectedLanguage = expectedLanguage; + this.tensorToReturn = tensorToReturn; + } + + @Override + public List<Integer> embed(String text, Embedder.Context context) { + fail("Unexpected call"); + return null; + } + + @Override + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + assertEquals(expectedText, text); + assertEquals(expectedLanguage, context.getLanguage()); + assertEquals(tensorToReturn.type(), tensorType); + return tensorToReturn; + } + + } + +} |