diff options
Diffstat (limited to 'streamingvisitors/src')
5 files changed, 120 insertions, 3 deletions
diff --git a/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp b/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp index 059d3c9f597..362aaeda938 100644 --- a/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp +++ b/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp @@ -11,6 +11,7 @@ #include <vespa/eval/eval/value_codec.h> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/objects/nbostream.h> +#include <vespa/vespalib/util/stringfmt.h> using namespace document; using namespace search::fef; @@ -22,7 +23,39 @@ using vespalib::eval::DoubleValue; using vespalib::eval::SimpleValue; using vespalib::eval::TensorSpec; using vespalib::eval::Value; +using vespalib::make_string_short::fmt; +using FeatureValue = FeatureSet::Value; + +namespace { + +double as_double(const FeatureValue& v) { + EXPECT_TRUE(v.is_double()); + return v.as_double(); +} + +TensorSpec as_spec(const FeatureValue& v) { + EXPECT_TRUE(v.is_data()); + auto mem = v.as_data(); + nbostream buf(mem.data, mem.size); + return spec_from_value(*SimpleValue::from_stream(buf)); +} + +ConstArrayRef<FeatureValue> as_value_slice(FeatureValues& mf, uint32_t index, uint32_t num_features) +{ + return { mf.values.data() + index * num_features, num_features }; +} + +void check_match_features(ConstArrayRef<FeatureValue> v, uint32_t docid) +{ + SCOPED_TRACE(fmt("Checking docid %u for expected match features", docid)); + // The following values should have been set by MyRankProgram::run() + EXPECT_EQ(10 + docid, as_double(v[0])); + EXPECT_EQ(30 + docid, as_double(v[1])); + EXPECT_EQ(TensorSpec("tensor(x{})").add({{"x", "a"}}, 20 + docid), as_spec(v[2])); +} + +} namespace streaming { @@ -326,6 +359,33 @@ TEST_F(HitCollectorTest, feature_set) assertHit(30, 4, 2, sr); } +TEST_F(HitCollectorTest, match_features) +{ + HitCollector hc(3); + + addHit(hc, 0, 10); + addHit(hc, 1, 50); // on heap + addHit(hc, 2, 20); + addHit(hc, 3, 40); // on heap + addHit(hc, 4, 30); // on heap + + MyRankProgram rankProgram; + FeatureResolver resolver(rankProgram.get_resolver()); + search::StringStringMap renames; + renames["bar"] = "qux"; + auto mf = hc.get_match_features(rankProgram, resolver, renames); + auto num_features = resolver.num_features(); + + EXPECT_EQ(num_features, mf.names.size()); + EXPECT_EQ("foo", mf.names[0]); + EXPECT_EQ("qux", mf.names[1]); + EXPECT_EQ("baz", mf.names[2]); + EXPECT_EQ(num_features * 3, mf.values.size()); + check_match_features(as_value_slice(mf, 0, num_features), 1); + check_match_features(as_value_slice(mf, 1, num_features), 3); + check_match_features(as_value_slice(mf, 2, num_features), 4); +} + } // namespace streaming GTEST_MAIN_RUN_ALL_TESTS() diff --git a/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp b/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp index 362d5a26611..babc6fbd697 100644 --- a/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp @@ -13,6 +13,7 @@ LOG_SETUP(".searchvisitor.hitcollector"); using search::fef::MatchData; using vespalib::FeatureSet; +using vespalib::FeatureValues; using vdslib::SearchResult; using FefUtils = search::fef::Utils; @@ -126,7 +127,7 @@ HitCollector::addHit(Hit && hit) } void -HitCollector::fillSearchResult(vdslib::SearchResult & searchResult) +HitCollector::fillSearchResult(vdslib::SearchResult & searchResult, FeatureValues&& match_features) { sortByDocId(); for (const Hit & hit : _hits) { @@ -142,6 +143,13 @@ HitCollector::fillSearchResult(vdslib::SearchResult & searchResult) searchResult.addHit(docId, documentId.c_str(), rank, hit.getSortBlob().c_str(), hit.getSortBlob().size()); } } + searchResult.set_match_features(std::move(match_features)); +} + +void +HitCollector::fillSearchResult(vdslib::SearchResult & searchResult) +{ + fillSearchResult(searchResult, FeatureValues()); } FeatureSet::SP @@ -164,5 +172,28 @@ HitCollector::getFeatureSet(IRankProgram &rankProgram, return retval; } +FeatureValues +HitCollector::get_match_features(IRankProgram& rank_program, + const search::fef::FeatureResolver& resolver, + const search::StringStringMap& feature_rename_map) +{ + FeatureValues match_features; + if (resolver.num_features() == 0 || _hits.empty()) { + return match_features; + } + sortByDocId(); + match_features.names = FefUtils::extract_feature_names(resolver, feature_rename_map); + match_features.values.resize(resolver.num_features() * _hits.size()); + auto f = match_features.values.data(); + for (const Hit & hit : _hits) { + auto docid = hit.getDocId(); + rank_program.run(docid, hit.getMatchData()); + FefUtils::extract_feature_values(resolver, docid, f); + f += resolver.num_features(); + } + assert(f == match_features.values.data() + match_features.values.size()); + return match_features; +} + } // namespace streaming diff --git a/streamingvisitors/src/vespa/searchvisitor/hitcollector.h b/streamingvisitors/src/vespa/searchvisitor/hitcollector.h index 2918f815811..07418b85c75 100644 --- a/streamingvisitors/src/vespa/searchvisitor/hitcollector.h +++ b/streamingvisitors/src/vespa/searchvisitor/hitcollector.h @@ -121,6 +121,7 @@ public: * Fills the given search result with the m best hits from the hit heap. * Invoking this method will destroy the heap property of the hit heap. **/ + void fillSearchResult(vdslib::SearchResult & searchResult, vespalib::FeatureValues&& match_features); void fillSearchResult(vdslib::SearchResult & searchResult); /** @@ -136,6 +137,9 @@ public: const search::fef::FeatureResolver &resolver, const search::StringStringMap &feature_rename_map); + vespalib::FeatureValues get_match_features(IRankProgram& rank_program, + const search::fef::FeatureResolver& resolver, + const search::StringStringMap& feature_rename_map); }; } // namespace streaming diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp index 3ce137bffe5..55638c3ec44 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp @@ -11,6 +11,7 @@ LOG_SETUP(".searchvisitor.rankprocessor"); using vespalib::FeatureSet; +using vespalib::FeatureValues; using search::fef::FeatureHandle; using search::fef::ITermData; using search::fef::ITermFieldData; @@ -131,6 +132,10 @@ RankProcessor::init(bool forRanking, size_t wantedHitCount) _rankScore = getFeature(*_rankProgram); _summaryProgram = _rankSetup.create_summary_program(); setupRankProgram(*_summaryProgram); + if (_rankSetup.has_match_features()) { + _match_features_program = _rankSetup.create_match_program(); + setupRankProgram(*_match_features_program); + } } else { _rankProgram = _rankSetup.create_dump_program(); setupRankProgram(*_rankProgram); @@ -157,7 +162,8 @@ RankProcessor::RankProcessor(RankManager::Snapshot::SP snapshot, _summaryProgram(), _zeroScore(), _rankScore(&_zeroScore), - _hitCollector() + _hitCollector(), + _match_features_program() { } @@ -222,10 +228,21 @@ RankProcessor::calculateFeatureSet() return sf; } +FeatureValues +RankProcessor::calculate_match_features() +{ + if (!_match_features_program) { + return FeatureValues(); + } + RankProgramWrapper wrapper(*_match_data); + search::fef::FeatureResolver resolver(_match_features_program->get_seeds(false)); + return _hitCollector->get_match_features(wrapper, resolver, _rankSetup.get_feature_rename_map()); +} + void RankProcessor::fillSearchResult(vdslib::SearchResult & searchResult) { - _hitCollector->fillSearchResult(searchResult); + _hitCollector->fillSearchResult(searchResult, calculate_match_features()); } void diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h index c74a2d1e3ee..5307b66e1d5 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.h @@ -23,6 +23,9 @@ namespace streaming { class RankProcessor { private: + using RankProgram = search::fef::RankProgram; + using FeatureSet = vespalib::FeatureSet; + using FeatureValues = vespalib::FeatureValues; RankManager::Snapshot::SP _rankManagerSnapshot; const search::fef::RankSetup & _rankSetup; QueryWrapper _query; @@ -37,10 +40,12 @@ private: search::fef::NumberOrObject _zeroScore; search::fef::LazyValue _rankScore; HitCollector::UP _hitCollector; + std::unique_ptr<RankProgram> _match_features_program; void initQueryEnvironment(); void initHitCollector(size_t wantedHitCount); void setupRankProgram(search::fef::RankProgram &program); + FeatureValues calculate_match_features(); /** * Initializes this rank processor. |