diff options
Diffstat (limited to 'streamingvisitors/src/tests/rank_processor')
-rw-r--r-- | streamingvisitors/src/tests/rank_processor/CMakeLists.txt | 9 | ||||
-rw-r--r-- | streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp | 95 |
2 files changed, 104 insertions, 0 deletions
diff --git a/streamingvisitors/src/tests/rank_processor/CMakeLists.txt b/streamingvisitors/src/tests/rank_processor/CMakeLists.txt new file mode 100644 index 00000000000..6ae79c6382d --- /dev/null +++ b/streamingvisitors/src/tests/rank_processor/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(streamingvisitors_rank_processor_test_app TEST + SOURCES + rank_processor_test.cpp + DEPENDS + streamingvisitors + GTest::GTest +) +vespa_add_test(NAME streamingvisitors_rank_processor_test_app COMMAND streamingvisitors_rank_processor_test_app) diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp new file mode 100644 index 00000000000..9f3f3d770e4 --- /dev/null +++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp @@ -0,0 +1,95 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchvisitor/rankprocessor.h> +#include <vespa/searchlib/query/streaming/query.h> +#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> +#include <vespa/searchlib/query/tree/querybuilder.h> +#include <vespa/searchlib/query/tree/simplequery.h> +#include <vespa/searchlib/query/tree/stackdumpcreator.h> +#include <vespa/searchvisitor/querytermdata.h> +#include <vespa/vespalib/gtest/gtest.h> + +using search::fef::MatchData; +using search::fef::TermFieldHandle; +using search::fef::TermFieldMatchData; +using search::query::Weight; +using search::query::QueryBuilder; +using search::query::SimpleQueryNodeTypes; +using search::query::StackDumpCreator; +using search::streaming::NearestNeighborQueryNode; +using search::streaming::Query; +using streaming::RankProcessor; +using streaming::QueryTermData; +using streaming::QueryTermDataFactory; +using streaming::QueryWrapper; + +class RankProcessorTest : public testing::Test +{ +protected: + QueryTermDataFactory _factory; + std::unique_ptr<Query> _query; + std::unique_ptr<QueryWrapper> _query_wrapper; + + RankProcessorTest(); + ~RankProcessorTest() override; + + void build_query(QueryBuilder<SimpleQueryNodeTypes> &builder); +}; + +RankProcessorTest::RankProcessorTest() + : testing::Test(), + _factory(), + _query(), + _query_wrapper() +{ +} + +RankProcessorTest::~RankProcessorTest() = default; + +void +RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder) +{ + auto build_node = builder.build(); + auto stack_dump = StackDumpCreator::create(*build_node); + _query = std::make_unique<Query>(_factory, stack_dump); + _query_wrapper = std::make_unique<QueryWrapper>(*_query); +} + + +TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) +{ + QueryBuilder<SimpleQueryNodeTypes> builder; + constexpr double distance_threshold = 35.5; + constexpr int32_t id = 42; + constexpr int32_t weight = 1; + constexpr uint32_t target_num_hits = 100; + constexpr bool allow_approximate = false; + constexpr uint32_t explore_additional_hits = 800; + builder.add_nearest_neighbor_term("qtensor", "field", id, Weight(weight), target_num_hits, allow_approximate, explore_additional_hits, distance_threshold); + build_query(builder); + auto& term_list = _query_wrapper->getTermList(); + EXPECT_EQ(1u, term_list.size()); + auto node = dynamic_cast<NearestNeighborQueryNode*>(term_list.front().getTerm()); + EXPECT_NE(nullptr, node); + auto& qtd = static_cast<QueryTermData &>(node->getQueryItem()); + auto& td = qtd.getTermData(); + constexpr TermFieldHandle handle = 27; + constexpr uint32_t field_id = 12; + td.addField(field_id).setHandle(handle); + auto md = MatchData::makeTestInstance(handle + 1, handle + 1); + auto tfmd = md->resolveTermField(handle); + auto invalid_id = TermFieldMatchData::invalidId(); + EXPECT_EQ(invalid_id, tfmd->getDocId()); + RankProcessor::unpack_match_data(1, *md, *_query_wrapper); + EXPECT_EQ(invalid_id, tfmd->getDocId()); + constexpr double raw_score = 1.5; + node->set_raw_score(raw_score); + RankProcessor::unpack_match_data(2, *md, *_query_wrapper); + EXPECT_EQ(2, tfmd->getDocId()); + EXPECT_EQ(raw_score, tfmd->getRawScore()); + node->reset(); + RankProcessor::unpack_match_data(3, *md, *_query_wrapper); + EXPECT_EQ(2, tfmd->getDocId()); +} + +GTEST_MAIN_RUN_ALL_TESTS() |