summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorHarald Musum <musum@vespa.ai>2024-02-23 09:10:57 +0100
committerGitHub <noreply@github.com>2024-02-23 09:10:57 +0100
commit7bfb8834fb328bb6350411e39281e6f5372a13e9 (patch)
tree5a257bb4415cab3f5f19e41fec3013d9b4679f8e /container-search
parent86f5d187f64868fecc69af4fa2c2677f04044a5e (diff)
Revert "Add embed + NN test"
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java178
1 files changed, 46 insertions, 132 deletions
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
index ac2d66e6487..3c1a500bfe8 100644
--- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
@@ -2,8 +2,6 @@
package com.yahoo.search.searchers;
import com.yahoo.config.subscription.ConfigGetter;
-import com.yahoo.language.Language;
-import com.yahoo.language.process.Embedder;
import com.yahoo.prelude.IndexFacts;
import com.yahoo.prelude.IndexModel;
import com.yahoo.prelude.SearchDefinition;
@@ -12,16 +10,7 @@ import com.yahoo.search.Result;
import com.yahoo.search.query.QueryTree;
import com.yahoo.search.query.parser.Parsable;
import com.yahoo.search.query.parser.ParserEnvironment;
-import com.yahoo.search.query.profile.QueryProfile;
-import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.search.query.profile.types.FieldDescription;
-import com.yahoo.search.query.profile.types.QueryProfileType;
-import com.yahoo.search.query.profile.types.QueryProfileTypeRegistry;
import com.yahoo.search.result.ErrorMessage;
-import com.yahoo.search.schema.Cluster;
-import com.yahoo.search.schema.RankProfile;
-import com.yahoo.search.schema.Schema;
-import com.yahoo.search.schema.SchemaInfo;
import com.yahoo.search.searchchain.Execution;
import com.yahoo.search.yql.YqlParser;
import com.yahoo.tensor.Tensor;
@@ -29,13 +18,8 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.AttributesConfig;
import org.junit.jupiter.api.Test;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
-import static org.junit.jupiter.api.Assertions.fail;
/**
* @author arnej
@@ -79,16 +63,6 @@ public class ValidateNearestNeighborTestCase {
));
}
- private static SchemaInfo createSchemaInfo() {
- List<Schema> schemas = new ArrayList<>();
- RankProfile.Builder common = new RankProfile.Builder("default")
- .addInput("query(qvector)", TensorType.fromSpec("tensor<float>(x[3])"));
- schemas.add(new Schema.Builder("document").add(common.build()).build());
- List<Cluster> clusters = new ArrayList<>();
- clusters.add(new Cluster.Builder("test").addSchema("document").build());
- return new SchemaInfo(schemas, clusters);
- }
-
private static final TensorType tt_dense_dvector_42 = TensorType.fromSpec("tensor(x[42])");
private static final TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])");
private static final TensorType tt_dense_dvector_2 = TensorType.fromSpec("tensor(x[2])");
@@ -96,6 +70,41 @@ public class ValidateNearestNeighborTestCase {
private static final TensorType tt_dense_matrix_xy = TensorType.fromSpec("tensor(x[3],y[1])");
private static final TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})");
+ private Tensor makeTensor(TensorType tensorType) {
+ return makeTensor(tensorType, 3);
+ }
+
+ private Tensor makeTensor(TensorType tensorType, int numCells) {
+ Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
+ double dv = 1.0;
+ String tensorDimension = "x";
+ for (long label = 0; label < numCells; label++) {
+ tensorBuilder.cell()
+ .label(tensorDimension, label)
+ .value(dv);
+ dv += 1.0;
+ }
+ return tensorBuilder.build();
+ }
+
+ private Tensor makeMatrix(TensorType tensorType) {
+ Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
+ double dv = 1.0;
+ String tensorDimension = "x";
+ for (long label = 0; label < 3; label++) {
+ tensorBuilder.cell()
+ .label("y", 0L)
+ .label(tensorDimension, label)
+ .value(dv);
+ dv += 1.0;
+ }
+ return tensorBuilder.build();
+ }
+
+ private String makeQuery(String attributeTensor, String queryTensor) {
+ return "select * from sources * where [{targetHits:1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ")";
+ }
+
@Test
void testValidQueryDoubleVectors() {
String q = makeQuery("dvector", "qvector");
@@ -141,14 +150,16 @@ public class ValidateNearestNeighborTestCase {
}
static String desc(String field, String qt, int th, String errmsg) {
- return "NEAREST_NEIGHBOR {" +
- "field=" + field +
- ",queryTensorName=" + qt +
- ",hnsw.exploreAdditionalHits=0" +
- ",distanceThreshold=Infinity" +
- ",approximate=true" +
- ",targetHits=" + th +
- "} " + errmsg;
+ StringBuilder r = new StringBuilder();
+ r.append("NEAREST_NEIGHBOR {");
+ r.append("field=").append(field);
+ r.append(",queryTensorName=").append(qt);
+ r.append(",hnsw.exploreAdditionalHits=0");
+ r.append(",distanceThreshold=Infinity");
+ r.append(",approximate=true");
+ r.append(",targetHits=").append(th);
+ r.append("} ").append(errmsg);
+ return r.toString();
}
@Test
@@ -221,111 +232,14 @@ public class ValidateNearestNeighborTestCase {
assertErrMsg(desc("matrix", "qvector", 1, "field type tensor(x[3],y[1]) is not supported by nearest neighbor searcher"), r);
}
- @Test
- void testWithQueryProfileArgument() {
- var embedder = new MockEmbedder("test text",
- Language.UNKNOWN,
- Tensor.from("tensor<float>(x[3]):[1.0, 2.0, 3.0]"));
- var registry = new QueryProfileRegistry();
- var profile = new QueryProfile("test");
- profile.set("ranking.features.query(qvector)", "embed(@foo)", registry);
- registry.register(profile);
- var queryString = makeQuery("fvector", "qvector");
- var query = new Query.Builder()
- .setSchemaInfo(createSchemaInfo())
- .setQueryProfile(registry.compile().findQueryProfile("test"))
- .setEmbedder(embedder)
- .setRequestMap(Map.of("foo", "test text"))
- .build();
- setYqlQuery(query, queryString);
- var result = doSearch(searcher, query);
- assertNull(result.hits().getError());
- }
-
- private Tensor makeTensor(TensorType tensorType) {
- return makeTensor(tensorType, 3);
- }
-
- private Tensor makeTensor(TensorType tensorType, int numCells) {
- Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
- double dv = 1.0;
- String tensorDimension = "x";
- for (long label = 0; label < numCells; label++) {
- tensorBuilder.cell()
- .label(tensorDimension, label)
- .value(dv);
- dv += 1.0;
- }
- return tensorBuilder.build();
- }
-
- private Tensor makeMatrix(TensorType tensorType) {
- Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
- double dv = 1.0;
- String tensorDimension = "x";
- for (long label = 0; label < 3; label++) {
- tensorBuilder.cell()
- .label("y", 0L)
- .label(tensorDimension, label)
- .value(dv);
- dv += 1.0;
- }
- return tensorBuilder.build();
- }
-
- private String makeQuery(String attributeTensor, String queryTensor) {
- return "select * from sources * where [{targetHits:1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ")";
- }
-
private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Tensor qTensor) {
- return doSearch(searcher, setYqlQuery(new Query(), yqlQuery), qTensor);
- }
-
- private static Query setYqlQuery(Query query, String yqlQuery) {
QueryTree queryTree = new YqlParser(new ParserEnvironment()).parse(new Parsable().setQuery(yqlQuery));
+ Query query = new Query();
query.getModel().getQueryTree().setRoot(queryTree.getRoot());
- return query;
- }
-
- private static Result doSearch(ValidateNearestNeighborSearcher searcher, Query query, Tensor qTensor) {
query.getRanking().getFeatures().put("query(qvector)", qTensor);
- return doSearch(searcher, query);
- }
-
- private static Result doSearch(ValidateNearestNeighborSearcher searcher, Query query) {
SearchDefinition searchDefinition = new SearchDefinition("document");
IndexFacts indexFacts = new IndexFacts(new IndexModel(searchDefinition));
return new Execution(searcher, Execution.Context.createContextStub(indexFacts)).search(query);
}
- private static final class MockEmbedder implements Embedder {
-
- private final String expectedText;
- private final Language expectedLanguage;
- private final Tensor tensorToReturn;
-
- public MockEmbedder(String expectedText,
- Language expectedLanguage,
- Tensor tensorToReturn) {
- this.expectedText = expectedText;
- this.expectedLanguage = expectedLanguage;
- this.tensorToReturn = tensorToReturn;
- }
-
- @Override
- public List<Integer> embed(String text, Embedder.Context context) {
- fail("Unexpected call");
- return null;
- }
-
- @Override
- public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
- assertEquals(expectedText, text);
- assertEquals(expectedLanguage, context.getLanguage());
- assertEquals(tensorToReturn.type(), tensorType);
- return tensorToReturn;
- }
-
- }
-
}