aboutsummaryrefslogtreecommitdiffstats
path: root/streamingvisitors
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@online.no>2023-07-19 12:27:26 +0200
committerTor Egge <Tor.Egge@online.no>2023-07-19 12:27:26 +0200
commit9f3fa50f106fc2b24bca1d132b82f86375b58c89 (patch)
treecf429e78f62b18ad906b3d82f8bf6f2cb71c070f /streamingvisitors
parent69d705ef0d9679a1a73f6c00ec2eabb584a6576a (diff)
Unpack interleaved features for streaming search.
Diffstat (limited to 'streamingvisitors')
-rw-r--r--streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp59
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/queryenvironment.h2
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp20
3 files changed, 78 insertions, 3 deletions
diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
index 4d425d9dedd..f408910d8c2 100644
--- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
+++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
@@ -18,6 +18,7 @@ using search::query::SimpleQueryNodeTypes;
using search::query::StackDumpCreator;
using search::streaming::NearestNeighborQueryNode;
using search::streaming::Query;
+using search::streaming::QueryTerm;
using streaming::RankProcessor;
using streaming::QueryTermData;
using streaming::QueryTermDataFactory;
@@ -34,6 +35,7 @@ protected:
~RankProcessorTest() override;
void build_query(QueryBuilder<SimpleQueryNodeTypes> &builder);
+ void test_unpack_match_data_for_term_node(bool interleaved_features);
};
RankProcessorTest::RankProcessorTest()
@@ -55,6 +57,63 @@ RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder)
_query_wrapper = std::make_unique<QueryWrapper>(*_query);
}
+void
+RankProcessorTest::test_unpack_match_data_for_term_node(bool interleaved_features)
+{
+ QueryBuilder<SimpleQueryNodeTypes> builder;
+ constexpr int32_t id = 42;
+ constexpr int32_t weight = 1;
+ builder.addStringTerm("term", "field", id, Weight(weight));
+ build_query(builder);
+ auto& term_list = _query_wrapper->getTermList();
+ EXPECT_EQ(1u, term_list.size());
+ auto node = dynamic_cast<QueryTerm*>(term_list.front().getTerm());
+ EXPECT_NE(nullptr, node);
+ auto& qtd = static_cast<QueryTermData &>(node->getQueryItem());
+ auto& td = qtd.getTermData();
+ constexpr TermFieldHandle handle = 27;
+ constexpr uint32_t field_id = 12;
+ constexpr uint32_t mock_num_occs = 2;
+ constexpr uint32_t mock_field_length = 101;
+ td.addField(field_id).setHandle(handle);
+ node->resizeFieldId(field_id);
+ auto md = MatchData::makeTestInstance(handle + 1, handle + 1);
+ auto tfmd = md->resolveTermField(handle);
+ tfmd->setNeedInterleavedFeatures(interleaved_features);
+ auto invalid_id = TermFieldMatchData::invalidId();
+ EXPECT_EQ(invalid_id, tfmd->getDocId());
+ RankProcessor::unpack_match_data(1, *md, *_query_wrapper);
+ EXPECT_EQ(invalid_id, tfmd->getDocId());
+ node->add(0, field_id, 0, 1);
+ auto& field_info = node->getFieldInfo(field_id);
+ field_info.setHitCount(mock_num_occs);
+ field_info.setFieldLength(mock_field_length);
+ RankProcessor::unpack_match_data(2, *md, *_query_wrapper);
+ EXPECT_EQ(2, tfmd->getDocId());
+ if (interleaved_features) {
+ EXPECT_EQ(mock_num_occs, tfmd->getNumOccs());
+ EXPECT_EQ(mock_field_length, tfmd->getFieldLength());
+ } else {
+ EXPECT_EQ(0, tfmd->getNumOccs());
+ EXPECT_EQ(0, tfmd->getFieldLength());
+ }
+ EXPECT_EQ(1, tfmd->size());
+ node->reset();
+ RankProcessor::unpack_match_data(3, *md, *_query_wrapper);
+ EXPECT_EQ(2, tfmd->getDocId());
+}
+
+
+TEST_F(RankProcessorTest, unpack_normal_match_data_for_term_node)
+{
+ test_unpack_match_data_for_term_node(false);
+}
+
+TEST_F(RankProcessorTest, unpack_interleaved_match_data_for_term_node)
+{
+ test_unpack_match_data_for_term_node(true);
+}
+
class MockRawScoreCalculator : public search::streaming::NearestNeighborQueryNode::RawScoreCalculator {
public:
double to_raw_score(double distance) override { return distance * 2; }
diff --git a/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h b/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h
index c5dc442e424..8084d776efe 100644
--- a/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h
+++ b/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h
@@ -55,7 +55,7 @@ public:
// inherit documentation
virtual const search::attribute::IAttributeContext & getAttributeContext() const override { return *_attrCtx; }
- double get_average_field_length(const vespalib::string &) const override { return 1.0; }
+ double get_average_field_length(const vespalib::string &) const override { return 100.0; }
// inherit documentation
virtual const search::fef::IIndexEnvironment & getIndexEnvironment() const override { return _indexEnv; }
diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
index 78d72102fe9..17056d9d4b7 100644
--- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
+++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
@@ -6,6 +6,7 @@
#include <vespa/searchlib/fef/simpletermfielddata.h>
#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h>
#include <vespa/vsm/vsm/fieldsearchspec.h>
+#include <algorithm>
#include <cmath>
#include <vespa/log/log.h>
LOG_SETUP(".searchvisitor.rankprocessor");
@@ -50,6 +51,11 @@ getFeature(const RankProgram &rankProgram) {
return resolver.resolve(0);
}
+uint16_t
+cap_16_bits(uint32_t value) {
+ return std::min(value, static_cast<uint32_t>(std::numeric_limits<uint16_t>::max()));
+}
+
}
void
@@ -284,6 +290,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap
uint32_t lastFieldId = -1;
TermFieldMatchData *tmd = nullptr;
uint32_t fieldLen = search::fef::FieldPositionsIterator::UNKNOWN_LENGTH;
+ uint32_t num_occs = 0;
// optimize for hitlist giving all hits for a single field in one chunk
for (const Hit & hit : hitList) {
@@ -292,6 +299,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap
// reset to notfound/unknown values
tmd = nullptr;
fieldLen = search::fef::FieldPositionsIterator::UNKNOWN_LENGTH;
+ num_occs = 0;
// setup for new field that had a hit
const ITermFieldData *tfd = td.lookupField(fieldId);
@@ -306,11 +314,15 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap
// find fieldLen for new field
if (isPhrase) {
if (fieldId < term.getParent()->getFieldInfoSize()) {
- fieldLen = term.getParent()->getFieldInfo(fieldId).getFieldLength();
+ auto& field_info = term.getParent()->getFieldInfo(fieldId);
+ fieldLen = field_info.getFieldLength();
+ num_occs = field_info.getHitCount();
}
} else {
if (fieldId < term.getTerm()->getFieldInfoSize()) {
- fieldLen = term.getTerm()->getFieldInfo(fieldId).getFieldLength();
+ auto& field_info = term.getTerm()->getFieldInfo(fieldId);
+ fieldLen = field_info.getFieldLength();
+ num_occs = field_info.getHitCount();
}
}
lastFieldId = fieldId;
@@ -322,6 +334,10 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap
tmd->appendPosition(pos);
LOG(debug, "Append elemId(%u),position(%u), weight(%d), tfmd.weight(%d)",
pos.getElementId(), pos.getPosition(), pos.getElementWeight(), tmd->getWeight());
+ if (tmd->needs_interleaved_features()) {
+ tmd->setFieldLength(cap_16_bits(fieldLen));
+ tmd->setNumOccs(cap_16_bits(num_occs));
+ }
}
}
}