summaryrefslogtreecommitdiffstats
path: root/streamingvisitors/src/tests/rank_processor
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@online.no>2023-07-19 12:27:26 +0200
committerTor Egge <Tor.Egge@online.no>2023-07-19 12:27:26 +0200
commit9f3fa50f106fc2b24bca1d132b82f86375b58c89 (patch)
treecf429e78f62b18ad906b3d82f8bf6f2cb71c070f /streamingvisitors/src/tests/rank_processor
parent69d705ef0d9679a1a73f6c00ec2eabb584a6576a (diff)
Unpack interleaved features for streaming search.
Diffstat (limited to 'streamingvisitors/src/tests/rank_processor')
-rw-r--r--streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp59
1 files changed, 59 insertions, 0 deletions
diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
index 4d425d9dedd..f408910d8c2 100644
--- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
+++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
@@ -18,6 +18,7 @@ using search::query::SimpleQueryNodeTypes;
using search::query::StackDumpCreator;
using search::streaming::NearestNeighborQueryNode;
using search::streaming::Query;
+using search::streaming::QueryTerm;
using streaming::RankProcessor;
using streaming::QueryTermData;
using streaming::QueryTermDataFactory;
@@ -34,6 +35,7 @@ protected:
~RankProcessorTest() override;
void build_query(QueryBuilder<SimpleQueryNodeTypes> &builder);
+ void test_unpack_match_data_for_term_node(bool interleaved_features);
};
RankProcessorTest::RankProcessorTest()
@@ -55,6 +57,63 @@ RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder)
_query_wrapper = std::make_unique<QueryWrapper>(*_query);
}
+void
+RankProcessorTest::test_unpack_match_data_for_term_node(bool interleaved_features)
+{
+ QueryBuilder<SimpleQueryNodeTypes> builder;
+ constexpr int32_t id = 42;
+ constexpr int32_t weight = 1;
+ builder.addStringTerm("term", "field", id, Weight(weight));
+ build_query(builder);
+ auto& term_list = _query_wrapper->getTermList();
+ EXPECT_EQ(1u, term_list.size());
+ auto node = dynamic_cast<QueryTerm*>(term_list.front().getTerm());
+ EXPECT_NE(nullptr, node);
+ auto& qtd = static_cast<QueryTermData &>(node->getQueryItem());
+ auto& td = qtd.getTermData();
+ constexpr TermFieldHandle handle = 27;
+ constexpr uint32_t field_id = 12;
+ constexpr uint32_t mock_num_occs = 2;
+ constexpr uint32_t mock_field_length = 101;
+ td.addField(field_id).setHandle(handle);
+ node->resizeFieldId(field_id);
+ auto md = MatchData::makeTestInstance(handle + 1, handle + 1);
+ auto tfmd = md->resolveTermField(handle);
+ tfmd->setNeedInterleavedFeatures(interleaved_features);
+ auto invalid_id = TermFieldMatchData::invalidId();
+ EXPECT_EQ(invalid_id, tfmd->getDocId());
+ RankProcessor::unpack_match_data(1, *md, *_query_wrapper);
+ EXPECT_EQ(invalid_id, tfmd->getDocId());
+ node->add(0, field_id, 0, 1);
+ auto& field_info = node->getFieldInfo(field_id);
+ field_info.setHitCount(mock_num_occs);
+ field_info.setFieldLength(mock_field_length);
+ RankProcessor::unpack_match_data(2, *md, *_query_wrapper);
+ EXPECT_EQ(2, tfmd->getDocId());
+ if (interleaved_features) {
+ EXPECT_EQ(mock_num_occs, tfmd->getNumOccs());
+ EXPECT_EQ(mock_field_length, tfmd->getFieldLength());
+ } else {
+ EXPECT_EQ(0, tfmd->getNumOccs());
+ EXPECT_EQ(0, tfmd->getFieldLength());
+ }
+ EXPECT_EQ(1, tfmd->size());
+ node->reset();
+ RankProcessor::unpack_match_data(3, *md, *_query_wrapper);
+ EXPECT_EQ(2, tfmd->getDocId());
+}
+
+
+TEST_F(RankProcessorTest, unpack_normal_match_data_for_term_node)
+{
+ test_unpack_match_data_for_term_node(false);
+}
+
+TEST_F(RankProcessorTest, unpack_interleaved_match_data_for_term_node)
+{
+ test_unpack_match_data_for_term_node(true);
+}
+
class MockRawScoreCalculator : public search::streaming::NearestNeighborQueryNode::RawScoreCalculator {
public:
double to_raw_score(double distance) override { return distance * 2; }