summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2024-02-21 18:06:42 +0100
committerJon Bratseth <bratseth@vespa.ai>2024-02-21 18:06:42 +0100
commit4d57acfb75d230f4219add5c915fdf56179bdc94 (patch)
tree015b2cea20af0761687ac28c081d78e6f6cc520c /container-search
parentd9bca44b67a6911cd4d363217b50f827bd8d8a95 (diff)
Add embed + NN test
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java178
1 files changed, 132 insertions, 46 deletions
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
index 3c1a500bfe8..ac2d66e6487 100644
--- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
@@ -2,6 +2,8 @@
package com.yahoo.search.searchers;
import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
import com.yahoo.prelude.IndexFacts;
import com.yahoo.prelude.IndexModel;
import com.yahoo.prelude.SearchDefinition;
@@ -10,7 +12,16 @@ import com.yahoo.search.Result;
import com.yahoo.search.query.QueryTree;
import com.yahoo.search.query.parser.Parsable;
import com.yahoo.search.query.parser.ParserEnvironment;
+import com.yahoo.search.query.profile.QueryProfile;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.search.query.profile.types.FieldDescription;
+import com.yahoo.search.query.profile.types.QueryProfileType;
+import com.yahoo.search.query.profile.types.QueryProfileTypeRegistry;
import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.schema.Cluster;
+import com.yahoo.search.schema.RankProfile;
+import com.yahoo.search.schema.Schema;
+import com.yahoo.search.schema.SchemaInfo;
import com.yahoo.search.searchchain.Execution;
import com.yahoo.search.yql.YqlParser;
import com.yahoo.tensor.Tensor;
@@ -18,8 +29,13 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.AttributesConfig;
import org.junit.jupiter.api.Test;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.fail;
/**
* @author arnej
@@ -63,6 +79,16 @@ public class ValidateNearestNeighborTestCase {
));
}
+ private static SchemaInfo createSchemaInfo() {
+ List<Schema> schemas = new ArrayList<>();
+ RankProfile.Builder common = new RankProfile.Builder("default")
+ .addInput("query(qvector)", TensorType.fromSpec("tensor<float>(x[3])"));
+ schemas.add(new Schema.Builder("document").add(common.build()).build());
+ List<Cluster> clusters = new ArrayList<>();
+ clusters.add(new Cluster.Builder("test").addSchema("document").build());
+ return new SchemaInfo(schemas, clusters);
+ }
+
private static final TensorType tt_dense_dvector_42 = TensorType.fromSpec("tensor(x[42])");
private static final TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])");
private static final TensorType tt_dense_dvector_2 = TensorType.fromSpec("tensor(x[2])");
@@ -70,41 +96,6 @@ public class ValidateNearestNeighborTestCase {
private static final TensorType tt_dense_matrix_xy = TensorType.fromSpec("tensor(x[3],y[1])");
private static final TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})");
- private Tensor makeTensor(TensorType tensorType) {
- return makeTensor(tensorType, 3);
- }
-
- private Tensor makeTensor(TensorType tensorType, int numCells) {
- Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
- double dv = 1.0;
- String tensorDimension = "x";
- for (long label = 0; label < numCells; label++) {
- tensorBuilder.cell()
- .label(tensorDimension, label)
- .value(dv);
- dv += 1.0;
- }
- return tensorBuilder.build();
- }
-
- private Tensor makeMatrix(TensorType tensorType) {
- Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
- double dv = 1.0;
- String tensorDimension = "x";
- for (long label = 0; label < 3; label++) {
- tensorBuilder.cell()
- .label("y", 0L)
- .label(tensorDimension, label)
- .value(dv);
- dv += 1.0;
- }
- return tensorBuilder.build();
- }
-
- private String makeQuery(String attributeTensor, String queryTensor) {
- return "select * from sources * where [{targetHits:1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ")";
- }
-
@Test
void testValidQueryDoubleVectors() {
String q = makeQuery("dvector", "qvector");
@@ -150,16 +141,14 @@ public class ValidateNearestNeighborTestCase {
}
static String desc(String field, String qt, int th, String errmsg) {
- StringBuilder r = new StringBuilder();
- r.append("NEAREST_NEIGHBOR {");
- r.append("field=").append(field);
- r.append(",queryTensorName=").append(qt);
- r.append(",hnsw.exploreAdditionalHits=0");
- r.append(",distanceThreshold=Infinity");
- r.append(",approximate=true");
- r.append(",targetHits=").append(th);
- r.append("} ").append(errmsg);
- return r.toString();
+ return "NEAREST_NEIGHBOR {" +
+ "field=" + field +
+ ",queryTensorName=" + qt +
+ ",hnsw.exploreAdditionalHits=0" +
+ ",distanceThreshold=Infinity" +
+ ",approximate=true" +
+ ",targetHits=" + th +
+ "} " + errmsg;
}
@Test
@@ -232,14 +221,111 @@ public class ValidateNearestNeighborTestCase {
assertErrMsg(desc("matrix", "qvector", 1, "field type tensor(x[3],y[1]) is not supported by nearest neighbor searcher"), r);
}
+ @Test
+ void testWithQueryProfileArgument() {
+ var embedder = new MockEmbedder("test text",
+ Language.UNKNOWN,
+ Tensor.from("tensor<float>(x[3]):[1.0, 2.0, 3.0]"));
+ var registry = new QueryProfileRegistry();
+ var profile = new QueryProfile("test");
+ profile.set("ranking.features.query(qvector)", "embed(@foo)", registry);
+ registry.register(profile);
+ var queryString = makeQuery("fvector", "qvector");
+ var query = new Query.Builder()
+ .setSchemaInfo(createSchemaInfo())
+ .setQueryProfile(registry.compile().findQueryProfile("test"))
+ .setEmbedder(embedder)
+ .setRequestMap(Map.of("foo", "test text"))
+ .build();
+ setYqlQuery(query, queryString);
+ var result = doSearch(searcher, query);
+ assertNull(result.hits().getError());
+ }
+
+ private Tensor makeTensor(TensorType tensorType) {
+ return makeTensor(tensorType, 3);
+ }
+
+ private Tensor makeTensor(TensorType tensorType, int numCells) {
+ Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
+ double dv = 1.0;
+ String tensorDimension = "x";
+ for (long label = 0; label < numCells; label++) {
+ tensorBuilder.cell()
+ .label(tensorDimension, label)
+ .value(dv);
+ dv += 1.0;
+ }
+ return tensorBuilder.build();
+ }
+
+ private Tensor makeMatrix(TensorType tensorType) {
+ Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
+ double dv = 1.0;
+ String tensorDimension = "x";
+ for (long label = 0; label < 3; label++) {
+ tensorBuilder.cell()
+ .label("y", 0L)
+ .label(tensorDimension, label)
+ .value(dv);
+ dv += 1.0;
+ }
+ return tensorBuilder.build();
+ }
+
+ private String makeQuery(String attributeTensor, String queryTensor) {
+ return "select * from sources * where [{targetHits:1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ")";
+ }
+
private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Tensor qTensor) {
+ return doSearch(searcher, setYqlQuery(new Query(), yqlQuery), qTensor);
+ }
+
+ private static Query setYqlQuery(Query query, String yqlQuery) {
QueryTree queryTree = new YqlParser(new ParserEnvironment()).parse(new Parsable().setQuery(yqlQuery));
- Query query = new Query();
query.getModel().getQueryTree().setRoot(queryTree.getRoot());
+ return query;
+ }
+
+ private static Result doSearch(ValidateNearestNeighborSearcher searcher, Query query, Tensor qTensor) {
query.getRanking().getFeatures().put("query(qvector)", qTensor);
+ return doSearch(searcher, query);
+ }
+
+ private static Result doSearch(ValidateNearestNeighborSearcher searcher, Query query) {
SearchDefinition searchDefinition = new SearchDefinition("document");
IndexFacts indexFacts = new IndexFacts(new IndexModel(searchDefinition));
return new Execution(searcher, Execution.Context.createContextStub(indexFacts)).search(query);
}
+ private static final class MockEmbedder implements Embedder {
+
+ private final String expectedText;
+ private final Language expectedLanguage;
+ private final Tensor tensorToReturn;
+
+ public MockEmbedder(String expectedText,
+ Language expectedLanguage,
+ Tensor tensorToReturn) {
+ this.expectedText = expectedText;
+ this.expectedLanguage = expectedLanguage;
+ this.tensorToReturn = tensorToReturn;
+ }
+
+ @Override
+ public List<Integer> embed(String text, Embedder.Context context) {
+ fail("Unexpected call");
+ return null;
+ }
+
+ @Override
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ assertEquals(expectedText, text);
+ assertEquals(expectedLanguage, context.getLanguage());
+ assertEquals(tensorToReturn.type(), tensorType);
+ return tensorToReturn;
+ }
+
+ }
+
}