summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-03-02 09:24:02 +0000
committerArne Juul <arnej@verizonmedia.com>2020-03-02 11:44:22 +0000
commitfeeb478f356b0c2d6c3b7e0d80ef15620dd019b1 (patch)
treee097f5e897bed98452ff9bb8a26fb37bab59c5c5 /container-search
parent25707b0248f895e17058c09a782e1c88914a542f (diff)
extend NearestNeighborItem
Diffstat (limited to 'container-search')
-rw-r--r--container-search/abi-spec.json4
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java18
-rw-r--r--container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java5
-rw-r--r--container-search/src/main/java/com/yahoo/search/yql/YqlParser.java10
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java30
-rw-r--r--container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java6
6 files changed, 62 insertions, 11 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 82d3223c8fe..51fee99a743 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -858,8 +858,12 @@
"public void <init>(java.lang.String, java.lang.String)",
"public int getTargetNumHits()",
"public java.lang.String getIndexName()",
+ "public int getHnswExploreAdditionalHits()",
+ "public boolean getAllowApproximate()",
"public java.lang.String getQueryTensorName()",
"public void setTargetNumHits(int)",
+ "public void setHnswExploreAdditionalHits(int)",
+ "public void setAllowApproximate(boolean)",
"public void setIndexName(java.lang.String)",
"public com.yahoo.prelude.query.Item$ItemType getItemType()",
"public java.lang.String getName()",
diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
index 35b87ec0190..836107138d0 100644
--- a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
+++ b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
@@ -20,6 +20,8 @@ import java.nio.ByteBuffer;
public class NearestNeighborItem extends SimpleTaggableItem {
private int targetNumHits = 0;
+ private int hnswExploreAdditionalHits = 0;
+ private boolean approximate = true;
private String field;
private String queryTensorName;
@@ -34,12 +36,24 @@ public class NearestNeighborItem extends SimpleTaggableItem {
/** Returns the field name */
public String getIndexName() { return field; }
+ /** Returns the number of extra hits to explore in HNSW algorithm */
+ public int getHnswExploreAdditionalHits() { return hnswExploreAdditionalHits; }
+
+ /** Returns whether approximation is allowed */
+ public boolean getAllowApproximate() { return approximate; }
+
/** Returns the name of the query tensor */
public String getQueryTensorName() { return queryTensorName; }
/** Set the K number of hits to produce */
public void setTargetNumHits(int target) { this.targetNumHits = target; }
+ /** Set the number of extra hits to explore in HNSW algorithm */
+ public void setHnswExploreAdditionalHits(int num) { this.hnswExploreAdditionalHits = num; }
+
+ /** Set whether approximation is allowed */
+ public void setAllowApproximate(boolean value) { this.approximate = value; }
+
@Override
public void setIndexName(String index) { this.field = index; }
@@ -58,6 +72,8 @@ public class NearestNeighborItem extends SimpleTaggableItem {
putString(field, buffer);
putString(queryTensorName, buffer);
IntegerCompressor.putCompressedPositiveNumber(targetNumHits, buffer);
+ IntegerCompressor.putCompressedPositiveNumber((approximate ? 1 : 0), buffer);
+ IntegerCompressor.putCompressedPositiveNumber(hnswExploreAdditionalHits, buffer);
return 1; // number of encoded stack dump items
}
@@ -65,6 +81,8 @@ public class NearestNeighborItem extends SimpleTaggableItem {
protected void appendBodyString(StringBuilder buffer) {
buffer.append("{field=").append(field);
buffer.append(",queryTensorName=").append(queryTensorName);
+ buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits);
+ buffer.append(",approximate=").append(String.valueOf(approximate));
buffer.append(",targetNumHits=").append(targetNumHits).append("}");
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
index 6eef1252998..38b207cc7eb 100644
--- a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
+++ b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
@@ -702,6 +702,11 @@ public class VespaSerializer {
comma(destination, initLen);
int targetNumHits = item.getTargetNumHits();
destination.append("\"targetNumHits\": ").append(targetNumHits);
+ int explore = item.getHnswExploreAdditionalHits();
+ if (explore != 0) {
+ destination.append(",\"hnsw.exploreAdditionalHits\": ").append(explore);
+ }
+ destination.append(",\"approximate\": ").append(item.getAllowApproximate());
destination.append("}]");
destination.append(NEAREST_NEIGHBOR).append('(');
destination.append(item.getIndexName()).append(", ");
diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
index 8d013e501e8..f4560806dd2 100644
--- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
+++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
@@ -137,6 +137,7 @@ public class YqlParser implements Parser {
static final String ACCENT_DROP = "accentDrop";
static final String ALTERNATIVES = "alternatives";
static final String AND_SEGMENTING = "andSegmenting";
+ static final String APPROXIMATE = "approximate";
static final String BOUNDS = "bounds";
static final String BOUNDS_LEFT_OPEN = "leftOpen";
static final String BOUNDS_OPEN = "open";
@@ -149,6 +150,7 @@ public class YqlParser implements Parser {
static final String EQUIV = "equiv";
static final String FILTER = "filter";
static final String HIT_LIMIT = "hitLimit";
+ static final String HNSW_EXPLORE_ADDITIONAL_HITS = "hnsw.exploreAdditionalHits";
static final String IMPLICIT_TRANSFORMS = "implicitTransforms";
static final String LABEL = "label";
static final String NEAR = "near";
@@ -421,6 +423,14 @@ public class YqlParser implements Parser {
if (targetNumHits != null) {
item.setTargetNumHits(targetNumHits);
}
+ Integer hnswExploreAdditionalHits = getAnnotation(ast, HNSW_EXPLORE_ADDITIONAL_HITS,
+ Integer.class, null, "number of extra hits to explore for HNSW algorithm");
+ if (hnswExploreAdditionalHits != null) {
+ item.setHnswExploreAdditionalHits(hnswExploreAdditionalHits);
+ }
+ Boolean allowApproximate = getAnnotation(ast, APPROXIMATE,
+ Boolean.class, Boolean.TRUE, "allow approximate nearest neighbor search");
+ item.setAllowApproximate(allowApproximate);
String label = getAnnotation(ast, LABEL, String.class, null, "item label");
if (label != null) {
item.setLabel(label);
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
index 0cbf3a6f92c..c6233ffa50b 100644
--- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
@@ -139,12 +139,24 @@ public class ValidateNearestNeighborTestCase {
assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError());
}
+ static String desc(String field, String qt, int th, String errmsg) {
+ StringBuilder r = new StringBuilder();
+ r.append("NEAREST_NEIGHBOR {");
+ r.append("field=").append(field);
+ r.append(",queryTensorName=").append(qt);
+ r.append(",hnsw.exploreAdditionalHits=0");
+ r.append(",approximate=true");
+ r.append(",targetNumHits=").append(th);
+ r.append("} ").append(errmsg);
+ return r.toString();
+ }
+
@Test
public void testMissingTargetNumHits() {
String q = "select * from sources * where nearestNeighbor(dvector,qvector);";
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=0} has invalid targetNumHits", r);
+ assertErrMsg(desc("dvector", "qvector", 0, "has invalid targetNumHits"), r);
}
@Test
@@ -152,16 +164,16 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("dvector", "foo");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=foo,targetNumHits=1} query tensor not found", r);
+ assertErrMsg(desc("dvector", "foo", 1, "query tensor not found"), r);
}
@Test
public void testQueryTensorWrongType() {
String q = makeQuery("dvector", "qvector");
Result r = doSearch(searcher, q, "tensor string");
- assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: class java.lang.String", r);
+ assertErrMsg(desc("dvector", "qvector", 1, "query tensor should be a tensor, was: class java.lang.String"), r);
r = doSearch(searcher, q, null);
- assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: null", r);
+ assertErrMsg(desc("dvector", "qvector", 1, "query tensor should be a tensor, was: null"), r);
}
@Test
@@ -169,7 +181,7 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("dvector", "qvector");
Tensor t = makeTensor(tt_dense_dvector_2, 2);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} field type tensor(x[3]) does not match query tensor type tensor(x[2])", r);
+ assertErrMsg(desc("dvector", "qvector", 1, "field type tensor(x[3]) does not match query tensor type tensor(x[2])"), r);
}
@Test
@@ -177,7 +189,7 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("foo", "qvector");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=foo,queryTensorName=qvector,targetNumHits=1} field is not an attribute", r);
+ assertErrMsg(desc("foo", "qvector", 1, "field is not an attribute"), r);
}
@Test
@@ -185,7 +197,7 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("simple", "qvector");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=simple,queryTensorName=qvector,targetNumHits=1} field is not a tensor", r);
+ assertErrMsg(desc("simple", "qvector", 1, "field is not a tensor"), r);
}
@Test
@@ -193,7 +205,7 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("sparse", "qvector");
Tensor t = makeTensor(tt_sparse_vector_x);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r);
+ assertErrMsg(desc("sparse", "qvector", 1, "tensor type tensor(x{}) is not a dense vector"), r);
}
@Test
@@ -201,7 +213,7 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("matrix", "qvector");
Tensor t = makeMatrix(tt_dense_matrix_xy);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=matrix,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x[3],y[1]) is not a dense vector", r);
+ assertErrMsg(desc("matrix", "qvector", 1, "tensor type tensor(x[3],y[1]) is not a dense vector"), r);
}
private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Object qTensor) {
diff --git a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
index 5eb1f3e3de1..e43dbd4e266 100644
--- a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
@@ -550,9 +550,11 @@ public class YqlParserTestCase {
@Test
public void testNearestNeighbor() {
assertParse("select foo from bar where nearestNeighbor(semantic_embedding, my_vector);",
- "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,targetNumHits=0}");
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=0}");
assertParse("select foo from bar where [{\"targetNumHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);",
- "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,targetNumHits=37}");
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=37}");
+ assertParse("select foo from bar where [{\"approximate\": false, \"hnsw.exploreAdditionalHits\": 8, \"targetNumHits\": 3}]nearestNeighbor(semantic_embedding, my_vector);",
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,approximate=false,targetNumHits=3}");
}
@Test