diff options
19 files changed, 114 insertions, 33 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index 82d3223c8fe..51fee99a743 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -858,8 +858,12 @@ "public void <init>(java.lang.String, java.lang.String)", "public int getTargetNumHits()", "public java.lang.String getIndexName()", + "public int getHnswExploreAdditionalHits()", + "public boolean getAllowApproximate()", "public java.lang.String getQueryTensorName()", "public void setTargetNumHits(int)", + "public void setHnswExploreAdditionalHits(int)", + "public void setAllowApproximate(boolean)", "public void setIndexName(java.lang.String)", "public com.yahoo.prelude.query.Item$ItemType getItemType()", "public java.lang.String getName()", diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java index 35b87ec0190..836107138d0 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java @@ -20,6 +20,8 @@ import java.nio.ByteBuffer; public class NearestNeighborItem extends SimpleTaggableItem { private int targetNumHits = 0; + private int hnswExploreAdditionalHits = 0; + private boolean approximate = true; private String field; private String queryTensorName; @@ -34,12 +36,24 @@ public class NearestNeighborItem extends SimpleTaggableItem { /** Returns the field name */ public String getIndexName() { return field; } + /** Returns the number of extra hits to explore in HNSW algorithm */ + public int getHnswExploreAdditionalHits() { return hnswExploreAdditionalHits; } + + /** Returns whether approximation is allowed */ + public boolean getAllowApproximate() { return approximate; } + /** Returns the name of the query tensor */ public String getQueryTensorName() { return queryTensorName; } /** Set the K number of hits to produce */ public void setTargetNumHits(int target) { this.targetNumHits = target; } + /** Set the number of extra hits to explore in HNSW algorithm */ + public void setHnswExploreAdditionalHits(int num) { this.hnswExploreAdditionalHits = num; } + + /** Set whether approximation is allowed */ + public void setAllowApproximate(boolean value) { this.approximate = value; } + @Override public void setIndexName(String index) { this.field = index; } @@ -58,6 +72,8 @@ public class NearestNeighborItem extends SimpleTaggableItem { putString(field, buffer); putString(queryTensorName, buffer); IntegerCompressor.putCompressedPositiveNumber(targetNumHits, buffer); + IntegerCompressor.putCompressedPositiveNumber((approximate ? 1 : 0), buffer); + IntegerCompressor.putCompressedPositiveNumber(hnswExploreAdditionalHits, buffer); return 1; // number of encoded stack dump items } @@ -65,6 +81,8 @@ public class NearestNeighborItem extends SimpleTaggableItem { protected void appendBodyString(StringBuilder buffer) { buffer.append("{field=").append(field); buffer.append(",queryTensorName=").append(queryTensorName); + buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits); + buffer.append(",approximate=").append(String.valueOf(approximate)); buffer.append(",targetNumHits=").append(targetNumHits).append("}"); } } diff --git a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java index 6eef1252998..38b207cc7eb 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java +++ b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java @@ -702,6 +702,11 @@ public class VespaSerializer { comma(destination, initLen); int targetNumHits = item.getTargetNumHits(); destination.append("\"targetNumHits\": ").append(targetNumHits); + int explore = item.getHnswExploreAdditionalHits(); + if (explore != 0) { + destination.append(",\"hnsw.exploreAdditionalHits\": ").append(explore); + } + destination.append(",\"approximate\": ").append(item.getAllowApproximate()); destination.append("}]"); destination.append(NEAREST_NEIGHBOR).append('('); destination.append(item.getIndexName()).append(", "); diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java index 8d013e501e8..f4560806dd2 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java +++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java @@ -137,6 +137,7 @@ public class YqlParser implements Parser { static final String ACCENT_DROP = "accentDrop"; static final String ALTERNATIVES = "alternatives"; static final String AND_SEGMENTING = "andSegmenting"; + static final String APPROXIMATE = "approximate"; static final String BOUNDS = "bounds"; static final String BOUNDS_LEFT_OPEN = "leftOpen"; static final String BOUNDS_OPEN = "open"; @@ -149,6 +150,7 @@ public class YqlParser implements Parser { static final String EQUIV = "equiv"; static final String FILTER = "filter"; static final String HIT_LIMIT = "hitLimit"; + static final String HNSW_EXPLORE_ADDITIONAL_HITS = "hnsw.exploreAdditionalHits"; static final String IMPLICIT_TRANSFORMS = "implicitTransforms"; static final String LABEL = "label"; static final String NEAR = "near"; @@ -421,6 +423,14 @@ public class YqlParser implements Parser { if (targetNumHits != null) { item.setTargetNumHits(targetNumHits); } + Integer hnswExploreAdditionalHits = getAnnotation(ast, HNSW_EXPLORE_ADDITIONAL_HITS, + Integer.class, null, "number of extra hits to explore for HNSW algorithm"); + if (hnswExploreAdditionalHits != null) { + item.setHnswExploreAdditionalHits(hnswExploreAdditionalHits); + } + Boolean allowApproximate = getAnnotation(ast, APPROXIMATE, + Boolean.class, Boolean.TRUE, "allow approximate nearest neighbor search"); + item.setAllowApproximate(allowApproximate); String label = getAnnotation(ast, LABEL, String.class, null, "item label"); if (label != null) { item.setLabel(label); diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index 0cbf3a6f92c..c6233ffa50b 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -139,12 +139,24 @@ public class ValidateNearestNeighborTestCase { assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError()); } + static String desc(String field, String qt, int th, String errmsg) { + StringBuilder r = new StringBuilder(); + r.append("NEAREST_NEIGHBOR {"); + r.append("field=").append(field); + r.append(",queryTensorName=").append(qt); + r.append(",hnsw.exploreAdditionalHits=0"); + r.append(",approximate=true"); + r.append(",targetNumHits=").append(th); + r.append("} ").append(errmsg); + return r.toString(); + } + @Test public void testMissingTargetNumHits() { String q = "select * from sources * where nearestNeighbor(dvector,qvector);"; Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=0} has invalid targetNumHits", r); + assertErrMsg(desc("dvector", "qvector", 0, "has invalid targetNumHits"), r); } @Test @@ -152,16 +164,16 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("dvector", "foo"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=foo,targetNumHits=1} query tensor not found", r); + assertErrMsg(desc("dvector", "foo", 1, "query tensor not found"), r); } @Test public void testQueryTensorWrongType() { String q = makeQuery("dvector", "qvector"); Result r = doSearch(searcher, q, "tensor string"); - assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: class java.lang.String", r); + assertErrMsg(desc("dvector", "qvector", 1, "query tensor should be a tensor, was: class java.lang.String"), r); r = doSearch(searcher, q, null); - assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: null", r); + assertErrMsg(desc("dvector", "qvector", 1, "query tensor should be a tensor, was: null"), r); } @Test @@ -169,7 +181,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("dvector", "qvector"); Tensor t = makeTensor(tt_dense_dvector_2, 2); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} field type tensor(x[3]) does not match query tensor type tensor(x[2])", r); + assertErrMsg(desc("dvector", "qvector", 1, "field type tensor(x[3]) does not match query tensor type tensor(x[2])"), r); } @Test @@ -177,7 +189,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("foo", "qvector"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=foo,queryTensorName=qvector,targetNumHits=1} field is not an attribute", r); + assertErrMsg(desc("foo", "qvector", 1, "field is not an attribute"), r); } @Test @@ -185,7 +197,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("simple", "qvector"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=simple,queryTensorName=qvector,targetNumHits=1} field is not a tensor", r); + assertErrMsg(desc("simple", "qvector", 1, "field is not a tensor"), r); } @Test @@ -193,7 +205,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("sparse", "qvector"); Tensor t = makeTensor(tt_sparse_vector_x); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r); + assertErrMsg(desc("sparse", "qvector", 1, "tensor type tensor(x{}) is not a dense vector"), r); } @Test @@ -201,7 +213,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("matrix", "qvector"); Tensor t = makeMatrix(tt_dense_matrix_xy); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=matrix,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x[3],y[1]) is not a dense vector", r); + assertErrMsg(desc("matrix", "qvector", 1, "tensor type tensor(x[3],y[1]) is not a dense vector"), r); } private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Object qTensor) { diff --git a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java index 5eb1f3e3de1..e43dbd4e266 100644 --- a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java @@ -550,9 +550,11 @@ public class YqlParserTestCase { @Test public void testNearestNeighbor() { assertParse("select foo from bar where nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,targetNumHits=0}"); + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=0}"); assertParse("select foo from bar where [{\"targetNumHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,targetNumHits=37}"); + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=37}"); + assertParse("select foo from bar where [{\"approximate\": false, \"hnsw.exploreAdditionalHits\": 8, \"targetNumHits\": 3}]nearestNeighbor(semantic_embedding, my_vector);", + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,approximate=false,targetNumHits=3}"); } @Test diff --git a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp index 24be21f65ec..47728c9785c 100644 --- a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp +++ b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp @@ -295,7 +295,7 @@ public: request_ctx.set_query_tensor("query_tensor", tensor_spec); } Blueprint::UP create_blueprint() { - query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7); + query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7, true, 33); return source.createBlueprint(request_ctx, FieldSpec(attr_name, 0, 0), term); } }; diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp index 39e381c0942..edbc29be784 100644 --- a/searchlib/src/tests/query/query_visitor_test.cpp +++ b/searchlib/src/tests/query/query_visitor_test.cpp @@ -99,7 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() { checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0))); checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0))); checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0))); - checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123)); + checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123, true, 321)); } } // namespace diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp index 7f496b3493c..8560cb0e091 100644 --- a/searchlib/src/tests/query/querybuilder_test.cpp +++ b/searchlib/src/tests/query/querybuilder_test.cpp @@ -111,7 +111,7 @@ Node::UP createQueryTree() { builder.addStringTerm(str[5], view[5], id[5], weight[6]); builder.addStringTerm(str[6], view[6], id[6], weight[7]); } - builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7); + builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7, true, 33); } Node::UP node = builder.build(); ASSERT_TRUE(node.get()); @@ -395,8 +395,9 @@ struct MyRegExpTerm : RegExpTerm { }; struct MyNearestNeighborTerm : NearestNeighborTerm { MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, - int32_t i, Weight w, uint32_t target_num_hits) - : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits) + int32_t i, Weight w, uint32_t target_num_hits, + bool allow_approximate, uint32_t explore_additional_hits) + : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits, allow_approximate, explore_additional_hits) {} }; diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index 8595b0eff7f..9af05059bef 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -646,7 +646,9 @@ public: query_tensor.release(); setResult(std::make_unique<queryeval::NearestNeighborBlueprint>(_field, *dense_attr_tensor, std::move(dense_query_tensor_up), - n.get_target_num_hits())); + n.get_target_num_hits(), + n.get_allow_approximate(), + n.get_explore_additional_hits())); } }; diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp index 70a3097ae05..f0fb53a5640 100644 --- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp +++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp @@ -274,7 +274,9 @@ SimpleQueryStackDumpIterator::next() try { _curr_index_name = read_stringref(p); _curr_term = read_stringref(p); // query_tensor_name - _currArg1 = readCompressedPositiveInt(p); // target_num_hits; + _currArg1 = readCompressedPositiveInt(p); // target_num_hits + _currArg2 = readCompressedPositiveInt(p); // allow_approximate + _currArg3 = readCompressedPositiveInt(p); // explore_additional_hits _currArity = 0; } catch (...) { return false; diff --git a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h index 797defc39f5..8e6f2944ec9 100644 --- a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h +++ b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h @@ -205,8 +205,11 @@ createRegExpTerm(vespalib::stringref term, vespalib::stringref view, int32_t id, template <class NodeTypes> typename NodeTypes::NearestNeighborTerm * create_nearest_neighbor_term(vespalib::stringref query_tensor_name, vespalib::stringref field_name, - int32_t id, Weight weight, uint32_t target_num_hits) { - return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits); + int32_t id, Weight weight, uint32_t target_num_hits, + bool allow_approximate, uint32_t explore_additional_hits) +{ + return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, + target_num_hits, allow_approximate, explore_additional_hits); } template <class NodeTypes> @@ -317,9 +320,10 @@ public: return addTerm(createRegExpTerm<NodeTypes>(term, view, id, weight)); } typename NodeTypes::NearestNeighborTerm &add_nearest_neighbor_term(stringref query_tensor_name, stringref field_name, - int32_t id, Weight weight, uint32_t target_num_hits) { + int32_t id, Weight weight, uint32_t target_num_hits, + bool allow_approximate, uint32_t explore_additional_hits) { adjustWeight(weight); - return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits)); + return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits, allow_approximate, explore_additional_hits)); } }; diff --git a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h index 0bf923960b9..9289df7cbe9 100644 --- a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h +++ b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h @@ -165,7 +165,8 @@ private: void visit(NearestNeighborTerm &node) override { replicate(node, _builder.add_nearest_neighbor_term(node.get_query_tensor_name(), node.getView(), - node.getId(), node.getWeight(), node.get_target_num_hits())); + node.getId(), node.getWeight(), node.get_target_num_hits(), + node.get_allow_approximate(), node.get_explore_additional_hits())); } }; diff --git a/searchlib/src/vespa/searchlib/query/tree/simplequery.h b/searchlib/src/vespa/searchlib/query/tree/simplequery.h index 8663bede4d6..4953f1a5b7c 100644 --- a/searchlib/src/vespa/searchlib/query/tree/simplequery.h +++ b/searchlib/src/vespa/searchlib/query/tree/simplequery.h @@ -105,8 +105,10 @@ struct SimpleRegExpTerm : RegExpTerm { }; struct SimpleNearestNeighborTerm : NearestNeighborTerm { SimpleNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, - int32_t id, Weight weight, uint32_t target_num_hits) - : NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits) + int32_t id, Weight weight, uint32_t target_num_hits, + bool allow_approximate, uint32_t explore_additional_hits) + : NearestNeighborTerm(query_tensor_name, field_name, id, weight, + target_num_hits, allow_approximate, explore_additional_hits) {} }; diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp index 63acf532144..aafeaa46a22 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp @@ -263,6 +263,8 @@ class QueryNodeConverter : public QueryVisitor { createTermNode(node, ParseItem::ITEM_NEAREST_NEIGHBOR); appendString(node.get_query_tensor_name()); appendCompressedPositiveNumber(node.get_target_num_hits()); + appendCompressedPositiveNumber(node.get_allow_approximate() ? 1 : 0); + appendCompressedPositiveNumber(node.get_explore_additional_hits()); } public: diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h index 791da010720..a57c24584cc 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h @@ -114,7 +114,10 @@ private: uint32_t target_num_hits = queryStack.getArg1(); int32_t id = queryStack.getUniqueId(); Weight weight = queryStack.GetWeight(); - builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, target_num_hits); + uint32_t allow_approximate = (queryStack.getArg2() != 0); + uint32_t explore_additional_hits = queryStack.getArg3(); + builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, + target_num_hits, allow_approximate, explore_additional_hits); } else { vespalib::stringref term = queryStack.getTerm(); vespalib::stringref view = queryStack.getIndexName(); diff --git a/searchlib/src/vespa/searchlib/query/tree/termnodes.h b/searchlib/src/vespa/searchlib/query/tree/termnodes.h index a82b1e14d76..9af424716fb 100644 --- a/searchlib/src/vespa/searchlib/query/tree/termnodes.h +++ b/searchlib/src/vespa/searchlib/query/tree/termnodes.h @@ -128,17 +128,24 @@ class NearestNeighborTerm : public QueryNodeMixin<NearestNeighborTerm, TermNode> private: vespalib::string _query_tensor_name; uint32_t _target_num_hits; + bool _allow_approximate; + uint32_t _explore_additional_hits; public: NearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, - int32_t id, Weight weight, uint32_t target_num_hits) + int32_t id, Weight weight, uint32_t target_num_hits, + bool allow_approximate, uint32_t explore_additional_hits) : QueryNodeMixinType(field_name, id, weight), _query_tensor_name(query_tensor_name), - _target_num_hits(target_num_hits) + _target_num_hits(target_num_hits), + _allow_approximate(allow_approximate), + _explore_additional_hits(explore_additional_hits) {} virtual ~NearestNeighborTerm() {} const vespalib::string& get_query_tensor_name() const { return _query_tensor_name; } uint32_t get_target_num_hits() const { return _target_num_hits; } + bool get_allow_approximate() const { return _allow_approximate; } + uint32_t get_explore_additional_hits() const { return _explore_additional_hits; } }; diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index d4aa2aaa1d7..c160f8d5485 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -13,17 +13,22 @@ namespace search::queryeval { NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& field, const tensor::DenseTensorAttribute& attr_tensor, std::unique_ptr<vespalib::tensor::DenseTensorView> query_tensor, - uint32_t target_num_hits) + uint32_t target_num_hits, bool approximate, uint32_t explore_k) : ComplexLeafBlueprint(field), _attr_tensor(attr_tensor), _query_tensor(std::move(query_tensor)), _target_num_hits(target_num_hits), + _approximate(approximate), + _explore_k(explore_k), _distance_heap(target_num_hits), _found_hits() { uint32_t est_hits = _attr_tensor.getNumDocs(); if (_attr_tensor.nearest_neighbor_index()) { est_hits = std::min(target_num_hits, est_hits); + if (_explore_k == 0) { + _explore_k = 100; + } } setEstimate(HitEstimate(est_hits, false)); } @@ -34,15 +39,14 @@ void NearestNeighborBlueprint::perform_top_k() { auto nns_index = _attr_tensor.nearest_neighbor_index(); - if (nns_index) { + if (_approximate && nns_index) { auto lhs_type = _query_tensor->fast_type(); auto rhs_type = _attr_tensor.getTensorType(); // XXX deal with different cell types later if (lhs_type == rhs_type) { auto lhs = _query_tensor->cellsRef(); uint32_t k = _target_num_hits; - uint32_t explore_k = k + 100; // XXX hardcoded for now - _found_hits = nns_index->find_top_k(k, lhs, explore_k); + _found_hits = nns_index->find_top_k(k, lhs, k + _explore_k); } } } diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h index ab4413c487a..a782633ccc3 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h @@ -21,6 +21,8 @@ private: const tensor::DenseTensorAttribute& _attr_tensor; std::unique_ptr<vespalib::tensor::DenseTensorView> _query_tensor; uint32_t _target_num_hits; + bool _approximate; + uint32_t _explore_k; mutable NearestNeighborDistanceHeap _distance_heap; std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits; @@ -29,7 +31,7 @@ public: NearestNeighborBlueprint(const queryeval::FieldSpec& field, const tensor::DenseTensorAttribute& attr_tensor, std::unique_ptr<vespalib::tensor::DenseTensorView> query_tensor, - uint32_t target_num_hits); + uint32_t target_num_hits, bool approximate, uint32_t explore_k); NearestNeighborBlueprint(const NearestNeighborBlueprint&) = delete; NearestNeighborBlueprint& operator=(const NearestNeighborBlueprint&) = delete; ~NearestNeighborBlueprint(); |