From 7bfb8834fb328bb6350411e39281e6f5372a13e9 Mon Sep 17 00:00:00 2001 From: Harald Musum Date: Fri, 23 Feb 2024 09:10:57 +0100 Subject: Revert "Add embed + NN test" --- .../searchers/ValidateNearestNeighborTestCase.java | 178 ++++++--------------- 1 file changed, 46 insertions(+), 132 deletions(-) diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index ac2d66e6487..3c1a500bfe8 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -2,8 +2,6 @@ package com.yahoo.search.searchers; import com.yahoo.config.subscription.ConfigGetter; -import com.yahoo.language.Language; -import com.yahoo.language.process.Embedder; import com.yahoo.prelude.IndexFacts; import com.yahoo.prelude.IndexModel; import com.yahoo.prelude.SearchDefinition; @@ -12,16 +10,7 @@ import com.yahoo.search.Result; import com.yahoo.search.query.QueryTree; import com.yahoo.search.query.parser.Parsable; import com.yahoo.search.query.parser.ParserEnvironment; -import com.yahoo.search.query.profile.QueryProfile; -import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.search.query.profile.types.FieldDescription; -import com.yahoo.search.query.profile.types.QueryProfileType; -import com.yahoo.search.query.profile.types.QueryProfileTypeRegistry; import com.yahoo.search.result.ErrorMessage; -import com.yahoo.search.schema.Cluster; -import com.yahoo.search.schema.RankProfile; -import com.yahoo.search.schema.Schema; -import com.yahoo.search.schema.SchemaInfo; import com.yahoo.search.searchchain.Execution; import com.yahoo.search.yql.YqlParser; import com.yahoo.tensor.Tensor; @@ -29,13 +18,8 @@ import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.AttributesConfig; import org.junit.jupiter.api.Test; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.fail; /** * @author arnej @@ -79,16 +63,6 @@ public class ValidateNearestNeighborTestCase { )); } - private static SchemaInfo createSchemaInfo() { - List schemas = new ArrayList<>(); - RankProfile.Builder common = new RankProfile.Builder("default") - .addInput("query(qvector)", TensorType.fromSpec("tensor(x[3])")); - schemas.add(new Schema.Builder("document").add(common.build()).build()); - List clusters = new ArrayList<>(); - clusters.add(new Cluster.Builder("test").addSchema("document").build()); - return new SchemaInfo(schemas, clusters); - } - private static final TensorType tt_dense_dvector_42 = TensorType.fromSpec("tensor(x[42])"); private static final TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])"); private static final TensorType tt_dense_dvector_2 = TensorType.fromSpec("tensor(x[2])"); @@ -96,6 +70,41 @@ public class ValidateNearestNeighborTestCase { private static final TensorType tt_dense_matrix_xy = TensorType.fromSpec("tensor(x[3],y[1])"); private static final TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})"); + private Tensor makeTensor(TensorType tensorType) { + return makeTensor(tensorType, 3); + } + + private Tensor makeTensor(TensorType tensorType, int numCells) { + Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); + double dv = 1.0; + String tensorDimension = "x"; + for (long label = 0; label < numCells; label++) { + tensorBuilder.cell() + .label(tensorDimension, label) + .value(dv); + dv += 1.0; + } + return tensorBuilder.build(); + } + + private Tensor makeMatrix(TensorType tensorType) { + Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); + double dv = 1.0; + String tensorDimension = "x"; + for (long label = 0; label < 3; label++) { + tensorBuilder.cell() + .label("y", 0L) + .label(tensorDimension, label) + .value(dv); + dv += 1.0; + } + return tensorBuilder.build(); + } + + private String makeQuery(String attributeTensor, String queryTensor) { + return "select * from sources * where [{targetHits:1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ")"; + } + @Test void testValidQueryDoubleVectors() { String q = makeQuery("dvector", "qvector"); @@ -141,14 +150,16 @@ public class ValidateNearestNeighborTestCase { } static String desc(String field, String qt, int th, String errmsg) { - return "NEAREST_NEIGHBOR {" + - "field=" + field + - ",queryTensorName=" + qt + - ",hnsw.exploreAdditionalHits=0" + - ",distanceThreshold=Infinity" + - ",approximate=true" + - ",targetHits=" + th + - "} " + errmsg; + StringBuilder r = new StringBuilder(); + r.append("NEAREST_NEIGHBOR {"); + r.append("field=").append(field); + r.append(",queryTensorName=").append(qt); + r.append(",hnsw.exploreAdditionalHits=0"); + r.append(",distanceThreshold=Infinity"); + r.append(",approximate=true"); + r.append(",targetHits=").append(th); + r.append("} ").append(errmsg); + return r.toString(); } @Test @@ -221,111 +232,14 @@ public class ValidateNearestNeighborTestCase { assertErrMsg(desc("matrix", "qvector", 1, "field type tensor(x[3],y[1]) is not supported by nearest neighbor searcher"), r); } - @Test - void testWithQueryProfileArgument() { - var embedder = new MockEmbedder("test text", - Language.UNKNOWN, - Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]")); - var registry = new QueryProfileRegistry(); - var profile = new QueryProfile("test"); - profile.set("ranking.features.query(qvector)", "embed(@foo)", registry); - registry.register(profile); - var queryString = makeQuery("fvector", "qvector"); - var query = new Query.Builder() - .setSchemaInfo(createSchemaInfo()) - .setQueryProfile(registry.compile().findQueryProfile("test")) - .setEmbedder(embedder) - .setRequestMap(Map.of("foo", "test text")) - .build(); - setYqlQuery(query, queryString); - var result = doSearch(searcher, query); - assertNull(result.hits().getError()); - } - - private Tensor makeTensor(TensorType tensorType) { - return makeTensor(tensorType, 3); - } - - private Tensor makeTensor(TensorType tensorType, int numCells) { - Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); - double dv = 1.0; - String tensorDimension = "x"; - for (long label = 0; label < numCells; label++) { - tensorBuilder.cell() - .label(tensorDimension, label) - .value(dv); - dv += 1.0; - } - return tensorBuilder.build(); - } - - private Tensor makeMatrix(TensorType tensorType) { - Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); - double dv = 1.0; - String tensorDimension = "x"; - for (long label = 0; label < 3; label++) { - tensorBuilder.cell() - .label("y", 0L) - .label(tensorDimension, label) - .value(dv); - dv += 1.0; - } - return tensorBuilder.build(); - } - - private String makeQuery(String attributeTensor, String queryTensor) { - return "select * from sources * where [{targetHits:1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ")"; - } - private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Tensor qTensor) { - return doSearch(searcher, setYqlQuery(new Query(), yqlQuery), qTensor); - } - - private static Query setYqlQuery(Query query, String yqlQuery) { QueryTree queryTree = new YqlParser(new ParserEnvironment()).parse(new Parsable().setQuery(yqlQuery)); + Query query = new Query(); query.getModel().getQueryTree().setRoot(queryTree.getRoot()); - return query; - } - - private static Result doSearch(ValidateNearestNeighborSearcher searcher, Query query, Tensor qTensor) { query.getRanking().getFeatures().put("query(qvector)", qTensor); - return doSearch(searcher, query); - } - - private static Result doSearch(ValidateNearestNeighborSearcher searcher, Query query) { SearchDefinition searchDefinition = new SearchDefinition("document"); IndexFacts indexFacts = new IndexFacts(new IndexModel(searchDefinition)); return new Execution(searcher, Execution.Context.createContextStub(indexFacts)).search(query); } - private static final class MockEmbedder implements Embedder { - - private final String expectedText; - private final Language expectedLanguage; - private final Tensor tensorToReturn; - - public MockEmbedder(String expectedText, - Language expectedLanguage, - Tensor tensorToReturn) { - this.expectedText = expectedText; - this.expectedLanguage = expectedLanguage; - this.tensorToReturn = tensorToReturn; - } - - @Override - public List embed(String text, Embedder.Context context) { - fail("Unexpected call"); - return null; - } - - @Override - public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { - assertEquals(expectedText, text); - assertEquals(expectedLanguage, context.getLanguage()); - assertEquals(tensorToReturn.type(), tensorType); - return tensorToReturn; - } - - } - } -- cgit v1.2.3