From e9baa89a4807e6ddfbc49331a27ae143982d8d9f Mon Sep 17 00:00:00 2001 From: Tor Egge Date: Wed, 19 Apr 2023 15:19:30 +0200 Subject: Unpack match data for nearest neighbor query node in streaming search. --- streamingvisitors/CMakeLists.txt | 1 + .../src/tests/rank_processor/CMakeLists.txt | 9 ++ .../tests/rank_processor/rank_processor_test.cpp | 95 ++++++++++++++++++++++ .../src/vespa/searchvisitor/rankprocessor.cpp | 26 ++++-- .../src/vespa/searchvisitor/rankprocessor.h | 3 +- 5 files changed, 126 insertions(+), 8 deletions(-) create mode 100644 streamingvisitors/src/tests/rank_processor/CMakeLists.txt create mode 100644 streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp diff --git a/streamingvisitors/CMakeLists.txt b/streamingvisitors/CMakeLists.txt index fede7087d8d..1e62203ea03 100644 --- a/streamingvisitors/CMakeLists.txt +++ b/streamingvisitors/CMakeLists.txt @@ -26,6 +26,7 @@ vespa_define_module( src/tests/docsum src/tests/document src/tests/query_term_filter_factory + src/tests/rank_processor src/tests/searcher src/tests/textutil ) diff --git a/streamingvisitors/src/tests/rank_processor/CMakeLists.txt b/streamingvisitors/src/tests/rank_processor/CMakeLists.txt new file mode 100644 index 00000000000..6ae79c6382d --- /dev/null +++ b/streamingvisitors/src/tests/rank_processor/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(streamingvisitors_rank_processor_test_app TEST + SOURCES + rank_processor_test.cpp + DEPENDS + streamingvisitors + GTest::GTest +) +vespa_add_test(NAME streamingvisitors_rank_processor_test_app COMMAND streamingvisitors_rank_processor_test_app) diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp new file mode 100644 index 00000000000..9f3f3d770e4 --- /dev/null +++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp @@ -0,0 +1,95 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include +#include +#include +#include +#include +#include +#include +#include + +using search::fef::MatchData; +using search::fef::TermFieldHandle; +using search::fef::TermFieldMatchData; +using search::query::Weight; +using search::query::QueryBuilder; +using search::query::SimpleQueryNodeTypes; +using search::query::StackDumpCreator; +using search::streaming::NearestNeighborQueryNode; +using search::streaming::Query; +using streaming::RankProcessor; +using streaming::QueryTermData; +using streaming::QueryTermDataFactory; +using streaming::QueryWrapper; + +class RankProcessorTest : public testing::Test +{ +protected: + QueryTermDataFactory _factory; + std::unique_ptr _query; + std::unique_ptr _query_wrapper; + + RankProcessorTest(); + ~RankProcessorTest() override; + + void build_query(QueryBuilder &builder); +}; + +RankProcessorTest::RankProcessorTest() + : testing::Test(), + _factory(), + _query(), + _query_wrapper() +{ +} + +RankProcessorTest::~RankProcessorTest() = default; + +void +RankProcessorTest::build_query(QueryBuilder &builder) +{ + auto build_node = builder.build(); + auto stack_dump = StackDumpCreator::create(*build_node); + _query = std::make_unique(_factory, stack_dump); + _query_wrapper = std::make_unique(*_query); +} + + +TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) +{ + QueryBuilder builder; + constexpr double distance_threshold = 35.5; + constexpr int32_t id = 42; + constexpr int32_t weight = 1; + constexpr uint32_t target_num_hits = 100; + constexpr bool allow_approximate = false; + constexpr uint32_t explore_additional_hits = 800; + builder.add_nearest_neighbor_term("qtensor", "field", id, Weight(weight), target_num_hits, allow_approximate, explore_additional_hits, distance_threshold); + build_query(builder); + auto& term_list = _query_wrapper->getTermList(); + EXPECT_EQ(1u, term_list.size()); + auto node = dynamic_cast(term_list.front().getTerm()); + EXPECT_NE(nullptr, node); + auto& qtd = static_cast(node->getQueryItem()); + auto& td = qtd.getTermData(); + constexpr TermFieldHandle handle = 27; + constexpr uint32_t field_id = 12; + td.addField(field_id).setHandle(handle); + auto md = MatchData::makeTestInstance(handle + 1, handle + 1); + auto tfmd = md->resolveTermField(handle); + auto invalid_id = TermFieldMatchData::invalidId(); + EXPECT_EQ(invalid_id, tfmd->getDocId()); + RankProcessor::unpack_match_data(1, *md, *_query_wrapper); + EXPECT_EQ(invalid_id, tfmd->getDocId()); + constexpr double raw_score = 1.5; + node->set_raw_score(raw_score); + RankProcessor::unpack_match_data(2, *md, *_query_wrapper); + EXPECT_EQ(2, tfmd->getDocId()); + EXPECT_EQ(raw_score, tfmd->getRawScore()); + node->reset(); + RankProcessor::unpack_match_data(3, *md, *_query_wrapper); + EXPECT_EQ(2, tfmd->getDocId()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp index 24925bd67ee..b41eb041c57 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp @@ -4,6 +4,7 @@ #include "rankprocessor.h" #include #include +#include #include #include #include @@ -227,14 +228,27 @@ void RankProcessor::unpackMatchData(uint32_t docId) { _docId = docId; - unpackMatchData(*_match_data); + unpack_match_data(docId, *_match_data, _query); } void -RankProcessor::unpackMatchData(MatchData &matchData) +RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrapper& query) { - for (QueryWrapper::Term & term: _query.getTermList()) { - if (!term.isPhraseTerm() || term.isFirstPhraseTerm()) { // consider 1 term data per phrase + for (QueryWrapper::Term & term: query.getTermList()) { + auto nn_node = term.getTerm()->as_nearest_neighbor_query_node(); + if (nn_node != nullptr) { + auto& raw_score = nn_node->get_raw_score(); + if (raw_score.has_value()) { + auto& qtd = static_cast(term.getTerm()->getQueryItem()); + auto& td = qtd.getTermData(); + if (td.numFields() == 1u) { + auto tfd = td.field(0u); + auto tmd = matchData.resolveTermField(tfd.getHandle()); + assert(tmd != nullptr); + tmd->setRawScore(docid, raw_score.value()); + } + } + } else if (!term.isPhraseTerm() || term.isFirstPhraseTerm()) { // consider 1 term data per phrase bool isPhrase = term.isFirstPhraseTerm(); QueryTermData & qtd = static_cast(term.getTerm()->getQueryItem()); const ITermData &td = qtd.getTermData(); @@ -266,8 +280,8 @@ RankProcessor::unpackMatchData(MatchData &matchData) tmd = matchData.resolveTermField(tfd->getHandle()); tmd->setFieldId(fieldId); // reset field match data, but only once per docId - if (tmd->getDocId() != _docId) { - tmd->reset(_docId); + if (tmd->getDocId() != docid) { + tmd->reset(docid); } } // find fieldLen for new field diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h index c2b3d8adedf..039a9386539 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h @@ -50,8 +50,6 @@ private: **/ void init(bool forRanking, size_t wantedHitCount); - void unpackMatchData(search::fef::MatchData &matchData); - public: using UP = std::unique_ptr; @@ -65,6 +63,7 @@ public: void initForRanking(size_t wantedHitCount); void initForDumping(size_t wantedHitCount); void unpackMatchData(uint32_t docId); + static void unpack_match_data(uint32_t docid, search::fef::MatchData& matchData, QueryWrapper& query); void runRankProgram(uint32_t docId); search::FeatureSet::SP calculateFeatureSet(); void fillSearchResult(vdslib::SearchResult & searchResult); -- cgit v1.2.3