summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/CMakeLists.txt1
-rw-r--r--searchlib/src/apps/vespa-query-analyzer/.gitignore3
-rw-r--r--searchlib/src/apps/vespa-query-analyzer/CMakeLists.txt9
-rw-r--r--searchlib/src/apps/vespa-query-analyzer/vespa-query-analyzer.cpp361
-rw-r--r--searchlib/src/tests/hitcollector/CMakeLists.txt2
-rw-r--r--searchlib/src/tests/hitcollector/hitcollector_test.cpp269
-rw-r--r--searchlib/src/tests/hitcollector/sorted_hit_sequence_test.cpp14
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp197
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/hitcollector.h11
9 files changed, 753 insertions, 114 deletions
diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt
index 570bffa59c2..43e17417c51 100644
--- a/searchlib/CMakeLists.txt
+++ b/searchlib/CMakeLists.txt
@@ -60,6 +60,7 @@ vespa_define_module(
src/apps/vespa-attribute-inspect
src/apps/vespa-fileheader-inspect
src/apps/vespa-index-inspect
+ src/apps/vespa-query-analyzer
src/apps/vespa-ranking-expression-analyzer
TESTS
diff --git a/searchlib/src/apps/vespa-query-analyzer/.gitignore b/searchlib/src/apps/vespa-query-analyzer/.gitignore
new file mode 100644
index 00000000000..e5a31caab09
--- /dev/null
+++ b/searchlib/src/apps/vespa-query-analyzer/.gitignore
@@ -0,0 +1,3 @@
+/.depend
+/Makefile
+/vespa-query-analyzer
diff --git a/searchlib/src/apps/vespa-query-analyzer/CMakeLists.txt b/searchlib/src/apps/vespa-query-analyzer/CMakeLists.txt
new file mode 100644
index 00000000000..f84a413ee70
--- /dev/null
+++ b/searchlib/src/apps/vespa-query-analyzer/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(searchlib_vespa-query-analyzer_app
+ SOURCES
+ vespa-query-analyzer.cpp
+ OUTPUT_NAME vespa-query-analyzer
+ INSTALL bin
+ DEPENDS
+ searchlib
+)
diff --git a/searchlib/src/apps/vespa-query-analyzer/vespa-query-analyzer.cpp b/searchlib/src/apps/vespa-query-analyzer/vespa-query-analyzer.cpp
new file mode 100644
index 00000000000..178c09c02ac
--- /dev/null
+++ b/searchlib/src/apps/vespa-query-analyzer/vespa-query-analyzer.cpp
@@ -0,0 +1,361 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/data/simple_buffer.h>
+#include <vespa/vespalib/data/slime/json_format.h>
+#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/io/mapped_file_input.h>
+#include <vespa/vespalib/util/overload.h>
+#include <vespa/vespalib/util/signalhandler.h>
+#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/searchlib/queryeval/flow.h>
+#include <variant>
+#include <vector>
+#include <map>
+
+using namespace vespalib::slime::convenience;
+using vespalib::make_string_short::fmt;
+using vespalib::slime::JsonFormat;
+using vespalib::slime::ARRAY;
+using vespalib::slime::OBJECT;
+using vespalib::slime::STRING;
+using vespalib::slime::DOUBLE;
+using vespalib::slime::BOOL;
+using search::queryeval::FlowStats;
+using search::queryeval::InFlow;
+
+//-----------------------------------------------------------------------------
+
+using Path = std::vector<std::variant<size_t,vespalib::stringref>>;
+using Paths = std::vector<Path>;
+
+template <typename F>
+struct Matcher : vespalib::slime::ObjectTraverser {
+ Path path;
+ Paths result;
+ F match;
+ ~Matcher();
+ Matcher(F match_in) noexcept : path(), result(), match(match_in) {}
+ void search(const Inspector &node) {
+ if (path.empty() && match(path, node)) {
+ result.push_back(path);
+ }
+ if (node.type() == OBJECT()) {
+ node.traverse(*this);
+ }
+ if (node.type() == ARRAY()) {
+ size_t size = node.entries();
+ for (size_t i = 0; i < size; ++i) {
+ path.emplace_back(i);
+ if (match(path, node[i])) {
+ result.push_back(path);
+ }
+ search(node[i]);
+ path.pop_back();
+ }
+ }
+ }
+ void field(const Memory &symbol, const Inspector &inspector) final {
+ path.emplace_back(symbol.make_stringref());
+ if (match(path, inspector)) {
+ result.push_back(path);
+ }
+ search(inspector);
+ path.pop_back();
+ }
+};
+template <typename F> Matcher<F>::~Matcher() = default;
+
+std::vector<Path> find_field(const Inspector &root, const vespalib::string &name) {
+ auto matcher = Matcher([&](const Path &path, const Inspector &){
+ return ((path.size() > 0) &&
+ (std::holds_alternative<vespalib::stringref>(path.back())) &&
+ (std::get<vespalib::stringref>(path.back()) == name));
+ });
+ matcher.search(root);
+ return matcher.result;
+}
+
+std::vector<Path> find_tag(const Inspector &root, const vespalib::string &name) {
+ auto matcher = Matcher([&](const Path &path, const Inspector &value){
+ return ((path.size() > 0) &&
+ (std::holds_alternative<vespalib::stringref>(path.back())) &&
+ (std::get<vespalib::stringref>(path.back()) == "tag") &&
+ (value.asString().make_stringref() == name));
+ });
+ matcher.search(root);
+ return matcher.result;
+}
+
+vespalib::string path_to_str(const Path &path) {
+ size_t cnt = 0;
+ vespalib::string str("[");
+ for (const auto &item: path) {
+ if (cnt++ > 0) {
+ str.append(",");
+ }
+ std::visit(vespalib::overload{
+ [&str](size_t value)noexcept{ str.append(fmt("%zu", value)); },
+ [&str](vespalib::stringref value)noexcept{ str.append(value); }}, item);
+ }
+ str.append("]");
+ return str;
+}
+
+vespalib::string strip_name(vespalib::stringref name) {
+ auto end = name.find("<");
+ auto ns = name.rfind("::", end);
+ size_t begin = (ns > name.size()) ? 0 : ns + 2;
+ return name.substr(begin, end - begin);
+}
+
+const Inspector &apply_path(const Inspector &node, const Path &path, size_t max = -1) {
+ size_t cnt = 0;
+ const Inspector *ptr = &node;
+ for (const auto &elem: path) {
+ if (cnt++ >= max) {
+ return *ptr;
+ }
+ if (std::holds_alternative<size_t>(elem)) {
+ ptr = &((*ptr)[std::get<size_t>(elem)]);
+ }
+ if (std::holds_alternative<vespalib::stringref>(elem)) {
+ auto ref = std::get<vespalib::stringref>(elem);
+ ptr = &((*ptr)[Memory(ref.data(), ref.size())]);
+ }
+ }
+ return *ptr;
+}
+
+void extract(vespalib::string &value, const Inspector &data) {
+ if (data.valid() && data.type() == STRING()) {
+ value = data.asString().make_stringref();
+ }
+}
+
+struct Sample {
+ enum class Type { INVALID, INIT, SEEK, UNPACK, TERMWISE };
+ Type type = Type::INVALID;
+ std::vector<size_t> path;
+ double self_time_ms = 0.0;
+ double total_time_ms = 0.0;
+ size_t count = 0;
+ Sample(const Inspector &sample) {
+ auto name = sample["name"].asString().make_stringref();
+ if (ends_with(name, "/init")) {
+ type = Type::INIT;
+ }
+ if (ends_with(name, "/seek")) {
+ type = Type::SEEK;
+ }
+ if (ends_with(name, "/unpack")) {
+ type = Type::UNPACK;
+ }
+ if (ends_with(name, "/termwise")) {
+ type = Type::TERMWISE;
+ }
+ if (starts_with(name, "/")) {
+ size_t child = 0;
+ for (size_t pos = 1; pos < name.size(); ++pos) {
+ char c = name[pos];
+ if (c == '/') {
+ path.push_back(child);
+ child = 0;
+ } else {
+ if (c < '0' || c > '9') {
+ break;
+ }
+ child = child * 10 + (c - '0');
+ }
+ }
+ }
+ self_time_ms = sample["self_time_ms"].asDouble();
+ total_time_ms = sample["total_time_ms"].asDouble();
+ count = sample["count"].asLong();
+ }
+ static vespalib::string type_to_str(Type type) {
+ switch(type) {
+ case Type::INVALID: return "<invalid>";
+ case Type::INIT: return "init";
+ case Type::SEEK: return "seek";
+ case Type::UNPACK: return "unpack";
+ case Type::TERMWISE: return "termwise";
+ }
+ abort();
+ }
+ static vespalib::string path_to_str(const std::vector<size_t> &path) {
+ vespalib::string result("/");
+ for (size_t elem: path) {
+ result += fmt("%zu/", elem);
+ }
+ return result;
+ }
+ vespalib::string to_string() const {
+ return fmt("type: %s, path: %s, count: %zu, total_time_ms: %g\n",
+ type_to_str(type).c_str(), path_to_str(path).c_str(), count, total_time_ms);
+ }
+};
+
+struct Node {
+ vespalib::string type = "unknown";
+ bool strict = false;
+ FlowStats flow_stats = FlowStats(0.0, 0.0, 0.0);
+ InFlow in_flow = InFlow(0.0);
+ size_t count = 0;
+ double self_time_ms = 0.0;
+ double total_time_ms = 0.0;
+ std::vector<Node> children;
+ Node(const Inspector &obj) {
+ extract(type, obj["[type]"]);
+ type = strip_name(type);
+ strict = obj["strict"].asBool();
+ flow_stats.estimate = obj["relative_estimate"].asDouble();
+ flow_stats.cost = obj["cost"].asDouble();
+ flow_stats.strict_cost = obj["strict_cost"].asDouble();
+ const Inspector &list = obj["children"];
+ for (size_t i = 0; true; ++i) {
+ const Inspector &child = list[fmt("[%zu]", i)];
+ if (child.valid()) {
+ children.emplace_back(child);
+ } else {
+ break;
+ }
+ }
+ }
+ ~Node();
+ void add_sample(const Sample &sample) {
+ Node *node = this;
+ for (size_t child: sample.path) {
+ if (child < node->children.size()) {
+ node = &node->children[child];
+ } else {
+ fprintf(stderr, "... ignoring bad sample: %s\n", sample.to_string().c_str());
+ return;
+ }
+ }
+ node->count += sample.count;
+ node->self_time_ms += sample.self_time_ms;
+ node->total_time_ms += sample.total_time_ms;
+ }
+ void dump_line(size_t indent) const {
+ fprintf(stderr, "|%10zu ", count);
+ fprintf(stderr, "|%11.3f ", total_time_ms);
+ fprintf(stderr, "|%10.3f | ", self_time_ms);
+ for (size_t i = 0; i < indent; ++i) {
+ fprintf(stderr, " ");
+ }
+ fprintf(stderr, "%s\n", type.c_str());
+ for (const Node &child: children) {
+ child.dump_line(indent + 1);
+ }
+ }
+ void dump() const {
+ fprintf(stderr, "| count | total_time | self_time | structure\n");
+ fprintf(stderr, "+-----------+------------+-----------+-------------------------------\n");
+ dump_line(0);
+ fprintf(stderr, "+-----------+------------+-----------+-------------------------------\n");
+ }
+};
+Node::~Node() = default;
+
+void each_sample_list(const Inspector &list, auto f) {
+ for (size_t i = 0; i < list.entries(); ++i) {
+ f(Sample(list[i]));
+ each_sample_list(list[i]["children"], f);
+ }
+}
+
+void each_sample(const Inspector &prof, auto f) {
+ each_sample_list(prof["roots"], f);
+}
+
+struct State {
+ void analyze(const Inspector &root) {
+ auto bp_list = find_field(root, "optimized");
+ for (const Path &path: bp_list) {
+ const Inspector &node = apply_path(root, path, path.size()-3);
+ const Inspector &key_field = node["distribution-key"];
+ if (key_field.valid()) {
+ int key = key_field.asLong();
+ Node data(apply_path(root, path));
+ auto prof_list = find_tag(node, "match_profiling");
+ double total_ms = 0.0;
+ std::map<Sample::Type,double> time_map;
+ for (const Path &prof_path: prof_list) {
+ const Inspector &prof = apply_path(node, prof_path, prof_path.size()-1);
+ if (prof["profiler"].asString().make_stringref() == "tree") {
+ total_ms += prof["total_time_ms"].asDouble();
+ each_sample(prof, [&](const Sample &sample) {
+ if (sample.type == Sample::Type::SEEK) {
+ data.add_sample(sample);
+ }
+ if (sample.path.empty()) {
+ time_map[sample.type] += sample.total_time_ms;
+ }
+ });
+ }
+ }
+ data.dump();
+ fprintf(stderr, "distribution key: %d, total_time_ms: %g\n", key, total_ms);
+ for (auto [type, time]: time_map) {
+ fprintf(stderr, "sample type %s used %g ms total\n", Sample::type_to_str(type).c_str(), time);
+ }
+ }
+ }
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+void usage(const char *self) {
+ fprintf(stderr, "usage: %s <json query result file>\n", self);
+ fprintf(stderr, " analyze query cost (planning vs profiling)\n");
+ fprintf(stderr, " query result must contain optimized blueprint dump\n");
+ fprintf(stderr, " query result must contain match phase tree profiling\n\n");
+}
+
+struct MyApp {
+ vespalib::string file_name;
+ bool parse_params(int argc, char **argv);
+ int main();
+};
+
+bool
+MyApp::parse_params(int argc, char **argv) {
+ if (argc != 2) {
+ return false;
+ }
+ file_name = argv[1];
+ return true;
+}
+
+int
+MyApp::main()
+{
+ vespalib::MappedFileInput file(file_name);
+ if (!file.valid()) {
+ fprintf(stderr, "could not read input file: '%s'\n",
+ file_name.c_str());
+ return 1;
+ }
+ Slime slime;
+ if(JsonFormat::decode(file, slime) == 0) {
+ fprintf(stderr, "file contains invalid json: '%s'\n",
+ file_name.c_str());
+ return 1;
+ }
+ State state;
+ state.analyze(slime.get());
+ return 0;
+}
+
+int main(int argc, char **argv) {
+ MyApp my_app;
+ vespalib::SignalHandler::PIPE.ignore();
+ if (!my_app.parse_params(argc, argv)) {
+ usage(argv[0]);
+ return 1;
+ }
+ return my_app.main();
+}
+
+//-----------------------------------------------------------------------------
diff --git a/searchlib/src/tests/hitcollector/CMakeLists.txt b/searchlib/src/tests/hitcollector/CMakeLists.txt
index 5cedbcbd7e6..cc62dd82af4 100644
--- a/searchlib/src/tests/hitcollector/CMakeLists.txt
+++ b/searchlib/src/tests/hitcollector/CMakeLists.txt
@@ -4,6 +4,7 @@ vespa_add_executable(searchlib_hitcollector_test_app TEST
hitcollector_test.cpp
DEPENDS
searchlib
+ GTest::gtest
)
vespa_add_test(NAME searchlib_hitcollector_test_app COMMAND searchlib_hitcollector_test_app)
vespa_add_executable(searchlib_sorted_hit_sequence_test_app TEST
@@ -11,5 +12,6 @@ vespa_add_executable(searchlib_sorted_hit_sequence_test_app TEST
sorted_hit_sequence_test.cpp
DEPENDS
searchlib
+ GTest::gtest
)
vespa_add_test(NAME searchlib_sorted_hit_sequence_test_app COMMAND searchlib_sorted_hit_sequence_test_app)
diff --git a/searchlib/src/tests/hitcollector/hitcollector_test.cpp b/searchlib/src/tests/hitcollector/hitcollector_test.cpp
index e6e38181412..60daa571f1d 100644
--- a/searchlib/src/tests/hitcollector/hitcollector_test.cpp
+++ b/searchlib/src/tests/hitcollector/hitcollector_test.cpp
@@ -1,9 +1,9 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <vespa/vespalib/testkit/testapp.h>
#include <vespa/searchlib/common/bitvector.h>
#include <vespa/searchlib/fef/fef.h>
#include <vespa/searchlib/queryeval/hitcollector.h>
+#include <vespa/vespalib/gtest/gtest.h>
#include <vespa/log/log.h>
LOG_SETUP("hitcollector_test");
@@ -13,6 +13,8 @@ using namespace search::fef;
using namespace search::queryeval;
using ScoreMap = std::map<uint32_t, feature_t>;
+using DocidVector = std::vector<uint32_t>;
+using RankedHitVector = std::vector<RankedHit>;
using Ranges = std::pair<Scores, Scores>;
@@ -67,11 +69,11 @@ void checkResult(const ResultSet & rs, const std::vector<RankedHit> & exp)
if ( ! exp.empty()) {
const RankedHit * rh = rs.getArray();
ASSERT_TRUE(rh != nullptr);
- ASSERT_EQUAL(rs.getArrayUsed(), exp.size());
+ ASSERT_EQ(rs.getArrayUsed(), exp.size());
for (uint32_t i = 0; i < exp.size(); ++i) {
- EXPECT_EQUAL(rh[i].getDocId(), exp[i].getDocId());
- EXPECT_EQUAL(rh[i].getRank() + 1.0, exp[i].getRank() + 1.0);
+ EXPECT_EQ(rh[i].getDocId(), exp[i].getDocId());
+ EXPECT_DOUBLE_EQ(rh[i].getRank() + 64.0, exp[i].getRank() + 64.0);
}
} else {
ASSERT_TRUE(rs.getArray() == nullptr);
@@ -93,21 +95,24 @@ void checkResult(ResultSet & rs, BitVector * exp)
}
}
-void testAddHit(uint32_t numDocs, uint32_t maxHitsSize)
+void testAddHit(uint32_t numDocs, uint32_t maxHitsSize, const vespalib::string& label)
{
+ SCOPED_TRACE(label);
LOG(info, "testAddHit: no hits");
- { // no hits
+ {
+ SCOPED_TRACE("no hits");
HitCollector hc(numDocs, maxHitsSize);
std::vector<RankedHit> expRh;
std::unique_ptr<ResultSet> rs = hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
- TEST_DO(checkResult(*rs, nullptr));
+ checkResult(*rs, expRh);
+ checkResult(*rs, nullptr);
}
LOG(info, "testAddHit: only ranked hits");
- { // only ranked hits
+ {
+ SCOPED_TRACE("only ranked hits");
HitCollector hc(numDocs, maxHitsSize);
std::vector<RankedHit> expRh;
@@ -121,12 +126,13 @@ void testAddHit(uint32_t numDocs, uint32_t maxHitsSize)
}
std::unique_ptr<ResultSet> rs = hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
- TEST_DO(checkResult(*rs, nullptr));
+ checkResult(*rs, expRh);
+ checkResult(*rs, nullptr);
}
LOG(info, "testAddHit: both ranked hits and bit vector hits");
- { // both ranked hits and bit vector hits
+ {
+ SCOPED_TRACE("both ranked hits and bitvector hits");
HitCollector hc(numDocs, maxHitsSize);
std::vector<RankedHit> expRh;
BitVector::UP expBv(BitVector::create(numDocs));
@@ -144,14 +150,15 @@ void testAddHit(uint32_t numDocs, uint32_t maxHitsSize)
}
std::unique_ptr<ResultSet> rs = hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
- TEST_DO(checkResult(*rs, expBv.get()));
+ checkResult(*rs, expRh);
+ checkResult(*rs, expBv.get());
}
}
-TEST("testAddHit") {
- TEST_DO(testAddHit(30, 10));
- TEST_DO(testAddHit(400, 10)); // 400/32 = 12 which is bigger than 10.
+TEST(HitCollectorTest, testAddHit)
+{
+ testAddHit(30, 10, "numDocs==30");
+ testAddHit(400, 10, "numDocs==400"); // 400/32 = 12 which is bigger than 10.
}
struct Fixture {
@@ -197,14 +204,17 @@ struct DescendingScoreFixture : Fixture {
DescendingScoreFixture::~DescendingScoreFixture() = default;
-TEST_F("testReRank - empty", Fixture) {
- EXPECT_EQUAL(0u, f.reRank());
+TEST(HitCollectorTest, rerank_empty)
+{
+ Fixture f;
+ EXPECT_EQ(0u, f.reRank());
}
-TEST_F("testReRank - ascending", AscendingScoreFixture)
+TEST(HitCollectorTest, rerank_ascending)
{
+ AscendingScoreFixture f;
f.addHits();
- EXPECT_EQUAL(5u, f.reRank());
+ EXPECT_EQ(5u, f.reRank());
std::vector<RankedHit> expRh;
for (uint32_t i = 10; i < 20; ++i) { // 10 last are the best
@@ -213,17 +223,18 @@ TEST_F("testReRank - ascending", AscendingScoreFixture)
expRh.back()._rankValue = i + 200; // after reranking
}
}
- EXPECT_EQUAL(expRh.size(), 10u);
+ EXPECT_EQ(expRh.size(), 10u);
std::unique_ptr<ResultSet> rs = f.hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
- TEST_DO(checkResult(*rs, f.expBv.get()));
+ checkResult(*rs, expRh);
+ checkResult(*rs, f.expBv.get());
}
-TEST_F("testReRank - descending", DescendingScoreFixture)
+TEST(HitCollectorTest, rerank_descending)
{
+ DescendingScoreFixture f;
f.addHits();
- EXPECT_EQUAL(5u, f.reRank());
+ EXPECT_EQ(5u, f.reRank());
std::vector<RankedHit> expRh;
for (uint32_t i = 0; i < 10; ++i) { // 10 first are the best
@@ -232,17 +243,18 @@ TEST_F("testReRank - descending", DescendingScoreFixture)
expRh.back()._rankValue = i + 200; // after reranking
}
}
- EXPECT_EQUAL(expRh.size(), 10u);
+ EXPECT_EQ(expRh.size(), 10u);
std::unique_ptr<ResultSet> rs = f.hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
- TEST_DO(checkResult(*rs, f.expBv.get()));
+ checkResult(*rs, expRh);
+ checkResult(*rs, f.expBv.get());
}
-TEST_F("testReRank - partial", AscendingScoreFixture)
+TEST(HitCollectorTest, rerank_partial)
{
+ AscendingScoreFixture f;
f.addHits();
- EXPECT_EQUAL(3u, f.reRank(3));
+ EXPECT_EQ(3u, f.reRank(3));
std::vector<RankedHit> expRh;
for (uint32_t i = 10; i < 20; ++i) { // 10 last are the best
@@ -251,36 +263,39 @@ TEST_F("testReRank - partial", AscendingScoreFixture)
expRh.back()._rankValue = i + 200; // after reranking
}
}
- EXPECT_EQUAL(expRh.size(), 10u);
+ EXPECT_EQ(expRh.size(), 10u);
std::unique_ptr<ResultSet> rs = f.hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
- TEST_DO(checkResult(*rs, f.expBv.get()));
+ checkResult(*rs, expRh);
+ checkResult(*rs, f.expBv.get());
}
-TEST_F("require that hits for 2nd phase candidates can be retrieved", DescendingScoreFixture)
+TEST(HitCollectorTest, require_that_hits_for_2nd_phase_candidates_can_be_retrieved)
{
+ DescendingScoreFixture f;
f.addHits();
std::vector<HitCollector::Hit> scores = extract(f.hc.getSortedHitSequence(5));
- ASSERT_EQUAL(5u, scores.size());
- EXPECT_EQUAL(100, scores[0].second);
- EXPECT_EQUAL(99, scores[1].second);
- EXPECT_EQUAL(98, scores[2].second);
- EXPECT_EQUAL(97, scores[3].second);
- EXPECT_EQUAL(96, scores[4].second);
+ ASSERT_EQ(5u, scores.size());
+ EXPECT_EQ(100, scores[0].second);
+ EXPECT_EQ(99, scores[1].second);
+ EXPECT_EQ(98, scores[2].second);
+ EXPECT_EQ(97, scores[3].second);
+ EXPECT_EQ(96, scores[4].second);
}
-TEST("require that score ranges can be read and set.") {
+TEST(HitCollectorTest, require_that_score_ranges_can_be_read_and_set)
+{
std::pair<Scores, Scores> ranges = std::make_pair(Scores(1.0, 2.0), Scores(3.0, 4.0));
HitCollector hc(20, 10);
hc.setRanges(ranges);
- EXPECT_EQUAL(ranges.first.low, hc.getRanges().first.low);
- EXPECT_EQUAL(ranges.first.high, hc.getRanges().first.high);
- EXPECT_EQUAL(ranges.second.low, hc.getRanges().second.low);
- EXPECT_EQUAL(ranges.second.high, hc.getRanges().second.high);
+ EXPECT_EQ(ranges.first.low, hc.getRanges().first.low);
+ EXPECT_EQ(ranges.first.high, hc.getRanges().first.high);
+ EXPECT_EQ(ranges.second.low, hc.getRanges().second.low);
+ EXPECT_EQ(ranges.second.high, hc.getRanges().second.high);
}
-TEST("testNoHitsToReRank") {
+TEST(HitCollectorTest, no_hits_to_rerank)
+{
uint32_t numDocs = 20;
uint32_t maxHitsSize = 10;
@@ -299,8 +314,8 @@ TEST("testNoHitsToReRank") {
}
std::unique_ptr<ResultSet> rs = hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
- TEST_DO(checkResult(*rs, nullptr));
+ checkResult(*rs, expRh);
+ checkResult(*rs, nullptr);
}
}
@@ -317,14 +332,15 @@ void testScaling(const std::vector<feature_t> &initScores,
PredefinedScorer scorer(std::move(finalScores));
// perform second phase ranking
- EXPECT_EQUAL(2u, do_reRank(scorer, hc, 2));
+ EXPECT_EQ(2u, do_reRank(scorer, hc, 2));
// check results
std::unique_ptr<ResultSet> rs = hc.getResultSet();
- TEST_DO(checkResult(*rs, expected));
+ checkResult(*rs, expected);
}
-TEST("testScaling") {
+TEST(HitCollectorTest, scaling)
+{
std::vector<feature_t> initScores(5);
initScores[0] = 1000;
initScores[1] = 2000;
@@ -338,7 +354,8 @@ TEST("testScaling") {
exp[i]._docId = i;
}
- { // scale down and adjust down
+ {
+ SCOPED_TRACE("scale down and adjust down");
exp[0]._rankValue = 0; // scaled
exp[1]._rankValue = 100; // scaled
exp[2]._rankValue = 200; // scaled
@@ -350,9 +367,10 @@ TEST("testScaling") {
finalScores[3] = 300;
finalScores[4] = 400;
- TEST_DO(testScaling(initScores, std::move(finalScores), exp));
+ testScaling(initScores, std::move(finalScores), exp);
}
- { // scale down and adjust up
+ {
+ SCOPED_TRACE("scale down and adjust up");
exp[0]._rankValue = 200; // scaled
exp[1]._rankValue = 300; // scaled
exp[2]._rankValue = 400; // scaled
@@ -364,10 +382,10 @@ TEST("testScaling") {
finalScores[3] = 500;
finalScores[4] = 600;
- TEST_DO(testScaling(initScores, std::move(finalScores), exp));
+ testScaling(initScores, std::move(finalScores), exp);
}
- { // scale up and adjust down
-
+ {
+ SCOPED_TRACE("scale up and adjust down");
exp[0]._rankValue = -500; // scaled (-500)
exp[1]._rankValue = 750; // scaled
exp[2]._rankValue = 2000; // scaled
@@ -379,9 +397,10 @@ TEST("testScaling") {
finalScores[3] = 3250;
finalScores[4] = 4500;
- TEST_DO(testScaling(initScores, std::move(finalScores), exp));
+ testScaling(initScores, std::move(finalScores), exp);
}
- { // minimal scale (second phase range = 0 (4 - 4) -> 1)
+ {
+ SCOPED_TRACE("minimal scale (second phase range = 0 (4 - 4) -> 1)");
exp[0]._rankValue = 1; // scaled
exp[1]._rankValue = 2; // scaled
exp[2]._rankValue = 3; // scaled
@@ -393,9 +412,10 @@ TEST("testScaling") {
finalScores[3] = 4;
finalScores[4] = 4;
- TEST_DO(testScaling(initScores, std::move(finalScores), exp));
+ testScaling(initScores, std::move(finalScores), exp);
}
- { // minimal scale (first phase range = 0 (4000 - 4000) -> 1)
+ {
+ SCOPED_TRACE("minimal scale (first phase range = 0 (4000 - 4000) -> 1)");
std::vector<feature_t> is(initScores);
is[4] = 4000;
exp[0]._rankValue = -299600; // scaled
@@ -409,11 +429,12 @@ TEST("testScaling") {
finalScores[3] = 400;
finalScores[4] = 500;
- TEST_DO(testScaling(is, std::move(finalScores), exp));
+ testScaling(is, std::move(finalScores), exp);
}
}
-TEST("testOnlyBitVector") {
+TEST(HitCollectorTest, only_bitvector)
+{
uint32_t numDocs = 20;
LOG(info, "testOnlyBitVector: test it");
{
@@ -428,8 +449,8 @@ TEST("testOnlyBitVector") {
std::unique_ptr<ResultSet> rs = hc.getResultSet();
std::vector<RankedHit> expRh;
- TEST_DO(checkResult(*rs, expRh)); // no ranked hits
- TEST_DO(checkResult(*rs, expBv.get())); // only bit vector
+ checkResult(*rs, expRh); // no ranked hits
+ checkResult(*rs, expBv.get()); // only bit vector
}
}
@@ -443,9 +464,9 @@ struct MergeResultSetFixture {
{}
};
-TEST_F("require that result set is merged correctly with first phase ranking",
- MergeResultSetFixture)
+TEST(HitCollectorTest, require_that_result_set_is_merged_correctly_with_first_phase_ranking)
{
+ MergeResultSetFixture f;
std::vector<RankedHit> expRh;
for (uint32_t i = 0; i < f.numDocs; ++i) {
f.hc.addHit(i, i + 1000);
@@ -457,7 +478,7 @@ TEST_F("require that result set is merged correctly with first phase ranking",
expRh.back()._rankValue = (i < f.numDocs - f.maxHitsSize) ? default_rank_value : i + 1000;
}
std::unique_ptr<ResultSet> rs = f.hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
+ checkResult(*rs, expRh);
}
void
@@ -474,9 +495,9 @@ addExpectedHitForMergeTest(const MergeResultSetFixture &f, std::vector<RankedHit
}
}
-TEST_F("require that result set is merged correctly with second phase ranking (document scorer)",
- MergeResultSetFixture)
+TEST(HitCollectorTest, require_that_result_set_is_merged_correctly_with_second_phase_ranking_using_document_scorer)
{
+ MergeResultSetFixture f;
// with second phase ranking that triggers rescoring / scaling
BasicScorer scorer(500); // second phase ranking setting score to docId + 500
std::vector<RankedHit> expRh;
@@ -484,12 +505,13 @@ TEST_F("require that result set is merged correctly with second phase ranking (d
f.hc.addHit(i, i + 1000);
addExpectedHitForMergeTest(f, expRh, i);
}
- EXPECT_EQUAL(f.maxHeapSize, do_reRank(scorer, f.hc, f.maxHeapSize));
+ EXPECT_EQ(f.maxHeapSize, do_reRank(scorer, f.hc, f.maxHeapSize));
std::unique_ptr<ResultSet> rs = f.hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
+ checkResult(*rs, expRh);
}
-TEST("require that hits can be added out of order") {
+TEST(HitCollectorTest, require_that_hits_can_be_added_out_of_order)
+{
HitCollector hc(1000, 100);
std::vector<RankedHit> expRh;
// produce expected result in normal order
@@ -503,11 +525,12 @@ TEST("require that hits can be added out of order") {
hc.addHit(i, i + 100);
}
std::unique_ptr<ResultSet> rs = hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
- TEST_DO(checkResult(*rs, nullptr));
+ checkResult(*rs, expRh);
+ checkResult(*rs, nullptr);
}
-TEST("require that hits can be added out of order when passing array limit") {
+TEST(HitCollectorTest, require_that_hits_can_be_added_out_of_order_when_passing_array_limit)
+{
HitCollector hc(10000, 100);
std::vector<RankedHit> expRh;
// produce expected result in normal order
@@ -525,11 +548,12 @@ TEST("require that hits can be added out of order when passing array limit") {
hc.addHit(i, i + 100);
}
std::unique_ptr<ResultSet> rs = hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
- TEST_DO(checkResult(*rs, nullptr));
+ checkResult(*rs, expRh);
+ checkResult(*rs, nullptr);
}
-TEST("require that hits can be added out of order only after passing array limit") {
+TEST(HitCollectorTest, require_that_hits_can_be_added_out_of_order_only_after_passing_array_limit)
+{
HitCollector hc(10000, 100);
std::vector<RankedHit> expRh;
// produce expected result in normal order
@@ -548,8 +572,87 @@ TEST("require that hits can be added out of order only after passing array limit
hc.addHit(i, i + 100);
}
std::unique_ptr<ResultSet> rs = hc.getResultSet();
- TEST_DO(checkResult(*rs, expRh));
- TEST_DO(checkResult(*rs, nullptr));
+ checkResult(*rs, expRh);
+ checkResult(*rs, nullptr);
+}
+
+struct RankDropFixture {
+ uint32_t _docid_limit;
+ HitCollector _hc;
+ std::vector<uint32_t> _dropped;
+ RankDropFixture(uint32_t docid_limit, uint32_t max_hits_size)
+ : _docid_limit(docid_limit),
+ _hc(docid_limit, max_hits_size)
+ {
+ }
+ void add(std::vector<RankedHit> hits) {
+ for (const auto& hit : hits) {
+ _hc.addHit(hit.getDocId(), hit.getRank());
+ }
+ }
+ void rerank(ScoreMap score_map, size_t count) {
+ PredefinedScorer scorer(score_map);
+ EXPECT_EQ(count, do_reRank(scorer, _hc, count));
+ }
+ std::unique_ptr<BitVector> make_bv(DocidVector docids) {
+ auto bv = BitVector::create(_docid_limit);
+ for (auto& docid : docids) {
+ bv->setBit(docid);
+ }
+ return bv;
+ }
+
+ void setup() {
+ // Initial 7 hits from first phase
+ add({{5, 1100},{10, 1200},{11, 1300},{12, 1400},{14, 500},{15, 900},{16,1000}});
+ // Rerank two best hits, calculate old and new ranges for reranked
+ // hits that will cause hits not reranked to later be rescored by
+ // dividing by 100.
+ rerank({{11,14},{12,13}}, 2);
+ }
+ void check_result(std::optional<double> rank_drop_limit, RankedHitVector exp_array,
+ std::unique_ptr<BitVector> exp_bv, DocidVector exp_dropped) {
+ auto rs = _hc.get_result_set(rank_drop_limit, &_dropped);
+ checkResult(*rs, exp_array);
+ checkResult(*rs, exp_bv.get());
+ EXPECT_EQ(exp_dropped, _dropped);
+ }
+};
+
+TEST(HitCollectorTest, require_that_second_phase_rank_drop_limit_is_enforced)
+{
+ // Track rank score for all 7 hits from first phase
+ RankDropFixture f(10000, 10);
+ f.setup();
+ f.check_result(9.0, {{5,11},{10,12},{11,14},{12,13},{16,10}},
+ {}, {14, 15});
+}
+
+TEST(HitCollectorTest, require_that_second_phase_rank_drop_limit_is_enforced_when_docid_vector_is_used)
+{
+ // Track rank score for 4 best hits from first phase, overflow to docid vector
+ RankDropFixture f(10000, 4);
+ f.setup();
+ f.check_result(13.0, {{11,14}},
+ {}, {5,10,12,14,15,16});
+}
+
+TEST(HitCollectorTest, require_that_bitvector_is_not_dropped_without_second_phase_rank_drop_limit)
+{
+ // Track rank score for 4 best hits from first phase, overflow to bitvector
+ RankDropFixture f(20, 4);
+ f.setup();
+ f.check_result(std::nullopt, {{5,11},{10,12},{11,14},{12,13}},
+ f.make_bv({5,10,11,12,14,15,16}), {});
+}
+
+TEST(HitCollectorTest, require_that_bitvector_is_dropped_with_second_phase_rank_drop_limit)
+{
+ // Track rank for 4 best hits from first phase, overflow to bitvector
+ RankDropFixture f(20, 4);
+ f.setup();
+ f.check_result(9.0, {{5,11},{10,12},{11,14},{12,13}},
+ {}, {14,15,16});
}
-TEST_MAIN() { TEST_RUN_ALL(); }
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/tests/hitcollector/sorted_hit_sequence_test.cpp b/searchlib/src/tests/hitcollector/sorted_hit_sequence_test.cpp
index c1c3a550d9b..4eefa5b5dfa 100644
--- a/searchlib/src/tests/hitcollector/sorted_hit_sequence_test.cpp
+++ b/searchlib/src/tests/hitcollector/sorted_hit_sequence_test.cpp
@@ -1,7 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/searchlib/queryeval/sorted_hit_sequence.h>
+#include <vespa/vespalib/gtest/gtest.h>
using search::queryeval::SortedHitSequence;
using Hits = std::vector<SortedHitSequence::Hit>;
@@ -10,20 +10,22 @@ using Refs = std::vector<SortedHitSequence::Ref>;
Hits hits({{1,10.0},{2,30.0},{3,20.0}});
Refs refs({1,2,0});
-TEST("require that empty hit sequence is empty") {
+TEST(SortedHitsSEquenceTest, require_that_empty_hit_sequence_is_empty)
+{
EXPECT_TRUE(!SortedHitSequence(nullptr, nullptr, 0).valid());
EXPECT_TRUE(!SortedHitSequence(&hits[0], &refs[0], 0).valid());
}
-TEST("require that sorted hit sequence can be iterated") {
+TEST(SortedHitsSEquenceTest, require_that_sorted_hit_sequence_can_be_iterated)
+{
SortedHitSequence seq(&hits[0], &refs[0], refs.size());
for (const auto &expect: Hits({{2,30.0},{3,20.0},{1,10.0}})) {
ASSERT_TRUE(seq.valid());
- EXPECT_EQUAL(expect.first, seq.get().first);
- EXPECT_EQUAL(expect.second, seq.get().second);
+ EXPECT_EQ(expect.first, seq.get().first);
+ EXPECT_EQ(expect.second, seq.get().second);
seq.next();
}
EXPECT_TRUE(!seq.valid());
}
-TEST_MAIN() { TEST_RUN_ALL(); }
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp b/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp
index 3154f95bbe1..01587ef485a 100644
--- a/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp
@@ -193,63 +193,206 @@ struct NoRescorer
};
template <typename Rescorer>
+class RerankRescorer {
+ Rescorer _rescorer;
+ using HitVector = std::vector<HitCollector::Hit>;
+ using Iterator = typename HitVector::const_iterator;
+ Iterator _reranked_cur;
+ Iterator _reranked_end;
+public:
+ RerankRescorer(const Rescorer& rescorer,
+ const HitVector& reranked_hits)
+ : _rescorer(rescorer),
+ _reranked_cur(reranked_hits.begin()),
+ _reranked_end(reranked_hits.end())
+ {
+ }
+
+ double rescore(uint32_t docid, double score) noexcept {
+ if (_reranked_cur != _reranked_end && _reranked_cur->first == docid) {
+ double result = _reranked_cur->second;
+ ++_reranked_cur;
+ return result;
+ } else {
+ return _rescorer.rescore(docid, score);
+ }
+ }
+};
+
+class SimpleHitAdder {
+protected:
+ ResultSet& _rs;
+public:
+ SimpleHitAdder(ResultSet& rs)
+ : _rs(rs)
+ {
+ }
+ void add(uint32_t docid, double rank_value) {
+ _rs.push_back({docid, rank_value});
+ }
+};
+
+class ConditionalHitAdder : public SimpleHitAdder {
+protected:
+ double _second_phase_rank_drop_limit;
+public:
+ ConditionalHitAdder(ResultSet& rs, double second_phase_rank_drop_limit)
+ : SimpleHitAdder(rs),
+ _second_phase_rank_drop_limit(second_phase_rank_drop_limit)
+ {
+ }
+ void add(uint32_t docid, double rank_value) {
+ if (rank_value > _second_phase_rank_drop_limit) {
+ _rs.push_back({docid, rank_value});
+ }
+ }
+};
+
+class TrackingConditionalHitAdder : public ConditionalHitAdder {
+ std::vector<uint32_t>& _dropped;
+public:
+ TrackingConditionalHitAdder(ResultSet& rs, double second_phase_rank_drop_limit, std::vector<uint32_t>& dropped)
+ : ConditionalHitAdder(rs, second_phase_rank_drop_limit),
+ _dropped(dropped)
+ {
+ }
+ void add(uint32_t docid, double rank_value) {
+ if (rank_value > _second_phase_rank_drop_limit) {
+ _rs.push_back({docid, rank_value});
+ } else {
+ _dropped.emplace_back(docid);
+ }
+ }
+};
+
+template <typename HitAdder, typename Rescorer>
void
-add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, Rescorer rescorer)
+add_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, Rescorer rescorer)
{
for (auto& hit : hits) {
- rs.push_back({hit.first, rescorer.rescore(hit.first, hit.second)});
+ hit_adder.add(hit.first, rescorer.rescore(hit.first, hit.second));
+ }
+}
+
+template <typename HitAdder, typename Rescorer>
+void
+add_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer)
+{
+ if (reranked_hits.empty()) {
+ add_rescored_hits(hit_adder, hits, rescorer);
+ } else {
+ add_rescored_hits(hit_adder, hits, RerankRescorer(rescorer, reranked_hits));
}
}
template <typename Rescorer>
void
-mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, Rescorer rescorer)
+add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<HitCollector::Hit>& reranked_hits, std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped, Rescorer rescorer)
+{
+ if (second_phase_rank_drop_limit.has_value()) {
+ if (dropped != nullptr) {
+ add_rescored_hits(TrackingConditionalHitAdder(rs, second_phase_rank_drop_limit.value(), *dropped), hits, reranked_hits, rescorer);
+ } else {
+ add_rescored_hits(ConditionalHitAdder(rs, second_phase_rank_drop_limit.value()), hits, reranked_hits, rescorer);
+ }
+ } else {
+ add_rescored_hits(SimpleHitAdder(rs), hits, reranked_hits, rescorer);
+ }
+}
+
+template <typename HitAdder, typename Rescorer>
+void
+mixin_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, Rescorer rescorer)
{
auto hits_cur = hits.begin();
auto hits_end = hits.end();
for (auto docid : docids) {
if (hits_cur != hits_end && docid == hits_cur->first) {
- rs.push_back({docid, rescorer.rescore(docid, hits_cur->second)});
+ hit_adder.add(docid, rescorer.rescore(docid, hits_cur->second));
++hits_cur;
} else {
- rs.push_back({docid, default_value});
+ hit_adder.add(docid, default_value);
+ }
+ }
+}
+
+template <typename HitAdder, typename Rescorer>
+void
+mixin_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer)
+{
+ if (reranked_hits.empty()) {
+ mixin_rescored_hits(hit_adder, hits, docids, default_value, rescorer);
+ } else {
+ mixin_rescored_hits(hit_adder, hits, docids, default_value, RerankRescorer(rescorer, reranked_hits));
+ }
+}
+
+template <typename Rescorer>
+void
+mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, const std::vector<HitCollector::Hit>& reranked_hits, std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped, Rescorer rescorer)
+{
+ if (second_phase_rank_drop_limit.has_value()) {
+ if (dropped != nullptr) {
+ mixin_rescored_hits(TrackingConditionalHitAdder(rs, second_phase_rank_drop_limit.value(), *dropped), hits, docids, default_value, reranked_hits, rescorer);
+ } else {
+ mixin_rescored_hits(ConditionalHitAdder(rs, second_phase_rank_drop_limit.value()), hits, docids, default_value, reranked_hits, rescorer);
}
+ } else {
+ mixin_rescored_hits(SimpleHitAdder(rs), hits, docids, default_value, reranked_hits, rescorer);
}
}
void
-mergeHitsIntoResultSet(const std::vector<HitCollector::Hit> &hits, ResultSet &result)
+add_bitvector_to_dropped(std::vector<uint32_t>& dropped, vespalib::ConstArrayRef<RankedHit> hits, const BitVector& bv)
{
- uint32_t rhCur(0);
- uint32_t rhEnd(result.getArrayUsed());
- for (const auto &hit : hits) {
- while (rhCur != rhEnd && result[rhCur].getDocId() != hit.first) {
- // just set the iterators right
- ++rhCur;
+ auto hits_cur = hits.begin();
+ auto hits_end = hits.end();
+ auto docid = bv.getFirstTrueBit();
+ auto docid_limit = bv.size();
+ while (docid < docid_limit) {
+ if (hits_cur != hits_end && hits_cur->getDocId() == docid) {
+ ++hits_cur;
+ } else {
+ dropped.emplace_back(docid);
}
- assert(rhCur != rhEnd); // the hits should be a subset of the hits in ranked hit array.
- result[rhCur]._rankValue = hit.second;
+ docid = bv.getNextTrueBit(docid + 1);
}
}
}
std::unique_ptr<ResultSet>
-HitCollector::getResultSet(HitRank default_value)
+HitCollector::get_result_set(std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped)
{
+ /*
+ * Use default_rank_value (i.e. -HUGE_VAL) when hit collector saves
+ * rank scores, otherwise use zero_rank_value (i.e. 0.0).
+ */
+ auto default_value = save_rank_scores() ? search::default_rank_value : search::zero_rank_value;
+
bool needReScore = FirstPhaseRescorer::need_rescore(_ranges);
FirstPhaseRescorer rescorer(_ranges);
+ if (dropped != nullptr) {
+ dropped->clear();
+ }
+
// destroys the heap property or score sort order
sortHitsByDocId();
auto rs = std::make_unique<ResultSet>();
- if ( ! _collector->isDocIdCollector() ) {
+ if ( ! _collector->isDocIdCollector() ||
+ (second_phase_rank_drop_limit.has_value() &&
+ (_bitVector || dropped == nullptr))) {
rs->allocArray(_hits.size());
+ auto* dropped_or_null = dropped;
+ if (second_phase_rank_drop_limit.has_value() && _bitVector) {
+ dropped_or_null = nullptr;
+ }
if (needReScore) {
- add_rescored_hits(*rs, _hits, rescorer);
+ add_rescored_hits(*rs, _hits, _reRankedHits, second_phase_rank_drop_limit, dropped_or_null, rescorer);
} else {
- add_rescored_hits(*rs, _hits, NoRescorer());
+ add_rescored_hits(*rs, _hits, _reRankedHits, second_phase_rank_drop_limit, dropped_or_null, NoRescorer());
}
} else {
if (_unordered) {
@@ -257,14 +400,18 @@ HitCollector::getResultSet(HitRank default_value)
}
rs->allocArray(_docIdVector.size());
if (needReScore) {
- mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, rescorer);
+ mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, second_phase_rank_drop_limit, dropped, rescorer);
} else {
- mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, NoRescorer());
+ mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, second_phase_rank_drop_limit, dropped, NoRescorer());
}
}
- if (!_reRankedHits.empty()) {
- mergeHitsIntoResultSet(_reRankedHits, *rs);
+ if (second_phase_rank_drop_limit.has_value() && _bitVector) {
+ if (dropped != nullptr) {
+ assert(dropped->empty());
+ add_bitvector_to_dropped(*dropped, {rs->getArray(), rs->getArrayUsed()}, *_bitVector);
+ }
+ _bitVector.reset();
}
if (_bitVector) {
@@ -274,4 +421,10 @@ HitCollector::getResultSet(HitRank default_value)
return rs;
}
+std::unique_ptr<ResultSet>
+HitCollector::getResultSet()
+{
+ return get_result_set(std::nullopt, nullptr);
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/hitcollector.h b/searchlib/src/vespa/searchlib/queryeval/hitcollector.h
index 903c2ab5b13..c23fb0a6ef6 100644
--- a/searchlib/src/vespa/searchlib/queryeval/hitcollector.h
+++ b/searchlib/src/vespa/searchlib/queryeval/hitcollector.h
@@ -8,6 +8,7 @@
#include <vespa/searchlib/common/resultset.h>
#include <vespa/vespalib/util/sort.h>
#include <algorithm>
+#include <optional>
#include <vector>
namespace search::queryeval {
@@ -121,6 +122,8 @@ private:
VESPA_DLL_LOCAL void sortHitsByScore(size_t topn);
VESPA_DLL_LOCAL void sortHitsByDocId();
+ bool save_rank_scores() const noexcept { return _maxHitsSize != 0; }
+
public:
HitCollector(const HitCollector &) = delete;
HitCollector &operator=(const HitCollector &) = delete;
@@ -164,15 +167,17 @@ public:
const std::pair<Scores, Scores> &getRanges() const { return _ranges; }
void setRanges(const std::pair<Scores, Scores> &ranges);
+ std::unique_ptr<ResultSet>
+ get_result_set(std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped);
+
/**
* Returns a result set based on the content of this collector.
* Invoking this method will destroy the heap property of the
* ranked hits and the match data heap.
*
- * @param auto pointer to the result set
- * @param default_value rank value to be used for results without rank value
+ * @return unique pointer to the result set
**/
- std::unique_ptr<ResultSet> getResultSet(HitRank default_value = default_rank_value);
+ std::unique_ptr<ResultSet> getResultSet();
};
}