summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2019-11-15 13:15:55 +0000
committerArne Juul <arnej@verizonmedia.com>2019-11-19 11:00:33 +0000
commit8b5afa32df34cc3e65318515d6d71e416425b07b (patch)
tree6517d1f9a16ec6ae8f3e246c9af5987876892985 /container-search
parent957a7cc70ba85568618fb2b5282d38f009c688ea (diff)
add NearestNeighborItem with validation
Diffstat (limited to 'container-search')
-rw-r--r--container-search/abi-spec.json34
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/query/Item.java3
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java69
-rw-r--r--container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java1
-rw-r--r--container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java127
-rw-r--r--container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java25
-rw-r--r--container-search/src/main/java/com/yahoo/search/yql/YqlParser.java24
-rw-r--r--container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java6
-rw-r--r--container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java8
9 files changed, 295 insertions, 2 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 51ce07c40bd..ddd754ce419 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -696,6 +696,7 @@
"public static final enum com.yahoo.prelude.query.Item$ItemType PREDICATE_QUERY",
"public static final enum com.yahoo.prelude.query.Item$ItemType REGEXP",
"public static final enum com.yahoo.prelude.query.Item$ItemType WORD_ALTERNATIVES",
+ "public static final enum com.yahoo.prelude.query.Item$ItemType NEAREST_NEIGHBOR",
"public final int code"
]
},
@@ -847,6 +848,27 @@
"public static final int defaultDistance"
]
},
+ "com.yahoo.prelude.query.NearestNeighborItem": {
+ "superClass": "com.yahoo.prelude.query.SimpleTaggableItem",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public int getTargetNumHits()",
+ "public java.lang.String getIndexName()",
+ "public java.lang.String getQueryRankFeatureName()",
+ "public void <init>(java.lang.String, java.lang.String)",
+ "public void setTargetNumHits(int)",
+ "public void setIndexName(java.lang.String)",
+ "public com.yahoo.prelude.query.Item$ItemType getItemType()",
+ "public java.lang.String getName()",
+ "public int getTermCount()",
+ "public int encode(java.nio.ByteBuffer)",
+ "protected void appendBodyString(java.lang.StringBuilder)"
+ ],
+ "fields": []
+ },
"com.yahoo.prelude.query.NonReducibleCompositeItem": {
"superClass": "com.yahoo.prelude.query.CompositeItem",
"interfaces": [],
@@ -7922,6 +7944,18 @@
],
"fields": []
},
+ "com.yahoo.search.searchers.ValidateNearestNeighborSearcher": {
+ "superClass": "com.yahoo.search.Searcher",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(com.yahoo.vespa.config.search.AttributesConfig)",
+ "public com.yahoo.search.Result search(com.yahoo.search.Query, com.yahoo.search.searchchain.Execution)"
+ ],
+ "fields": []
+ },
"com.yahoo.search.statistics.ElapsedTime": {
"superClass": "java.lang.Object",
"interfaces": [],
diff --git a/container-search/src/main/java/com/yahoo/prelude/query/Item.java b/container-search/src/main/java/com/yahoo/prelude/query/Item.java
index 9d8ccce1b76..ea65bc7d7d2 100644
--- a/container-search/src/main/java/com/yahoo/prelude/query/Item.java
+++ b/container-search/src/main/java/com/yahoo/prelude/query/Item.java
@@ -59,7 +59,8 @@ public abstract class Item implements Cloneable {
WAND(22),
PREDICATE_QUERY(23),
REGEXP(24),
- WORD_ALTERNATIVES(25);
+ WORD_ALTERNATIVES(25),
+ NEAREST_NEIGHBOR(26);
public final int code;
diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
new file mode 100644
index 00000000000..51a336fa4ba
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
@@ -0,0 +1,69 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.prelude.query;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.compress.IntegerCompressor;
+
+import java.nio.ByteBuffer;
+
+/**
+ * Represent a query item matching the K nearest neighbors in a multi-dimensional vector space.
+ * The query point vector is referenced by the name of a tensor rank feature passed in the query,
+ * so specifying "myvector" as the name means the query must set "ranking.features.query(myvector)",
+ * which must be configured with the correct tensor type in the active query profile.
+ * The field name (AKA the index name) given must be an attribute, with the exact same tensor type.
+ *
+ * @author arnej
+ */
+@Beta
+public class NearestNeighborItem extends SimpleTaggableItem {
+
+ private int targetNumber = 0;
+ private String field;
+ private String property;
+
+ /** @return the K number of hits to produce */
+ public int getTargetNumHits() { return targetNumber; }
+
+ /** @return the field name */
+ public String getIndexName() { return field; }
+
+ /** @return the name of the query ranking feature */
+ public String getQueryRankFeatureName() { return property; }
+
+ public NearestNeighborItem(String fieldName, String queryRankFeatureName) {
+ this.field = fieldName;
+ this.property = queryRankFeatureName;
+ }
+
+ public void setTargetNumHits(int target) { this.targetNumber = target; }
+
+ @Override
+ public void setIndexName(String index) { this.field = index; }
+
+ @Override
+ public ItemType getItemType() { return ItemType.NEAREST_NEIGHBOR; }
+
+ @Override
+ public String getName() { return "NEAREST_NEIGHBOR"; }
+
+ @Override
+ public int getTermCount() { return 1; }
+
+ @Override
+ public int encode(ByteBuffer buffer) {
+ super.encodeThis(buffer);
+ putString(field, buffer);
+ putString(property, buffer);
+ IntegerCompressor.putCompressedPositiveNumber(targetNumber, buffer);
+ return 1; // number of encoded stack dump items
+ }
+
+ @Override
+ protected void appendBodyString(StringBuilder buffer) {
+ buffer.append("{field=").append(field);
+ buffer.append(",property=").append(property);
+ buffer.append(",targetNumHits=").append(targetNumber).append("}");
+ }
+}
diff --git a/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java b/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java
index aa075b94af1..4c36ca9b4da 100644
--- a/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java
+++ b/container-search/src/main/java/com/yahoo/search/searchchain/model/federation/LocalProviderSpec.java
@@ -40,6 +40,7 @@ public class LocalProviderSpec {
com.yahoo.search.querytransform.WandSearcher.class,
com.yahoo.search.querytransform.BooleanSearcher.class,
com.yahoo.prelude.searcher.ValidatePredicateSearcher.class,
+ com.yahoo.search.searchers.ValidateNearestNeighborSearcher.class,
com.yahoo.search.searchers.ValidateMatchPhaseSearcher.class,
com.yahoo.search.yql.FieldFiller.class,
com.yahoo.search.searchers.InputCheckingSearcher.class,
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
new file mode 100644
index 00000000000..c2e05b15f33
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
@@ -0,0 +1,127 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.searchers;
+
+import com.google.common.annotations.Beta;
+
+import com.yahoo.container.QrSearchersConfig;
+import com.yahoo.prelude.query.Item;
+import com.yahoo.prelude.query.NearestNeighborItem;
+import com.yahoo.prelude.query.ToolBox;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.query.ranking.RankProperties;
+import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.searchchain.Execution;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.config.search.AttributesConfig;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Validates any NearestNeighborItem query items.
+ *
+ * @author arnej
+ */
+@Beta
+public class ValidateNearestNeighborSearcher extends Searcher {
+
+ private Map<String, TensorType> validAttributes = new HashMap<>();
+
+ public ValidateNearestNeighborSearcher(AttributesConfig attributesConfig) {
+ for (AttributesConfig.Attribute a : attributesConfig.attribute()) {
+ TensorType tt = null;
+ if (a.datatype() == AttributesConfig.Attribute.Datatype.TENSOR) {
+ tt = TensorType.fromSpec(a.tensortype());
+ }
+ validAttributes.put(a.name(), tt);
+ }
+ }
+
+ @Override
+ public Result search(Query query, Execution execution) {
+ Optional<ErrorMessage> e = validate(query);
+ return e.isEmpty() ? execution.search(query) : new Result(query, e.get());
+ }
+
+ private Optional<ErrorMessage> validate(Query query) {
+ NNVisitor visitor = new NNVisitor(query.getRanking().getProperties(), validAttributes);
+ ToolBox.visit(visitor, query.getModel().getQueryTree().getRoot());
+ return visitor.errorMessage;
+ }
+
+ private static class NNVisitor extends ToolBox.QueryVisitor {
+
+ public Optional<ErrorMessage> errorMessage = Optional.empty();
+
+ private RankProperties rankProperties;
+ private Map<String, TensorType> validAttributes;
+
+ public NNVisitor(RankProperties rankProperties, Map<String, TensorType> validAttributes) {
+ this.rankProperties = rankProperties;
+ this.validAttributes = validAttributes;
+ }
+
+ @Override
+ public boolean visit(Item item) {
+ if (item instanceof NearestNeighborItem) {
+ validate((NearestNeighborItem) item);
+ }
+ return true;
+ }
+
+ private void setError(String description) {
+ errorMessage = Optional.of(ErrorMessage.createIllegalQuery(description));
+ }
+
+ private void validate(NearestNeighborItem item) {
+ int target = item.getTargetNumHits();
+ if (target < 1) {
+ setError(item.toString() + " has invalid targetNumHits");
+ return;
+ }
+ String qprop = item.getQueryRankFeatureName();
+ List<Object> rankPropValList = rankProperties.asMap().get(qprop);
+ if (rankPropValList == null) {
+ setError(item.toString() + " query property not found");
+ return;
+ }
+ if (rankPropValList.size() != 1) {
+ setError(item.toString() + " query property does not have a single value");
+ return;
+ }
+ Object rankPropValue = rankPropValList.get(0);
+ if (! (rankPropValue instanceof Tensor)) {
+ setError(item.toString() + " query property should be a tensor, was: "+rankPropValue);
+ return;
+ }
+ Tensor qTensor = (Tensor)rankPropValue;
+ TensorType qTensorType = qTensor.type();
+
+ String field = item.getIndexName();
+ if (validAttributes.containsKey(field)) {
+ TensorType fTensorType = validAttributes.get(field);
+ if (fTensorType == null) {
+ setError(item.toString() + " field is not a tensor");
+ return;
+ }
+ if (! fTensorType.equals(qTensorType)) {
+ setError(item.toString() + " field type "+fTensorType+" does not match query property type "+qTensorType);
+ return;
+ }
+ } else {
+ setError(item.toString() + " field is not an attribute");
+ return;
+ }
+ }
+
+ @Override
+ public void onExit() {}
+
+ }
+
+}
diff --git a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
index 5565c680efb..6ab07d95c76 100644
--- a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
+++ b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
@@ -20,6 +20,7 @@ import static com.yahoo.search.yql.YqlParser.HIT_LIMIT;
import static com.yahoo.search.yql.YqlParser.IMPLICIT_TRANSFORMS;
import static com.yahoo.search.yql.YqlParser.LABEL;
import static com.yahoo.search.yql.YqlParser.NEAR;
+import static com.yahoo.search.yql.YqlParser.NEAREST_NEIGHBOR;
import static com.yahoo.search.yql.YqlParser.NORMALIZE_CASE;
import static com.yahoo.search.yql.YqlParser.ONEAR;
import static com.yahoo.search.yql.YqlParser.ORIGIN;
@@ -73,6 +74,7 @@ import com.yahoo.prelude.query.IntItem;
import com.yahoo.prelude.query.Item;
import com.yahoo.prelude.query.MarkerWordItem;
import com.yahoo.prelude.query.NearItem;
+import com.yahoo.prelude.query.NearestNeighborItem;
import com.yahoo.prelude.query.NotItem;
import com.yahoo.prelude.query.NullItem;
import com.yahoo.prelude.query.ONearItem;
@@ -687,6 +689,28 @@ public class VespaSerializer {
}
+ private static class NearestNeighborSerializer extends Serializer<NearestNeighborItem> {
+
+ @Override
+ void onExit(StringBuilder destination, NearestNeighborItem item) { }
+
+ @Override
+ boolean serialize(StringBuilder destination, NearestNeighborItem item) {
+ destination.append("[{");
+ int initLen = destination.length();
+ destination.append(leafAnnotations(item));
+ comma(destination, initLen);
+ int targetNumHits = item.getTargetNumHits();
+ destination.append("\"targetNumHits\": ").append(targetNumHits);
+ destination.append("}]");
+ destination.append(NEAREST_NEIGHBOR).append('(');
+ destination.append(item.getIndexName()).append(", ");
+ destination.append(item.getQueryRankFeatureName()).append(')');
+ return false;
+ }
+
+ }
+
private static class PredicateQuerySerializer extends Serializer<PredicateQueryItem> {
@Override
@@ -1131,6 +1155,7 @@ public class VespaSerializer {
dispatchBuilder.put(BoolItem.class, new BoolSerializer());
dispatchBuilder.put(MarkerWordItem.class, new WordSerializer()); // gotcha
dispatchBuilder.put(NearItem.class, new NearSerializer());
+ dispatchBuilder.put(NearestNeighborItem.class, new NearestNeighborSerializer());
dispatchBuilder.put(NotItem.class, new NotSerializer());
dispatchBuilder.put(NullItem.class, new NullSerializer());
dispatchBuilder.put(ONearItem.class, new ONearSerializer());
diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
index 306ac4e42f2..8d013e501e8 100644
--- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
+++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
@@ -35,6 +35,7 @@ import com.yahoo.prelude.query.IntItem;
import com.yahoo.prelude.query.Item;
import com.yahoo.prelude.query.Limit;
import com.yahoo.prelude.query.NearItem;
+import com.yahoo.prelude.query.NearestNeighborItem;
import com.yahoo.prelude.query.NotItem;
import com.yahoo.prelude.query.NullItem;
import com.yahoo.prelude.query.ONearItem;
@@ -151,6 +152,7 @@ public class YqlParser implements Parser {
static final String IMPLICIT_TRANSFORMS = "implicitTransforms";
static final String LABEL = "label";
static final String NEAR = "near";
+ static final String NEAREST_NEIGHBOR = "nearestNeighbor";
static final String NORMALIZE_CASE = "normalizeCase";
static final String ONEAR = "onear";
static final String ORIGIN_LENGTH = "length";
@@ -367,6 +369,8 @@ public class YqlParser implements Parser {
return buildWeightedSet(ast);
case DOT_PRODUCT:
return buildDotProduct(ast);
+ case NEAREST_NEIGHBOR:
+ return buildNearestNeighbor(ast);
case PREDICATE:
return buildPredicate(ast);
case RANK:
@@ -378,7 +382,7 @@ public class YqlParser implements Parser {
case NON_EMPTY:
return ensureNonEmpty(ast);
default:
- throw newUnexpectedArgumentException(names.get(0), DOT_PRODUCT,
+ throw newUnexpectedArgumentException(names.get(0), DOT_PRODUCT, NEAREST_NEIGHBOR,
RANGE, RANK, USER_QUERY, WAND, WEAK_AND, WEIGHTED_SET,
PREDICATE, USER_INPUT, NON_EMPTY);
}
@@ -406,6 +410,24 @@ public class YqlParser implements Parser {
return fillWeightedSet(ast, args.get(1), new DotProductItem(getIndex(args.get(0))));
}
+ private Item buildNearestNeighbor(OperatorNode<ExpressionOperator> ast) {
+ List<OperatorNode<ExpressionOperator>> args = ast.getArgument(1);
+ Preconditions.checkArgument(args.size() == 2, "Expected 2 arguments, got %s.", args.size());
+ String field = fetchFieldRead(args.get(0));
+ String property = fetchFieldRead(args.get(1));
+ NearestNeighborItem item = new NearestNeighborItem(field, property);
+ Integer targetNumHits = getAnnotation(ast, TARGET_NUM_HITS,
+ Integer.class, null, "desired minimum hits to produce");
+ if (targetNumHits != null) {
+ item.setTargetNumHits(targetNumHits);
+ }
+ String label = getAnnotation(ast, LABEL, String.class, null, "item label");
+ if (label != null) {
+ item.setLabel(label);
+ }
+ return item;
+ }
+
private Item buildPredicate(OperatorNode<ExpressionOperator> ast) {
List<OperatorNode<ExpressionOperator>> args = ast.getArgument(1);
Preconditions.checkArgument(args.size() == 3, "Expected 3 arguments, got %s.", args.size());
diff --git a/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java
index 0bbcabe4107..1106d8c3999 100644
--- a/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java
@@ -125,6 +125,12 @@ public class VespaSerializerTestCase {
}
@Test
+ public void testNearestNeighbor() {
+ parseAndConfirm("[{\"label\": \"foo\", \"targetNumHits\": 1000}]nearestNeighbor(semantic_embedding, my_property)");
+ parseAndConfirm("[{\"targetNumHits\": 42}]nearestNeighbor(semantic_embedding, my_property)");
+ }
+
+ @Test
public void testNumbers() {
parseAndConfirm("title = 500");
parseAndConfirm("title > 500");
diff --git a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
index 8f717e74ab2..e191829b875 100644
--- a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
@@ -548,6 +548,14 @@ public class YqlParserTestCase {
}
@Test
+ public void testNearestNeighbor() {
+ assertParse("select foo from bar where nearestNeighbor(semantic_embedding, my_vector);",
+ "NEAREST_NEIGHBOR {field=semantic_embedding,property=my_vector,targetNumHits=0}");
+ assertParse("select foo from bar where [{\"targetNumHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);",
+ "NEAREST_NEIGHBOR {field=semantic_embedding,property=my_vector,targetNumHits=37}");
+ }
+
+ @Test
public void testPredicate() {
assertParse("select foo from bar where predicate(predicate_field, " +
"{\"gender\":\"male\", \"hobby\":[\"music\", \"hiking\"]}, {\"age\":23L});",