diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-03-02 09:24:02 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-03-02 11:44:22 +0000 |
commit | feeb478f356b0c2d6c3b7e0d80ef15620dd019b1 (patch) | |
tree | e097f5e897bed98452ff9bb8a26fb37bab59c5c5 /container-search | |
parent | 25707b0248f895e17058c09a782e1c88914a542f (diff) |
extend NearestNeighborItem
Diffstat (limited to 'container-search')
6 files changed, 62 insertions, 11 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index 82d3223c8fe..51fee99a743 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -858,8 +858,12 @@ "public void <init>(java.lang.String, java.lang.String)", "public int getTargetNumHits()", "public java.lang.String getIndexName()", + "public int getHnswExploreAdditionalHits()", + "public boolean getAllowApproximate()", "public java.lang.String getQueryTensorName()", "public void setTargetNumHits(int)", + "public void setHnswExploreAdditionalHits(int)", + "public void setAllowApproximate(boolean)", "public void setIndexName(java.lang.String)", "public com.yahoo.prelude.query.Item$ItemType getItemType()", "public java.lang.String getName()", diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java index 35b87ec0190..836107138d0 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java @@ -20,6 +20,8 @@ import java.nio.ByteBuffer; public class NearestNeighborItem extends SimpleTaggableItem { private int targetNumHits = 0; + private int hnswExploreAdditionalHits = 0; + private boolean approximate = true; private String field; private String queryTensorName; @@ -34,12 +36,24 @@ public class NearestNeighborItem extends SimpleTaggableItem { /** Returns the field name */ public String getIndexName() { return field; } + /** Returns the number of extra hits to explore in HNSW algorithm */ + public int getHnswExploreAdditionalHits() { return hnswExploreAdditionalHits; } + + /** Returns whether approximation is allowed */ + public boolean getAllowApproximate() { return approximate; } + /** Returns the name of the query tensor */ public String getQueryTensorName() { return queryTensorName; } /** Set the K number of hits to produce */ public void setTargetNumHits(int target) { this.targetNumHits = target; } + /** Set the number of extra hits to explore in HNSW algorithm */ + public void setHnswExploreAdditionalHits(int num) { this.hnswExploreAdditionalHits = num; } + + /** Set whether approximation is allowed */ + public void setAllowApproximate(boolean value) { this.approximate = value; } + @Override public void setIndexName(String index) { this.field = index; } @@ -58,6 +72,8 @@ public class NearestNeighborItem extends SimpleTaggableItem { putString(field, buffer); putString(queryTensorName, buffer); IntegerCompressor.putCompressedPositiveNumber(targetNumHits, buffer); + IntegerCompressor.putCompressedPositiveNumber((approximate ? 1 : 0), buffer); + IntegerCompressor.putCompressedPositiveNumber(hnswExploreAdditionalHits, buffer); return 1; // number of encoded stack dump items } @@ -65,6 +81,8 @@ public class NearestNeighborItem extends SimpleTaggableItem { protected void appendBodyString(StringBuilder buffer) { buffer.append("{field=").append(field); buffer.append(",queryTensorName=").append(queryTensorName); + buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits); + buffer.append(",approximate=").append(String.valueOf(approximate)); buffer.append(",targetNumHits=").append(targetNumHits).append("}"); } } diff --git a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java index 6eef1252998..38b207cc7eb 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java +++ b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java @@ -702,6 +702,11 @@ public class VespaSerializer { comma(destination, initLen); int targetNumHits = item.getTargetNumHits(); destination.append("\"targetNumHits\": ").append(targetNumHits); + int explore = item.getHnswExploreAdditionalHits(); + if (explore != 0) { + destination.append(",\"hnsw.exploreAdditionalHits\": ").append(explore); + } + destination.append(",\"approximate\": ").append(item.getAllowApproximate()); destination.append("}]"); destination.append(NEAREST_NEIGHBOR).append('('); destination.append(item.getIndexName()).append(", "); diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java index 8d013e501e8..f4560806dd2 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java +++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java @@ -137,6 +137,7 @@ public class YqlParser implements Parser { static final String ACCENT_DROP = "accentDrop"; static final String ALTERNATIVES = "alternatives"; static final String AND_SEGMENTING = "andSegmenting"; + static final String APPROXIMATE = "approximate"; static final String BOUNDS = "bounds"; static final String BOUNDS_LEFT_OPEN = "leftOpen"; static final String BOUNDS_OPEN = "open"; @@ -149,6 +150,7 @@ public class YqlParser implements Parser { static final String EQUIV = "equiv"; static final String FILTER = "filter"; static final String HIT_LIMIT = "hitLimit"; + static final String HNSW_EXPLORE_ADDITIONAL_HITS = "hnsw.exploreAdditionalHits"; static final String IMPLICIT_TRANSFORMS = "implicitTransforms"; static final String LABEL = "label"; static final String NEAR = "near"; @@ -421,6 +423,14 @@ public class YqlParser implements Parser { if (targetNumHits != null) { item.setTargetNumHits(targetNumHits); } + Integer hnswExploreAdditionalHits = getAnnotation(ast, HNSW_EXPLORE_ADDITIONAL_HITS, + Integer.class, null, "number of extra hits to explore for HNSW algorithm"); + if (hnswExploreAdditionalHits != null) { + item.setHnswExploreAdditionalHits(hnswExploreAdditionalHits); + } + Boolean allowApproximate = getAnnotation(ast, APPROXIMATE, + Boolean.class, Boolean.TRUE, "allow approximate nearest neighbor search"); + item.setAllowApproximate(allowApproximate); String label = getAnnotation(ast, LABEL, String.class, null, "item label"); if (label != null) { item.setLabel(label); diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index 0cbf3a6f92c..c6233ffa50b 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -139,12 +139,24 @@ public class ValidateNearestNeighborTestCase { assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError()); } + static String desc(String field, String qt, int th, String errmsg) { + StringBuilder r = new StringBuilder(); + r.append("NEAREST_NEIGHBOR {"); + r.append("field=").append(field); + r.append(",queryTensorName=").append(qt); + r.append(",hnsw.exploreAdditionalHits=0"); + r.append(",approximate=true"); + r.append(",targetNumHits=").append(th); + r.append("} ").append(errmsg); + return r.toString(); + } + @Test public void testMissingTargetNumHits() { String q = "select * from sources * where nearestNeighbor(dvector,qvector);"; Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=0} has invalid targetNumHits", r); + assertErrMsg(desc("dvector", "qvector", 0, "has invalid targetNumHits"), r); } @Test @@ -152,16 +164,16 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("dvector", "foo"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=foo,targetNumHits=1} query tensor not found", r); + assertErrMsg(desc("dvector", "foo", 1, "query tensor not found"), r); } @Test public void testQueryTensorWrongType() { String q = makeQuery("dvector", "qvector"); Result r = doSearch(searcher, q, "tensor string"); - assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: class java.lang.String", r); + assertErrMsg(desc("dvector", "qvector", 1, "query tensor should be a tensor, was: class java.lang.String"), r); r = doSearch(searcher, q, null); - assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: null", r); + assertErrMsg(desc("dvector", "qvector", 1, "query tensor should be a tensor, was: null"), r); } @Test @@ -169,7 +181,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("dvector", "qvector"); Tensor t = makeTensor(tt_dense_dvector_2, 2); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} field type tensor(x[3]) does not match query tensor type tensor(x[2])", r); + assertErrMsg(desc("dvector", "qvector", 1, "field type tensor(x[3]) does not match query tensor type tensor(x[2])"), r); } @Test @@ -177,7 +189,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("foo", "qvector"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=foo,queryTensorName=qvector,targetNumHits=1} field is not an attribute", r); + assertErrMsg(desc("foo", "qvector", 1, "field is not an attribute"), r); } @Test @@ -185,7 +197,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("simple", "qvector"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=simple,queryTensorName=qvector,targetNumHits=1} field is not a tensor", r); + assertErrMsg(desc("simple", "qvector", 1, "field is not a tensor"), r); } @Test @@ -193,7 +205,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("sparse", "qvector"); Tensor t = makeTensor(tt_sparse_vector_x); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r); + assertErrMsg(desc("sparse", "qvector", 1, "tensor type tensor(x{}) is not a dense vector"), r); } @Test @@ -201,7 +213,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("matrix", "qvector"); Tensor t = makeMatrix(tt_dense_matrix_xy); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=matrix,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x[3],y[1]) is not a dense vector", r); + assertErrMsg(desc("matrix", "qvector", 1, "tensor type tensor(x[3],y[1]) is not a dense vector"), r); } private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Object qTensor) { diff --git a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java index 5eb1f3e3de1..e43dbd4e266 100644 --- a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java @@ -550,9 +550,11 @@ public class YqlParserTestCase { @Test public void testNearestNeighbor() { assertParse("select foo from bar where nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,targetNumHits=0}"); + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=0}"); assertParse("select foo from bar where [{\"targetNumHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,targetNumHits=37}"); + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=37}"); + assertParse("select foo from bar where [{\"approximate\": false, \"hnsw.exploreAdditionalHits\": 8, \"targetNumHits\": 3}]nearestNeighbor(semantic_embedding, my_vector);", + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,approximate=false,targetNumHits=3}"); } @Test |