summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-06-13 11:45:11 +0000
committerGeir Storli <geirst@verizonmedia.com>2019-06-13 11:45:11 +0000
commit704338454b8051fca33a3a8f55a12b88a017701a (patch)
treebccc2dd3776cd7ba3b037c07cbb9a5d350779042 /searchlib
parentcee7e5ca940bd4db6cd38efaf5f04058c0b9376a (diff)
Add support for overriding k1 and b parameters in bm25 feature via rank properties.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/features/bm25/bm25_test.cpp63
-rw-r--r--searchlib/src/vespa/searchlib/features/bm25_feature.cpp46
-rw-r--r--searchlib/src/vespa/searchlib/features/bm25_feature.h8
3 files changed, 103 insertions, 14 deletions
diff --git a/searchlib/src/tests/features/bm25/bm25_test.cpp b/searchlib/src/tests/features/bm25/bm25_test.cpp
index 6f0d4b80051..eb2f46650a6 100644
--- a/searchlib/src/tests/features/bm25/bm25_test.cpp
+++ b/searchlib/src/tests/features/bm25/bm25_test.cpp
@@ -65,6 +65,18 @@ TEST_F(Bm25BlueprintTest, blueprint_setup_fails_when_parameter_list_is_not_valid
expect_setup_fail({"is", "ia"}); // wrong parameter number
}
+TEST_F(Bm25BlueprintTest, blueprint_setup_fails_when_k1_param_is_malformed)
+{
+ index_env.getProperties().add("bm25(is).k1", "malformed");
+ expect_setup_fail({"is"});
+}
+
+TEST_F(Bm25BlueprintTest, blueprint_setup_fails_when_b_param_is_malformed)
+{
+ index_env.getProperties().add("bm25(is).b", "malformed");
+ expect_setup_fail({"is"});
+}
+
TEST_F(Bm25BlueprintTest, blueprint_setup_succeeds_for_index_field)
{
expect_setup_succeed({"is"});
@@ -82,10 +94,29 @@ TEST_F(Bm25BlueprintTest, blueprint_can_prepare_shared_state_with_average_field_
EXPECT_DOUBLE_EQ(10, as_value<double>(*store.get("bm25.afl.is")));
}
+struct Scorer {
+
+ double avg_field_length;
+ double k1_param;
+ double b_param;
+
+ Scorer() :
+ avg_field_length(10),
+ k1_param(1.2),
+ b_param(0.75)
+ {}
+
+ feature_t score(feature_t num_occs, feature_t field_length, double inverse_doc_freq) const {
+ return inverse_doc_freq * (num_occs * (1 + k1_param)) /
+ (num_occs + (k1_param * ((1 - b_param) + b_param * field_length / avg_field_length)));
+ }
+};
+
struct Bm25ExecutorTest : public ::testing::Test {
BlueprintFactory factory;
FtFeatureTest test;
test::MatchDataBuilder::UP match_data;
+ Scorer scorer;
static constexpr uint32_t total_doc_count = 100;
Bm25ExecutorTest()
@@ -132,9 +163,8 @@ struct Bm25ExecutorTest : public ::testing::Test {
return Bm25Executor::calculate_inverse_document_frequency(matching_doc_count, total_doc_count);
}
- feature_t get_score(feature_t num_occs, feature_t field_length,
- double inverse_doc_freq, double avg_field_length = 10) const {
- return inverse_doc_freq * (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / avg_field_length)));
+ feature_t score(feature_t num_occs, feature_t field_length, double inverse_doc_freq) const {
+ return scorer.score(num_occs, field_length, inverse_doc_freq);
}
};
@@ -142,7 +172,7 @@ TEST_F(Bm25ExecutorTest, score_is_calculated_for_a_single_term)
{
setup();
prepare_term(0, 0, 3, 20);
- EXPECT_TRUE(execute(get_score(3.0, 20, idf(25))));
+ EXPECT_TRUE(execute(score(3.0, 20, idf(25))));
}
TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms)
@@ -150,7 +180,7 @@ TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms)
setup();
prepare_term(0, 0, 3, 20);
prepare_term(1, 0, 7, 5);
- EXPECT_TRUE(execute(get_score(3.0, 20, idf(25)) + get_score(7.0, 5.0, idf(35))));
+ EXPECT_TRUE(execute(score(3.0, 20, idf(25)) + score(7.0, 5.0, idf(35))));
}
TEST_F(Bm25ExecutorTest, term_that_does_not_match_document_is_ignored)
@@ -159,7 +189,7 @@ TEST_F(Bm25ExecutorTest, term_that_does_not_match_document_is_ignored)
prepare_term(0, 0, 3, 20);
uint32_t unmatched_doc_id = 123;
prepare_term(1, 0, 7, 5, unmatched_doc_id);
- EXPECT_TRUE(execute(get_score(3.0, 20, idf(25))));
+ EXPECT_TRUE(execute(score(3.0, 20, idf(25))));
}
TEST_F(Bm25ExecutorTest, term_searching_another_field_is_ignored)
@@ -174,7 +204,8 @@ TEST_F(Bm25ExecutorTest, uses_average_field_length_from_shared_state_if_found)
test.getQueryEnv().getObjectStore().add("bm25.afl.foo", std::make_unique<AnyWrapper<double>>(15));
setup();
prepare_term(0, 0, 3, 20);
- EXPECT_TRUE(execute(get_score(3.0, 20, idf(25), 15)));
+ scorer.avg_field_length = 15;
+ EXPECT_TRUE(execute(score(3.0, 20, idf(25))));
}
TEST_F(Bm25ExecutorTest, calculates_inverse_document_frequency)
@@ -187,4 +218,22 @@ TEST_F(Bm25ExecutorTest, calculates_inverse_document_frequency)
Bm25Executor::calculate_inverse_document_frequency(100, 100));
}
+TEST_F(Bm25ExecutorTest, k1_param_can_be_overriden)
+{
+ test.getIndexEnv().getProperties().add("bm25(foo).k1", "2.5");
+ setup();
+ prepare_term(0, 0, 3, 20);
+ scorer.k1_param = 2.5;
+ EXPECT_TRUE(execute(score(3.0, 20, idf(25))));
+}
+
+TEST_F(Bm25ExecutorTest, b_param_can_be_overriden)
+{
+ test.getIndexEnv().getProperties().add("bm25(foo).b", "0.9");
+ setup();
+ prepare_term(0, 0, 3, 20);
+ scorer.b_param = 0.9;
+ EXPECT_TRUE(execute(score(3.0, 20, idf(25))));
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp
index 5a9e8455d73..e89655a75bb 100644
--- a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp
@@ -4,8 +4,13 @@
#include <vespa/searchlib/fef/itermdata.h>
#include <vespa/searchlib/fef/itermfielddata.h>
#include <vespa/searchlib/fef/objectstore.h>
+#include <vespa/searchlib/fef/properties.h>
#include <cmath>
#include <memory>
+#include <stdexcept>
+
+#include <vespa/log/log.h>
+LOG_SETUP(".features.bm25_feature");
namespace search::features {
@@ -20,14 +25,15 @@ using fef::objectstore::as_value;
Bm25Executor::Bm25Executor(const fef::FieldInfo& field,
const fef::IQueryEnvironment& env,
- double avg_field_length)
+ double avg_field_length,
+ double k1_param,
+ double b_param)
: FeatureExecutor(),
_terms(),
_avg_field_length(avg_field_length),
- _k1_param(1.2),
- _b_param(0.75)
+ _k1_param(k1_param),
+ _b_param(b_param)
{
- // TODO: Add support for setting k1 and b
for (size_t i = 0; i < env.getNumTerms(); ++i) {
const ITermData* term = env.getTerm(i);
for (size_t j = 0; j < term->numFields(); ++j) {
@@ -75,10 +81,31 @@ Bm25Executor::execute(uint32_t doc_id)
outputs().set_number(0, score);
}
+bool
+Bm25Blueprint::lookup_param(const fef::Properties& props, const vespalib::string& param, double& result) const
+{
+ vespalib::string key = getBaseName() + "(" + _field->name() + ")." + param;
+ auto value = props.lookup(key);
+ if (value.found()) {
+ try {
+ result = std::stod(value.get());
+ } catch (const std::invalid_argument& ex) {
+ LOG(warning, "Not able to convert rank property '%s': '%s' to a double value",
+ key.c_str(), value.get().c_str());
+ return false;
+ }
+ }
+ return true;
+}
+
+double constexpr default_k1_param = 1.2;
+double constexpr default_b_param = 0.75;
Bm25Blueprint::Bm25Blueprint()
: Blueprint("bm25"),
- _field(nullptr)
+ _field(nullptr),
+ _k1_param(default_k1_param),
+ _b_param(default_b_param)
{
}
@@ -102,6 +129,13 @@ Bm25Blueprint::setup(const fef::IIndexEnvironment& env, const fef::ParameterList
const auto& field_name = params[0].getValue();
_field = env.getFieldByName(field_name);
+ if (!lookup_param(env.getProperties(), "k1", _k1_param)) {
+ return false;
+ }
+ if (!lookup_param(env.getProperties(), "b", _b_param)) {
+ return false;
+ }
+
describeOutput("score", "The bm25 score for all terms searching in the given index field");
return (_field != nullptr);
}
@@ -132,7 +166,7 @@ Bm25Blueprint::createExecutor(const fef::IQueryEnvironment& env, vespalib::Stash
double avg_field_length = lookup_result != nullptr ?
as_value<double>(*lookup_result) :
env.get_average_field_length(_field->name());
- return stash.create<Bm25Executor>(*_field, env, avg_field_length);
+ return stash.create<Bm25Executor>(*_field, env, avg_field_length, _k1_param, _b_param);
}
}
diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.h b/searchlib/src/vespa/searchlib/features/bm25_feature.h
index 4b5ea0214bc..533c7487a2f 100644
--- a/searchlib/src/vespa/searchlib/features/bm25_feature.h
+++ b/searchlib/src/vespa/searchlib/features/bm25_feature.h
@@ -31,7 +31,9 @@ private:
public:
Bm25Executor(const fef::FieldInfo& field,
const fef::IQueryEnvironment& env,
- double avg_field_length);
+ double avg_field_length,
+ double k1_param,
+ double b_param);
double static calculate_inverse_document_frequency(uint32_t matching_doc_count, uint32_t total_doc_count);
@@ -46,6 +48,10 @@ public:
class Bm25Blueprint : public fef::Blueprint {
private:
const fef::FieldInfo* _field;
+ double _k1_param;
+ double _b_param;
+
+ bool lookup_param(const fef::Properties& props, const vespalib::string& param, double& result) const;
public:
Bm25Blueprint();