diff options
author | Arne Juul <arnej@verizonmedia.com> | 2019-11-20 11:15:58 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2019-11-20 11:50:09 +0000 |
commit | d67eae3b41eda7e90ad5b31255b8d3c7b12eec48 (patch) | |
tree | 7ab6d95b3384564c5c3ee82476bbbae7d8ae12f5 /container-search | |
parent | 7d8c0ced9bbb71d00a9a6eae2f070ba685f20164 (diff) |
add unit test + minor fixes for ValidateNearestNeighborSearcher
Diffstat (limited to 'container-search')
2 files changed, 192 insertions, 4 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java index dfd8205cd05..25c65783821 100644 --- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java @@ -7,6 +7,7 @@ import com.google.common.annotations.Beta; import com.yahoo.container.QrSearchersConfig; import com.yahoo.prelude.query.Item; import com.yahoo.prelude.query.NearestNeighborItem; +import com.yahoo.prelude.query.QueryCanonicalizer; import com.yahoo.prelude.query.ToolBox; import com.yahoo.search.Query; import com.yahoo.search.Result; @@ -17,12 +18,17 @@ import com.yahoo.search.searchchain.Execution; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.AttributesConfig; +import com.yahoo.yolean.chain.After; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +// This depends on tensors in query.getRanking which are moved to rank.properties during query.prepare() +// Query.prepare is done at the same time as canonicalization (by GroupingExecutor), so use that constraint. +@After(QueryCanonicalizer.queryCanonicalization) + /** * Validates any NearestNeighborItem query items. * @@ -79,6 +85,15 @@ public class ValidateNearestNeighborSearcher extends Searcher { errorMessage = Optional.of(ErrorMessage.createIllegalQuery(description)); } + private static boolean isDenseVector(TensorType tt) { + List<TensorType.Dimension> dims = tt.dimensions(); + if (dims.size() != 1) return false; + for (var d : dims) { + if (d.type() != TensorType.Dimension.Type.indexedBound) return false; + } + return true; + } + private void validate(NearestNeighborItem item) { int target = item.getTargetNumHits(); if (target < 1) { @@ -88,16 +103,17 @@ public class ValidateNearestNeighborSearcher extends Searcher { String qprop = item.getQueryTensorName(); List<Object> rankPropValList = rankProperties.asMap().get(qprop); if (rankPropValList == null) { - setError(item.toString() + " query property not found"); + setError(item.toString() + " query tensor not found"); return; } if (rankPropValList.size() != 1) { - setError(item.toString() + " query property does not have a single value"); + setError(item.toString() + " query tensor does not have a single value"); return; } Object rankPropValue = rankPropValList.get(0); if (! (rankPropValue instanceof Tensor)) { - setError(item.toString() + " query property should be a tensor, was: "+rankPropValue); + setError(item.toString() + " query tensor should be a tensor, was: "+ + (rankPropValue == null ? "null" : rankPropValue.getClass().toString())); return; } Tensor qTensor = (Tensor)rankPropValue; @@ -111,7 +127,11 @@ public class ValidateNearestNeighborSearcher extends Searcher { return; } if (! fTensorType.equals(qTensorType)) { - setError(item.toString() + " field type "+fTensorType+" does not match query property type "+qTensorType); + setError(item.toString() + " field type "+fTensorType+" does not match query tensor type "+qTensorType); + return; + } + if (! isDenseVector(fTensorType)) { + setError(item.toString() + " tensor type "+fTensorType+" is not a dense vector"); return; } } else { diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java new file mode 100644 index 00000000000..cd1849a3586 --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -0,0 +1,168 @@ + +package com.yahoo.prelude.searcher; + +import com.google.common.util.concurrent.MoreExecutors; + +import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.config.subscription.RawSource; +import com.yahoo.language.Linguistics; +import com.yahoo.language.simple.SimpleLinguistics; +import com.yahoo.prelude.Index; +import com.yahoo.prelude.IndexFacts; +import com.yahoo.prelude.IndexModel; +import com.yahoo.prelude.SearchDefinition; +import com.yahoo.search.Query; +import com.yahoo.search.query.parser.Parsable; +import com.yahoo.search.query.parser.ParserEnvironment; +import com.yahoo.search.query.QueryTree; +import com.yahoo.search.rendering.RendererRegistry; +import com.yahoo.search.Result; +import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.searchchain.Execution; +import com.yahoo.search.Searcher; +import com.yahoo.search.searchers.ValidateNearestNeighborSearcher; +import com.yahoo.search.yql.YqlParser; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.config.search.AttributesConfig; + +import java.util.*; + +import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** + * @author arnej + */ +public class ValidateNearestNeighborTestCase { + + ValidateNearestNeighborSearcher searcher; + + public ValidateNearestNeighborTestCase() { + searcher = new ValidateNearestNeighborSearcher( + ConfigGetter.getConfig(AttributesConfig.class, + "raw:", + new RawSource("attribute[4]\n" + + "attribute[0].name simple\n" + + "attribute[0].datatype INT32\n" + + "attribute[1].name dvector\n" + + "attribute[1].datatype TENSOR\n" + + "attribute[1].tensortype tensor(x[3])\n" + + "attribute[2].name fvector\n" + + "attribute[2].datatype TENSOR\n" + + "attribute[2].tensortype tensor<float>(x[3])\n" + + "attribute[3].name sparse\n" + + "attribute[3].datatype TENSOR\n" + + "attribute[3].tensortype tensor(x{})" + ))); + } + + private static TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])"); + private static TensorType tt_dense_fvector_3 = TensorType.fromSpec("tensor<float>(x[3])"); + private static TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})"); + + private Tensor makeTensor(TensorType tensorType) { + Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); + double dv = 1.0; + String tensorDimension = "x"; + for (int label = 0; label < 3; label++) { + tensorBuilder.cell().label(tensorDimension, Integer.toString(label)).value(dv); + dv += 1.0; + } + return tensorBuilder.build(); + } + + @Test + public void testValidQueryDV() { + String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);"; + Tensor t = makeTensor(tt_dense_dvector_3); + Result r = doSearch(searcher, q, t); + assertNull(r.hits().getError()); + } + + @Test + public void testValidQueryFV() { + String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);"; + Tensor t = makeTensor(tt_dense_fvector_3); + Result r = doSearch(searcher, q, t); + assertNull(r.hits().getError()); + } + + private static void assertErrMsg(String message, Result r) { + assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError()); + } + + @Test + public void testMissingTargetNumHits() { + String q = "select * from sources * where nearestNeighbor(dvector,qvector);"; + Tensor t = makeTensor(tt_dense_dvector_3); + Result r = doSearch(searcher, q, t); + assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=0} has invalid targetNumHits", r); + } + + @Test + public void testMissingQueryTensor() { + String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,foo);"; + Tensor t = makeTensor(tt_dense_dvector_3); + Result r = doSearch(searcher, q, t); + assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=foo,targetNumHits=1} query tensor not found", r); + } + + @Test + public void testQueryTensorWrongType() { + String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);"; + Result r = doSearch(searcher, q, "tensor string"); + assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: class java.lang.String", r); + r = doSearch(searcher, q, null); + assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: null", r); + } + + @Test + public void testWrongTensorType() { + String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);"; + Tensor t = makeTensor(tt_dense_dvector_3); + Result r = doSearch(searcher, q, t); + assertErrMsg("NEAREST_NEIGHBOR {field=fvector,queryTensorName=qvector,targetNumHits=1} field type tensor<float>(x[3]) does not match query tensor type tensor(x[3])", r); + } + + @Test + public void testNotAttribute() { + String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(foo,qvector);"; + Tensor t = makeTensor(tt_dense_dvector_3); + Result r = doSearch(searcher, q, t); + assertErrMsg("NEAREST_NEIGHBOR {field=foo,queryTensorName=qvector,targetNumHits=1} field is not an attribute", r); + } + + @Test + public void testWrongAttributeType() { + String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(simple,qvector);"; + Tensor t = makeTensor(tt_dense_dvector_3); + Result r = doSearch(searcher, q, t); + assertErrMsg("NEAREST_NEIGHBOR {field=simple,queryTensorName=qvector,targetNumHits=1} field is not a tensor", r); + } + + @Test + public void testSparseTensor() { + String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(sparse,qvector);"; + Tensor t = makeTensor(tt_sparse_vector_x); + Result r = doSearch(searcher, q, t); + assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r); + } + + private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Object qTensor) { + QueryTree queryTree = new YqlParser(new ParserEnvironment()).parse(new Parsable().setQuery(yqlQuery)); + Query query = new Query(); + query.getModel().getQueryTree().setRoot(queryTree.getRoot()); + query.getRanking().getProperties().put("qvector", qTensor); + TreeMap<String, List<String>> masterClusters = new TreeMap<>(); + masterClusters.put("cluster", Arrays.asList("document")); + SearchDefinition searchDefinition = new SearchDefinition("document"); + Map<String, SearchDefinition> searchDefinitionMap = new HashMap<>(); + searchDefinitionMap.put("document", searchDefinition); + IndexFacts indexFacts = new IndexFacts(new IndexModel(masterClusters, searchDefinitionMap, searchDefinition)); + Execution.Context context = new Execution.Context(null, indexFacts, null, new RendererRegistry(MoreExecutors.directExecutor()), new SimpleLinguistics()); + return new Execution(searcher, context).search(query); + } + +} |