diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-06-12 16:17:34 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-12 16:17:34 +0200 |
commit | 1f06d0d5f6218dafc1a0a0dc1dae85e86cb6ac90 (patch) | |
tree | e93c5dc8a8581aa17e48e5a0abecbab5f13b6ea5 /searchlib | |
parent | d3e5bf770cf893f7a9196fa49b07eae651d3c688 (diff) | |
parent | 3704383893ba63e3ea674881f8cd24da6d7683d8 (diff) |
Merge pull request #9763 from vespa-engine/geirst/idf-calculation-in-bm25-feature
Geirst/idf calculation in bm25 feature
Diffstat (limited to 'searchlib')
8 files changed, 82 insertions, 54 deletions
diff --git a/searchlib/src/tests/features/bm25/bm25_test.cpp b/searchlib/src/tests/features/bm25/bm25_test.cpp index 1a3895d7e28..6f0d4b80051 100644 --- a/searchlib/src/tests/features/bm25/bm25_test.cpp +++ b/searchlib/src/tests/features/bm25/bm25_test.cpp @@ -86,6 +86,7 @@ struct Bm25ExecutorTest : public ::testing::Test { BlueprintFactory factory; FtFeatureTest test; test::MatchDataBuilder::UP match_data; + static constexpr uint32_t total_doc_count = 100; Bm25ExecutorTest() : factory(), @@ -95,11 +96,15 @@ struct Bm25ExecutorTest : public ::testing::Test { setup_search_features(factory); test.getIndexEnv().getBuilder().addField(FieldType::INDEX, CollectionType::SINGLE, "foo"); test.getIndexEnv().getBuilder().addField(FieldType::INDEX, CollectionType::SINGLE, "bar"); - test.getQueryEnv().getBuilder().addIndexNode({"foo"}); - test.getQueryEnv().getBuilder().addIndexNode({"foo"}); - test.getQueryEnv().getBuilder().addIndexNode({"bar"}); + add_query_term("foo", 25); + add_query_term("foo", 35); + add_query_term("bar", 45); test.getQueryEnv().getBuilder().set_avg_field_length("foo", 10); } + void add_query_term(const vespalib::string& field_name, uint32_t matching_doc_count) { + auto* term = test.getQueryEnv().getBuilder().addIndexNode({field_name}); + term->field(0).setDocFreq(matching_doc_count, total_doc_count); + } void setup() { EXPECT_TRUE(test.setup()); match_data = test.createMatchDataBuilder(); @@ -108,7 +113,7 @@ struct Bm25ExecutorTest : public ::testing::Test { clear_term(2, 1); } bool execute(feature_t exp_score) { - return test.execute(exp_score); + return test.execute(exp_score, 0.000001); } void clear_term(uint32_t term_id, uint32_t field_id) { auto* tfmd = match_data->getTermFieldMatchData(term_id, field_id); @@ -123,8 +128,13 @@ struct Bm25ExecutorTest : public ::testing::Test { tfmd->setFieldLength(field_length); } - feature_t get_score(feature_t num_occs, feature_t field_length, double avg_field_length = 10) const { - return (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / avg_field_length))); + double idf(uint32_t matching_doc_count) const { + return Bm25Executor::calculate_inverse_document_frequency(matching_doc_count, total_doc_count); + } + + feature_t get_score(feature_t num_occs, feature_t field_length, + double inverse_doc_freq, double avg_field_length = 10) const { + return inverse_doc_freq * (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / avg_field_length))); } }; @@ -132,7 +142,7 @@ TEST_F(Bm25ExecutorTest, score_is_calculated_for_a_single_term) { setup(); prepare_term(0, 0, 3, 20); - EXPECT_TRUE(execute(get_score(3.0, 20))); + EXPECT_TRUE(execute(get_score(3.0, 20, idf(25)))); } TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms) @@ -140,15 +150,16 @@ TEST_F(Bm25ExecutorTest, score_is_calculated_for_multiple_terms) setup(); prepare_term(0, 0, 3, 20); prepare_term(1, 0, 7, 5); - EXPECT_TRUE(execute(get_score(3.0, 20) + get_score(7.0, 5.0))); + EXPECT_TRUE(execute(get_score(3.0, 20, idf(25)) + get_score(7.0, 5.0, idf(35)))); } TEST_F(Bm25ExecutorTest, term_that_does_not_match_document_is_ignored) { setup(); prepare_term(0, 0, 3, 20); - prepare_term(1, 0, 7, 5, 123); - EXPECT_TRUE(execute(get_score(3.0, 20))); + uint32_t unmatched_doc_id = 123; + prepare_term(1, 0, 7, 5, unmatched_doc_id); + EXPECT_TRUE(execute(get_score(3.0, 20, idf(25)))); } TEST_F(Bm25ExecutorTest, term_searching_another_field_is_ignored) @@ -163,7 +174,17 @@ TEST_F(Bm25ExecutorTest, uses_average_field_length_from_shared_state_if_found) test.getQueryEnv().getObjectStore().add("bm25.afl.foo", std::make_unique<AnyWrapper<double>>(15)); setup(); prepare_term(0, 0, 3, 20); - EXPECT_TRUE(execute(get_score(3.0, 20, 15))); + EXPECT_TRUE(execute(get_score(3.0, 20, idf(25), 15))); +} + +TEST_F(Bm25ExecutorTest, calculates_inverse_document_frequency) +{ + EXPECT_DOUBLE_EQ(std::log(1 + (99 + 0.5) / (1 + 0.5)), + Bm25Executor::calculate_inverse_document_frequency(1, 100)); + EXPECT_DOUBLE_EQ(std::log(1 + (60 + 0.5) / (40 + 0.5)), + Bm25Executor::calculate_inverse_document_frequency(40, 100)); + EXPECT_DOUBLE_EQ(std::log(1 + (0.5) / (100 + 0.5)), + Bm25Executor::calculate_inverse_document_frequency(100, 100)); } GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/features/prod_features.cpp b/searchlib/src/tests/features/prod_features.cpp index 626a470cb5c..70250b05bf1 100644 --- a/searchlib/src/tests/features/prod_features.cpp +++ b/searchlib/src/tests/features/prod_features.cpp @@ -1968,8 +1968,10 @@ Test::testTerm() .addField(FieldType::INDEX, CollectionType::SINGLE, "idx2") // field 1 .addField(FieldType::ATTRIBUTE, CollectionType::SINGLE, "attr"); // field 2 ft.getQueryEnv().getBuilder().addAllFields().setUniqueId(0); - ft.getQueryEnv().getBuilder().addAllFields().setUniqueId(1).setWeight(search::query::Weight(200)).lookupField(0)->setDocFreq(0.5); - ft.getQueryEnv().getBuilder().addAttributeNode("attr")->setUniqueId(2).setWeight(search::query::Weight(400)).lookupField(2)->setDocFreq(0.25); + ft.getQueryEnv().getBuilder().addAllFields().setUniqueId(1) + .setWeight(search::query::Weight(200)).lookupField(0)->setDocFreq(50, 100); + ft.getQueryEnv().getBuilder().addAttributeNode("attr")->setUniqueId(2) + .setWeight(search::query::Weight(400)).lookupField(2)->setDocFreq(25, 100); // setup connectedness between term 1 and term 0 ft.getQueryEnv().getProperties().add("vespa.term.1.connexity", "0"); ft.getQueryEnv().getProperties().add("vespa.term.1.connexity", "0.7"); diff --git a/searchlib/src/tests/fef/termfieldmodel/termfieldmodel_test.cpp b/searchlib/src/tests/fef/termfieldmodel/termfieldmodel_test.cpp index 9ed94c02287..3a0c334fbba 100644 --- a/searchlib/src/tests/fef/termfieldmodel/termfieldmodel_test.cpp +++ b/searchlib/src/tests/fef/termfieldmodel/termfieldmodel_test.cpp @@ -50,7 +50,7 @@ void testSetup(State &state) { { int i = 1; for (SFR iter(state.term); iter.valid(); iter.next()) { - iter.get().setDocFreq(0.25 * i++); + iter.get().setDocFreq(25 * i++, 100); } } diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp index 0be3c2876f7..5a9e8455d73 100644 --- a/searchlib/src/vespa/searchlib/features/bm25_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/bm25_feature.cpp @@ -4,6 +4,7 @@ #include <vespa/searchlib/fef/itermdata.h> #include <vespa/searchlib/fef/itermfielddata.h> #include <vespa/searchlib/fef/objectstore.h> +#include <cmath> #include <memory> namespace search::features { @@ -32,13 +33,22 @@ Bm25Executor::Bm25Executor(const fef::FieldInfo& field, for (size_t j = 0; j < term->numFields(); ++j) { const ITermFieldData& term_field = term->field(j); if (field.id() == term_field.getFieldId()) { - // TODO: Add proper calculation of IDF - _terms.emplace_back(term_field.getHandle(MatchDataDetails::Cheap), 1.0); + // TODO: Add support for using significance instead of default idf if specified in the query + _terms.emplace_back(term_field.getHandle(MatchDataDetails::Cheap), + calculate_inverse_document_frequency(term_field.get_matching_doc_count(), + term_field.get_total_doc_count())); } } } } +double +Bm25Executor::calculate_inverse_document_frequency(uint32_t matching_doc_count, uint32_t total_doc_count) +{ + return std::log(1 + (static_cast<double>(total_doc_count - matching_doc_count + 0.5) / + static_cast<double>(matching_doc_count + 0.5))); +} + void Bm25Executor::handle_bind_match_data(const fef::MatchData& match_data) { diff --git a/searchlib/src/vespa/searchlib/features/bm25_feature.h b/searchlib/src/vespa/searchlib/features/bm25_feature.h index 4b1576b57b3..4b5ea0214bc 100644 --- a/searchlib/src/vespa/searchlib/features/bm25_feature.h +++ b/searchlib/src/vespa/searchlib/features/bm25_feature.h @@ -33,6 +33,8 @@ public: const fef::IQueryEnvironment& env, double avg_field_length); + double static calculate_inverse_document_frequency(uint32_t matching_doc_count, uint32_t total_doc_count); + void handle_bind_match_data(const fef::MatchData& match_data) override; void execute(uint32_t docId) override; }; diff --git a/searchlib/src/vespa/searchlib/fef/itermfielddata.h b/searchlib/src/vespa/searchlib/fef/itermfielddata.h index 80343db2250..6fb467ce25c 100644 --- a/searchlib/src/vespa/searchlib/fef/itermfielddata.h +++ b/searchlib/src/vespa/searchlib/fef/itermfielddata.h @@ -27,13 +27,26 @@ public: **/ virtual uint32_t getFieldId() const = 0; + + /** + * Returns the number of documents matching this term. + */ + virtual uint32_t get_matching_doc_count() const = 0; + + /** + * Returns the total number of documents in the corpus. + */ + virtual uint32_t get_total_doc_count() const = 0; + /** * Obtain the document frequency. This is a value between 0 and 1 * indicating the ratio of the matching documents to the corpus. * * @return document frequency - **/ - virtual double getDocFreq() const = 0; + **/ + double getDocFreq() const { + return (double)get_matching_doc_count() / (double)get_total_doc_count(); + } /** * Obtain the match handle for this field, diff --git a/searchlib/src/vespa/searchlib/fef/simpletermfielddata.cpp b/searchlib/src/vespa/searchlib/fef/simpletermfielddata.cpp index d1edee7fd07..64906eed22e 100644 --- a/searchlib/src/vespa/searchlib/fef/simpletermfielddata.cpp +++ b/searchlib/src/vespa/searchlib/fef/simpletermfielddata.cpp @@ -2,22 +2,22 @@ #include "simpletermfielddata.h" -namespace search { -namespace fef { +namespace search::fef { SimpleTermFieldData::SimpleTermFieldData(uint32_t fieldId) : _fieldId(fieldId), - _docFreq(0), + _matching_doc_count(0), + _total_doc_count(1), _handle(IllegalHandle) { } SimpleTermFieldData::SimpleTermFieldData(const ITermFieldData &rhs) : _fieldId(rhs.getFieldId()), - _docFreq(rhs.getDocFreq()), + _matching_doc_count(rhs.get_matching_doc_count()), + _total_doc_count(rhs.get_total_doc_count()), _handle(rhs.getHandle()) { } -} // namespace fef -} // namespace search +} diff --git a/searchlib/src/vespa/searchlib/fef/simpletermfielddata.h b/searchlib/src/vespa/searchlib/fef/simpletermfielddata.h index 6f0fbc9af64..d92d3a48f03 100644 --- a/searchlib/src/vespa/searchlib/fef/simpletermfielddata.h +++ b/searchlib/src/vespa/searchlib/fef/simpletermfielddata.h @@ -4,8 +4,7 @@ #include "itermfielddata.h" -namespace search { -namespace fef { +namespace search::fef { /** * Information about a single field that is being searched for a term @@ -17,7 +16,8 @@ class SimpleTermFieldData : public ITermFieldData { private: uint32_t _fieldId; - double _docFreq; + uint32_t _matching_doc_count; + uint32_t _total_doc_count; TermFieldHandle _handle; public: @@ -33,28 +33,14 @@ public: **/ SimpleTermFieldData(uint32_t fieldId); - /** - * Obtain the field id. - * - * @return field id - **/ uint32_t getFieldId() const override final { return _fieldId; } - /** - * Obtain the document frequency. - * - * @return document frequency - **/ - double getDocFreq() const override final { return _docFreq; } + uint32_t get_matching_doc_count() const override { return _matching_doc_count; } + + uint32_t get_total_doc_count() const override { return _total_doc_count; } using ITermFieldData::getHandle; - /** - * Obtain the match handle for this field, - * requesting match data with the given details in the corresponding TermFieldMatchData. - * - * @return match handle (or IllegalHandle) - **/ TermFieldHandle getHandle(MatchDataDetails requestedDetails) const override { (void) requestedDetails; return _handle; @@ -62,20 +48,15 @@ public: /** * Sets the document frequency. - * - * @return this object (for chaining) - * @param docFreq document frequency **/ - SimpleTermFieldData &setDocFreq(double docFreq) { - _docFreq = docFreq; + SimpleTermFieldData &setDocFreq(uint32_t matching_doc_count, uint32_t total_doc_count) { + _matching_doc_count = matching_doc_count; + _total_doc_count = total_doc_count; return *this; } /** * Sets the match handle for this field. - * - * @return this object (for chaining) - * @param handle match handle **/ SimpleTermFieldData &setHandle(TermFieldHandle handle) { _handle = handle; @@ -83,6 +64,5 @@ public: } }; -} // namespace fef -} // namespace search +} |