diff options
Diffstat (limited to 'searchlib')
13 files changed, 52 insertions, 22 deletions
diff --git a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp index 24be21f65ec..47728c9785c 100644 --- a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp +++ b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp @@ -295,7 +295,7 @@ public: request_ctx.set_query_tensor("query_tensor", tensor_spec); } Blueprint::UP create_blueprint() { - query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7); + query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7, true, 33); return source.createBlueprint(request_ctx, FieldSpec(attr_name, 0, 0), term); } }; diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp index 39e381c0942..edbc29be784 100644 --- a/searchlib/src/tests/query/query_visitor_test.cpp +++ b/searchlib/src/tests/query/query_visitor_test.cpp @@ -99,7 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() { checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0))); checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0))); checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0))); - checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123)); + checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123, true, 321)); } } // namespace diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp index 7f496b3493c..8560cb0e091 100644 --- a/searchlib/src/tests/query/querybuilder_test.cpp +++ b/searchlib/src/tests/query/querybuilder_test.cpp @@ -111,7 +111,7 @@ Node::UP createQueryTree() { builder.addStringTerm(str[5], view[5], id[5], weight[6]); builder.addStringTerm(str[6], view[6], id[6], weight[7]); } - builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7); + builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7, true, 33); } Node::UP node = builder.build(); ASSERT_TRUE(node.get()); @@ -395,8 +395,9 @@ struct MyRegExpTerm : RegExpTerm { }; struct MyNearestNeighborTerm : NearestNeighborTerm { MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, - int32_t i, Weight w, uint32_t target_num_hits) - : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits) + int32_t i, Weight w, uint32_t target_num_hits, + bool allow_approximate, uint32_t explore_additional_hits) + : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits, allow_approximate, explore_additional_hits) {} }; diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index 8595b0eff7f..9af05059bef 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -646,7 +646,9 @@ public: query_tensor.release(); setResult(std::make_unique<queryeval::NearestNeighborBlueprint>(_field, *dense_attr_tensor, std::move(dense_query_tensor_up), - n.get_target_num_hits())); + n.get_target_num_hits(), + n.get_allow_approximate(), + n.get_explore_additional_hits())); } }; diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp index 70a3097ae05..f0fb53a5640 100644 --- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp +++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp @@ -274,7 +274,9 @@ SimpleQueryStackDumpIterator::next() try { _curr_index_name = read_stringref(p); _curr_term = read_stringref(p); // query_tensor_name - _currArg1 = readCompressedPositiveInt(p); // target_num_hits; + _currArg1 = readCompressedPositiveInt(p); // target_num_hits + _currArg2 = readCompressedPositiveInt(p); // allow_approximate + _currArg3 = readCompressedPositiveInt(p); // explore_additional_hits _currArity = 0; } catch (...) { return false; diff --git a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h index 797defc39f5..8e6f2944ec9 100644 --- a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h +++ b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h @@ -205,8 +205,11 @@ createRegExpTerm(vespalib::stringref term, vespalib::stringref view, int32_t id, template <class NodeTypes> typename NodeTypes::NearestNeighborTerm * create_nearest_neighbor_term(vespalib::stringref query_tensor_name, vespalib::stringref field_name, - int32_t id, Weight weight, uint32_t target_num_hits) { - return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits); + int32_t id, Weight weight, uint32_t target_num_hits, + bool allow_approximate, uint32_t explore_additional_hits) +{ + return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, + target_num_hits, allow_approximate, explore_additional_hits); } template <class NodeTypes> @@ -317,9 +320,10 @@ public: return addTerm(createRegExpTerm<NodeTypes>(term, view, id, weight)); } typename NodeTypes::NearestNeighborTerm &add_nearest_neighbor_term(stringref query_tensor_name, stringref field_name, - int32_t id, Weight weight, uint32_t target_num_hits) { + int32_t id, Weight weight, uint32_t target_num_hits, + bool allow_approximate, uint32_t explore_additional_hits) { adjustWeight(weight); - return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits)); + return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits, allow_approximate, explore_additional_hits)); } }; diff --git a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h index 0bf923960b9..9289df7cbe9 100644 --- a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h +++ b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h @@ -165,7 +165,8 @@ private: void visit(NearestNeighborTerm &node) override { replicate(node, _builder.add_nearest_neighbor_term(node.get_query_tensor_name(), node.getView(), - node.getId(), node.getWeight(), node.get_target_num_hits())); + node.getId(), node.getWeight(), node.get_target_num_hits(), + node.get_allow_approximate(), node.get_explore_additional_hits())); } }; diff --git a/searchlib/src/vespa/searchlib/query/tree/simplequery.h b/searchlib/src/vespa/searchlib/query/tree/simplequery.h index 8663bede4d6..4953f1a5b7c 100644 --- a/searchlib/src/vespa/searchlib/query/tree/simplequery.h +++ b/searchlib/src/vespa/searchlib/query/tree/simplequery.h @@ -105,8 +105,10 @@ struct SimpleRegExpTerm : RegExpTerm { }; struct SimpleNearestNeighborTerm : NearestNeighborTerm { SimpleNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, - int32_t id, Weight weight, uint32_t target_num_hits) - : NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits) + int32_t id, Weight weight, uint32_t target_num_hits, + bool allow_approximate, uint32_t explore_additional_hits) + : NearestNeighborTerm(query_tensor_name, field_name, id, weight, + target_num_hits, allow_approximate, explore_additional_hits) {} }; diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp index 63acf532144..aafeaa46a22 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp @@ -263,6 +263,8 @@ class QueryNodeConverter : public QueryVisitor { createTermNode(node, ParseItem::ITEM_NEAREST_NEIGHBOR); appendString(node.get_query_tensor_name()); appendCompressedPositiveNumber(node.get_target_num_hits()); + appendCompressedPositiveNumber(node.get_allow_approximate() ? 1 : 0); + appendCompressedPositiveNumber(node.get_explore_additional_hits()); } public: diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h index 791da010720..a57c24584cc 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h @@ -114,7 +114,10 @@ private: uint32_t target_num_hits = queryStack.getArg1(); int32_t id = queryStack.getUniqueId(); Weight weight = queryStack.GetWeight(); - builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, target_num_hits); + uint32_t allow_approximate = (queryStack.getArg2() != 0); + uint32_t explore_additional_hits = queryStack.getArg3(); + builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, + target_num_hits, allow_approximate, explore_additional_hits); } else { vespalib::stringref term = queryStack.getTerm(); vespalib::stringref view = queryStack.getIndexName(); diff --git a/searchlib/src/vespa/searchlib/query/tree/termnodes.h b/searchlib/src/vespa/searchlib/query/tree/termnodes.h index a82b1e14d76..9af424716fb 100644 --- a/searchlib/src/vespa/searchlib/query/tree/termnodes.h +++ b/searchlib/src/vespa/searchlib/query/tree/termnodes.h @@ -128,17 +128,24 @@ class NearestNeighborTerm : public QueryNodeMixin<NearestNeighborTerm, TermNode> private: vespalib::string _query_tensor_name; uint32_t _target_num_hits; + bool _allow_approximate; + uint32_t _explore_additional_hits; public: NearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, - int32_t id, Weight weight, uint32_t target_num_hits) + int32_t id, Weight weight, uint32_t target_num_hits, + bool allow_approximate, uint32_t explore_additional_hits) : QueryNodeMixinType(field_name, id, weight), _query_tensor_name(query_tensor_name), - _target_num_hits(target_num_hits) + _target_num_hits(target_num_hits), + _allow_approximate(allow_approximate), + _explore_additional_hits(explore_additional_hits) {} virtual ~NearestNeighborTerm() {} const vespalib::string& get_query_tensor_name() const { return _query_tensor_name; } uint32_t get_target_num_hits() const { return _target_num_hits; } + bool get_allow_approximate() const { return _allow_approximate; } + uint32_t get_explore_additional_hits() const { return _explore_additional_hits; } }; diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index d4aa2aaa1d7..c160f8d5485 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -13,17 +13,22 @@ namespace search::queryeval { NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& field, const tensor::DenseTensorAttribute& attr_tensor, std::unique_ptr<vespalib::tensor::DenseTensorView> query_tensor, - uint32_t target_num_hits) + uint32_t target_num_hits, bool approximate, uint32_t explore_k) : ComplexLeafBlueprint(field), _attr_tensor(attr_tensor), _query_tensor(std::move(query_tensor)), _target_num_hits(target_num_hits), + _approximate(approximate), + _explore_k(explore_k), _distance_heap(target_num_hits), _found_hits() { uint32_t est_hits = _attr_tensor.getNumDocs(); if (_attr_tensor.nearest_neighbor_index()) { est_hits = std::min(target_num_hits, est_hits); + if (_explore_k == 0) { + _explore_k = 100; + } } setEstimate(HitEstimate(est_hits, false)); } @@ -34,15 +39,14 @@ void NearestNeighborBlueprint::perform_top_k() { auto nns_index = _attr_tensor.nearest_neighbor_index(); - if (nns_index) { + if (_approximate && nns_index) { auto lhs_type = _query_tensor->fast_type(); auto rhs_type = _attr_tensor.getTensorType(); // XXX deal with different cell types later if (lhs_type == rhs_type) { auto lhs = _query_tensor->cellsRef(); uint32_t k = _target_num_hits; - uint32_t explore_k = k + 100; // XXX hardcoded for now - _found_hits = nns_index->find_top_k(k, lhs, explore_k); + _found_hits = nns_index->find_top_k(k, lhs, k + _explore_k); } } } diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h index ab4413c487a..a782633ccc3 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h @@ -21,6 +21,8 @@ private: const tensor::DenseTensorAttribute& _attr_tensor; std::unique_ptr<vespalib::tensor::DenseTensorView> _query_tensor; uint32_t _target_num_hits; + bool _approximate; + uint32_t _explore_k; mutable NearestNeighborDistanceHeap _distance_heap; std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits; @@ -29,7 +31,7 @@ public: NearestNeighborBlueprint(const queryeval::FieldSpec& field, const tensor::DenseTensorAttribute& attr_tensor, std::unique_ptr<vespalib::tensor::DenseTensorView> query_tensor, - uint32_t target_num_hits); + uint32_t target_num_hits, bool approximate, uint32_t explore_k); NearestNeighborBlueprint(const NearestNeighborBlueprint&) = delete; NearestNeighborBlueprint& operator=(const NearestNeighborBlueprint&) = delete; ~NearestNeighborBlueprint(); |