diff options
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/searchers')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java | 72 |
1 files changed, 25 insertions, 47 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java index 8cae081cada..ef46ee5e5ea 100644 --- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java @@ -4,7 +4,6 @@ package com.yahoo.search.searchers; import com.google.common.annotations.Beta; -import com.yahoo.container.QrSearchersConfig; import com.yahoo.prelude.query.Item; import com.yahoo.prelude.query.NearestNeighborItem; import com.yahoo.prelude.query.QueryCanonicalizer; @@ -25,16 +24,15 @@ import java.util.List; import java.util.Map; import java.util.Optional; -// This depends on tensors in query.getRanking which are moved to rank.properties during query.prepare() -// Query.prepare is done at the same time as canonicalization (by GroupingExecutor), so use that constraint. -@After(QueryCanonicalizer.queryCanonicalization) - /** * Validates any NearestNeighborItem query items. * * @author arnej */ @Beta +// This depends on tensors in query.getRanking which are moved to rank.properties during query.prepare() +// Query.prepare is done at the same time as canonicalization (by GroupingExecutor), so use that constraint. +@After(QueryCanonicalizer.queryCanonicalization) public class ValidateNearestNeighborSearcher extends Searcher { private Map<String, TensorType> validAttributes = new HashMap<>(); @@ -76,15 +74,13 @@ public class ValidateNearestNeighborSearcher extends Searcher { @Override public boolean visit(Item item) { if (item instanceof NearestNeighborItem) { - validate((NearestNeighborItem) item); + String error = validate((NearestNeighborItem) item); + if (error != null) + errorMessage = Optional.of(ErrorMessage.createIllegalQuery(error)); } return true; } - private void setError(String description) { - errorMessage = Optional.of(ErrorMessage.createIllegalQuery(description)); - } - private static boolean isCompatible(TensorType lhs, TensorType rhs) { return lhs.dimensions().equals(rhs.dimensions()); } @@ -98,50 +94,32 @@ public class ValidateNearestNeighborSearcher extends Searcher { return true; } - private void validate(NearestNeighborItem item) { + /** Returns an error message if this is invalid, or null if it is valid */ + private String validate(NearestNeighborItem item) { int target = item.getTargetNumHits(); - if (target < 1) { - setError(item.toString() + " has invalid targetNumHits"); - return; - } - String qprop = item.getQueryTensorName(); - List<Object> rankPropValList = rankProperties.asMap().get(qprop); - if (rankPropValList == null) { - setError(item.toString() + " query tensor not found"); - return; - } - if (rankPropValList.size() != 1) { - setError(item.toString() + " query tensor does not have a single value"); - return; - } + if (target < 1) return item + " has invalid targetNumHits"; + + List<Object> rankPropValList = rankProperties.asMap().get(item.getQueryTensorName()); + if (rankPropValList == null) return item + " query tensor not found"; + if (rankPropValList.size() != 1) return item + " query tensor does not have a single value"; + Object rankPropValue = rankPropValList.get(0); if (! (rankPropValue instanceof Tensor)) { - setError(item.toString() + " query tensor should be a tensor, was: "+ - (rankPropValue == null ? "null" : rankPropValue.getClass().toString())); - return; + return item + " query tensor should be a tensor, was: " + + (rankPropValue == null ? "null" : rankPropValue.getClass()); } - Tensor qTensor = (Tensor)rankPropValue; - TensorType qTensorType = qTensor.type(); String field = item.getIndexName(); - if (validAttributes.containsKey(field)) { - TensorType fTensorType = validAttributes.get(field); - if (fTensorType == null) { - setError(item.toString() + " field is not a tensor"); - return; - } - if (! isCompatible(fTensorType, qTensorType)) { - setError(item.toString() + " field type "+fTensorType+" does not match query tensor type "+qTensorType); - return; - } - if (! isDenseVector(fTensorType)) { - setError(item.toString() + " tensor type "+fTensorType+" is not a dense vector"); - return; - } - } else { - setError(item.toString() + " field is not an attribute"); - return; + if ( ! validAttributes.containsKey(field)) return item + " field is not an attribute"; + + TensorType fTensorType = validAttributes.get(field); + TensorType qTensorType = ((Tensor)rankPropValue).type(); + if (fTensorType == null) return item + " field is not a tensor"; + if ( ! isCompatible(fTensorType, qTensorType)) { + return item + " field type " + fTensorType + " does not match query tensor type " + qTensorType; } + if (! isDenseVector(fTensorType)) return item + " tensor type " + fTensorType+" is not a dense vector"; + return null; } @Override |