diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-06-13 11:45:11 +0000 |
---|---|---|
committer | Geir Storli <geirst@verizonmedia.com> | 2019-06-13 11:45:11 +0000 |
commit | 704338454b8051fca33a3a8f55a12b88a017701a (patch) | |
tree | bccc2dd3776cd7ba3b037c07cbb9a5d350779042 /searchlib | |
parent | cee7e5ca940bd4db6cd38efaf5f04058c0b9376a (diff) |
Add support for overriding k1 and b parameters in bm25 feature via rank properties.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/tests/features/bm25/bm25_test.cpp | 63 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/bm25_feature.cpp | 46 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/bm25_feature.h | 8 |
3 files changed, 103 insertions, 14 deletions
diff --git a/searchlib/src/tests/features/bm25/bm25_test.cpp b/searchlib/src/tests/features/bm25/bm25_test.cpp index 6f0d4b80051..eb2f46650a6 100644 --- a/searchlib/src/tests/features/bm25/bm25_test.cpp +++ b/searchlib/src/tests/features/bm25/bm25_test.cpp @@ -65,6 +65,18 @@ TEST_F(Bm25BlueprintTest, blueprint_setup_fails_when_parameter_list_is_not_valid expect_setup_fail({"is", "ia"}); // wrong parameter number } +TEST_F(Bm25BlueprintTest, blueprint_setup_fails_when_k1_param_is_malformed) +{ + index_env.getProperties().add("bm25(is).k1", "malformed"); + expect_setup_fail({"is"}); +} + +TEST_F(Bm25BlueprintTest, blueprint_setup_fails_when_b_param_is_malformed) +{ + index_env.getProperties().add("bm25(is).b", "malformed"); + expect_setup_fail({"is"}); +} + TEST_F(Bm25BlueprintTest, blueprint_setup_succeeds_for_index_field) { expect_setup_succeed({"is"}); @@ -82,10 +94,29 @@ TEST_F(Bm25BlueprintTest, blueprint_can_prepare_shared_state_with_average_field_ EXPECT_DOUBLE_EQ(10, as_value<double>(*store.get("bm25.afl.is"))); } +struct Scorer { + + double avg_field_length; + double k1_param; + double b_param; + + Scorer() : + avg_field_length(10), + k1_param(1.2), + b_param(0.75) + {} + + feature_t score(feature_t num_occs, feature_t field_length, double inverse_doc_freq) const { + return inverse_doc_freq * (num_occs * (1 + k1_param)) / + (num_occs + (k1_param * ((1 - b_param) + b_param * field_length / avg_field_length))); + } +}; + struct Bm25ExecutorTest : public ::testing::Test { BlueprintFactory factory; FtFeatureTest test; test::MatchDataBuilder::UP match_data; + Scorer scorer; static constexpr uint32_t total_doc_count = 100; Bm25ExecutorTest() @@ -132,9 +163,8 @@ struct Bm25ExecutorTest : public ::testing::Test { return Bm25Executor::calculate_inverse_document_frequency(matching_doc_count, total_doc_count); } - feature_t get_score(feature_t num_occs, feature_t field_length, - double inverse_doc_freq, double avg_field_length = 10) const { - return inverse_doc_freq * (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / avg_field_length))); + feature_t score(feature_t num_occs, feature_t field_length, double inverse_doc_freq) const { + return scorer.score(num_occs, field_length, inverse_doc_freq); } }; @@ -142,7 +172,7 @@ TEST_F(Bm25ExecutorTest, score_is_calculated_for_a_single_term) { setup(); prepare_term(0, 0, 3, 20); - EXPECT_TRUE(execute(get_score(3.0, 20, idf(25)))); + EXPECT_TRUE(execute(score(3.0, 20, idf(25)))); } TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms) @@ -150,7 +180,7 @@ TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms) setup(); prepare_term(0, 0, 3, 20); prepare_term(1, 0, 7, 5); - EXPECT_TRUE(execute(get_score(3.0, 20, idf(25)) + get_score(7.0, 5.0, idf(35)))); + EXPECT_TRUE(execute(score(3.0, 20, idf(25)) + score(7.0, 5.0, idf(35)))); } TEST_F(Bm25ExecutorTest, term_that_does_not_match_document_is_ignored) @@ -159,7 +189,7 @@ TEST_F(Bm25ExecutorTest, term_that_does_not_match_document_is_ignored) prepare_term(0, 0, 3, 20); uint32_t unmatched_doc_id = 123; prepare_term(1, 0, 7, 5, unmatched_doc_id); - EXPECT_TRUE(execute(get_score(3.0, 20, idf(25)))); + EXPECT_TRUE(execute(score(3.0, 20, idf(25)))); } TEST_F(Bm25ExecutorTest, term_searching_another_field_is_ignored) @@ -174,7 +204,8 @@ TEST_F(Bm25ExecutorTest, uses_average_field_length_from_shared_state_if_found) test.getQueryEnv().getObjectStore().add("bm25.afl.foo", std::make_unique<AnyWrapper<double>>(15)); setup(); prepare_term(0, 0, 3, 20); - EXPECT_TRUE(execute(get_score(3.0, 20, idf(25), 15))); + scorer.avg_field_length = 15; + EXPECT_TRUE(execute(score(3.0, 20, idf(25)))); } TEST_F(Bm25ExecutorTest, calculates_inverse_document_frequency) @@ -187,4 +218,22 @@ TEST_F(Bm25ExecutorTest, calculates_inverse_document_frequency) Bm25Executor::calculate_inverse_document_frequency(100, 100)); } +TEST_F(Bm25ExecutorTest, k1_param_can_be_overriden) +{ + test.getIndexEnv().getProperties().add("bm25(foo).k1", "2.5"); + setup(); + prepare_term(0, 0, 3, 20); + scorer.k1_param = 2.5; + EXPECT_TRUE(execute(score(3.0, 20, idf(25)))); +} + +TEST_F(Bm25ExecutorTest, b_param_can_be_overriden) +{ + test.getIndexEnv().getProperties().add("bm25(foo).b", "0.9"); + setup(); + prepare_term(0, 0, 3, 20); + scorer.b_param = 0.9; + EXPECT_TRUE(execute(score(3.0, 20, idf(25)))); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp index 5a9e8455d73..e89655a75bb 100644 --- a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp @@ -4,8 +4,13 @@ #include <vespa/searchlib/fef/itermdata.h> #include <vespa/searchlib/fef/itermfielddata.h> #include <vespa/searchlib/fef/objectstore.h> +#include <vespa/searchlib/fef/properties.h> #include <cmath> #include <memory> +#include <stdexcept> + +#include <vespa/log/log.h> +LOG_SETUP(".features.bm25_feature"); namespace search::features { @@ -20,14 +25,15 @@ using fef::objectstore::as_value; Bm25Executor::Bm25Executor(const fef::FieldInfo& field, const fef::IQueryEnvironment& env, - double avg_field_length) + double avg_field_length, + double k1_param, + double b_param) : FeatureExecutor(), _terms(), _avg_field_length(avg_field_length), - _k1_param(1.2), - _b_param(0.75) + _k1_param(k1_param), + _b_param(b_param) { - // TODO: Add support for setting k1 and b for (size_t i = 0; i < env.getNumTerms(); ++i) { const ITermData* term = env.getTerm(i); for (size_t j = 0; j < term->numFields(); ++j) { @@ -75,10 +81,31 @@ Bm25Executor::execute(uint32_t doc_id) outputs().set_number(0, score); } +bool +Bm25Blueprint::lookup_param(const fef::Properties& props, const vespalib::string& param, double& result) const +{ + vespalib::string key = getBaseName() + "(" + _field->name() + ")." + param; + auto value = props.lookup(key); + if (value.found()) { + try { + result = std::stod(value.get()); + } catch (const std::invalid_argument& ex) { + LOG(warning, "Not able to convert rank property '%s': '%s' to a double value", + key.c_str(), value.get().c_str()); + return false; + } + } + return true; +} + +double constexpr default_k1_param = 1.2; +double constexpr default_b_param = 0.75; Bm25Blueprint::Bm25Blueprint() : Blueprint("bm25"), - _field(nullptr) + _field(nullptr), + _k1_param(default_k1_param), + _b_param(default_b_param) { } @@ -102,6 +129,13 @@ Bm25Blueprint::setup(const fef::IIndexEnvironment& env, const fef::ParameterList const auto& field_name = params[0].getValue(); _field = env.getFieldByName(field_name); + if (!lookup_param(env.getProperties(), "k1", _k1_param)) { + return false; + } + if (!lookup_param(env.getProperties(), "b", _b_param)) { + return false; + } + describeOutput("score", "The bm25 score for all terms searching in the given index field"); return (_field != nullptr); } @@ -132,7 +166,7 @@ Bm25Blueprint::createExecutor(const fef::IQueryEnvironment& env, vespalib::Stash double avg_field_length = lookup_result != nullptr ? as_value<double>(*lookup_result) : env.get_average_field_length(_field->name()); - return stash.create<Bm25Executor>(*_field, env, avg_field_length); + return stash.create<Bm25Executor>(*_field, env, avg_field_length, _k1_param, _b_param); } } diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.h b/searchlib/src/vespa/searchlib/features/bm25_feature.h index 4b5ea0214bc..533c7487a2f 100644 --- a/searchlib/src/vespa/searchlib/features/bm25_feature.h +++ b/searchlib/src/vespa/searchlib/features/bm25_feature.h @@ -31,7 +31,9 @@ private: public: Bm25Executor(const fef::FieldInfo& field, const fef::IQueryEnvironment& env, - double avg_field_length); + double avg_field_length, + double k1_param, + double b_param); double static calculate_inverse_document_frequency(uint32_t matching_doc_count, uint32_t total_doc_count); @@ -46,6 +48,10 @@ public: class Bm25Blueprint : public fef::Blueprint { private: const fef::FieldInfo* _field; + double _k1_param; + double _b_param; + + bool lookup_param(const fef::Properties& props, const vespalib::string& param, double& result) const; public: Bm25Blueprint(); |