aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2019-11-20 11:15:58 +0000
committerArne Juul <arnej@verizonmedia.com>2019-11-20 11:50:09 +0000
commitd67eae3b41eda7e90ad5b31255b8d3c7b12eec48 (patch)
tree7ab6d95b3384564c5c3ee82476bbbae7d8ae12f5 /container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
parent7d8c0ced9bbb71d00a9a6eae2f070ba685f20164 (diff)
add unit test + minor fixes for ValidateNearestNeighborSearcher
Diffstat (limited to 'container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java')
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java168
1 files changed, 168 insertions, 0 deletions
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
new file mode 100644
index 00000000000..cd1849a3586
--- /dev/null
+++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
@@ -0,0 +1,168 @@
+
+package com.yahoo.prelude.searcher;
+
+import com.google.common.util.concurrent.MoreExecutors;
+
+import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.config.subscription.RawSource;
+import com.yahoo.language.Linguistics;
+import com.yahoo.language.simple.SimpleLinguistics;
+import com.yahoo.prelude.Index;
+import com.yahoo.prelude.IndexFacts;
+import com.yahoo.prelude.IndexModel;
+import com.yahoo.prelude.SearchDefinition;
+import com.yahoo.search.Query;
+import com.yahoo.search.query.parser.Parsable;
+import com.yahoo.search.query.parser.ParserEnvironment;
+import com.yahoo.search.query.QueryTree;
+import com.yahoo.search.rendering.RendererRegistry;
+import com.yahoo.search.Result;
+import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.searchchain.Execution;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.searchers.ValidateNearestNeighborSearcher;
+import com.yahoo.search.yql.YqlParser;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.config.search.AttributesConfig;
+
+import java.util.*;
+
+import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/**
+ * @author arnej
+ */
+public class ValidateNearestNeighborTestCase {
+
+ ValidateNearestNeighborSearcher searcher;
+
+ public ValidateNearestNeighborTestCase() {
+ searcher = new ValidateNearestNeighborSearcher(
+ ConfigGetter.getConfig(AttributesConfig.class,
+ "raw:",
+ new RawSource("attribute[4]\n" +
+ "attribute[0].name simple\n" +
+ "attribute[0].datatype INT32\n" +
+ "attribute[1].name dvector\n" +
+ "attribute[1].datatype TENSOR\n" +
+ "attribute[1].tensortype tensor(x[3])\n" +
+ "attribute[2].name fvector\n" +
+ "attribute[2].datatype TENSOR\n" +
+ "attribute[2].tensortype tensor<float>(x[3])\n" +
+ "attribute[3].name sparse\n" +
+ "attribute[3].datatype TENSOR\n" +
+ "attribute[3].tensortype tensor(x{})"
+ )));
+ }
+
+ private static TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])");
+ private static TensorType tt_dense_fvector_3 = TensorType.fromSpec("tensor<float>(x[3])");
+ private static TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})");
+
+ private Tensor makeTensor(TensorType tensorType) {
+ Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
+ double dv = 1.0;
+ String tensorDimension = "x";
+ for (int label = 0; label < 3; label++) {
+ tensorBuilder.cell().label(tensorDimension, Integer.toString(label)).value(dv);
+ dv += 1.0;
+ }
+ return tensorBuilder.build();
+ }
+
+ @Test
+ public void testValidQueryDV() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertNull(r.hits().getError());
+ }
+
+ @Test
+ public void testValidQueryFV() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);";
+ Tensor t = makeTensor(tt_dense_fvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertNull(r.hits().getError());
+ }
+
+ private static void assertErrMsg(String message, Result r) {
+ assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError());
+ }
+
+ @Test
+ public void testMissingTargetNumHits() {
+ String q = "select * from sources * where nearestNeighbor(dvector,qvector);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=0} has invalid targetNumHits", r);
+ }
+
+ @Test
+ public void testMissingQueryTensor() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,foo);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=foo,targetNumHits=1} query tensor not found", r);
+ }
+
+ @Test
+ public void testQueryTensorWrongType() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);";
+ Result r = doSearch(searcher, q, "tensor string");
+ assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: class java.lang.String", r);
+ r = doSearch(searcher, q, null);
+ assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: null", r);
+ }
+
+ @Test
+ public void testWrongTensorType() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=fvector,queryTensorName=qvector,targetNumHits=1} field type tensor<float>(x[3]) does not match query tensor type tensor(x[3])", r);
+ }
+
+ @Test
+ public void testNotAttribute() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(foo,qvector);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=foo,queryTensorName=qvector,targetNumHits=1} field is not an attribute", r);
+ }
+
+ @Test
+ public void testWrongAttributeType() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(simple,qvector);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=simple,queryTensorName=qvector,targetNumHits=1} field is not a tensor", r);
+ }
+
+ @Test
+ public void testSparseTensor() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(sparse,qvector);";
+ Tensor t = makeTensor(tt_sparse_vector_x);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r);
+ }
+
+ private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Object qTensor) {
+ QueryTree queryTree = new YqlParser(new ParserEnvironment()).parse(new Parsable().setQuery(yqlQuery));
+ Query query = new Query();
+ query.getModel().getQueryTree().setRoot(queryTree.getRoot());
+ query.getRanking().getProperties().put("qvector", qTensor);
+ TreeMap<String, List<String>> masterClusters = new TreeMap<>();
+ masterClusters.put("cluster", Arrays.asList("document"));
+ SearchDefinition searchDefinition = new SearchDefinition("document");
+ Map<String, SearchDefinition> searchDefinitionMap = new HashMap<>();
+ searchDefinitionMap.put("document", searchDefinition);
+ IndexFacts indexFacts = new IndexFacts(new IndexModel(masterClusters, searchDefinitionMap, searchDefinition));
+ Execution.Context context = new Execution.Context(null, indexFacts, null, new RendererRegistry(MoreExecutors.directExecutor()), new SimpleLinguistics());
+ return new Execution(searcher, context).search(query);
+ }
+
+}