From c6cf14a79ebb2312490bf6bf11289e15669875f2 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 23 Feb 2024 09:16:17 +0100 Subject: Revert "Merge pull request #30384 from vespa-engine/revert-30361-bratseth/resolve-from-query-profile" This reverts commit 9956c1867a8d36a67e15a416d1b75bec8aa53ba3, reversing changes made to 86f5d187f64868fecc69af4fa2c2677f04044a5e. --- .../searchers/ValidateNearestNeighborTestCase.java | 178 +++++++++++++++------ 1 file changed, 132 insertions(+), 46 deletions(-) (limited to 'container-search/src') diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index 3c1a500bfe8..ac2d66e6487 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -2,6 +2,8 @@ package com.yahoo.search.searchers; import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; import com.yahoo.prelude.IndexFacts; import com.yahoo.prelude.IndexModel; import com.yahoo.prelude.SearchDefinition; @@ -10,7 +12,16 @@ import com.yahoo.search.Result; import com.yahoo.search.query.QueryTree; import com.yahoo.search.query.parser.Parsable; import com.yahoo.search.query.parser.ParserEnvironment; +import com.yahoo.search.query.profile.QueryProfile; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.types.FieldDescription; +import com.yahoo.search.query.profile.types.QueryProfileType; +import com.yahoo.search.query.profile.types.QueryProfileTypeRegistry; import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.schema.Cluster; +import com.yahoo.search.schema.RankProfile; +import com.yahoo.search.schema.Schema; +import com.yahoo.search.schema.SchemaInfo; import com.yahoo.search.searchchain.Execution; import com.yahoo.search.yql.YqlParser; import com.yahoo.tensor.Tensor; @@ -18,8 +29,13 @@ import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.AttributesConfig; import org.junit.jupiter.api.Test; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.fail; /** * @author arnej @@ -63,6 +79,16 @@ public class ValidateNearestNeighborTestCase { )); } + private static SchemaInfo createSchemaInfo() { + List schemas = new ArrayList<>(); + RankProfile.Builder common = new RankProfile.Builder("default") + .addInput("query(qvector)", TensorType.fromSpec("tensor(x[3])")); + schemas.add(new Schema.Builder("document").add(common.build()).build()); + List clusters = new ArrayList<>(); + clusters.add(new Cluster.Builder("test").addSchema("document").build()); + return new SchemaInfo(schemas, clusters); + } + private static final TensorType tt_dense_dvector_42 = TensorType.fromSpec("tensor(x[42])"); private static final TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])"); private static final TensorType tt_dense_dvector_2 = TensorType.fromSpec("tensor(x[2])"); @@ -70,41 +96,6 @@ public class ValidateNearestNeighborTestCase { private static final TensorType tt_dense_matrix_xy = TensorType.fromSpec("tensor(x[3],y[1])"); private static final TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})"); - private Tensor makeTensor(TensorType tensorType) { - return makeTensor(tensorType, 3); - } - - private Tensor makeTensor(TensorType tensorType, int numCells) { - Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); - double dv = 1.0; - String tensorDimension = "x"; - for (long label = 0; label < numCells; label++) { - tensorBuilder.cell() - .label(tensorDimension, label) - .value(dv); - dv += 1.0; - } - return tensorBuilder.build(); - } - - private Tensor makeMatrix(TensorType tensorType) { - Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); - double dv = 1.0; - String tensorDimension = "x"; - for (long label = 0; label < 3; label++) { - tensorBuilder.cell() - .label("y", 0L) - .label(tensorDimension, label) - .value(dv); - dv += 1.0; - } - return tensorBuilder.build(); - } - - private String makeQuery(String attributeTensor, String queryTensor) { - return "select * from sources * where [{targetHits:1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ")"; - } - @Test void testValidQueryDoubleVectors() { String q = makeQuery("dvector", "qvector"); @@ -150,16 +141,14 @@ public class ValidateNearestNeighborTestCase { } static String desc(String field, String qt, int th, String errmsg) { - StringBuilder r = new StringBuilder(); - r.append("NEAREST_NEIGHBOR {"); - r.append("field=").append(field); - r.append(",queryTensorName=").append(qt); - r.append(",hnsw.exploreAdditionalHits=0"); - r.append(",distanceThreshold=Infinity"); - r.append(",approximate=true"); - r.append(",targetHits=").append(th); - r.append("} ").append(errmsg); - return r.toString(); + return "NEAREST_NEIGHBOR {" + + "field=" + field + + ",queryTensorName=" + qt + + ",hnsw.exploreAdditionalHits=0" + + ",distanceThreshold=Infinity" + + ",approximate=true" + + ",targetHits=" + th + + "} " + errmsg; } @Test @@ -232,14 +221,111 @@ public class ValidateNearestNeighborTestCase { assertErrMsg(desc("matrix", "qvector", 1, "field type tensor(x[3],y[1]) is not supported by nearest neighbor searcher"), r); } + @Test + void testWithQueryProfileArgument() { + var embedder = new MockEmbedder("test text", + Language.UNKNOWN, + Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]")); + var registry = new QueryProfileRegistry(); + var profile = new QueryProfile("test"); + profile.set("ranking.features.query(qvector)", "embed(@foo)", registry); + registry.register(profile); + var queryString = makeQuery("fvector", "qvector"); + var query = new Query.Builder() + .setSchemaInfo(createSchemaInfo()) + .setQueryProfile(registry.compile().findQueryProfile("test")) + .setEmbedder(embedder) + .setRequestMap(Map.of("foo", "test text")) + .build(); + setYqlQuery(query, queryString); + var result = doSearch(searcher, query); + assertNull(result.hits().getError()); + } + + private Tensor makeTensor(TensorType tensorType) { + return makeTensor(tensorType, 3); + } + + private Tensor makeTensor(TensorType tensorType, int numCells) { + Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); + double dv = 1.0; + String tensorDimension = "x"; + for (long label = 0; label < numCells; label++) { + tensorBuilder.cell() + .label(tensorDimension, label) + .value(dv); + dv += 1.0; + } + return tensorBuilder.build(); + } + + private Tensor makeMatrix(TensorType tensorType) { + Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); + double dv = 1.0; + String tensorDimension = "x"; + for (long label = 0; label < 3; label++) { + tensorBuilder.cell() + .label("y", 0L) + .label(tensorDimension, label) + .value(dv); + dv += 1.0; + } + return tensorBuilder.build(); + } + + private String makeQuery(String attributeTensor, String queryTensor) { + return "select * from sources * where [{targetHits:1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ")"; + } + private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Tensor qTensor) { + return doSearch(searcher, setYqlQuery(new Query(), yqlQuery), qTensor); + } + + private static Query setYqlQuery(Query query, String yqlQuery) { QueryTree queryTree = new YqlParser(new ParserEnvironment()).parse(new Parsable().setQuery(yqlQuery)); - Query query = new Query(); query.getModel().getQueryTree().setRoot(queryTree.getRoot()); + return query; + } + + private static Result doSearch(ValidateNearestNeighborSearcher searcher, Query query, Tensor qTensor) { query.getRanking().getFeatures().put("query(qvector)", qTensor); + return doSearch(searcher, query); + } + + private static Result doSearch(ValidateNearestNeighborSearcher searcher, Query query) { SearchDefinition searchDefinition = new SearchDefinition("document"); IndexFacts indexFacts = new IndexFacts(new IndexModel(searchDefinition)); return new Execution(searcher, Execution.Context.createContextStub(indexFacts)).search(query); } + private static final class MockEmbedder implements Embedder { + + private final String expectedText; + private final Language expectedLanguage; + private final Tensor tensorToReturn; + + public MockEmbedder(String expectedText, + Language expectedLanguage, + Tensor tensorToReturn) { + this.expectedText = expectedText; + this.expectedLanguage = expectedLanguage; + this.tensorToReturn = tensorToReturn; + } + + @Override + public List embed(String text, Embedder.Context context) { + fail("Unexpected call"); + return null; + } + + @Override + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + assertEquals(expectedText, text); + assertEquals(expectedLanguage, context.getLanguage()); + assertEquals(tensorToReturn.type(), tensorType); + return tensorToReturn; + } + + } + } -- cgit v1.2.3 From 8dad821312de7d2f75435b26cbb8c4ba359627e1 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 23 Feb 2024 09:41:39 +0100 Subject: Use InputType --- .../com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'container-search/src') diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index ac2d66e6487..8e7c7276de1 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -82,7 +82,7 @@ public class ValidateNearestNeighborTestCase { private static SchemaInfo createSchemaInfo() { List schemas = new ArrayList<>(); RankProfile.Builder common = new RankProfile.Builder("default") - .addInput("query(qvector)", TensorType.fromSpec("tensor(x[3])")); + .addInput("query(qvector)", RankProfile.InputType.fromSpec("tensor(x[3])")); schemas.add(new Schema.Builder("document").add(common.build()).build()); List clusters = new ArrayList<>(); clusters.add(new Cluster.Builder("test").addSchema("document").build()); -- cgit v1.2.3