summaryrefslogtreecommitdiffstats
path: root/streamingvisitors
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@online.no>2023-04-27 15:36:01 +0200
committerTor Egge <Tor.Egge@online.no>2023-04-27 15:36:01 +0200
commit00dc72987c75579cb85bca577a0d3648191b8203 (patch)
treefe0ad41893816baffdc1df07600c98e48f7c44e2 /streamingvisitors
parentf5dd3cb5d31875cf596adc01f2207f690afe553f (diff)
Populate match features in search result for streaming search.
Diffstat (limited to 'streamingvisitors')
-rw-r--r--streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp60
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp33
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/hitcollector.h4
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp21
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/rankprocessor.h5
5 files changed, 120 insertions, 3 deletions
diff --git a/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp b/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp
index 059d3c9f597..362aaeda938 100644
--- a/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp
+++ b/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp
@@ -11,6 +11,7 @@
#include <vespa/eval/eval/value_codec.h>
#include <vespa/vespalib/gtest/gtest.h>
#include <vespa/vespalib/objects/nbostream.h>
+#include <vespa/vespalib/util/stringfmt.h>
using namespace document;
using namespace search::fef;
@@ -22,7 +23,39 @@ using vespalib::eval::DoubleValue;
using vespalib::eval::SimpleValue;
using vespalib::eval::TensorSpec;
using vespalib::eval::Value;
+using vespalib::make_string_short::fmt;
+using FeatureValue = FeatureSet::Value;
+
+namespace {
+
+double as_double(const FeatureValue& v) {
+ EXPECT_TRUE(v.is_double());
+ return v.as_double();
+}
+
+TensorSpec as_spec(const FeatureValue& v) {
+ EXPECT_TRUE(v.is_data());
+ auto mem = v.as_data();
+ nbostream buf(mem.data, mem.size);
+ return spec_from_value(*SimpleValue::from_stream(buf));
+}
+
+ConstArrayRef<FeatureValue> as_value_slice(FeatureValues& mf, uint32_t index, uint32_t num_features)
+{
+ return { mf.values.data() + index * num_features, num_features };
+}
+
+void check_match_features(ConstArrayRef<FeatureValue> v, uint32_t docid)
+{
+ SCOPED_TRACE(fmt("Checking docid %u for expected match features", docid));
+ // The following values should have been set by MyRankProgram::run()
+ EXPECT_EQ(10 + docid, as_double(v[0]));
+ EXPECT_EQ(30 + docid, as_double(v[1]));
+ EXPECT_EQ(TensorSpec("tensor(x{})").add({{"x", "a"}}, 20 + docid), as_spec(v[2]));
+}
+
+}
namespace streaming {
@@ -326,6 +359,33 @@ TEST_F(HitCollectorTest, feature_set)
assertHit(30, 4, 2, sr);
}
+TEST_F(HitCollectorTest, match_features)
+{
+ HitCollector hc(3);
+
+ addHit(hc, 0, 10);
+ addHit(hc, 1, 50); // on heap
+ addHit(hc, 2, 20);
+ addHit(hc, 3, 40); // on heap
+ addHit(hc, 4, 30); // on heap
+
+ MyRankProgram rankProgram;
+ FeatureResolver resolver(rankProgram.get_resolver());
+ search::StringStringMap renames;
+ renames["bar"] = "qux";
+ auto mf = hc.get_match_features(rankProgram, resolver, renames);
+ auto num_features = resolver.num_features();
+
+ EXPECT_EQ(num_features, mf.names.size());
+ EXPECT_EQ("foo", mf.names[0]);
+ EXPECT_EQ("qux", mf.names[1]);
+ EXPECT_EQ("baz", mf.names[2]);
+ EXPECT_EQ(num_features * 3, mf.values.size());
+ check_match_features(as_value_slice(mf, 0, num_features), 1);
+ check_match_features(as_value_slice(mf, 1, num_features), 3);
+ check_match_features(as_value_slice(mf, 2, num_features), 4);
+}
+
} // namespace streaming
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp b/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp
index 362d5a26611..babc6fbd697 100644
--- a/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp
+++ b/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp
@@ -13,6 +13,7 @@ LOG_SETUP(".searchvisitor.hitcollector");
using search::fef::MatchData;
using vespalib::FeatureSet;
+using vespalib::FeatureValues;
using vdslib::SearchResult;
using FefUtils = search::fef::Utils;
@@ -126,7 +127,7 @@ HitCollector::addHit(Hit && hit)
}
void
-HitCollector::fillSearchResult(vdslib::SearchResult & searchResult)
+HitCollector::fillSearchResult(vdslib::SearchResult & searchResult, FeatureValues&& match_features)
{
sortByDocId();
for (const Hit & hit : _hits) {
@@ -142,6 +143,13 @@ HitCollector::fillSearchResult(vdslib::SearchResult & searchResult)
searchResult.addHit(docId, documentId.c_str(), rank, hit.getSortBlob().c_str(), hit.getSortBlob().size());
}
}
+ searchResult.set_match_features(std::move(match_features));
+}
+
+void
+HitCollector::fillSearchResult(vdslib::SearchResult & searchResult)
+{
+ fillSearchResult(searchResult, FeatureValues());
}
FeatureSet::SP
@@ -164,5 +172,28 @@ HitCollector::getFeatureSet(IRankProgram &rankProgram,
return retval;
}
+FeatureValues
+HitCollector::get_match_features(IRankProgram& rank_program,
+ const search::fef::FeatureResolver& resolver,
+ const search::StringStringMap& feature_rename_map)
+{
+ FeatureValues match_features;
+ if (resolver.num_features() == 0 || _hits.empty()) {
+ return match_features;
+ }
+ sortByDocId();
+ match_features.names = FefUtils::extract_feature_names(resolver, feature_rename_map);
+ match_features.values.resize(resolver.num_features() * _hits.size());
+ auto f = match_features.values.data();
+ for (const Hit & hit : _hits) {
+ auto docid = hit.getDocId();
+ rank_program.run(docid, hit.getMatchData());
+ FefUtils::extract_feature_values(resolver, docid, f);
+ f += resolver.num_features();
+ }
+ assert(f == match_features.values.data() + match_features.values.size());
+ return match_features;
+}
+
} // namespace streaming
diff --git a/streamingvisitors/src/vespa/searchvisitor/hitcollector.h b/streamingvisitors/src/vespa/searchvisitor/hitcollector.h
index 2918f815811..07418b85c75 100644
--- a/streamingvisitors/src/vespa/searchvisitor/hitcollector.h
+++ b/streamingvisitors/src/vespa/searchvisitor/hitcollector.h
@@ -121,6 +121,7 @@ public:
* Fills the given search result with the m best hits from the hit heap.
* Invoking this method will destroy the heap property of the hit heap.
**/
+ void fillSearchResult(vdslib::SearchResult & searchResult, vespalib::FeatureValues&& match_features);
void fillSearchResult(vdslib::SearchResult & searchResult);
/**
@@ -136,6 +137,9 @@ public:
const search::fef::FeatureResolver &resolver,
const search::StringStringMap &feature_rename_map);
+ vespalib::FeatureValues get_match_features(IRankProgram& rank_program,
+ const search::fef::FeatureResolver& resolver,
+ const search::StringStringMap& feature_rename_map);
};
} // namespace streaming
diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
index 3ce137bffe5..55638c3ec44 100644
--- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
+++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
@@ -11,6 +11,7 @@
LOG_SETUP(".searchvisitor.rankprocessor");
using vespalib::FeatureSet;
+using vespalib::FeatureValues;
using search::fef::FeatureHandle;
using search::fef::ITermData;
using search::fef::ITermFieldData;
@@ -131,6 +132,10 @@ RankProcessor::init(bool forRanking, size_t wantedHitCount)
_rankScore = getFeature(*_rankProgram);
_summaryProgram = _rankSetup.create_summary_program();
setupRankProgram(*_summaryProgram);
+ if (_rankSetup.has_match_features()) {
+ _match_features_program = _rankSetup.create_match_program();
+ setupRankProgram(*_match_features_program);
+ }
} else {
_rankProgram = _rankSetup.create_dump_program();
setupRankProgram(*_rankProgram);
@@ -157,7 +162,8 @@ RankProcessor::RankProcessor(RankManager::Snapshot::SP snapshot,
_summaryProgram(),
_zeroScore(),
_rankScore(&_zeroScore),
- _hitCollector()
+ _hitCollector(),
+ _match_features_program()
{
}
@@ -222,10 +228,21 @@ RankProcessor::calculateFeatureSet()
return sf;
}
+FeatureValues
+RankProcessor::calculate_match_features()
+{
+ if (!_match_features_program) {
+ return FeatureValues();
+ }
+ RankProgramWrapper wrapper(*_match_data);
+ search::fef::FeatureResolver resolver(_match_features_program->get_seeds(false));
+ return _hitCollector->get_match_features(wrapper, resolver, _rankSetup.get_feature_rename_map());
+}
+
void
RankProcessor::fillSearchResult(vdslib::SearchResult & searchResult)
{
- _hitCollector->fillSearchResult(searchResult);
+ _hitCollector->fillSearchResult(searchResult, calculate_match_features());
}
void
diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h
index c74a2d1e3ee..5307b66e1d5 100644
--- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h
+++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h
@@ -23,6 +23,9 @@ namespace streaming {
class RankProcessor
{
private:
+ using RankProgram = search::fef::RankProgram;
+ using FeatureSet = vespalib::FeatureSet;
+ using FeatureValues = vespalib::FeatureValues;
RankManager::Snapshot::SP _rankManagerSnapshot;
const search::fef::RankSetup & _rankSetup;
QueryWrapper _query;
@@ -37,10 +40,12 @@ private:
search::fef::NumberOrObject _zeroScore;
search::fef::LazyValue _rankScore;
HitCollector::UP _hitCollector;
+ std::unique_ptr<RankProgram> _match_features_program;
void initQueryEnvironment();
void initHitCollector(size_t wantedHitCount);
void setupRankProgram(search::fef::RankProgram &program);
+ FeatureValues calculate_match_features();
/**
* Initializes this rank processor.