diff options
11 files changed, 150 insertions, 53 deletions
diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index 2c202d9131b..210f32af15e 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -814,7 +814,7 @@ TEST("test_nearest_neighbor_query_node") constexpr uint32_t target_num_hits = 100; constexpr bool allow_approximate = false; constexpr uint32_t explore_additional_hits = 800; - constexpr double raw_score = 0.5; + constexpr double distance = 0.5; builder.add_nearest_neighbor_term("qtensor", "field", id, Weight(weight), target_num_hits, allow_approximate, explore_additional_hits, distance_threshold); auto build_node = builder.build(); auto stack_dump = StackDumpCreator::create(*build_node); @@ -830,14 +830,14 @@ TEST("test_nearest_neighbor_query_node") EXPECT_EQUAL(id, static_cast<int32_t>(node->uniqueId())); EXPECT_EQUAL(weight, node->weight().percent()); EXPECT_EQUAL(distance_threshold, node->get_distance_threshold()); - EXPECT_FALSE(node->get_raw_score().has_value()); + EXPECT_FALSE(node->get_distance().has_value()); EXPECT_FALSE(node->evaluate()); - node->set_raw_score(raw_score); - EXPECT_TRUE(node->get_raw_score().has_value()); - EXPECT_EQUAL(raw_score, node->get_raw_score().value()); + node->set_distance(distance); + EXPECT_TRUE(node->get_distance().has_value()); + EXPECT_EQUAL(distance, node->get_distance().value()); EXPECT_TRUE(node->evaluate()); node->reset(); - EXPECT_FALSE(node->get_raw_score().has_value()); + EXPECT_FALSE(node->get_distance().has_value()); EXPECT_FALSE(node->evaluate()); } diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h index 46b89fdfeb4..9bef389a278 100644 --- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h +++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h @@ -109,7 +109,7 @@ public: uint32_t getArity() const { return _currArity; } uint32_t getNearDistance() const { return _extraIntArg1; } - uint32_t getTargetNumHits() const { return _extraIntArg1; } + uint32_t getTargetHits() const { return _extraIntArg1; } double getDistanceThreshold() const { return _extraDoubleArg4; } double getScoreThreshold() const { return _extraDoubleArg4; } double getThresholdBoostFactor() const { return _extraDoubleArg5; } diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp index d1c37cd6dcd..b2d8a0ee4be 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp @@ -1,15 +1,21 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "nearest_neighbor_query_node.h" +#include <cassert> namespace search::streaming { -NearestNeighborQueryNode::NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold) - : QueryTerm(std::move(resultBase), term, index, Type::NEAREST_NEIGHBOR), +NearestNeighborQueryNode::NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, + const string& query_tensor_name, const string& field_name, + uint32_t target_hits, double distance_threshold, + int32_t unique_id, search::query::Weight weight) + : QueryTerm(std::move(resultBase), query_tensor_name, field_name, Type::NEAREST_NEIGHBOR), + _target_hits(target_hits), _distance_threshold(distance_threshold), - _raw_score() + _distance(), + _calc() { - setUniqueId(id); + setUniqueId(unique_id); setWeight(weight); } @@ -18,13 +24,13 @@ NearestNeighborQueryNode::~NearestNeighborQueryNode() = default; bool NearestNeighborQueryNode::evaluate() const { - return _raw_score.has_value(); + return _distance.has_value(); } void NearestNeighborQueryNode::reset() { - _raw_score.reset(); + _distance.reset(); } NearestNeighborQueryNode* @@ -33,4 +39,14 @@ NearestNeighborQueryNode::as_nearest_neighbor_query_node() noexcept return this; } +std::optional<double> +NearestNeighborQueryNode::get_raw_score() const +{ + if (_distance.has_value()) { + assert(_calc != nullptr); + return _calc->to_raw_score(_distance.value()); + } + return std::nullopt; +} + } diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h index 0beb130c53d..c66364b0c52 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h +++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h @@ -8,16 +8,34 @@ namespace search::streaming { /* - * Nearest neighbor query node. + * Nearest neighbor query node for streaming search. */ class NearestNeighborQueryNode: public QueryTerm { +public: + class RawScoreCalculator { + public: + virtual ~RawScoreCalculator() = default; + /** + * Convert the given distance to a raw score. + * + * This is used during unpacking, and also signals that the entire document was a match. + */ + virtual double to_raw_score(double distance) = 0; + }; + private: + uint32_t _target_hits; double _distance_threshold; - // When this value is set it also indicates a match - std::optional<double> _raw_score; + // When this value is set it also indicates a match for this query node. + std::optional<double> _distance; + RawScoreCalculator* _calc; + public: - NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold); + NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, + const string& query_tensor_name, const string& field_name, + uint32_t target_hits, double distance_threshold, + int32_t unique_id, search::query::Weight weight); NearestNeighborQueryNode(const NearestNeighborQueryNode &) = delete; NearestNeighborQueryNode & operator = (const NearestNeighborQueryNode &) = delete; NearestNeighborQueryNode(NearestNeighborQueryNode &&) = delete; @@ -27,9 +45,13 @@ public: void reset() override; NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept override; const vespalib::string& get_query_tensor_name() const { return getTermString(); } + uint32_t get_target_hits() const { return _target_hits; } double get_distance_threshold() const { return _distance_threshold; } - void set_raw_score(double value) { _raw_score = value; } - const std::optional<double>& get_raw_score() const noexcept { return _raw_score; } + void set_raw_score_calc(RawScoreCalculator* calc_in) { _calc = calc_in; } + void set_distance(double value) { _distance = value; } + const std::optional<double>& get_distance() const { return _distance; } + // This is used during unpacking, and also signals to the RawScoreCalculator that the entire document was a match. + std::optional<double> get_raw_score() const; }; } diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index 226cb92c894..84344831cbc 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -200,15 +200,17 @@ QueryNode::build_nearest_neighbor_query_node(const QueryNodeResultFactory& facto { vespalib::stringref query_tensor_name = query_rep.getTerm(); vespalib::stringref field_name = query_rep.getIndexName(); - int32_t id = query_rep.getUniqueId(); - search::query::Weight weight = query_rep.GetWeight(); + int32_t unique_id = query_rep.getUniqueId(); + auto weight = query_rep.GetWeight(); + uint32_t target_hits = query_rep.getTargetHits(); double distance_threshold = query_rep.getDistanceThreshold(); return std::make_unique<NearestNeighborQueryNode>(factory.create(), query_tensor_name, field_name, - id, - weight, - distance_threshold); + target_hits, + distance_threshold, + unique_id, + weight); } } diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h index 90bd87979c7..a552a650704 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h @@ -89,7 +89,7 @@ private: pureTermView = view; } else if (type == ParseItem::ITEM_WEAK_AND) { vespalib::stringref view = queryStack.getIndexName(); - uint32_t targetNumHits = queryStack.getTargetNumHits(); + uint32_t targetNumHits = queryStack.getTargetHits(); builder.addWeakAnd(arity, targetNumHits, view); pureTermView = view; } else if (type == ParseItem::ITEM_EQUIV) { @@ -134,7 +134,7 @@ private: vespalib::stringref view = queryStack.getIndexName(); int32_t id = queryStack.getUniqueId(); Weight weight = queryStack.GetWeight(); - uint32_t targetNumHits = queryStack.getTargetNumHits(); + uint32_t targetNumHits = queryStack.getTargetHits(); double scoreThreshold = queryStack.getScoreThreshold(); double thresholdBoostFactor = queryStack.getThresholdBoostFactor(); auto & wand = builder.addWandTerm(arity, view, id, weight, targetNumHits, scoreThreshold, thresholdBoostFactor); @@ -146,7 +146,7 @@ private: } else if (type == ParseItem::ITEM_NEAREST_NEIGHBOR) { vespalib::stringref query_tensor_name = queryStack.getTerm(); vespalib::stringref field_name = queryStack.getIndexName(); - uint32_t target_num_hits = queryStack.getTargetNumHits(); + uint32_t target_num_hits = queryStack.getTargetHits(); int32_t id = queryStack.getUniqueId(); Weight weight = queryStack.GetWeight(); bool allow_approximate = queryStack.getAllowApproximate(); diff --git a/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp b/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp index 43c77398be8..b64d477fd4c 100644 --- a/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp +++ b/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp @@ -31,9 +31,11 @@ struct MockQuery { std::vector<std::unique_ptr<NearestNeighborQueryNode>> nodes; QueryTermList term_list; MockQuery& add(const vespalib::string& query_tensor_name, + uint32_t target_hits, double distance_threshold) { std::unique_ptr<QueryNodeResultBase> base; - auto node = std::make_unique<NearestNeighborQueryNode>(std::move(base), query_tensor_name, "my_tensor_field", 7, search::query::Weight(11), distance_threshold); + auto node = std::make_unique<NearestNeighborQueryNode>(std::move(base), query_tensor_name, "my_tensor_field", + target_hits, distance_threshold, 7, search::query::Weight(100)); nodes.push_back(std::move(node)); term_list.push_back(nodes.back().get()); return *this; @@ -90,34 +92,71 @@ public: query.reset(); searcher.onValue(fv); } + void expect_match(const vespalib::string& spec_expr, double exp_square_distance, const NearestNeighborQueryNode& node) { + match(spec_expr); + expect_match(exp_square_distance, node); + } void expect_match(double exp_square_distance, const NearestNeighborQueryNode& node) { double exp_raw_score = dist_func.to_rawscore(exp_square_distance); EXPECT_TRUE(node.evaluate()); + EXPECT_DOUBLE_EQ(exp_square_distance, node.get_distance().value()); EXPECT_DOUBLE_EQ(exp_raw_score, node.get_raw_score().value()); } + void expect_not_match(const vespalib::string& spec_expr, const NearestNeighborQueryNode& node) { + match(spec_expr); + EXPECT_FALSE(node.evaluate()); + } }; -TEST_F(NearestNeighborSearcherTest, raw_score_calculated_with_distance_threshold) +TEST_F(NearestNeighborSearcherTest, distance_heap_keeps_the_best_target_hits) { - query.add("qt1", 3); + query.add("qt1", 2, 100.0); + const auto& node = query.get(0); set_query_tensor("qt1", "tensor(x[2]):[1,3]"); prepare(); - match("tensor(x[2]):[1,5]"); - expect_match((5-3)*(5-3), query.get(0)); + expect_match("tensor(x[2]):[1,7]", (7-3)*(7-3), node); + expect_match("tensor(x[2]):[1,9]", (9-3)*(9-3), node); - match("tensor(x[2]):[1,6]"); - expect_match((6-3)*(6-3), query.get(0)); + // The distance limit is now (9-3)*(9-3) = 36, so this is not good enough. + expect_not_match("tensor(x[2]):[1,10]", node); + + expect_match("tensor(x[2]):[1,5]", (5-3)*(5-3), node); + + // The distance limit is now (7-3)*(7-3) = 16, so this is not good enough. + expect_not_match("tensor(x[2]):[1,8]", node); + + // This is not considered a document match as get_raw_score() is not called, + // and the distance heap is not updated. + match("tensor(x[2]):[1,4]"); + EXPECT_EQ(1, node.get_distance().value()); + EXPECT_TRUE(node.evaluate()); + + // The distance limit is still (7-3)*(7-3) = 16, so this is in fact good enough. + expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node); + + // The distance limit is (6-3)*(6-3) = 4, and a similar distance is a match. + expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node); +} + +TEST_F(NearestNeighborSearcherTest, raw_score_calculated_with_distance_threshold) +{ + query.add("qt1", 10, 3.0); + const auto& node = query.get(0); + set_query_tensor("qt1", "tensor(x[2]):[1,3]"); + prepare(); + + expect_match("tensor(x[2]):[1,5]", (5-3)*(5-3), node); + expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node); - match("tensor(x[2]):[1,7]"); // This is not a match since ((7-3)*(7-3) = 16) is larger than the internal distance threshold of (3*3 = 9). - EXPECT_FALSE(query.get(0).evaluate()); + expect_not_match("tensor(x[2]):[1,7]", node); } TEST_F(NearestNeighborSearcherTest, raw_score_calculated_for_two_query_operators) { - query.add("qt1", 3); - query.add("qt2", 4); + query.add("qt1", 10, 3.0); + query.add("qt2", 10, 4.0); set_query_tensor("qt1", "tensor(x[2]):[1,3]"); set_query_tensor("qt2", "tensor(x[2]):[1,4]"); prepare(); diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp index 9f3f3d770e4..4d425d9dedd 100644 --- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp +++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp @@ -55,6 +55,10 @@ RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder) _query_wrapper = std::make_unique<QueryWrapper>(*_query); } +class MockRawScoreCalculator : public search::streaming::NearestNeighborQueryNode::RawScoreCalculator { +public: + double to_raw_score(double distance) override { return distance * 2; } +}; TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) { @@ -71,6 +75,8 @@ TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) EXPECT_EQ(1u, term_list.size()); auto node = dynamic_cast<NearestNeighborQueryNode*>(term_list.front().getTerm()); EXPECT_NE(nullptr, node); + MockRawScoreCalculator calc; + node->set_raw_score_calc(&calc); auto& qtd = static_cast<QueryTermData &>(node->getQueryItem()); auto& td = qtd.getTermData(); constexpr TermFieldHandle handle = 27; @@ -82,11 +88,11 @@ TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) EXPECT_EQ(invalid_id, tfmd->getDocId()); RankProcessor::unpack_match_data(1, *md, *_query_wrapper); EXPECT_EQ(invalid_id, tfmd->getDocId()); - constexpr double raw_score = 1.5; - node->set_raw_score(raw_score); + constexpr double distance = 1.5; + node->set_distance(distance); RankProcessor::unpack_match_data(2, *md, *_query_wrapper); EXPECT_EQ(2, tfmd->getDocId()); - EXPECT_EQ(raw_score, tfmd->getRawScore()); + EXPECT_EQ(distance * 2, tfmd->getRawScore()); node->reset(); RankProcessor::unpack_match_data(3, *md, *_query_wrapper); EXPECT_EQ(2, tfmd->getDocId()); diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp index 01b21edc1ba..3ce137bffe5 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp @@ -241,7 +241,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap for (QueryWrapper::Term & term: query.getTermList()) { auto nn_node = term.getTerm()->as_nearest_neighbor_query_node(); if (nn_node != nullptr) { - auto& raw_score = nn_node->get_raw_score(); + auto raw_score = nn_node->get_raw_score(); if (raw_score.has_value()) { auto& qtd = static_cast<QueryTermData &>(term.getTerm()->getQueryItem()); auto& td = qtd.getTermData(); diff --git a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp index f064760e55d..db4ee12438e 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp @@ -48,8 +48,17 @@ NearestNeighborFieldSearcher::NodeAndCalc::NodeAndCalc(search::streaming::Neares std::unique_ptr<search::tensor::DistanceCalculator> calc_in) : node(node_in), calc(std::move(calc_in)), - distance_threshold(calc->function().convert_threshold(node->get_distance_threshold())) + heap(node->get_target_hits()) { + node->set_raw_score_calc(this); + heap.set_distance_threshold(calc->function().convert_threshold(node->get_distance_threshold())); +} + +double +NearestNeighborFieldSearcher::NodeAndCalc::to_raw_score(double distance) +{ + heap.used(distance); + return calc->function().to_rawscore(distance); } NearestNeighborFieldSearcher::NearestNeighborFieldSearcher(FieldIdT fid, @@ -100,7 +109,7 @@ NearestNeighborFieldSearcher::prepare(search::streaming::QueryTermList& qtl, } try { auto calc = DistanceCalculator::make_with_validation(*_attr, *tensor_value); - _calcs.emplace_back(nn_term, std::move(calc)); + _calcs.push_back(std::make_unique<NodeAndCalc>(nn_term, std::move(calc))); } catch (const vespalib::IllegalArgumentException& ex) { vespalib::Issue::report("Could not create DistanceCalculator for NearestNeighborQueryNode(%s, %s): %s", nn_term->index().c_str(), nn_term->get_query_tensor_name().c_str(), ex.what()); @@ -116,10 +125,10 @@ NearestNeighborFieldSearcher::onValue(const document::FieldValue& fv) if (tfv && tfv->getAsTensorPtr()) { _attr->add(*tfv->getAsTensorPtr(), 1); for (auto& elem : _calcs) { - double distance = elem.calc->calc_with_limit(scratch_docid, elem.distance_threshold); - if (distance <= elem.distance_threshold) { - double score = elem.calc->function().to_rawscore(distance); - elem.node->set_raw_score(score); + double distance_limit = elem->heap.distanceLimit(); + double distance = elem->calc->calc_with_limit(scratch_docid, distance_limit); + if (distance <= distance_limit) { + elem->node->set_distance(distance); } } } diff --git a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h index ba39b91c677..d5d751cd637 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h +++ b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h @@ -5,6 +5,8 @@ #include "fieldsearcher.h" #include <vespa/eval/eval/value_type.h> #include <vespa/searchcommon/attribute/distance_metric.h> +#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> +#include <vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h> #include <vespa/searchlib/tensor/distance_calculator.h> #include <vespa/searchlib/tensor/tensor_ext_attribute.h> @@ -14,8 +16,6 @@ namespace search::tensor { class TensorExtAttribute; } -namespace search::streaming { class NearestNeighborQueryNode; } - namespace vsm { /** @@ -26,16 +26,19 @@ namespace vsm { */ class NearestNeighborFieldSearcher : public FieldSearcher { private: - struct NodeAndCalc { + class NodeAndCalc : search::streaming::NearestNeighborQueryNode::RawScoreCalculator { + public: search::streaming::NearestNeighborQueryNode* node; std::unique_ptr<search::tensor::DistanceCalculator> calc; - double distance_threshold; + search::queryeval::NearestNeighborDistanceHeap heap; NodeAndCalc(search::streaming::NearestNeighborQueryNode* node_in, std::unique_ptr<search::tensor::DistanceCalculator> calc_in); + + double to_raw_score(double distance) override; }; search::attribute::DistanceMetric _metric; std::unique_ptr<search::tensor::TensorExtAttribute> _attr; - std::vector<NodeAndCalc> _calcs; + std::vector<std::unique_ptr<NodeAndCalc>> _calcs; public: NearestNeighborFieldSearcher(FieldIdT fid, |