diff options
author | Tor Egge <Tor.Egge@online.no> | 2023-07-19 12:27:26 +0200 |
---|---|---|
committer | Tor Egge <Tor.Egge@online.no> | 2023-07-19 12:27:26 +0200 |
commit | 9f3fa50f106fc2b24bca1d132b82f86375b58c89 (patch) | |
tree | cf429e78f62b18ad906b3d82f8bf6f2cb71c070f /streamingvisitors/src | |
parent | 69d705ef0d9679a1a73f6c00ec2eabb584a6576a (diff) |
Unpack interleaved features for streaming search.
Diffstat (limited to 'streamingvisitors/src')
3 files changed, 78 insertions, 3 deletions
diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp index 4d425d9dedd..f408910d8c2 100644 --- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp +++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp @@ -18,6 +18,7 @@ using search::query::SimpleQueryNodeTypes; using search::query::StackDumpCreator; using search::streaming::NearestNeighborQueryNode; using search::streaming::Query; +using search::streaming::QueryTerm; using streaming::RankProcessor; using streaming::QueryTermData; using streaming::QueryTermDataFactory; @@ -34,6 +35,7 @@ protected: ~RankProcessorTest() override; void build_query(QueryBuilder<SimpleQueryNodeTypes> &builder); + void test_unpack_match_data_for_term_node(bool interleaved_features); }; RankProcessorTest::RankProcessorTest() @@ -55,6 +57,63 @@ RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder) _query_wrapper = std::make_unique<QueryWrapper>(*_query); } +void +RankProcessorTest::test_unpack_match_data_for_term_node(bool interleaved_features) +{ + QueryBuilder<SimpleQueryNodeTypes> builder; + constexpr int32_t id = 42; + constexpr int32_t weight = 1; + builder.addStringTerm("term", "field", id, Weight(weight)); + build_query(builder); + auto& term_list = _query_wrapper->getTermList(); + EXPECT_EQ(1u, term_list.size()); + auto node = dynamic_cast<QueryTerm*>(term_list.front().getTerm()); + EXPECT_NE(nullptr, node); + auto& qtd = static_cast<QueryTermData &>(node->getQueryItem()); + auto& td = qtd.getTermData(); + constexpr TermFieldHandle handle = 27; + constexpr uint32_t field_id = 12; + constexpr uint32_t mock_num_occs = 2; + constexpr uint32_t mock_field_length = 101; + td.addField(field_id).setHandle(handle); + node->resizeFieldId(field_id); + auto md = MatchData::makeTestInstance(handle + 1, handle + 1); + auto tfmd = md->resolveTermField(handle); + tfmd->setNeedInterleavedFeatures(interleaved_features); + auto invalid_id = TermFieldMatchData::invalidId(); + EXPECT_EQ(invalid_id, tfmd->getDocId()); + RankProcessor::unpack_match_data(1, *md, *_query_wrapper); + EXPECT_EQ(invalid_id, tfmd->getDocId()); + node->add(0, field_id, 0, 1); + auto& field_info = node->getFieldInfo(field_id); + field_info.setHitCount(mock_num_occs); + field_info.setFieldLength(mock_field_length); + RankProcessor::unpack_match_data(2, *md, *_query_wrapper); + EXPECT_EQ(2, tfmd->getDocId()); + if (interleaved_features) { + EXPECT_EQ(mock_num_occs, tfmd->getNumOccs()); + EXPECT_EQ(mock_field_length, tfmd->getFieldLength()); + } else { + EXPECT_EQ(0, tfmd->getNumOccs()); + EXPECT_EQ(0, tfmd->getFieldLength()); + } + EXPECT_EQ(1, tfmd->size()); + node->reset(); + RankProcessor::unpack_match_data(3, *md, *_query_wrapper); + EXPECT_EQ(2, tfmd->getDocId()); +} + + +TEST_F(RankProcessorTest, unpack_normal_match_data_for_term_node) +{ + test_unpack_match_data_for_term_node(false); +} + +TEST_F(RankProcessorTest, unpack_interleaved_match_data_for_term_node) +{ + test_unpack_match_data_for_term_node(true); +} + class MockRawScoreCalculator : public search::streaming::NearestNeighborQueryNode::RawScoreCalculator { public: double to_raw_score(double distance) override { return distance * 2; } diff --git a/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h b/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h index c5dc442e424..8084d776efe 100644 --- a/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h +++ b/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h @@ -55,7 +55,7 @@ public: // inherit documentation virtual const search::attribute::IAttributeContext & getAttributeContext() const override { return *_attrCtx; } - double get_average_field_length(const vespalib::string &) const override { return 1.0; } + double get_average_field_length(const vespalib::string &) const override { return 100.0; } // inherit documentation virtual const search::fef::IIndexEnvironment & getIndexEnvironment() const override { return _indexEnv; } diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp index 78d72102fe9..17056d9d4b7 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp @@ -6,6 +6,7 @@ #include <vespa/searchlib/fef/simpletermfielddata.h> #include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> #include <vespa/vsm/vsm/fieldsearchspec.h> +#include <algorithm> #include <cmath> #include <vespa/log/log.h> LOG_SETUP(".searchvisitor.rankprocessor"); @@ -50,6 +51,11 @@ getFeature(const RankProgram &rankProgram) { return resolver.resolve(0); } +uint16_t +cap_16_bits(uint32_t value) { + return std::min(value, static_cast<uint32_t>(std::numeric_limits<uint16_t>::max())); +} + } void @@ -284,6 +290,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap uint32_t lastFieldId = -1; TermFieldMatchData *tmd = nullptr; uint32_t fieldLen = search::fef::FieldPositionsIterator::UNKNOWN_LENGTH; + uint32_t num_occs = 0; // optimize for hitlist giving all hits for a single field in one chunk for (const Hit & hit : hitList) { @@ -292,6 +299,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap // reset to notfound/unknown values tmd = nullptr; fieldLen = search::fef::FieldPositionsIterator::UNKNOWN_LENGTH; + num_occs = 0; // setup for new field that had a hit const ITermFieldData *tfd = td.lookupField(fieldId); @@ -306,11 +314,15 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap // find fieldLen for new field if (isPhrase) { if (fieldId < term.getParent()->getFieldInfoSize()) { - fieldLen = term.getParent()->getFieldInfo(fieldId).getFieldLength(); + auto& field_info = term.getParent()->getFieldInfo(fieldId); + fieldLen = field_info.getFieldLength(); + num_occs = field_info.getHitCount(); } } else { if (fieldId < term.getTerm()->getFieldInfoSize()) { - fieldLen = term.getTerm()->getFieldInfo(fieldId).getFieldLength(); + auto& field_info = term.getTerm()->getFieldInfo(fieldId); + fieldLen = field_info.getFieldLength(); + num_occs = field_info.getHitCount(); } } lastFieldId = fieldId; @@ -322,6 +334,10 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap tmd->appendPosition(pos); LOG(debug, "Append elemId(%u),position(%u), weight(%d), tfmd.weight(%d)", pos.getElementId(), pos.getPosition(), pos.getElementWeight(), tmd->getWeight()); + if (tmd->needs_interleaved_features()) { + tmd->setFieldLength(cap_16_bits(fieldLen)); + tmd->setNumOccs(cap_16_bits(num_occs)); + } } } } |