aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2019-11-15 13:15:55 +0000
committerArne Juul <arnej@verizonmedia.com>2019-11-19 11:00:33 +0000
commit8b5afa32df34cc3e65318515d6d71e416425b07b (patch)
tree6517d1f9a16ec6ae8f3e246c9af5987876892985 /container-search/src/main/java/com/yahoo/search
parent957a7cc70ba85568618fb2b5282d38f009c688ea (diff)
add NearestNeighborItem with validation
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search')
-rw-r--r--container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java1
-rw-r--r--container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java127
-rw-r--r--container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java25
-rw-r--r--container-search/src/main/java/com/yahoo/search/yql/YqlParser.java24
4 files changed, 176 insertions, 1 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java b/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java
index aa075b94af1..4c36ca9b4da 100644
--- a/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java
+++ b/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java
@@ -40,6 +40,7 @@ public class LocalProviderSpec {
com.yahoo.search.querytransform.WandSearcher.class,
com.yahoo.search.querytransform.BooleanSearcher.class,
com.yahoo.prelude.searcher.ValidatePredicateSearcher.class,
+ com.yahoo.search.searchers.ValidateNearestNeighborSearcher.class,
com.yahoo.search.searchers.ValidateMatchPhaseSearcher.class,
com.yahoo.search.yql.FieldFiller.class,
com.yahoo.search.searchers.InputCheckingSearcher.class,
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
new file mode 100644
index 00000000000..c2e05b15f33
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
@@ -0,0 +1,127 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.searchers;
+
+import com.google.common.annotations.Beta;
+
+import com.yahoo.container.QrSearchersConfig;
+import com.yahoo.prelude.query.Item;
+import com.yahoo.prelude.query.NearestNeighborItem;
+import com.yahoo.prelude.query.ToolBox;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.query.ranking.RankProperties;
+import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.searchchain.Execution;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.config.search.AttributesConfig;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Validates any NearestNeighborItem query items.
+ *
+ * @author arnej
+ */
+@Beta
+public class ValidateNearestNeighborSearcher extends Searcher {
+
+ private Map<String, TensorType> validAttributes = new HashMap<>();
+
+ public ValidateNearestNeighborSearcher(AttributesConfig attributesConfig) {
+ for (AttributesConfig.Attribute a : attributesConfig.attribute()) {
+ TensorType tt = null;
+ if (a.datatype() == AttributesConfig.Attribute.Datatype.TENSOR) {
+ tt = TensorType.fromSpec(a.tensortype());
+ }
+ validAttributes.put(a.name(), tt);
+ }
+ }
+
+ @Override
+ public Result search(Query query, Execution execution) {
+ Optional<ErrorMessage> e = validate(query);
+ return e.isEmpty() ? execution.search(query) : new Result(query, e.get());
+ }
+
+ private Optional<ErrorMessage> validate(Query query) {
+ NNVisitor visitor = new NNVisitor(query.getRanking().getProperties(), validAttributes);
+ ToolBox.visit(visitor, query.getModel().getQueryTree().getRoot());
+ return visitor.errorMessage;
+ }
+
+ private static class NNVisitor extends ToolBox.QueryVisitor {
+
+ public Optional<ErrorMessage> errorMessage = Optional.empty();
+
+ private RankProperties rankProperties;
+ private Map<String, TensorType> validAttributes;
+
+ public NNVisitor(RankProperties rankProperties, Map<String, TensorType> validAttributes) {
+ this.rankProperties = rankProperties;
+ this.validAttributes = validAttributes;
+ }
+
+ @Override
+ public boolean visit(Item item) {
+ if (item instanceof NearestNeighborItem) {
+ validate((NearestNeighborItem) item);
+ }
+ return true;
+ }
+
+ private void setError(String description) {
+ errorMessage = Optional.of(ErrorMessage.createIllegalQuery(description));
+ }
+
+ private void validate(NearestNeighborItem item) {
+ int target = item.getTargetNumHits();
+ if (target < 1) {
+ setError(item.toString() + " has invalid targetNumHits");
+ return;
+ }
+ String qprop = item.getQueryRankFeatureName();
+ List<Object> rankPropValList = rankProperties.asMap().get(qprop);
+ if (rankPropValList == null) {
+ setError(item.toString() + " query property not found");
+ return;
+ }
+ if (rankPropValList.size() != 1) {
+ setError(item.toString() + " query property does not have a single value");
+ return;
+ }
+ Object rankPropValue = rankPropValList.get(0);
+ if (! (rankPropValue instanceof Tensor)) {
+ setError(item.toString() + " query property should be a tensor, was: "+rankPropValue);
+ return;
+ }
+ Tensor qTensor = (Tensor)rankPropValue;
+ TensorType qTensorType = qTensor.type();
+
+ String field = item.getIndexName();
+ if (validAttributes.containsKey(field)) {
+ TensorType fTensorType = validAttributes.get(field);
+ if (fTensorType == null) {
+ setError(item.toString() + " field is not a tensor");
+ return;
+ }
+ if (! fTensorType.equals(qTensorType)) {
+ setError(item.toString() + " field type "+fTensorType+" does not match query property type "+qTensorType);
+ return;
+ }
+ } else {
+ setError(item.toString() + " field is not an attribute");
+ return;
+ }
+ }
+
+ @Override
+ public void onExit() {}
+
+ }
+
+}
diff --git a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
index 5565c680efb..6ab07d95c76 100644
--- a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
+++ b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
@@ -20,6 +20,7 @@ import static com.yahoo.search.yql.YqlParser.HIT_LIMIT;
import static com.yahoo.search.yql.YqlParser.IMPLICIT_TRANSFORMS;
import static com.yahoo.search.yql.YqlParser.LABEL;
import static com.yahoo.search.yql.YqlParser.NEAR;
+import static com.yahoo.search.yql.YqlParser.NEAREST_NEIGHBOR;
import static com.yahoo.search.yql.YqlParser.NORMALIZE_CASE;
import static com.yahoo.search.yql.YqlParser.ONEAR;
import static com.yahoo.search.yql.YqlParser.ORIGIN;
@@ -73,6 +74,7 @@ import com.yahoo.prelude.query.IntItem;
import com.yahoo.prelude.query.Item;
import com.yahoo.prelude.query.MarkerWordItem;
import com.yahoo.prelude.query.NearItem;
+import com.yahoo.prelude.query.NearestNeighborItem;
import com.yahoo.prelude.query.NotItem;
import com.yahoo.prelude.query.NullItem;
import com.yahoo.prelude.query.ONearItem;
@@ -687,6 +689,28 @@ public class VespaSerializer {
}
+ private static class NearestNeighborSerializer extends Serializer<NearestNeighborItem> {
+
+ @Override
+ void onExit(StringBuilder destination, NearestNeighborItem item) { }
+
+ @Override
+ boolean serialize(StringBuilder destination, NearestNeighborItem item) {
+ destination.append("[{");
+ int initLen = destination.length();
+ destination.append(leafAnnotations(item));
+ comma(destination, initLen);
+ int targetNumHits = item.getTargetNumHits();
+ destination.append("\"targetNumHits\": ").append(targetNumHits);
+ destination.append("}]");
+ destination.append(NEAREST_NEIGHBOR).append('(');
+ destination.append(item.getIndexName()).append(", ");
+ destination.append(item.getQueryRankFeatureName()).append(')');
+ return false;
+ }
+
+ }
+
private static class PredicateQuerySerializer extends Serializer<PredicateQueryItem> {
@Override
@@ -1131,6 +1155,7 @@ public class VespaSerializer {
dispatchBuilder.put(BoolItem.class, new BoolSerializer());
dispatchBuilder.put(MarkerWordItem.class, new WordSerializer()); // gotcha
dispatchBuilder.put(NearItem.class, new NearSerializer());
+ dispatchBuilder.put(NearestNeighborItem.class, new NearestNeighborSerializer());
dispatchBuilder.put(NotItem.class, new NotSerializer());
dispatchBuilder.put(NullItem.class, new NullSerializer());
dispatchBuilder.put(ONearItem.class, new ONearSerializer());
diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
index 306ac4e42f2..8d013e501e8 100644
--- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
+++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
@@ -35,6 +35,7 @@ import com.yahoo.prelude.query.IntItem;
import com.yahoo.prelude.query.Item;
import com.yahoo.prelude.query.Limit;
import com.yahoo.prelude.query.NearItem;
+import com.yahoo.prelude.query.NearestNeighborItem;
import com.yahoo.prelude.query.NotItem;
import com.yahoo.prelude.query.NullItem;
import com.yahoo.prelude.query.ONearItem;
@@ -151,6 +152,7 @@ public class YqlParser implements Parser {
static final String IMPLICIT_TRANSFORMS = "implicitTransforms";
static final String LABEL = "label";
static final String NEAR = "near";
+ static final String NEAREST_NEIGHBOR = "nearestNeighbor";
static final String NORMALIZE_CASE = "normalizeCase";
static final String ONEAR = "onear";
static final String ORIGIN_LENGTH = "length";
@@ -367,6 +369,8 @@ public class YqlParser implements Parser {
return buildWeightedSet(ast);
case DOT_PRODUCT:
return buildDotProduct(ast);
+ case NEAREST_NEIGHBOR:
+ return buildNearestNeighbor(ast);
case PREDICATE:
return buildPredicate(ast);
case RANK:
@@ -378,7 +382,7 @@ public class YqlParser implements Parser {
case NON_EMPTY:
return ensureNonEmpty(ast);
default:
- throw newUnexpectedArgumentException(names.get(0), DOT_PRODUCT,
+ throw newUnexpectedArgumentException(names.get(0), DOT_PRODUCT, NEAREST_NEIGHBOR,
RANGE, RANK, USER_QUERY, WAND, WEAK_AND, WEIGHTED_SET,
PREDICATE, USER_INPUT, NON_EMPTY);
}
@@ -406,6 +410,24 @@ public class YqlParser implements Parser {
return fillWeightedSet(ast, args.get(1), new DotProductItem(getIndex(args.get(0))));
}
+ private Item buildNearestNeighbor(OperatorNode<ExpressionOperator> ast) {
+ List<OperatorNode<ExpressionOperator>> args = ast.getArgument(1);
+ Preconditions.checkArgument(args.size() == 2, "Expected 2 arguments, got %s.", args.size());
+ String field = fetchFieldRead(args.get(0));
+ String property = fetchFieldRead(args.get(1));
+ NearestNeighborItem item = new NearestNeighborItem(field, property);
+ Integer targetNumHits = getAnnotation(ast, TARGET_NUM_HITS,
+ Integer.class, null, "desired minimum hits to produce");
+ if (targetNumHits != null) {
+ item.setTargetNumHits(targetNumHits);
+ }
+ String label = getAnnotation(ast, LABEL, String.class, null, "item label");
+ if (label != null) {
+ item.setLabel(label);
+ }
+ return item;
+ }
+
private Item buildPredicate(OperatorNode<ExpressionOperator> ast) {
List<OperatorNode<ExpressionOperator>> args = ast.getArgument(1);
Preconditions.checkArgument(args.size() == 3, "Expected 3 arguments, got %s.", args.size());