diff options
author | Arne Juul <arnej@verizonmedia.com> | 2019-11-15 13:15:55 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2019-11-19 11:00:33 +0000 |
commit | 8b5afa32df34cc3e65318515d6d71e416425b07b (patch) | |
tree | 6517d1f9a16ec6ae8f3e246c9af5987876892985 /container-search/src/main/java/com/yahoo/search | |
parent | 957a7cc70ba85568618fb2b5282d38f009c688ea (diff) |
add NearestNeighborItem with validation
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search')
4 files changed, 176 insertions, 1 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java b/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java index aa075b94af1..4c36ca9b4da 100644 --- a/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java +++ b/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java @@ -40,6 +40,7 @@ public class LocalProviderSpec { com.yahoo.search.querytransform.WandSearcher.class, com.yahoo.search.querytransform.BooleanSearcher.class, com.yahoo.prelude.searcher.ValidatePredicateSearcher.class, + com.yahoo.search.searchers.ValidateNearestNeighborSearcher.class, com.yahoo.search.searchers.ValidateMatchPhaseSearcher.class, com.yahoo.search.yql.FieldFiller.class, com.yahoo.search.searchers.InputCheckingSearcher.class, diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java new file mode 100644 index 00000000000..c2e05b15f33 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java @@ -0,0 +1,127 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.searchers; + +import com.google.common.annotations.Beta; + +import com.yahoo.container.QrSearchersConfig; +import com.yahoo.prelude.query.Item; +import com.yahoo.prelude.query.NearestNeighborItem; +import com.yahoo.prelude.query.ToolBox; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.Searcher; +import com.yahoo.search.query.ranking.RankProperties; +import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.searchchain.Execution; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.config.search.AttributesConfig; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Validates any NearestNeighborItem query items. + * + * @author arnej + */ +@Beta +public class ValidateNearestNeighborSearcher extends Searcher { + + private Map<String, TensorType> validAttributes = new HashMap<>(); + + public ValidateNearestNeighborSearcher(AttributesConfig attributesConfig) { + for (AttributesConfig.Attribute a : attributesConfig.attribute()) { + TensorType tt = null; + if (a.datatype() == AttributesConfig.Attribute.Datatype.TENSOR) { + tt = TensorType.fromSpec(a.tensortype()); + } + validAttributes.put(a.name(), tt); + } + } + + @Override + public Result search(Query query, Execution execution) { + Optional<ErrorMessage> e = validate(query); + return e.isEmpty() ? execution.search(query) : new Result(query, e.get()); + } + + private Optional<ErrorMessage> validate(Query query) { + NNVisitor visitor = new NNVisitor(query.getRanking().getProperties(), validAttributes); + ToolBox.visit(visitor, query.getModel().getQueryTree().getRoot()); + return visitor.errorMessage; + } + + private static class NNVisitor extends ToolBox.QueryVisitor { + + public Optional<ErrorMessage> errorMessage = Optional.empty(); + + private RankProperties rankProperties; + private Map<String, TensorType> validAttributes; + + public NNVisitor(RankProperties rankProperties, Map<String, TensorType> validAttributes) { + this.rankProperties = rankProperties; + this.validAttributes = validAttributes; + } + + @Override + public boolean visit(Item item) { + if (item instanceof NearestNeighborItem) { + validate((NearestNeighborItem) item); + } + return true; + } + + private void setError(String description) { + errorMessage = Optional.of(ErrorMessage.createIllegalQuery(description)); + } + + private void validate(NearestNeighborItem item) { + int target = item.getTargetNumHits(); + if (target < 1) { + setError(item.toString() + " has invalid targetNumHits"); + return; + } + String qprop = item.getQueryRankFeatureName(); + List<Object> rankPropValList = rankProperties.asMap().get(qprop); + if (rankPropValList == null) { + setError(item.toString() + " query property not found"); + return; + } + if (rankPropValList.size() != 1) { + setError(item.toString() + " query property does not have a single value"); + return; + } + Object rankPropValue = rankPropValList.get(0); + if (! (rankPropValue instanceof Tensor)) { + setError(item.toString() + " query property should be a tensor, was: "+rankPropValue); + return; + } + Tensor qTensor = (Tensor)rankPropValue; + TensorType qTensorType = qTensor.type(); + + String field = item.getIndexName(); + if (validAttributes.containsKey(field)) { + TensorType fTensorType = validAttributes.get(field); + if (fTensorType == null) { + setError(item.toString() + " field is not a tensor"); + return; + } + if (! fTensorType.equals(qTensorType)) { + setError(item.toString() + " field type "+fTensorType+" does not match query property type "+qTensorType); + return; + } + } else { + setError(item.toString() + " field is not an attribute"); + return; + } + } + + @Override + public void onExit() {} + + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java index 5565c680efb..6ab07d95c76 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java +++ b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java @@ -20,6 +20,7 @@ import static com.yahoo.search.yql.YqlParser.HIT_LIMIT; import static com.yahoo.search.yql.YqlParser.IMPLICIT_TRANSFORMS; import static com.yahoo.search.yql.YqlParser.LABEL; import static com.yahoo.search.yql.YqlParser.NEAR; +import static com.yahoo.search.yql.YqlParser.NEAREST_NEIGHBOR; import static com.yahoo.search.yql.YqlParser.NORMALIZE_CASE; import static com.yahoo.search.yql.YqlParser.ONEAR; import static com.yahoo.search.yql.YqlParser.ORIGIN; @@ -73,6 +74,7 @@ import com.yahoo.prelude.query.IntItem; import com.yahoo.prelude.query.Item; import com.yahoo.prelude.query.MarkerWordItem; import com.yahoo.prelude.query.NearItem; +import com.yahoo.prelude.query.NearestNeighborItem; import com.yahoo.prelude.query.NotItem; import com.yahoo.prelude.query.NullItem; import com.yahoo.prelude.query.ONearItem; @@ -687,6 +689,28 @@ public class VespaSerializer { } + private static class NearestNeighborSerializer extends Serializer<NearestNeighborItem> { + + @Override + void onExit(StringBuilder destination, NearestNeighborItem item) { } + + @Override + boolean serialize(StringBuilder destination, NearestNeighborItem item) { + destination.append("[{"); + int initLen = destination.length(); + destination.append(leafAnnotations(item)); + comma(destination, initLen); + int targetNumHits = item.getTargetNumHits(); + destination.append("\"targetNumHits\": ").append(targetNumHits); + destination.append("}]"); + destination.append(NEAREST_NEIGHBOR).append('('); + destination.append(item.getIndexName()).append(", "); + destination.append(item.getQueryRankFeatureName()).append(')'); + return false; + } + + } + private static class PredicateQuerySerializer extends Serializer<PredicateQueryItem> { @Override @@ -1131,6 +1155,7 @@ public class VespaSerializer { dispatchBuilder.put(BoolItem.class, new BoolSerializer()); dispatchBuilder.put(MarkerWordItem.class, new WordSerializer()); // gotcha dispatchBuilder.put(NearItem.class, new NearSerializer()); + dispatchBuilder.put(NearestNeighborItem.class, new NearestNeighborSerializer()); dispatchBuilder.put(NotItem.class, new NotSerializer()); dispatchBuilder.put(NullItem.class, new NullSerializer()); dispatchBuilder.put(ONearItem.class, new ONearSerializer()); diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java index 306ac4e42f2..8d013e501e8 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java +++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java @@ -35,6 +35,7 @@ import com.yahoo.prelude.query.IntItem; import com.yahoo.prelude.query.Item; import com.yahoo.prelude.query.Limit; import com.yahoo.prelude.query.NearItem; +import com.yahoo.prelude.query.NearestNeighborItem; import com.yahoo.prelude.query.NotItem; import com.yahoo.prelude.query.NullItem; import com.yahoo.prelude.query.ONearItem; @@ -151,6 +152,7 @@ public class YqlParser implements Parser { static final String IMPLICIT_TRANSFORMS = "implicitTransforms"; static final String LABEL = "label"; static final String NEAR = "near"; + static final String NEAREST_NEIGHBOR = "nearestNeighbor"; static final String NORMALIZE_CASE = "normalizeCase"; static final String ONEAR = "onear"; static final String ORIGIN_LENGTH = "length"; @@ -367,6 +369,8 @@ public class YqlParser implements Parser { return buildWeightedSet(ast); case DOT_PRODUCT: return buildDotProduct(ast); + case NEAREST_NEIGHBOR: + return buildNearestNeighbor(ast); case PREDICATE: return buildPredicate(ast); case RANK: @@ -378,7 +382,7 @@ public class YqlParser implements Parser { case NON_EMPTY: return ensureNonEmpty(ast); default: - throw newUnexpectedArgumentException(names.get(0), DOT_PRODUCT, + throw newUnexpectedArgumentException(names.get(0), DOT_PRODUCT, NEAREST_NEIGHBOR, RANGE, RANK, USER_QUERY, WAND, WEAK_AND, WEIGHTED_SET, PREDICATE, USER_INPUT, NON_EMPTY); } @@ -406,6 +410,24 @@ public class YqlParser implements Parser { return fillWeightedSet(ast, args.get(1), new DotProductItem(getIndex(args.get(0)))); } + private Item buildNearestNeighbor(OperatorNode<ExpressionOperator> ast) { + List<OperatorNode<ExpressionOperator>> args = ast.getArgument(1); + Preconditions.checkArgument(args.size() == 2, "Expected 2 arguments, got %s.", args.size()); + String field = fetchFieldRead(args.get(0)); + String property = fetchFieldRead(args.get(1)); + NearestNeighborItem item = new NearestNeighborItem(field, property); + Integer targetNumHits = getAnnotation(ast, TARGET_NUM_HITS, + Integer.class, null, "desired minimum hits to produce"); + if (targetNumHits != null) { + item.setTargetNumHits(targetNumHits); + } + String label = getAnnotation(ast, LABEL, String.class, null, "item label"); + if (label != null) { + item.setLabel(label); + } + return item; + } + private Item buildPredicate(OperatorNode<ExpressionOperator> ast) { List<OperatorNode<ExpressionOperator>> args = ast.getArgument(1); Preconditions.checkArgument(args.size() == 3, "Expected 3 arguments, got %s.", args.size()); |