aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-04-19 16:24:31 +0200
committerGitHub <noreply@github.com>2023-04-19 16:24:31 +0200
commit4bf83d5e87a8896ce3b6a14fb0889a2891053bf1 (patch)
treee63949be6f3ee75a1f16088ddfe9fc7d583a91ee
parent6c98021a888d31632eeb2140c771b4a07a60ed73 (diff)
parente9baa89a4807e6ddfbc49331a27ae143982d8d9f (diff)
Merge pull request #26786 from vespa-engine/toregge/unpack-match-data-for-nearest-neighbor-query-node-in-streaming-searchv8.155.19
Unpack match data for nearest neighbor query node in streaming search.
-rw-r--r--streamingvisitors/CMakeLists.txt1
-rw-r--r--streamingvisitors/src/tests/rank_processor/CMakeLists.txt9
-rw-r--r--streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp95
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp26
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/rankprocessor.h3
5 files changed, 126 insertions, 8 deletions
diff --git a/streamingvisitors/CMakeLists.txt b/streamingvisitors/CMakeLists.txt
index fede7087d8d..1e62203ea03 100644
--- a/streamingvisitors/CMakeLists.txt
+++ b/streamingvisitors/CMakeLists.txt
@@ -26,6 +26,7 @@ vespa_define_module(
src/tests/docsum
src/tests/document
src/tests/query_term_filter_factory
+ src/tests/rank_processor
src/tests/searcher
src/tests/textutil
)
diff --git a/streamingvisitors/src/tests/rank_processor/CMakeLists.txt b/streamingvisitors/src/tests/rank_processor/CMakeLists.txt
new file mode 100644
index 00000000000..6ae79c6382d
--- /dev/null
+++ b/streamingvisitors/src/tests/rank_processor/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(streamingvisitors_rank_processor_test_app TEST
+ SOURCES
+ rank_processor_test.cpp
+ DEPENDS
+ streamingvisitors
+ GTest::GTest
+)
+vespa_add_test(NAME streamingvisitors_rank_processor_test_app COMMAND streamingvisitors_rank_processor_test_app)
diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
new file mode 100644
index 00000000000..9f3f3d770e4
--- /dev/null
+++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
@@ -0,0 +1,95 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/searchvisitor/rankprocessor.h>
+#include <vespa/searchlib/query/streaming/query.h>
+#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h>
+#include <vespa/searchlib/query/tree/querybuilder.h>
+#include <vespa/searchlib/query/tree/simplequery.h>
+#include <vespa/searchlib/query/tree/stackdumpcreator.h>
+#include <vespa/searchvisitor/querytermdata.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using search::fef::MatchData;
+using search::fef::TermFieldHandle;
+using search::fef::TermFieldMatchData;
+using search::query::Weight;
+using search::query::QueryBuilder;
+using search::query::SimpleQueryNodeTypes;
+using search::query::StackDumpCreator;
+using search::streaming::NearestNeighborQueryNode;
+using search::streaming::Query;
+using streaming::RankProcessor;
+using streaming::QueryTermData;
+using streaming::QueryTermDataFactory;
+using streaming::QueryWrapper;
+
+class RankProcessorTest : public testing::Test
+{
+protected:
+ QueryTermDataFactory _factory;
+ std::unique_ptr<Query> _query;
+ std::unique_ptr<QueryWrapper> _query_wrapper;
+
+ RankProcessorTest();
+ ~RankProcessorTest() override;
+
+ void build_query(QueryBuilder<SimpleQueryNodeTypes> &builder);
+};
+
+RankProcessorTest::RankProcessorTest()
+ : testing::Test(),
+ _factory(),
+ _query(),
+ _query_wrapper()
+{
+}
+
+RankProcessorTest::~RankProcessorTest() = default;
+
+void
+RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder)
+{
+ auto build_node = builder.build();
+ auto stack_dump = StackDumpCreator::create(*build_node);
+ _query = std::make_unique<Query>(_factory, stack_dump);
+ _query_wrapper = std::make_unique<QueryWrapper>(*_query);
+}
+
+
+TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node)
+{
+ QueryBuilder<SimpleQueryNodeTypes> builder;
+ constexpr double distance_threshold = 35.5;
+ constexpr int32_t id = 42;
+ constexpr int32_t weight = 1;
+ constexpr uint32_t target_num_hits = 100;
+ constexpr bool allow_approximate = false;
+ constexpr uint32_t explore_additional_hits = 800;
+ builder.add_nearest_neighbor_term("qtensor", "field", id, Weight(weight), target_num_hits, allow_approximate, explore_additional_hits, distance_threshold);
+ build_query(builder);
+ auto& term_list = _query_wrapper->getTermList();
+ EXPECT_EQ(1u, term_list.size());
+ auto node = dynamic_cast<NearestNeighborQueryNode*>(term_list.front().getTerm());
+ EXPECT_NE(nullptr, node);
+ auto& qtd = static_cast<QueryTermData &>(node->getQueryItem());
+ auto& td = qtd.getTermData();
+ constexpr TermFieldHandle handle = 27;
+ constexpr uint32_t field_id = 12;
+ td.addField(field_id).setHandle(handle);
+ auto md = MatchData::makeTestInstance(handle + 1, handle + 1);
+ auto tfmd = md->resolveTermField(handle);
+ auto invalid_id = TermFieldMatchData::invalidId();
+ EXPECT_EQ(invalid_id, tfmd->getDocId());
+ RankProcessor::unpack_match_data(1, *md, *_query_wrapper);
+ EXPECT_EQ(invalid_id, tfmd->getDocId());
+ constexpr double raw_score = 1.5;
+ node->set_raw_score(raw_score);
+ RankProcessor::unpack_match_data(2, *md, *_query_wrapper);
+ EXPECT_EQ(2, tfmd->getDocId());
+ EXPECT_EQ(raw_score, tfmd->getRawScore());
+ node->reset();
+ RankProcessor::unpack_match_data(3, *md, *_query_wrapper);
+ EXPECT_EQ(2, tfmd->getDocId());
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
index 24925bd67ee..b41eb041c57 100644
--- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
+++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
@@ -4,6 +4,7 @@
#include "rankprocessor.h"
#include <vespa/searchlib/fef/handle.h>
#include <vespa/searchlib/fef/simpletermfielddata.h>
+#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h>
#include <vespa/vsm/vsm/fieldsearchspec.h>
#include <cmath>
#include <vespa/log/log.h>
@@ -227,14 +228,27 @@ void
RankProcessor::unpackMatchData(uint32_t docId)
{
_docId = docId;
- unpackMatchData(*_match_data);
+ unpack_match_data(docId, *_match_data, _query);
}
void
-RankProcessor::unpackMatchData(MatchData &matchData)
+RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrapper& query)
{
- for (QueryWrapper::Term & term: _query.getTermList()) {
- if (!term.isPhraseTerm() || term.isFirstPhraseTerm()) { // consider 1 term data per phrase
+ for (QueryWrapper::Term & term: query.getTermList()) {
+ auto nn_node = term.getTerm()->as_nearest_neighbor_query_node();
+ if (nn_node != nullptr) {
+ auto& raw_score = nn_node->get_raw_score();
+ if (raw_score.has_value()) {
+ auto& qtd = static_cast<QueryTermData &>(term.getTerm()->getQueryItem());
+ auto& td = qtd.getTermData();
+ if (td.numFields() == 1u) {
+ auto tfd = td.field(0u);
+ auto tmd = matchData.resolveTermField(tfd.getHandle());
+ assert(tmd != nullptr);
+ tmd->setRawScore(docid, raw_score.value());
+ }
+ }
+ } else if (!term.isPhraseTerm() || term.isFirstPhraseTerm()) { // consider 1 term data per phrase
bool isPhrase = term.isFirstPhraseTerm();
QueryTermData & qtd = static_cast<QueryTermData &>(term.getTerm()->getQueryItem());
const ITermData &td = qtd.getTermData();
@@ -266,8 +280,8 @@ RankProcessor::unpackMatchData(MatchData &matchData)
tmd = matchData.resolveTermField(tfd->getHandle());
tmd->setFieldId(fieldId);
// reset field match data, but only once per docId
- if (tmd->getDocId() != _docId) {
- tmd->reset(_docId);
+ if (tmd->getDocId() != docid) {
+ tmd->reset(docid);
}
}
// find fieldLen for new field
diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h
index c2b3d8adedf..039a9386539 100644
--- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h
+++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h
@@ -50,8 +50,6 @@ private:
**/
void init(bool forRanking, size_t wantedHitCount);
- void unpackMatchData(search::fef::MatchData &matchData);
-
public:
using UP = std::unique_ptr<RankProcessor>;
@@ -65,6 +63,7 @@ public:
void initForRanking(size_t wantedHitCount);
void initForDumping(size_t wantedHitCount);
void unpackMatchData(uint32_t docId);
+ static void unpack_match_data(uint32_t docid, search::fef::MatchData& matchData, QueryWrapper& query);
void runRankProgram(uint32_t docId);
search::FeatureSet::SP calculateFeatureSet();
void fillSearchResult(vdslib::SearchResult & searchResult);