diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-06-12 12:32:46 +0000 |
---|---|---|
committer | Geir Storli <geirst@verizonmedia.com> | 2019-06-12 12:32:46 +0000 |
commit | 3704383893ba63e3ea674881f8cd24da6d7683d8 (patch) | |
tree | eaefab6d75e9d5b519dc9c915a289f5cce94c487 /searchlib | |
parent | 9c500e2e0f518a91e3bf67dd5394e8b385d11c26 (diff) |
Use inverse document frequency in calculation of bm25 score.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/tests/features/bm25/bm25_test.cpp | 43 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/bm25_feature.cpp | 14 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/bm25_feature.h | 2 |
3 files changed, 46 insertions, 13 deletions
diff --git a/searchlib/src/tests/features/bm25/bm25_test.cpp b/searchlib/src/tests/features/bm25/bm25_test.cpp index 1a3895d7e28..6f0d4b80051 100644 --- a/searchlib/src/tests/features/bm25/bm25_test.cpp +++ b/searchlib/src/tests/features/bm25/bm25_test.cpp @@ -86,6 +86,7 @@ struct Bm25ExecutorTest : public ::testing::Test { BlueprintFactory factory; FtFeatureTest test; test::MatchDataBuilder::UP match_data; + static constexpr uint32_t total_doc_count = 100; Bm25ExecutorTest() : factory(), @@ -95,11 +96,15 @@ struct Bm25ExecutorTest : public ::testing::Test { setup_search_features(factory); test.getIndexEnv().getBuilder().addField(FieldType::INDEX, CollectionType::SINGLE, "foo"); test.getIndexEnv().getBuilder().addField(FieldType::INDEX, CollectionType::SINGLE, "bar"); - test.getQueryEnv().getBuilder().addIndexNode({"foo"}); - test.getQueryEnv().getBuilder().addIndexNode({"foo"}); - test.getQueryEnv().getBuilder().addIndexNode({"bar"}); + add_query_term("foo", 25); + add_query_term("foo", 35); + add_query_term("bar", 45); test.getQueryEnv().getBuilder().set_avg_field_length("foo", 10); } + void add_query_term(const vespalib::string& field_name, uint32_t matching_doc_count) { + auto* term = test.getQueryEnv().getBuilder().addIndexNode({field_name}); + term->field(0).setDocFreq(matching_doc_count, total_doc_count); + } void setup() { EXPECT_TRUE(test.setup()); match_data = test.createMatchDataBuilder(); @@ -108,7 +113,7 @@ struct Bm25ExecutorTest : public ::testing::Test { clear_term(2, 1); } bool execute(feature_t exp_score) { - return test.execute(exp_score); + return test.execute(exp_score, 0.000001); } void clear_term(uint32_t term_id, uint32_t field_id) { auto* tfmd = match_data->getTermFieldMatchData(term_id, field_id); @@ -123,8 +128,13 @@ struct Bm25ExecutorTest : public ::testing::Test { tfmd->setFieldLength(field_length); } - feature_t get_score(feature_t num_occs, feature_t field_length, double avg_field_length = 10) const { - return (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / avg_field_length))); + double idf(uint32_t matching_doc_count) const { + return Bm25Executor::calculate_inverse_document_frequency(matching_doc_count, total_doc_count); + } + + feature_t get_score(feature_t num_occs, feature_t field_length, + double inverse_doc_freq, double avg_field_length = 10) const { + return inverse_doc_freq * (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / avg_field_length))); } }; @@ -132,7 +142,7 @@ TEST_F(Bm25ExecutorTest, score_is_calculated_for_a_single_term) { setup(); prepare_term(0, 0, 3, 20); - EXPECT_TRUE(execute(get_score(3.0, 20))); + EXPECT_TRUE(execute(get_score(3.0, 20, idf(25)))); } TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms) @@ -140,15 +150,16 @@ TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms) setup(); prepare_term(0, 0, 3, 20); prepare_term(1, 0, 7, 5); - EXPECT_TRUE(execute(get_score(3.0, 20) + get_score(7.0, 5.0))); + EXPECT_TRUE(execute(get_score(3.0, 20, idf(25)) + get_score(7.0, 5.0, idf(35)))); } TEST_F(Bm25ExecutorTest, term_that_does_not_match_document_is_ignored) { setup(); prepare_term(0, 0, 3, 20); - prepare_term(1, 0, 7, 5, 123); - EXPECT_TRUE(execute(get_score(3.0, 20))); + uint32_t unmatched_doc_id = 123; + prepare_term(1, 0, 7, 5, unmatched_doc_id); + EXPECT_TRUE(execute(get_score(3.0, 20, idf(25)))); } TEST_F(Bm25ExecutorTest, term_searching_another_field_is_ignored) @@ -163,7 +174,17 @@ TEST_F(Bm25ExecutorTest, uses_average_field_length_from_shared_state_if_found) test.getQueryEnv().getObjectStore().add("bm25.afl.foo", std::make_unique<AnyWrapper<double>>(15)); setup(); prepare_term(0, 0, 3, 20); - EXPECT_TRUE(execute(get_score(3.0, 20, 15))); + EXPECT_TRUE(execute(get_score(3.0, 20, idf(25), 15))); +} + +TEST_F(Bm25ExecutorTest, calculates_inverse_document_frequency) +{ + EXPECT_DOUBLE_EQ(std::log(1 + (99 + 0.5) / (1 + 0.5)), + Bm25Executor::calculate_inverse_document_frequency(1, 100)); + EXPECT_DOUBLE_EQ(std::log(1 + (60 + 0.5) / (40 + 0.5)), + Bm25Executor::calculate_inverse_document_frequency(40, 100)); + EXPECT_DOUBLE_EQ(std::log(1 + (0.5) / (100 + 0.5)), + Bm25Executor::calculate_inverse_document_frequency(100, 100)); } GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp index 0be3c2876f7..5a9e8455d73 100644 --- a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp @@ -4,6 +4,7 @@ #include <vespa/searchlib/fef/itermdata.h> #include <vespa/searchlib/fef/itermfielddata.h> #include <vespa/searchlib/fef/objectstore.h> +#include <cmath> #include <memory> namespace search::features { @@ -32,13 +33,22 @@ Bm25Executor::Bm25Executor(const fef::FieldInfo& field, for (size_t j = 0; j < term->numFields(); ++j) { const ITermFieldData& term_field = term->field(j); if (field.id() == term_field.getFieldId()) { - // TODO: Add proper calculation of IDF - _terms.emplace_back(term_field.getHandle(MatchDataDetails::Cheap), 1.0); + // TODO: Add support for using significance instead of default idf if specified in the query + _terms.emplace_back(term_field.getHandle(MatchDataDetails::Cheap), + calculate_inverse_document_frequency(term_field.get_matching_doc_count(), + term_field.get_total_doc_count())); } } } } +double +Bm25Executor::calculate_inverse_document_frequency(uint32_t matching_doc_count, uint32_t total_doc_count) +{ + return std::log(1 + (static_cast<double>(total_doc_count - matching_doc_count + 0.5) / + static_cast<double>(matching_doc_count + 0.5))); +} + void Bm25Executor::handle_bind_match_data(const fef::MatchData& match_data) { diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.h b/searchlib/src/vespa/searchlib/features/bm25_feature.h index 4b1576b57b3..4b5ea0214bc 100644 --- a/searchlib/src/vespa/searchlib/features/bm25_feature.h +++ b/searchlib/src/vespa/searchlib/features/bm25_feature.h @@ -33,6 +33,8 @@ public: const fef::IQueryEnvironment& env, double avg_field_length); + double static calculate_inverse_document_frequency(uint32_t matching_doc_count, uint32_t total_doc_count); + void handle_bind_match_data(const fef::MatchData& match_data) override; void execute(uint32_t docId) override; }; |