diff options
author | Tor Egge <Tor.Egge@online.no> | 2023-04-19 15:19:30 +0200 |
---|---|---|
committer | Tor Egge <Tor.Egge@online.no> | 2023-04-19 15:19:30 +0200 |
commit | e9baa89a4807e6ddfbc49331a27ae143982d8d9f (patch) | |
tree | e63949be6f3ee75a1f16088ddfe9fc7d583a91ee | |
parent | 6c98021a888d31632eeb2140c771b4a07a60ed73 (diff) |
Unpack match data for nearest neighbor query node in streaming search.
5 files changed, 126 insertions, 8 deletions
diff --git a/streamingvisitors/CMakeLists.txt b/streamingvisitors/CMakeLists.txt index fede7087d8d..1e62203ea03 100644 --- a/streamingvisitors/CMakeLists.txt +++ b/streamingvisitors/CMakeLists.txt @@ -26,6 +26,7 @@ vespa_define_module( src/tests/docsum src/tests/document src/tests/query_term_filter_factory + src/tests/rank_processor src/tests/searcher src/tests/textutil ) diff --git a/streamingvisitors/src/tests/rank_processor/CMakeLists.txt b/streamingvisitors/src/tests/rank_processor/CMakeLists.txt new file mode 100644 index 00000000000..6ae79c6382d --- /dev/null +++ b/streamingvisitors/src/tests/rank_processor/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(streamingvisitors_rank_processor_test_app TEST + SOURCES + rank_processor_test.cpp + DEPENDS + streamingvisitors + GTest::GTest +) +vespa_add_test(NAME streamingvisitors_rank_processor_test_app COMMAND streamingvisitors_rank_processor_test_app) diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp new file mode 100644 index 00000000000..9f3f3d770e4 --- /dev/null +++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp @@ -0,0 +1,95 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchvisitor/rankprocessor.h> +#include <vespa/searchlib/query/streaming/query.h> +#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> +#include <vespa/searchlib/query/tree/querybuilder.h> +#include <vespa/searchlib/query/tree/simplequery.h> +#include <vespa/searchlib/query/tree/stackdumpcreator.h> +#include <vespa/searchvisitor/querytermdata.h> +#include <vespa/vespalib/gtest/gtest.h> + +using search::fef::MatchData; +using search::fef::TermFieldHandle; +using search::fef::TermFieldMatchData; +using search::query::Weight; +using search::query::QueryBuilder; +using search::query::SimpleQueryNodeTypes; +using search::query::StackDumpCreator; +using search::streaming::NearestNeighborQueryNode; +using search::streaming::Query; +using streaming::RankProcessor; +using streaming::QueryTermData; +using streaming::QueryTermDataFactory; +using streaming::QueryWrapper; + +class RankProcessorTest : public testing::Test +{ +protected: + QueryTermDataFactory _factory; + std::unique_ptr<Query> _query; + std::unique_ptr<QueryWrapper> _query_wrapper; + + RankProcessorTest(); + ~RankProcessorTest() override; + + void build_query(QueryBuilder<SimpleQueryNodeTypes> &builder); +}; + +RankProcessorTest::RankProcessorTest() + : testing::Test(), + _factory(), + _query(), + _query_wrapper() +{ +} + +RankProcessorTest::~RankProcessorTest() = default; + +void +RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder) +{ + auto build_node = builder.build(); + auto stack_dump = StackDumpCreator::create(*build_node); + _query = std::make_unique<Query>(_factory, stack_dump); + _query_wrapper = std::make_unique<QueryWrapper>(*_query); +} + + +TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) +{ + QueryBuilder<SimpleQueryNodeTypes> builder; + constexpr double distance_threshold = 35.5; + constexpr int32_t id = 42; + constexpr int32_t weight = 1; + constexpr uint32_t target_num_hits = 100; + constexpr bool allow_approximate = false; + constexpr uint32_t explore_additional_hits = 800; + builder.add_nearest_neighbor_term("qtensor", "field", id, Weight(weight), target_num_hits, allow_approximate, explore_additional_hits, distance_threshold); + build_query(builder); + auto& term_list = _query_wrapper->getTermList(); + EXPECT_EQ(1u, term_list.size()); + auto node = dynamic_cast<NearestNeighborQueryNode*>(term_list.front().getTerm()); + EXPECT_NE(nullptr, node); + auto& qtd = static_cast<QueryTermData &>(node->getQueryItem()); + auto& td = qtd.getTermData(); + constexpr TermFieldHandle handle = 27; + constexpr uint32_t field_id = 12; + td.addField(field_id).setHandle(handle); + auto md = MatchData::makeTestInstance(handle + 1, handle + 1); + auto tfmd = md->resolveTermField(handle); + auto invalid_id = TermFieldMatchData::invalidId(); + EXPECT_EQ(invalid_id, tfmd->getDocId()); + RankProcessor::unpack_match_data(1, *md, *_query_wrapper); + EXPECT_EQ(invalid_id, tfmd->getDocId()); + constexpr double raw_score = 1.5; + node->set_raw_score(raw_score); + RankProcessor::unpack_match_data(2, *md, *_query_wrapper); + EXPECT_EQ(2, tfmd->getDocId()); + EXPECT_EQ(raw_score, tfmd->getRawScore()); + node->reset(); + RankProcessor::unpack_match_data(3, *md, *_query_wrapper); + EXPECT_EQ(2, tfmd->getDocId()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp index 24925bd67ee..b41eb041c57 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp @@ -4,6 +4,7 @@ #include "rankprocessor.h" #include <vespa/searchlib/fef/handle.h> #include <vespa/searchlib/fef/simpletermfielddata.h> +#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> #include <vespa/vsm/vsm/fieldsearchspec.h> #include <cmath> #include <vespa/log/log.h> @@ -227,14 +228,27 @@ void RankProcessor::unpackMatchData(uint32_t docId) { _docId = docId; - unpackMatchData(*_match_data); + unpack_match_data(docId, *_match_data, _query); } void -RankProcessor::unpackMatchData(MatchData &matchData) +RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrapper& query) { - for (QueryWrapper::Term & term: _query.getTermList()) { - if (!term.isPhraseTerm() || term.isFirstPhraseTerm()) { // consider 1 term data per phrase + for (QueryWrapper::Term & term: query.getTermList()) { + auto nn_node = term.getTerm()->as_nearest_neighbor_query_node(); + if (nn_node != nullptr) { + auto& raw_score = nn_node->get_raw_score(); + if (raw_score.has_value()) { + auto& qtd = static_cast<QueryTermData &>(term.getTerm()->getQueryItem()); + auto& td = qtd.getTermData(); + if (td.numFields() == 1u) { + auto tfd = td.field(0u); + auto tmd = matchData.resolveTermField(tfd.getHandle()); + assert(tmd != nullptr); + tmd->setRawScore(docid, raw_score.value()); + } + } + } else if (!term.isPhraseTerm() || term.isFirstPhraseTerm()) { // consider 1 term data per phrase bool isPhrase = term.isFirstPhraseTerm(); QueryTermData & qtd = static_cast<QueryTermData &>(term.getTerm()->getQueryItem()); const ITermData &td = qtd.getTermData(); @@ -266,8 +280,8 @@ RankProcessor::unpackMatchData(MatchData &matchData) tmd = matchData.resolveTermField(tfd->getHandle()); tmd->setFieldId(fieldId); // reset field match data, but only once per docId - if (tmd->getDocId() != _docId) { - tmd->reset(_docId); + if (tmd->getDocId() != docid) { + tmd->reset(docid); } } // find fieldLen for new field diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h index c2b3d8adedf..039a9386539 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h @@ -50,8 +50,6 @@ private: **/ void init(bool forRanking, size_t wantedHitCount); - void unpackMatchData(search::fef::MatchData &matchData); - public: using UP = std::unique_ptr<RankProcessor>; @@ -65,6 +63,7 @@ public: void initForRanking(size_t wantedHitCount); void initForDumping(size_t wantedHitCount); void unpackMatchData(uint32_t docId); + static void unpack_match_data(uint32_t docid, search::fef::MatchData& matchData, QueryWrapper& query); void runRankProgram(uint32_t docId); search::FeatureSet::SP calculateFeatureSet(); void fillSearchResult(vdslib::SearchResult & searchResult); |