diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2020-03-11 14:12:24 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2020-03-11 14:12:24 +0100 |
commit | 9066658d74a995ae960b0bee919f0a698bb720b8 (patch) | |
tree | 7a828ca0fc85579c3922f76636c6f58a68701677 | |
parent | 3da641f5e5295f0f67ebf062618464327c2d7f40 (diff) |
Validate before prepare to avoid depending on the properties API
2 files changed, 24 insertions, 87 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java index 0f7f163dce0..76b8c1ef8a2 100644 --- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java @@ -6,24 +6,18 @@ import com.google.common.annotations.Beta; import com.yahoo.prelude.query.Item; import com.yahoo.prelude.query.NearestNeighborItem; -import com.yahoo.prelude.query.QueryCanonicalizer; import com.yahoo.prelude.query.ToolBox; -import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; -import com.yahoo.search.query.profile.QueryProfileProperties; -import com.yahoo.search.query.profile.compiled.CompiledQueryProfile; -import com.yahoo.search.query.profile.types.FieldDescription; -import com.yahoo.search.query.profile.types.QueryProfileFieldType; -import com.yahoo.search.query.profile.types.QueryProfileType; +import com.yahoo.search.grouping.vespa.GroupingExecutor; import com.yahoo.search.query.ranking.RankProperties; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.searchchain.Execution; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.AttributesConfig; -import com.yahoo.yolean.chain.After; +import com.yahoo.yolean.chain.Before; import java.util.HashMap; import java.util.List; @@ -36,9 +30,7 @@ import java.util.Optional; * @author arnej */ @Beta -// This depends on tensors in query.getRanking which are moved to rank.properties during query.prepare() -// Query.prepare is done at the same time as canonicalization (by GroupingExecutor), so use that constraint. -@After(QueryCanonicalizer.queryCanonicalization) +@Before(GroupingExecutor.COMPONENT_NAME) // Must happen before query.prepare() public class ValidateNearestNeighborSearcher extends Searcher { private Map<String, TensorType> validAttributes = new HashMap<>(); @@ -104,76 +96,30 @@ public class ValidateNearestNeighborSearcher extends Searcher { /** Returns an error message if this is invalid, or null if it is valid */ private String validate(NearestNeighborItem item) { - int target = item.getTargetNumHits(); - if (target < 1) return item + " has invalid targetNumHits"; - - List<Object> rankPropValList = rankProperties.asMap().get(item.getQueryTensorName()); - if (rankPropValList == null) return item + " query tensor not found"; - if (rankPropValList.size() != 1) return item + " query tensor does not have a single value"; - - Object rankPropValue = rankPropValList.get(0); - if (! (rankPropValue instanceof Tensor)) { - return item + " expected a query tensor but got " + - (rankPropValue == null ? "null" : rankPropValue.getClass()) + - resolvedTypeInfo(); - } + if (item.getTargetNumHits() < 1) + return item + " has invalid targetNumHits " + item.getTargetNumHits() + ": Must be >= 1"; - String field = item.getIndexName(); - if ( ! validAttributes.containsKey(field)) return item + " field is not an attribute"; + String queryFeatureName = "query(" + item.getQueryTensorName() + ")"; + Optional<Tensor> queryTensor = query.getRanking().getFeatures().getTensor(queryFeatureName); + if (queryTensor.isEmpty()) + return item + " requires a tensor rank feature " + queryFeatureName + " but this is not present"; + + if ( ! validAttributes.containsKey(item.getIndexName())) + return item + " field is not an attribute"; + TensorType fieldType = validAttributes.get(item.getIndexName()); + if (fieldType == null) return item + " field is not a tensor"; + if ( ! isDenseVector(fieldType)) + return item + " tensor type " + fieldType + " is not a dense vector"; + + if ( ! isCompatible(fieldType, queryTensor.get().type())) + return item + " field type " + fieldType + " does not match query type " + queryTensor.get().type(); - TensorType fTensorType = validAttributes.get(field); - TensorType qTensorType = ((Tensor)rankPropValue).type(); - if (fTensorType == null) return item + " field is not a tensor"; - if ( ! isCompatible(fTensorType, qTensorType)) { - return item + " field type " + fTensorType + " does not match query tensor type " + qTensorType; - } - if (! isDenseVector(fTensorType)) return item + " tensor type " + fTensorType+" is not a dense vector"; return null; } @Override public void onExit() {} - // TODO: Remove - private String resolvedTypeInfo() { - StringBuilder b = new StringBuilder(); - QueryProfileProperties properties = query.properties().getInstance(QueryProfileProperties.class); - if (properties == null) return b.toString(); - CompiledQueryProfile profile = properties.getQueryProfile(); - b.append(", profile: ").append(profile); - - CompoundName name = new CompoundName("ranking.features.query(q_vec)"); - - if ( ! profile.getTypes().isEmpty()) { - QueryProfileType type = null; - for (int i = 0; i < name.size(); i++) { - if (type == null) // We're on the first iteration, or no type is explicitly specified - type = profile.getType(name.first(i), new HashMap<>()); - if (type == null) continue; - String localName = name.get(i); - FieldDescription fieldDescription = type.getField(localName); - if (fieldDescription == null && type.isStrict()) - throw new IllegalArgumentException("'" + localName + "' is not declared in " + type + ", and the type is strict"); - - // TODO: In addition to strictness, check legality along the way - - if (fieldDescription != null) { - if (i == name.size() - 1) { // at the end of the path, check the assignment type - b.append(", field description: ").append(fieldDescription); - } else if (fieldDescription.getType() instanceof QueryProfileFieldType) { - // If a type is specified, use that instead of the type implied by the name - type = ((QueryProfileFieldType) fieldDescription.getType()).getQueryProfileType(); - } - } - - } - } - else { - b.append(", profile types is empty"); - } - return b.toString(); - } - } } diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index 0abb8d281b4..2c849a9b52c 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -149,7 +149,7 @@ public class ValidateNearestNeighborTestCase { String q = "select * from sources * where nearestNeighbor(dvector,qvector);"; Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); - assertErrMsg(desc("dvector", "qvector", 0, "has invalid targetNumHits"), r); + assertErrMsg(desc("dvector", "qvector", 0, "has invalid targetNumHits 0: Must be >= 1"), r); } @Test @@ -157,16 +157,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("dvector", "foo"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); - assertErrMsg(desc("dvector", "foo", 1, "query tensor not found"), r); - } - - @Test - public void testQueryTensorWrongType() { - String q = makeQuery("dvector", "qvector"); - Result r = doSearch(searcher, q, "tensor string"); - assertErrMsg(desc("dvector", "qvector", 1, "expected a query tensor but got class java.lang.String"), r); - r = doSearch(searcher, q, null); - assertErrMsg(desc("dvector", "qvector", 1, "expected a query tensor but got null"), r); + assertErrMsg(desc("dvector", "foo", 1, "requires a tensor rank feature query(foo) but this is not present"), r); } @Test @@ -174,7 +165,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("dvector", "qvector"); Tensor t = makeTensor(tt_dense_dvector_2, 2); Result r = doSearch(searcher, q, t); - assertErrMsg(desc("dvector", "qvector", 1, "field type tensor(x[3]) does not match query tensor type tensor(x[2])"), r); + assertErrMsg(desc("dvector", "qvector", 1, "field type tensor(x[3]) does not match query type tensor(x[2])"), r); } @Test @@ -209,11 +200,11 @@ public class ValidateNearestNeighborTestCase { assertErrMsg(desc("matrix", "qvector", 1, "tensor type tensor(x[3],y[1]) is not a dense vector"), r); } - private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Object qTensor) { + private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Tensor qTensor) { QueryTree queryTree = new YqlParser(new ParserEnvironment()).parse(new Parsable().setQuery(yqlQuery)); Query query = new Query(); query.getModel().getQueryTree().setRoot(queryTree.getRoot()); - query.getRanking().getProperties().put("qvector", qTensor); + query.getRanking().getFeatures().put("query(qvector)", qTensor); SearchDefinition searchDefinition = new SearchDefinition("document"); IndexFacts indexFacts = new IndexFacts(new IndexModel(searchDefinition)); Execution.Context context = new Execution.Context(null, indexFacts, null, new RendererRegistry(MoreExecutors.directExecutor()), new SimpleLinguistics()); |