summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2019-11-20 11:15:58 +0000
committerArne Juul <arnej@verizonmedia.com>2019-11-20 11:50:09 +0000
commitd67eae3b41eda7e90ad5b31255b8d3c7b12eec48 (patch)
tree7ab6d95b3384564c5c3ee82476bbbae7d8ae12f5 /container-search
parent7d8c0ced9bbb71d00a9a6eae2f070ba685f20164 (diff)
add unit test + minor fixes for ValidateNearestNeighborSearcher
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java28
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java168
2 files changed, 192 insertions, 4 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
index dfd8205cd05..25c65783821 100644
--- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
+++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
@@ -7,6 +7,7 @@ import com.google.common.annotations.Beta;
import com.yahoo.container.QrSearchersConfig;
import com.yahoo.prelude.query.Item;
import com.yahoo.prelude.query.NearestNeighborItem;
+import com.yahoo.prelude.query.QueryCanonicalizer;
import com.yahoo.prelude.query.ToolBox;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
@@ -17,12 +18,17 @@ import com.yahoo.search.searchchain.Execution;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.AttributesConfig;
+import com.yahoo.yolean.chain.After;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+// This depends on tensors in query.getRanking which are moved to rank.properties during query.prepare()
+// Query.prepare is done at the same time as canonicalization (by GroupingExecutor), so use that constraint.
+@After(QueryCanonicalizer.queryCanonicalization)
+
/**
* Validates any NearestNeighborItem query items.
*
@@ -79,6 +85,15 @@ public class ValidateNearestNeighborSearcher extends Searcher {
errorMessage = Optional.of(ErrorMessage.createIllegalQuery(description));
}
+ private static boolean isDenseVector(TensorType tt) {
+ List<TensorType.Dimension> dims = tt.dimensions();
+ if (dims.size() != 1) return false;
+ for (var d : dims) {
+ if (d.type() != TensorType.Dimension.Type.indexedBound) return false;
+ }
+ return true;
+ }
+
private void validate(NearestNeighborItem item) {
int target = item.getTargetNumHits();
if (target < 1) {
@@ -88,16 +103,17 @@ public class ValidateNearestNeighborSearcher extends Searcher {
String qprop = item.getQueryTensorName();
List<Object> rankPropValList = rankProperties.asMap().get(qprop);
if (rankPropValList == null) {
- setError(item.toString() + " query property not found");
+ setError(item.toString() + " query tensor not found");
return;
}
if (rankPropValList.size() != 1) {
- setError(item.toString() + " query property does not have a single value");
+ setError(item.toString() + " query tensor does not have a single value");
return;
}
Object rankPropValue = rankPropValList.get(0);
if (! (rankPropValue instanceof Tensor)) {
- setError(item.toString() + " query property should be a tensor, was: "+rankPropValue);
+ setError(item.toString() + " query tensor should be a tensor, was: "+
+ (rankPropValue == null ? "null" : rankPropValue.getClass().toString()));
return;
}
Tensor qTensor = (Tensor)rankPropValue;
@@ -111,7 +127,11 @@ public class ValidateNearestNeighborSearcher extends Searcher {
return;
}
if (! fTensorType.equals(qTensorType)) {
- setError(item.toString() + " field type "+fTensorType+" does not match query property type "+qTensorType);
+ setError(item.toString() + " field type "+fTensorType+" does not match query tensor type "+qTensorType);
+ return;
+ }
+ if (! isDenseVector(fTensorType)) {
+ setError(item.toString() + " tensor type "+fTensorType+" is not a dense vector");
return;
}
} else {
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
new file mode 100644
index 00000000000..cd1849a3586
--- /dev/null
+++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
@@ -0,0 +1,168 @@
+
+package com.yahoo.prelude.searcher;
+
+import com.google.common.util.concurrent.MoreExecutors;
+
+import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.config.subscription.RawSource;
+import com.yahoo.language.Linguistics;
+import com.yahoo.language.simple.SimpleLinguistics;
+import com.yahoo.prelude.Index;
+import com.yahoo.prelude.IndexFacts;
+import com.yahoo.prelude.IndexModel;
+import com.yahoo.prelude.SearchDefinition;
+import com.yahoo.search.Query;
+import com.yahoo.search.query.parser.Parsable;
+import com.yahoo.search.query.parser.ParserEnvironment;
+import com.yahoo.search.query.QueryTree;
+import com.yahoo.search.rendering.RendererRegistry;
+import com.yahoo.search.Result;
+import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.searchchain.Execution;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.searchers.ValidateNearestNeighborSearcher;
+import com.yahoo.search.yql.YqlParser;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.config.search.AttributesConfig;
+
+import java.util.*;
+
+import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/**
+ * @author arnej
+ */
+public class ValidateNearestNeighborTestCase {
+
+ ValidateNearestNeighborSearcher searcher;
+
+ public ValidateNearestNeighborTestCase() {
+ searcher = new ValidateNearestNeighborSearcher(
+ ConfigGetter.getConfig(AttributesConfig.class,
+ "raw:",
+ new RawSource("attribute[4]\n" +
+ "attribute[0].name simple\n" +
+ "attribute[0].datatype INT32\n" +
+ "attribute[1].name dvector\n" +
+ "attribute[1].datatype TENSOR\n" +
+ "attribute[1].tensortype tensor(x[3])\n" +
+ "attribute[2].name fvector\n" +
+ "attribute[2].datatype TENSOR\n" +
+ "attribute[2].tensortype tensor<float>(x[3])\n" +
+ "attribute[3].name sparse\n" +
+ "attribute[3].datatype TENSOR\n" +
+ "attribute[3].tensortype tensor(x{})"
+ )));
+ }
+
+ private static TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])");
+ private static TensorType tt_dense_fvector_3 = TensorType.fromSpec("tensor<float>(x[3])");
+ private static TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})");
+
+ private Tensor makeTensor(TensorType tensorType) {
+ Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
+ double dv = 1.0;
+ String tensorDimension = "x";
+ for (int label = 0; label < 3; label++) {
+ tensorBuilder.cell().label(tensorDimension, Integer.toString(label)).value(dv);
+ dv += 1.0;
+ }
+ return tensorBuilder.build();
+ }
+
+ @Test
+ public void testValidQueryDV() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertNull(r.hits().getError());
+ }
+
+ @Test
+ public void testValidQueryFV() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);";
+ Tensor t = makeTensor(tt_dense_fvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertNull(r.hits().getError());
+ }
+
+ private static void assertErrMsg(String message, Result r) {
+ assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError());
+ }
+
+ @Test
+ public void testMissingTargetNumHits() {
+ String q = "select * from sources * where nearestNeighbor(dvector,qvector);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=0} has invalid targetNumHits", r);
+ }
+
+ @Test
+ public void testMissingQueryTensor() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,foo);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=foo,targetNumHits=1} query tensor not found", r);
+ }
+
+ @Test
+ public void testQueryTensorWrongType() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);";
+ Result r = doSearch(searcher, q, "tensor string");
+ assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: class java.lang.String", r);
+ r = doSearch(searcher, q, null);
+ assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: null", r);
+ }
+
+ @Test
+ public void testWrongTensorType() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=fvector,queryTensorName=qvector,targetNumHits=1} field type tensor<float>(x[3]) does not match query tensor type tensor(x[3])", r);
+ }
+
+ @Test
+ public void testNotAttribute() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(foo,qvector);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=foo,queryTensorName=qvector,targetNumHits=1} field is not an attribute", r);
+ }
+
+ @Test
+ public void testWrongAttributeType() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(simple,qvector);";
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=simple,queryTensorName=qvector,targetNumHits=1} field is not a tensor", r);
+ }
+
+ @Test
+ public void testSparseTensor() {
+ String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(sparse,qvector);";
+ Tensor t = makeTensor(tt_sparse_vector_x);
+ Result r = doSearch(searcher, q, t);
+ assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r);
+ }
+
+ private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Object qTensor) {
+ QueryTree queryTree = new YqlParser(new ParserEnvironment()).parse(new Parsable().setQuery(yqlQuery));
+ Query query = new Query();
+ query.getModel().getQueryTree().setRoot(queryTree.getRoot());
+ query.getRanking().getProperties().put("qvector", qTensor);
+ TreeMap<String, List<String>> masterClusters = new TreeMap<>();
+ masterClusters.put("cluster", Arrays.asList("document"));
+ SearchDefinition searchDefinition = new SearchDefinition("document");
+ Map<String, SearchDefinition> searchDefinitionMap = new HashMap<>();
+ searchDefinitionMap.put("document", searchDefinition);
+ IndexFacts indexFacts = new IndexFacts(new IndexModel(masterClusters, searchDefinitionMap, searchDefinition));
+ Execution.Context context = new Execution.Context(null, indexFacts, null, new RendererRegistry(MoreExecutors.directExecutor()), new SimpleLinguistics());
+ return new Execution(searcher, context).search(query);
+ }
+
+}